1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-06 09:33:22 +00:00

Automatic function inlining; test suite changes

This commit is contained in:
Karol Stasiak 2017-12-20 02:50:52 +01:00
parent e78bd0e41a
commit 5c2832f4f3
12 changed files with 244 additions and 40 deletions

View File

@ -75,7 +75,7 @@ object CompilationFlag extends Enumeration {
// compilation options:
EmitIllegals, EmitCmosOpcodes, DecimalMode, ReadOnlyArrays, PreventJmpIndirectBug,
// optimization options:
DetailedFlowAnalysis, DangerousOptimizations,
DetailedFlowAnalysis, DangerousOptimizations, InlineFunctions,
// memory allocation options
VariableOverlap,
// warning options

View File

@ -96,7 +96,7 @@ object Main {
}
// compile
val assembler = new Assembler(env)
val assembler = new Assembler(program, env)
val result = assembler.assemble(callGraph, assemblyOptimizations, options)
ErrorReporting.assertNoErrors("Codegen failed")
ErrorReporting.debug(f"Unoptimized code size: ${assembler.unoptimizedCodeSize}%5d B")
@ -214,6 +214,9 @@ object Main {
}.description("Optimize code even more.")
if (i > 3) f.hidden()
}
flag("--inline").action { c =>
c.changeFlag(CompilationFlag.InlineFunctions, true)
}.description("Inline functions automatically (experimental).")
flag("--detailed-flow").action { c =>
c.changeFlag(CompilationFlag.DetailedFlowAnalysis, true)
}.description("Use detailed flow analysis (experimental).")

View File

@ -87,14 +87,16 @@ object OpcodeClasses {
val ConcernsXAlways = ReadsXAlways | ChangesX
val ConcernsYAlways = ReadsYAlways | ChangesY
val ConcernsStack = Set(
val ChangesStack = Set(
PHA, PLA, PHP, PLP,
PHX, PLX, PHY, PLY,
TXS, TSX,
TXS,
JSR, RTS, RTI,
TAS, LAS,
)
val ConcernsStack = ChangesStack + TSX
val ChangesNAndZ = Set(
ADC, AND, ASL, BIT, CMP, CPX, CPY, DEC, DEX, DEY, EOR, INC, INX, INY, LDA,
LDX, LDY, LSR, ORA, PLP, ROL, ROR, SBC, TAX, TAY, TXA, TYA,

View File

@ -0,0 +1,69 @@
package millfork.compiler
import millfork.{CompilationFlag, CompilationOptions}
import millfork.assembly._
import millfork.env._
import millfork.node._
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
import millfork.error.ErrorReporting
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import millfork.assembly.AssemblyLine
import millfork.env.{NumericConstant, RegisterVariable, Type}
import millfork.error.ErrorReporting
import millfork.node.{Expression, Register}
/**
* @author Karol Stasiak
*/
object DecimalBuiltIns {
def compileByteShiftLeft(ctx: CompilationContext, l: Expression, r: Expression, rotate: Boolean): List[AssemblyLine] = {
val b = ctx.env.get[Type]("byte")
ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(v, _)) =>
val addition = BuiltIns.compileAddition(ctx, List.fill(1<<v)(false -> l), decimal = true)
if (rotate) addition.filterNot(_.opcode == CLC) else addition
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
Nil
}
}
def compileByteShiftRight(ctx: CompilationContext, l: Expression, r: Expression, rotate: Boolean): List[AssemblyLine] = {
val b = ctx.env.get[Type]("byte")
ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(1, _)) =>
val constantLabel = MlCompiler.nextLabel("c8")
val skipHiDigit = MlCompiler.nextLabel("ds")
val skipLoDigit = MlCompiler.nextLabel("ds")
val bit = if (ctx.options.flags(CompilationFlag.EmitCmosOpcodes)) {
AssemblyLine.immediate(BIT, 8)
} else {
AssemblyLine.absolute(BIT, Label(constantLabel))
}
List(
if (rotate) AssemblyLine.implied(ROR) else AssemblyLine.implied(LSR),
AssemblyLine.label(constantLabel),
AssemblyLine.implied(PHP),
AssemblyLine.relative(BPL, skipHiDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x30),
AssemblyLine.label(skipHiDigit),
bit,
AssemblyLine.relative(BPL, skipLoDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x3),
AssemblyLine.label(skipLoDigit),
AssemblyLine.implied(PLP))
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
Nil
}
}
}

View File

@ -5,7 +5,7 @@ import millfork.assembly.{AddrMode, AssemblyLine, Opcode}
import millfork.compiler.{CompilationContext, MlCompiler}
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node.CallGraph
import millfork.node.{CallGraph, Program}
import millfork.{CompilationFlag, CompilationOptions, Tarjan}
import scala.collection.mutable
@ -16,7 +16,7 @@ import scala.collection.mutable
case class AssemblerOutput(code: Array[Byte], asm: Array[String], labels: List[(String, Int)])
class Assembler(private val rootEnv: Environment) {
class Assembler(private val program: Program, private val rootEnv: Environment) {
var env = rootEnv.allThings
var unoptimizedCodeSize = 0
@ -143,11 +143,31 @@ class Assembler(private val rootEnv: Environment) {
val assembly = mutable.ArrayBuffer[String]()
val potentiallyInlineable: Map[String, Int] =
if (options.flags(CompilationFlag.InlineFunctions))
InliningCalculator.getPotentiallyInlineableFunctions(program)
else Map()
var inlinedFunctions = Map[String, List[AssemblyLine]]()
val compiledFunctions = mutable.Map[String, List[AssemblyLine]]()
callGraph.recommendedCompilationOrder.foreach{ f =>
env.maybeGet[NormalFunction](f).foreach( function =>
compiledFunctions(f) = compileFunction(function, optimizations, options)
)
env.maybeGet[NormalFunction](f).foreach{ function =>
val code = compileFunction(function, optimizations, options, inlinedFunctions)
val strippedCodeForInlining = for {
limit <- potentiallyInlineable.get(f)
if code.map(_.sizeInBytes).sum <= limit
s <- InliningCalculator.codeForInlining(f, code)
} yield s
strippedCodeForInlining match {
case Some(c) =>
ErrorReporting.debug("Inlining " + f, function.position)
inlinedFunctions += f -> c
compiledFunctions(f) = Nil
case None =>
compiledFunctions(f) = code
optimizedCodeSize += code.map(_.sizeInBytes).sum
}
}
}
env.allPreallocatables.foreach {
@ -249,14 +269,20 @@ class Assembler(private val rootEnv: Environment) {
AssemblerOutput(platform.outputPackager.packageOutput(mem, 0), assembly.toArray, labelMap.toList)
}
private def compileFunction(f: NormalFunction, optimizations: Seq[AssemblyOptimization], options: CompilationOptions) :List[AssemblyLine] = {
private def compileFunction(f: NormalFunction, optimizations: Seq[AssemblyOptimization], options: CompilationOptions, inlinedFunctions: Map[String, List[AssemblyLine]]) :List[AssemblyLine] = {
ErrorReporting.debug("Compiling: " + f.name, f.position)
val unoptimized = MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize
val unoptimized =
MlCompiler.compile(CompilationContext(env = f.environment, function = f, extraStackOffset = 0, options = options)).linearize.flatMap{
case AssemblyLine(Opcode.JSR, _, p, true) if inlinedFunctions.contains(p.toString) =>
inlinedFunctions(p.toString)
case AssemblyLine(Opcode.JMP, AddrMode.Absolute, p, true) if inlinedFunctions.contains(p.toString) =>
inlinedFunctions(p.toString) :+ AssemblyLine.implied(Opcode.RTS)
case x => List(x)
}
unoptimizedCodeSize += unoptimized.map(_.sizeInBytes).sum
val code = optimizations.foldLeft(unoptimized) { (c, opt) =>
opt.optimize(f, c, options)
}
optimizedCodeSize += code.map(_.sizeInBytes).sum
code
}

View File

@ -0,0 +1,75 @@
package millfork.output
import millfork.assembly.{AssemblyLine, Opcode, OpcodeClasses}
import millfork.assembly.Opcode._
import millfork.env._
import millfork.node._
import scala.collection.mutable
/**
* @author Karol Stasiak
*/
object InliningCalculator {
private val sizes = Seq(30, 30, 8, 6, 5, 5, 4)
def getPotentiallyInlineableFunctions(program: Program): Map[String, Int] = {
val callCount = mutable.Map[String, Int]().withDefaultValue(0)
val allFunctions = mutable.Set[String]()
val badFunctions = mutable.Set[String]()
getAllCalledFunctions(program.declarations).foreach{
case (name, true) => badFunctions += name
case (name, false) => callCount(name) += 1
}
program.declarations.foreach{
case f:FunctionDeclarationStatement =>
allFunctions += f.name
if (f.inlined) badFunctions += f.name
if (f.address.isDefined) badFunctions += f.name
if (f.interrupt) badFunctions += f.name
if (f.reentrant) badFunctions += f.name
if (f.name == "main") badFunctions += f.name
case _ =>
}
allFunctions --= badFunctions
allFunctions.map(f => f -> sizes(callCount(f) min (sizes.size - 1))).toMap
}
private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap {
case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList)
case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil))
case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)
case AssemblyStatement(JSR, _, VariableExpression(name), true) => (name -> false) :: Nil
case s: Statement => getAllCalledFunctions(s.getAllExpressions)
case s: VariableExpression => Set(
s.name,
s.name.stripSuffix(".addr"),
s.name.stripSuffix(".hi"),
s.name.stripSuffix(".lo"),
s.name.stripSuffix(".addr.lo"),
s.name.stripSuffix(".addr.hi")).toList.map(_ -> true)
case s: LiteralExpression => Nil
case HalfWordExpression(param, _) => getAllCalledFunctions(param :: Nil)
case SumExpression(xs, _) => getAllCalledFunctions(xs.map(_._2))
case FunctionCallExpression(name, xs) => (name -> false) :: getAllCalledFunctions(xs)
case IndexedExpression(arr, index) => (arr -> true) :: getAllCalledFunctions(List(index))
case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l))
case _ => Nil
}
private val badOpcodes =
Set(RTI, RTS, JSR, JMP, LABEL, BRK) ++
OpcodeClasses.ShortBranching ++
OpcodeClasses.ChangesStack
def codeForInlining(fname: String, code: List[AssemblyLine]): Option[List[AssemblyLine]] = {
if (code.isEmpty) return None
if (code.last.opcode != RTS) return None
var result = code.init
if (result.head.opcode == LABEL && result.head.parameter == Label(fname).toAddress) result = result.tail
if (result.exists(l => badOpcodes(l.opcode))) return None
Some(result)
}
}

View File

@ -9,12 +9,17 @@ object EmuBenchmarkRun {
def apply(source: String)(verifier: MemoryBank => Unit) = {
val (Timings(t0, _), m0) = EmuUnoptimizedRun.apply2(source)
val (Timings(t1, _), m1) = EmuOptimizedRun.apply2(source)
val (Timings(t2, _), m2) = EmuOptimizedInlinedRun.apply2(source)
println(f"Before optimization: $t0%7d")
println(f"After optimization: $t1%7d")
println(f"After inlining: $t2%7d")
println(f"Gain: ${(100L * (t0 - t1) / t0.toDouble).round}%7d%%")
println(f"Gain with inlining: ${(100L * (t0 - t2) / t0.toDouble).round}%7d%%")
println(f"Running unoptimized")
verifier(m0)
println(f"Running optimized")
verifier(m1)
println(f"Running optimized inlined")
verifier(m2)
}
}

View File

@ -0,0 +1,15 @@
package millfork.test.emu
import millfork.{Cpu, OptimizationPresets}
/**
* @author Karol Stasiak
*/
object EmuOptimizedInlinedRun extends EmuRun(
Cpu.StrictMos,
OptimizationPresets.NodeOpt,
OptimizationPresets.AssOpt ++ OptimizationPresets.Good ++ OptimizationPresets.Good ++ OptimizationPresets.Good,
false)

View File

@ -28,6 +28,8 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
def emitIllegals = false
def inline = false
private val timingNmos = Array[Int](
7, 6, 0, 8, 3, 3, 5, 5, 3, 2, 2, 2, 4, 4, 6, 6,
2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7,
@ -91,6 +93,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
val options = new CompilationOptions(platform, Map(
CompilationFlag.EmitIllegals -> this.emitIllegals,
CompilationFlag.DetailedFlowAnalysis -> quantum,
CompilationFlag.InlineFunctions -> this.inline,
))
ErrorReporting.hasErrors = false
ErrorReporting.verbosity = 999
@ -107,42 +110,31 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
env.collectDeclarations(program, options)
val hasOptimizations = assemblyOptimizations.nonEmpty
var optimizedSize = 0L
var unoptimizedSize = 0L
// print asm
env.allPreallocatables.foreach {
case f: NormalFunction =>
val result = MlCompiler.compile(CompilationContext(f.environment, f, 0, options))
val unoptimized = result.linearize
if (hasOptimizations) {
val optimized = assemblyOptimizations.foldLeft(unoptimized) { (c, opt) =>
opt.optimize(f, c, options)
}
println("Unoptimized:")
unoptimized.filter(_.isPrintable).foreach(println(_))
println("Optimized:")
optimized.filter(_.isPrintable).foreach(println(_))
unoptimizedSize += unoptimized.map(_.sizeInBytes).sum
optimizedSize += optimized.map(_.sizeInBytes).sum
} else {
unoptimized.filter(_.isPrintable).foreach(println(_))
unoptimizedSize += unoptimized.map(_.sizeInBytes).sum
optimizedSize += unoptimized.map(_.sizeInBytes).sum
}
unoptimizedSize += unoptimized.map(_.sizeInBytes).sum
case d: InitializedArray =>
println(d.name)
d.contents.foreach(c => println(" !byte " + c))
unoptimizedSize += d.contents.length
optimizedSize += d.contents.length
case d: InitializedMemoryVariable =>
println(d.name)
0.until(d.typ.size).foreach(c => println(" !byte " + d.initialValue.subbyte(c)))
unoptimizedSize += d.typ.size
optimizedSize += d.typ.size
}
ErrorReporting.assertNoErrors("Compile failed")
// compile
val assembler = new Assembler(program, env)
val output = assembler.assemble(callGraph, assemblyOptimizations, options)
println(";;; compiled: -----------------")
output.asm.takeWhile(s => !(s.startsWith(".") && s.contains("= $"))).foreach(println)
println(";;; ---------------------------")
assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") }
val optimizedSize = assembler.mem.banks(0).occupied.count(identity).toLong
if (unoptimizedSize == optimizedSize) {
println(f"Size: $unoptimizedSize%5d B")
} else {
@ -151,11 +143,6 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
println(f"Gain: ${(100L * (unoptimizedSize - optimizedSize) / unoptimizedSize.toDouble).round}%5d%%")
}
// compile
val assembler = new Assembler(env)
assembler.assemble(callGraph, assemblyOptimizations, options)
assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") }
ErrorReporting.assertNoErrors("Code generation failed")
val memoryBank = assembler.mem.banks(0)

View File

@ -0,0 +1,18 @@
package millfork.test.emu
import millfork.assembly.opt.SuperOptimizer
import millfork.{Cpu, OptimizationPresets}
/**
* @author Karol Stasiak
*/
object EmuSuperQuantumOptimizedInliningRun extends EmuRun(
Cpu.StrictMos,
OptimizationPresets.NodeOpt,
List(SuperOptimizer),
true) {
override def inline = true
}

View File

@ -6,7 +6,6 @@ import millfork.{Cpu, OptimizationPresets}
/**
* @author Karol Stasiak
*/
// TODO : it doesn't work
object EmuSuperOptimizedRun extends EmuRun(
Cpu.StrictMos,
OptimizationPresets.NodeOpt,

View File

@ -12,15 +12,18 @@ object EmuUltraBenchmarkRun {
val (Timings(t2, _), m2) = EmuSuperOptimizedRun.apply2(source)
val (Timings(t3, _), m3) = EmuQuantumOptimizedRun.apply2(source)
val (Timings(t4, _), m4) = EmuSuperQuantumOptimizedRun.apply2(source)
val (Timings(t5, _), m5) = EmuSuperQuantumOptimizedInliningRun.apply2(source)
println(f"Before optimization: $t0%7d")
println(f"After optimization: $t1%7d")
println(f"After superopt.: $t2%7d")
println(f"After quantum: $t3%7d")
println(f"After superquantum: $t4%7d")
println(f"After inlining: $t5%7d")
println(f"Gain: ${(100L*(t0-t1)/t0.toDouble).round}%7d%%")
println(f"Superopt. gain: ${(100L*(t0-t2)/t0.toDouble).round}%7d%%")
println(f"Quantum gain: ${(100L*(t0-t3)/t0.toDouble).round}%7d%%")
println(f"Super quantum gain: ${(100L*(t0-t4)/t0.toDouble).round}%7d%%")
println(f"SQ+inlining gain: ${(100L*(t0-t5)/t0.toDouble).round}%7d%%")
println(f"Running unoptimized")
verifier(m0)
println(f"Running optimized")
@ -31,5 +34,7 @@ object EmuUltraBenchmarkRun {
verifier(m3)
println(f"Running superquantum optimized")
verifier(m4)
println(f"Running superquantum optimized inlined")
verifier(m5)
}
}