diff --git a/src/main/scala/millfork/Tarjan.scala b/src/main/scala/millfork/Tarjan.scala new file mode 100644 index 00000000..b5c85a8f --- /dev/null +++ b/src/main/scala/millfork/Tarjan.scala @@ -0,0 +1,47 @@ +package millfork + +import scala.collection.{immutable, mutable} + +/** + * @author Karol Stasiak + */ +object Tarjan { + def sort[T](vertices: Iterable[T], edges: Iterable[(T,T)]): List[T] = { + var index = 0 + val s = mutable.Stack[T]() + val indices = mutable.Map[T, Int]() + val lowlinks = mutable.Map[T, Int]() + val onStack = mutable.Set[T]() + var result = List[T]() + + def strongConnect(v: T): Unit = { + indices(v) = index + lowlinks(v) = index + index += 1 + s.push(v) + onStack += v + edges.filter(_._1 == v).foreach { + case (_, w) => + if (!indices.contains(w)) { + strongConnect(w) + lowlinks(v) = lowlinks(v) min lowlinks(w) + } else if (onStack(w)) { + lowlinks(v) = lowlinks(v) min indices(w) + } + } + if (lowlinks(v) == indices(v)) { + var w: T = v + do { + w = s.pop() + onStack -= w + result ::= w + } while (w != v) + } + } + + vertices.foreach{ v => + if (!indices.contains(v)) strongConnect(v) + } + result.reverse + } +} diff --git a/src/main/scala/millfork/node/CallGraph.scala b/src/main/scala/millfork/node/CallGraph.scala index 9fc994bd..c70b67aa 100644 --- a/src/main/scala/millfork/node/CallGraph.scala +++ b/src/main/scala/millfork/node/CallGraph.scala @@ -1,5 +1,6 @@ package millfork.node +import millfork.Tarjan import millfork.error.ErrorReporting import scala.collection.mutable @@ -20,25 +21,18 @@ case object GlobalVertex extends VariableVertex { override def function = "" } -trait CallGraph { +abstract class CallGraph(program: Program) { + def canOverlap(a: VariableVertex, b: VariableVertex): Boolean -} -object RestrictiveCallGraph extends CallGraph { - - def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false -} - -class StandardCallGraph(program: Program) extends CallGraph { - - private val entryPoints = mutable.Set[String]() + protected val entryPoints = mutable.Set[String]() // (F,G) means function F calls function G - private val callEdges = mutable.Set[(String, String)]() + protected val callEdges = mutable.Set[(String, String)]() // (F,G) means function G is called when building parameters for function F - private val paramEdges = mutable.Set[(String, String)]() - private val multiaccessibleFunctions = mutable.Set[String]() - private val everCalledFunctions = mutable.Set[String]() - private val allFunctions = mutable.Set[String]() + protected val paramEdges = mutable.Set[(String, String)]() + protected val multiaccessibleFunctions = mutable.Set[String]() + protected val everCalledFunctions = mutable.Set[String]() + protected val allFunctions = mutable.Set[String]() entryPoints += "main" program.declarations.foreach(s => add(None, Nil, s)) @@ -120,6 +114,16 @@ class StandardCallGraph(program: Program) extends CallGraph { everCalledFunctions(function) } + def recommendedCompilationOrder: List[String] = Tarjan.sort(allFunctions, callEdges) +} + +class RestrictiveCallGraph(program: Program) extends CallGraph(program) { + + def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false +} + +class StandardCallGraph(program: Program) extends CallGraph(program) { + def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = { if (a.function == b.function) { return false diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala index 42183256..31d3eb26 100644 --- a/src/main/scala/millfork/output/Assembler.scala +++ b/src/main/scala/millfork/output/Assembler.scala @@ -6,7 +6,7 @@ import millfork.compiler.{CompilationContext, MlCompiler} import millfork.env._ import millfork.error.ErrorReporting import millfork.node.CallGraph -import millfork.{CompilationFlag, CompilationOptions} +import millfork.{CompilationFlag, CompilationOptions, Tarjan} import scala.collection.mutable @@ -159,21 +159,28 @@ class Assembler(private val rootEnv: Environment) { case f: NormalFunction if f.address.isDefined => var index = f.address.get.asInstanceOf[NumericConstant].value.toInt labelMap(f.name) = index - compileFunction(f, index, optimizations, assembly, options) + val code = compileFunction(f, optimizations, options) + index = outputFunction(code, index, assembly, options) case _ => } + val compiledFunctions = mutable.Map[String, List[AssemblyLine]]() + callGraph.recommendedCompilationOrder.foreach{ f => + env.maybeGet[NormalFunction](f).foreach( function => + compiledFunctions(f) = compileFunction(function, optimizations, options) + ) + } var index = platform.org env.allPreallocatables.foreach { case f: NormalFunction if f.address.isEmpty && f.name == "main" => labelMap(f.name) = index - index = compileFunction(f, index, optimizations, assembly, options) + index = outputFunction(compiledFunctions(f.name), index, assembly, options) case _ => } env.allPreallocatables.foreach { case f: NormalFunction if f.address.isEmpty && f.name != "main" => labelMap(f.name) = index - index = compileFunction(f, index, optimizations, assembly, options) + index = outputFunction(compiledFunctions(f.name), index, assembly, options) case _ => } env.allPreallocatables.foreach { @@ -242,16 +249,20 @@ class Assembler(private val rootEnv: Environment) { AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList) } - private def compileFunction(f: NormalFunction, startFrom: Int, optimizations: Seq[AssemblyOptimization], assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = { + private def compileFunction(f: NormalFunction, optimizations: Seq[AssemblyOptimization], options: CompilationOptions) :List[AssemblyLine] = { ErrorReporting.debug("Compiling: " + f.name, f.position) - var index = startFrom - assOut.append("* = $" + startFrom.toHexString) val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum val code = optimizations.foldLeft(unoptimized) { (c, opt) => opt.optimize(f, c, options) } optimizedCodeSize += code.map(_.sizeInBytes).sum + code + } + + private def outputFunction(code:List[AssemblyLine], startFrom: Int, assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = { + var index = startFrom + assOut.append("* = $" + startFrom.toHexString) import millfork.assembly.AddrMode._ import millfork.assembly.Opcode._ for (instr <- code) { diff --git a/src/test/scala/millfork/test/auxilary/TarjanTest.scala b/src/test/scala/millfork/test/auxilary/TarjanTest.scala new file mode 100644 index 00000000..7d71cb64 --- /dev/null +++ b/src/test/scala/millfork/test/auxilary/TarjanTest.scala @@ -0,0 +1,30 @@ +package millfork.test.auxilary + +import millfork.Tarjan + +/** + * @author Karol Stasiak + */ +object TarjanTest { + def main(s: Array[String]): Unit = { + println(Tarjan.sort( + List(1, 2, 3, 4, 5, 6, 7, 8), + List( + 1 -> 2, + 2 -> 3, + 3 -> 1, + 4 -> 2, + 4 -> 3, + 4 -> 5, + 5 -> 4, + 5 -> 6, + 6 -> 3, + 6 -> 7, + 7 -> 6, + 8 -> 8, + 8 -> 5, + 8 -> 7 + ) + )) + } +}