diff --git a/docs/abi/generated-labels.md b/docs/abi/generated-labels.md index 186635c5..0ef173e8 100644 --- a/docs/abi/generated-labels.md +++ b/docs/abi/generated-labels.md @@ -20,6 +20,8 @@ where `11111` is a sequential number and `xx` is the type: * `cp` – equality comparison for larger types +* `dd` – labels renamed by code deduplication + * `de` – decrement for larger types * `do` – beginning of a `do-while` statement diff --git a/src/main/scala/millfork/assembly/z80/ZLine.scala b/src/main/scala/millfork/assembly/z80/ZLine.scala index 8ebbdcd2..1e729e4e 100644 --- a/src/main/scala/millfork/assembly/z80/ZLine.scala +++ b/src/main/scala/millfork/assembly/z80/ZLine.scala @@ -178,7 +178,8 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta import ZRegister._ val inherent = opcode match { case BYTE => 1 - case d if ZOpcodeClasses.NoopDiscards(d) => 0 + case LABEL => return 0 + case d if ZOpcodeClasses.NoopDiscards(d) => return 0 case JP => registers match { case OneRegister(HL | IX | IY) => 1 case _ => 2 diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index 6d01cbbc..57c71cc4 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -274,7 +274,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program val codeAllocators = platform.codeAllocators.mapValues(new VariableAllocator(Nil, _)) var justAfterCode = platform.codeAllocators.mapValues(a => a.startAt) - compiledFunctions.toList.sortBy{case (name, cf) => if (name == "main") 0 -> "" else cf.orderKey}.foreach { + val sortedCompilerFunctions = compiledFunctions.toList.sortBy { case (name, cf) => if (name == "main") 0 -> "" else cf.orderKey } + sortedCompilerFunctions.foreach { case (_, NormalCompiledFunction(_, _, true, _)) => // already done before case (name, NormalCompiledFunction(bank, code, false, alignment)) => @@ -282,9 +283,12 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program val index = codeAllocators(bank).allocateBytes(mem.banks(bank), options, size, initialized = true, writeable = false, location = AllocationLocation.High, alignment = alignment) labelMap(name) = index justAfterCode += bank -> outputFunction(bank, code, index, assembly, options) - case (_, NonexistentFunction()) => + case _ => + } + sortedCompilerFunctions.foreach { case (name, RedirectedFunction(_, target, offset)) => labelMap(name) = labelMap(target) + offset + case _ => } if (options.flag(CompilationFlag.LUnixRelocatableCode)) { diff --git a/src/main/scala/millfork/output/Deduplicate.scala b/src/main/scala/millfork/output/Deduplicate.scala index f8b9b976..d530585e 100644 --- a/src/main/scala/millfork/output/Deduplicate.scala +++ b/src/main/scala/millfork/output/Deduplicate.scala @@ -18,6 +18,14 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila } runStage(compiledFunctions, deduplicateIdenticalFunctions) runStage(compiledFunctions, eliminateTailJumps) + runStage(compiledFunctions, eliminateRemainingTrivialTailJumps) + fixDoubleRedirects(compiledFunctions) +// println(compiledFunctions.map { +// case (k, v) => k + " " + (v match { +// case _: NormalCompiledFunction[_] => "NormalCompiledFunction" +// case _ => v.toString +// }) +// }.mkString(" ; ")) } def runStage(compiledFunctions: mutable.Map[String, CompiledFunction[T]], @@ -32,18 +40,21 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila def extractCommonCode(segmentName: String, segContents: Map[String, Either[String, CodeAndAlignment[T]]]): Seq[(String, CompiledFunction[T])] = { var result = ListBuffer[(String, CompiledFunction[T])]() - val chunks = segContents.flatMap{ + val snippets: Seq[(List[T], CodeChunk[T])] = segContents.toSeq.flatMap { case (_, Left(_)) => Nil case (functionName, Right(CodeAndAlignment(code, _))) => if (options.flag(CompilationFlag.OptimizeForSize)) { - getExtractableSnippets(functionName, code) + getExtractableSnippets(functionName, code).map(code -> _) } else Nil - }.flatMap { chunk => - for { - start <- chunk.code.indices - end <- start + 1 to chunk.code.length - } yield CodeChunk(chunk.functionName, chunk.offset + start, chunk.offset + end)(chunk.code.slice(start, end)) - }.filter(_.codeSizeInBytes > 3).groupBy(_.code).filter{ + } + val chunks: Seq[CodeChunk[T]] = snippets.flatMap { case (wholeCode, snippet) => + for { + start <- snippet.code.indices + end <- start + 1 to snippet.code.length + } yield wholeCode -> CodeChunk(snippet.functionName, snippet.offset + start, snippet.offset + end)(snippet.code.slice(start, end)) + }.map(_._2).filter(_.codeSizeInBytes > 3).groupBy { chunk => + renumerateLabels(chunk.code, temporary = true) + }.filter { case (code, _) => if (isBadExtractedCodeHead(code.head)) false else if (isBadExtractedCodeLast(code.last)) false @@ -62,7 +73,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila x <- set y <- set if x != y - } yield x & y).forall(_ == false)).toSeq.map(_.groupBy(_.code).filter(_._2.size >= 2).mapValues(_.toSeq)).filter(_.nonEmpty).map { map => + } yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => renumerateLabels(chunk.code, temporary = true)).filter(_._2.size >= 2).mapValues(_.toSeq)).filter(_.nonEmpty).map { map => map.foldLeft(0) { (sum, entry) => val chunkSize = entry._2.head.codeSizeInBytes @@ -96,7 +107,8 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila } for((code, instances) <- best._2) { val newName = env.nextLabel("xc") - result += newName -> NormalCompiledFunction(segmentName, createLabel(newName) :: tco(code :+ createReturn), hasFixedAddress = false, alignment = NoAlignment) + val newCode = createLabel(newName) :: tco(renumerateLabels(instances.head.code, temporary = false) :+ createReturn) + result += newName -> NormalCompiledFunction(segmentName, newCode, hasFixedAddress = false, alignment = NoAlignment) for(instance <- instances) { toReplace(instance.functionName)(instance.offset) = newName for (i <- instance.offset + 1 until instance.endOffset) { @@ -127,7 +139,11 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila var result = ListBuffer[(String, CompiledFunction[T])]() val identicalFunctions = segContents.flatMap{ case (name, code) => code.toOption.map(c => name -> actualCode(name, c.code)) - }.groupBy(_._2).values.toSeq.map(_.keySet).filter(set => set.size > 1) + }.filter{ + case (_, code) => checkIfLabelsAreInternal(code, code) + }.groupBy{ + case (_, code) => renumerateLabels(code, temporary = true) + }.values.toSeq.map(_.keySet).filter(set => set.size > 1) for(set <- identicalFunctions) { val representative = if (set("main")) "main" else set.head options.log.debug(s"Functions [${set.mkString(",")}] are identical") @@ -153,7 +169,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila var result: String = to val visited = mutable.Set[String]() do { - segContents.get(to) match { + segContents.get(result) match { case Some(Left(next)) => if (visited(next)) return None visited += result @@ -200,11 +216,67 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila result.toSeq } + def eliminateRemainingTrivialTailJumps(segmentName: String, segContents: Map[String, Either[String, CodeAndAlignment[T]]]): Seq[(String, CompiledFunction[T])] = { + var result = ListBuffer[(String, CompiledFunction[T])]() + val fallThroughList = segContents.flatMap { + case (name, Right(CodeAndAlignment(code, alignment))) => + if (code.length != 2) None + else getJump(code.last) + .filter(segContents.contains) + .filter(_ != name) + .filter(_ != "main") + .map(name -> _) + case _ => None + } + val fallthroughPredecessors = fallThroughList.groupBy(_._2).mapValues(_.keySet) + fallthroughPredecessors.foreach { + case (to, froms) => + for (from <- froms) { + options.log.debug(s"Trivial fallthrough from $from to $to") + result += from -> RedirectedFunction(segmentName, to, 0) + follow(segContents, to) match { + case Some(actualTo) => + options.log.trace(s"which physically is $actualTo") + val value = result.find(_._1 == actualTo).fold(segContents(actualTo).right.get){ + case (_, NormalCompiledFunction(_, code, _, alignment)) => CodeAndAlignment(code, alignment) + } + result += actualTo -> NormalCompiledFunction(segmentName, + createLabel(from) :: value.code, + hasFixedAddress = false, + alignment = value.alignment + ) + case _ => + } + } + } + result.toSeq + } + + def fixDoubleRedirects(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Unit = { + var changed = true + while (changed) { + changed = false + val functionNames = compiledFunctions.keys.toSeq + for (name <- functionNames) { + compiledFunctions(name) match { + case RedirectedFunction(_, redirect, offset1) => + compiledFunctions.get(redirect) match { + case Some(r: RedirectedFunction[T]) => + compiledFunctions(name) = r.copy(offset = r.offset + offset1) + changed = true + case _ => + } + case _ => + } + } + } + } + def tco(code: List[T]): List[T] def isBadExtractedCodeHead(head: T): Boolean - def isBadExtractedCodeLast(head: T): Boolean + def isBadExtractedCodeLast(last: T): Boolean def getJump(line: T): Option[String] @@ -214,6 +286,10 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila def createLabel(name: String): T + def renumerateLabels(code: List[T], temporary: Boolean): List[T] + + def checkIfLabelsAreInternal(snippet: List[T], code: List[T]): Boolean + def bySegment(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Map[String, Map[String, Either[String, CodeAndAlignment[T]]]] = { compiledFunctions.flatMap { case (name, NormalCompiledFunction(segment, code, false, alignment)) => Some((segment, name, Right(CodeAndAlignment(code, alignment)))) // TODO diff --git a/src/main/scala/millfork/output/MosDeduplicate.scala b/src/main/scala/millfork/output/MosDeduplicate.scala index 3bdbd6e7..e920e123 100644 --- a/src/main/scala/millfork/output/MosDeduplicate.scala +++ b/src/main/scala/millfork/output/MosDeduplicate.scala @@ -1,11 +1,13 @@ package millfork.output import millfork.CompilationOptions -import millfork.assembly.mos.{AddrMode, AssemblyLine, Opcode} +import millfork.assembly.mos.{AddrMode, AssemblyLine, Opcode, OpcodeClasses} import millfork.env.{Environment, Label, MemoryAddressConstant} import Opcode._ import millfork.assembly.mos.AddrMode._ +import scala.collection.mutable + /** * @author Karol Stasiak */ @@ -43,24 +45,71 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu SED, CLD, SEC, CLC, CLV, SEI, CLI, SEP, REP, HuSAX, SAY, SXY, CLA, CLX, CLY, + JMP, BRA, BEQ, BNE, BMI, BCC, BCS, BVC, BVS, LABEL, ) - private val badAddressingModes = Set(Stack, IndexedSY, Relative) + private val badAddressingModes = Set(Stack, IndexedSY, AbsoluteIndexedX, Indirect, LongIndirect) override def isExtractable(line: AssemblyLine): Boolean = goodOpcodes(line.opcode) && !badAddressingModes(line.addrMode) override def isBadExtractedCodeHead(head: AssemblyLine): Boolean = false - override def isBadExtractedCodeLast(head: AssemblyLine): Boolean = false + override def isBadExtractedCodeLast(last: AssemblyLine): Boolean = false override def createCall(functionName: String): AssemblyLine = AssemblyLine.absolute(Opcode.JSR, Label(functionName)) override def createReturn(): AssemblyLine = AssemblyLine.implied(RTS) + def rtsPrecededByDiscards(xs: List[AssemblyLine]): Option[List[AssemblyLine]] = { + xs match { + case AssemblyLine(op, _, _, _) :: xs if OpcodeClasses.NoopDiscardsFlags(op) => rtsPrecededByDiscards(xs) + case AssemblyLine(RTS, _, _, _) :: xs => Some(xs) + case _ => None + } + } + override def tco(code: List[AssemblyLine]): List[AssemblyLine] = code match { - case (call@AssemblyLine(JSR, Absolute, _, _)) :: AssemblyLine(RTS, _, _, _) :: xs => call.copy(opcode = JMP) :: tco(xs) + case (call@AssemblyLine(JSR, Absolute | LongAbsolute, _, _)) :: xs => rtsPrecededByDiscards(xs) match { + case Some(rest) => call.copy(opcode = JMP) :: tco(rest) + case _ => call :: tco(xs) + } case x :: xs => x :: tco(xs) case Nil => Nil } + + override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): List[AssemblyLine] = { + val map = mutable.Map[String, String]() + var counter = 0 + code.foreach{ + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(x)), _) if x.startsWith(".") => + map(x) = if (temporary) ".ddtmp__" + counter else env.nextLabel("dd") + counter += 1 + case _ => + } + code.map{ + case l@AssemblyLine(_, _, MemoryAddressConstant(Label(x)), _) if map.contains(x) => + l.copy(parameter = MemoryAddressConstant(Label(map(x)))) + case l => l + } + } + + def checkIfLabelsAreInternal(snippet: List[AssemblyLine], wholeCode: List[AssemblyLine]): Boolean = { + val myLabels = mutable.Set[String]() + val useCount = mutable.Map[String, Int]() + snippet.foreach{ + case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(x)), _) => + myLabels += x + case AssemblyLine(_, _, MemoryAddressConstant(Label(x)), _) => + useCount(x) = useCount.getOrElse(x, 0) - 1 + case _ => + } + wholeCode.foreach { + case AssemblyLine(op, _, MemoryAddressConstant(Label(x)), _) if op != LABEL && myLabels(x) => + useCount(x) = useCount.getOrElse(x, 0) + 1 + case _ => + } + useCount.values.forall(_ == 0) + } + } diff --git a/src/main/scala/millfork/output/Z80Deduplicate.scala b/src/main/scala/millfork/output/Z80Deduplicate.scala index c0f2dace..06a875c5 100644 --- a/src/main/scala/millfork/output/Z80Deduplicate.scala +++ b/src/main/scala/millfork/output/Z80Deduplicate.scala @@ -6,6 +6,8 @@ import millfork.env.{Environment, Label, MemoryAddressConstant} import ZOpcode._ import millfork.node.ZRegister.SP +import scala.collection.mutable + /** * @author Karol Stasiak */ @@ -39,7 +41,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu IN_IMM, OUT_IMM, IN_C, OUT_C, LD_AHLI, LD_AHLD, LD_HLIA, LD_HLDA, LDH_AC, LDH_AD, LDH_CA, LDH_DA, - CALL, + CALL, JP, JR, LABEL, ) ++ ZOpcodeClasses.AllSingleBit private val conditionallyGoodOpcodes = Set( @@ -58,7 +60,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu override def isBadExtractedCodeHead(head: ZLine): Boolean = false - override def isBadExtractedCodeLast(head: ZLine): Boolean = head.opcode match { + override def isBadExtractedCodeLast(last: ZLine): Boolean = last.opcode match { case EI | DI | IM => true case _ => false } @@ -67,9 +69,54 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu override def createReturn(): ZLine = ZLine.implied(RET) + def retPrecededByDiscards(xs: List[ZLine]): Option[List[ZLine]] = { + xs match { + case ZLine(op, _, _, _) :: xs if ZOpcodeClasses.NoopDiscards(op) => retPrecededByDiscards(xs) + case ZLine(RET, _, _, _) :: xs => Some(xs) + case _ => None + } + } + override def tco(code: List[ZLine]): List[ZLine] = code match { - case (call@ZLine(CALL, _, _, _)) :: ZLine(RET, NoRegisters, _, _) :: xs => call.copy(opcode = JP) :: tco(xs) + case (call@ZLine(CALL, _, _, _)) :: xs => retPrecededByDiscards(xs) match { + case Some(rest) =>call.copy(opcode = JP) :: tco(rest) + case _ => call :: tco(xs) + } case x :: xs => x :: tco(xs) case Nil => Nil } + + override def renumerateLabels(code: List[ZLine], temporary: Boolean): List[ZLine] = { + val map = mutable.Map[String, String]() + var counter = 0 + code.foreach{ + case ZLine(LABEL, _, MemoryAddressConstant(Label(x)), _) if x.startsWith(".") => + map(x) = if (temporary) ".ddtmp__" + counter else env.nextLabel("dd") + counter += 1 + case _ => + } + code.map{ + case l@ZLine(_, _, MemoryAddressConstant(Label(x)), _) if map.contains(x) => + l.copy(parameter = MemoryAddressConstant(Label(map(x)))) + case l => l + } + } + + def checkIfLabelsAreInternal(snippet: List[ZLine], wholeCode: List[ZLine]): Boolean = { + val myLabels = mutable.Set[String]() + val useCount = mutable.Map[String, Int]() + snippet.foreach{ + case ZLine(LABEL, _, MemoryAddressConstant(Label(x)), _) => + myLabels += x + case ZLine(_, _, MemoryAddressConstant(Label(x)), _) => + useCount(x) = useCount.getOrElse(x, 0) - 1 + case _ => + } + wholeCode.foreach { + case ZLine(op, _, MemoryAddressConstant(Label(x)), _) if op != LABEL && myLabels(x) => + useCount(x) = useCount.getOrElse(x, 0) + 1 + case _ => + } + useCount.values.forall(_ == 0) + } } diff --git a/src/test/scala/millfork/test/DeduplicationSuite.scala b/src/test/scala/millfork/test/DeduplicationSuite.scala index 137a9721..b97bf7fd 100644 --- a/src/test/scala/millfork/test/DeduplicationSuite.scala +++ b/src/test/scala/millfork/test/DeduplicationSuite.scala @@ -69,4 +69,36 @@ class DeduplicationSuite extends FunSuite with Matchers { m.readMedium(0xc000) should equal(0x1FB00) } } + + test("Loop subroutine extraction") { + EmuSizeOptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080)( + """ + | array output [8] @$c000 + | void main() { + | f(2) + | g(3) + | h(6) + | } + | noinline void f(byte x) { + | byte i + | for i,0,until,output.length { + | output[i] = x + | } + | } + | noinline void g(byte x) { + | byte i + | for i,0,until,output.length { + | output[i] = x + | } + | } + | noinline void h(byte x) { + | byte i + | for i,0,until,output.length { + | output[i] = x + | } + | } + """.stripMargin) {m => + m.readByte(0xc000) should equal(6) + } + } }