diff --git a/src/main/scala/millfork/Prehashed.scala b/src/main/scala/millfork/Prehashed.scala new file mode 100644 index 00000000..89dc1dcf --- /dev/null +++ b/src/main/scala/millfork/Prehashed.scala @@ -0,0 +1,16 @@ +package millfork + +/** + * @author Karol Stasiak + */ +class Prehashed[T](override val hashCode: Int, val value: T) { + override def equals(obj: Any): Boolean = { + if (!obj.isInstanceOf[Prehashed[_]]) return false + if (obj.hashCode() != this.hashCode) return false + obj.asInstanceOf[Prehashed[_]].value == value + } +} + +object Prehashed { + def apply[T](value: T): Prehashed[T] = new Prehashed(value.hashCode(), value) +} diff --git a/src/main/scala/millfork/output/Deduplicate.scala b/src/main/scala/millfork/output/Deduplicate.scala index 92ad819d..8642e001 100644 --- a/src/main/scala/millfork/output/Deduplicate.scala +++ b/src/main/scala/millfork/output/Deduplicate.scala @@ -1,6 +1,6 @@ package millfork.output -import millfork.{CompilationFlag, CompilationOptions} +import millfork.{CompilationFlag, CompilationOptions, Prehashed} import millfork.assembly.AbstractCode import millfork.env.Environment @@ -60,33 +60,44 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila case (functionName, Right(CodeAndAlignment(code, _))) => getExtractableSnippets(functionName, code).filter(_.codeSizeInBytes.>=(minSnippetSize)).map(code -> _) } - val chunks: Seq[CodeChunk[T]] = snippets.flatMap { case (wholeCode, snippet) => + val chunksWithThresholds: Seq[(CodeChunk[T], Int)] = snippets.flatMap { case (wholeCode, snippet) => for { start <- snippet.code.indices end <- start + 1 to snippet.code.length } yield wholeCode -> CodeChunk(snippet.functionName, snippet.offset + start, snippet.offset + end)(snippet.code.slice(start, end)) }.map(_._2).filter(_.codeSizeInBytes >= minSnippetSize).groupBy { chunk => - renumerateLabels(chunk.code, temporary = true) + chunk.renumerateLabels(this, temporary = true) }.filter { case (code, _) => - if (isBadExtractedCodeHead(code.head)) false - else if (isBadExtractedCodeLast(code.last)) false + if (isBadExtractedCodeHead(code.value.head)) false + else if (isBadExtractedCodeLast(code.value.last)) false else true - }.mapValues(_.toSeq).filter { - case (_, instances) => + }.mapValues(_.toSeq).flatMap { + case v@(_, instances) => val chunkSize = instances.head.codeSizeInBytes val extractedProcedureSize = chunkSize + 1 val savedInCallers = (chunkSize - 3) * instances.length val maxPossibleProfit = savedInCallers - extractedProcedureSize // (instances.length >=2) println(s"Instances: ${instances.length}, max profit: $maxPossibleProfit: $instances") - maxPossibleProfit > 0 && instances.length >= 2 // TODO - }.flatMap(_._2).toSeq - //println(s"Chunks: ${chunks.length} $chunks") + if (maxPossibleProfit > 0 && instances.length >= 2) { + println(s"Instances: ${instances.length}, max profit: $maxPossibleProfit: $instances") + instances.map(_ -> maxPossibleProfit) + } else Nil // TODO + }.toSeq + if (chunksWithThresholds.isEmpty) return Nil + var chunks = chunksWithThresholds.map(_._1) + var threshold = 1 + while (chunks.length > 64) { + threshold = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._2).min + 1 + chunks = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._1) + } + println(s"Requiring $threshold profit trimmed the chunk candidate list from ${chunksWithThresholds.size} to ${chunks.size}") + if (chunks.length < 20) println(s"Chunks: ${chunks.length} $chunks") else println(s"Chunks: ${chunks.length}") val candidates: Seq[(Int, Map[List[T], Seq[CodeChunk[T]]])] = powerSet(chunks)((set, chunk) => !set.exists(_ & chunk)).filter(_.nonEmpty).filter(set => (for { x <- set y <- set if x != y - } yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => renumerateLabels(chunk.code, temporary = true)).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map => + } yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => chunk.renumerateLabels(this, temporary = true).value).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map => map.foldLeft(0) { (sum, entry) => val chunkSize = entry._2.head.codeSizeInBytes @@ -121,7 +132,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila } for((code, instances) <- best._2) { val newName = env.nextLabel("xc") - val newCode = createLabel(newName) :: tco(renumerateLabels(instances.head.code, temporary = false) :+ createReturn) + val newCode = createLabel(newName) :: tco(instances.head.renumerateLabels(this, temporary = false).value :+ createReturn) result += newName -> NormalCompiledFunction(segmentName, newCode, hasFixedAddress = false, alignment = NoAlignment) for(instance <- instances) { toReplace(instance.functionName)(instance.offset) = newName @@ -301,7 +312,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila def createLabel(name: String): T - def renumerateLabels(code: List[T], temporary: Boolean): List[T] + def renumerateLabels(code: List[T], temporary: Boolean): Prehashed[List[T]] def removePositionInfo(line: T): T @@ -346,15 +357,38 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila @annotation.tailrec def pwr(t: Iterable[A], ps: Set[Set[A]]): Set[Set[A]] = if (t.isEmpty) ps - else pwr(t.tail, ps ++ (ps.filter(p => f(p, t.head)) map (_ + t.head))) - pwr(t, Set(Set.empty[A])) + else { + println(s"Powerset size so far: ${ps.size} Remaining chunks: ${t.size}") + pwr(t.tail, ps ++ (ps.filter(p => f(p, t.head)) map (_ + t.head))) + } + val ps = pwr(t, Set(Set.empty[A])) + println("Powerset size: "+ ps.size) + ps } } case class CodeAndAlignment[T <: AbstractCode](code: List[T], alignment: MemoryAlignment) case class CodeChunk[T <: AbstractCode](functionName: String, offset: Int, endOffset: Int)(val code: List[T]) { - val codeSizeInBytes: Int = code.map(_.sizeInBytes).sum + + private var renumerated: Prehashed[List[T]] = _ + private var lastTemporary: Boolean = false + private var codeSizeMeasured = -1 + + def renumerateLabels(deduplicate: Deduplicate[T], temporary: Boolean): Prehashed[List[T]] = { + if ((renumerated eq null) || lastTemporary != temporary) { + renumerated = deduplicate.renumerateLabels(code, temporary = temporary) + lastTemporary = temporary + } + renumerated + } + + def codeSizeInBytes: Int = { + if (codeSizeMeasured < 0) { + codeSizeMeasured = code.map(_.sizeInBytes).sum + } + codeSizeMeasured + } def &(that: CodeChunk[T]): Boolean = this.functionName == that.functionName && diff --git a/src/main/scala/millfork/output/MosDeduplicate.scala b/src/main/scala/millfork/output/MosDeduplicate.scala index cec21510..e77028fa 100644 --- a/src/main/scala/millfork/output/MosDeduplicate.scala +++ b/src/main/scala/millfork/output/MosDeduplicate.scala @@ -1,6 +1,6 @@ package millfork.output -import millfork.CompilationOptions +import millfork.{CompilationOptions, Prehashed} import millfork.assembly.mos._ import millfork.env.{Environment, Label, MemoryAddressConstant} import Opcode._ @@ -78,7 +78,7 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu case Nil => Nil } - override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): List[AssemblyLine] = { + override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): Prehashed[List[AssemblyLine]] = { val map = mutable.Map[String, String]() var counter = 0 code.foreach{ @@ -87,11 +87,11 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu counter += 1 case _ => } - code.map{ + Prehashed(code.map{ case l@AssemblyLine0(_, _, MemoryAddressConstant(Label(x))) if map.contains(x) => l.copy(parameter = MemoryAddressConstant(Label(map(x)))) case l => l - } + }) } def checkIfLabelsAreInternal(snippet: List[AssemblyLine], wholeCode: List[AssemblyLine]): Boolean = { diff --git a/src/main/scala/millfork/output/Z80Deduplicate.scala b/src/main/scala/millfork/output/Z80Deduplicate.scala index adb1cc8f..274f6905 100644 --- a/src/main/scala/millfork/output/Z80Deduplicate.scala +++ b/src/main/scala/millfork/output/Z80Deduplicate.scala @@ -1,6 +1,6 @@ package millfork.output -import millfork.CompilationOptions +import millfork.{CompilationOptions, Prehashed} import millfork.assembly.z80.{ZOpcode, _} import millfork.env.{Environment, Label, MemoryAddressConstant} import ZOpcode._ @@ -87,7 +87,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu case Nil => Nil } - override def renumerateLabels(code: List[ZLine], temporary: Boolean): List[ZLine] = { + override def renumerateLabels(code: List[ZLine], temporary: Boolean): Prehashed[List[ZLine]] = { val map = mutable.Map[String, String]() var counter = 0 code.foreach{ @@ -96,11 +96,11 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu counter += 1 case _ => } - code.map{ + Prehashed(code.map{ case l@ZLine0(_, _, MemoryAddressConstant(Label(x))) if map.contains(x) => l.copy(parameter = MemoryAddressConstant(Label(map(x)))) case l => l - } + }) } def checkIfLabelsAreInternal(snippet: List[ZLine], wholeCode: List[ZLine]): Boolean = {