From f4ecc2512c628a37ed0bfba85e082f24e2e8c94d Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Mon, 16 Jul 2018 23:05:16 +0200 Subject: [PATCH] Refactoring the inlining calculators --- .../millfork/output/AbstractAssembler.scala | 2 +- .../output/AbstractInliningCalculator.scala | 77 +++++++++++++++++++ .../output/MosInliningCalculator.scala | 69 ----------------- 3 files changed, 78 insertions(+), 70 deletions(-) diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index c0563322..95446ab5 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -189,7 +189,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program val assembly = mutable.ArrayBuffer[String]() - val inliningResult = MosInliningCalculator.calculate( + val inliningResult = inliningCalculator.calculate( program, options.flags(CompilationFlag.InlineFunctions) || options.flags(CompilationFlag.OptimizeForSonicSpeed), if (options.flags(CompilationFlag.OptimizeForSonicSpeed)) 4.0 diff --git a/src/main/scala/millfork/output/AbstractInliningCalculator.scala b/src/main/scala/millfork/output/AbstractInliningCalculator.scala index 3fb94d0b..ef43bde0 100644 --- a/src/main/scala/millfork/output/AbstractInliningCalculator.scala +++ b/src/main/scala/millfork/output/AbstractInliningCalculator.scala @@ -1,7 +1,12 @@ package millfork.output import millfork.assembly.AbstractCode +import millfork.assembly.mos.Opcode +import millfork.assembly.z80.ZOpcode import millfork.compiler.AbstractCompiler +import millfork.node._ + +import scala.collection.mutable /** * @author Karol Stasiak @@ -9,4 +14,76 @@ import millfork.compiler.AbstractCompiler abstract class AbstractInliningCalculator[T <: AbstractCode] { def codeForInlining(fname: String, functionsAlreadyKnownToBeNonInlineable: Set[String], code: List[T]): Option[List[T]] def inline(code: List[T], inlinedFunctions: Map[String, List[T]], compiler: AbstractCompiler[T]): List[T] + + private val sizes = Seq(64, 64, 8, 6, 5, 5, 4) + + def calculate(program: Program, + inlineByDefault: Boolean, + aggressivenessForNormal: Double, + aggressivenessForRecommended: Double): InliningResult = { + val callCount = mutable.Map[String, Int]().withDefaultValue(0) + val allFunctions = mutable.Set[String]() + val badFunctions = mutable.Set[String]() + val recommendedFunctions = mutable.Set[String]() + getAllCalledFunctions(program.declarations).foreach{ + case (name, true) => badFunctions += name + case (name, false) => callCount(name) += 1 + } + program.declarations.foreach{ + case f:FunctionDeclarationStatement => + allFunctions += f.name + if (f.inlinable.contains(true)) { + recommendedFunctions += f.name + } + if (f.isMacro + || f.inlinable.contains(false) + || f.address.isDefined + || f.interrupt + || f.reentrant + || f.name == "main" + || containsReturnDispatch(f.statements.getOrElse(Nil))) badFunctions += f.name + case _ => + } + allFunctions --= badFunctions + recommendedFunctions --= badFunctions + val map = (if (inlineByDefault) allFunctions else recommendedFunctions).map(f => f -> { + val size = sizes(callCount(f) min (sizes.size - 1)) + val aggressiveness = if (recommendedFunctions(f)) aggressivenessForRecommended else aggressivenessForNormal + (size * aggressiveness).floor.toInt + }).toMap + InliningResult(map, badFunctions.toSet) + } + + protected def containsReturnDispatch(statements: Seq[Statement]): Boolean = statements.exists { + case _: ReturnDispatchStatement => true + case c: CompoundStatement => containsReturnDispatch(c.getChildStatements) + case _ => false + } + + protected def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap { + case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) + case ReturnDispatchStatement(index, params, branches) => + getAllCalledFunctions(List(index)) ++ getAllCalledFunctions(params) ++ getAllCalledFunctions(branches.map(b => b.function)) + case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList) + case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions) + case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) + case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) + case MosAssemblyStatement(Opcode.JSR, _, VariableExpression(name), true) => (name -> false) :: Nil + case Z80AssemblyStatement(ZOpcode.CALL, _, _, VariableExpression(name), true) => (name -> false) :: Nil + case s: Statement => getAllCalledFunctions(s.getAllExpressions) + case s: VariableExpression => Set( + s.name, + s.name.stripSuffix(".addr"), + s.name.stripSuffix(".hi"), + s.name.stripSuffix(".lo"), + s.name.stripSuffix(".addr.lo"), + s.name.stripSuffix(".addr.hi")).toList.map(_ -> true) + case s: LiteralExpression => Nil + case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil) + case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2)) + case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs) + case IndexedExpression(arr, index) => (arr -> true) :: getAllCalledFunctions(List(index)) + case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l)) + case _ => Nil + } } diff --git a/src/main/scala/millfork/output/MosInliningCalculator.scala b/src/main/scala/millfork/output/MosInliningCalculator.scala index f1872755..5f79b087 100644 --- a/src/main/scala/millfork/output/MosInliningCalculator.scala +++ b/src/main/scala/millfork/output/MosInliningCalculator.scala @@ -18,75 +18,6 @@ object MosInliningCalculator extends AbstractInliningCalculator[AssemblyLine] { private val sizes = Seq(64, 64, 8, 6, 5, 5, 4) - def calculate(program: Program, - inlineByDefault: Boolean, - aggressivenessForNormal: Double, - aggressivenessForRecommended: Double): InliningResult = { - val callCount = mutable.Map[String, Int]().withDefaultValue(0) - val allFunctions = mutable.Set[String]() - val badFunctions = mutable.Set[String]() - val recommendedFunctions = mutable.Set[String]() - getAllCalledFunctions(program.declarations).foreach{ - case (name, true) => badFunctions += name - case (name, false) => callCount(name) += 1 - } - program.declarations.foreach{ - case f:FunctionDeclarationStatement => - allFunctions += f.name - if (f.inlinable.contains(true)) { - recommendedFunctions += f.name - } - if (f.isMacro - || f.inlinable.contains(false) - || f.address.isDefined - || f.interrupt - || f.reentrant - || f.name == "main" - || containsReturnDispatch(f.statements.getOrElse(Nil))) badFunctions += f.name - case _ => - } - allFunctions --= badFunctions - recommendedFunctions --= badFunctions - val map = (if (inlineByDefault) allFunctions else recommendedFunctions).map(f => f -> { - val size = sizes(callCount(f) min (sizes.size - 1)) - val aggressiveness = if (recommendedFunctions(f)) aggressivenessForRecommended else aggressivenessForNormal - (size * aggressiveness).floor.toInt - }).toMap - InliningResult(map, badFunctions.toSet) - } - - private def containsReturnDispatch(statements: Seq[Statement]): Boolean = statements.exists { - case _: ReturnDispatchStatement => true - case c: CompoundStatement => containsReturnDispatch(c.getChildStatements) - case _ => false - } - - private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap { - case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) - case ReturnDispatchStatement(index, params, branches) => - getAllCalledFunctions(List(index)) ++ getAllCalledFunctions(params) ++ getAllCalledFunctions(branches.map(b => b.function)) - case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList) - case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions) - case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) - case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) - case MosAssemblyStatement(JSR, _, VariableExpression(name), true) => (name -> false) :: Nil - case s: Statement => getAllCalledFunctions(s.getAllExpressions) - case s: VariableExpression => Set( - s.name, - s.name.stripSuffix(".addr"), - s.name.stripSuffix(".hi"), - s.name.stripSuffix(".lo"), - s.name.stripSuffix(".addr.lo"), - s.name.stripSuffix(".addr.hi")).toList.map(_ -> true) - case s: LiteralExpression => Nil - case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil) - case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2)) - case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs) - case IndexedExpression(arr, index) => (arr -> true) :: getAllCalledFunctions(List(index)) - case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l)) - case _ => Nil - } - private val badOpcodes = Set(RTI, RTS, JSR, BRK, RTL, BSR, BYTE) ++ OpcodeClasses.ChangesStack private val jumpingRelatedOpcodes = Set(LABEL, JMP) ++ OpcodeClasses.ShortBranching