diff --git a/codeCore/src/prog8/code/ast/AstBase.kt b/codeCore/src/prog8/code/ast/AstBase.kt index 244a47cb4..0adcf6a64 100644 --- a/codeCore/src/prog8/code/ast/AstBase.kt +++ b/codeCore/src/prog8/code/ast/AstBase.kt @@ -31,7 +31,10 @@ sealed class PtNode(val position: Position) { } -class PtNodeGroup : PtNode(Position.DUMMY) +sealed interface IPtStatementContainer + + +class PtNodeGroup : PtNode(Position.DUMMY), IPtStatementContainer sealed class PtNamedNode(var name: String, position: Position): PtNode(position) { @@ -75,7 +78,7 @@ class PtBlock(name: String, val source: SourceCode, // taken from the module the block is defined in. val options: Options, position: Position -) : PtNamedNode(name, position) { +) : PtNamedNode(name, position), IPtStatementContainer { enum class BlockAlignment { NONE, WORD, diff --git a/codeCore/src/prog8/code/ast/AstExpressions.kt b/codeCore/src/prog8/code/ast/AstExpressions.kt index 5a9139874..a089eb593 100644 --- a/codeCore/src/prog8/code/ast/AstExpressions.kt +++ b/codeCore/src/prog8/code/ast/AstExpressions.kt @@ -36,7 +36,14 @@ sealed class PtExpression(val type: DataType, position: Position) : PtNode(posit return arrayIndexExpr!! isSameAs other.arrayIndexExpr!! } is PtArrayIndexer -> other is PtArrayIndexer && other.type==type && other.variable isSameAs variable && other.index isSameAs index && other.splitWords==splitWords - is PtBinaryExpression -> other is PtBinaryExpression && other.left isSameAs left && other.right isSameAs right + is PtBinaryExpression -> { + if(other !is PtBinaryExpression || other.operator!=operator) + false + else if(operator in AssociativeOperators) + (other.left isSameAs left && other.right isSameAs right) || (other.left isSameAs right && other.right isSameAs left) + else + other.left isSameAs left && other.right isSameAs right + } is PtContainmentCheck -> other is PtContainmentCheck && other.type==type && other.element isSameAs element && other.iterable isSameAs iterable is PtIdentifier -> other is PtIdentifier && other.type==type && other.name==name is PtMachineRegister -> other is PtMachineRegister && other.type==type && other.register==register diff --git a/codeCore/src/prog8/code/ast/AstStatements.kt b/codeCore/src/prog8/code/ast/AstStatements.kt index ee0570dd7..27fd72cdd 100644 --- a/codeCore/src/prog8/code/ast/AstStatements.kt +++ b/codeCore/src/prog8/code/ast/AstStatements.kt @@ -23,7 +23,7 @@ class PtSub( val parameters: List, val returntype: DataType?, position: Position -) : PtNamedNode(name, position), IPtSubroutine { +) : PtNamedNode(name, position), IPtSubroutine, IPtStatementContainer { init { // params and return value should not be str if(parameters.any{ it.type !in NumericDatatypes }) diff --git a/codeCore/src/prog8/code/optimize/Optimizer.kt b/codeCore/src/prog8/code/optimize/Optimizer.kt new file mode 100644 index 000000000..c316ea5f0 --- /dev/null +++ b/codeCore/src/prog8/code/optimize/Optimizer.kt @@ -0,0 +1,120 @@ +package prog8.code.optimize + +import prog8.code.ast.* +import prog8.code.core.* + + +fun optimizeIntermediateAst(program: PtProgram, options: CompilationOptions, errors: IErrorReporter) { + if (!options.optimize) + return + while(errors.noErrors() && optimizeCommonSubExpressions(program, errors)>0) { + // keep rolling + } +} + + +private fun walkAst(root: PtNode, act: (node: PtNode, depth: Int) -> Boolean) { + fun recurse(node: PtNode, depth: Int) { + if(act(node, depth)) + node.children.forEach { recurse(it, depth+1) } + } + recurse(root, 0) +} + + +private fun optimizeCommonSubExpressions(program: PtProgram, errors: IErrorReporter): Int { + + fun extractableSubExpr(expr: PtExpression): Boolean { + return if(expr is PtBinaryExpression) + !expr.left.isSimple() || !expr.right.isSimple() || (expr.operator !in LogicalOperators && expr.operator !in BitwiseOperators) + else + !expr.isSimple() + } + + // for each Binaryexpression, recurse to find a common subexpression pair therein. + val commons = mutableMapOf>() + walkAst(program) { node: PtNode, depth: Int -> + if(node is PtBinaryExpression) { + val subExpressions = mutableListOf() + walkAst(node.left) { subNode: PtNode, subDepth: Int -> + if (subNode is PtExpression) { + if(extractableSubExpr(subNode)) subExpressions.add(subNode) + true + } else false + } + walkAst(node.right) { subNode: PtNode, subDepth: Int -> + if (subNode is PtExpression) { + if(extractableSubExpr(subNode)) subExpressions.add(subNode) + true + } else false + } + + outer@for (first in subExpressions) { + for (second in subExpressions) { + if (first!==second && first isSameAs second) { + commons[node] = first to second + break@outer // do only 1 replacement at a time per binaryexpression + } + } + } + false + } else true + } + + // replace common subexpressions by a temp variable that is assigned only once. + commons.forEach { binexpr, (occurrence1, occurrence2) -> + val (stmtContainer, stmt) = findContainingStatements(binexpr) + val occurrence1idx = occurrence1.parent.children.indexOf(occurrence1) + val occurrence2idx = occurrence2.parent.children.indexOf(occurrence2) + val containerScopedName = findScopeName(stmtContainer) + val tempvarName = "subexprvar_line${binexpr.position.line}_${binexpr.hashCode().toUInt()}" + // TODO: some tempvars could be reused, if they are from different lines + + val datatype = occurrence1.type + val singleReplacement1 = PtIdentifier("$containerScopedName.$tempvarName", datatype, occurrence1.position) + val singleReplacement2 = PtIdentifier("$containerScopedName.$tempvarName", datatype, occurrence2.position) + occurrence1.parent.children[occurrence1idx] = singleReplacement1 + singleReplacement1.parent = occurrence1.parent + occurrence2.parent.children[occurrence2idx] = singleReplacement2 + singleReplacement2.parent = occurrence2.parent + + val tempassign = PtAssignment(binexpr.position).also { assign -> + assign.add(PtAssignTarget(binexpr.position).also { tgt-> + tgt.add(PtIdentifier("$containerScopedName.$tempvarName", datatype, binexpr.position)) + }) + assign.add(occurrence1) + occurrence1.parent = assign + } + stmtContainer.children.add(stmtContainer.children.indexOf(stmt), tempassign) + tempassign.parent = stmtContainer + + val tempvar = PtVariable(tempvarName, datatype, ZeropageWish.DONTCARE, null, null, binexpr.position) + stmtContainer.add(0, tempvar) + tempvar.parent = stmtContainer + + errors.info("common subexpressions replaced by a tempvar, maybe simplify the expression manually", binexpr.position) + } + + return commons.size +} + + +internal fun findScopeName(node: PtNode): String { + var parent=node + while(parent !is PtNamedNode) + parent = parent.parent + return parent.scopedName +} + + +internal fun findContainingStatements(node: PtNode): Pair { // returns (parentstatementcontainer, childstatement) + var parent = node.parent + var child = node + while(true) { + if(parent is IPtStatementContainer) { + return parent to child + } + child=parent + parent=parent.parent + } +} diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 02d00025a..4d809d7b6 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -10,6 +10,7 @@ import prog8.ast.statements.Directive import prog8.code.SymbolTableMaker import prog8.code.ast.PtProgram import prog8.code.core.* +import prog8.code.optimize.optimizeIntermediateAst import prog8.code.target.* import prog8.codegen.vm.VmCodeGen import prog8.compiler.astprocessing.* @@ -102,7 +103,6 @@ fun compileProgram(args: CompilerArguments): CompilationResult? { compilationOptions, args.errors, BuiltinFunctionsFacade(BuiltinFunctions), - compTarget ) } postprocessAst(program, args.errors, compilationOptions) @@ -124,6 +124,9 @@ fun compileProgram(args: CompilerArguments): CompilationResult? { // printProgram(program) val intermediateAst = IntermediateAstMaker(program, args.errors).transform() +// printAst(intermediateAst, true) { println(it) } + optimizeIntermediateAst(intermediateAst, compilationOptions, args.errors) + args.errors.report() // println("*********** AST RIGHT BEFORE ASM GENERATION *************") // printAst(intermediateAst, true, ::println) @@ -378,13 +381,13 @@ private fun processAst(program: Program, errors: IErrorReporter, compilerOptions errors.report() } -private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, errors: IErrorReporter, functions: IBuiltinFunctions, compTarget: ICompilationTarget) { - val remover = UnusedCodeRemover(program, errors, compTarget) +private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, errors: IErrorReporter, functions: IBuiltinFunctions) { + val remover = UnusedCodeRemover(program, errors, compilerOptions.compTarget) remover.visit(program) remover.applyModifications() while (true) { // keep optimizing expressions and statements until no more steps remain - val optsDone1 = program.simplifyExpressions(errors, compTarget) + val optsDone1 = program.simplifyExpressions(errors, compilerOptions.compTarget) val optsDone2 = program.optimizeStatements(errors, functions, compilerOptions) val optsDone3 = program.inlineSubroutines(compilerOptions) program.constantFold(errors, compilerOptions) // because simplified statements and expressions can result in more constants that can be folded away @@ -395,7 +398,7 @@ private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, e if (optsDone1 + optsDone2 + optsDone3 == 0) break } - val remover2 = UnusedCodeRemover(program, errors, compTarget) + val remover2 = UnusedCodeRemover(program, errors, compilerOptions.compTarget) remover2.visit(program) remover2.applyModifications() if(errors.noErrors()) diff --git a/compiler/test/ast/TestIntermediateAst.kt b/compiler/test/ast/TestIntermediateAst.kt index 15a791022..1677e1eb1 100644 --- a/compiler/test/ast/TestIntermediateAst.kt +++ b/compiler/test/ast/TestIntermediateAst.kt @@ -6,6 +6,7 @@ import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.shouldBe import prog8.code.ast.* import prog8.code.core.DataType +import prog8.code.core.Position import prog8.code.target.C64Target import prog8.compiler.astprocessing.IntermediateAstMaker import prog8tests.helpers.ErrorReporterForTests @@ -60,4 +61,31 @@ class TestIntermediateAst: FunSpec({ fcall.type shouldBe DataType.UBYTE } + test("isSame on binaryExpressions") { + val expr1 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY) + expr1.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY)) + expr1.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY)) + val expr2 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY) + expr2.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY)) + expr2.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY)) + (expr1 isSameAs expr2) shouldBe true + val expr3 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY) + expr3.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY)) + expr3.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY)) + (expr1 isSameAs expr3) shouldBe false + } + + test("isSame on binaryExpressions with associative operators") { + val expr1 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY) + expr1.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY)) + expr1.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY)) + val expr2 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY) + expr2.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY)) + expr2.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY)) + (expr1 isSameAs expr2) shouldBe true + val expr3 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY) + expr3.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY)) + expr3.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY)) + (expr1 isSameAs expr3) shouldBe true + } }) \ No newline at end of file diff --git a/compiler/test/ast/TestVariousCompilerAst.kt b/compiler/test/ast/TestVariousCompilerAst.kt index e9249be71..aa20f15f2 100644 --- a/compiler/test/ast/TestVariousCompilerAst.kt +++ b/compiler/test/ast/TestVariousCompilerAst.kt @@ -511,5 +511,33 @@ main { val value = (st[5] as Assignment).value as BinaryExpression value.operator shouldBe "%" } + + test("isSame on binary expressions") { + val left1 = NumericLiteral.optimalInteger(1, Position.DUMMY) + val right1 = NumericLiteral.optimalInteger(2, Position.DUMMY) + val expr1 = BinaryExpression(left1, "/", right1, Position.DUMMY) + val left2 = NumericLiteral.optimalInteger(1, Position.DUMMY) + val right2 = NumericLiteral.optimalInteger(2, Position.DUMMY) + val expr2 = BinaryExpression(left2, "/", right2, Position.DUMMY) + (expr1 isSameAs expr2) shouldBe true + val left3 = NumericLiteral.optimalInteger(2, Position.DUMMY) + val right3 = NumericLiteral.optimalInteger(1, Position.DUMMY) + val expr3 = BinaryExpression(left3, "/", right3, Position.DUMMY) + (expr1 isSameAs expr3) shouldBe false + } + + test("isSame on binary expressions with associative operators") { + val left1 = NumericLiteral.optimalInteger(1, Position.DUMMY) + val right1 = NumericLiteral.optimalInteger(2, Position.DUMMY) + val expr1 = BinaryExpression(left1, "+", right1, Position.DUMMY) + val left2 = NumericLiteral.optimalInteger(1, Position.DUMMY) + val right2 = NumericLiteral.optimalInteger(2, Position.DUMMY) + val expr2 = BinaryExpression(left2, "+", right2, Position.DUMMY) + (expr1 isSameAs expr2) shouldBe true + val left3 = NumericLiteral.optimalInteger(2, Position.DUMMY) + val right3 = NumericLiteral.optimalInteger(1, Position.DUMMY) + val expr3 = BinaryExpression(left3, "+", right3, Position.DUMMY) + (expr1 isSameAs expr3) shouldBe true + } }) diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index cbd0d745d..257151238 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -34,11 +34,14 @@ sealed class Expression: Node { (other is IdentifierReference && other.nameInSource==nameInSource) is PrefixExpression -> (other is PrefixExpression && other.operator==operator && other.expression isSameAs expression) - is BinaryExpression -> - (other is BinaryExpression && other.operator==operator - && other.left isSameAs left - && other.right isSameAs right - && other.isChainedComparison() == isChainedComparison()) + is BinaryExpression -> { + if(other !is BinaryExpression || other.operator!=operator || other.isChainedComparison()!=isChainedComparison()) + false + else if(operator in AssociativeOperators) + (other.left isSameAs left && other.right isSameAs right) || (other.left isSameAs right && other.right isSameAs left) + else + other.left isSameAs left && other.right isSameAs right + } is ArrayIndexedExpression -> { (other is ArrayIndexedExpression && other.arrayvar.nameInSource == arrayvar.nameInSource && other.indexer isSameAs indexer) diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 0e0c2fa05..0ae3b304b 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -1,10 +1,18 @@ - TODO ==== -PtAst/IR: attempt more complex common subexpression eliminations. - for any "top level" PtExpression enumerate all subexpressions and find commons, replace them by a tempvar - for walking the ast see walkAst() but it should not recurse into the "top level" PtExpression again +- why is the right term of cx16.r0 = (cx16.r1+cx16.r2) + (cx16.r1+cx16.r2) flipped around but the left term isn't? + +- Revert or fix current "desugar chained comparisons" it causes problems with if statements. + ubyte @shared n=20 + ubyte @shared L1=10 + ubyte @shared L2=100 + + if n < L1 { + ;txt.print("bing") + } else { + txt.print("boom") ; no longer triggers + } ... @@ -36,7 +44,6 @@ Compiler: global initialization values are simply a list of LOAD instructions. Variables replaced include all subroutine parameters! So the only variables that remain as variables are arrays and strings. - ir: add more optimizations in IRPeepholeOptimizer -- ir: for expressions with array indexes that occur multiple times, can we avoid loading them into new virtualregs everytime and just reuse a single virtualreg as indexer? (this is a form of common subexpression elimination) - ir: the @split arrays are currently also split in _lsb/_msb arrays in the IR, and operations take multiple (byte) instructions that may lead to verbose and slow operation and machine code generation down the line. maybe another representation is needed once actual codegeneration is done from the IR...? - [problematic due to using 64tass:] better support for building library programs, where unused .proc shouldn't be deleted from the assembly? diff --git a/examples/test.p8 b/examples/test.p8 index 0d933460f..dfce0ad43 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,13 +1,14 @@ %import textio %zeropage basicsafe -; Note: this program can be compiled for multiple target systems. - main { sub start() { - cx16.r0L=1 - while cx16.r0L < 10 and cx16.r0L>0 { - cx16.r0L++ - } + ubyte[10] array1 + ubyte[10] array2 + ubyte @shared xx + + cx16.r0 = (cx16.r1+cx16.r2) / (cx16.r2+cx16.r1) + cx16.r1 = 4*(cx16.r1+cx16.r2) + 3*(cx16.r1+cx16.r2) + cx16.r2 = array1[xx+20]==10 or array2[xx+20]==20 or array1[xx+20]==30 or array2[xx+20]==40 } }