mirror of
https://github.com/KarolS/millfork.git
synced 2024-12-22 16:31:02 +00:00
Compile functions in topological order, in preparation for inlining optimization
This commit is contained in:
parent
86ef4fcaf4
commit
4d8de94c8a
47
src/main/scala/millfork/Tarjan.scala
Normal file
47
src/main/scala/millfork/Tarjan.scala
Normal file
@ -0,0 +1,47 @@
|
||||
package millfork
|
||||
|
||||
import scala.collection.{immutable, mutable}
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
object Tarjan {
|
||||
def sort[T](vertices: Iterable[T], edges: Iterable[(T,T)]): List[T] = {
|
||||
var index = 0
|
||||
val s = mutable.Stack[T]()
|
||||
val indices = mutable.Map[T, Int]()
|
||||
val lowlinks = mutable.Map[T, Int]()
|
||||
val onStack = mutable.Set[T]()
|
||||
var result = List[T]()
|
||||
|
||||
def strongConnect(v: T): Unit = {
|
||||
indices(v) = index
|
||||
lowlinks(v) = index
|
||||
index += 1
|
||||
s.push(v)
|
||||
onStack += v
|
||||
edges.filter(_._1 == v).foreach {
|
||||
case (_, w) =>
|
||||
if (!indices.contains(w)) {
|
||||
strongConnect(w)
|
||||
lowlinks(v) = lowlinks(v) min lowlinks(w)
|
||||
} else if (onStack(w)) {
|
||||
lowlinks(v) = lowlinks(v) min indices(w)
|
||||
}
|
||||
}
|
||||
if (lowlinks(v) == indices(v)) {
|
||||
var w: T = v
|
||||
do {
|
||||
w = s.pop()
|
||||
onStack -= w
|
||||
result ::= w
|
||||
} while (w != v)
|
||||
}
|
||||
}
|
||||
|
||||
vertices.foreach{ v =>
|
||||
if (!indices.contains(v)) strongConnect(v)
|
||||
}
|
||||
result.reverse
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package millfork.node
|
||||
|
||||
import millfork.Tarjan
|
||||
import millfork.error.ErrorReporting
|
||||
|
||||
import scala.collection.mutable
|
||||
@ -20,25 +21,18 @@ case object GlobalVertex extends VariableVertex {
|
||||
override def function = ""
|
||||
}
|
||||
|
||||
trait CallGraph {
|
||||
abstract class CallGraph(program: Program) {
|
||||
|
||||
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean
|
||||
}
|
||||
|
||||
object RestrictiveCallGraph extends CallGraph {
|
||||
|
||||
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false
|
||||
}
|
||||
|
||||
class StandardCallGraph(program: Program) extends CallGraph {
|
||||
|
||||
private val entryPoints = mutable.Set[String]()
|
||||
protected val entryPoints = mutable.Set[String]()
|
||||
// (F,G) means function F calls function G
|
||||
private val callEdges = mutable.Set[(String, String)]()
|
||||
protected val callEdges = mutable.Set[(String, String)]()
|
||||
// (F,G) means function G is called when building parameters for function F
|
||||
private val paramEdges = mutable.Set[(String, String)]()
|
||||
private val multiaccessibleFunctions = mutable.Set[String]()
|
||||
private val everCalledFunctions = mutable.Set[String]()
|
||||
private val allFunctions = mutable.Set[String]()
|
||||
protected val paramEdges = mutable.Set[(String, String)]()
|
||||
protected val multiaccessibleFunctions = mutable.Set[String]()
|
||||
protected val everCalledFunctions = mutable.Set[String]()
|
||||
protected val allFunctions = mutable.Set[String]()
|
||||
|
||||
entryPoints += "main"
|
||||
program.declarations.foreach(s => add(None, Nil, s))
|
||||
@ -120,6 +114,16 @@ class StandardCallGraph(program: Program) extends CallGraph {
|
||||
everCalledFunctions(function)
|
||||
}
|
||||
|
||||
def recommendedCompilationOrder: List[String] = Tarjan.sort(allFunctions, callEdges)
|
||||
}
|
||||
|
||||
class RestrictiveCallGraph(program: Program) extends CallGraph(program) {
|
||||
|
||||
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false
|
||||
}
|
||||
|
||||
class StandardCallGraph(program: Program) extends CallGraph(program) {
|
||||
|
||||
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = {
|
||||
if (a.function == b.function) {
|
||||
return false
|
||||
|
@ -6,7 +6,7 @@ import millfork.compiler.{CompilationContext, MlCompiler}
|
||||
import millfork.env._
|
||||
import millfork.error.ErrorReporting
|
||||
import millfork.node.CallGraph
|
||||
import millfork.{CompilationFlag, CompilationOptions}
|
||||
import millfork.{CompilationFlag, CompilationOptions, Tarjan}
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
@ -159,21 +159,28 @@ class Assembler(private val rootEnv: Environment) {
|
||||
case f: NormalFunction if f.address.isDefined =>
|
||||
var index = f.address.get.asInstanceOf[NumericConstant].value.toInt
|
||||
labelMap(f.name) = index
|
||||
compileFunction(f, index, optimizations, assembly, options)
|
||||
val code = compileFunction(f, optimizations, options)
|
||||
index = outputFunction(code, index, assembly, options)
|
||||
case _ =>
|
||||
}
|
||||
|
||||
val compiledFunctions = mutable.Map[String, List[AssemblyLine]]()
|
||||
callGraph.recommendedCompilationOrder.foreach{ f =>
|
||||
env.maybeGet[NormalFunction](f).foreach( function =>
|
||||
compiledFunctions(f) = compileFunction(function, optimizations, options)
|
||||
)
|
||||
}
|
||||
var index = platform.org
|
||||
env.allPreallocatables.foreach {
|
||||
case f: NormalFunction if f.address.isEmpty && f.name == "main" =>
|
||||
labelMap(f.name) = index
|
||||
index = compileFunction(f, index, optimizations, assembly, options)
|
||||
index = outputFunction(compiledFunctions(f.name), index, assembly, options)
|
||||
case _ =>
|
||||
}
|
||||
env.allPreallocatables.foreach {
|
||||
case f: NormalFunction if f.address.isEmpty && f.name != "main" =>
|
||||
labelMap(f.name) = index
|
||||
index = compileFunction(f, index, optimizations, assembly, options)
|
||||
index = outputFunction(compiledFunctions(f.name), index, assembly, options)
|
||||
case _ =>
|
||||
}
|
||||
env.allPreallocatables.foreach {
|
||||
@ -242,16 +249,20 @@ class Assembler(private val rootEnv: Environment) {
|
||||
AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList)
|
||||
}
|
||||
|
||||
private def compileFunction(f: NormalFunction, startFrom: Int, optimizations: Seq[AssemblyOptimization], assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = {
|
||||
private def compileFunction(f: NormalFunction, optimizations: Seq[AssemblyOptimization], options: CompilationOptions) :List[AssemblyLine] = {
|
||||
ErrorReporting.debug("Compiling: " + f.name, f.position)
|
||||
var index = startFrom
|
||||
assOut.append("* = $" + startFrom.toHexString)
|
||||
val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize
|
||||
unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum
|
||||
val code = optimizations.foldLeft(unoptimized) { (c, opt) =>
|
||||
opt.optimize(f, c, options)
|
||||
}
|
||||
optimizedCodeSize += code.map(_.sizeInBytes).sum
|
||||
code
|
||||
}
|
||||
|
||||
private def outputFunction(code:List[AssemblyLine], startFrom: Int, assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = {
|
||||
var index = startFrom
|
||||
assOut.append("* = $" + startFrom.toHexString)
|
||||
import millfork.assembly.AddrMode._
|
||||
import millfork.assembly.Opcode._
|
||||
for (instr <- code) {
|
||||
|
30
src/test/scala/millfork/test/auxilary/TarjanTest.scala
Normal file
30
src/test/scala/millfork/test/auxilary/TarjanTest.scala
Normal file
@ -0,0 +1,30 @@
|
||||
package millfork.test.auxilary
|
||||
|
||||
import millfork.Tarjan
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
object TarjanTest {
|
||||
def main(s: Array[String]): Unit = {
|
||||
println(Tarjan.sort(
|
||||
List(1, 2, 3, 4, 5, 6, 7, 8),
|
||||
List(
|
||||
1 -> 2,
|
||||
2 -> 3,
|
||||
3 -> 1,
|
||||
4 -> 2,
|
||||
4 -> 3,
|
||||
4 -> 5,
|
||||
5 -> 4,
|
||||
5 -> 6,
|
||||
6 -> 3,
|
||||
6 -> 7,
|
||||
7 -> 6,
|
||||
8 -> 8,
|
||||
8 -> 5,
|
||||
8 -> 7
|
||||
)
|
||||
))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user