diff --git a/compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt b/compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt index f17225445..3967dfa84 100644 --- a/compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt +++ b/compiler/src/prog8/compiler/astprocessing/NotExpressionAndIfComparisonExprChanger.kt @@ -70,7 +70,7 @@ internal class NotExpressionAndIfComparisonExprChanger(val program: Program, val // not(not(x)) -> x if((expr.expression as? PrefixExpression)?.operator=="not") - return listOf(IAstModification.ReplaceNode(expr, expr.expression, parent)) + return listOf(IAstModification.ReplaceNode(expr, (expr.expression as PrefixExpression).expression, parent)) // not(~x) -> x!=0 if((expr.expression as? PrefixExpression)?.operator=="~") { val x = (expr.expression as PrefixExpression).expression diff --git a/compiler/test/TestOptimization.kt b/compiler/test/TestOptimization.kt index 251dafb70..c31b0bc7a 100644 --- a/compiler/test/TestOptimization.kt +++ b/compiler/test/TestOptimization.kt @@ -298,13 +298,15 @@ class TestOptimization: FunSpec({ (initY2.value as NumericLiteral).number shouldBe 11.0 } - test("various 'not' operator rewrites even without optimizations on") { + test("various 'not' operator rewrites even without optimizations") { val src = """ main { sub start() { - ubyte a1 - ubyte a2 - a1 = not not a1 ; a1 = a1==0 + bool @shared a1 + bool @shared a2 + a1 = not a1 ; a1 = a1==0 + a1 = not not a1 ; a1 = a1, so removed totally + a1 = not not not a1 ; a1 = a1==0 a1 = not a1 or not a2 ; a1 = a1==0 or a2==0 a1 = not a1 and not a2 ; a1 = a1==0 and a2==0 } @@ -312,15 +314,48 @@ class TestOptimization: FunSpec({ """ val result = compileText(C64Target(), false, src, writeAssembly = true)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 8 + stmts.size shouldBe 9 val value1 = (stmts[4] as Assignment).value as BinaryExpression val value2 = (stmts[5] as Assignment).value as BinaryExpression val value3 = (stmts[6] as Assignment).value as BinaryExpression + val value4 = (stmts[7] as Assignment).value as BinaryExpression value1.operator shouldBe "==" value1.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) - value2.operator shouldBe "or" - value3.operator shouldBe "and" + value2.operator shouldBe "==" + value2.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) + value3.operator shouldBe "or" + value4.operator shouldBe "and" + } + + test("various 'not' operator rewrites with optimizations") { + val src = """ + main { + sub start() { + bool @shared a1 + bool @shared a2 + a1 = not a1 ; a1 = a1==0 + a1 = not not a1 ; a1 = a1, so removed totally + a1 = not not not a1 ; a1 = a1==0 + a1 = not a1 or not a2 ; a1 = a1==0 or a2==0 + a1 = not a1 and not a2 ; a1 = a1==0 and a2==0 + } + } + """ + val result = compileText(C64Target(), true, src, writeAssembly = true)!! + val stmts = result.compilerAst.entrypoint.statements + stmts.size shouldBe 9 + + val value1 = (stmts[4] as Assignment).value as BinaryExpression + val value2 = (stmts[5] as Assignment).value as BinaryExpression + val value3 = (stmts[6] as Assignment).value as BinaryExpression + val value4 = (stmts[7] as Assignment).value as BinaryExpression + value1.operator shouldBe "==" + value1.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) + value2.operator shouldBe "==" + value2.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) + value3.operator shouldBe "or" + value4.operator shouldBe "and" } test("asmgen correctly deals with float typecasting in augmented assignment") { diff --git a/compiler/test/TestTypecasts.kt b/compiler/test/TestTypecasts.kt index 3d2bdcdbc..8eb4970c0 100644 --- a/compiler/test/TestTypecasts.kt +++ b/compiler/test/TestTypecasts.kt @@ -69,7 +69,12 @@ class TestTypecasts: FunSpec({ main { sub ftrue(ubyte arg) -> ubyte { arg++ - return 64 + return 42 + } + + sub btrue(ubyte arg) -> bool { + arg++ + return true } sub start() { @@ -81,30 +86,38 @@ main { bvalue = ub1 xor ub2 xor ub3 xor true bvalue = ub1 xor ub2 xor ub3 xor ftrue(99) bvalue = ub1 and ub2 and ftrue(99) + bvalue = ub1 xor ub2 xor ub3 xor btrue(99) + bvalue = ub1 and ub2 and btrue(99) } }""" val result = compileText(C64Target(), true, text, writeAssembly = true)!! val stmts = result.compilerAst.entrypoint.statements /* - ubyte ub1 + ubyte @shared ub1 ub1 = 1 - ubyte ub2 + ubyte @shared ub2 ub2 = 1 - ubyte ub3 + ubyte @shared ub3 ub3 = 1 ubyte @shared bvalue - bvalue = ub1 xor ub2 xor ub3 xor true - bvalue = (((ub1 xor ub2)xor ub3) xor (ftrue(99)!=0)) + bvalue = (((ub1 xor ub2) xor ub3) xor 1) + bvalue = (((ub1 xor ub2) xor ub3) xor (ftrue(99)!=0)) bvalue = ((ub1 and ub2) and (ftrue(99)!=0)) + bvalue = (((ub1 xor ub2) xor ub3) xor btrue(99)) + bvalue = ((ub1 and ub2) and btrue(99)) return */ - stmts.size shouldBe 11 + stmts.size shouldBe 13 val assignValue1 = (stmts[7] as Assignment).value as BinaryExpression val assignValue2 = (stmts[8] as Assignment).value as BinaryExpression val assignValue3 = (stmts[9] as Assignment).value as BinaryExpression + val assignValue4 = (stmts[10] as Assignment).value as BinaryExpression + val assignValue5 = (stmts[11] as Assignment).value as BinaryExpression assignValue1.operator shouldBe "xor" assignValue2.operator shouldBe "xor" assignValue3.operator shouldBe "and" + assignValue4.operator shouldBe "xor" + assignValue5.operator shouldBe "and" val right2 = assignValue2.right as BinaryExpression val right3 = assignValue3.right as BinaryExpression right2.operator shouldBe "!=" @@ -113,6 +126,8 @@ main { right2.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) right3.left shouldBe instanceOf() right3.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) + assignValue4.right shouldBe instanceOf() + assignValue5.right shouldBe instanceOf() } test("simple logical with byte instead of bool ok with typecasting") { diff --git a/compiler/test/ast/TestVariousCompilerAst.kt b/compiler/test/ast/TestVariousCompilerAst.kt index 9af80d3d4..487d8a370 100644 --- a/compiler/test/ast/TestVariousCompilerAst.kt +++ b/compiler/test/ast/TestVariousCompilerAst.kt @@ -444,11 +444,11 @@ main { main { sub start() { str test = "test" - ubyte insync + bool @shared insync if not insync - insync++ + insync=true if insync not in test - insync++ + insync=true } }""" compileText(VMTarget(), optimize=false, src, writeAssembly=false) shouldNotBe null diff --git a/compiler/test/vm/TestCompilerVirtual.kt b/compiler/test/vm/TestCompilerVirtual.kt index 876ece978..4d8e68d3e 100644 --- a/compiler/test/vm/TestCompilerVirtual.kt +++ b/compiler/test/vm/TestCompilerVirtual.kt @@ -346,6 +346,7 @@ main { sub start() { ubyte[3] values = [1,2,3] func(33 + (22 in values)) ; bool cast to byte + cx16.r0L = 33 + (22 in values) ; bool cast to byte func(values[cx16.r0L] + (22 in values)) ; containment in complex expression } sub func(ubyte arg) { diff --git a/examples/test.p8 b/examples/test.p8 index 1bedafc00..8cb734906 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,21 +1,21 @@ +%import textio %zeropage basicsafe %option no_sysinit main { sub start() { - sub1() - } + bool @shared a1 = true + bool @shared a2 = false - sub sub1() { - cx16.r0++ - sub2() - } - sub sub2() { - cx16.r0++ - sub3() - } - sub sub3() { - cx16.r0++ - sys.exit(42) + txt.print_ub(not a1) ; a1 = a1==0 "0" + txt.nl() + txt.print_ub(not not a1) ; a1 = a1 "1" + txt.nl() + txt.print_ub(not not not a1) ; a1 = a1==0 "0" + txt.nl() + txt.print_ub(not a1 or not a2) ; a1 = a1==0 or a2==0 "1" + txt.nl() + txt.print_ub(not a1 and not a2) ; a1 = a1==0 and a2==0 "0" + txt.nl() } }