From 3b7ceb390009453aef66209f06687f0dad28bf52 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Sun, 1 Dec 2019 02:39:05 +0100 Subject: [PATCH] Make subroutine extraction usable --- .../scala/millfork/output/Deduplicate.scala | 43 ++++++++++++------- .../millfork/output/MosDeduplicate.scala | 8 +++- .../millfork/output/Z80Deduplicate.scala | 5 ++- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/main/scala/millfork/output/Deduplicate.scala b/src/main/scala/millfork/output/Deduplicate.scala index f799ce99..759d225c 100644 --- a/src/main/scala/millfork/output/Deduplicate.scala +++ b/src/main/scala/millfork/output/Deduplicate.scala @@ -14,7 +14,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila def apply(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Unit = { if (options.flag(CompilationFlag.SubroutineExtraction)) { - runStage(compiledFunctions, extractCommonCode) + while(runStage(compiledFunctions, extractCommonCode)){} } if (options.flag(CompilationFlag.FunctionDeduplication)) { runStage(compiledFunctions, deduplicateIdenticalFunctions) @@ -39,13 +39,17 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila def removeChains(map: Map[String, String]): Map[String, String] = map.filterNot{case (_, to) => map.contains(to)} def runStage(compiledFunctions: mutable.Map[String, CompiledFunction[T]], - function: (String, Map[String, Either[String, CodeAndAlignment[T]]]) => Seq[(String, CompiledFunction[T])]): Unit = { + function: (String, Map[String, Either[String, CodeAndAlignment[T]]]) => Seq[(String, CompiledFunction[T])]): Boolean = { + var progress = false bySegment(compiledFunctions).foreach { case (segmentName, segContents) => - function(segmentName, segContents).foreach { + val segmentDelta = function(segmentName, segContents) + progress |= segmentDelta.nonEmpty + segmentDelta.foreach { case (fname, cf) => compiledFunctions(fname) = cf } } + progress } def extractCommonCode(segmentName: String, segContents: Map[String, Either[String, CodeAndAlignment[T]]]): Seq[(String, CompiledFunction[T])] = { @@ -58,7 +62,8 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila val snippets: Seq[(List[T], CodeChunk[T])] = segContents.toSeq.flatMap { case (_, Left(_)) => Nil case (functionName, Right(CodeAndAlignment(code, _))) => - getExtractableSnippets(functionName, code).filter(_.codeSizeInBytes.>=(minSnippetSize)).map(code -> _) + if (functionName.startsWith(".xc")) Nil + else getExtractableSnippets(functionName, code).filter(_.codeSizeInBytes.>=(minSnippetSize)).map(code -> _) } val chunksWithThresholds: Seq[(CodeChunk[T], Int)] = snippets.flatMap { case (wholeCode, snippet) => for { @@ -87,17 +92,18 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila if (chunksWithThresholds.isEmpty) return Nil var chunks = chunksWithThresholds.map(_._1) var threshold = 1 - while (chunks.length > 64) { + while (chunks.length > 100 || chunks.length > 20 && !allChunksDisjoint(chunks)) { + env.log.trace(s"Current threshold: $threshold current chunk count ${chunks.size}") threshold = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._2).min + 1 chunks = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._1) } env.log.debug(s"Requiring $threshold profit trimmed the chunk candidate list from ${chunksWithThresholds.size} to ${chunks.size}") + if (env.log.traceEnabled) { + chunksWithThresholds.filter(c => c._2 >= threshold).foreach(c => env.log.trace(c.toString)) + } if (chunks.length < 20) env.log.debug(s"Chunks: ${chunks.length} $chunks") else env.log.debug(s"Chunks: ${chunks.length}") - val candidates: Seq[(Int, Map[List[T], Seq[CodeChunk[T]]])] = powerSet(chunks)((set, chunk) => !set.exists(_ & chunk)).filter(_.nonEmpty).filter(set => (for { - x <- set - y <- set - if x != y - } yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => chunk.renumerateLabels(this, temporary = true).value).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map => + val powerset = if (allChunksDisjoint(chunks)) Seq(chunks) else powerSet(chunks)((set, chunk) => !set.exists(_ & chunk)).filter(_.nonEmpty).filter(set => allChunksDisjoint(set)).toSeq + val candidates: Seq[(Int, Map[List[T], Seq[CodeChunk[T]]])] = powerset.map(_.groupBy(chunk => chunk.renumerateLabels(this, temporary = true).value).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map => map.foldLeft(0) { (sum, entry) => val chunkSize = entry._2.head.codeSizeInBytes @@ -108,11 +114,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila } -> map }.filter { set => val allChunks = set._2.values.flatten - (for { - x <- allChunks - y <- allChunks - if x != y - } yield x & y).forall(_ == false) + allChunksDisjoint(allChunks) } // candidates.sortBy(_._1).foreach { // case (profit, map) => @@ -367,6 +369,15 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila env.log.trace("Powerset size: " + ps.size) ps } + + def allChunksDisjoint(t: Iterable[CodeChunk[T]]): Boolean = { + for (x <- t) { + for (y <- t) { + if ((x & y) && x != y) return false + } + } + true + } } case class CodeAndAlignment[T <: AbstractCode](code: List[T], alignment: MemoryAlignment) @@ -385,6 +396,7 @@ case class CodeChunk[T <: AbstractCode](functionName: String, offset: Int, endOf renumerated } + @inline def codeSizeInBytes: Int = { if (codeSizeMeasured < 0) { codeSizeMeasured = code.map(_.sizeInBytes).sum @@ -392,6 +404,7 @@ case class CodeChunk[T <: AbstractCode](functionName: String, offset: Int, endOf codeSizeMeasured } + @inline def &(that: CodeChunk[T]): Boolean = this.functionName == that.functionName && this.offset <= that.endOffset && diff --git a/src/main/scala/millfork/output/MosDeduplicate.scala b/src/main/scala/millfork/output/MosDeduplicate.scala index e77028fa..8e7ecfef 100644 --- a/src/main/scala/millfork/output/MosDeduplicate.scala +++ b/src/main/scala/millfork/output/MosDeduplicate.scala @@ -2,7 +2,7 @@ package millfork.output import millfork.{CompilationOptions, Prehashed} import millfork.assembly.mos._ -import millfork.env.{Environment, Label, MemoryAddressConstant} +import millfork.env.{Environment, Label, MemoryAddressConstant, StructureConstant} import Opcode._ import millfork.assembly.mos.AddrMode._ @@ -51,7 +51,11 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu private val badAddressingModes = Set(Stack, IndexedSY, AbsoluteIndexedX, Indirect, LongIndirect) override def isExtractable(line: AssemblyLine): Boolean = - line.elidable && goodOpcodes(line.opcode) && !badAddressingModes(line.addrMode) + line.elidable && goodOpcodes(line.opcode) && !badAddressingModes(line.addrMode) && (line.parameter match { + case MemoryAddressConstant(Label(x)) => !x.startsWith(".") + case StructureConstant(_, List(_, MemoryAddressConstant(Label(x)))) => !x.startsWith(".") + case _ => true + }) override def isBadExtractedCodeHead(head: AssemblyLine): Boolean = false diff --git a/src/main/scala/millfork/output/Z80Deduplicate.scala b/src/main/scala/millfork/output/Z80Deduplicate.scala index 274f6905..be4a8989 100644 --- a/src/main/scala/millfork/output/Z80Deduplicate.scala +++ b/src/main/scala/millfork/output/Z80Deduplicate.scala @@ -56,7 +56,10 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu case TwoRegisters(_, SP) => false case TwoRegisters(SP, _) => false case _ => true - })) + })) && (line.parameter match { + case MemoryAddressConstant(Label(x)) => !x.startsWith(".") + case _ => true + }) } override def isBadExtractedCodeHead(head: ZLine): Boolean = false