1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-06-25 19:29:49 +00:00

Deduplicate more complex code. Better deduplication.

This commit is contained in:
Karol Stasiak 2018-08-08 01:53:47 +02:00
parent b01c440cf0
commit 2af8304512
7 changed files with 234 additions and 23 deletions

View File

@ -20,6 +20,8 @@ where `11111` is a sequential number and `xx` is the type:
* `cp` equality comparison for larger types
* `dd` labels renamed by code deduplication
* `de` decrement for larger types
* `do` beginning of a `do-while` statement

View File

@ -178,7 +178,8 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta
import ZRegister._
val inherent = opcode match {
case BYTE => 1
case d if ZOpcodeClasses.NoopDiscards(d) => 0
case LABEL => return 0
case d if ZOpcodeClasses.NoopDiscards(d) => return 0
case JP => registers match {
case OneRegister(HL | IX | IY) => 1
case _ => 2

View File

@ -274,7 +274,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
val codeAllocators = platform.codeAllocators.mapValues(new VariableAllocator(Nil, _))
var justAfterCode = platform.codeAllocators.mapValues(a => a.startAt)
compiledFunctions.toList.sortBy{case (name, cf) => if (name == "main") 0 -> "" else cf.orderKey}.foreach {
val sortedCompilerFunctions = compiledFunctions.toList.sortBy { case (name, cf) => if (name == "main") 0 -> "" else cf.orderKey }
sortedCompilerFunctions.foreach {
case (_, NormalCompiledFunction(_, _, true, _)) =>
// already done before
case (name, NormalCompiledFunction(bank, code, false, alignment)) =>
@ -282,9 +283,12 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
val index = codeAllocators(bank).allocateBytes(mem.banks(bank), options, size, initialized = true, writeable = false, location = AllocationLocation.High, alignment = alignment)
labelMap(name) = index
justAfterCode += bank -> outputFunction(bank, code, index, assembly, options)
case (_, NonexistentFunction()) =>
case _ =>
}
sortedCompilerFunctions.foreach {
case (name, RedirectedFunction(_, target, offset)) =>
labelMap(name) = labelMap(target) + offset
case _ =>
}
if (options.flag(CompilationFlag.LUnixRelocatableCode)) {

View File

@ -18,6 +18,14 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
}
runStage(compiledFunctions, deduplicateIdenticalFunctions)
runStage(compiledFunctions, eliminateTailJumps)
runStage(compiledFunctions, eliminateRemainingTrivialTailJumps)
fixDoubleRedirects(compiledFunctions)
// println(compiledFunctions.map {
// case (k, v) => k + " " + (v match {
// case _: NormalCompiledFunction[_] => "NormalCompiledFunction"
// case _ => v.toString
// })
// }.mkString(" ; "))
}
def runStage(compiledFunctions: mutable.Map[String, CompiledFunction[T]],
@ -32,18 +40,21 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
def extractCommonCode(segmentName: String, segContents: Map[String, Either[String, CodeAndAlignment[T]]]): Seq[(String, CompiledFunction[T])] = {
var result = ListBuffer[(String, CompiledFunction[T])]()
val chunks = segContents.flatMap{
val snippets: Seq[(List[T], CodeChunk[T])] = segContents.toSeq.flatMap {
case (_, Left(_)) => Nil
case (functionName, Right(CodeAndAlignment(code, _))) =>
if (options.flag(CompilationFlag.OptimizeForSize)) {
getExtractableSnippets(functionName, code)
getExtractableSnippets(functionName, code).map(code -> _)
} else Nil
}.flatMap { chunk =>
for {
start <- chunk.code.indices
end <- start + 1 to chunk.code.length
} yield CodeChunk(chunk.functionName, chunk.offset + start, chunk.offset + end)(chunk.code.slice(start, end))
}.filter(_.codeSizeInBytes > 3).groupBy(_.code).filter{
}
val chunks: Seq[CodeChunk[T]] = snippets.flatMap { case (wholeCode, snippet) =>
for {
start <- snippet.code.indices
end <- start + 1 to snippet.code.length
} yield wholeCode -> CodeChunk(snippet.functionName, snippet.offset + start, snippet.offset + end)(snippet.code.slice(start, end))
}.map(_._2).filter(_.codeSizeInBytes > 3).groupBy { chunk =>
renumerateLabels(chunk.code, temporary = true)
}.filter {
case (code, _) =>
if (isBadExtractedCodeHead(code.head)) false
else if (isBadExtractedCodeLast(code.last)) false
@ -62,7 +73,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
x <- set
y <- set
if x != y
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(_.code).filter(_._2.size >= 2).mapValues(_.toSeq)).filter(_.nonEmpty).map { map =>
} yield x & y).forall(_ == false)).toSeq.map(_.groupBy(chunk => renumerateLabels(chunk.code, temporary = true)).filter(_._2.size >= 2).mapValues(_.toSeq)).filter(_.nonEmpty).map { map =>
map.foldLeft(0) {
(sum, entry) =>
val chunkSize = entry._2.head.codeSizeInBytes
@ -96,7 +107,8 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
}
for((code, instances) <- best._2) {
val newName = env.nextLabel("xc")
result += newName -> NormalCompiledFunction(segmentName, createLabel(newName) :: tco(code :+ createReturn), hasFixedAddress = false, alignment = NoAlignment)
val newCode = createLabel(newName) :: tco(renumerateLabels(instances.head.code, temporary = false) :+ createReturn)
result += newName -> NormalCompiledFunction(segmentName, newCode, hasFixedAddress = false, alignment = NoAlignment)
for(instance <- instances) {
toReplace(instance.functionName)(instance.offset) = newName
for (i <- instance.offset + 1 until instance.endOffset) {
@ -127,7 +139,11 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
var result = ListBuffer[(String, CompiledFunction[T])]()
val identicalFunctions = segContents.flatMap{
case (name, code) => code.toOption.map(c => name -> actualCode(name, c.code))
}.groupBy(_._2).values.toSeq.map(_.keySet).filter(set => set.size > 1)
}.filter{
case (_, code) => checkIfLabelsAreInternal(code, code)
}.groupBy{
case (_, code) => renumerateLabels(code, temporary = true)
}.values.toSeq.map(_.keySet).filter(set => set.size > 1)
for(set <- identicalFunctions) {
val representative = if (set("main")) "main" else set.head
options.log.debug(s"Functions [${set.mkString(",")}] are identical")
@ -153,7 +169,7 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
var result: String = to
val visited = mutable.Set[String]()
do {
segContents.get(to) match {
segContents.get(result) match {
case Some(Left(next)) =>
if (visited(next)) return None
visited += result
@ -200,11 +216,67 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
result.toSeq
}
def eliminateRemainingTrivialTailJumps(segmentName: String, segContents: Map[String, Either[String, CodeAndAlignment[T]]]): Seq[(String, CompiledFunction[T])] = {
var result = ListBuffer[(String, CompiledFunction[T])]()
val fallThroughList = segContents.flatMap {
case (name, Right(CodeAndAlignment(code, alignment))) =>
if (code.length != 2) None
else getJump(code.last)
.filter(segContents.contains)
.filter(_ != name)
.filter(_ != "main")
.map(name -> _)
case _ => None
}
val fallthroughPredecessors = fallThroughList.groupBy(_._2).mapValues(_.keySet)
fallthroughPredecessors.foreach {
case (to, froms) =>
for (from <- froms) {
options.log.debug(s"Trivial fallthrough from $from to $to")
result += from -> RedirectedFunction(segmentName, to, 0)
follow(segContents, to) match {
case Some(actualTo) =>
options.log.trace(s"which physically is $actualTo")
val value = result.find(_._1 == actualTo).fold(segContents(actualTo).right.get){
case (_, NormalCompiledFunction(_, code, _, alignment)) => CodeAndAlignment(code, alignment)
}
result += actualTo -> NormalCompiledFunction(segmentName,
createLabel(from) :: value.code,
hasFixedAddress = false,
alignment = value.alignment
)
case _ =>
}
}
}
result.toSeq
}
def fixDoubleRedirects(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Unit = {
var changed = true
while (changed) {
changed = false
val functionNames = compiledFunctions.keys.toSeq
for (name <- functionNames) {
compiledFunctions(name) match {
case RedirectedFunction(_, redirect, offset1) =>
compiledFunctions.get(redirect) match {
case Some(r: RedirectedFunction[T]) =>
compiledFunctions(name) = r.copy(offset = r.offset + offset1)
changed = true
case _ =>
}
case _ =>
}
}
}
}
def tco(code: List[T]): List[T]
def isBadExtractedCodeHead(head: T): Boolean
def isBadExtractedCodeLast(head: T): Boolean
def isBadExtractedCodeLast(last: T): Boolean
def getJump(line: T): Option[String]
@ -214,6 +286,10 @@ abstract class Deduplicate[T <: AbstractCode](env: Environment, options: Compila
def createLabel(name: String): T
def renumerateLabels(code: List[T], temporary: Boolean): List[T]
def checkIfLabelsAreInternal(snippet: List[T], code: List[T]): Boolean
def bySegment(compiledFunctions: mutable.Map[String, CompiledFunction[T]]): Map[String, Map[String, Either[String, CodeAndAlignment[T]]]] = {
compiledFunctions.flatMap {
case (name, NormalCompiledFunction(segment, code, false, alignment)) => Some((segment, name, Right(CodeAndAlignment(code, alignment)))) // TODO

View File

@ -1,11 +1,13 @@
package millfork.output
import millfork.CompilationOptions
import millfork.assembly.mos.{AddrMode, AssemblyLine, Opcode}
import millfork.assembly.mos.{AddrMode, AssemblyLine, Opcode, OpcodeClasses}
import millfork.env.{Environment, Label, MemoryAddressConstant}
import Opcode._
import millfork.assembly.mos.AddrMode._
import scala.collection.mutable
/**
* @author Karol Stasiak
*/
@ -43,24 +45,71 @@ class MosDeduplicate(env: Environment, options: CompilationOptions) extends Dedu
SED, CLD, SEC, CLC, CLV, SEI, CLI, SEP, REP,
HuSAX, SAY, SXY,
CLA, CLX, CLY,
JMP, BRA, BEQ, BNE, BMI, BCC, BCS, BVC, BVS, LABEL,
)
private val badAddressingModes = Set(Stack, IndexedSY, Relative)
private val badAddressingModes = Set(Stack, IndexedSY, AbsoluteIndexedX, Indirect, LongIndirect)
override def isExtractable(line: AssemblyLine): Boolean =
goodOpcodes(line.opcode) && !badAddressingModes(line.addrMode)
override def isBadExtractedCodeHead(head: AssemblyLine): Boolean = false
override def isBadExtractedCodeLast(head: AssemblyLine): Boolean = false
override def isBadExtractedCodeLast(last: AssemblyLine): Boolean = false
override def createCall(functionName: String): AssemblyLine = AssemblyLine.absolute(Opcode.JSR, Label(functionName))
override def createReturn(): AssemblyLine = AssemblyLine.implied(RTS)
def rtsPrecededByDiscards(xs: List[AssemblyLine]): Option[List[AssemblyLine]] = {
xs match {
case AssemblyLine(op, _, _, _) :: xs if OpcodeClasses.NoopDiscardsFlags(op) => rtsPrecededByDiscards(xs)
case AssemblyLine(RTS, _, _, _) :: xs => Some(xs)
case _ => None
}
}
override def tco(code: List[AssemblyLine]): List[AssemblyLine] = code match {
case (call@AssemblyLine(JSR, Absolute, _, _)) :: AssemblyLine(RTS, _, _, _) :: xs => call.copy(opcode = JMP) :: tco(xs)
case (call@AssemblyLine(JSR, Absolute | LongAbsolute, _, _)) :: xs => rtsPrecededByDiscards(xs) match {
case Some(rest) => call.copy(opcode = JMP) :: tco(rest)
case _ => call :: tco(xs)
}
case x :: xs => x :: tco(xs)
case Nil => Nil
}
override def renumerateLabels(code: List[AssemblyLine], temporary: Boolean): List[AssemblyLine] = {
val map = mutable.Map[String, String]()
var counter = 0
code.foreach{
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(x)), _) if x.startsWith(".") =>
map(x) = if (temporary) ".ddtmp__" + counter else env.nextLabel("dd")
counter += 1
case _ =>
}
code.map{
case l@AssemblyLine(_, _, MemoryAddressConstant(Label(x)), _) if map.contains(x) =>
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
case l => l
}
}
def checkIfLabelsAreInternal(snippet: List[AssemblyLine], wholeCode: List[AssemblyLine]): Boolean = {
val myLabels = mutable.Set[String]()
val useCount = mutable.Map[String, Int]()
snippet.foreach{
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(x)), _) =>
myLabels += x
case AssemblyLine(_, _, MemoryAddressConstant(Label(x)), _) =>
useCount(x) = useCount.getOrElse(x, 0) - 1
case _ =>
}
wholeCode.foreach {
case AssemblyLine(op, _, MemoryAddressConstant(Label(x)), _) if op != LABEL && myLabels(x) =>
useCount(x) = useCount.getOrElse(x, 0) + 1
case _ =>
}
useCount.values.forall(_ == 0)
}
}

View File

@ -6,6 +6,8 @@ import millfork.env.{Environment, Label, MemoryAddressConstant}
import ZOpcode._
import millfork.node.ZRegister.SP
import scala.collection.mutable
/**
* @author Karol Stasiak
*/
@ -39,7 +41,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
IN_IMM, OUT_IMM, IN_C, OUT_C,
LD_AHLI, LD_AHLD, LD_HLIA, LD_HLDA,
LDH_AC, LDH_AD, LDH_CA, LDH_DA,
CALL,
CALL, JP, JR, LABEL,
) ++ ZOpcodeClasses.AllSingleBit
private val conditionallyGoodOpcodes = Set(
@ -58,7 +60,7 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
override def isBadExtractedCodeHead(head: ZLine): Boolean = false
override def isBadExtractedCodeLast(head: ZLine): Boolean = head.opcode match {
override def isBadExtractedCodeLast(last: ZLine): Boolean = last.opcode match {
case EI | DI | IM => true
case _ => false
}
@ -67,9 +69,54 @@ class Z80Deduplicate(env: Environment, options: CompilationOptions) extends Dedu
override def createReturn(): ZLine = ZLine.implied(RET)
def retPrecededByDiscards(xs: List[ZLine]): Option[List[ZLine]] = {
xs match {
case ZLine(op, _, _, _) :: xs if ZOpcodeClasses.NoopDiscards(op) => retPrecededByDiscards(xs)
case ZLine(RET, _, _, _) :: xs => Some(xs)
case _ => None
}
}
override def tco(code: List[ZLine]): List[ZLine] = code match {
case (call@ZLine(CALL, _, _, _)) :: ZLine(RET, NoRegisters, _, _) :: xs => call.copy(opcode = JP) :: tco(xs)
case (call@ZLine(CALL, _, _, _)) :: xs => retPrecededByDiscards(xs) match {
case Some(rest) =>call.copy(opcode = JP) :: tco(rest)
case _ => call :: tco(xs)
}
case x :: xs => x :: tco(xs)
case Nil => Nil
}
override def renumerateLabels(code: List[ZLine], temporary: Boolean): List[ZLine] = {
val map = mutable.Map[String, String]()
var counter = 0
code.foreach{
case ZLine(LABEL, _, MemoryAddressConstant(Label(x)), _) if x.startsWith(".") =>
map(x) = if (temporary) ".ddtmp__" + counter else env.nextLabel("dd")
counter += 1
case _ =>
}
code.map{
case l@ZLine(_, _, MemoryAddressConstant(Label(x)), _) if map.contains(x) =>
l.copy(parameter = MemoryAddressConstant(Label(map(x))))
case l => l
}
}
def checkIfLabelsAreInternal(snippet: List[ZLine], wholeCode: List[ZLine]): Boolean = {
val myLabels = mutable.Set[String]()
val useCount = mutable.Map[String, Int]()
snippet.foreach{
case ZLine(LABEL, _, MemoryAddressConstant(Label(x)), _) =>
myLabels += x
case ZLine(_, _, MemoryAddressConstant(Label(x)), _) =>
useCount(x) = useCount.getOrElse(x, 0) - 1
case _ =>
}
wholeCode.foreach {
case ZLine(op, _, MemoryAddressConstant(Label(x)), _) if op != LABEL && myLabels(x) =>
useCount(x) = useCount.getOrElse(x, 0) + 1
case _ =>
}
useCount.values.forall(_ == 0)
}
}

View File

@ -69,4 +69,36 @@ class DeduplicationSuite extends FunSuite with Matchers {
m.readMedium(0xc000) should equal(0x1FB00)
}
}
test("Loop subroutine extraction") {
EmuSizeOptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080)(
"""
| array output [8] @$c000
| void main() {
| f(2)
| g(3)
| h(6)
| }
| noinline void f(byte x) {
| byte i
| for i,0,until,output.length {
| output[i] = x
| }
| }
| noinline void g(byte x) {
| byte i
| for i,0,until,output.length {
| output[i] = x
| }
| }
| noinline void h(byte x) {
| byte i
| for i,0,until,output.length {
| output[i] = x
| }
| }
""".stripMargin) {m =>
m.readByte(0xc000) should equal(6)
}
}
}