mirror of
https://github.com/KarolS/millfork.git
synced 2025-04-01 02:31:38 +00:00
Deduplicate more complex code. Better deduplication.
This commit is contained in:
parent
b01c440cf0
commit
2af8304512
@ -20,6 +20,8 @@ where `11111` is a sequential number and `xx` is the type:
|
||||
|
||||
* `cp` – equality comparison for larger types
|
||||
|
||||
* `dd` – labels renamed by code deduplication
|
||||
|
||||
* `de` – decrement for larger types
|
||||
|
||||
* `do` – beginning of a `do-while` statement
|
||||
|
@ -178,7 +178,8 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta
|
||||
import ZRegister._
|
||||
val inherent = opcode match {
|
||||
case BYTE => 1
|
||||
case d if ZOpcodeClasses.NoopDiscards(d) => 0
|
||||
case LABEL => return 0
|
||||
case d if ZOpcodeClasses.NoopDiscards(d) => return 0
|
||||
case JP => registers match {
|
||||
case OneRegister(HL | IX | IY) => 1
|
||||
case _ => 2
|
||||
|
@ -274,7 +274,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
|
||||
val codeAllocators = platform.codeAllocators.mapValues(new VariableAllocator(Nil, _))
|
||||
var justAfterCode = platform.codeAllocators.mapValues(a => a.startAt)
|
||||
|
||||
compiledFunctions.toList.sortBy{case (name, cf) => if (name == "main") 0 -> "" else cf.orderKey}.foreach {
|
||||
val sortedCompilerFunctions = compiledFunctions.toList.sortBy { case (name, cf) => if (name == "main") 0 -> "" else cf.orderKey }
|
||||
sortedCompilerFunctions.foreach {
|
||||
case (_, NormalCompiledFunction(_, _, true, _)) =>
|
||||
// already done before
|
||||
case (name, NormalCompiledFunction(bank, code, false, alignment)) =>
|
||||
@ -282,9 +283,12 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
|
||||
val index = codeAllocators(bank).allocateBytes(mem.banks(bank), options, size, initialized = true, writeable = false, location = AllocationLocation.High, alignment = alignment)
|
||||
labelMap(name) = index
|
||||
justAfterCode += bank -> outputFunction(bank, code, index, assembly, options)
|
||||
case (_, NonexistentFunction()) =>
|
||||
case _ =>
|
||||
}
|
||||
sortedCompilerFunctions.foreach {
|
||||
case (name, RedirectedFunction(_, target, offset)) =>
|
||||
labelMap(name) = labelMap(target) + offset
|
||||
case _ =>
|
||||
}
|
||||
|
||||
if (options.flag(CompilationFlag.LUnixRelocatableCode)) {
|
||||
|
@ -18,6 +18,14 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
}
|
||||
runStage(compiledFunctions, deduplicateIdenticalFunctions)
|
||||
runStage(compiledFunctions, eliminateTailJumps)
|
||||
runStage(compiledFunctions, eliminateRemainingTrivialTailJumps)
|
||||
fixDoubleRedirects(compiledFunctions)
|
||||
// println(compiledFunctions.map {
|
||||
// case (k, v) => k + " " + (v match {
|
||||
// case _: NormalCompiledFunction[_] => "NormalCompiledFunction"
|
||||
// case _ => v.toString
|
||||
// })
|
||||
// }.mkString(" ; "))
|
||||
}
|
||||
|
||||
def runStage(compiledFunctions: mutable.Map[String, CompiledFunction[T]],
|
||||
@ -32,18 +40,21 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
|
||||
def extractCommonCode(segmentName: String, segContents: Map[String, Either[String, CodeAndAlignment[T]]]): Seq[(String, CompiledFunction[T])] = {
|
||||
var result = ListBuffer[(String, CompiledFunction[T])]()
|
||||
val chunks = segContents.flatMap{
|
||||
val snippets: Seq[(List[T], CodeChunk[T])] = segContents.toSeq.flatMap {
|
||||
case (_, Left(_)) => Nil
|
||||
case (functionName, Right(CodeAndAlignment(code, _))) =>
|
||||
if (options.flag(CompilationFlag.OptimizeForSize)) {
|
||||
getExtractableSnippets(functionName, code)
|
||||
getExtractableSnippets(functionName, code).map(code -> _)
|
||||
} else Nil
|
||||
}.flatMap { chunk =>
|
||||
for {
|
||||
start <- chunk.code.indices
|
||||
end <- start + 1 to chunk.code.length
|
||||
} yield CodeChunk(chunk.functionName, chunk.offset + start, chunk.offset + end)(chunk.code.slice(start, end))
|
||||
}.filter(_.codeSizeInBytes > 3).groupBy(_.code).filter{
|
||||
}
|
||||
val chunks: Seq[CodeChunk[T]] = snippets.flatMap { case (wholeCode, snippet) =>
|
||||
for {
|
||||
start <- snippet.code.indices
|
||||
end <- start + 1 to snippet.code.length
|
||||
} yield wholeCode -> CodeChunk(snippet.functionName, snippet.offset + start, snippet.offset + end)(snippet.code.slice(start, end))
|
||||
}.map(_._2).filter(_.codeSizeInBytes > 3).groupBy { chunk =>
|
||||
renumerateLabels(chunk.code, temporary = true)
|
||||
}.filter {
|
||||
case (code, _) =>
|
||||
if (isBadExtractedCodeHead(code.head)) false
|
||||
else if (isBadExtractedCodeLast(code.last)) false
|
||||
@ -62,7 +73,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
x <- set
|
||||
y <- set
|
||||
if x != y
|
||||
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(_.code).filter(_._2.size >= 2).mapValues(_.toSeq)).filter(_.nonEmpty).map { map =>
|
||||
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => renumerateLabels(chunk.code, temporary = true)).filter(_._2.size >= 2).mapValues(_.toSeq)).filter(_.nonEmpty).map { map =>
|
||||
map.foldLeft(0) {
|
||||
(sum, entry) =>
|
||||
val chunkSize = entry._2.head.codeSizeInBytes
|
||||
@ -96,7 +107,8 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
}
|
||||
for((code, instances) <- best._2) {
|
||||
val newName = env.nextLabel("xc")
|
||||
result += newName -> NormalCompiledFunction(segmentName, createLabel(newName) :: tco(code :+ createReturn), hasFixedAddress = false, alignment = NoAlignment)
|
||||
val newCode = createLabel(newName) :: tco(renumerateLabels(instances.head.code, temporary = false) :+ createReturn)
|
||||
result += newName -> NormalCompiledFunction(segmentName, newCode, hasFixedAddress = false, alignment = NoAlignment)
|
||||
for(instance <- instances) {
|
||||
toReplace(instance.functionName)(instance.offset) = newName
|
||||
for (i <- instance.offset + 1 until instance.endOffset) {
|
||||
@ -127,7 +139,11 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
var result = ListBuffer[(String, CompiledFunction[T])]()
|
||||
val identicalFunctions = segContents.flatMap{
|
||||
case (name, code) => code.toOption.map(c => name -> actualCode(name, c.code))
|
||||
}.groupBy(_._2).values.toSeq.map(_.keySet).filter(set => set.size > 1)
|
||||
}.filter{
|
||||
case (_, code) => checkIfLabelsAreInternal(code, code)
|
||||
}.groupBy{
|
||||
case (_, code) => renumerateLabels(code, temporary = true)
|
||||
}.values.toSeq.map(_.keySet).filter(set => set.size > 1)
|
||||
for(set <- identicalFunctions) {
|
||||
val representative = if (set("main")) "main" else set.head
|
||||
options.log.debug(s"Functions [${set.mkString(",")}] are identical")
|
||||
@ -153,7 +169,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
var result: String = to
|
||||
val visited = mutable.Set[String]()
|
||||
do {
|
||||
segContents.get(to) match {
|
||||
segContents.get(result) match {
|
||||
case Some(Left(next)) =>
|
||||
if (visited(next)) return None
|
||||
visited += result
|
||||
@ -200,11 +216,67 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
result.toSeq
|
||||
}
|
||||
|
||||
def eliminateRemainingTrivialTailJumps(segmentName: String, segContents: Map[String, Either[String, CodeAndAlignment[T]]]): Seq[(String, CompiledFunction[T])] = {
|
||||
var result = ListBuffer[(String, CompiledFunction[T])]()
|
||||
val fallThroughList = segContents.flatMap {
|
||||
case (name, Right(CodeAndAlignment(code, alignment))) =>
|
||||
if (code.length != 2) None
|
||||
else getJump(code.last)
|
||||
.filter(segContents.contains)
|
||||
.filter(_ != name)
|
||||
.filter(_ != "main")
|
||||
.map(name -> _)
|
||||
case _ => None
|
||||
}
|
||||
val fallthroughPredecessors = fallThroughList.groupBy(_._2).mapValues(_.keySet)
|
||||
fallthroughPredecessors.foreach {
|
||||
case (to, froms) =>
|
||||
for (from <- froms) {
|
||||
options.log.debug(s"Trivial fallthrough from $from to $to")
|
||||
result += from -> RedirectedFunction(segmentName, to, 0)
|
||||
follow(segContents, to) match {
|
||||
case Some(actualTo) =>
|
||||
options.log.trace(s"which physically is $actualTo")
|
||||
val value = result.find(_._1 == actualTo).fold(segContents(actualTo).right.get){
|
||||
case (_, NormalCompiledFunction(_, code, _, alignment)) => CodeAndAlignment(code, alignment)
|
||||
}
|
||||
result += actualTo -> NormalCompiledFunction(segmentName,
|
||||
createLabel(from) :: value.code,
|
||||
hasFixedAddress = false,
|
||||
alignment = value.alignment
|
||||
)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
result.toSeq
|
||||
}
|
||||
|
||||
def fixDoubleRedirects(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Unit = {
|
||||
var changed = true
|
||||
while (changed) {
|
||||
changed = false
|
||||
val functionNames = compiledFunctions.keys.toSeq
|
||||
for (name <- functionNames) {
|
||||
compiledFunctions(name) match {
|
||||
case RedirectedFunction(_, redirect, offset1) =>
|
||||
compiledFunctions.get(redirect) match {
|
||||
case Some(r: RedirectedFunction[T]) =>
|
||||
compiledFunctions(name) = r.copy(offset = r.offset + offset1)
|
||||
changed = true
|
||||
case _ =>
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def tco(code: List[T]): List[T]
|
||||
|
||||
def isBadExtractedCodeHead(head: T): Boolean
|
||||
|
||||
def isBadExtractedCodeLast(head: T): Boolean
|
||||
def isBadExtractedCodeLast(last: T): Boolean
|
||||
|
||||
def getJump(line: T): Option[String]
|
||||
|
||||
@ -214,6 +286,10 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
|
||||
|
||||
def createLabel(name: String): T
|
||||
|
||||
def renumerateLabels(code: List[T], temporary: Boolean): List[T]
|
||||
|
||||
def checkIfLabelsAreInternal(snippet: List[T], code: List[T]): Boolean
|
||||
|
||||
def bySegment(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Map[String, Map[String, Either[String, CodeAndAlignment[T]]]] = {
|
||||
compiledFunctions.flatMap {
|
||||
case (name, NormalCompiledFunction(segment, code, false, alignment)) => Some((segment, name, Right(CodeAndAlignment(code, alignment)))) // TODO
|
||||
|
@ -1,11 +1,13 @@
|
||||
package millfork.output
|
||||
|
||||
import millfork.CompilationOptions
|
||||
import millfork.assembly.mos.{AddrMode, AssemblyLine, Opcode}
|
||||
import millfork.assembly.mos.{AddrMode, AssemblyLine, Opcode, OpcodeClasses}
|
||||
import millfork.env.{Environment, Label, MemoryAddressConstant}
|
||||
import Opcode._
|
||||
import millfork.assembly.mos.AddrMode._
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
@ -43,24 +45,71 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
SED, CLD, SEC, CLC, CLV, SEI, CLI, SEP, REP,
|
||||
HuSAX, SAY, SXY,
|
||||
CLA, CLX, CLY,
|
||||
JMP, BRA, BEQ, BNE, BMI, BCC, BCS, BVC, BVS, LABEL,
|
||||
)
|
||||
|
||||
private val badAddressingModes = Set(Stack, IndexedSY, Relative)
|
||||
private val badAddressingModes = Set(Stack, IndexedSY, AbsoluteIndexedX, Indirect, LongIndirect)
|
||||
|
||||
override def isExtractable(line: AssemblyLine): Boolean =
|
||||
goodOpcodes(line.opcode) && !badAddressingModes(line.addrMode)
|
||||
|
||||
override def isBadExtractedCodeHead(head: AssemblyLine): Boolean = false
|
||||
|
||||
override def isBadExtractedCodeLast(head: AssemblyLine): Boolean = false
|
||||
override def isBadExtractedCodeLast(last: AssemblyLine): Boolean = false
|
||||
|
||||
override def createCall(functionName: String): AssemblyLine = AssemblyLine.absolute(Opcode.JSR, Label(functionName))
|
||||
|
||||
override def createReturn(): AssemblyLine = AssemblyLine.implied(RTS)
|
||||
|
||||
def rtsPrecededByDiscards(xs: List[AssemblyLine]): Option[List[AssemblyLine]] = {
|
||||
xs match {
|
||||
case AssemblyLine(op, _, _, _) :: xs if OpcodeClasses.NoopDiscardsFlags(op) => rtsPrecededByDiscards(xs)
|
||||
case AssemblyLine(RTS, _, _, _) :: xs => Some(xs)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
override def tco(code: List[AssemblyLine]): List[AssemblyLine] = code match {
|
||||
case (call@AssemblyLine(JSR, Absolute, _, _)) :: AssemblyLine(RTS, _, _, _) :: xs => call.copy(opcode = JMP) :: tco(xs)
|
||||
case (call@AssemblyLine(JSR, Absolute | LongAbsolute, _, _)) :: xs => rtsPrecededByDiscards(xs) match {
|
||||
case Some(rest) => call.copy(opcode = JMP) :: tco(rest)
|
||||
case _ => call :: tco(xs)
|
||||
}
|
||||
case x :: xs => x :: tco(xs)
|
||||
case Nil => Nil
|
||||
}
|
||||
|
||||
override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): List[AssemblyLine] = {
|
||||
val map = mutable.Map[String, String]()
|
||||
var counter = 0
|
||||
code.foreach{
|
||||
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(x)), _) if x.startsWith(".") =>
|
||||
map(x) = if (temporary) ".ddtmp__" + counter else env.nextLabel("dd")
|
||||
counter += 1
|
||||
case _ =>
|
||||
}
|
||||
code.map{
|
||||
case l@AssemblyLine(_, _, MemoryAddressConstant(Label(x)), _) if map.contains(x) =>
|
||||
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
|
||||
case l => l
|
||||
}
|
||||
}
|
||||
|
||||
def checkIfLabelsAreInternal(snippet: List[AssemblyLine], wholeCode: List[AssemblyLine]): Boolean = {
|
||||
val myLabels = mutable.Set[String]()
|
||||
val useCount = mutable.Map[String, Int]()
|
||||
snippet.foreach{
|
||||
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(x)), _) =>
|
||||
myLabels += x
|
||||
case AssemblyLine(_, _, MemoryAddressConstant(Label(x)), _) =>
|
||||
useCount(x) = useCount.getOrElse(x, 0) - 1
|
||||
case _ =>
|
||||
}
|
||||
wholeCode.foreach {
|
||||
case AssemblyLine(op, _, MemoryAddressConstant(Label(x)), _) if op != LABEL && myLabels(x) =>
|
||||
useCount(x) = useCount.getOrElse(x, 0) + 1
|
||||
case _ =>
|
||||
}
|
||||
useCount.values.forall(_ == 0)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -6,6 +6,8 @@ import millfork.env.{Environment, Label, MemoryAddressConstant}
|
||||
import ZOpcode._
|
||||
import millfork.node.ZRegister.SP
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
@ -39,7 +41,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
IN_IMM, OUT_IMM, IN_C, OUT_C,
|
||||
LD_AHLI, LD_AHLD, LD_HLIA, LD_HLDA,
|
||||
LDH_AC, LDH_AD, LDH_CA, LDH_DA,
|
||||
CALL,
|
||||
CALL, JP, JR, LABEL,
|
||||
) ++ ZOpcodeClasses.AllSingleBit
|
||||
|
||||
private val conditionallyGoodOpcodes = Set(
|
||||
@ -58,7 +60,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
|
||||
override def isBadExtractedCodeHead(head: ZLine): Boolean = false
|
||||
|
||||
override def isBadExtractedCodeLast(head: ZLine): Boolean = head.opcode match {
|
||||
override def isBadExtractedCodeLast(last: ZLine): Boolean = last.opcode match {
|
||||
case EI | DI | IM => true
|
||||
case _ => false
|
||||
}
|
||||
@ -67,9 +69,54 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
|
||||
|
||||
override def createReturn(): ZLine = ZLine.implied(RET)
|
||||
|
||||
def retPrecededByDiscards(xs: List[ZLine]): Option[List[ZLine]] = {
|
||||
xs match {
|
||||
case ZLine(op, _, _, _) :: xs if ZOpcodeClasses.NoopDiscards(op) => retPrecededByDiscards(xs)
|
||||
case ZLine(RET, _, _, _) :: xs => Some(xs)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
override def tco(code: List[ZLine]): List[ZLine] = code match {
|
||||
case (call@ZLine(CALL, _, _, _)) :: ZLine(RET, NoRegisters, _, _) :: xs => call.copy(opcode = JP) :: tco(xs)
|
||||
case (call@ZLine(CALL, _, _, _)) :: xs => retPrecededByDiscards(xs) match {
|
||||
case Some(rest) =>call.copy(opcode = JP) :: tco(rest)
|
||||
case _ => call :: tco(xs)
|
||||
}
|
||||
case x :: xs => x :: tco(xs)
|
||||
case Nil => Nil
|
||||
}
|
||||
|
||||
override def renumerateLabels(code: List[ZLine], temporary: Boolean): List[ZLine] = {
|
||||
val map = mutable.Map[String, String]()
|
||||
var counter = 0
|
||||
code.foreach{
|
||||
case ZLine(LABEL, _, MemoryAddressConstant(Label(x)), _) if x.startsWith(".") =>
|
||||
map(x) = if (temporary) ".ddtmp__" + counter else env.nextLabel("dd")
|
||||
counter += 1
|
||||
case _ =>
|
||||
}
|
||||
code.map{
|
||||
case l@ZLine(_, _, MemoryAddressConstant(Label(x)), _) if map.contains(x) =>
|
||||
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
|
||||
case l => l
|
||||
}
|
||||
}
|
||||
|
||||
def checkIfLabelsAreInternal(snippet: List[ZLine], wholeCode: List[ZLine]): Boolean = {
|
||||
val myLabels = mutable.Set[String]()
|
||||
val useCount = mutable.Map[String, Int]()
|
||||
snippet.foreach{
|
||||
case ZLine(LABEL, _, MemoryAddressConstant(Label(x)), _) =>
|
||||
myLabels += x
|
||||
case ZLine(_, _, MemoryAddressConstant(Label(x)), _) =>
|
||||
useCount(x) = useCount.getOrElse(x, 0) - 1
|
||||
case _ =>
|
||||
}
|
||||
wholeCode.foreach {
|
||||
case ZLine(op, _, MemoryAddressConstant(Label(x)), _) if op != LABEL && myLabels(x) =>
|
||||
useCount(x) = useCount.getOrElse(x, 0) + 1
|
||||
case _ =>
|
||||
}
|
||||
useCount.values.forall(_ == 0)
|
||||
}
|
||||
}
|
||||
|
@ -69,4 +69,36 @@ class DeduplicationSuite extends FunSuite with Matchers {
|
||||
m.readMedium(0xc000) should equal(0x1FB00)
|
||||
}
|
||||
}
|
||||
|
||||
test("Loop subroutine extraction") {
|
||||
EmuSizeOptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080)(
|
||||
"""
|
||||
| array output [8] @$c000
|
||||
| void main() {
|
||||
| f(2)
|
||||
| g(3)
|
||||
| h(6)
|
||||
| }
|
||||
| noinline void f(byte x) {
|
||||
| byte i
|
||||
| for i,0,until,output.length {
|
||||
| output[i] = x
|
||||
| }
|
||||
| }
|
||||
| noinline void g(byte x) {
|
||||
| byte i
|
||||
| for i,0,until,output.length {
|
||||
| output[i] = x
|
||||
| }
|
||||
| }
|
||||
| noinline void h(byte x) {
|
||||
| byte i
|
||||
| for i,0,until,output.length {
|
||||
| output[i] = x
|
||||
| }
|
||||
| }
|
||||
""".stripMargin) {m =>
|
||||
m.readByte(0xc000) should equal(6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user