diff --git a/codeOptimizers/src/prog8/optimizer/StatementOptimizer.kt b/codeOptimizers/src/prog8/optimizer/StatementOptimizer.kt index a07e68808..78416482a 100644 --- a/codeOptimizers/src/prog8/optimizer/StatementOptimizer.kt +++ b/codeOptimizers/src/prog8/optimizer/StatementOptimizer.kt @@ -391,4 +391,47 @@ class StatementOptimizer(private val program: Program, else noModifications } + + override fun after(whenStmt: When, parent: Node): Iterable { + + fun replaceWithIf(condition: Expression, trueBlock: AnonymousScope, elseBlock: AnonymousScope?): List { + val ifStmt = IfElse(condition, trueBlock, elseBlock ?: AnonymousScope(mutableListOf(), whenStmt.position), whenStmt.position) + errors.warn("for boolean condition a normal if statement is preferred", whenStmt.position) + return listOf(IAstModification.ReplaceNode(whenStmt, ifStmt, parent)) + } + + if(whenStmt.condition.inferType(program).isBool) { + if(whenStmt.choices.all { it.values?.size==1 }) { + if (whenStmt.choices.all { it.values!!.single().constValue(program)!!.number in arrayOf(0.0, 1.0) }) { + // it's a when statement on booleans that can just be replaced by an if or if..else. + if (whenStmt.choices.size == 1) { + return if(whenStmt.choices[0].values!![0].constValue(program)!!.number==1.0) { + replaceWithIf(whenStmt.condition, whenStmt.choices[0].statements, null) + } else { + val notCondition = BinaryExpression(whenStmt.condition, "==", NumericLiteral(DataType.UBYTE, 0.0, whenStmt.condition.position), whenStmt.condition.position) + replaceWithIf(notCondition, whenStmt.choices[0].statements, null) + } + } else if (whenStmt.choices.size == 2) { + var trueBlock: AnonymousScope? = null + var elseBlock: AnonymousScope? = null + if(whenStmt.choices[0].values!![0].constValue(program)!!.number==1.0) { + trueBlock = whenStmt.choices[0].statements + } else { + elseBlock = whenStmt.choices[0].statements + } + if(whenStmt.choices[1].values!![0].constValue(program)!!.number==1.0) { + trueBlock = whenStmt.choices[1].statements + } else { + elseBlock = whenStmt.choices[1].statements + } + if(trueBlock!=null && elseBlock!=null) { + return replaceWithIf(whenStmt.condition, trueBlock, elseBlock) + } + } + } + } + } + return noModifications + } + } diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index f914336bb..b86815d5a 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -1341,8 +1341,14 @@ internal class AstChecker(private val program: Program, constvalue == null -> errors.err("choice value must be a constant", whenChoice.position) constvalue.type !in IntegerDatatypes -> errors.err("choice value must be a byte or word", whenChoice.position) conditionType isnot constvalue.type -> { - if(conditionType.isKnown) - errors.err("choice value datatype differs from condition value", whenChoice.position) + if(conditionType.isKnown) { + if(conditionType.istype(DataType.BOOL)) { + if(constvalue.number!=0.0 && constvalue.number!=1.0) + errors.err("choice value datatype differs from condition value", whenChoice.position) + } else { + errors.err("choice value datatype differs from condition value", whenChoice.position) + } + } } } } diff --git a/compiler/test/ast/TestVariousCompilerAst.kt b/compiler/test/ast/TestVariousCompilerAst.kt index 9b54649c4..0f7a24309 100644 --- a/compiler/test/ast/TestVariousCompilerAst.kt +++ b/compiler/test/ast/TestVariousCompilerAst.kt @@ -334,5 +334,22 @@ main { }""" compileText(VMTarget(), optimize=false, src, writeAssembly=false) shouldNotBe null } + + test("when on booleans") { + val src = """ +main +{ + sub start() + { + bool choiceVariable=true + when choiceVariable { + false -> cx16.r0++ + true -> cx16.r1++ + } + } +}""" + + compileText(VMTarget(), optimize=false, src, writeAssembly=false) shouldNotBe null + } }) diff --git a/examples/test.p8 b/examples/test.p8 index deef4bd23..001a38bcd 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,32 +1,16 @@ %import textio %zeropage basicsafe -; (127 instructions in 15 chunks, 47 registers) -; 679 steps - - -main { - -sub start() { - uword i - uword n - - repeat 10 { - txt.chrout('.') - } - txt.nl() - - n=10 - for i in 0 to n step 3 { - txt.print_uw(i) - txt.nl() - } - txt.nl() - - n=0 - for i in 10 downto n step -3 { - txt.print_uw(i) - txt.nl() +main +{ + ; 00f9 + sub start() + { + bool rasterIrqAfterSubs=false + when rasterIrqAfterSubs { + false -> txt.print("false\n") + true -> txt.print("true\n") + } + txt.print("done") } } -} diff --git a/gradle.properties b/gradle.properties index f8846c382..e9d0482de 100644 --- a/gradle.properties +++ b/gradle.properties @@ -5,4 +5,4 @@ org.gradle.daemon=true kotlin.code.style=official javaVersion=11 kotlinVersion=1.9.0 -version=9.1 +version=9.2-SNAPSHOT