1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-12-22 16:31:02 +00:00

Compile functions in topological order, in preparation for inlining optimization

This commit is contained in:
Karol Stasiak 2017-12-19 22:09:57 +01:00
parent 86ef4fcaf4
commit 4d8de94c8a
4 changed files with 114 additions and 22 deletions

View File

@ -0,0 +1,47 @@
package millfork
import scala.collection.{immutable, mutable}
/**
* @author Karol Stasiak
*/
object Tarjan {
def sort[T](vertices: Iterable[T], edges: Iterable[(T,T)]): List[T] = {
var index = 0
val s = mutable.Stack[T]()
val indices = mutable.Map[T, Int]()
val lowlinks = mutable.Map[T, Int]()
val onStack = mutable.Set[T]()
var result = List[T]()
def strongConnect(v: T): Unit = {
indices(v) = index
lowlinks(v) = index
index += 1
s.push(v)
onStack += v
edges.filter(_._1 == v).foreach {
case (_, w) =>
if (!indices.contains(w)) {
strongConnect(w)
lowlinks(v) = lowlinks(v) min lowlinks(w)
} else if (onStack(w)) {
lowlinks(v) = lowlinks(v) min indices(w)
}
}
if (lowlinks(v) == indices(v)) {
var w: T = v
do {
w = s.pop()
onStack -= w
result ::= w
} while (w != v)
}
}
vertices.foreach{ v =>
if (!indices.contains(v)) strongConnect(v)
}
result.reverse
}
}

View File

@ -1,5 +1,6 @@
package millfork.node
import millfork.Tarjan
import millfork.error.ErrorReporting
import scala.collection.mutable
@ -20,25 +21,18 @@ case object GlobalVertex extends VariableVertex {
override def function = ""
}
trait CallGraph {
abstract class CallGraph(program: Program) {
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean
}
object RestrictiveCallGraph extends CallGraph {
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false
}
class StandardCallGraph(program: Program) extends CallGraph {
private val entryPoints = mutable.Set[String]()
protected val entryPoints = mutable.Set[String]()
// (F,G) means function F calls function G
private val callEdges = mutable.Set[(String, String)]()
protected val callEdges = mutable.Set[(String, String)]()
// (F,G) means function G is called when building parameters for function F
private val paramEdges = mutable.Set[(String, String)]()
private val multiaccessibleFunctions = mutable.Set[String]()
private val everCalledFunctions = mutable.Set[String]()
private val allFunctions = mutable.Set[String]()
protected val paramEdges = mutable.Set[(String, String)]()
protected val multiaccessibleFunctions = mutable.Set[String]()
protected val everCalledFunctions = mutable.Set[String]()
protected val allFunctions = mutable.Set[String]()
entryPoints += "main"
program.declarations.foreach(s => add(None, Nil, s))
@ -120,6 +114,16 @@ class StandardCallGraph(program: Program) extends CallGraph {
everCalledFunctions(function)
}
def recommendedCompilationOrder: List[String] = Tarjan.sort(allFunctions, callEdges)
}
class RestrictiveCallGraph(program: Program) extends CallGraph(program) {
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false
}
class StandardCallGraph(program: Program) extends CallGraph(program) {
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = {
if (a.function == b.function) {
return false

View File

@ -6,7 +6,7 @@ import millfork.compiler.{CompilationContext, MlCompiler}
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node.CallGraph
import millfork.{CompilationFlag, CompilationOptions}
import millfork.{CompilationFlag, CompilationOptions, Tarjan}
import scala.collection.mutable
@ -159,21 +159,28 @@ class Assembler(private val rootEnv: Environment) {
case f: NormalFunction if f.address.isDefined =>
var index = f.address.get.asInstanceOf[NumericConstant].value.toInt
labelMap(f.name) = index
compileFunction(f, index, optimizations, assembly, options)
val code = compileFunction(f, optimizations, options)
index = outputFunction(code, index, assembly, options)
case _ =>
}
val compiledFunctions = mutable.Map[String, List[AssemblyLine]]()
callGraph.recommendedCompilationOrder.foreach{ f =>
env.maybeGet[NormalFunction](f).foreach( function =>
compiledFunctions(f) = compileFunction(function, optimizations, options)
)
}
var index = platform.org
env.allPreallocatables.foreach {
case f: NormalFunction if f.address.isEmpty && f.name == "main" =>
labelMap(f.name) = index
index = compileFunction(f, index, optimizations, assembly, options)
index = outputFunction(compiledFunctions(f.name), index, assembly, options)
case _ =>
}
env.allPreallocatables.foreach {
case f: NormalFunction if f.address.isEmpty && f.name != "main" =>
labelMap(f.name) = index
index = compileFunction(f, index, optimizations, assembly, options)
index = outputFunction(compiledFunctions(f.name), index, assembly, options)
case _ =>
}
env.allPreallocatables.foreach {
@ -242,16 +249,20 @@ class Assembler(private val rootEnv: Environment) {
AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList)
}
private def compileFunction(f: NormalFunction, startFrom: Int, optimizations: Seq[AssemblyOptimization], assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = {
private def compileFunction(f: NormalFunction, optimizations: Seq[AssemblyOptimization], options: CompilationOptions) :List[AssemblyLine] = {
ErrorReporting.debug("Compiling: " + f.name, f.position)
var index = startFrom
assOut.append("* = $" + startFrom.toHexString)
val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize
unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum
val code = optimizations.foldLeft(unoptimized) { (c, opt) =>
opt.optimize(f, c, options)
}
optimizedCodeSize += code.map(_.sizeInBytes).sum
code
}
private def outputFunction(code:List[AssemblyLine], startFrom: Int, assOut: mutable.ArrayBuffer[String], options: CompilationOptions): Int = {
var index = startFrom
assOut.append("* = $" + startFrom.toHexString)
import millfork.assembly.AddrMode._
import millfork.assembly.Opcode._
for (instr <- code) {

View File

@ -0,0 +1,30 @@
package millfork.test.auxilary
import millfork.Tarjan
/**
* @author Karol Stasiak
*/
object TarjanTest {
def main(s: Array[String]): Unit = {
println(Tarjan.sort(
List(1, 2, 3, 4, 5, 6, 7, 8),
List(
1 -> 2,
2 -> 3,
3 -> 1,
4 -> 2,
4 -> 3,
4 -> 5,
5 -> 4,
5 -> 6,
6 -> 3,
6 -> 7,
7 -> 6,
8 -> 8,
8 -> 5,
8 -> 7
)
))
}
}