package millfork.node import millfork.error.ErrorReporting import scala.collection.mutable /** * @author Karol Stasiak */ sealed trait VariableVertex { def function: String } case class ParamVertex(function: String) extends VariableVertex case class LocalVertex(function: String) extends VariableVertex case object GlobalVertex extends VariableVertex { override def function = "" } trait CallGraph { def canOverlap(a: VariableVertex, b: VariableVertex): Boolean } object RestrictiveCallGraph extends CallGraph { def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = false } class StandardCallGraph(program: Program) extends CallGraph { private val entryPoints = mutable.Set[String]() // (F,G) means function F calls function G private val callEdges = mutable.Set[(String, String)]() // (F,G) means function G is called when building parameters for function F private val paramEdges = mutable.Set[(String, String)]() private val multiaccessibleFunctions = mutable.Set[String]() private val everCalledFunctions = mutable.Set[String]() private val allFunctions = mutable.Set[String]() entryPoints += "main" program.declarations.foreach(s => add(None, Nil, s)) everCalledFunctions.retain(allFunctions) def add(currentFunction: Option[String], callingFunctions: List[String], node: Node): Unit = { node match { case f: FunctionDeclarationStatement => allFunctions += f.name if (f.address.isDefined || f.interrupt) entryPoints += f.name f.statements.getOrElse(Nil).foreach(s => this.add(Some(f.name), Nil, s)) case s: Statement => s.getAllExpressions.foreach(e => add(currentFunction, callingFunctions, e)) case g: FunctionCallExpression => everCalledFunctions += g.functionName currentFunction.foreach(f => callEdges += f -> g.functionName) callingFunctions.foreach(f => paramEdges += f -> g.functionName) g.expressions.foreach(expr => add(currentFunction, g.functionName :: callingFunctions, expr)) case x: VariableExpression => val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") everCalledFunctions += varName case _ => () } } def fillOut(): Unit = { var changed = true while (changed) { changed = false val toAdd = for { (a, b) <- callEdges (c, d) <- callEdges if b == c if !callEdges.contains(a -> d) } yield (a, d) if (toAdd.nonEmpty) { callEdges ++= toAdd changed = true } } changed = true while (changed) { changed = false val toAdd = for { (a, b) <- paramEdges (c, d) <- callEdges if b == c if !paramEdges.contains(a -> d) } yield (a, d) if (toAdd.nonEmpty) { paramEdges ++= toAdd changed = true } } multiaccessibleFunctions ++= entryPoints everCalledFunctions ++= entryPoints callEdges.filter(e => entryPoints.contains(e._1)).foreach(e => everCalledFunctions += e._2) multiaccessibleFunctions ++= callEdges.filter(e => entryPoints.contains(e._1)).map(_._2).groupBy(identity).filter(p => p._2.size > 1).keys ErrorReporting.trace("Call edges:") callEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString)) ErrorReporting.trace("Param edges:") paramEdges.toList.sorted.foreach(s => ErrorReporting.trace(s.toString)) ErrorReporting.trace("Entry points:") entryPoints.toList.sorted.foreach(ErrorReporting.trace(_)) ErrorReporting.trace("Multiaccessible functions:") multiaccessibleFunctions.toList.sorted.foreach(ErrorReporting.trace(_)) ErrorReporting.trace("Ever called functions:") everCalledFunctions.toList.sorted.foreach(ErrorReporting.trace(_)) } def isEverCalled(function: String): Boolean = { everCalledFunctions(function) } def canOverlap(a: VariableVertex, b: VariableVertex): Boolean = { if (a.function == b.function) { return false } if (a == GlobalVertex || b == GlobalVertex) { return false } if (multiaccessibleFunctions(a.function) || multiaccessibleFunctions(b.function)) { return false } if (callEdges(a.function -> b.function) || callEdges(b.function -> a.function)) { return false } a match { case ParamVertex(af) => if (paramEdges(af -> b.function)) return false case _ => } b match { case ParamVertex(bf) => if (paramEdges(bf -> a.function)) return false case _ => } ErrorReporting.trace(s"$a and $b can overlap") true } }