diff --git a/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/AsmGen.kt b/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/AsmGen.kt index 7548a5ef4..b096b9f41 100644 --- a/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/AsmGen.kt +++ b/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/AsmGen.kt @@ -1098,16 +1098,21 @@ class AsmGen(private val program: Program, val booleanCondition = stmt.condition as BinaryExpression if (stmt.elsepart.isEmpty()) { - val endLabel = makeLabel("if_end") - translateComparisonExpressionWithJumpIfFalse(booleanCondition, endLabel) - translate(stmt.truepart) - out(endLabel) +// TODO specialize this +// if(stmt.truepart.statements.singleOrNull() is Jump) { +// translateCompareAndJumpIfTrue(booleanCondition, stmt.truepart.statements[0] as Jump) +// } else { + val endLabel = makeLabel("if_end") + translateCompareAndJumpIfFalse(booleanCondition, endLabel) + translate(stmt.truepart) + out(endLabel) +// } } else { // both true and else parts val elseLabel = makeLabel("if_else") val endLabel = makeLabel("if_end") - translateComparisonExpressionWithJumpIfFalse(booleanCondition, elseLabel) + translateCompareAndJumpIfFalse(booleanCondition, elseLabel) translate(stmt.truepart) jmp(endLabel) out(elseLabel) @@ -1607,44 +1612,39 @@ $label nop""") return false } - - private fun translateComparisonExpressionWithJumpIfFalse(expr: BinaryExpression, jumpIfFalseLabel: String) { - // This is a helper routine called from if expressions to generate optimized conditional branching code. - // First, if it is of the form: X , then flip the expression so the constant is always the right operand. - - var left = expr.left - var right = expr.right - var operator = expr.operator - var leftConstVal = left.constValue(program) - var rightConstVal = right.constValue(program) - - // make sure the constant value is on the right of the comparison expression - if(leftConstVal!=null) { - val tmp = left - left = right - right = tmp - val tmp2 = leftConstVal - leftConstVal = rightConstVal - rightConstVal = tmp2 - when(expr.operator) { - "<" -> operator = ">" - "<=" -> operator = ">=" - ">" -> operator = "<" - ">=" -> operator = "<=" - } - } + private fun translateCompareAndJumpIfTrue(expr: BinaryExpression, jump: Jump) { + val left = expr.left + val right = expr.right + val operator = expr.operator + val leftConstVal = left.constValue(program) + val rightConstVal = right.constValue(program) if (rightConstVal!=null && rightConstVal.number == 0.0) - jumpIfZeroOrNot(left, operator, jumpIfFalseLabel) + testZeroAndJump(left, operator, jump, null) else - jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal) + testNonzeroComparisonAndJump(left, operator, right, jump, null, leftConstVal, rightConstVal) } - private fun jumpIfZeroOrNot( + private fun translateCompareAndJumpIfFalse(expr: BinaryExpression, jumpIfFalseLabel: String) { + val left = expr.left + val right = expr.right + val operator = expr.operator + val leftConstVal = left.constValue(program) + val rightConstVal = right.constValue(program) + + if (rightConstVal!=null && rightConstVal.number == 0.0) + testZeroAndJump(left, operator, null, jumpIfFalseLabel) + else + testNonzeroComparisonAndJump(left, operator, right, null, jumpIfFalseLabel, leftConstVal, rightConstVal) + } + + private fun testZeroAndJump( left: Expression, operator: String, - jumpIfFalseLabel: String + jumpIfTrue: Jump?, + jumpIfFalseLabel: String? ) { + require(jumpIfTrue!=null || jumpIfFalseLabel!=null) when(val dt = left.inferType(program).getOr(DataType.UNDEFINED)) { DataType.UBYTE, DataType.UWORD -> { if(operator=="<") { @@ -1729,16 +1729,20 @@ $label nop""") } } - private fun jumpIfComparison( + private fun testNonzeroComparisonAndJump( left: Expression, operator: String, right: Expression, - jumpIfFalseLabel: String, + jumpIfTrue: Jump?, + jumpIfFalseLabel: String?, leftConstVal: NumericLiteralValue?, rightConstVal: NumericLiteralValue? ) { + require(jumpIfTrue!=null || jumpIfFalseLabel!=null) val dt = left.inferType(program).getOrElse { throw AssemblyError("unknown dt") } + jumpIfFalseLabel!! // TODO jump if true... or rewrite everything to use just jump-if-false + when (operator) { "==" -> { when (dt) { diff --git a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt index 2df64e661..7fc47026c 100644 --- a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt +++ b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt @@ -1,5 +1,6 @@ package prog8.optimizer +import prog8.ast.IStatementContainer import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.DataType @@ -7,7 +8,10 @@ import prog8.ast.base.FatalAstException import prog8.ast.base.IntegerDatatypes import prog8.ast.base.NumericDatatypes import prog8.ast.expressions.* +import prog8.ast.statements.AnonymousScope import prog8.ast.statements.Assignment +import prog8.ast.statements.IfStatement +import prog8.ast.statements.Jump import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification import kotlin.math.abs @@ -54,6 +58,31 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() { return mods } + override fun after(ifStatement: IfStatement, parent: Node): Iterable { + val truepart = ifStatement.truepart + val elsepart = ifStatement.elsepart + if(truepart.isNotEmpty() && elsepart.isNotEmpty()) { + if(truepart.statements.singleOrNull() is Jump) { + return listOf( + IAstModification.InsertAfter(ifStatement, elsepart, parent as IStatementContainer), + IAstModification.ReplaceNode(elsepart, AnonymousScope(mutableListOf(), elsepart.position), ifStatement) + ) + } + if(elsepart.statements.singleOrNull() is Jump) { + val invertedCondition = invertCondition(ifStatement.condition) + if(invertedCondition!=null) { + return listOf( + IAstModification.ReplaceNode(ifStatement.condition, invertedCondition, ifStatement), + IAstModification.InsertAfter(ifStatement, truepart, parent as IStatementContainer), + IAstModification.ReplaceNode(elsepart, AnonymousScope(mutableListOf(), elsepart.position), ifStatement), + IAstModification.ReplaceNode(truepart, elsepart, ifStatement) + ) + } + } + } + return noModifications + } + override fun after(expr: BinaryExpression, parent: Node): Iterable { val leftVal = expr.left.constValue(program) val rightVal = expr.right.constValue(program) @@ -696,3 +725,24 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() { private data class BinExprWithConstants(val expr: BinaryExpression, val leftVal: NumericLiteralValue?, val rightVal: NumericLiteralValue?) } + + +fun invertCondition(cond: Expression): BinaryExpression? { + if(cond is BinaryExpression) { + val invertedOperator = invertedComparisonOperator(cond.operator) + if (invertedOperator != null) + return BinaryExpression(cond.left, invertedOperator, cond.right, cond.position) + } + return null +} + +fun invertedComparisonOperator(operator: String) = + when (operator) { + "==" -> "!=" + "!=" -> "==" + "<" -> ">=" + ">" -> "<=" + "<=" -> ">" + ">=" -> "<" + else -> null + } diff --git a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt index a32afa87c..f159a5b20 100644 --- a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt +++ b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt @@ -9,6 +9,7 @@ import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification import prog8.compilerinterface.IErrorReporter +import prog8.optimizer.invertedComparisonOperator internal class VariousCleanups(val program: Program, val errors: IErrorReporter): AstWalker() { @@ -81,16 +82,7 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter) val comparison = expr.expression as? BinaryExpression if (comparison != null) { // NOT COMPARISON ==> inverted COMPARISON - val invertedOperator = - when (comparison.operator) { - "==" -> "!=" - "!=" -> "==" - "<" -> ">=" - ">" -> "<=" - "<=" -> ">" - ">=" -> "<" - else -> null - } + val invertedOperator = invertedComparisonOperator(comparison.operator) if (invertedOperator != null) { comparison.operator = invertedOperator return listOf(IAstModification.ReplaceNode(expr, comparison, parent)) @@ -99,4 +91,25 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter) } return noModifications } + + override fun after(expr: BinaryExpression, parent: Node): Iterable { + if(expr.operator in ComparisonOperators) { + val leftConstVal = expr.left.constValue(program) + val rightConstVal = expr.right.constValue(program) + // make sure the constant value is on the right of the comparison expression + if(rightConstVal==null && leftConstVal!=null) { + val newOperator = + when(expr.operator) { + "<" -> ">" + "<=" -> ">=" + ">" -> "<" + ">=" -> "<=" + else -> expr.operator + } + val replacement = BinaryExpression(expr.right, newOperator, expr.left, expr.position) + return listOf(IAstModification.ReplaceNode(expr, replacement, parent)) + } + } + return noModifications + } }