From f31086e6861fe1759ae1212d2d22d5b364cf7fdf Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Wed, 28 Feb 2018 01:11:14 +0100 Subject: [PATCH] Inlining improvements and bugfixes --- src/main/scala/millfork/node/CallGraph.scala | 6 ++ .../scala/millfork/output/Assembler.scala | 57 ++++++++++++------- .../millfork/output/InliningCalculator.scala | 16 ++++-- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/src/main/scala/millfork/node/CallGraph.scala b/src/main/scala/millfork/node/CallGraph.scala index c70b67aa..1b92401f 100644 --- a/src/main/scala/millfork/node/CallGraph.scala +++ b/src/main/scala/millfork/node/CallGraph.scala @@ -51,9 +51,15 @@ abstract class CallGraph(program: Program) { currentFunction.foreach(f => callEdges += f -> g.functionName) callingFunctions.foreach(f => paramEdges += f -> g.functionName) g.expressions.foreach(expr => add(currentFunction, g.functionName :: callingFunctions, expr)) + case s: SumExpression => + s.expressions.foreach(expr => add(currentFunction, callingFunctions, expr._2)) case x: VariableExpression => val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") everCalledFunctions += varName + case i: IndexedExpression => + val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") + everCalledFunctions += varName + add(currentFunction, callingFunctions, i.index) case _ => () } } diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala index abb5aef4..8b69cd41 100644 --- a/src/main/scala/millfork/output/Assembler.scala +++ b/src/main/scala/millfork/output/Assembler.scala @@ -69,22 +69,28 @@ class Assembler(private val program: Program, private val rootEnv: Environment) c match { case NumericConstant(v, _) => v case MemoryAddressConstant(th) => - if (labelMap.contains(th.name)) return labelMap(th.name) - if (labelMap.contains(th.name + "`")) return labelMap(th.name) - if (labelMap.contains(th.name + ".addr")) return labelMap(th.name) - val x1 = env.maybeGet[ConstantThing](th.name).map(_.value) - val x2 = env.maybeGet[ConstantThing](th.name + "`").map(_.value) - val x3 = env.maybeGet[NormalFunction](th.name).flatMap(_.address) - val x4 = env.maybeGet[ConstantThing](th.name + ".addr").map(_.value) - val x5 = env.maybeGet[RelativeVariable](th.name).map(_.address) - val x6 = env.maybeGet[ConstantThing](th.name.stripSuffix(".array") + ".addr").map(_.value) - val x = x1.orElse(x2).orElse(x3).orElse(x4).orElse(x5).orElse(x6) - x match { - case Some(cc) => - deepConstResolve(cc) - case None => - println(th) - ??? + try { + if (labelMap.contains(th.name)) return labelMap(th.name) + if (labelMap.contains(th.name + "`")) return labelMap(th.name) + if (labelMap.contains(th.name + ".addr")) return labelMap(th.name) + val x1 = env.maybeGet[ConstantThing](th.name).map(_.value) + val x2 = env.maybeGet[ConstantThing](th.name + "`").map(_.value) + val x3 = env.maybeGet[NormalFunction](th.name).flatMap(_.address) + val x4 = env.maybeGet[ConstantThing](th.name + ".addr").map(_.value) + val x5 = env.maybeGet[RelativeVariable](th.name).map(_.address) + val x6 = env.maybeGet[ConstantThing](th.name.stripSuffix(".array") + ".addr").map(_.value) + val x = x1.orElse(x2).orElse(x3).orElse(x4).orElse(x5).orElse(x6) + x match { + case Some(cc) => + deepConstResolve(cc) + case None => + println(th) + ??? + } + } catch { + case e: StackOverflowError => + e.printStackTrace() + ErrorReporting.fatal("Stack overflow") } case HalfWordConstant(cc, true) => deepConstResolve(cc).>>>(8).&(0xff) case HalfWordConstant(cc, false) => deepConstResolve(cc).&(0xff) @@ -157,7 +163,8 @@ class Assembler(private val program: Program, private val rootEnv: Environment) var inlinedFunctions = Map[String, List[AssemblyLine]]() val compiledFunctions = mutable.Map[String, List[AssemblyLine]]() - callGraph.recommendedCompilationOrder.foreach { f => + val recommendedCompilationOrder = callGraph.recommendedCompilationOrder + recommendedCompilationOrder.foreach { f => env.maybeGet[NormalFunction](f).foreach { function => val code = compileFunction(function, optimizations, options, inlinedFunctions) val strippedCodeForInlining = for { @@ -308,9 +315,21 @@ class Assembler(private val program: Program, private val rootEnv: Environment) val unoptimized = MfCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).flatMap { case AssemblyLine(Opcode.JSR, _, p, true) if inlinedFunctions.contains(p.toString) => - inlinedFunctions(p.toString) + val labelPrefix = MfCompiler.nextLabel("ai") + inlinedFunctions(p.toString).map{ + case line@AssemblyLine(_, _, MemoryAddressConstant(Label(label)), _) => + val newLabel = MemoryAddressConstant(Label(labelPrefix + label)) + line.copy(parameter = newLabel) + case l => l + } case AssemblyLine(Opcode.JMP, AddrMode.Absolute, p, true) if inlinedFunctions.contains(p.toString) => - inlinedFunctions(p.toString) :+ AssemblyLine.implied(Opcode.RTS) + val labelPrefix = MfCompiler.nextLabel("ai") + inlinedFunctions(p.toString).map{ + case line@AssemblyLine(_, _, MemoryAddressConstant(Label(label)), _) => + val newLabel = MemoryAddressConstant(Label(labelPrefix + label)) + line.copy(parameter = newLabel) + case l => l + } :+ AssemblyLine.implied(Opcode.RTS) case x => List(x) } unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum diff --git a/src/main/scala/millfork/output/InliningCalculator.scala b/src/main/scala/millfork/output/InliningCalculator.scala index cd12ad79..b9f9537d 100644 --- a/src/main/scala/millfork/output/InliningCalculator.scala +++ b/src/main/scala/millfork/output/InliningCalculator.scala @@ -1,7 +1,8 @@ package millfork.output -import millfork.assembly.{AssemblyLine, Opcode, OpcodeClasses} +import millfork.assembly.{AddrMode, AssemblyLine, Opcode, OpcodeClasses} import millfork.assembly.Opcode._ +import millfork.compiler.{ExpressionCompiler, MfCompiler} import millfork.env._ import millfork.node._ @@ -75,10 +76,8 @@ object InliningCalculator { case _ => Nil } - private val badOpcodes = - Set(RTI, RTS, JSR, JMP, LABEL, BRK) ++ - OpcodeClasses.ShortBranching ++ - OpcodeClasses.ChangesStack + private val badOpcodes = Set(RTI, RTS, JSR, BRK) ++ OpcodeClasses.ChangesStack + private val jumpingRelatedOpcodes = Set(LABEL, JMP) ++ OpcodeClasses.ShortBranching def codeForInlining(fname: String, code: List[AssemblyLine]): Option[List[AssemblyLine]] = { if (code.isEmpty) return None @@ -88,7 +87,12 @@ object InliningCalculator { result = result.init } if (result.head.opcode == LABEL && result.head.parameter == Label(fname).toAddress) result = result.tail - if (result.exists(l => badOpcodes(l.opcode))) return None + if (result.exists{ + case AssemblyLine(op, AddrMode.Absolute | AddrMode.Relative, MemoryAddressConstant(Label(l)), _) if jumpingRelatedOpcodes(op) => + !l.startsWith(".") + case AssemblyLine(op, _, _, _) if jumpingRelatedOpcodes(op) || badOpcodes(op) => true + case _ => false + }) return None Some(result) } }