1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-11-04 09:04:33 +00:00

Refactoring the inlining calculators

This commit is contained in:
Karol Stasiak 2018-07-16 23:05:16 +02:00
parent 28e11873dc
commit f4ecc2512c
3 changed files with 78 additions and 70 deletions

View File

@ -189,7 +189,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
val assembly = mutable.ArrayBuffer[String]() val assembly = mutable.ArrayBuffer[String]()
val inliningResult = MosInliningCalculator.calculate( val inliningResult = inliningCalculator.calculate(
program, program,
options.flags(CompilationFlag.InlineFunctions) || options.flags(CompilationFlag.OptimizeForSonicSpeed), options.flags(CompilationFlag.InlineFunctions) || options.flags(CompilationFlag.OptimizeForSonicSpeed),
if (options.flags(CompilationFlag.OptimizeForSonicSpeed)) 4.0 if (options.flags(CompilationFlag.OptimizeForSonicSpeed)) 4.0

View File

@ -1,7 +1,12 @@
package millfork.output package millfork.output
import millfork.assembly.AbstractCode import millfork.assembly.AbstractCode
import millfork.assembly.mos.Opcode
import millfork.assembly.z80.ZOpcode
import millfork.compiler.AbstractCompiler import millfork.compiler.AbstractCompiler
import millfork.node._
import scala.collection.mutable
/** /**
* @author Karol Stasiak * @author Karol Stasiak
@ -9,4 +14,76 @@ import millfork.compiler.AbstractCompiler
abstract class AbstractInliningCalculator[T <: AbstractCode] { abstract class AbstractInliningCalculator[T <: AbstractCode] {
def codeForInlining(fname: String, functionsAlreadyKnownToBeNonInlineable: Set[String], code: List[T]): Option[List[T]] def codeForInlining(fname: String, functionsAlreadyKnownToBeNonInlineable: Set[String], code: List[T]): Option[List[T]]
def inline(code: List[T], inlinedFunctions: Map[String, List[T]], compiler: AbstractCompiler[T]): List[T] def inline(code: List[T], inlinedFunctions: Map[String, List[T]], compiler: AbstractCompiler[T]): List[T]
private val sizes = Seq(64, 64, 8, 6, 5, 5, 4)
def calculate(program: Program,
inlineByDefault: Boolean,
aggressivenessForNormal: Double,
aggressivenessForRecommended: Double): InliningResult = {
val callCount = mutable.Map[String, Int]().withDefaultValue(0)
val allFunctions = mutable.Set[String]()
val badFunctions = mutable.Set[String]()
val recommendedFunctions = mutable.Set[String]()
getAllCalledFunctions(program.declarations).foreach{
case (name, true) => badFunctions += name
case (name, false) => callCount(name) += 1
}
program.declarations.foreach{
case f:FunctionDeclarationStatement =>
allFunctions += f.name
if (f.inlinable.contains(true)) {
recommendedFunctions += f.name
}
if (f.isMacro
|| f.inlinable.contains(false)
|| f.address.isDefined
|| f.interrupt
|| f.reentrant
|| f.name == "main"
|| containsReturnDispatch(f.statements.getOrElse(Nil))) badFunctions += f.name
case _ =>
}
allFunctions --= badFunctions
recommendedFunctions --= badFunctions
val map = (if (inlineByDefault) allFunctions else recommendedFunctions).map(f => f -> {
val size = sizes(callCount(f) min (sizes.size - 1))
val aggressiveness = if (recommendedFunctions(f)) aggressivenessForRecommended else aggressivenessForNormal
(size * aggressiveness).floor.toInt
}).toMap
InliningResult(map, badFunctions.toSet)
}
protected def containsReturnDispatch(statements: Seq[Statement]): Boolean = statements.exists {
case _: ReturnDispatchStatement => true
case c: CompoundStatement => containsReturnDispatch(c.getChildStatements)
case _ => false
}
protected def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap {
case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList)
case ReturnDispatchStatement(index, params, branches) =>
getAllCalledFunctions(List(index)) ++ getAllCalledFunctions(params) ++ getAllCalledFunctions(branches.map(b => b.function))
case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList)
case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions)
case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)
case MosAssemblyStatement(Opcode.JSR, _, VariableExpression(name), true) => (name -> false) :: Nil
case Z80AssemblyStatement(ZOpcode.CALL, _, _, VariableExpression(name), true) => (name -> false) :: Nil
case s: Statement => getAllCalledFunctions(s.getAllExpressions)
case s: VariableExpression => Set(
s.name,
s.name.stripSuffix(".addr"),
s.name.stripSuffix(".hi"),
s.name.stripSuffix(".lo"),
s.name.stripSuffix(".addr.lo"),
s.name.stripSuffix(".addr.hi")).toList.map(_ -> true)
case s: LiteralExpression => Nil
case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil)
case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2))
case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs)
case IndexedExpression(arr, index) => (arr -> true) :: getAllCalledFunctions(List(index))
case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l))
case _ => Nil
}
} }

View File

@ -18,75 +18,6 @@ object MosInliningCalculator extends AbstractInliningCalculator[AssemblyLine] {
private val sizes = Seq(64, 64, 8, 6, 5, 5, 4) private val sizes = Seq(64, 64, 8, 6, 5, 5, 4)
def calculate(program: Program,
inlineByDefault: Boolean,
aggressivenessForNormal: Double,
aggressivenessForRecommended: Double): InliningResult = {
val callCount = mutable.Map[String, Int]().withDefaultValue(0)
val allFunctions = mutable.Set[String]()
val badFunctions = mutable.Set[String]()
val recommendedFunctions = mutable.Set[String]()
getAllCalledFunctions(program.declarations).foreach{
case (name, true) => badFunctions += name
case (name, false) => callCount(name) += 1
}
program.declarations.foreach{
case f:FunctionDeclarationStatement =>
allFunctions += f.name
if (f.inlinable.contains(true)) {
recommendedFunctions += f.name
}
if (f.isMacro
|| f.inlinable.contains(false)
|| f.address.isDefined
|| f.interrupt
|| f.reentrant
|| f.name == "main"
|| containsReturnDispatch(f.statements.getOrElse(Nil))) badFunctions += f.name
case _ =>
}
allFunctions --= badFunctions
recommendedFunctions --= badFunctions
val map = (if (inlineByDefault) allFunctions else recommendedFunctions).map(f => f -> {
val size = sizes(callCount(f) min (sizes.size - 1))
val aggressiveness = if (recommendedFunctions(f)) aggressivenessForRecommended else aggressivenessForNormal
(size * aggressiveness).floor.toInt
}).toMap
InliningResult(map, badFunctions.toSet)
}
private def containsReturnDispatch(statements: Seq[Statement]): Boolean = statements.exists {
case _: ReturnDispatchStatement => true
case c: CompoundStatement => containsReturnDispatch(c.getChildStatements)
case _ => false
}
private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap {
case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList)
case ReturnDispatchStatement(index, params, branches) =>
getAllCalledFunctions(List(index)) ++ getAllCalledFunctions(params) ++ getAllCalledFunctions(branches.map(b => b.function))
case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList)
case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions)
case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)
case MosAssemblyStatement(JSR, _, VariableExpression(name), true) => (name -> false) :: Nil
case s: Statement => getAllCalledFunctions(s.getAllExpressions)
case s: VariableExpression => Set(
s.name,
s.name.stripSuffix(".addr"),
s.name.stripSuffix(".hi"),
s.name.stripSuffix(".lo"),
s.name.stripSuffix(".addr.lo"),
s.name.stripSuffix(".addr.hi")).toList.map(_ -> true)
case s: LiteralExpression => Nil
case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil)
case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2))
case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs)
case IndexedExpression(arr, index) => (arr -> true) :: getAllCalledFunctions(List(index))
case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l))
case _ => Nil
}
private val badOpcodes = Set(RTI, RTS, JSR, BRK, RTL, BSR, BYTE) ++ OpcodeClasses.ChangesStack private val badOpcodes = Set(RTI, RTS, JSR, BRK, RTL, BSR, BYTE) ++ OpcodeClasses.ChangesStack
private val jumpingRelatedOpcodes = Set(LABEL, JMP) ++ OpcodeClasses.ShortBranching private val jumpingRelatedOpcodes = Set(LABEL, JMP) ++ OpcodeClasses.ShortBranching