From bcc75732e97b4ba425733156f1e684fa60e60e04 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Mon, 23 Mar 2020 23:28:05 +0100 Subject: [PATCH] optimize asm jsr+rts into jmp --- .../target/c64/codegen/AsmOptimizer.kt | 137 +++++++----- examples/test.p8 | 207 +----------------- 2 files changed, 82 insertions(+), 262 deletions(-) diff --git a/compiler/src/prog8/compiler/target/c64/codegen/AsmOptimizer.kt b/compiler/src/prog8/compiler/target/c64/codegen/AsmOptimizer.kt index a73393db2..270745b31 100644 --- a/compiler/src/prog8/compiler/target/c64/codegen/AsmOptimizer.kt +++ b/compiler/src/prog8/compiler/target/c64/codegen/AsmOptimizer.kt @@ -13,55 +13,70 @@ fun optimizeAssembly(lines: MutableList): Int { var linesByFour = getLinesBy(lines, 4) - var removeLines = optimizeUselessStackByteWrites(linesByFour) - if(removeLines.isNotEmpty()) { - for (i in removeLines.reversed()) - lines.removeAt(i) + var mods = optimizeUselessStackByteWrites(linesByFour) + if(mods.isNotEmpty()) { + apply(mods, lines) linesByFour = getLinesBy(lines, 4) numberOfOptimizations++ } - removeLines = optimizeIncDec(linesByFour) - if(removeLines.isNotEmpty()) { - for (i in removeLines.reversed()) - lines.removeAt(i) + mods = optimizeIncDec(linesByFour) + if(mods.isNotEmpty()) { + apply(mods, lines) linesByFour = getLinesBy(lines, 4) numberOfOptimizations++ } - removeLines = optimizeCmpSequence(linesByFour) - if(removeLines.isNotEmpty()) { - for (i in removeLines.reversed()) - lines.removeAt(i) + mods = optimizeCmpSequence(linesByFour) + if(mods.isNotEmpty()) { + apply(mods, lines) linesByFour = getLinesBy(lines, 4) numberOfOptimizations++ } - removeLines = optimizeStoreLoadSame(linesByFour) - if(removeLines.isNotEmpty()) { - for (i in removeLines.reversed()) - lines.removeAt(i) + mods = optimizeStoreLoadSame(linesByFour) + if(mods.isNotEmpty()) { + apply(mods, lines) + linesByFour = getLinesBy(lines, 4) + numberOfOptimizations++ + } + + mods= optimizeJsrRts(linesByFour) + if(mods.isNotEmpty()) { + apply(mods, lines) linesByFour = getLinesBy(lines, 4) numberOfOptimizations++ } var linesByFourteen = getLinesBy(lines, 14) - removeLines = optimizeSameAssignments(linesByFourteen) - if(removeLines.isNotEmpty()) { - for (i in removeLines.reversed()) - lines.removeAt(i) + mods = optimizeSameAssignments(linesByFourteen) + if(mods.isNotEmpty()) { + apply(mods, lines) linesByFourteen = getLinesBy(lines, 14) numberOfOptimizations++ } // TODO more assembly optimizations - // TODO optimize jsr + rts -> jmp - return numberOfOptimizations } -fun optimizeCmpSequence(linesByFour: List>>): List { +private class Modification(val lineIndex: Int, val remove: Boolean, val replacement: String?) + +private fun apply(modifications: List, lines: MutableList) { + for (modification in modifications.sortedBy { it.lineIndex }.reversed()) { + if(modification.remove) + lines.removeAt(modification.lineIndex) + else + lines[modification.lineIndex] = modification.replacement!! + } +} + +private fun getLinesBy(lines: MutableList, windowSize: Int) = +// all lines (that aren't empty or comments) in sliding windows of certain size + lines.withIndex().filter { it.value.isNotBlank() && !it.value.trimStart().startsWith(';') }.windowed(windowSize, partialWindows = false) + +private fun optimizeCmpSequence(linesByFour: List>>): List { // the when statement (on bytes) generates a sequence of: // lda $ce01,x // cmp #$20 @@ -70,42 +85,42 @@ fun optimizeCmpSequence(linesByFour: List>>): List() + val mods = mutableListOf() for(lines in linesByFour) { if(lines[0].value.trim()=="lda $ESTACK_LO_PLUS1_HEX,x" && lines[1].value.trim().startsWith("cmp ") && lines[2].value.trim().startsWith("beq ") && lines[3].value.trim()=="lda $ESTACK_LO_PLUS1_HEX,x") { - removeLines.add(lines[3].index) // remove the second lda + mods.add(Modification(lines[3].index, true, null)) // remove the second lda } } - return removeLines + return mods } -fun optimizeUselessStackByteWrites(linesByFour: List>>): List { +private fun optimizeUselessStackByteWrites(linesByFour: List>>): List { // sta on stack, dex, inx, lda from stack -> eliminate this useless stack byte write // this is a lot harder for word values because the instruction sequence varies. - val removeLines = mutableListOf() + val mods = mutableListOf() for(lines in linesByFour) { if(lines[0].value.trim()=="sta $ESTACK_LO_HEX,x" && lines[1].value.trim()=="dex" && lines[2].value.trim()=="inx" && lines[3].value.trim()=="lda $ESTACK_LO_HEX,x") { - removeLines.add(lines[1].index) - removeLines.add(lines[2].index) - removeLines.add(lines[3].index) + mods.add(Modification(lines[1].index, true, null)) + mods.add(Modification(lines[2].index, true, null)) + mods.add(Modification(lines[3].index, true, null)) } } - return removeLines + return mods } -fun optimizeSameAssignments(linesByFourteen: List>>): List { +private fun optimizeSameAssignments(linesByFourteen: List>>): List { // optimize sequential assignments of the isSameAs value to various targets (bytes, words, floats) // the float one is the one that requires 2*7=14 lines of code to check... // @todo a better place to do this is in the Compiler instead and transform the Ast, or the AsmGen, and never even create the inefficient asm in the first place... - val removeLines = mutableListOf() + val mods = mutableListOf() for (pair in linesByFourteen) { val first = pair[0].value.trimStart() val second = pair[1].value.trimStart() @@ -124,8 +139,8 @@ fun optimizeSameAssignments(linesByFourteen: List>>): val fourthvalue = sixth.substring(4) if(firstvalue==thirdvalue && secondvalue==fourthvalue) { // lda/ldy sta/sty twice the isSameAs word --> remove second lda/ldy pair (fifth and sixth lines) - removeLines.add(pair[4].index) - removeLines.add(pair[5].index) + mods.add(Modification(pair[4].index, true, null)) + mods.add(Modification(pair[5].index, true, null)) } } @@ -134,7 +149,7 @@ fun optimizeSameAssignments(linesByFourteen: List>>): val secondvalue = third.substring(4) if(firstvalue==secondvalue) { // lda value / sta ? / lda isSameAs-value / sta ? -> remove second lda (third line) - removeLines.add(pair[2].index) + mods.add(Modification(pair[2].index, true, null)) } } @@ -153,24 +168,20 @@ fun optimizeSameAssignments(linesByFourteen: List>>): if(first.substring(4) == eighth.substring(4) && second.substring(4)==nineth.substring(4)) { // identical float init - removeLines.add(pair[7].index) - removeLines.add(pair[8].index) - removeLines.add(pair[9].index) - removeLines.add(pair[10].index) + mods.add(Modification(pair[7].index, true, null)) + mods.add(Modification(pair[8].index, true, null)) + mods.add(Modification(pair[9].index, true, null)) + mods.add(Modification(pair[10].index, true, null)) } } } } - return removeLines + return mods } -private fun getLinesBy(lines: MutableList, windowSize: Int) = -// all lines (that aren't empty or comments) in sliding windows of certain size - lines.withIndex().filter { it.value.isNotBlank() && !it.value.trimStart().startsWith(';') }.windowed(windowSize, partialWindows = false) - -private fun optimizeStoreLoadSame(linesByFour: List>>): List { +private fun optimizeStoreLoadSame(linesByFour: List>>): List { // sta X + lda X, sty X + ldy X, stx X + ldx X -> the second instruction can be eliminated - val removeLines = mutableListOf() + val mods = mutableListOf() for (pair in linesByFour) { val first = pair[0].value.trimStart() val second = pair[1].value.trimStart() @@ -188,26 +199,40 @@ private fun optimizeStoreLoadSame(linesByFour: List>>) val firstLoc = first.substring(4) val secondLoc = second.substring(4) if (firstLoc == secondLoc) { - removeLines.add(pair[1].index) + mods.add(Modification(pair[1].index, true, null)) } } } - return removeLines + return mods } -private fun optimizeIncDec(linesByTwo: List>>): List { +private fun optimizeIncDec(linesByFour: List>>): List { // sometimes, iny+dey / inx+dex / dey+iny / dex+inx sequences are generated, these can be eliminated. - val removeLines = mutableListOf() - for (pair in linesByTwo) { + val mods = mutableListOf() + for (pair in linesByFour) { val first = pair[0].value val second = pair[1].value if ((" iny" in first || "\tiny" in first) && (" dey" in second || "\tdey" in second) || (" inx" in first || "\tinx" in first) && (" dex" in second || "\tdex" in second) || (" dey" in first || "\tdey" in first) && (" iny" in second || "\tiny" in second) || (" dex" in first || "\tdex" in first) && (" inx" in second || "\tinx" in second)) { - removeLines.add(pair[0].index) - removeLines.add(pair[1].index) + mods.add(Modification(pair[0].index, true, null)) + mods.add(Modification(pair[1].index, true, null)) } } - return removeLines + return mods +} + +private fun optimizeJsrRts(linesByFour: List>>): List { + // jsr Sub + rts -> jmp Sub + val mods = mutableListOf() + for (pair in linesByFour) { + val first = pair[0].value + val second = pair[1].value + if ((" jsr" in first || "\tjsr" in first ) && (" rts" in second || "\trts" in second)) { + mods += Modification(pair[0].index, false, pair[0].value.replace("jsr", "jmp")) + mods += Modification(pair[1].index, true, null) + } + } + return mods } diff --git a/examples/test.p8 b/examples/test.p8 index 22e32fd83..d0fa31ba2 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -6,211 +6,6 @@ main { sub start() { - A = shiftlb0() - A = shiftlb1() - A = shiftlb2() - A = shiftlb3() - A = shiftlb4() - A = shiftlb5() - A = shiftlb6() - A = shiftlb7() - A = shiftlb8() - A = shiftlb9() - - A = shiftrb0() - A = shiftrb1() - A = shiftrb2() - A = shiftrb3() - A = shiftrb4() - A = shiftrb5() - A = shiftrb6() - A = shiftrb7() - A = shiftrb8() - A = shiftrb9() - - uword uw - uw = shiftluw0() - uw = shiftluw1() - uw = shiftluw2() - uw = shiftluw3() - uw = shiftluw4() - uw = shiftluw5() - uw = shiftluw6() - uw = shiftluw7() - uw = shiftluw8() - uw = shiftluw9() - - uw = shiftruw0() - uw = shiftruw1() - uw = shiftruw2() - uw = shiftruw3() - uw = shiftruw4() - uw = shiftruw5() - uw = shiftruw6() - uw = shiftruw7() -; uw = shiftruw8() -; uw = shiftruw9() - } - - sub shiftruw0() -> uword { - uword q = 12345 - return q >> 0 - } - - sub shiftruw1() -> uword { - uword q = 12345 - return q >> 1 - } - - sub shiftruw2() -> uword { - uword q = 12345 - return q >> 2 - } - - sub shiftruw3() -> uword { - uword q = 12345 - return q >> 3 - } - - sub shiftruw4() -> uword { - uword q = 12345 - return q >> 4 - } - - sub shiftruw5() -> uword { - uword q = 12345 - return q >> 5 - } - - sub shiftruw6() -> uword { - uword q = 12345 - return q >> 6 - } - - sub shiftruw7() -> uword { - uword q = 12345 - return q >> 7 - } - -; sub shiftruw8() -> uword { -; uword q = 12345 -; return (q >> 8) as uword ; TODO auto cast return type -; } -; -; sub shiftruw9() -> uword { -; uword q = 12345 -; return (q >> 9) as uword ; TODO auto cast return type -; } - - sub shiftluw0() -> uword { - uword q = 12345 - return q << 0 - } - - sub shiftluw1() -> uword { - uword q = 12345 - return q << 1 - } - - sub shiftluw2() -> uword { - uword q = 12345 - return q << 2 - } - - sub shiftluw3() -> uword { - uword q = 12345 - return q << 3 - } - - sub shiftluw4() -> uword { - uword q = 12345 - return q << 4 - } - - sub shiftluw5() -> uword { - uword q = 12345 - return q << 5 - } - - sub shiftluw6() -> uword { - uword q = 12345 - return q << 6 - } - - sub shiftluw7() -> uword { - uword q = 12345 - return q << 7 - } - - sub shiftluw8() -> uword { - uword q = 12345 - return q << 8 - } - - sub shiftluw9() -> uword { - uword q = 12345 - return q << 9 - } - - sub shiftlb0() -> ubyte { - return Y << 0 - } - sub shiftlb1() -> ubyte { - return Y << 1 - } - sub shiftlb2() -> ubyte { - return Y << 2 - } - sub shiftlb3() -> ubyte { - return Y << 3 - } - sub shiftlb4() -> ubyte { - return Y << 4 - } - sub shiftlb5() -> ubyte { - return Y << 5 - } - sub shiftlb6() -> ubyte { - return Y << 6 - } - sub shiftlb7() -> ubyte { - return Y << 7 - } - sub shiftlb8() -> ubyte { - return Y << 8 - } - sub shiftlb9() -> ubyte { - return Y << 9 - } - - sub shiftrb0() -> ubyte { - return Y >> 0 - } - sub shiftrb1() -> ubyte { - return Y >> 1 - } - sub shiftrb2() -> ubyte { - return Y >> 2 - } - sub shiftrb3() -> ubyte { - return Y >> 3 - } - sub shiftrb4() -> ubyte { - return Y >> 4 - } - sub shiftrb5() -> ubyte { - return Y >> 5 - } - sub shiftrb6() -> ubyte { - return Y >> 6 - } - sub shiftrb7() -> ubyte { - return Y >> 7 - } - sub shiftrb8() -> ubyte { - return Y >> 8 - } - sub shiftrb9() -> ubyte { - return Y >> 9 + c64scr.print("ubyte shift left\n") } }