1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-26 20:33:02 +00:00

Restrict subroutine extraction to make exponential runtime less likely

This commit is contained in:
Karol Stasiak 2019-06-29 00:07:32 +02:00
parent 3e0dad4cb0
commit bcb2e362b2
4 changed files with 74 additions and 24 deletions

View File

@ -0,0 +1,16 @@
package millfork
/**
* @author Karol Stasiak
*/
class Prehashed[T](override val hashCode: Int, val value: T) {
override def equals(obj: Any): Boolean = {
if (!obj.isInstanceOf[Prehashed[_]]) return false
if (obj.hashCode() != this.hashCode) return false
obj.asInstanceOf[Prehashed[_]].value == value
}
}
object Prehashed {
def apply[T](value: T): Prehashed[T] = new Prehashed(value.hashCode(), value)
}

View File

@ -1,6 +1,6 @@
package millfork.output
import millfork.{CompilationFlag, CompilationOptions}
import millfork.{CompilationFlag, CompilationOptions, Prehashed}
import millfork.assembly.AbstractCode
import millfork.env.Environment
@ -60,33 +60,44 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
case (functionName, Right(CodeAndAlignment(code, _))) =>
getExtractableSnippets(functionName, code).filter(_.codeSizeInBytes.>=(minSnippetSize)).map(code -> _)
}
val chunks: Seq[CodeChunk[T]] = snippets.flatMap { case (wholeCode, snippet) =>
val chunksWithThresholds: Seq[(CodeChunk[T], Int)] = snippets.flatMap { case (wholeCode, snippet) =>
for {
start <- snippet.code.indices
end <- start + 1 to snippet.code.length
} yield wholeCode -> CodeChunk(snippet.functionName, snippet.offset + start, snippet.offset + end)(snippet.code.slice(start, end))
}.map(_._2).filter(_.codeSizeInBytes >= minSnippetSize).groupBy { chunk =>
renumerateLabels(chunk.code, temporary = true)
chunk.renumerateLabels(this, temporary = true)
}.filter {
case (code, _) =>
if (isBadExtractedCodeHead(code.head)) false
else if (isBadExtractedCodeLast(code.last)) false
if (isBadExtractedCodeHead(code.value.head)) false
else if (isBadExtractedCodeLast(code.value.last)) false
else true
}.mapValues(_.toSeq).filter {
case (_, instances) =>
}.mapValues(_.toSeq).flatMap {
case v@(_, instances) =>
val chunkSize = instances.head.codeSizeInBytes
val extractedProcedureSize = chunkSize + 1
val savedInCallers = (chunkSize - 3) * instances.length
val maxPossibleProfit = savedInCallers - extractedProcedureSize
// (instances.length >=2) println(s"Instances: ${instances.length}, max profit: $maxPossibleProfit: $instances")
maxPossibleProfit > 0 && instances.length >= 2 // TODO
}.flatMap(_._2).toSeq
//println(s"Chunks: ${chunks.length} $chunks")
if (maxPossibleProfit > 0 && instances.length >= 2) {
println(s"Instances: ${instances.length}, max profit: $maxPossibleProfit: $instances")
instances.map(_ -> maxPossibleProfit)
} else Nil // TODO
}.toSeq
if (chunksWithThresholds.isEmpty) return Nil
var chunks = chunksWithThresholds.map(_._1)
var threshold = 1
while (chunks.length > 64) {
threshold = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._2).min + 1
chunks = chunksWithThresholds.filter(c => c._2 >= threshold).map(_._1)
}
println(s"Requiring $threshold profit trimmed the chunk candidate list from ${chunksWithThresholds.size} to ${chunks.size}")
if (chunks.length < 20) println(s"Chunks: ${chunks.length} $chunks") else println(s"Chunks: ${chunks.length}")
val candidates: Seq[(Int, Map[List[T], Seq[CodeChunk[T]]])] = powerSet(chunks)((set, chunk) => !set.exists(_ & chunk)).filter(_.nonEmpty).filter(set => (for {
x <- set
y <- set
if x != y
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => renumerateLabels(chunk.code, temporary = true)).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map =>
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => chunk.renumerateLabels(this, temporary = true).value).filter(_._2.size >= 2).mapValues(_.toSeq).view.force).filter(_.nonEmpty).map { map =>
map.foldLeft(0) {
(sum, entry) =>
val chunkSize = entry._2.head.codeSizeInBytes
@ -121,7 +132,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
}
for((code, instances) <- best._2) {
val newName = env.nextLabel("xc")
val newCode = createLabel(newName) :: tco(renumerateLabels(instances.head.code, temporary = false) :+ createReturn)
val newCode = createLabel(newName) :: tco(instances.head.renumerateLabels(this, temporary = false).value :+ createReturn)
result += newName -> NormalCompiledFunction(segmentName, newCode, hasFixedAddress = false, alignment = NoAlignment)
for(instance <- instances) {
toReplace(instance.functionName)(instance.offset) = newName
@ -301,7 +312,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
def createLabel(name: String): T
def renumerateLabels(code: List[T], temporary: Boolean): List[T]
def renumerateLabels(code: List[T], temporary: Boolean): Prehashed[List[T]]
def removePositionInfo(line: T): T
@ -346,15 +357,38 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
@annotation.tailrec
def pwr(t: Iterable[A], ps: Set[Set[A]]): Set[Set[A]] =
if (t.isEmpty) ps
else pwr(t.tail, ps ++ (ps.filter(p => f(p, t.head)) map (_ + t.head)))
pwr(t, Set(Set.empty[A]))
else {
println(s"Powerset size so far: ${ps.size} Remaining chunks: ${t.size}")
pwr(t.tail, ps ++ (ps.filter(p => f(p, t.head)) map (_ + t.head)))
}
val ps = pwr(t, Set(Set.empty[A]))
println("Powerset size: "+ ps.size)
ps
}
}
case class CodeAndAlignment[T <: AbstractCode](code: List[T], alignment: MemoryAlignment)
case class CodeChunk[T <: AbstractCode](functionName: String, offset: Int, endOffset: Int)(val code: List[T]) {
val codeSizeInBytes: Int = code.map(_.sizeInBytes).sum
private var renumerated: Prehashed[List[T]] = _
private var lastTemporary: Boolean = false
private var codeSizeMeasured = -1
def renumerateLabels(deduplicate: Deduplicate[T], temporary: Boolean): Prehashed[List[T]] = {
if ((renumerated eq null) || lastTemporary != temporary) {
renumerated = deduplicate.renumerateLabels(code, temporary = temporary)
lastTemporary = temporary
}
renumerated
}
def codeSizeInBytes: Int = {
if (codeSizeMeasured < 0) {
codeSizeMeasured = code.map(_.sizeInBytes).sum
}
codeSizeMeasured
}
def &(that: CodeChunk[T]): Boolean =
this.functionName == that.functionName &&

View File

@ -1,6 +1,6 @@
package millfork.output
import millfork.CompilationOptions
import millfork.{CompilationOptions, Prehashed}
import millfork.assembly.mos._
import millfork.env.{Environment, Label, MemoryAddressConstant}
import Opcode._
@ -78,7 +78,7 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu
case Nil => Nil
}
override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): List[AssemblyLine] = {
override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): Prehashed[List[AssemblyLine]] = {
val map = mutable.Map[String, String]()
var counter = 0
code.foreach{
@ -87,11 +87,11 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu
counter += 1
case _ =>
}
code.map{
Prehashed(code.map{
case l@AssemblyLine0(_, _, MemoryAddressConstant(Label(x))) if map.contains(x) =>
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
case l => l
}
})
}
def checkIfLabelsAreInternal(snippet: List[AssemblyLine], wholeCode: List[AssemblyLine]): Boolean = {

View File

@ -1,6 +1,6 @@
package millfork.output
import millfork.CompilationOptions
import millfork.{CompilationOptions, Prehashed}
import millfork.assembly.z80.{ZOpcode, _}
import millfork.env.{Environment, Label, MemoryAddressConstant}
import ZOpcode._
@ -87,7 +87,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
case Nil => Nil
}
override def renumerateLabels(code: List[ZLine], temporary: Boolean): List[ZLine] = {
override def renumerateLabels(code: List[ZLine], temporary: Boolean): Prehashed[List[ZLine]] = {
val map = mutable.Map[String, String]()
var counter = 0
code.foreach{
@ -96,11 +96,11 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
counter += 1
case _ =>
}
code.map{
Prehashed(code.map{
case l@ZLine0(_, _, MemoryAddressConstant(Label(x))) if map.contains(x) =>
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
case l => l
}
})
}
def checkIfLabelsAreInternal(snippet: List[ZLine], wholeCode: List[ZLine]): Boolean = {