From ff16854a114105f6de004b8d9232405608fdc0b7 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Mon, 6 Aug 2018 19:29:09 +0200 Subject: [PATCH] Code deduplication --- CHANGELOG.md | 2 + docs/abi/generated-labels.md | 1 + .../millfork/output/AbstractAssembler.scala | 70 +++-- .../millfork/output/CompiledFunction.scala | 22 ++ .../scala/millfork/output/Deduplicate.scala | 264 ++++++++++++++++++ .../scala/millfork/output/MosAssembler.scala | 3 + .../millfork/output/MosDeduplicate.scala | 66 +++++ .../scala/millfork/output/Z80Assembler.scala | 3 + .../millfork/output/Z80Deduplicate.scala | 75 +++++ .../millfork/test/DeduplicationSuite.scala | 72 +++++ .../millfork/test/emu/EmuOptimizedRun.scala | 22 ++ src/test/scala/millfork/test/emu/EmuRun.scala | 3 + .../EmuSizeOptimizedCrossPlatformRun.scala | 27 ++ .../scala/millfork/test/emu/EmuZ80Run.scala | 3 + 14 files changed, 596 insertions(+), 37 deletions(-) create mode 100644 src/main/scala/millfork/output/CompiledFunction.scala create mode 100644 src/main/scala/millfork/output/Deduplicate.scala create mode 100644 src/main/scala/millfork/output/MosDeduplicate.scala create mode 100644 src/main/scala/millfork/output/Z80Deduplicate.scala create mode 100644 src/test/scala/millfork/test/DeduplicationSuite.scala create mode 100644 src/test/scala/millfork/test/emu/EmuSizeOptimizedCrossPlatformRun.scala diff --git a/CHANGELOG.md b/CHANGELOG.md index dda51a30..ed5860bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ Code that uses a custom platform definitions will cause extra warnings until fix * Optimizations for stack variables. +* Various code deduplication optimizations. + * Fixed emitting constant decimal expressions. * Fixed decimal subtraction. diff --git a/docs/abi/generated-labels.md b/docs/abi/generated-labels.md index 59977d9a..07cf993a 100644 --- a/docs/abi/generated-labels.md +++ b/docs/abi/generated-labels.md @@ -58,4 +58,5 @@ where `11111` is a sequential number and `xx` is the type: * `wh` – beginning of a `while` statement +* `xc` – automatically extracted subroutine of commonly repeating code diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index e2447b4a..e7f5f418 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -153,6 +153,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program def bytePseudoopcode: String + def deduplicate(options: CompilationOptions, compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Unit + def assemble(callGraph: CallGraph, optimizations: Seq[AssemblyOptimization[T]], options: CompilationOptions): AssemblerOutput = { mem.programName = options.outputFileName.getOrElse("MILLFORK") val platform = options.platform @@ -184,7 +186,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program env.allocateVariables(None, mem, callGraph, variableAllocators, options, labelMap.put, 3, forZpOnly = true) var inlinedFunctions = Map[String, List[T]]() - val compiledFunctions = mutable.Map[String, List[T]]() + val compiledFunctions = mutable.Map[String, CompiledFunction[T]]() val recommendedCompilationOrder = callGraph.recommendedCompilationOrder val niceFunctionProperties = mutable.Set[(NiceFunctionProperty, String)]() recommendedCompilationOrder.foreach { f => @@ -199,10 +201,10 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program case Some(c) => log.debug("Inlining " + f, function.position) inlinedFunctions += f -> c - compiledFunctions(f) = Nil + compiledFunctions(f) = NonexistentFunction() case None => nonInlineableFunctions += function.name - compiledFunctions(f) = code + compiledFunctions(f) = NormalCompiledFunction(function.declaredBank.getOrElse(platform.defaultCodeBank), code, function.address.isDefined) optimizedCodeSize += code.map(_.sizeInBytes).sum if (options.flag(CompilationFlag.InterproceduralOptimization)) { gatherNiceFunctionProperties(niceFunctionProperties, f, code) @@ -211,6 +213,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program function.environment.removedThings.foreach(env.removeVariable) } } +// deduplicate(options, compiledFunctions) if (log.traceEnabled) { niceFunctionProperties.toList.groupBy(_._2).mapValues(_.map(_._1).sortBy(_.toString)).toList.sortBy(_._1).foreach{ case (fname, properties) => log.trace(fname.padTo(30, ' ') + properties.mkString(" ")) @@ -253,51 +256,44 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program val bank = f.bank(options) val bank0 = mem.banks(bank) val index = f.address.get.asInstanceOf[NumericConstant].value.toInt - val code = compiledFunctions(f.name) - if (code.nonEmpty) { - labelMap(f.name) = index - val end = outputFunction(bank, code, index, assembly, options) - for(i <- index until end) { - bank0.occupied(index) = true - bank0.initialized(index) = true - bank0.readable(index) = true - } + compiledFunctions(f.name) match { + case NormalCompiledFunction(_, code, _) => + labelMap(f.name) = index + val end = outputFunction(bank, code, index, assembly, options) + for (i <- index until end) { + bank0.occupied(index) = true + bank0.initialized(index) = true + bank0.readable(index) = true + } + case NonexistentFunction() => throw new IllegalStateException() + case RedirectedFunction(_, _, _) => throw new IllegalStateException() } case _ => } val codeAllocators = platform.codeAllocators.mapValues(new VariableAllocator(Nil, _)) var justAfterCode = platform.codeAllocators.mapValues(a => a.startAt) - env.allPreallocatables.foreach { - case f: NormalFunction if f.address.isEmpty && f.name == "main" => - val bank = f.bank(options) - val code = compiledFunctions(f.name) - if (code.nonEmpty) { - val size = code.map(_.sizeInBytes).sum - val index = codeAllocators(bank).allocateBytes(mem.banks(bank), options, size, initialized = true, writeable = false, location = AllocationLocation.High) - labelMap(f.name) = index - justAfterCode += bank -> outputFunction(bank, code, index, assembly, options) - } - case _ => - } - env.allPreallocatables.foreach { - case f: NormalFunction if f.address.isEmpty && f.name != "main" => - val bank = f.bank(options) - val bank0 = mem.banks(bank) - val code = compiledFunctions(f.name) - if (code.nonEmpty) { - val size = code.map(_.sizeInBytes).sum - val index = codeAllocators(bank).allocateBytes(bank0, options, size, initialized = true, writeable = false, location = AllocationLocation.High) - labelMap(f.name) = index - justAfterCode += bank -> outputFunction(bank, code, index, assembly, options) - } - case _ => + compiledFunctions.toList.sortBy{case (name, cf) => if (name == "main") 0 -> "" else cf.orderKey}.foreach { + case (_, NormalCompiledFunction(_, _, true)) => + // already done before + case (name, NormalCompiledFunction(bank, code, false)) => + val size = code.map(_.sizeInBytes).sum + val index = codeAllocators(bank).allocateBytes(mem.banks(bank), options, size, initialized = true, writeable = false, location = AllocationLocation.High) + labelMap(name) = index + justAfterCode += bank -> outputFunction(bank, code, index, assembly, options) + case (_, NonexistentFunction()) => + case (name, RedirectedFunction(_, target, offset)) => + labelMap(name) = labelMap(target) + offset } + if (options.flag(CompilationFlag.LUnixRelocatableCode)) { env.allThings.things.foreach { case (_, m@UninitializedMemoryVariable(name, typ, _, _)) if name.endsWith(".addr") || env.maybeGet[Thing](name + ".array").isDefined => - val isUsed = compiledFunctions.values.exists(_.exists(_.parameter.isRelatedTo(m))) + val isUsed = compiledFunctions.values.exists{ + case NormalCompiledFunction(_, code, _) => code.exists(_.parameter.isRelatedTo(m)) + case _ => false + } // println(m.name -> isUsed) if (isUsed) { val bank = m.bank(options) diff --git a/src/main/scala/millfork/output/CompiledFunction.scala b/src/main/scala/millfork/output/CompiledFunction.scala new file mode 100644 index 00000000..1ca46185 --- /dev/null +++ b/src/main/scala/millfork/output/CompiledFunction.scala @@ -0,0 +1,22 @@ +package millfork.output + +import millfork.assembly.AbstractCode + +/** + * @author Karol Stasiak + */ +sealed trait CompiledFunction[T <: AbstractCode] { + def orderKey : (Int, String) +} + +case class NormalCompiledFunction[T <: AbstractCode](segment: String, code: List[T], hasFixedAddress: Boolean) extends CompiledFunction[T] { + override def orderKey: (Int, String) = (if (hasFixedAddress) 1 else 2) -> "" +} + +case class RedirectedFunction[T <: AbstractCode](segment: String, redirect: String, offset: Int) extends CompiledFunction[T] { + override def orderKey: (Int, String) = 3 -> redirect +} + +case class NonexistentFunction[T <: AbstractCode]() extends CompiledFunction[T] { + override def orderKey: (Int, String) = 4 -> "" +} diff --git a/src/main/scala/millfork/output/Deduplicate.scala b/src/main/scala/millfork/output/Deduplicate.scala new file mode 100644 index 00000000..7e19ed7c --- /dev/null +++ b/src/main/scala/millfork/output/Deduplicate.scala @@ -0,0 +1,264 @@ +package millfork.output + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.assembly.AbstractCode +import millfork.env.Environment + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * @author Karol Stasiak + */ +abstract class Deduplicate[T <: AbstractCode](env: Environment, options: CompilationOptions) { + + def apply(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Unit = { + if (options.flag(CompilationFlag.OptimizeForSize)) { + runStage(compiledFunctions, extractCommonCode) + } + runStage(compiledFunctions, deduplicateIdenticalFunctions) + runStage(compiledFunctions, eliminateTailJumps) + } + + def runStage(compiledFunctions: mutable.Map[String, CompiledFunction[T]], + function: (String, Map[String, Either[String, List[T]]]) => Seq[(String, CompiledFunction[T])]): Unit = { + bySegment(compiledFunctions).foreach { + case (segmentName, segContents) => + function(segmentName, segContents).foreach { + case (fname, cf) => compiledFunctions(fname) = cf + } + } + } + + def extractCommonCode(segmentName: String, segContents: Map[String, Either[String, List[T]]]): Seq[(String, CompiledFunction[T])] = { + var result = ListBuffer[(String, CompiledFunction[T])]() + val chunks = segContents.flatMap{ + case (_, Left(_)) => Nil + case (functionName, Right(code)) => + if (options.flag(CompilationFlag.OptimizeForSize)) { + getExtractableSnippets(functionName, code) + } else Nil + }.flatMap { chunk => + for { + start <- chunk.code.indices + end <- start + 1 to chunk.code.length + } yield CodeChunk(chunk.functionName, chunk.offset + start, chunk.offset + end)(chunk.code.slice(start, end)) + }.filter(_.codeSizeInBytes > 3).groupBy(_.code).filter{ + case (code, _) => + if (isBadExtractedCodeHead(code.head)) false + else if (isBadExtractedCodeLast(code.last)) false + else true + }.mapValues(_.toSeq).filter { + case (_, instances) => + val chunkSize = instances.head.codeSizeInBytes + val extractedProcedureSize = chunkSize + 1 + val savedInCallers = (chunkSize - 3) * instances.length + val maxPossibleProfit = savedInCallers - extractedProcedureSize + // (instances.length >=2) println(s"Instances: ${instances.length}, max profit: $maxPossibleProfit: $instances") + maxPossibleProfit > 0 && instances.length >= 2 // TODO + }.flatMap(_._2).toSeq + //println(s"Chunks: ${chunks.length} $chunks") + val candidates: Seq[(Int, Map[List[T], Seq[CodeChunk[T]]])] = powerSet(chunks)((set, chunk) => !set.exists(_ & chunk)).filter(_.nonEmpty).filter(set => (for { + x <- set + y <- set + if x != y + } yield x & y).forall(_ == false)).toSeq.map(_.groupBy(_.code).filter(_._2.size >= 2).mapValues(_.toSeq)).filter(_.nonEmpty).map { map => + map.foldLeft(0) { + (sum, entry) => + val chunkSize = entry._2.head.codeSizeInBytes + val chunkCount = entry._2.size + val extractedProcedureSize = chunkSize + 1 + val savedInCallers = (chunkSize - 3) * chunkCount + sum + savedInCallers - extractedProcedureSize + } -> map + }.filter { set => + val allChunks = set._2.values.flatten + (for { + x <- allChunks + y <- allChunks + if x != y + } yield x & y).forall(_ == false) + } +// candidates.sortBy(_._1).foreach { +// case (profit, map) => +// if (profit > 0) { +// println(s"Profit: $profit ${map.map { case (_, instances) => s"${instances.length}×${instances.head}" }.mkString(" ; ")}") +// } +// } + if (candidates.nonEmpty) { + val best = candidates.maxBy(_._1) + //println(s"Best extraction candidate: $best") + val allAffectedFunctions = best._2.values.flatten.map(_.functionName).toSet + val toRemove = allAffectedFunctions.map(_ -> mutable.Set[Int]()).toMap + val toReplace = allAffectedFunctions.map(_ -> mutable.Map[Int, String]()).toMap + if (options.log.traceEnabled){ + options.log.debug(s"Extracted ${best._2.size} common code subroutines from ${allAffectedFunctions.size} functions, saving $best._1 bytes") + } + for((code, instances) <- best._2) { + val newName = env.nextLabel("xc") + result += newName -> NormalCompiledFunction(segmentName, createLabel(newName) :: tco(code :+ createReturn), hasFixedAddress = false) + for(instance <- instances) { + toReplace(instance.functionName)(instance.offset) = newName + for (i <- instance.offset + 1 until instance.endOffset) { + toRemove(instance.functionName) += i + } + } + } + for(functionName <- allAffectedFunctions) { + result += functionName -> { + val linesToRemove = toRemove(functionName) + val linesToReplace = toReplace(functionName) + val newCode = segContents(functionName).right.get.zipWithIndex.flatMap{ + case (line, i) => + if (linesToRemove(i)) None + else if (linesToReplace.contains(i)) Some(createCall(linesToReplace(i))) + else Some(line) + } + NormalCompiledFunction(segmentName, tco(newCode), hasFixedAddress = false) + } + } + + } + result.toSeq + } + + def deduplicateIdenticalFunctions(segmentName: String, segContents: Map[String, Either[String, List[T]]]): Seq[(String, CompiledFunction[T])] = { + var result = ListBuffer[(String, CompiledFunction[T])]() + val identicalFunctions = segContents.flatMap{ + case (name, code) => code.toOption.map(c => name -> actualCode(name, c)) + }.groupBy(_._2).values.toSeq.map(_.keySet).filter(set => set.size > 1) + for(set <- identicalFunctions) { + val representative = if (set("main")) "main" else set.head + options.log.debug(s"Functions [${set.mkString(",")}] are identical") + for (function <- set) { + if (function != representative) { + result += function -> RedirectedFunction(segmentName, representative, 0) + } else { + segContents(function) match { + case Right(code) => + result += function -> NormalCompiledFunction(segmentName, + set.toList.map(name => createLabel(name)) ++ actualCode(function, code), + hasFixedAddress = false) + case Left(_) => + } + } + } + } + result.toSeq + } + + private def follow(segContents: Map[String, Either[String, List[T]]], to: String): Option[String] = { + var result: String = to + val visited = mutable.Set[String]() + do { + segContents.get(to) match { + case Some(Left(next)) => + if (visited(next)) return None + visited += result + result = next + case Some(Right(_)) => + return Some(result) + case _ => return None + } + } while(true) + None + } + + def eliminateTailJumps(segmentName: String, segContents: Map[String, Either[String, List[T]]]): Seq[(String, CompiledFunction[T])] = { + var result = ListBuffer[(String, CompiledFunction[T])]() + val fallThroughList = segContents.flatMap { + case (name, Right(code)) => + if (code.isEmpty) None + else getJump(code.last) + .filter(segContents.contains) + .filter(_ != name) + .filter(_ != "main") + .flatMap(to => follow(segContents, to)) + .map(name -> _) + case _ => None + } + val fallthroughPredecessors = fallThroughList.groupBy(_._2).mapValues(_.head._1) // TODO: be smarter than head + fallthroughPredecessors.foreach { + case (to, from) => + options.log.debug(s"Fallthrough from $from to $to") + val init = segContents(from).right.get.init + result += from -> NormalCompiledFunction(segmentName, + init ++ segContents(to).right.get, + hasFixedAddress = false + ) + val initSize = init.map(_.sizeInBytes).sum + if (initSize <= 2) { + result += to -> RedirectedFunction(segmentName, from, initSize) + } else { + result += to -> NonexistentFunction() + } + } + result.toSeq + } + + def tco(code: List[T]): List[T] + + def isBadExtractedCodeHead(head: T): Boolean + + def isBadExtractedCodeLast(head: T): Boolean + + def getJump(line: T): Option[String] + + def createCall(functionName: String): T + + def createReturn(): T + + def createLabel(name: String): T + + def bySegment(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Map[String, Map[String, Either[String, List[T]]]] = { + compiledFunctions.flatMap { + case (name, NormalCompiledFunction(segment, code, false)) => Some((segment, name, Right(code))) // TODO + case (name, RedirectedFunction(segment, target, 0)) => Some((segment, name, Left(target))) // TODO + case _ => None + }.groupBy(_._1).mapValues(_.map { case (_, name, code) => name -> code }.toMap) + } + + def actualCode(functionName: String, functionCode: List[T]): List[T] + + def isExtractable(line: T): Boolean + + def getExtractableSnippets(functionName: String, code: List[T]): List[CodeChunk[T]] = { + var cursor = 0 + var mutCode = code + val result = mutable.ListBuffer[CodeChunk[T]]() + while (true) { + val (bad, rest1) = mutCode.span(l => !isExtractable(l)) + mutCode = rest1 + cursor += bad.length + val (good, rest2) = mutCode.span(l => isExtractable(l)) + mutCode = rest2 + if (good.nonEmpty) { + result += CodeChunk(functionName, cursor, cursor + good.length)(good) + cursor += good.length + } else { + //println(s"Snippets in $functionName: $result") + return result.toList + } + } + null + } + + def powerSet[A](t: Iterable[A])(f: (Set[A], A) => Boolean): Set[Set[A]] = { + @annotation.tailrec + def pwr(t: Iterable[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps.filter(p => f(p, t.head)) map (_ + t.head))) + pwr(t, Set(Set.empty[A])) + } +} + +case class CodeChunk[T <: AbstractCode](functionName: String, offset: Int, endOffset: Int)(val code: List[T]) { + val codeSizeInBytes: Int = code.map(_.sizeInBytes).sum + + def &(that: CodeChunk[T]): Boolean = + this.functionName == that.functionName && + this.offset <= that.endOffset && + that.offset <= this.endOffset + + override def toString: String = s"$functionName:$offset:${code.map(_.toString.trim).mkString(";")}($codeSizeInBytes bytes)" +} diff --git a/src/main/scala/millfork/output/MosAssembler.scala b/src/main/scala/millfork/output/MosAssembler.scala index eb801474..b063bafb 100644 --- a/src/main/scala/millfork/output/MosAssembler.scala +++ b/src/main/scala/millfork/output/MosAssembler.scala @@ -166,6 +166,9 @@ class MosAssembler(program: Program, } override def bytePseudoopcode: String = "!byte" + + override def deduplicate(options: CompilationOptions, compiledFunctions: mutable.Map[String, CompiledFunction[AssemblyLine]]): Unit = + new MosDeduplicate(rootEnv, options).apply(compiledFunctions) } diff --git a/src/main/scala/millfork/output/MosDeduplicate.scala b/src/main/scala/millfork/output/MosDeduplicate.scala new file mode 100644 index 00000000..3bdbd6e7 --- /dev/null +++ b/src/main/scala/millfork/output/MosDeduplicate.scala @@ -0,0 +1,66 @@ +package millfork.output + +import millfork.CompilationOptions +import millfork.assembly.mos.{AddrMode, AssemblyLine, Opcode} +import millfork.env.{Environment, Label, MemoryAddressConstant} +import Opcode._ +import millfork.assembly.mos.AddrMode._ + +/** + * @author Karol Stasiak + */ +class MosDeduplicate(env: Environment, options: CompilationOptions) extends Deduplicate[AssemblyLine](env, options) { + override def getJump(line: AssemblyLine): Option[String] = line match { + case AssemblyLine(Opcode.JMP, Absolute, MemoryAddressConstant(thing), _) => Some(thing.name) + case _ => None + } + + override def createLabel(name: String): AssemblyLine = AssemblyLine.label(name) + + override def actualCode(FunctionName: String, functionCode: List[AssemblyLine]): List[AssemblyLine] = { + functionCode match { + case AssemblyLine(Opcode.LABEL, _, MemoryAddressConstant(Label(FunctionName)), _) :: xs => xs + case xs => xs + } + } + + private val goodOpcodes = Set( + ADC, SBC, CMP, AND, EOR, ORA, + ADC_W, SBC_W, CMP_W, AND_W, EOR_W, ORA_W, + ASL, ROL, LSR, ROR, INC, DEC, + ASL_W, ROL_W, LSR_W, ROR_W, INC_W, DEC_W, + NEG, ASR, + LDA, STA, LDX, STX, LDY, STY, LDZ, STZ, + LDA_W, STA_W, LDX_W, STX_W, LDY_W, STY_W, STZ_W, + TAX, TXA, TAY, TYA, TXY, TYX, TAZ, TZA, XBA, + SLO, SRE, RRA, RLA, ARR, ALR, ANC, SBX, LXA, XAA, DCP, ISC, + CPX, CPY, CPZ, CPX_W, CPY_W, + INX, INY, INZ, INX_W, INY_W, + DEX, DEY, DEZ, DEX_W, DEY_W, + BIT, TRB, TSB, + JSR, + NOP, WAI, STP, + SED, CLD, SEC, CLC, CLV, SEI, CLI, SEP, REP, + HuSAX, SAY, SXY, + CLA, CLX, CLY, + ) + + private val badAddressingModes = Set(Stack, IndexedSY, Relative) + + override def isExtractable(line: AssemblyLine): Boolean = + goodOpcodes(line.opcode) && !badAddressingModes(line.addrMode) + + override def isBadExtractedCodeHead(head: AssemblyLine): Boolean = false + + override def isBadExtractedCodeLast(head: AssemblyLine): Boolean = false + + override def createCall(functionName: String): AssemblyLine = AssemblyLine.absolute(Opcode.JSR, Label(functionName)) + + override def createReturn(): AssemblyLine = AssemblyLine.implied(RTS) + + override def tco(code: List[AssemblyLine]): List[AssemblyLine] = code match { + case (call@AssemblyLine(JSR, Absolute, _, _)) :: AssemblyLine(RTS, _, _, _) :: xs => call.copy(opcode = JMP) :: tco(xs) + case x :: xs => x :: tco(xs) + case Nil => Nil + } +} diff --git a/src/main/scala/millfork/output/Z80Assembler.scala b/src/main/scala/millfork/output/Z80Assembler.scala index cf4c3490..46ba876f 100644 --- a/src/main/scala/millfork/output/Z80Assembler.scala +++ b/src/main/scala/millfork/output/Z80Assembler.scala @@ -635,6 +635,9 @@ class Z80Assembler(program: Program, } override def bytePseudoopcode: String = "DB" + + override def deduplicate(options: CompilationOptions, compiledFunctions: mutable.Map[String, CompiledFunction[ZLine]]): Unit = + new Z80Deduplicate(rootEnv, options).apply(compiledFunctions) } object Z80Assembler { diff --git a/src/main/scala/millfork/output/Z80Deduplicate.scala b/src/main/scala/millfork/output/Z80Deduplicate.scala new file mode 100644 index 00000000..c0f2dace --- /dev/null +++ b/src/main/scala/millfork/output/Z80Deduplicate.scala @@ -0,0 +1,75 @@ +package millfork.output + +import millfork.CompilationOptions +import millfork.assembly.z80.{ZOpcode, _} +import millfork.env.{Environment, Label, MemoryAddressConstant} +import ZOpcode._ +import millfork.node.ZRegister.SP + +/** + * @author Karol Stasiak + */ +class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Deduplicate[ZLine](env, options) { + override def getJump(line: ZLine): Option[String] = line match { + case ZLine(JP, NoRegisters, MemoryAddressConstant(thing), _) => Some(thing.name) + case _ => None + } + + override def createLabel(name: String): ZLine = ZLine.label(name) + + override def actualCode(FunctionName: String, functionCode: List[ZLine]): List[ZLine] = { + functionCode match { + case ZLine(LABEL, _, MemoryAddressConstant(Label(FunctionName)), _) :: xs => xs + case xs => xs + } + } + + private val alwaysGoodOpcodes: Set[ZOpcode.Value] = Set( + ADD, ADC, SUB, SBC, XOR, OR, AND, CP, + LD, INC, DEC, + DAA, CPL, SCF, CCF, NEG, EX_DE_HL, + RLA, RRA, RLCA, RRCA, + RL, RR, RLC, RRC, SLA, SLL, SRL, SRA, SWAP, + RLD, RRD, + EI, DI, IM, HALT, NOP, + LDI, LDD, LDIR, LDDR, + INI, IND, INIR, INDR, + CPI, CPD, CPIR, CPDR, + OUTI, OUTD, OUTIR, OUTDR, + IN_IMM, OUT_IMM, IN_C, OUT_C, + LD_AHLI, LD_AHLD, LD_HLIA, LD_HLDA, + LDH_AC, LDH_AD, LDH_CA, LDH_DA, + CALL, + ) ++ ZOpcodeClasses.AllSingleBit + + private val conditionallyGoodOpcodes = Set( + LD_16, ADD_16, SBC_16, ADC_16, INC_16, DEC_16, + ) + + override def isExtractable(line: ZLine): Boolean = { + alwaysGoodOpcodes(line.opcode) || + conditionallyGoodOpcodes(line.opcode) && (line.registers match { + case OneRegister(SP) => false + case TwoRegisters(_, SP) => false + case TwoRegisters(SP, _) => false + case _ => true + }) + } + + override def isBadExtractedCodeHead(head: ZLine): Boolean = false + + override def isBadExtractedCodeLast(head: ZLine): Boolean = head.opcode match { + case EI | DI | IM => true + case _ => false + } + + override def createCall(functionName: String): ZLine = ZLine(CALL, NoRegisters, MemoryAddressConstant(Label(functionName)), elidable = false) + + override def createReturn(): ZLine = ZLine.implied(RET) + + override def tco(code: List[ZLine]): List[ZLine] = code match { + case (call@ZLine(CALL, _, _, _)) :: ZLine(RET, NoRegisters, _, _) :: xs => call.copy(opcode = JP) :: tco(xs) + case x :: xs => x :: tco(xs) + case Nil => Nil + } +} diff --git a/src/test/scala/millfork/test/DeduplicationSuite.scala b/src/test/scala/millfork/test/DeduplicationSuite.scala new file mode 100644 index 00000000..137a9721 --- /dev/null +++ b/src/test/scala/millfork/test/DeduplicationSuite.scala @@ -0,0 +1,72 @@ +package millfork.test + +import millfork.Cpu +import millfork.test.emu._ +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class DeduplicationSuite extends FunSuite with Matchers { + + test("Code deduplication") { + EmuOptimizedCmosRun( + """ + | + | void main() { + | times2(1) + | shift_left(2) + | nothing(2) + | } + | noinline byte shift_left(byte x) { + | return x << 1 + | } + | noinline byte times2(byte x) { + | x *= 2 + | return x + | } + | noinline void nothing(byte x) { + | } + """.stripMargin) + } + + test("Subroutine extraction") { + EmuSizeOptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080)( + """ + | int24 output @$c000 + | void main() { + | output.b0 = f(1) + | output.b1 = g(2) + | output.b2 = h(2) + | } + | noinline byte f(byte x) { + | x += 2 + | x |= 4 + | x <<= 1 + | x &= 7 + | x -= 6 + | return x + | } + | noinline byte g(byte x) { + | x += 3 + | x |= 4 + | x <<= 1 + | x &= 7 + | x -= 7 + | return x + | } + | noinline byte h(byte x) { + | x += 5 + | x |= 4 + | x <<= 1 + | x &= 7 + | x -= 5 + | return x + | } + | noinline void nothing(byte x) { + | } + """.stripMargin) {m => + m.readMedium(0xc000) should equal(0x1FB00) + } + } +} diff --git a/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala b/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala index c7b69727..9eea2631 100644 --- a/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala +++ b/src/test/scala/millfork/test/emu/EmuOptimizedRun.scala @@ -20,10 +20,32 @@ object EmuOptimizedRun extends EmuRun( OptimizationPresets.Good ++ OptimizationPresets.Good) +object EmuSizeOptimizedRun extends EmuRun( + Cpu.StrictMos, + OptimizationPresets.NodeOpt, + OptimizationPresets.AssOpt ++ + ZeropageRegisterOptimizations.All ++ + OptimizationPresets.Good ++ + OptimizationPresets.Good ++ + OptimizationPresets.Good ++ LaterOptimizations.Nmos ++ + OptimizationPresets.Good ++ LaterOptimizations.Nmos ++ + ZeropageRegisterOptimizations.All ++ + OptimizationPresets.Good ++ + OptimizationPresets.Good) { + override def optimizeForSize = true +} object EmuOptimizedZ80Run extends EmuZ80Run(Cpu.Z80, OptimizationPresets.NodeOpt, Z80OptimizationPresets.GoodForZ80) +object EmuSizeOptimizedZ80Run extends EmuZ80Run(Cpu.Z80, OptimizationPresets.NodeOpt, Z80OptimizationPresets.GoodForZ80) { + override def optimizeForSize = true +} + object EmuOptimizedIntel8080Run extends EmuZ80Run(Cpu.Intel8080, OptimizationPresets.NodeOpt, Z80OptimizationPresets.GoodForIntel8080) +object EmuSizeOptimizedIntel8080Run extends EmuZ80Run(Cpu.Intel8080, OptimizationPresets.NodeOpt, Z80OptimizationPresets.GoodForIntel8080) { + override def optimizeForSize = true +} + object EmuOptimizedSharpRun extends EmuZ80Run(Cpu.Sharp, OptimizationPresets.NodeOpt, Z80OptimizationPresets.GoodForSharp) diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala index 9fd25a99..9deec760 100644 --- a/src/test/scala/millfork/test/emu/EmuRun.scala +++ b/src/test/scala/millfork/test/emu/EmuRun.scala @@ -61,6 +61,8 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], def blastProcessing = false + def optimizeForSize = false + private val timingNmos = Array[Int]( 7, 6, 0, 8, 3, 3, 5, 5, 3, 2, 2, 2, 4, 4, 6, 6, 2, 5, 0, 8, 4, 4, 6, 6, 2, 4, 2, 7, 4, 4, 7, 7, @@ -133,6 +135,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], CompilationFlag.EmitEmulation65816Opcodes -> (platform.cpu == millfork.Cpu.Sixteen), CompilationFlag.Emit65CE02Opcodes -> (platform.cpu == millfork.Cpu.CE02), CompilationFlag.EmitHudsonOpcodes -> (platform.cpu == millfork.Cpu.HuC6280), + CompilationFlag.OptimizeForSize -> optimizeForSize, CompilationFlag.OptimizeForSpeed -> blastProcessing, CompilationFlag.OptimizeForSonicSpeed -> blastProcessing // CompilationFlag.CheckIndexOutOfBounds -> true, diff --git a/src/test/scala/millfork/test/emu/EmuSizeOptimizedCrossPlatformRun.scala b/src/test/scala/millfork/test/emu/EmuSizeOptimizedCrossPlatformRun.scala new file mode 100644 index 00000000..53aa4506 --- /dev/null +++ b/src/test/scala/millfork/test/emu/EmuSizeOptimizedCrossPlatformRun.scala @@ -0,0 +1,27 @@ +package millfork.test.emu + +import millfork.Cpu +import millfork.output.MemoryBank + +/** + * @author Karol Stasiak + */ +object EmuSizeOptimizedCrossPlatformRun { + def apply(platforms: Cpu.Value*)(source: String)(verifier: MemoryBank => Unit): Unit = { + val (_, mm) = if (platforms.contains(Cpu.Mos)) EmuSizeOptimizedRun.apply2(source) else Timings(-1, -1) -> null + val (_, mz) = if (platforms.contains(Cpu.Z80)) EmuSizeOptimizedZ80Run.apply2(source) else Timings(-1, -1) -> null + val (_, mi) = if (platforms.contains(Cpu.Intel8080)) EmuSizeOptimizedIntel8080Run.apply2(source) else Timings(-1, -1) -> null + if (platforms.contains(Cpu.Mos)) { + println(f"Running 6502") + verifier(mm) + } + if (platforms.contains(Cpu.Z80)) { + println(f"Running Z80") + verifier(mz) + } + if (platforms.contains(Cpu.Intel8080)) { + println(f"Running 8080") + verifier(mi) + } + } +} diff --git a/src/test/scala/millfork/test/emu/EmuZ80Run.scala b/src/test/scala/millfork/test/emu/EmuZ80Run.scala index e9dca1b1..1e23cec7 100644 --- a/src/test/scala/millfork/test/emu/EmuZ80Run.scala +++ b/src/test/scala/millfork/test/emu/EmuZ80Run.scala @@ -24,6 +24,8 @@ import org.scalatest.Matchers class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], assemblyOptimizations: List[AssemblyOptimization[ZLine]]) extends Matchers { def inline: Boolean = false + def optimizeForSize: Boolean = false + private val TooManyCycles: Long = 1500000 def apply(source: String): MemoryBank = { @@ -38,6 +40,7 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio val platform = EmuPlatform.get(cpu) val extraFlags = Map( CompilationFlag.InlineFunctions -> this.inline, + CompilationFlag.OptimizeForSize -> this.optimizeForSize, CompilationFlag.EmitIllegals -> (cpu == millfork.Cpu.Z80), CompilationFlag.LenientTextEncoding -> true) val options = CompilationOptions(platform, millfork.Cpu.defaultFlags(cpu).map(_ -> true).toMap ++ extraFlags, None, 0, JobContext(log, new LabelGenerator))