diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt index 8d493d323..e0ec683e8 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt @@ -7,6 +7,7 @@ import prog8.code.SymbolTable import prog8.code.ast.* import prog8.code.core.* import prog8.intermediate.* +import prog8.iroptimizer.IROptimizer import kotlin.io.path.readBytes import kotlin.math.pow @@ -66,6 +67,11 @@ class IRCodeGen( irProg.linkChunks() // re-link } + if(options.optimize) { + val opt = IROptimizer(irProg) + opt.optimize() + } + irProg.validate() return irProg } @@ -217,7 +223,7 @@ class IRCodeGen( } program.allBlocks().forEach { block -> - block.children.forEach { + block.children.toList().forEach { if (it is PtSub) { // Only regular subroutines can have nested subroutines. it.children.filterIsInstance().forEach { subsub -> moveToBlock(block, it, subsub) } diff --git a/codeGenIntermediate/src/prog8/iroptimizer/IROptimizer.kt b/codeGenIntermediate/src/prog8/iroptimizer/IROptimizer.kt new file mode 100644 index 000000000..7207541aa --- /dev/null +++ b/codeGenIntermediate/src/prog8/iroptimizer/IROptimizer.kt @@ -0,0 +1,62 @@ +package prog8.iroptimizer + +import prog8.intermediate.* + +internal class IROptimizer(val program: IRProgram) { + fun optimize() { + program.blocks.forEach { block -> + block.children.forEach { elt -> + process(elt) + } + } + } + + private fun process(elt: IIRBlockElement) { + when(elt) { + is IRCodeChunkBase -> { + optimizeInstructions(elt) + // TODO renumber registers that are only used within the code chunk + // val used = elt.usedRegisters() + } + is IRAsmSubroutine -> { + if(elt.asmChunk.isIR) { + optimizeInstructions(elt.asmChunk) + } + // TODO renumber registers that are only used within the code chunk + // val used = elt.usedRegisters() + } + is IRSubroutine -> { + elt.chunks.forEach { process(it) } + } + } + } + + private fun optimizeInstructions(elt: IRCodeChunkBase) { + elt.instructions.withIndex().windowed(2).forEach {(first, second) -> + val i1 = first.value + val i2 = second.value + // replace call + return --> jump + if(i1.opcode==Opcode.CALL && i2.opcode==Opcode.RETURN) { + elt.instructions[first.index] = IRInstruction(Opcode.JUMP, value=i1.value, labelSymbol = i1.labelSymbol, branchTarget = i1.branchTarget) + elt.instructions[second.index] = IRInstruction(Opcode.NOP) + if(second.index==elt.instructions.size-1) { + // it was the last instruction, so the link to the next chunk needs to be cleared + elt.next = null + } + } + + // replace subsequent opcodes that jump by just the first + if(i1.opcode in OpcodesThatJump && i2.opcode in OpcodesThatJump) { + elt.instructions[second.index] = IRInstruction(Opcode.NOP) + } + } + + // remove nops + elt.instructions.withIndex() + .filter { it.value.opcode==Opcode.NOP } + .reversed() + .forEach { + elt.instructions.removeAt(it.index) + } + } +} \ No newline at end of file diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 8e0499d74..509b80bec 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -3,6 +3,10 @@ TODO For next minor release ^^^^^^^^^^^^^^^^^^^^^^ +- IR: don't hardcode r0/fr0 as return registers. + instead have RETURN -> returns void, RETURNREG -> return value from given register + also CALL -> void call, CALLRVAL -> specify register to put call result in. CALLRVAL r0, functionThatReturnsInt + ... diff --git a/intermediate/src/prog8/intermediate/IRFileWriter.kt b/intermediate/src/prog8/intermediate/IRFileWriter.kt index 862b59ea8..3f0c59686 100644 --- a/intermediate/src/prog8/intermediate/IRFileWriter.kt +++ b/intermediate/src/prog8/intermediate/IRFileWriter.kt @@ -39,7 +39,7 @@ class IRFileWriter(private val irProgram: IRProgram, outfileOverride: Path?) { out.close() val used = irProgram.registersUsed() - val numberUsed = (used.inputRegs.keys + used.outputRegs.keys).size + (used.inputFpRegs.keys + used.outputFpRegs.keys).size + val numberUsed = (used.readRegs.keys + used.writeRegs.keys).size + (used.readFpRegs.keys + used.writeFpRegs.keys).size println("($numInstr instructions in $numChunks chunks, $numberUsed registers)") return outfile } diff --git a/intermediate/src/prog8/intermediate/IRInstructions.kt b/intermediate/src/prog8/intermediate/IRInstructions.kt index 2d1538cc4..8aac3686b 100644 --- a/intermediate/src/prog8/intermediate/IRInstructions.kt +++ b/intermediate/src/prog8/intermediate/IRInstructions.kt @@ -465,9 +465,9 @@ enum class IRDataType { enum class OperandDirection { UNUSED, - INPUT, - OUTPUT, - INOUT + READ, + WRITE, + READWRITE } data class InstructionFormat(val datatype: IRDataType?, @@ -492,14 +492,14 @@ data class InstructionFormat(val datatype: IRDataType?, val typespec = splits.next() while(splits.hasNext()) { when(splits.next()) { - " { reg1=OperandDirection.INPUT } - ">r1" -> { reg1=OperandDirection.OUTPUT } - "<>r1" -> { reg1=OperandDirection.INOUT } - " reg2 = OperandDirection.INPUT - " { fpreg1=OperandDirection.INPUT } - ">fr1" -> { fpreg1=OperandDirection.OUTPUT } - "<>fr1" -> { fpreg1=OperandDirection.INOUT } - " fpreg2 = OperandDirection.INPUT + " { reg1=OperandDirection.READ } + ">r1" -> { reg1=OperandDirection.WRITE } + "<>r1" -> { reg1=OperandDirection.READWRITE } + " reg2 = OperandDirection.READ + " { fpreg1=OperandDirection.READ } + ">fr1" -> { fpreg1=OperandDirection.WRITE } + "<>fr1" -> { fpreg1=OperandDirection.READWRITE } + " fpreg2 = OperandDirection.READ " { if('F' in typespec) fpvalueIn = true @@ -524,9 +524,9 @@ data class InstructionFormat(val datatype: IRDataType?, } /* - X = X is overwritten with output value (output value) - <>X = X is modified (read as input + written as output) + X = X is overwritten with output value (write value) + <>X = X is modified (read + written) TODO: also encode if *memory* is read/written/modified? */ @Suppress("BooleanLiteralArgument") @@ -755,38 +755,38 @@ data class IRInstruction( } fun addUsedRegistersCounts( - inputRegs: MutableMap, - outputRegs: MutableMap, - inputFpRegs: MutableMap, - outputFpRegs: MutableMap + readRegs: MutableMap, + writeRegs: MutableMap, + readFpRegs: MutableMap, + writeFpRegs: MutableMap ) { when (this.reg1direction) { OperandDirection.UNUSED -> {} - OperandDirection.INPUT -> inputRegs[this.reg1!!] = inputRegs.getValue(this.reg1)+1 - OperandDirection.OUTPUT -> outputRegs[this.reg1!!] = outputRegs.getValue(this.reg1)+1 - OperandDirection.INOUT -> { - inputRegs[this.reg1!!] = inputRegs.getValue(this.reg1)+1 - outputRegs[this.reg1] = outputRegs.getValue(this.reg1)+1 + OperandDirection.READ -> readRegs[this.reg1!!] = readRegs.getValue(this.reg1)+1 + OperandDirection.WRITE -> writeRegs[this.reg1!!] = writeRegs.getValue(this.reg1)+1 + OperandDirection.READWRITE -> { + readRegs[this.reg1!!] = readRegs.getValue(this.reg1)+1 + writeRegs[this.reg1] = writeRegs.getValue(this.reg1)+1 } } when (this.reg2direction) { OperandDirection.UNUSED -> {} - OperandDirection.INPUT -> outputRegs[this.reg2!!] = outputRegs.getValue(this.reg2)+1 - else -> throw IllegalArgumentException("reg2 can only be input") + OperandDirection.READ -> writeRegs[this.reg2!!] = writeRegs.getValue(this.reg2)+1 + else -> throw IllegalArgumentException("reg2 can only be read") } when (this.fpReg1direction) { OperandDirection.UNUSED -> {} - OperandDirection.INPUT -> inputFpRegs[this.fpReg1!!] = inputFpRegs.getValue(this.fpReg1)+1 - OperandDirection.OUTPUT -> outputFpRegs[this.fpReg1!!] = outputFpRegs.getValue(this.fpReg1)+1 - OperandDirection.INOUT -> { - inputFpRegs[this.fpReg1!!] = inputFpRegs.getValue(this.fpReg1)+1 - outputFpRegs[this.fpReg1] = outputFpRegs.getValue(this.fpReg1)+1 + OperandDirection.READ -> readFpRegs[this.fpReg1!!] = readFpRegs.getValue(this.fpReg1)+1 + OperandDirection.WRITE -> writeFpRegs[this.fpReg1!!] = writeFpRegs.getValue(this.fpReg1)+1 + OperandDirection.READWRITE -> { + readFpRegs[this.fpReg1!!] = readFpRegs.getValue(this.fpReg1)+1 + writeFpRegs[this.fpReg1] = writeFpRegs.getValue(this.fpReg1)+1 } } when (this.fpReg2direction) { OperandDirection.UNUSED -> {} - OperandDirection.INPUT -> inputFpRegs[this.fpReg2!!] = inputFpRegs.getValue(this.fpReg2)+1 - else -> throw IllegalArgumentException("fpReg2 can only be input") + OperandDirection.READ -> readFpRegs[this.fpReg2!!] = readFpRegs.getValue(this.fpReg2)+1 + else -> throw IllegalArgumentException("fpReg2 can only be read") } } diff --git a/intermediate/src/prog8/intermediate/IRProgram.kt b/intermediate/src/prog8/intermediate/IRProgram.kt index cdfa0fbd4..59da3d4ea 100644 --- a/intermediate/src/prog8/intermediate/IRProgram.kt +++ b/intermediate/src/prog8/intermediate/IRProgram.kt @@ -199,19 +199,19 @@ class IRProgram(val name: String, } fun registersUsed(): RegistersUsed { - val inputRegs = mutableMapOf().withDefault { 0 } - val inputFpRegs = mutableMapOf().withDefault { 0 } - val outputRegs = mutableMapOf().withDefault { 0 } - val outputFpRegs = mutableMapOf().withDefault { 0 } + val readRegs = mutableMapOf().withDefault { 0 } + val readFpRegs = mutableMapOf().withDefault { 0 } + val writeRegs = mutableMapOf().withDefault { 0 } + val writeFpRegs = mutableMapOf().withDefault { 0 } fun addUsed(usedRegisters: RegistersUsed) { - usedRegisters.inputRegs.forEach{ (reg, count) -> inputRegs[reg] = inputRegs.getValue(reg) + count } - usedRegisters.outputRegs.forEach{ (reg, count) -> outputRegs[reg] = outputRegs.getValue(reg) + count } - usedRegisters.inputFpRegs.forEach{ (reg, count) -> inputFpRegs[reg] = inputFpRegs.getValue(reg) + count } - usedRegisters.outputFpRegs.forEach{ (reg, count) -> outputFpRegs[reg] = outputFpRegs.getValue(reg) + count } + usedRegisters.readRegs.forEach{ (reg, count) -> readRegs[reg] = readRegs.getValue(reg) + count } + usedRegisters.writeRegs.forEach{ (reg, count) -> writeRegs[reg] = writeRegs.getValue(reg) + count } + usedRegisters.readFpRegs.forEach{ (reg, count) -> readFpRegs[reg] = readFpRegs.getValue(reg) + count } + usedRegisters.writeFpRegs.forEach{ (reg, count) -> writeFpRegs[reg] = writeFpRegs.getValue(reg) + count } } - globalInits.instructions.forEach { it.addUsedRegistersCounts(inputRegs, outputRegs, inputFpRegs, outputFpRegs) } + globalInits.instructions.forEach { it.addUsedRegistersCounts(readRegs, writeRegs, readFpRegs, writeFpRegs) } blocks.forEach {block -> block.children.forEach { child -> when(child) { @@ -224,7 +224,7 @@ class IRProgram(val name: String, } } - return RegistersUsed(inputRegs, outputRegs, inputFpRegs, outputFpRegs) + return RegistersUsed(readRegs, writeRegs, readFpRegs, writeFpRegs) } } @@ -333,12 +333,12 @@ class IRCodeChunk(label: String?, next: IRCodeChunkBase?): IRCodeChunkBase(label override fun isEmpty() = instructions.isEmpty() override fun isNotEmpty() = instructions.isNotEmpty() override fun usedRegisters(): RegistersUsed { - val inputRegs = mutableMapOf().withDefault { 0 } - val inputFpRegs = mutableMapOf().withDefault { 0 } - val outputRegs = mutableMapOf().withDefault { 0 } - val outputFpRegs = mutableMapOf().withDefault { 0 } - instructions.forEach { it.addUsedRegistersCounts(inputRegs, outputRegs, inputFpRegs, outputFpRegs) } - return RegistersUsed(inputRegs, outputRegs, inputFpRegs, outputFpRegs) + val readRegs = mutableMapOf().withDefault { 0 } + val readFpRegs = mutableMapOf().withDefault { 0 } + val writeRegs = mutableMapOf().withDefault { 0 } + val writeFpRegs = mutableMapOf().withDefault { 0 } + instructions.forEach { it.addUsedRegistersCounts(readRegs, writeRegs, readFpRegs, writeFpRegs) } + return RegistersUsed(readRegs, writeRegs, readFpRegs, writeFpRegs) } operator fun plusAssign(ins: IRInstruction) { @@ -380,34 +380,34 @@ typealias IRCodeChunks = List class RegistersUsed( // register num -> number of uses - val inputRegs: Map, - val outputRegs: Map, - val inputFpRegs: Map, - val outputFpRegs: Map, + val readRegs: Map, + val writeRegs: Map, + val readFpRegs: Map, + val writeFpRegs: Map, ) { override fun toString(): String { - return "input=$inputRegs, output=$outputRegs, inputFp=$inputFpRegs, outputFp=$outputFpRegs" + return "read=$readRegs, write=$writeRegs, readFp=$readFpRegs, writeFp=$writeFpRegs" } - fun isEmpty() = inputRegs.isEmpty() && outputRegs.isEmpty() && inputFpRegs.isEmpty() && outputFpRegs.isEmpty() + fun isEmpty() = readRegs.isEmpty() && writeRegs.isEmpty() && readFpRegs.isEmpty() && writeFpRegs.isEmpty() fun isNotEmpty() = !isEmpty() } private fun registersUsedInAssembly(isIR: Boolean, assembly: String): RegistersUsed { - val inputRegs = mutableMapOf().withDefault { 0 } - val inputFpRegs = mutableMapOf().withDefault { 0 } - val outputRegs = mutableMapOf().withDefault { 0 } - val outputFpRegs = mutableMapOf().withDefault { 0 } + val readRegs = mutableMapOf().withDefault { 0 } + val readFpRegs = mutableMapOf().withDefault { 0 } + val writeRegs = mutableMapOf().withDefault { 0 } + val writeFpRegs = mutableMapOf().withDefault { 0 } if(isIR) { assembly.lineSequence().forEach { line -> val result = parseIRCodeLine(line.trim(), null, mutableMapOf()) result.fold( - ifLeft = { it.addUsedRegistersCounts(inputRegs, outputRegs, inputFpRegs, outputFpRegs) }, + ifLeft = { it.addUsedRegistersCounts(readRegs, writeRegs, readFpRegs, writeFpRegs) }, ifRight = { /* labels can be skipped */ } ) } } - return RegistersUsed(inputRegs, outputRegs, inputFpRegs, outputFpRegs) + return RegistersUsed(readRegs, writeRegs, readFpRegs, writeFpRegs) } diff --git a/intermediate/test/TestInstructions.kt b/intermediate/test/TestInstructions.kt index 6f124afa7..b637dc690 100644 --- a/intermediate/test/TestInstructions.kt +++ b/intermediate/test/TestInstructions.kt @@ -24,7 +24,7 @@ class TestInstructions: FunSpec({ val ins = IRInstruction(Opcode.BZ, IRDataType.BYTE, reg1=42, value = 99) ins.opcode shouldBe Opcode.BZ ins.type shouldBe IRDataType.BYTE - ins.reg1direction shouldBe OperandDirection.INPUT + ins.reg1direction shouldBe OperandDirection.READ ins.fpReg1direction shouldBe OperandDirection.UNUSED ins.reg1 shouldBe 42 ins.reg2 shouldBe null @@ -37,7 +37,7 @@ class TestInstructions: FunSpec({ val ins = IRInstruction(Opcode.BZ, IRDataType.WORD, reg1=11, labelSymbol = "a.b.c") ins.opcode shouldBe Opcode.BZ ins.type shouldBe IRDataType.WORD - ins.reg1direction shouldBe OperandDirection.INPUT + ins.reg1direction shouldBe OperandDirection.READ ins.fpReg1direction shouldBe OperandDirection.UNUSED ins.reg1 shouldBe 11 ins.reg2 shouldBe null @@ -50,8 +50,8 @@ class TestInstructions: FunSpec({ val ins = IRInstruction(Opcode.ADDR, IRDataType.WORD, reg1=11, reg2=22) ins.opcode shouldBe Opcode.ADDR ins.type shouldBe IRDataType.WORD - ins.reg1direction shouldBe OperandDirection.INOUT - ins.reg2direction shouldBe OperandDirection.INPUT + ins.reg1direction shouldBe OperandDirection.READWRITE + ins.reg2direction shouldBe OperandDirection.READ ins.fpReg1direction shouldBe OperandDirection.UNUSED ins.fpReg2direction shouldBe OperandDirection.UNUSED ins.reg1 shouldBe 11 @@ -63,8 +63,8 @@ class TestInstructions: FunSpec({ val ins2 = IRInstruction(Opcode.SQRT, IRDataType.BYTE, reg1=11, reg2=22) ins2.opcode shouldBe Opcode.SQRT ins2.type shouldBe IRDataType.BYTE - ins2.reg1direction shouldBe OperandDirection.OUTPUT - ins2.reg2direction shouldBe OperandDirection.INPUT + ins2.reg1direction shouldBe OperandDirection.WRITE + ins2.reg2direction shouldBe OperandDirection.READ ins2.fpReg1direction shouldBe OperandDirection.UNUSED ins2.fpReg2direction shouldBe OperandDirection.UNUSED ins2.reg1 shouldBe 11 @@ -80,8 +80,8 @@ class TestInstructions: FunSpec({ ins.type shouldBe IRDataType.FLOAT ins.reg1direction shouldBe OperandDirection.UNUSED ins.reg2direction shouldBe OperandDirection.UNUSED - ins.fpReg1direction shouldBe OperandDirection.OUTPUT - ins.fpReg2direction shouldBe OperandDirection.INPUT + ins.fpReg1direction shouldBe OperandDirection.WRITE + ins.fpReg2direction shouldBe OperandDirection.READ ins.fpReg1 shouldBe 1 ins.fpReg2 shouldBe 2 ins.reg1 shouldBe null @@ -115,8 +115,8 @@ class TestInstructions: FunSpec({ Opcode.values().forEach { val fmt = instructionFormats.getValue(it) fmt.values.forEach { format -> - require(format.reg2==OperandDirection.UNUSED || format.reg2==OperandDirection.INPUT) {"reg2 can only be used as input"} - require(format.fpReg2==OperandDirection.UNUSED || format.fpReg2==OperandDirection.INPUT) {"fpReg2 can only be used as input"} + require(format.reg2==OperandDirection.UNUSED || format.reg2==OperandDirection.READ) {"reg2 can only be used as input"} + require(format.fpReg2==OperandDirection.UNUSED || format.fpReg2==OperandDirection.READ) {"fpReg2 can only be used as input"} } } }