diff --git a/src/main/scala/millfork/CompilationOptions.scala b/src/main/scala/millfork/CompilationOptions.scala index bcfcb6b4..be845a12 100644 --- a/src/main/scala/millfork/CompilationOptions.scala +++ b/src/main/scala/millfork/CompilationOptions.scala @@ -75,7 +75,7 @@ object CompilationFlag extends Enumeration { // compilation options: EmitIllegals, EmitCmosOpcodes, DecimalMode, ReadOnlyArrays, PreventJmpIndirectBug, // optimization options: - DetailedFlowAnalysis, DangerousOptimizations, + DetailedFlowAnalysis, DangerousOptimizations, InlineFunctions, // memory allocation options VariableOverlap, // warning options diff --git a/src/main/scala/millfork/Main.scala b/src/main/scala/millfork/Main.scala index 21edf0ab..9a5adbd3 100644 --- a/src/main/scala/millfork/Main.scala +++ b/src/main/scala/millfork/Main.scala @@ -96,7 +96,7 @@ object Main { } // compile - val assembler = new Assembler(env) + val assembler = new Assembler(program, env) val result = assembler.assemble(callGraph, assemblyOptimizations, options) ErrorReporting.assertNoErrors("Codegen failed") ErrorReporting.debug(f"Unoptimized code size: ${assembler.unoptimizedCodeSize}%5d B") @@ -214,6 +214,9 @@ object Main { }.description("Optimize code even more.") if (i > 3) f.hidden() } + flag("--inline").action { c => + c.changeFlag(CompilationFlag.InlineFunctions, true) + }.description("Inline functions automatically (experimental).") flag("--detailed-flow").action { c => c.changeFlag(CompilationFlag.DetailedFlowAnalysis, true) }.description("Use detailed flow analysis (experimental).") diff --git a/src/main/scala/millfork/assembly/AssemblyLine.scala b/src/main/scala/millfork/assembly/AssemblyLine.scala index 308beca4..a34c3268 100644 --- a/src/main/scala/millfork/assembly/AssemblyLine.scala +++ b/src/main/scala/millfork/assembly/AssemblyLine.scala @@ -87,14 +87,16 @@ object OpcodeClasses { val ConcernsXAlways = ReadsXAlways | ChangesX val ConcernsYAlways = ReadsYAlways | ChangesY - val ConcernsStack = Set( + val ChangesStack = Set( PHA, PLA, PHP, PLP, PHX, PLX, PHY, PLY, - TXS, TSX, + TXS, JSR, RTS, RTI, TAS, LAS, ) + val ConcernsStack = ChangesStack + TSX + val ChangesNAndZ = Set( ADC, AND, ASL, BIT, CMP, CPX, CPY, DEC, DEX, DEY, EOR, INC, INX, INY, LDA, LDX, LDY, LSR, ORA, PLP, ROL, ROR, SBC, TAX, TAY, TXA, TYA, diff --git a/src/main/scala/millfork/compiler/DecimalBuiltIns.scala b/src/main/scala/millfork/compiler/DecimalBuiltIns.scala new file mode 100644 index 00000000..3aeec5f0 --- /dev/null +++ b/src/main/scala/millfork/compiler/DecimalBuiltIns.scala @@ -0,0 +1,69 @@ +package millfork.compiler +import millfork.{CompilationFlag, CompilationOptions} +import millfork.assembly._ +import millfork.env._ +import millfork.node._ +import millfork.assembly.Opcode._ +import millfork.assembly.AddrMode._ +import millfork.error.ErrorReporting + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import millfork.assembly.AssemblyLine +import millfork.env.{NumericConstant, RegisterVariable, Type} +import millfork.error.ErrorReporting +import millfork.node.{Expression, Register} + +/** + * @author Karol Stasiak + */ +object DecimalBuiltIns { + def compileByteShiftLeft(ctx: CompilationContext, l: Expression, r: Expression, rotate: Boolean): List[AssemblyLine] = { + val b = ctx.env.get[Type]("byte") + ctx.env.eval(r) match { + case Some(NumericConstant(0, _)) => + Nil + case Some(NumericConstant(v, _)) => + val addition = BuiltIns.compileAddition(ctx, List.fill(1< l), decimal = true) + if (rotate) addition.filterNot(_.opcode == CLC) else addition + case _ => + ErrorReporting.error("Cannot shift by a non-constant amount") + Nil + } + } + + def compileByteShiftRight(ctx: CompilationContext, l: Expression, r: Expression, rotate: Boolean): List[AssemblyLine] = { + val b = ctx.env.get[Type]("byte") + ctx.env.eval(r) match { + case Some(NumericConstant(0, _)) => + Nil + case Some(NumericConstant(1, _)) => + val constantLabel = MlCompiler.nextLabel("c8") + val skipHiDigit = MlCompiler.nextLabel("ds") + val skipLoDigit = MlCompiler.nextLabel("ds") + val bit = if (ctx.options.flags(CompilationFlag.EmitCmosOpcodes)) { + AssemblyLine.immediate(BIT, 8) + } else { + AssemblyLine.absolute(BIT, Label(constantLabel)) + } + List( + if (rotate) AssemblyLine.implied(ROR) else AssemblyLine.implied(LSR), + AssemblyLine.label(constantLabel), + AssemblyLine.implied(PHP), + AssemblyLine.relative(BPL, skipHiDigit), + AssemblyLine.implied(SEC), + AssemblyLine.immediate(SBC, 0x30), + AssemblyLine.label(skipHiDigit), + bit, + AssemblyLine.relative(BPL, skipLoDigit), + AssemblyLine.implied(SEC), + AssemblyLine.immediate(SBC, 0x3), + AssemblyLine.label(skipLoDigit), + AssemblyLine.implied(PLP)) + case _ => + ErrorReporting.error("Cannot shift by a non-constant amount") + Nil + } + } + +} diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala index 2c2d7123..db51d308 100644 --- a/src/main/scala/millfork/output/Assembler.scala +++ b/src/main/scala/millfork/output/Assembler.scala @@ -5,7 +5,7 @@ import millfork.assembly.{AddrMode, AssemblyLine, Opcode} import millfork.compiler.{CompilationContext, MlCompiler} import millfork.env._ import millfork.error.ErrorReporting -import millfork.node.CallGraph +import millfork.node.{CallGraph, Program} import millfork.{CompilationFlag, CompilationOptions, Tarjan} import scala.collection.mutable @@ -16,7 +16,7 @@ import scala.collection.mutable case class AssemblerOutput(code: Array[Byte], asm: Array[String], labels: List[(String, Int)]) -class Assembler(private val rootEnv: Environment) { +class Assembler(private val program: Program, private val rootEnv: Environment) { var env = rootEnv.allThings var unoptimizedCodeSize = 0 @@ -143,11 +143,31 @@ class Assembler(private val rootEnv: Environment) { val assembly = mutable.ArrayBuffer[String]() + val potentiallyInlineable: Map[String, Int] = + if (options.flags(CompilationFlag.InlineFunctions)) + InliningCalculator.getPotentiallyInlineableFunctions(program) + else Map() + + var inlinedFunctions = Map[String, List[AssemblyLine]]() val compiledFunctions = mutable.Map[String, List[AssemblyLine]]() callGraph.recommendedCompilationOrder.foreach{ f => - env.maybeGet[NormalFunction](f).foreach( function => - compiledFunctions(f) = compileFunction(function, optimizations, options) - ) + env.maybeGet[NormalFunction](f).foreach{ function => + val code = compileFunction(function, optimizations, options, inlinedFunctions) + val strippedCodeForInlining = for { + limit <- potentiallyInlineable.get(f) + if code.map(_.sizeInBytes).sum <= limit + s <- InliningCalculator.codeForInlining(f, code) + } yield s + strippedCodeForInlining match { + case Some(c) => + ErrorReporting.debug("Inlining " + f, function.position) + inlinedFunctions += f -> c + compiledFunctions(f) = Nil + case None => + compiledFunctions(f) = code + optimizedCodeSize += code.map(_.sizeInBytes).sum + } + } } env.allPreallocatables.foreach { @@ -249,14 +269,20 @@ class Assembler(private val rootEnv: Environment) { AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList) } - private def compileFunction(f: NormalFunction, optimizations: Seq[AssemblyOptimization], options: CompilationOptions) :List[AssemblyLine] = { + private def compileFunction(f: NormalFunction, optimizations: Seq[AssemblyOptimization], options: CompilationOptions, inlinedFunctions: Map[String, List[AssemblyLine]]) :List[AssemblyLine] = { ErrorReporting.debug("Compiling: " + f.name, f.position) - val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize + val unoptimized = + MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize.flatMap{ + case AssemblyLine(Opcode.JSR, _, p, true) if inlinedFunctions.contains(p.toString) => + inlinedFunctions(p.toString) + case AssemblyLine(Opcode.JMP, AddrMode.Absolute, p, true) if inlinedFunctions.contains(p.toString) => + inlinedFunctions(p.toString) :+ AssemblyLine.implied(Opcode.RTS) + case x => List(x) + } unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum val code = optimizations.foldLeft(unoptimized) { (c, opt) => opt.optimize(f, c, options) } - optimizedCodeSize += code.map(_.sizeInBytes).sum code } diff --git a/src/main/scala/millfork/output/InliningCalculator.scala b/src/main/scala/millfork/output/InliningCalculator.scala new file mode 100644 index 00000000..5fecf611 --- /dev/null +++ b/src/main/scala/millfork/output/InliningCalculator.scala @@ -0,0 +1,75 @@ +package millfork.output + +import millfork.assembly.{AssemblyLine, Opcode, OpcodeClasses} +import millfork.assembly.Opcode._ +import millfork.env._ +import millfork.node._ + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +object InliningCalculator { + + private val sizes = Seq(30, 30, 8, 6, 5, 5, 4) + + def getPotentiallyInlineableFunctions(program: Program): Map[String, Int] = { + val callCount = mutable.Map[String, Int]().withDefaultValue(0) + val allFunctions = mutable.Set[String]() + val badFunctions = mutable.Set[String]() + getAllCalledFunctions(program.declarations).foreach{ + case (name, true) => badFunctions += name + case (name, false) => callCount(name) += 1 + } + program.declarations.foreach{ + case f:FunctionDeclarationStatement => + allFunctions += f.name + if (f.inlined) badFunctions += f.name + if (f.address.isDefined) badFunctions += f.name + if (f.interrupt) badFunctions += f.name + if (f.reentrant) badFunctions += f.name + if (f.name == "main") badFunctions += f.name + case _ => + } + allFunctions --= badFunctions + allFunctions.map(f => f -> sizes(callCount(f) min (sizes.size - 1))).toMap + } + + private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap { + case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) + case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) + case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) + case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) + case AssemblyStatement(JSR, _, VariableExpression(name), true) => (name -> false) :: Nil + case s: Statement => getAllCalledFunctions(s.getAllExpressions) + case s: VariableExpression => Set( + s.name, + s.name.stripSuffix(".addr"), + s.name.stripSuffix(".hi"), + s.name.stripSuffix(".lo"), + s.name.stripSuffix(".addr.lo"), + s.name.stripSuffix(".addr.hi")).toList.map(_ -> true) + case s: LiteralExpression => Nil + case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil) + case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2)) + case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs) + case IndexedExpression(arr, index) => (arr -> true) :: getAllCalledFunctions(List(index)) + case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l)) + case _ => Nil + } + + private val badOpcodes = + Set(RTI, RTS, JSR, JMP, LABEL, BRK) ++ + OpcodeClasses.ShortBranching ++ + OpcodeClasses.ChangesStack + + def codeForInlining(fname: String, code: List[AssemblyLine]): Option[List[AssemblyLine]] = { + if (code.isEmpty) return None + if (code.last.opcode != RTS) return None + var result = code.init + if (result.head.opcode == LABEL && result.head.parameter == Label(fname).toAddress) result = result.tail + if (result.exists(l => badOpcodes(l.opcode))) return None + Some(result) + } +} diff --git a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala index e0dea4ca..bf0b663d 100644 --- a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala +++ b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala @@ -9,12 +9,17 @@ object EmuBenchmarkRun { def apply(source: String)(verifier: MemoryBank => Unit) = { val (Timings(t0, _), m0) = EmuUnoptimizedRun.apply2(source) val (Timings(t1, _), m1) = EmuOptimizedRun.apply2(source) + val (Timings(t2, _), m2) = EmuOptimizedInlinedRun.apply2(source) println(f"Before optimization: $t0%7d") println(f"After optimization: $t1%7d") + println(f"After inlining: $t2%7d") println(f"Gain: ${(100L * (t0 - t1) / t0.toDouble).round}%7d%%") + println(f"Gain with inlining: ${(100L * (t0 - t2) / t0.toDouble).round}%7d%%") println(f"Running unoptimized") verifier(m0) println(f"Running optimized") verifier(m1) + println(f"Running optimized inlined") + verifier(m2) } } diff --git a/src/test/scala/millfork/test/emu/EmuOptimizedInlinedRun.scala b/src/test/scala/millfork/test/emu/EmuOptimizedInlinedRun.scala new file mode 100644 index 00000000..8506cfb0 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuOptimizedInlinedRun.scala @@ -0,0 +1,15 @@ +package millfork.test.emu + +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +object EmuOptimizedInlinedRun extends EmuRun( + Cpu.StrictMos, + OptimizationPresets.NodeOpt, + OptimizationPresets.AssOpt ++ OptimizationPresets.Good ++ OptimizationPresets.Good ++ OptimizationPresets.Good, + false) + + + diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala index 19d79c82..9599931d 100644 --- a/src/test/scala/millfork/test/emu/EmuRun.scala +++ b/src/test/scala/millfork/test/emu/EmuRun.scala @@ -28,6 +28,8 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], def emitIllegals = false + def inline = false + private val timingNmos = Array[Int]( 7, 6, 0, 8, 3, 3, 5, 5, 3, 2, 2, 2, 4, 4, 6, 6, 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, @@ -91,6 +93,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], val options = new CompilationOptions(platform, Map( CompilationFlag.EmitIllegals -> this.emitIllegals, CompilationFlag.DetailedFlowAnalysis -> quantum, + CompilationFlag.InlineFunctions -> this.inline, )) ErrorReporting.hasErrors = false ErrorReporting.verbosity = 999 @@ -107,42 +110,31 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], env.collectDeclarations(program, options) val hasOptimizations = assemblyOptimizations.nonEmpty - var optimizedSize = 0L var unoptimizedSize = 0L // print asm env.allPreallocatables.foreach { case f: NormalFunction => val result = MlCompiler.compile(CompilationContext(f.environment, f, 0, options)) val unoptimized = result.linearize - if (hasOptimizations) { - val optimized = assemblyOptimizations.foldLeft(unoptimized) { (c, opt) => - opt.optimize(f, c, options) - } - println("Unoptimized:") - unoptimized.filter(_.isPrintable).foreach(println(_)) - println("Optimized:") - optimized.filter(_.isPrintable).foreach(println(_)) - unoptimizedSize += unoptimized.map(_.sizeInBytes).sum - optimizedSize += optimized.map(_.sizeInBytes).sum - } else { - unoptimized.filter(_.isPrintable).foreach(println(_)) - unoptimizedSize += unoptimized.map(_.sizeInBytes).sum - optimizedSize += unoptimized.map(_.sizeInBytes).sum - } + unoptimizedSize += unoptimized.map(_.sizeInBytes).sum case d: InitializedArray => - println(d.name) - d.contents.foreach(c => println(" !byte " + c)) unoptimizedSize += d.contents.length - optimizedSize += d.contents.length case d: InitializedMemoryVariable => - println(d.name) - 0.until(d.typ.size).foreach(c => println(" !byte " + d.initialValue.subbyte(c))) unoptimizedSize += d.typ.size - optimizedSize += d.typ.size } ErrorReporting.assertNoErrors("Compile failed") + + // compile + val assembler = new Assembler(program, env) + val output = assembler.assemble(callGraph, assemblyOptimizations, options) + println(";;; compiled: -----------------") + output.asm.takeWhile(s => !(s.startsWith(".") && s.contains("= $"))).foreach(println) + println(";;; ---------------------------") + assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") } + + val optimizedSize = assembler.mem.banks(0).occupied.count(identity).toLong if (unoptimizedSize == optimizedSize) { println(f"Size: $unoptimizedSize%5d B") } else { @@ -151,11 +143,6 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], println(f"Gain: ${(100L * (unoptimizedSize - optimizedSize) / unoptimizedSize.toDouble).round}%5d%%") } - // compile - val assembler = new Assembler(env) - assembler.assemble(callGraph, assemblyOptimizations, options) - assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") } - ErrorReporting.assertNoErrors("Code generation failed") val memoryBank = assembler.mem.banks(0) diff --git a/src/test/scala/millfork/test/emu/EmuSuperOptimizedInliningRun.scala b/src/test/scala/millfork/test/emu/EmuSuperOptimizedInliningRun.scala new file mode 100644 index 00000000..98139190 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuSuperOptimizedInliningRun.scala @@ -0,0 +1,18 @@ +package millfork.test.emu + +import millfork.assembly.opt.SuperOptimizer +import millfork.{Cpu, OptimizationPresets} + +/** + * @author Karol Stasiak + */ +object EmuSuperQuantumOptimizedInliningRun extends EmuRun( + Cpu.StrictMos, + OptimizationPresets.NodeOpt, + List(SuperOptimizer), + true) { + override def inline = true +} + + + diff --git a/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala index 30be3c0d..f39d8c24 100644 --- a/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala +++ b/src/test/scala/millfork/test/emu/EmuSuperoptimizedRun.scala @@ -6,7 +6,6 @@ import millfork.{Cpu, OptimizationPresets} /** * @author Karol Stasiak */ -// TODO : it doesn't work object EmuSuperOptimizedRun extends EmuRun( Cpu.StrictMos, OptimizationPresets.NodeOpt, diff --git a/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala index 008b7dcc..40e0bbe6 100644 --- a/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala +++ b/src/test/scala/millfork/test/emu/EmuUltraBenchmarkRun.scala @@ -12,15 +12,18 @@ object EmuUltraBenchmarkRun { val (Timings(t2, _), m2) = EmuSuperOptimizedRun.apply2(source) val (Timings(t3, _), m3) = EmuQuantumOptimizedRun.apply2(source) val (Timings(t4, _), m4) = EmuSuperQuantumOptimizedRun.apply2(source) + val (Timings(t5, _), m5) = EmuSuperQuantumOptimizedInliningRun.apply2(source) println(f"Before optimization: $t0%7d") println(f"After optimization: $t1%7d") println(f"After superopt.: $t2%7d") println(f"After quantum: $t3%7d") println(f"After superquantum: $t4%7d") + println(f"After inlining: $t5%7d") println(f"Gain: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%") println(f"Superopt. gain: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%") println(f"Quantum gain: ${(100L*(t0-t3)/t0.toDouble).round}%7d%%") println(f"Super quantum gain: ${(100L*(t0-t4)/t0.toDouble).round}%7d%%") + println(f"SQ+inlining gain: ${(100L*(t0-t5)/t0.toDouble).round}%7d%%") println(f"Running unoptimized") verifier(m0) println(f"Running optimized") @@ -31,5 +34,7 @@ object EmuUltraBenchmarkRun { verifier(m3) println(f"Running superquantum optimized") verifier(m4) + println(f"Running superquantum optimized inlined") + verifier(m5) } }