diff --git a/compiler/src/prog8/compiler/Main.kt b/compiler/src/prog8/compiler/Main.kt index 77f9acfd2..651399b41 100644 --- a/compiler/src/prog8/compiler/Main.kt +++ b/compiler/src/prog8/compiler/Main.kt @@ -144,7 +144,6 @@ private fun processAst(programAst: Program, errors: ErrorReporter, compilerOptio println("Processing...") programAst.checkIdentifiers(errors) errors.handle() - programAst.makeForeverLoops() programAst.constantFold(errors) errors.handle() programAst.removeNopsFlattenAnonScopes() @@ -194,7 +193,7 @@ private fun writeAssembly(programAst: Program, errors: ErrorReporter, outputDir: programAst.processAstBeforeAsmGeneration(errors) errors.handle() - // printAst(programAst) // TODO + // printAst(programAst) val assembly = CompilationTarget.asmGenerator( programAst, diff --git a/compiler/src/prog8/optimizer/Extensions.kt b/compiler/src/prog8/optimizer/Extensions.kt index e768a950f..304cd71a3 100644 --- a/compiler/src/prog8/optimizer/Extensions.kt +++ b/compiler/src/prog8/optimizer/Extensions.kt @@ -19,16 +19,13 @@ internal fun Program.constantFold(errors: ErrorReporter) { internal fun Program.optimizeStatements(errors: ErrorReporter): Int { - val optimizer = StatementOptimizer2(this, errors) + val optimizer = StatementOptimizer(this, errors) optimizer.visit(this) val optimizationCount = optimizer.applyModifications() - val old_optimizer = StatementOptimizer(this, errors) - old_optimizer.visit(this) - modules.forEach { it.linkParents(this.namespace) } // re-link in final configuration - return optimizationCount + old_optimizer.optimizationsDone + return optimizationCount } internal fun Program.simplifyExpressions() : Int { diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index d1981c332..a40d22571 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -7,7 +7,6 @@ import prog8.ast.base.* import prog8.ast.expressions.* import prog8.ast.processing.AstWalker import prog8.ast.processing.IAstModification -import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstVisitor import prog8.ast.statements.* import prog8.compiler.target.CompilationTarget @@ -21,8 +20,8 @@ import kotlin.math.floor */ -internal class StatementOptimizer2(private val program: Program, - private val errors: ErrorReporter) : AstWalker() { +internal class StatementOptimizer(private val program: Program, + private val errors: ErrorReporter) : AstWalker() { private val callgraph = CallGraph(program) private val pureBuiltinFunctions = BuiltinFunctions.filter { it.value.pure } @@ -303,6 +302,97 @@ internal class StatementOptimizer2(private val program: Program, return emptyList() } + override fun after(assignment: Assignment, parent: Node): Iterable { + if(assignment.aug_op!=null) + throw FatalAstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment") + + // remove assignments to self + if(assignment.target isSameAs assignment.value) { + if(assignment.target.isNotMemory(program.namespace)) + return listOf(IAstModification.Remove(assignment, parent)) + } + + val targetIDt = assignment.target.inferType(program, assignment) + if(!targetIDt.isKnown) + throw FatalAstException("can't infer type of assignment target") + + + // optimize binary expressions a bit + val targetDt = targetIDt.typeOrElse(DataType.STRUCT) + val bexpr=assignment.value as? BinaryExpression + if(bexpr!=null) { + val cv = bexpr.right.constValue(program)?.number?.toDouble() + if (cv != null && assignment.target isSameAs bexpr.left) { + // assignments of the form: X = X + // remove assignments that have no effect (such as X=X+0) + // optimize/rewrite some other expressions + val vardeclDt = (assignment.target.identifier?.targetVarDecl(program.namespace))?.type + when (bexpr.operator) { + "+" -> { + if (cv == 0.0) { + return listOf(IAstModification.Remove(assignment, parent)) + } else if (targetDt in IntegerDatatypes && floor(cv) == cv) { + if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) { + // replace by several INCs (a bit less when dealing with memory targets) + val incs = AnonymousScope(mutableListOf(), assignment.position) + repeat(cv.toInt()) { + incs.statements.add(PostIncrDecr(assignment.target, "++", assignment.position)) + } + return listOf(IAstModification.ReplaceNode(assignment, incs, parent)) + } + } + } + "-" -> { + if (cv == 0.0) { + return listOf(IAstModification.Remove(assignment, parent)) + } else if (targetDt in IntegerDatatypes && floor(cv) == cv) { + if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) { + // replace by several DECs (a bit less when dealing with memory targets) + val decs = AnonymousScope(mutableListOf(), assignment.position) + repeat(cv.toInt()) { + decs.statements.add(PostIncrDecr(assignment.target, "--", assignment.position)) + } + return listOf(IAstModification.ReplaceNode(assignment, decs, parent)) + } + } + } + "*" -> if (cv == 1.0) return listOf(IAstModification.Remove(assignment, parent)) + "/" -> if (cv == 1.0) return listOf(IAstModification.Remove(assignment, parent)) + "**" -> if (cv == 1.0) return listOf(IAstModification.Remove(assignment, parent)) + "|" -> if (cv == 0.0) return listOf(IAstModification.Remove(assignment, parent)) + "^" -> if (cv == 0.0) return listOf(IAstModification.Remove(assignment, parent)) + "<<" -> { + if (cv == 0.0) + return listOf(IAstModification.Remove(assignment, parent)) + // replace by in-place lsl(...) call + val scope = AnonymousScope(mutableListOf(), assignment.position) + var numshifts = cv.toInt() + while (numshifts > 0) { + scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsl"), assignment.position), + mutableListOf(bexpr.left), true, assignment.position)) + numshifts-- + } + return listOf(IAstModification.ReplaceNode(assignment, scope, parent)) + } + ">>" -> { + if (cv == 0.0) + return listOf(IAstModification.Remove(assignment, parent)) + // replace by in-place lsr(...) call + val scope = AnonymousScope(mutableListOf(), assignment.position) + var numshifts = cv.toInt() + while (numshifts > 0) { + scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsr"), assignment.position), + mutableListOf(bexpr.left), true, assignment.position)) + numshifts-- + } + return listOf(IAstModification.ReplaceNode(assignment, scope, parent)) + } + } + + } + } + return emptyList() + } private fun deduplicateAssignments(statements: List): MutableList { // removes 'duplicate' assignments that assign the isSameAs target @@ -354,175 +444,3 @@ internal class StatementOptimizer2(private val program: Program, } } - - -// ------------------------------------------------------------ - - -// TODO implement using AstWalker instead of IAstModifyingVisitor -internal class StatementOptimizer(private val program: Program, - private val errors: ErrorReporter) : IAstModifyingVisitor { - var optimizationsDone: Int = 0 - private set - - private val vardeclsToRemove = mutableListOf() - - override fun visit(program: Program) { - super.visit(program) - - for(decl in vardeclsToRemove) { - decl.definingScope().remove(decl) - } - } - - override fun visit(assignment: Assignment): Statement { - if(assignment.aug_op!=null) - throw FatalAstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment") - - if(assignment.target isSameAs assignment.value) { - if(assignment.target.isNotMemory(program.namespace)) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - } - val targetIDt = assignment.target.inferType(program, assignment) - if(!targetIDt.isKnown) - throw FatalAstException("can't infer type of assignment target") - val targetDt = targetIDt.typeOrElse(DataType.STRUCT) - val bexpr=assignment.value as? BinaryExpression - if(bexpr!=null) { - val cv = bexpr.right.constValue(program)?.number?.toDouble() - if (cv == null) { - if (bexpr.operator == "+" && targetDt != DataType.FLOAT) { - if (bexpr.left isSameAs bexpr.right && assignment.target isSameAs bexpr.left) { - bexpr.operator = "*" - bexpr.right = NumericLiteralValue.optimalInteger(2, assignment.value.position) - optimizationsDone++ - return assignment - } - } - } else { - if (assignment.target isSameAs bexpr.left) { - // remove assignments that have no effect X=X , X+=0, X-=0, X*=1, X/=1, X//=1, A |= 0, A ^= 0, A<<=0, etc etc - // A = A B - val vardeclDt = (assignment.target.identifier?.targetVarDecl(program.namespace))?.type - - when (bexpr.operator) { - "+" -> { - if (cv == 0.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } else if (targetDt in IntegerDatatypes && floor(cv) == cv) { - if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) { - // replace by several INCs (a bit less when dealing with memory targets) - val decs = AnonymousScope(mutableListOf(), assignment.position) - repeat(cv.toInt()) { - decs.statements.add(PostIncrDecr(assignment.target, "++", assignment.position)) - } - return decs - } - } - } - "-" -> { - if (cv == 0.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } else if (targetDt in IntegerDatatypes && floor(cv) == cv) { - if ((vardeclDt == VarDeclType.MEMORY && cv in 1.0..3.0) || (vardeclDt != VarDeclType.MEMORY && cv in 1.0..8.0)) { - // replace by several DECs (a bit less when dealing with memory targets) - val decs = AnonymousScope(mutableListOf(), assignment.position) - repeat(cv.toInt()) { - decs.statements.add(PostIncrDecr(assignment.target, "--", assignment.position)) - } - return decs - } - } - } - "*" -> if (cv == 1.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - "/" -> if (cv == 1.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - "**" -> if (cv == 1.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - "|" -> if (cv == 0.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - "^" -> if (cv == 0.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - "<<" -> { - if (cv == 0.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - if (((targetDt == DataType.UWORD || targetDt == DataType.WORD) && cv > 15.0) || - ((targetDt == DataType.UBYTE || targetDt == DataType.BYTE) && cv > 7.0)) { - assignment.value = NumericLiteralValue.optimalInteger(0, assignment.value.position) - assignment.value.linkParents(assignment) - optimizationsDone++ - } else { - // replace by in-place lsl(...) call - val scope = AnonymousScope(mutableListOf(), assignment.position) - var numshifts = cv.toInt() - while (numshifts > 0) { - scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsl"), assignment.position), - mutableListOf(bexpr.left), true, assignment.position)) - numshifts-- - } - optimizationsDone++ - return scope - } - } - ">>" -> { - if (cv == 0.0) { - optimizationsDone++ - return NopStatement.insteadOf(assignment) - } - if ((targetDt == DataType.UWORD && cv > 15.0) || (targetDt == DataType.UBYTE && cv > 7.0)) { - assignment.value = NumericLiteralValue.optimalInteger(0, assignment.value.position) - assignment.value.linkParents(assignment) - optimizationsDone++ - } else { - // replace by in-place lsr(...) call - val scope = AnonymousScope(mutableListOf(), assignment.position) - var numshifts = cv.toInt() - while (numshifts > 0) { - scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("lsr"), assignment.position), - mutableListOf(bexpr.left), true, assignment.position)) - numshifts-- - } - optimizationsDone++ - return scope - } - } - } - } - } - } - - - return super.visit(assignment) - } - - - override fun visit(label: Label): Statement { - // remove duplicate labels - val stmts = label.definingScope().statements - val startIdx = stmts.indexOf(label) - if(startIdx< stmts.lastIndex && stmts[startIdx+1] == label) - return NopStatement.insteadOf(label) - - return super.visit(label) - } -} - - - diff --git a/examples/test.p8 b/examples/test.p8 index 9563a3d0d..6e9418903 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -4,43 +4,31 @@ %zeropage basicsafe + main { - sub start() { - A += 50 + sub jumpsub() { - A += Y + 1 - A -= Y + 1 - A += Y - 1 - A -= Y - 1 - - A += Y + 2 - A -= Y + 2 - A += Y - 2 - A -= Y - 2 - -; ubyte ubb -; byte bb -; uword uww -; word ww -; word ww2 -; -; A = ubb*0 -; Y = ubb*1 -; A = ubb*2 -; Y = ubb*4 -; A = ubb*8 -; Y = ubb*16 -; A = ubb*32 -; Y = ubb*64 -; A = ubb*128 -; Y = ubb+ubb+ubb -; A = ubb+ubb+ubb+ubb -; ww = ww2+ww2 -; ww = ww2+ww2+ww2 -; ww = ww2+ww2+ww2+ww2 + ; goto jumpsub ; TODO fix compiler loop + goto blabla +blabla: + A=99 + return } + + sub start() { + + A <<= 2 + A >>= 2 + A -= 3 + A = A+A + lsl(X) + lsl(Y) + lsl(A) + lsl(@($d020)) + } + }