From 176ec8ac7d64213850c9d075301188a8dcb88cf2 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Tue, 23 Aug 2022 00:05:57 +0200 Subject: [PATCH] fix 6502 codegen bug: complex comparison expression is evaluated wrong. Fixed by reintroducing splitting of comparison expression in if statements by using a temporary variable and/or register to precompute left/right values. --- compiler/src/prog8/compiler/Compiler.kt | 2 +- .../compiler/astprocessing/AstExtensions.kt | 4 +- ...NotExpressionAndIfComparisonExprChanger.kt | 189 ++++++++++++++++++ .../astprocessing/NotExpressionChanger.kt | 80 -------- docs/source/todo.rst | 2 - examples/cx16/circles.p8 | 4 +- examples/test.p8 | 36 ++-- 7 files changed, 212 insertions(+), 105 deletions(-) create mode 100644 compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt delete mode 100644 compiler/src/prog8/compiler/astprocessing/NotExpressionChanger.kt diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 8bce24256..16f5c763a 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -339,7 +339,7 @@ private fun processAst(program: Program, errors: IErrorReporter, compilerOptions errors.report() program.reorderStatements(errors, compilerOptions) errors.report() - program.changeNotExpression(errors) + program.changeNotExpressionAndIfComparisonExpr(errors, compilerOptions.compTarget) errors.report() program.addTypecasts(errors, compilerOptions) errors.report() diff --git a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt index cc2920cdf..e95729e24 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt @@ -50,8 +50,8 @@ internal fun Program.reorderStatements(errors: IErrorReporter, options: Compilat } } -internal fun Program.changeNotExpression(errors: IErrorReporter) { - val changer = NotExpressionChanger(this, errors) +internal fun Program.changeNotExpressionAndIfComparisonExpr(errors: IErrorReporter, target: ICompilationTarget) { + val changer = NotExpressionAndIfComparisonExprChanger(this, errors, target) changer.visit(this) while(errors.noErrors() && changer.applyModifications()>0) { changer.visit(this) diff --git a/compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt b/compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt new file mode 100644 index 000000000..0f40d8469 --- /dev/null +++ b/compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt @@ -0,0 +1,189 @@ +package prog8.compiler.astprocessing + +import prog8.ast.* +import prog8.ast.base.FatalAstException +import prog8.ast.expressions.* +import prog8.ast.statements.AssignTarget +import prog8.ast.statements.Assignment +import prog8.ast.statements.AssignmentOrigin +import prog8.ast.statements.IfElse +import prog8.ast.walk.AstWalker +import prog8.ast.walk.IAstModification +import prog8.code.core.* +import prog8.code.target.VMTarget + +internal class NotExpressionAndIfComparisonExprChanger(val program: Program, val errors: IErrorReporter, val compTarget: ICompilationTarget) : AstWalker() { + + override fun before(expr: BinaryExpression, parent: Node): Iterable { + if(expr.operator=="==" || expr.operator=="!=") { + val left = expr.left as? BinaryExpression + if (left != null) { + val rightValue = expr.right.constValue(program) + if (rightValue?.number == 0.0 && rightValue.type in IntegerDatatypes) { + if (left.operator == "==" && expr.operator == "==") { + // (x==something)==0 --> x!=something + left.operator = "!=" + return listOf(IAstModification.ReplaceNode(expr, left, parent)) + } else if (left.operator == "!=" && expr.operator == "==") { + // (x!=something)==0 --> x==something + left.operator = "==" + return listOf(IAstModification.ReplaceNode(expr, left, parent)) + } else if (left.operator == "==" && expr.operator == "!=") { + // (x==something)!=0 --> x==something + left.operator = "==" + return listOf(IAstModification.ReplaceNode(expr, left, parent)) + } else if (left.operator == "!=" && expr.operator == "!=") { + // (x!=something)!=0 --> x!=something + left.operator = "!=" + return listOf(IAstModification.ReplaceNode(expr, left, parent)) + } + } + } + } + return noModifications + } + + override fun after(expr: PrefixExpression, parent: Node): Iterable { + if(expr.operator == "not") { + // not(not(x)) -> x + if((expr.expression as? PrefixExpression)?.operator=="not") + return listOf(IAstModification.ReplaceNode(expr, expr.expression, parent)) + // not(~x) -> x!=0 + if((expr.expression as? PrefixExpression)?.operator=="~") { + val x = (expr.expression as PrefixExpression).expression + val dt = x.inferType(program).getOrElse { throw FatalAstException("invalid dt") } + val notZero = BinaryExpression(x, "!=", NumericLiteral(dt, 0.0, expr.position), expr.position) + return listOf(IAstModification.ReplaceNode(expr, notZero, parent)) + } + val subBinExpr = expr.expression as? BinaryExpression + if(subBinExpr?.operator=="==") { + if(subBinExpr.right.constValue(program)?.number==0.0) { + // not(x==0) -> x!=0 + subBinExpr.operator = "!=" + return listOf(IAstModification.ReplaceNode(expr, subBinExpr, parent)) + } + } else if(subBinExpr?.operator=="!=") { + if(subBinExpr.right.constValue(program)?.number==0.0) { + // not(x!=0) -> x==0 + subBinExpr.operator = "==" + return listOf(IAstModification.ReplaceNode(expr, subBinExpr, parent)) + } + } + + // all other not(x) --> x==0 + // this means that "not" will never occur anywhere again in the ast after this + val replacement = BinaryExpression(expr.expression, "==", NumericLiteral(DataType.UBYTE,0.0, expr.position), expr.position) + return listOf(IAstModification.ReplaceNodeSafe(expr, replacement, parent)) + } + return noModifications + } + + + override fun before(ifElse: IfElse, parent: Node): Iterable { + if(compTarget.name == VMTarget.NAME) // don't apply this optimization for Vm target + return noModifications + + val binExpr = ifElse.condition as? BinaryExpression + if(binExpr==null || binExpr.operator !in ComparisonOperators) + return noModifications + + // Simplify the conditional expression, introduce simple assignments if required. + // This is REQUIRED for correct code generation on 6502 because evaluating certain expressions + // clobber the handful of temporary variables in the zeropage and leaving everything in one + // expression then results in invalid values being compared. VM Codegen doesn't suffer from this. + // NOTE: sometimes this increases code size because additional stores/loads are generated for the + // intermediate variables. We assume these are optimized away from the resulting assembly code later. + val simplify = simplifyConditionalExpression(binExpr) + val modifications = mutableListOf() + if(simplify.rightVarAssignment!=null) { + modifications += IAstModification.ReplaceNode(binExpr.right, simplify.rightOperandReplacement!!, binExpr) + modifications += IAstModification.InsertBefore( + ifElse, + simplify.rightVarAssignment, + parent as IStatementContainer + ) + } + if(simplify.leftVarAssignment!=null) { + modifications += IAstModification.ReplaceNode(binExpr.left, simplify.leftOperandReplacement!!, binExpr) + modifications += IAstModification.InsertBefore( + ifElse, + simplify.leftVarAssignment, + parent as IStatementContainer + ) + } + + return modifications + } + + private class CondExprSimplificationResult( + val leftVarAssignment: Assignment?, + val leftOperandReplacement: Expression?, + val rightVarAssignment: Assignment?, + val rightOperandReplacement: Expression? + ) + + private fun simplifyConditionalExpression(expr: BinaryExpression): CondExprSimplificationResult { + + // TODO: somehow figure out if the expr will result in stack-evaluation STILL after being split off, + // in that case: do *not* split it off but just keep it as it is (otherwise code size increases) + // NOTE: do NOT move this to an earler ast transform phase (such as StatementReorderer or StatementOptimizer) - it WILL result in larger code. + + if(compTarget.name == VMTarget.NAME) // don't apply this optimization for Vm target + return CondExprSimplificationResult(null, null, null, null) + + var leftAssignment: Assignment? = null + var leftOperandReplacement: Expression? = null + var rightAssignment: Assignment? = null + var rightOperandReplacement: Expression? = null + + val separateLeftExpr = !expr.left.isSimple + && expr.left !is IFunctionCall + && expr.left !is ContainmentCheck + val separateRightExpr = !expr.right.isSimple + && expr.right !is IFunctionCall + && expr.right !is ContainmentCheck + val leftDt = expr.left.inferType(program) + val rightDt = expr.right.inferType(program) + + if(!leftDt.isInteger || !rightDt.isInteger) { + // we can't reasonably simplify non-integer expressions + return CondExprSimplificationResult(null, null, null, null) + } + + if(separateLeftExpr) { + val name = getTempRegisterName(leftDt) + leftOperandReplacement = IdentifierReference(name, expr.position) + leftAssignment = Assignment( + AssignTarget(IdentifierReference(name, expr.position), null, null, expr.position), + expr.left.copy(), + AssignmentOrigin.BEFOREASMGEN, expr.position + ) + } + if(separateRightExpr) { + val (tempVarName, _) = program.getTempVar(rightDt.getOrElse { throw FatalAstException("invalid dt") }, true) + rightOperandReplacement = IdentifierReference(tempVarName, expr.position) + rightAssignment = Assignment( + AssignTarget(IdentifierReference(tempVarName, expr.position), null, null, expr.position), + expr.right.copy(), + AssignmentOrigin.BEFOREASMGEN, expr.position + ) + } + return CondExprSimplificationResult( + leftAssignment, leftOperandReplacement, + rightAssignment, rightOperandReplacement + ) + } + + fun getTempRegisterName(dt: InferredTypes.InferredType): List { + return when { + // TODO assume (hope) cx16.r9 isn't used for anything else during the use of this temporary variable... + dt istype DataType.UBYTE -> listOf("cx16", "r9L") + dt istype DataType.BOOL -> listOf("cx16", "r9L") + dt istype DataType.BYTE -> listOf("cx16", "r9sL") + dt istype DataType.UWORD -> listOf("cx16", "r9") + dt istype DataType.WORD -> listOf("cx16", "r9s") + dt.isPassByReference -> listOf("cx16", "r9") + else -> throw FatalAstException("invalid dt $dt") + } + } +} diff --git a/compiler/src/prog8/compiler/astprocessing/NotExpressionChanger.kt b/compiler/src/prog8/compiler/astprocessing/NotExpressionChanger.kt deleted file mode 100644 index de3948702..000000000 --- a/compiler/src/prog8/compiler/astprocessing/NotExpressionChanger.kt +++ /dev/null @@ -1,80 +0,0 @@ -package prog8.compiler.astprocessing - -import prog8.ast.Node -import prog8.ast.Program -import prog8.ast.base.FatalAstException -import prog8.ast.expressions.BinaryExpression -import prog8.ast.expressions.NumericLiteral -import prog8.ast.expressions.PrefixExpression -import prog8.ast.walk.AstWalker -import prog8.ast.walk.IAstModification -import prog8.code.core.DataType -import prog8.code.core.IErrorReporter -import prog8.code.core.IntegerDatatypes - -internal class NotExpressionChanger(val program: Program, val errors: IErrorReporter) : AstWalker() { - - override fun before(expr: BinaryExpression, parent: Node): Iterable { - if(expr.operator=="==" || expr.operator=="!=") { - val left = expr.left as? BinaryExpression - if (left != null) { - val rightValue = expr.right.constValue(program) - if (rightValue?.number == 0.0 && rightValue.type in IntegerDatatypes) { - if (left.operator == "==" && expr.operator == "==") { - // (x==something)==0 --> x!=something - left.operator = "!=" - return listOf(IAstModification.ReplaceNode(expr, left, parent)) - } else if (left.operator == "!=" && expr.operator == "==") { - // (x!=something)==0 --> x==something - left.operator = "==" - return listOf(IAstModification.ReplaceNode(expr, left, parent)) - } else if (left.operator == "==" && expr.operator == "!=") { - // (x==something)!=0 --> x==something - left.operator = "==" - return listOf(IAstModification.ReplaceNode(expr, left, parent)) - } else if (left.operator == "!=" && expr.operator == "!=") { - // (x!=something)!=0 --> x!=something - left.operator = "!=" - return listOf(IAstModification.ReplaceNode(expr, left, parent)) - } - } - } - } - return noModifications - } - - override fun after(expr: PrefixExpression, parent: Node): Iterable { - if(expr.operator == "not") { - // not(not(x)) -> x - if((expr.expression as? PrefixExpression)?.operator=="not") - return listOf(IAstModification.ReplaceNode(expr, expr.expression, parent)) - // not(~x) -> x!=0 - if((expr.expression as? PrefixExpression)?.operator=="~") { - val x = (expr.expression as PrefixExpression).expression - val dt = x.inferType(program).getOrElse { throw FatalAstException("invalid dt") } - val notZero = BinaryExpression(x, "!=", NumericLiteral(dt, 0.0, expr.position), expr.position) - return listOf(IAstModification.ReplaceNode(expr, notZero, parent)) - } - val subBinExpr = expr.expression as? BinaryExpression - if(subBinExpr?.operator=="==") { - if(subBinExpr.right.constValue(program)?.number==0.0) { - // not(x==0) -> x!=0 - subBinExpr.operator = "!=" - return listOf(IAstModification.ReplaceNode(expr, subBinExpr, parent)) - } - } else if(subBinExpr?.operator=="!=") { - if(subBinExpr.right.constValue(program)?.number==0.0) { - // not(x!=0) -> x==0 - subBinExpr.operator = "==" - return listOf(IAstModification.ReplaceNode(expr, subBinExpr, parent)) - } - } - - // all other not(x) --> x==0 - // this means that "not" will never occur anywhere again in the ast after this - val replacement = BinaryExpression(expr.expression, "==", NumericLiteral(DataType.UBYTE,0.0, expr.position), expr.position) - return listOf(IAstModification.ReplaceNodeSafe(expr, replacement, parent)) - } - return noModifications - } -} diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 30378bae2..258367dde 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -13,8 +13,6 @@ For next release dpad up (and/or controller X btn) = "C" (hold) controller start = F1 (new game) make the 'R' tilemap a backwards R :) -- fix 6502 codegen bug (check that it isn't also in vm codegen): complex comparison expression is evaluated wrong - see circles.p8 line 57; if distance(c) < (radius as uword) + circle_radius[c] - vm: intermediate code: don't flatten everything. Instead, as a new intermediary step, convert the new Ast into *structured* intermediary code. Basically keep the blocks and subroutines structure, including full subroutine signature information, diff --git a/examples/cx16/circles.p8 b/examples/cx16/circles.p8 index 04726eca7..eb3c49adb 100644 --- a/examples/cx16/circles.p8 +++ b/examples/cx16/circles.p8 @@ -54,9 +54,7 @@ main { return true ubyte @zp c for c in 0 to num_circles-1 { - ; TODO FIX THIS IN 6502 CODEGEN: if distance(c) < (radius as uword) + circle_radius[c] - cx16.r15 = (radius as uword) + circle_radius[c] - if distance(c) < cx16.r15 + if distance(c) < (radius as uword) + circle_radius[c] return false } return true diff --git a/examples/test.p8 b/examples/test.p8 index dfe39e4d8..6bf82e2a8 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -5,25 +5,27 @@ main { sub start() { - uword slab1 = memory("slab 1", 2000, 0) - uword slab2 = memory("slab 1", 2000, 0) - uword slab3 = memory("slab other", 2000, 64) + ubyte c + ubyte radius = 1 + ubyte[] circle_radius = [5,10,15,20,25,30] - txt.print_uwhex(slab1, true) - txt.print_uwhex(slab2, true) - txt.print_uwhex(slab3, true) + for c in 0 to len(circle_radius)-1 { + if distance(c) < (radius as uword) + circle_radius[c] + txt.chrout('y') + else + txt.chrout('n') + cx16.r15 = (radius as uword) + circle_radius[c] + if distance(c) < cx16.r15 + txt.chrout('y') + else + txt.chrout('n') + txt.nl() + } + } - ubyte rasterCount = 231 - - if rasterCount >= 230 - txt.print("y1") - - if rasterCount ^ $80 >= 230 - txt.print("y2") - - if (rasterCount ^ $80) >= 230 - txt.print("y3") - + sub distance(ubyte cix) -> uword { + uword sqx = cix+10 + return sqrt16(sqx*sqx) } }