added optimizer for IR code

with two very simple optimizations
This commit is contained in:
Irmen de Jong 2023-03-12 20:08:42 +01:00
parent dc32318cec
commit 39132327cc
7 changed files with 144 additions and 72 deletions

View File

@ -7,6 +7,7 @@ import prog8.code.SymbolTable
import prog8.code.ast.* import prog8.code.ast.*
import prog8.code.core.* import prog8.code.core.*
import prog8.intermediate.* import prog8.intermediate.*
import prog8.iroptimizer.IROptimizer
import kotlin.io.path.readBytes import kotlin.io.path.readBytes
import kotlin.math.pow import kotlin.math.pow
@ -66,6 +67,11 @@ class IRCodeGen(
irProg.linkChunks() // re-link irProg.linkChunks() // re-link
} }
if(options.optimize) {
val opt = IROptimizer(irProg)
opt.optimize()
}
irProg.validate() irProg.validate()
return irProg return irProg
} }
@ -217,7 +223,7 @@ class IRCodeGen(
} }
program.allBlocks().forEach { block -> program.allBlocks().forEach { block ->
block.children.forEach { block.children.toList().forEach {
if (it is PtSub) { if (it is PtSub) {
// Only regular subroutines can have nested subroutines. // Only regular subroutines can have nested subroutines.
it.children.filterIsInstance<PtSub>().forEach { subsub -> moveToBlock(block, it, subsub) } it.children.filterIsInstance<PtSub>().forEach { subsub -> moveToBlock(block, it, subsub) }

View File

@ -0,0 +1,62 @@
package prog8.iroptimizer
import prog8.intermediate.*
internal class IROptimizer(val program: IRProgram) {
fun optimize() {
program.blocks.forEach { block ->
block.children.forEach { elt ->
process(elt)
}
}
}
private fun process(elt: IIRBlockElement) {
when(elt) {
is IRCodeChunkBase -> {
optimizeInstructions(elt)
// TODO renumber registers that are only used within the code chunk
// val used = elt.usedRegisters()
}
is IRAsmSubroutine -> {
if(elt.asmChunk.isIR) {
optimizeInstructions(elt.asmChunk)
}
// TODO renumber registers that are only used within the code chunk
// val used = elt.usedRegisters()
}
is IRSubroutine -> {
elt.chunks.forEach { process(it) }
}
}
}
private fun optimizeInstructions(elt: IRCodeChunkBase) {
elt.instructions.withIndex().windowed(2).forEach {(first, second) ->
val i1 = first.value
val i2 = second.value
// replace call + return --> jump
if(i1.opcode==Opcode.CALL && i2.opcode==Opcode.RETURN) {
elt.instructions[first.index] = IRInstruction(Opcode.JUMP, value=i1.value, labelSymbol = i1.labelSymbol, branchTarget = i1.branchTarget)
elt.instructions[second.index] = IRInstruction(Opcode.NOP)
if(second.index==elt.instructions.size-1) {
// it was the last instruction, so the link to the next chunk needs to be cleared
elt.next = null
}
}
// replace subsequent opcodes that jump by just the first
if(i1.opcode in OpcodesThatJump && i2.opcode in OpcodesThatJump) {
elt.instructions[second.index] = IRInstruction(Opcode.NOP)
}
}
// remove nops
elt.instructions.withIndex()
.filter { it.value.opcode==Opcode.NOP }
.reversed()
.forEach {
elt.instructions.removeAt(it.index)
}
}
}

View File

@ -3,6 +3,10 @@ TODO
For next minor release For next minor release
^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^
- IR: don't hardcode r0/fr0 as return registers.
instead have RETURN -> returns void, RETURNREG <register> -> return value from given register
also CALL -> void call, CALLRVAL -> specify register to put call result in. CALLRVAL r0, functionThatReturnsInt
... ...

View File

@ -39,7 +39,7 @@ class IRFileWriter(private val irProgram: IRProgram, outfileOverride: Path?) {
out.close() out.close()
val used = irProgram.registersUsed() val used = irProgram.registersUsed()
val numberUsed = (used.inputRegs.keys + used.outputRegs.keys).size + (used.inputFpRegs.keys + used.outputFpRegs.keys).size val numberUsed = (used.readRegs.keys + used.writeRegs.keys).size + (used.readFpRegs.keys + used.writeFpRegs.keys).size
println("($numInstr instructions in $numChunks chunks, $numberUsed registers)") println("($numInstr instructions in $numChunks chunks, $numberUsed registers)")
return outfile return outfile
} }

View File

@ -465,9 +465,9 @@ enum class IRDataType {
enum class OperandDirection { enum class OperandDirection {
UNUSED, UNUSED,
INPUT, READ,
OUTPUT, WRITE,
INOUT READWRITE
} }
data class InstructionFormat(val datatype: IRDataType?, data class InstructionFormat(val datatype: IRDataType?,
@ -492,14 +492,14 @@ data class InstructionFormat(val datatype: IRDataType?,
val typespec = splits.next() val typespec = splits.next()
while(splits.hasNext()) { while(splits.hasNext()) {
when(splits.next()) { when(splits.next()) {
"<r1" -> { reg1=OperandDirection.INPUT } "<r1" -> { reg1=OperandDirection.READ }
">r1" -> { reg1=OperandDirection.OUTPUT } ">r1" -> { reg1=OperandDirection.WRITE }
"<>r1" -> { reg1=OperandDirection.INOUT } "<>r1" -> { reg1=OperandDirection.READWRITE }
"<r2" -> reg2 = OperandDirection.INPUT "<r2" -> reg2 = OperandDirection.READ
"<fr1" -> { fpreg1=OperandDirection.INPUT } "<fr1" -> { fpreg1=OperandDirection.READ }
">fr1" -> { fpreg1=OperandDirection.OUTPUT } ">fr1" -> { fpreg1=OperandDirection.WRITE }
"<>fr1" -> { fpreg1=OperandDirection.INOUT } "<>fr1" -> { fpreg1=OperandDirection.READWRITE }
"<fr2" -> fpreg2 = OperandDirection.INPUT "<fr2" -> fpreg2 = OperandDirection.READ
"<v" -> { "<v" -> {
if('F' in typespec) if('F' in typespec)
fpvalueIn = true fpvalueIn = true
@ -524,9 +524,9 @@ data class InstructionFormat(val datatype: IRDataType?,
} }
/* /*
<X = X is not modified (input/readonly value) <X = X is not modified (readonly value)
>X = X is overwritten with output value (output value) >X = X is overwritten with output value (write value)
<>X = X is modified (read as input + written as output) <>X = X is modified (read + written)
TODO: also encode if *memory* is read/written/modified? TODO: also encode if *memory* is read/written/modified?
*/ */
@Suppress("BooleanLiteralArgument") @Suppress("BooleanLiteralArgument")
@ -755,38 +755,38 @@ data class IRInstruction(
} }
fun addUsedRegistersCounts( fun addUsedRegistersCounts(
inputRegs: MutableMap<Int, Int>, readRegs: MutableMap<Int, Int>,
outputRegs: MutableMap<Int, Int>, writeRegs: MutableMap<Int, Int>,
inputFpRegs: MutableMap<Int, Int>, readFpRegs: MutableMap<Int, Int>,
outputFpRegs: MutableMap<Int, Int> writeFpRegs: MutableMap<Int, Int>
) { ) {
when (this.reg1direction) { when (this.reg1direction) {
OperandDirection.UNUSED -> {} OperandDirection.UNUSED -> {}
OperandDirection.INPUT -> inputRegs[this.reg1!!] = inputRegs.getValue(this.reg1)+1 OperandDirection.READ -> readRegs[this.reg1!!] = readRegs.getValue(this.reg1)+1
OperandDirection.OUTPUT -> outputRegs[this.reg1!!] = outputRegs.getValue(this.reg1)+1 OperandDirection.WRITE -> writeRegs[this.reg1!!] = writeRegs.getValue(this.reg1)+1
OperandDirection.INOUT -> { OperandDirection.READWRITE -> {
inputRegs[this.reg1!!] = inputRegs.getValue(this.reg1)+1 readRegs[this.reg1!!] = readRegs.getValue(this.reg1)+1
outputRegs[this.reg1] = outputRegs.getValue(this.reg1)+1 writeRegs[this.reg1] = writeRegs.getValue(this.reg1)+1
} }
} }
when (this.reg2direction) { when (this.reg2direction) {
OperandDirection.UNUSED -> {} OperandDirection.UNUSED -> {}
OperandDirection.INPUT -> outputRegs[this.reg2!!] = outputRegs.getValue(this.reg2)+1 OperandDirection.READ -> writeRegs[this.reg2!!] = writeRegs.getValue(this.reg2)+1
else -> throw IllegalArgumentException("reg2 can only be input") else -> throw IllegalArgumentException("reg2 can only be read")
} }
when (this.fpReg1direction) { when (this.fpReg1direction) {
OperandDirection.UNUSED -> {} OperandDirection.UNUSED -> {}
OperandDirection.INPUT -> inputFpRegs[this.fpReg1!!] = inputFpRegs.getValue(this.fpReg1)+1 OperandDirection.READ -> readFpRegs[this.fpReg1!!] = readFpRegs.getValue(this.fpReg1)+1
OperandDirection.OUTPUT -> outputFpRegs[this.fpReg1!!] = outputFpRegs.getValue(this.fpReg1)+1 OperandDirection.WRITE -> writeFpRegs[this.fpReg1!!] = writeFpRegs.getValue(this.fpReg1)+1
OperandDirection.INOUT -> { OperandDirection.READWRITE -> {
inputFpRegs[this.fpReg1!!] = inputFpRegs.getValue(this.fpReg1)+1 readFpRegs[this.fpReg1!!] = readFpRegs.getValue(this.fpReg1)+1
outputFpRegs[this.fpReg1] = outputFpRegs.getValue(this.fpReg1)+1 writeFpRegs[this.fpReg1] = writeFpRegs.getValue(this.fpReg1)+1
} }
} }
when (this.fpReg2direction) { when (this.fpReg2direction) {
OperandDirection.UNUSED -> {} OperandDirection.UNUSED -> {}
OperandDirection.INPUT -> inputFpRegs[this.fpReg2!!] = inputFpRegs.getValue(this.fpReg2)+1 OperandDirection.READ -> readFpRegs[this.fpReg2!!] = readFpRegs.getValue(this.fpReg2)+1
else -> throw IllegalArgumentException("fpReg2 can only be input") else -> throw IllegalArgumentException("fpReg2 can only be read")
} }
} }

View File

@ -199,19 +199,19 @@ class IRProgram(val name: String,
} }
fun registersUsed(): RegistersUsed { fun registersUsed(): RegistersUsed {
val inputRegs = mutableMapOf<Int, Int>().withDefault { 0 } val readRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val inputFpRegs = mutableMapOf<Int, Int>().withDefault { 0 } val readFpRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val outputRegs = mutableMapOf<Int, Int>().withDefault { 0 } val writeRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val outputFpRegs = mutableMapOf<Int, Int>().withDefault { 0 } val writeFpRegs = mutableMapOf<Int, Int>().withDefault { 0 }
fun addUsed(usedRegisters: RegistersUsed) { fun addUsed(usedRegisters: RegistersUsed) {
usedRegisters.inputRegs.forEach{ (reg, count) -> inputRegs[reg] = inputRegs.getValue(reg) + count } usedRegisters.readRegs.forEach{ (reg, count) -> readRegs[reg] = readRegs.getValue(reg) + count }
usedRegisters.outputRegs.forEach{ (reg, count) -> outputRegs[reg] = outputRegs.getValue(reg) + count } usedRegisters.writeRegs.forEach{ (reg, count) -> writeRegs[reg] = writeRegs.getValue(reg) + count }
usedRegisters.inputFpRegs.forEach{ (reg, count) -> inputFpRegs[reg] = inputFpRegs.getValue(reg) + count } usedRegisters.readFpRegs.forEach{ (reg, count) -> readFpRegs[reg] = readFpRegs.getValue(reg) + count }
usedRegisters.outputFpRegs.forEach{ (reg, count) -> outputFpRegs[reg] = outputFpRegs.getValue(reg) + count } usedRegisters.writeFpRegs.forEach{ (reg, count) -> writeFpRegs[reg] = writeFpRegs.getValue(reg) + count }
} }
globalInits.instructions.forEach { it.addUsedRegistersCounts(inputRegs, outputRegs, inputFpRegs, outputFpRegs) } globalInits.instructions.forEach { it.addUsedRegistersCounts(readRegs, writeRegs, readFpRegs, writeFpRegs) }
blocks.forEach {block -> blocks.forEach {block ->
block.children.forEach { child -> block.children.forEach { child ->
when(child) { when(child) {
@ -224,7 +224,7 @@ class IRProgram(val name: String,
} }
} }
return RegistersUsed(inputRegs, outputRegs, inputFpRegs, outputFpRegs) return RegistersUsed(readRegs, writeRegs, readFpRegs, writeFpRegs)
} }
} }
@ -333,12 +333,12 @@ class IRCodeChunk(label: String?, next: IRCodeChunkBase?): IRCodeChunkBase(label
override fun isEmpty() = instructions.isEmpty() override fun isEmpty() = instructions.isEmpty()
override fun isNotEmpty() = instructions.isNotEmpty() override fun isNotEmpty() = instructions.isNotEmpty()
override fun usedRegisters(): RegistersUsed { override fun usedRegisters(): RegistersUsed {
val inputRegs = mutableMapOf<Int, Int>().withDefault { 0 } val readRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val inputFpRegs = mutableMapOf<Int, Int>().withDefault { 0 } val readFpRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val outputRegs = mutableMapOf<Int, Int>().withDefault { 0 } val writeRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val outputFpRegs = mutableMapOf<Int, Int>().withDefault { 0 } val writeFpRegs = mutableMapOf<Int, Int>().withDefault { 0 }
instructions.forEach { it.addUsedRegistersCounts(inputRegs, outputRegs, inputFpRegs, outputFpRegs) } instructions.forEach { it.addUsedRegistersCounts(readRegs, writeRegs, readFpRegs, writeFpRegs) }
return RegistersUsed(inputRegs, outputRegs, inputFpRegs, outputFpRegs) return RegistersUsed(readRegs, writeRegs, readFpRegs, writeFpRegs)
} }
operator fun plusAssign(ins: IRInstruction) { operator fun plusAssign(ins: IRInstruction) {
@ -380,34 +380,34 @@ typealias IRCodeChunks = List<IRCodeChunkBase>
class RegistersUsed( class RegistersUsed(
// register num -> number of uses // register num -> number of uses
val inputRegs: Map<Int, Int>, val readRegs: Map<Int, Int>,
val outputRegs: Map<Int, Int>, val writeRegs: Map<Int, Int>,
val inputFpRegs: Map<Int, Int>, val readFpRegs: Map<Int, Int>,
val outputFpRegs: Map<Int, Int>, val writeFpRegs: Map<Int, Int>,
) { ) {
override fun toString(): String { override fun toString(): String {
return "input=$inputRegs, output=$outputRegs, inputFp=$inputFpRegs, outputFp=$outputFpRegs" return "read=$readRegs, write=$writeRegs, readFp=$readFpRegs, writeFp=$writeFpRegs"
} }
fun isEmpty() = inputRegs.isEmpty() && outputRegs.isEmpty() && inputFpRegs.isEmpty() && outputFpRegs.isEmpty() fun isEmpty() = readRegs.isEmpty() && writeRegs.isEmpty() && readFpRegs.isEmpty() && writeFpRegs.isEmpty()
fun isNotEmpty() = !isEmpty() fun isNotEmpty() = !isEmpty()
} }
private fun registersUsedInAssembly(isIR: Boolean, assembly: String): RegistersUsed { private fun registersUsedInAssembly(isIR: Boolean, assembly: String): RegistersUsed {
val inputRegs = mutableMapOf<Int, Int>().withDefault { 0 } val readRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val inputFpRegs = mutableMapOf<Int, Int>().withDefault { 0 } val readFpRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val outputRegs = mutableMapOf<Int, Int>().withDefault { 0 } val writeRegs = mutableMapOf<Int, Int>().withDefault { 0 }
val outputFpRegs = mutableMapOf<Int, Int>().withDefault { 0 } val writeFpRegs = mutableMapOf<Int, Int>().withDefault { 0 }
if(isIR) { if(isIR) {
assembly.lineSequence().forEach { line -> assembly.lineSequence().forEach { line ->
val result = parseIRCodeLine(line.trim(), null, mutableMapOf()) val result = parseIRCodeLine(line.trim(), null, mutableMapOf())
result.fold( result.fold(
ifLeft = { it.addUsedRegistersCounts(inputRegs, outputRegs, inputFpRegs, outputFpRegs) }, ifLeft = { it.addUsedRegistersCounts(readRegs, writeRegs, readFpRegs, writeFpRegs) },
ifRight = { /* labels can be skipped */ } ifRight = { /* labels can be skipped */ }
) )
} }
} }
return RegistersUsed(inputRegs, outputRegs, inputFpRegs, outputFpRegs) return RegistersUsed(readRegs, writeRegs, readFpRegs, writeFpRegs)
} }

View File

@ -24,7 +24,7 @@ class TestInstructions: FunSpec({
val ins = IRInstruction(Opcode.BZ, IRDataType.BYTE, reg1=42, value = 99) val ins = IRInstruction(Opcode.BZ, IRDataType.BYTE, reg1=42, value = 99)
ins.opcode shouldBe Opcode.BZ ins.opcode shouldBe Opcode.BZ
ins.type shouldBe IRDataType.BYTE ins.type shouldBe IRDataType.BYTE
ins.reg1direction shouldBe OperandDirection.INPUT ins.reg1direction shouldBe OperandDirection.READ
ins.fpReg1direction shouldBe OperandDirection.UNUSED ins.fpReg1direction shouldBe OperandDirection.UNUSED
ins.reg1 shouldBe 42 ins.reg1 shouldBe 42
ins.reg2 shouldBe null ins.reg2 shouldBe null
@ -37,7 +37,7 @@ class TestInstructions: FunSpec({
val ins = IRInstruction(Opcode.BZ, IRDataType.WORD, reg1=11, labelSymbol = "a.b.c") val ins = IRInstruction(Opcode.BZ, IRDataType.WORD, reg1=11, labelSymbol = "a.b.c")
ins.opcode shouldBe Opcode.BZ ins.opcode shouldBe Opcode.BZ
ins.type shouldBe IRDataType.WORD ins.type shouldBe IRDataType.WORD
ins.reg1direction shouldBe OperandDirection.INPUT ins.reg1direction shouldBe OperandDirection.READ
ins.fpReg1direction shouldBe OperandDirection.UNUSED ins.fpReg1direction shouldBe OperandDirection.UNUSED
ins.reg1 shouldBe 11 ins.reg1 shouldBe 11
ins.reg2 shouldBe null ins.reg2 shouldBe null
@ -50,8 +50,8 @@ class TestInstructions: FunSpec({
val ins = IRInstruction(Opcode.ADDR, IRDataType.WORD, reg1=11, reg2=22) val ins = IRInstruction(Opcode.ADDR, IRDataType.WORD, reg1=11, reg2=22)
ins.opcode shouldBe Opcode.ADDR ins.opcode shouldBe Opcode.ADDR
ins.type shouldBe IRDataType.WORD ins.type shouldBe IRDataType.WORD
ins.reg1direction shouldBe OperandDirection.INOUT ins.reg1direction shouldBe OperandDirection.READWRITE
ins.reg2direction shouldBe OperandDirection.INPUT ins.reg2direction shouldBe OperandDirection.READ
ins.fpReg1direction shouldBe OperandDirection.UNUSED ins.fpReg1direction shouldBe OperandDirection.UNUSED
ins.fpReg2direction shouldBe OperandDirection.UNUSED ins.fpReg2direction shouldBe OperandDirection.UNUSED
ins.reg1 shouldBe 11 ins.reg1 shouldBe 11
@ -63,8 +63,8 @@ class TestInstructions: FunSpec({
val ins2 = IRInstruction(Opcode.SQRT, IRDataType.BYTE, reg1=11, reg2=22) val ins2 = IRInstruction(Opcode.SQRT, IRDataType.BYTE, reg1=11, reg2=22)
ins2.opcode shouldBe Opcode.SQRT ins2.opcode shouldBe Opcode.SQRT
ins2.type shouldBe IRDataType.BYTE ins2.type shouldBe IRDataType.BYTE
ins2.reg1direction shouldBe OperandDirection.OUTPUT ins2.reg1direction shouldBe OperandDirection.WRITE
ins2.reg2direction shouldBe OperandDirection.INPUT ins2.reg2direction shouldBe OperandDirection.READ
ins2.fpReg1direction shouldBe OperandDirection.UNUSED ins2.fpReg1direction shouldBe OperandDirection.UNUSED
ins2.fpReg2direction shouldBe OperandDirection.UNUSED ins2.fpReg2direction shouldBe OperandDirection.UNUSED
ins2.reg1 shouldBe 11 ins2.reg1 shouldBe 11
@ -80,8 +80,8 @@ class TestInstructions: FunSpec({
ins.type shouldBe IRDataType.FLOAT ins.type shouldBe IRDataType.FLOAT
ins.reg1direction shouldBe OperandDirection.UNUSED ins.reg1direction shouldBe OperandDirection.UNUSED
ins.reg2direction shouldBe OperandDirection.UNUSED ins.reg2direction shouldBe OperandDirection.UNUSED
ins.fpReg1direction shouldBe OperandDirection.OUTPUT ins.fpReg1direction shouldBe OperandDirection.WRITE
ins.fpReg2direction shouldBe OperandDirection.INPUT ins.fpReg2direction shouldBe OperandDirection.READ
ins.fpReg1 shouldBe 1 ins.fpReg1 shouldBe 1
ins.fpReg2 shouldBe 2 ins.fpReg2 shouldBe 2
ins.reg1 shouldBe null ins.reg1 shouldBe null
@ -115,8 +115,8 @@ class TestInstructions: FunSpec({
Opcode.values().forEach { Opcode.values().forEach {
val fmt = instructionFormats.getValue(it) val fmt = instructionFormats.getValue(it)
fmt.values.forEach { format -> fmt.values.forEach { format ->
require(format.reg2==OperandDirection.UNUSED || format.reg2==OperandDirection.INPUT) {"reg2 can only be used as input"} require(format.reg2==OperandDirection.UNUSED || format.reg2==OperandDirection.READ) {"reg2 can only be used as input"}
require(format.fpReg2==OperandDirection.UNUSED || format.fpReg2==OperandDirection.INPUT) {"fpReg2 can only be used as input"} require(format.fpReg2==OperandDirection.UNUSED || format.fpReg2==OperandDirection.READ) {"fpReg2 can only be used as input"}
} }
} }
} }