From af98d01053d64236d7f6ccfd7a33441f35d4d9fb Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Mon, 27 Jun 2022 21:40:48 +0200 Subject: [PATCH] failed attempt at McCarthy shortcut evaluation --- .../compiler/astprocessing/CodeDesugarer.kt | 15 +- .../astprocessing/StatementReorderer.kt | 154 +++++++++++++----- compilerAst/src/prog8/ast/Program.kt | 4 + docs/source/todo.rst | 3 + examples/test.p8 | 35 +++- 5 files changed, 156 insertions(+), 55 deletions(-) diff --git a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt index 8152697b1..00014eced 100644 --- a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt +++ b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt @@ -27,16 +27,11 @@ internal class CodeDesugarer(val program: Program, // - repeat-forever loops replaced by label+jump. - private fun jumpLabel(label: Label): Jump { - val ident = IdentifierReference(listOf(label.name), label.position) - return Jump(null, ident, null, label.position) - } - override fun before(breakStmt: Break, parent: Node): Iterable { fun jumpAfter(stmt: Statement): Iterable { val label = program.makeLabel("after", breakStmt.position) return listOf( - IAstModification.ReplaceNode(breakStmt, jumpLabel(label), parent), + IAstModification.ReplaceNode(breakStmt, program.jumpLabel(label), parent), IAstModification.InsertAfter(stmt, label, stmt.parent as IStatementContainer) ) } @@ -73,7 +68,7 @@ if not CONDITION loopLabel, untilLoop.body, IfElse(notCondition, - AnonymousScope(mutableListOf(jumpLabel(loopLabel)), pos), + AnonymousScope(mutableListOf(program.jumpLabel(loopLabel)), pos), AnonymousScope(mutableListOf(), pos), pos) ), pos) @@ -97,11 +92,11 @@ _after: val replacement = AnonymousScope(mutableListOf( loopLabel, IfElse(notCondition, - AnonymousScope(mutableListOf(jumpLabel(afterLabel)), pos), + AnonymousScope(mutableListOf(program.jumpLabel(afterLabel)), pos), AnonymousScope(mutableListOf(), pos), pos), whileLoop.body, - jumpLabel(loopLabel), + program.jumpLabel(loopLabel), afterLabel ), pos) return listOf(IAstModification.ReplaceNode(whileLoop, replacement, parent)) @@ -131,7 +126,7 @@ _after: override fun after(repeatLoop: RepeatLoop, parent: Node): Iterable { if(repeatLoop.iterations==null) { val label = program.makeLabel("repeat", repeatLoop.position) - val jump = jumpLabel(label) + val jump = program.jumpLabel(label) return listOf( IAstModification.InsertFirst(label, repeatLoop.body), IAstModification.InsertLast(jump, repeatLoop.body), diff --git a/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt b/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt index 3c0b0d2ce..a62341a15 100644 --- a/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt +++ b/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt @@ -200,6 +200,17 @@ internal class StatementReorderer(val program: Program, && maySwapOperandOrder(expr)) return listOf(IAstModification.SwapOperands(expr)) + if (expr.operator == "and") { + val leftBinExpr = expr.left as? BinaryExpression + if (leftBinExpr?.operator == "and") + return mcCarthyAndExpression(expr, parent) + } + if (expr.operator == "or") { + val leftBinExpr = expr.left as? BinaryExpression + if (leftBinExpr?.operator == "or") + return mcCarthyOrExpression(expr, parent) + } + // when using a simple bit shift and assigning it to a variable of a different type, // try to make the bit shifting 'wide enough' to fall into the variable's type. // with this, for instance, uword x = 1 << 10 will result in 1024 rather than 0 (the ubyte result). @@ -279,6 +290,7 @@ internal class StatementReorderer(val program: Program, } } } + return noModifications } @@ -320,16 +332,16 @@ internal class StatementReorderer(val program: Program, // rewrite in-place assignment expressions a bit so that the assignment target usually is the leftmost operand val binExpr = assignment.value as? BinaryExpression if(binExpr!=null) { -// if (binExpr.operator == "and") { -// val leftBinExpr = binExpr.left as? BinaryExpression -// if (leftBinExpr?.operator == "and") -// return mcCarthyAndAssignment(binExpr, assignment, parent) -// } -// if (binExpr.operator == "or") { -// val leftBinExpr = binExpr.left as? BinaryExpression -// if (leftBinExpr?.operator == "or") -// return mcCarthyOrAssignment(binExpr, assignment, parent) -// } + if (binExpr.operator == "and") { + val leftBinExpr = binExpr.left as? BinaryExpression + if (leftBinExpr?.operator == "and") + return mcCarthyAndAssignment(binExpr, assignment, parent) + } + if (binExpr.operator == "or") { + val leftBinExpr = binExpr.left as? BinaryExpression + if (leftBinExpr?.operator == "or") + return mcCarthyOrAssignment(binExpr, assignment, parent) + } if(binExpr.left isSameAs assignment.target) { // A = A 5, unchanged @@ -438,35 +450,73 @@ internal class StatementReorderer(val program: Program, } private fun mcCarthyAndAssignment(andExpr: BinaryExpression, assignment: Assignment, assignmentParent: Node): Iterable { - val andTerms = findTerms(andExpr, "and") - if(andTerms.any { it.constValue(program)?.asBooleanValue == false }) { + val terms = findTerms(andExpr, "and") + if(terms.any { it.constValue(program)?.asBooleanValue == false }) { errors.warn("expression is always false", andExpr.position) return listOf(IAstModification.ReplaceNode(andExpr, NumericLiteral.fromBoolean(false, andExpr.position), assignment)) } - // TODO: - // repeat for all terms: - // assign term to target - // add if: if target==0: goto done - // done:. - - return noModifications + val replacement = doShortcutEvaluation( + assignment.target, + assignment.target.inferType(program).getOr(DataType.UNDEFINED), + terms, + "==", + false, + assignment.position + ) + return listOf(IAstModification.ReplaceNode(assignment, replacement, assignmentParent)) } private fun mcCarthyOrAssignment(orExpr: BinaryExpression, assignment: Assignment, assignmentParent: Node): Iterable { - val andTerms = findTerms(orExpr, "or") - if(andTerms.any { it.constValue(program)?.asBooleanValue == true }) { + val terms = findTerms(orExpr, "or") + if(terms.any { it.constValue(program)?.asBooleanValue == true }) { errors.warn("expression is always true", orExpr.position) return listOf(IAstModification.ReplaceNode(orExpr, NumericLiteral.fromBoolean(true, orExpr.position), assignment)) } - // TODO: - // repeat for all terms: - // assign term to target - // add if: if target!=0: goto done - // done:. + val replacement = doShortcutEvaluation( + assignment.target, + assignment.target.inferType(program).getOr(DataType.UNDEFINED), + terms, + "!=", + false, + assignment.position + ) + return listOf(IAstModification.ReplaceNode(assignment, replacement, assignmentParent)) + } - return noModifications + private fun mcCarthyAndExpression(expr: BinaryExpression, parent: Node): Iterable { + val terms = findTerms(expr, "and") + if(terms.any { it.constValue(program)?.asBooleanValue == false }) { + errors.warn("expression is always false", expr.position) + return listOf(IAstModification.ReplaceNode(expr, NumericLiteral.fromBoolean(false, expr.position), parent)) + } + + val (tempvarName, _) = program.getTempVar(DataType.UBYTE) + val assignTarget = AssignTarget(IdentifierReference(tempvarName, expr.position), null, null, position = expr.position) + val replacement = doShortcutEvaluation(assignTarget, DataType.UBYTE, terms, "==", true, expr.position) + val exprStmt = expr.containingStatement + return listOf( + IAstModification.InsertBefore(exprStmt, replacement, exprStmt.definingScope), + IAstModification.ReplaceNode(expr, assignTarget.identifier!!.copy(), parent) + ) + } + + private fun mcCarthyOrExpression(expr: BinaryExpression, parent: Node): Iterable { + val terms = findTerms(expr, "or") + if(terms.any { it.constValue(program)?.asBooleanValue == true }) { + errors.warn("expression is always true", expr.position) + return listOf(IAstModification.ReplaceNode(expr, NumericLiteral.fromBoolean(false, expr.position), parent)) + } + + val (tempvarName, _) = program.getTempVar(DataType.UBYTE) + val assignTarget = AssignTarget(IdentifierReference(tempvarName, expr.position), null, null, position = expr.position) + val replacement = doShortcutEvaluation(assignTarget, DataType.UBYTE, terms, "!=", true, expr.position) + val exprStmt = expr.containingStatement + return listOf( + IAstModification.InsertBefore(exprStmt, replacement, exprStmt.definingScope), + IAstModification.ReplaceNode(expr, assignTarget.identifier!!.copy(), parent) + ) } private fun findTerms(expr: BinaryExpression, operator: String, terms: List = emptyList()): List { @@ -475,16 +525,46 @@ internal class StatementReorderer(val program: Program, val leftBinExpr = expr.left as? BinaryExpression ?: return listOf(expr.left, expr.right) + terms return findTerms(leftBinExpr, operator, listOf(expr.right)+terms) } - - private fun mcCarthyAndExpression(e1: Expression, e2: Expression, e3: Expression, origAndExpression: BinaryExpression, parent: Node): Iterable { - // TODO - println("AND EXPRESSION:") - println(" $e1") - println(" $e2") - println(" $e3") - println(" (orig) $origAndExpression") - val replacement = NumericLiteral.fromBoolean(false, origAndExpression.position) - return listOf(IAstModification.ReplaceNode(origAndExpression, replacement, parent)) + + private fun doShortcutEvaluation( + target: AssignTarget, + targetDt: DataType, + terms: List, + checkZeroOperator: String, + convertToBool: Boolean, + position: Position + ): AnonymousScope { + val replacement = AnonymousScope(mutableListOf(), position) + val doneLabel = program.makeLabel("and", position) + for ((idx, term) in terms.withIndex()) { + val value: Expression = if (term is IFunctionCall + && term.target.nameInSource == listOf("boolean") + && term.args[0].inferType(program).isAssignableTo(targetDt) + ) + term.args[0].copy() + else + term.copy() + val assignTerm = Assignment(target.copy(), value, AssignmentOrigin.OPTIMIZER, position) + replacement.statements.add(assignTerm) + if (idx < terms.size - 1) { + val targetCheck = BinaryExpression(target.toExpression(), checkZeroOperator, + NumericLiteral(DataType.UBYTE, 0.0, position), position) + val jumpDone = program.jumpLabel(doneLabel) + val ifStmt = IfElse(targetCheck, + AnonymousScope(mutableListOf(jumpDone), position), + AnonymousScope(mutableListOf(), Position.DUMMY), + position) + replacement.statements.add(ifStmt) + } else if(idx==terms.size-1) { + if(convertToBool) { + assignTerm.value = BuiltinFunctionCall(IdentifierReference(listOf("boolean"), assignTerm.position), mutableListOf(value), assignTerm.position) + } + if(targetDt !in ByteDatatypes) + assignTerm.value = TypecastExpression(assignTerm.value, targetDt, true, assignTerm.position) + } + } + replacement.statements.add(doneLabel) + return replacement } } diff --git a/compilerAst/src/prog8/ast/Program.kt b/compilerAst/src/prog8/ast/Program.kt index 5974fb5d9..0460d914b 100644 --- a/compilerAst/src/prog8/ast/Program.kt +++ b/compilerAst/src/prog8/ast/Program.kt @@ -155,6 +155,10 @@ class Program(val name: String, return Label(strLabel, position) } + fun jumpLabel(label: Label): Jump { + val ident = IdentifierReference(listOf(label.name), label.position) + return Jump(null, ident, null, label.position) + } } diff --git a/docs/source/todo.rst b/docs/source/todo.rst index d13e9b0c0..e85f899a7 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -5,6 +5,9 @@ For next release ^^^^^^^^^^^^^^^^ - add McCarthy evaluation to shortcircuit and/or expressions. Both conditional expressions and assignments! StatementReorder.after(assignment). + TODO: boolean expressions. +- add unit tests for all 4 mcarthy shortcut cases. +- can we optimize redundant calls to boolean() away? imageviewer.prg got larger because of them - add some more optimizations in vmPeepholeOptimizer - vm Instruction needs to know what the read-registers/memory are, and what the write-register/memory is. this info is needed for more advanced optimizations and later code generation steps. diff --git a/examples/test.p8 b/examples/test.p8 index a0abf244e..56e0d2b32 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -11,7 +11,7 @@ main { } sub funcFalseWord() -> uword { - txt.print("falseWord() ") + txt.print("falseword() ") return 0 } @@ -55,15 +55,15 @@ main { ubyte ub4 = 44 ubyte ub5 = 55 - ub4 = 42 + ub4 = 0 txt.print("and with bytes: ") - ub5 = ub1 and ub2 and ub3 and ub4 and ub5 ; TODO FIX !! should be True (!=0) + ub5 = ub1 and ub2 and ub3 and ub4 and ub5 txt.print_ub(ub5) txt.nl() - ub4 = 42 + ub4 = 0 txt.print("or with bytes: ") - ub5 = ub1 or ub2 or ub3 or ub4 or ub5 ; TODO FIX!! should be False (0) + ub5 = ub1 or ub2 or ub3 or ub4 or ub5 txt.print_ub(ub5) txt.nl() @@ -86,7 +86,7 @@ main { txt.nl() txt.print("and with false: ") - value = func1(25) and func2(25) and funcFalse() and false and func3(25) and func4(25) + value = func1(25) and func2(25) and funcFalse() and func3(25) and func4(25) txt.print_ub(value) txt.nl() txt.print("and with true: ") @@ -94,11 +94,11 @@ main { txt.print_ub(value) txt.nl() txt.print("or with false: ") - value = func1(25) or func2(25) or funcFalse() or true or func3(25) or func4(25) + value = func1(0) or func2(0) or funcFalse() or func3(25) or func4(25) txt.print_ub(value) txt.nl() txt.print("or with true: ") - value = func1(25) or func2(25) or funcTrue() or func3(25) or func4(25) + value = func1(0) or func2(0) or funcTrue() or func3(25) or func4(25) txt.print_ub(value) txt.nl() txt.print("xor with false: ") @@ -110,6 +110,25 @@ main { txt.print_ub(value) txt.nl() + txt.print("\nif and with false: [nothing]: ") + if func1(25) and func2(25) and funcFalse() and func3(25) and func4(25) + txt.print("failure!") + txt.print("\nif and with true: [ok]: ") + if func1(25) and func2(25) and funcTrue() and func3(25) and func4(25) + txt.print("ok!") + txt.print("\nif or with false: [ok]: ") + if func1(0) or func2(0) or funcFalse() or func3(25) or func4(25) + txt.print("ok!") + txt.print("\nif or with true: [ok]: ") + if func1(0) or func2(0) or funcTrue() or func3(25) or func4(25) + txt.print("ok!") + txt.print("\nif xor with false: [nothing]: ") + if func1(25) xor func2(25) xor funcFalse() xor func3(25) xor func4(25) + txt.print("failure!") + txt.print("\nif xor with true: [ok]: ") + if func1(25) xor func2(25) xor funcTrue() xor func3(25) xor func4(25) + txt.print("ok!") + txt.nl() ; a "pixelshader": ; sys.gfx_enable(0) ; enable lo res screen