diff --git a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt index ab3de0a48..e40570c5d 100644 --- a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt +++ b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt @@ -240,9 +240,16 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o val separateLeftExpr = !expr.left.isSimple && expr.left !is IFunctionCall val separateRightExpr = !expr.right.isSimple && expr.right !is IFunctionCall + val leftDt = expr.left.inferType(program) + val rightDt = expr.right.inferType(program) + + if(!leftDt.isInteger || !rightDt.isInteger) { + // we can't reasonably simplify non-integer expressions + return CondExprSimplificationResult(null, null, null, null) + } if(separateLeftExpr) { - val name = getTempVarName(expr.left.inferType(program)) + val name = getTempVarName(leftDt) leftOperandReplacement = IdentifierReference(name, expr.position) leftAssignment = Assignment( AssignTarget(IdentifierReference(name, expr.position), null, null, expr.position), @@ -251,12 +258,11 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o ) } if(separateRightExpr) { - val dt = expr.right.inferType(program) val name = when { - dt.istype(DataType.UBYTE) -> listOf("prog8_lib","retval_interm_ub") - dt.istype(DataType.UWORD) -> listOf("prog8_lib","retval_interm_uw") - dt.istype(DataType.BYTE) -> listOf("prog8_lib","retval_interm_b2") - dt.istype(DataType.WORD) -> listOf("prog8_lib","retval_interm_w2") + rightDt.istype(DataType.UBYTE) -> listOf("prog8_lib","retval_interm_ub") + rightDt.istype(DataType.UWORD) -> listOf("prog8_lib","retval_interm_uw") + rightDt.istype(DataType.BYTE) -> listOf("prog8_lib","retval_interm_b2") + rightDt.istype(DataType.WORD) -> listOf("prog8_lib","retval_interm_w2") else -> throw AssemblyError("invalid dt") } rightOperandReplacement = IdentifierReference(name, expr.position) diff --git a/compiler/test/ModuleImporterTests.kt b/compiler/test/ModuleImporterTests.kt index d2ec30790..8a1dfcd46 100644 --- a/compiler/test/ModuleImporterTests.kt +++ b/compiler/test/ModuleImporterTests.kt @@ -55,7 +55,6 @@ class TestModuleImporter: FunSpec({ error1.file.absolutePath shouldBe "${srcPathAbs.normalize()}" } program.modules.size shouldBe 1 - val error2 = importer.importModule(srcPathAbs).getErrorOrElse { fail("should have import error") } withClue(".file should be normalized") { "${error2.file}" shouldBe "${error2.file.normalize()}" diff --git a/compiler/test/TestAstChecks.kt b/compiler/test/TestAstChecks.kt new file mode 100644 index 000000000..88e7ba782 --- /dev/null +++ b/compiler/test/TestAstChecks.kt @@ -0,0 +1,33 @@ +package prog8tests + +import io.kotest.core.spec.style.FunSpec +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import prog8.compiler.target.C64Target +import prog8tests.helpers.ErrorReporterForTests +import prog8tests.helpers.assertSuccess +import prog8tests.helpers.compileText + + +class TestAstChecks: FunSpec({ + + test("conditional expression w/float works even without tempvar to split it") { + val text = """ + %import floats + main { + sub start() { + uword xx + if xx+99.99 == xx+1.234 { + xx++ + } + } + } + """ + val errors = ErrorReporterForTests(keepMessagesAfterReporting = true) + compileText(C64Target, true, text, writeAssembly = true, errors=errors).assertSuccess() + errors.errors.size shouldBe 0 + errors.warnings.size shouldBe 2 + errors.warnings[0] shouldContain "converted to float" + errors.warnings[1] shouldContain "converted to float" + } +}) diff --git a/compiler/test/helpers/ErrorReporterForTests.kt b/compiler/test/helpers/ErrorReporterForTests.kt index 32304e41e..22fcec895 100644 --- a/compiler/test/helpers/ErrorReporterForTests.kt +++ b/compiler/test/helpers/ErrorReporterForTests.kt @@ -3,7 +3,7 @@ package prog8tests.helpers import prog8.ast.base.Position import prog8.compilerinterface.IErrorReporter -internal class ErrorReporterForTests(private val throwExceptionAtReportIfErrors: Boolean=true): IErrorReporter { +internal class ErrorReporterForTests(private val throwExceptionAtReportIfErrors: Boolean=true, private val keepMessagesAfterReporting: Boolean=false): IErrorReporter { val errors = mutableListOf() val warnings = mutableListOf() @@ -23,6 +23,12 @@ internal class ErrorReporterForTests(private val throwExceptionAtReportIfErrors: errors.forEach { println("UNITTEST COMPILATION REPORT: ERROR: $it") } if(throwExceptionAtReportIfErrors) finalizeNumErrors(errors.size, warnings.size) + if(!keepMessagesAfterReporting) { + clear() + } + } + + fun clear() { errors.clear() warnings.clear() } diff --git a/examples/test.p8 b/examples/test.p8 index b06744937..aaf35c887 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,65 +1,8 @@ -%import diskio -%import cx16diskio -%import textio -%zeropage basicsafe - main { sub start() { - uword xx - uword yy - - c64.SETMSG(%10000000) - xx = diskio.load(8, "hello", 0) - txt.nl() - yy = diskio.load(8, "hello", $8800) - txt.nl() - c64.SETMSG(0) - - txt.print_uwhex(xx, true) - txt.nl() - txt.print_uwhex(yy, true) - txt.nl() - - c64.SETMSG(%10000000) - xx = diskio.load_raw(8, "hello", $8700) - txt.nl() - c64.SETMSG(0) - txt.print_uwhex(xx, true) - txt.nl() - - txt.print("\ncx16:\n") - - c64.SETMSG(%10000000) - yy = cx16diskio.load(8, "x16edit", 1, $3000) - txt.nl() - c64.SETMSG(0) - txt.print_uwhex(yy, true) - txt.nl() - - c64.SETMSG(%10000000) - xx = cx16diskio.load_raw(8, "x16edit", 1, $3000) - txt.nl() - c64.SETMSG(0) - txt.print_uwhex(xx, true) - txt.nl() - txt.print_uw(cx16diskio.load_size(1, $3000, xx)) - txt.nl() - - c64.SETMSG(%10000000) - xx = cx16diskio.load(8, "x16edit", 4, $a100) - txt.nl() - c64.SETMSG(0) - txt.print_uwhex(xx, true) - txt.nl() - txt.print_uw(cx16diskio.load_size(4, $a100, xx)) - txt.nl() - c64.SETMSG(%10000000) - xx = cx16diskio.load_raw(8, "x16edit", 4, $a100) - txt.nl() - c64.SETMSG(0) - txt.print_uwhex(xx, true) - txt.nl() - txt.print_uw(cx16diskio.load_size(4, $a100, xx)) - txt.nl() + uword xx = 10 + if xx+99 == 1.234 { + xx++ + } } }