mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-09 13:31:32 +00:00
Restrict subroutine extraction to make exponential runtime less likely
This commit is contained in:
parent
3e0dad4cb0
commit
bcb2e362b2
16
src/main/scala/millfork/Prehashed.scala
Normal file
16
src/main/scala/millfork/Prehashed.scala
Normal file
@ -0,0 +1,16 @@
|
||||
package millfork
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
class Prehashed[T](override val hashCode: Int, val value: T) {
|
||||
override def equals(obj: Any): Boolean = {
|
||||
if (!obj.isInstanceOf[Prehashed[_]]) return false
|
||||
if (obj.hashCode() != this.hashCode) return false
|
||||
obj.asInstanceOf[Prehashed[_]].value == value
|
||||
}
|
||||
}
|
||||
|
||||
object Prehashed {
|
||||
def apply[T](value: T): Prehashed[T] = new Prehashed(value.hashCode(), value)
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
package millfork.output
|
||||
|
||||
import millfork.{CompilationFlag, CompilationOptions}
|
||||
import millfork.{CompilationFlag, CompilationOptions, Prehashed}
|
||||
import millfork.assembly.AbstractCode
|
||||
import millfork.env.Environment
|
||||
|
||||
@ -60,33 +60,44 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
case (functionName, Right(CodeAndAlignment(code, _))) =>
|
||||
getExtractableSnippets(functionName, code).filter(_.codeSizeInBytes.>=(minSnippetSize)).map(code -> _)
|
||||
}
|
||||
val chunks: Seq[CodeChunk[T]] = snippets.flatMap { case (wholeCode, snippet) =>
|
||||
val chunksWithThresholds: Seq[(CodeChunk[T], Int)] = snippets.flatMap { case (wholeCode, snippet) =>
|
||||
for {
|
||||
start <- snippet.code.indices
|
||||
end <- start + 1 to snippet.code.length
|
||||
} yield wholeCode -> CodeChunk(snippet.functionName, snippet.offset + start, snippet.offset + end)(snippet.code.slice(start, end))
|
||||
}.map(_._2).filter(_.codeSizeInBytes >= minSnippetSize).groupBy { chunk =>
|
||||
renumerateLabels(chunk.code, temporary = true)
|
||||
chunk.renumerateLabels(this, temporary = true)
|
||||
}.filter {
|
||||
case (code, _) =>
|
||||
if (isBadExtractedCodeHead(code.head)) false
|
||||
else if (isBadExtractedCodeLast(code.last)) false
|
||||
if (isBadExtractedCodeHead(code.value.head)) false
|
||||
else if (isBadExtractedCodeLast(code.value.last)) false
|
||||
else true
|
||||
}.mapValues(_.toSeq).filter {
|
||||
case (_, instances) =>
|
||||
}.mapValues(_.toSeq).flatMap {
|
||||
case v@(_, instances) =>
|
||||
val chunkSize = instances.head.codeSizeInBytes
|
||||
val extractedProcedureSize = chunkSize + 1
|
||||
val savedInCallers = (chunkSize - 3) * instances.length
|
||||
val maxPossibleProfit = savedInCallers - extractedProcedureSize
|
||||
// (instances.length >=2) println(s"Instances: ${instances.length}, max profit: $maxPossibleProfit: $instances")
|
||||
maxPossibleProfit > 0 && instances.length >= 2 // TODO
|
||||
}.flatMap(_._2).toSeq
|
||||
//println(s"Chunks: ${chunks.length} $chunks")
|
||||
if (maxPossibleProfit > 0 && instances.length >= 2) {
|
||||
println(s"Instances: ${instances.length}, max profit: $maxPossibleProfit: $instances")
|
||||
instances.map(_ -> maxPossibleProfit)
|
||||
} else Nil // TODO
|
||||
}.toSeq
|
||||
if (chunksWithThresholds.isEmpty) return Nil
|
||||
var chunks = chunksWithThresholds.map(_._1)
|
||||
var threshold = 1
|
||||
while (chunks.length > 64) {
|
||||
threshold = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._2).min + 1
|
||||
chunks = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._1)
|
||||
}
|
||||
println(s"Requiring $threshold profit trimmed the chunk candidate list from ${chunksWithThresholds.size} to ${chunks.size}")
|
||||
if (chunks.length < 20) println(s"Chunks: ${chunks.length} $chunks") else println(s"Chunks: ${chunks.length}")
|
||||
val candidates: Seq[(Int, Map[List[T], Seq[CodeChunk[T]]])] = powerSet(chunks)((set, chunk) => !set.exists(_ & chunk)).filter(_.nonEmpty).filter(set => (for {
|
||||
x <- set
|
||||
y <- set
|
||||
if x != y
|
||||
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => renumerateLabels(chunk.code, temporary = true)).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map =>
|
||||
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => chunk.renumerateLabels(this, temporary = true).value).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map =>
|
||||
map.foldLeft(0) {
|
||||
(sum, entry) =>
|
||||
val chunkSize = entry._2.head.codeSizeInBytes
|
||||
@ -121,7 +132,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
}
|
||||
for((code, instances) <- best._2) {
|
||||
val newName = env.nextLabel("xc")
|
||||
val newCode = createLabel(newName) :: tco(renumerateLabels(instances.head.code, temporary = false) :+ createReturn)
|
||||
val newCode = createLabel(newName) :: tco(instances.head.renumerateLabels(this, temporary = false).value :+ createReturn)
|
||||
result += newName -> NormalCompiledFunction(segmentName, newCode, hasFixedAddress = false, alignment = NoAlignment)
|
||||
for(instance <- instances) {
|
||||
toReplace(instance.functionName)(instance.offset) = newName
|
||||
@ -301,7 +312,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
|
||||
def createLabel(name: String): T
|
||||
|
||||
def renumerateLabels(code: List[T], temporary: Boolean): List[T]
|
||||
def renumerateLabels(code: List[T], temporary: Boolean): Prehashed[List[T]]
|
||||
|
||||
def removePositionInfo(line: T): T
|
||||
|
||||
@ -346,15 +357,38 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
@annotation.tailrec
|
||||
def pwr(t: Iterable[A], ps: Set[Set[A]]): Set[Set[A]] =
|
||||
if (t.isEmpty) ps
|
||||
else pwr(t.tail, ps ++ (ps.filter(p => f(p, t.head)) map (_ + t.head)))
|
||||
pwr(t, Set(Set.empty[A]))
|
||||
else {
|
||||
println(s"Powerset size so far: ${ps.size} Remaining chunks: ${t.size}")
|
||||
pwr(t.tail, ps ++ (ps.filter(p => f(p, t.head)) map (_ + t.head)))
|
||||
}
|
||||
val ps = pwr(t, Set(Set.empty[A]))
|
||||
println("Powerset size: "+ ps.size)
|
||||
ps
|
||||
}
|
||||
}
|
||||
|
||||
case class CodeAndAlignment[T <: AbstractCode](code: List[T], alignment: MemoryAlignment)
|
||||
|
||||
case class CodeChunk[T <: AbstractCode](functionName: String, offset: Int, endOffset: Int)(val code: List[T]) {
|
||||
val codeSizeInBytes: Int = code.map(_.sizeInBytes).sum
|
||||
|
||||
private var renumerated: Prehashed[List[T]] = _
|
||||
private var lastTemporary: Boolean = false
|
||||
private var codeSizeMeasured = -1
|
||||
|
||||
def renumerateLabels(deduplicate: Deduplicate[T], temporary: Boolean): Prehashed[List[T]] = {
|
||||
if ((renumerated eq null) || lastTemporary != temporary) {
|
||||
renumerated = deduplicate.renumerateLabels(code, temporary = temporary)
|
||||
lastTemporary = temporary
|
||||
}
|
||||
renumerated
|
||||
}
|
||||
|
||||
def codeSizeInBytes: Int = {
|
||||
if (codeSizeMeasured < 0) {
|
||||
codeSizeMeasured = code.map(_.sizeInBytes).sum
|
||||
}
|
||||
codeSizeMeasured
|
||||
}
|
||||
|
||||
def &(that: CodeChunk[T]): Boolean =
|
||||
this.functionName == that.functionName &&
|
||||
|
@ -1,6 +1,6 @@
|
||||
package millfork.output
|
||||
|
||||
import millfork.CompilationOptions
|
||||
import millfork.{CompilationOptions, Prehashed}
|
||||
import millfork.assembly.mos._
|
||||
import millfork.env.{Environment, Label, MemoryAddressConstant}
|
||||
import Opcode._
|
||||
@ -78,7 +78,7 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
case Nil => Nil
|
||||
}
|
||||
|
||||
override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): List[AssemblyLine] = {
|
||||
override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): Prehashed[List[AssemblyLine]] = {
|
||||
val map = mutable.Map[String, String]()
|
||||
var counter = 0
|
||||
code.foreach{
|
||||
@ -87,11 +87,11 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
counter += 1
|
||||
case _ =>
|
||||
}
|
||||
code.map{
|
||||
Prehashed(code.map{
|
||||
case l@AssemblyLine0(_, _, MemoryAddressConstant(Label(x))) if map.contains(x) =>
|
||||
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
|
||||
case l => l
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
def checkIfLabelsAreInternal(snippet: List[AssemblyLine], wholeCode: List[AssemblyLine]): Boolean = {
|
||||
|
@ -1,6 +1,6 @@
|
||||
package millfork.output
|
||||
|
||||
import millfork.CompilationOptions
|
||||
import millfork.{CompilationOptions, Prehashed}
|
||||
import millfork.assembly.z80.{ZOpcode, _}
|
||||
import millfork.env.{Environment, Label, MemoryAddressConstant}
|
||||
import ZOpcode._
|
||||
@ -87,7 +87,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
case Nil => Nil
|
||||
}
|
||||
|
||||
override def renumerateLabels(code: List[ZLine], temporary: Boolean): List[ZLine] = {
|
||||
override def renumerateLabels(code: List[ZLine], temporary: Boolean): Prehashed[List[ZLine]] = {
|
||||
val map = mutable.Map[String, String]()
|
||||
var counter = 0
|
||||
code.foreach{
|
||||
@ -96,11 +96,11 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
counter += 1
|
||||
case _ =>
|
||||
}
|
||||
code.map{
|
||||
Prehashed(code.map{
|
||||
case l@ZLine0(_, _, MemoryAddressConstant(Label(x))) if map.contains(x) =>
|
||||
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
|
||||
case l => l
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
def checkIfLabelsAreInternal(snippet: List[ZLine], wholeCode: List[ZLine]): Boolean = {
|
||||
|
Loading…
Reference in New Issue
Block a user