mirror of
https://github.com/KarolS/millfork.git
synced 2024-06-12 06:29:34 +00:00
210 lines
7.1 KiB
Scala
210 lines
7.1 KiB
Scala
package millfork.node
|
|
|
|
import millfork.Tarjan
|
|
import millfork.error.{ConsoleLogger, Logger}
|
|
|
|
import scala.collection.mutable
|
|
|
|
/**
|
|
* @author Karol Stasiak
|
|
*/
|
|
|
|
sealed trait VariableVertex {
|
|
def function: String
|
|
}
|
|
|
|
case class ParamVertex(function: String) extends VariableVertex
|
|
|
|
case class LocalVertex(function: String) extends VariableVertex
|
|
|
|
case object GlobalVertex extends VariableVertex {
|
|
override def function = ""
|
|
}
|
|
|
|
abstract class CallGraph(program: Program, log: Logger) {
|
|
|
|
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean
|
|
|
|
protected val entryPoints: mutable.Set[String] = mutable.Set[String]()
|
|
// (F,G) means function F calls function G
|
|
protected val callEdges: mutable.Set[(String, String)] = mutable.Set[(String, String)]()
|
|
// (F,G) means function G is called when building parameters for function F
|
|
protected val paramEdges: mutable.Set[(String, String)] = mutable.Set[(String, String)]()
|
|
protected val multiaccessibleFunctions: mutable.Set[String] = mutable.Set[String]()
|
|
protected val everCalledFunctions: mutable.Set[String] = mutable.Set[String]()
|
|
protected val allFunctions: mutable.Set[String] = mutable.Set[String]()
|
|
protected val aliases: mutable.Map[String, String] = mutable.Map[String, String]()
|
|
|
|
entryPoints += "main"
|
|
program.declarations.foreach(s => add(None, Nil, s))
|
|
everCalledFunctions ++= everCalledFunctions.flatMap(aliases.get)
|
|
everCalledFunctions.retain(allFunctions)
|
|
fillOut()
|
|
|
|
def add(currentFunction: Option[String], callingFunctions: List[String], node: Node): Unit = {
|
|
node match {
|
|
case AliasDefinitionStatement(name, target, _) =>
|
|
aliases += name -> target
|
|
case f: FunctionDeclarationStatement =>
|
|
allFunctions += f.name
|
|
allFunctions += f.name + ".trampoline" // TODO: ???
|
|
if (f.address.isDefined || f.interrupt || f.kernalInterrupt) entryPoints += f.name
|
|
f.statements.getOrElse(Nil).foreach(s => this.add(Some(f.name), Nil, s))
|
|
case s: Statement =>
|
|
s.getAllExpressions.foreach(e => add(currentFunction, callingFunctions, e))
|
|
case g: FunctionCallExpression =>
|
|
everCalledFunctions += g.functionName
|
|
currentFunction.foreach(f => callEdges += f -> g.functionName)
|
|
callingFunctions.foreach(f => paramEdges += f -> g.functionName)
|
|
g.expressions.foreach(expr => add(currentFunction, g.functionName :: callingFunctions, expr))
|
|
case s: SumExpression =>
|
|
s.expressions.foreach(expr => add(currentFunction, callingFunctions, expr._2))
|
|
case x: VariableExpression =>
|
|
val varName0 = x.name.stripSuffix(".hi").stripSuffix(".lo")
|
|
if (varName0.endsWith(".pointer")) {
|
|
val trampolineName = varName0.stripSuffix(".pointer") + ".trampoline"
|
|
everCalledFunctions += trampolineName
|
|
entryPoints += trampolineName
|
|
}
|
|
val varName = varName0.stripSuffix(".addr").stripSuffix(".pointer")
|
|
everCalledFunctions += varName
|
|
entryPoints += varName // TODO: figure out how to interpret pointed-to functions
|
|
case i: IndexedExpression =>
|
|
val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr").stripSuffix(".pointer")
|
|
everCalledFunctions += varName
|
|
entryPoints += varName
|
|
add(currentFunction, callingFunctions, i.index)
|
|
case i: DerefExpression =>
|
|
add(currentFunction, callingFunctions, i.inner)
|
|
case i: DerefDebuggingExpression =>
|
|
add(currentFunction, callingFunctions, i.inner)
|
|
case IndirectFieldExpression(root, firstIndices, fields) =>
|
|
add(currentFunction, callingFunctions, root)
|
|
firstIndices.foreach(i => add(currentFunction, callingFunctions, i))
|
|
fields.foreach(f => f._3.foreach(i => add(currentFunction, callingFunctions, i)))
|
|
case _ => ()
|
|
}
|
|
}
|
|
|
|
|
|
def fillOut(): Unit = {
|
|
|
|
callEdges ++= callEdges.flatMap {
|
|
case (a,b) => aliases.get(b).map(a -> _)
|
|
}
|
|
paramEdges ++= paramEdges.flatMap {
|
|
case (a,b) => aliases.get(b).map(a -> _)
|
|
}
|
|
|
|
callEdges ++= everCalledFunctions.filter(_.endsWith(".trampoline")).map(t => t -> t.stripSuffix(".trampoline"))
|
|
|
|
var changed = true
|
|
while (changed) {
|
|
changed = false
|
|
val toAdd = for {
|
|
(a, b) <- callEdges
|
|
(c, d) <- callEdges
|
|
if b == c
|
|
if !callEdges.contains(a -> d)
|
|
} yield (a, d)
|
|
if (toAdd.nonEmpty) {
|
|
callEdges ++= toAdd
|
|
changed = true
|
|
}
|
|
}
|
|
|
|
changed = true
|
|
while (changed) {
|
|
changed = false
|
|
val toAdd = for {
|
|
(a, b) <- paramEdges
|
|
(c, d) <- callEdges
|
|
if b == c
|
|
if !paramEdges.contains(a -> d)
|
|
} yield (a, d)
|
|
if (toAdd.nonEmpty) {
|
|
paramEdges ++= toAdd
|
|
changed = true
|
|
}
|
|
}
|
|
multiaccessibleFunctions ++= entryPoints
|
|
everCalledFunctions ++= entryPoints
|
|
callEdges.filter(e => entryPoints.contains(e._1)).foreach(e => everCalledFunctions += e._2)
|
|
multiaccessibleFunctions ++= callEdges.filter(e => entryPoints.contains(e._1)).map(_._2).groupBy(identity).filter(p => p._2.size > 1).keys
|
|
for {
|
|
operator <- everCalledFunctions
|
|
if operator.nonEmpty && operator.head != '_' && !operator.head.isLetterOrDigit
|
|
internal <- allFunctions
|
|
if internal.startsWith("__")
|
|
} {
|
|
callEdges += operator -> internal
|
|
}
|
|
|
|
if (log.traceEnabled) {
|
|
log.trace("Call edges:")
|
|
callEdges.toList.sorted.foreach(s => log.trace(s.toString))
|
|
|
|
log.trace("Param edges:")
|
|
paramEdges.toList.sorted.foreach(s => log.trace(s.toString))
|
|
|
|
log.trace("Entry points:")
|
|
entryPoints.toList.sorted.foreach(log.trace(_))
|
|
|
|
log.trace("Multiaccessible functions:")
|
|
multiaccessibleFunctions.toList.sorted.foreach(log.trace(_))
|
|
|
|
log.trace("Ever called functions:")
|
|
everCalledFunctions.toList.sorted.foreach(log.trace(_))
|
|
}
|
|
}
|
|
|
|
def isEverCalled(function: String): Boolean = {
|
|
everCalledFunctions(function)
|
|
}
|
|
|
|
def recommendedCompilationOrder: List[String] = Tarjan.sort(allFunctions, callEdges)
|
|
}
|
|
|
|
class RestrictiveCallGraph(program: Program, log: Logger) extends CallGraph(program, log) {
|
|
|
|
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false
|
|
}
|
|
|
|
class StandardCallGraph(program: Program, log: Logger) extends CallGraph(program, log) {
|
|
|
|
def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = {
|
|
if (a.function == b.function) {
|
|
return false
|
|
}
|
|
if (a == GlobalVertex || b == GlobalVertex) {
|
|
return false
|
|
}
|
|
if (multiaccessibleFunctions(a.function) || multiaccessibleFunctions(b.function)) {
|
|
return false
|
|
}
|
|
if (a.function + ".trampoline" == b.function) {
|
|
return false
|
|
}
|
|
if (a.function == b.function + ".trampoline") {
|
|
return false
|
|
}
|
|
if (callEdges(a.function -> b.function) || callEdges(b.function -> a.function)) {
|
|
return false
|
|
}
|
|
a match {
|
|
case ParamVertex(af) =>
|
|
if (paramEdges(af -> b.function)) return false
|
|
case _ =>
|
|
}
|
|
b match {
|
|
case ParamVertex(bf) =>
|
|
if (paramEdges(bf -> a.function)) return false
|
|
case _ =>
|
|
}
|
|
log.trace(s"$a and $b can overlap")
|
|
true
|
|
}
|
|
|
|
|
|
}
|