fix compiler crash when using floats in a comparison expression

This commit is contained in:
Irmen de Jong 2021-12-15 01:24:25 +01:00
parent 890327b381
commit 510bda1b28
5 changed files with 56 additions and 69 deletions

View File

@ -240,9 +240,16 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
val separateLeftExpr = !expr.left.isSimple && expr.left !is IFunctionCall
val separateRightExpr = !expr.right.isSimple && expr.right !is IFunctionCall
val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)
if(!leftDt.isInteger || !rightDt.isInteger) {
// we can't reasonably simplify non-integer expressions
return CondExprSimplificationResult(null, null, null, null)
}
if(separateLeftExpr) {
val name = getTempVarName(expr.left.inferType(program))
val name = getTempVarName(leftDt)
leftOperandReplacement = IdentifierReference(name, expr.position)
leftAssignment = Assignment(
AssignTarget(IdentifierReference(name, expr.position), null, null, expr.position),
@ -251,12 +258,11 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
)
}
if(separateRightExpr) {
val dt = expr.right.inferType(program)
val name = when {
dt.istype(DataType.UBYTE) -> listOf("prog8_lib","retval_interm_ub")
dt.istype(DataType.UWORD) -> listOf("prog8_lib","retval_interm_uw")
dt.istype(DataType.BYTE) -> listOf("prog8_lib","retval_interm_b2")
dt.istype(DataType.WORD) -> listOf("prog8_lib","retval_interm_w2")
rightDt.istype(DataType.UBYTE) -> listOf("prog8_lib","retval_interm_ub")
rightDt.istype(DataType.UWORD) -> listOf("prog8_lib","retval_interm_uw")
rightDt.istype(DataType.BYTE) -> listOf("prog8_lib","retval_interm_b2")
rightDt.istype(DataType.WORD) -> listOf("prog8_lib","retval_interm_w2")
else -> throw AssemblyError("invalid dt")
}
rightOperandReplacement = IdentifierReference(name, expr.position)

View File

@ -55,7 +55,6 @@ class TestModuleImporter: FunSpec({
error1.file.absolutePath shouldBe "${srcPathAbs.normalize()}"
}
program.modules.size shouldBe 1
val error2 = importer.importModule(srcPathAbs).getErrorOrElse { fail("should have import error") }
withClue(".file should be normalized") {
"${error2.file}" shouldBe "${error2.file.normalize()}"

View File

@ -0,0 +1,33 @@
package prog8tests
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import prog8.compiler.target.C64Target
import prog8tests.helpers.ErrorReporterForTests
import prog8tests.helpers.assertSuccess
import prog8tests.helpers.compileText
class TestAstChecks: FunSpec({
test("conditional expression w/float works even without tempvar to split it") {
val text = """
%import floats
main {
sub start() {
uword xx
if xx+99.99 == xx+1.234 {
xx++
}
}
}
"""
val errors = ErrorReporterForTests(keepMessagesAfterReporting = true)
compileText(C64Target, true, text, writeAssembly = true, errors=errors).assertSuccess()
errors.errors.size shouldBe 0
errors.warnings.size shouldBe 2
errors.warnings[0] shouldContain "converted to float"
errors.warnings[1] shouldContain "converted to float"
}
})

View File

@ -3,7 +3,7 @@ package prog8tests.helpers
import prog8.ast.base.Position
import prog8.compilerinterface.IErrorReporter
internal class ErrorReporterForTests(private val throwExceptionAtReportIfErrors: Boolean=true): IErrorReporter {
internal class ErrorReporterForTests(private val throwExceptionAtReportIfErrors: Boolean=true, private val keepMessagesAfterReporting: Boolean=false): IErrorReporter {
val errors = mutableListOf<String>()
val warnings = mutableListOf<String>()
@ -23,6 +23,12 @@ internal class ErrorReporterForTests(private val throwExceptionAtReportIfErrors:
errors.forEach { println("UNITTEST COMPILATION REPORT: ERROR: $it") }
if(throwExceptionAtReportIfErrors)
finalizeNumErrors(errors.size, warnings.size)
if(!keepMessagesAfterReporting) {
clear()
}
}
fun clear() {
errors.clear()
warnings.clear()
}

View File

@ -1,65 +1,8 @@
%import diskio
%import cx16diskio
%import textio
%zeropage basicsafe
main {
sub start() {
uword xx
uword yy
c64.SETMSG(%10000000)
xx = diskio.load(8, "hello", 0)
txt.nl()
yy = diskio.load(8, "hello", $8800)
txt.nl()
c64.SETMSG(0)
txt.print_uwhex(xx, true)
txt.nl()
txt.print_uwhex(yy, true)
txt.nl()
c64.SETMSG(%10000000)
xx = diskio.load_raw(8, "hello", $8700)
txt.nl()
c64.SETMSG(0)
txt.print_uwhex(xx, true)
txt.nl()
txt.print("\ncx16:\n")
c64.SETMSG(%10000000)
yy = cx16diskio.load(8, "x16edit", 1, $3000)
txt.nl()
c64.SETMSG(0)
txt.print_uwhex(yy, true)
txt.nl()
c64.SETMSG(%10000000)
xx = cx16diskio.load_raw(8, "x16edit", 1, $3000)
txt.nl()
c64.SETMSG(0)
txt.print_uwhex(xx, true)
txt.nl()
txt.print_uw(cx16diskio.load_size(1, $3000, xx))
txt.nl()
c64.SETMSG(%10000000)
xx = cx16diskio.load(8, "x16edit", 4, $a100)
txt.nl()
c64.SETMSG(0)
txt.print_uwhex(xx, true)
txt.nl()
txt.print_uw(cx16diskio.load_size(4, $a100, xx))
txt.nl()
c64.SETMSG(%10000000)
xx = cx16diskio.load_raw(8, "x16edit", 4, $a100)
txt.nl()
c64.SETMSG(0)
txt.print_uwhex(xx, true)
txt.nl()
txt.print_uw(cx16diskio.load_size(4, $a100, xx))
txt.nl()
uword xx = 10
if xx+99 == 1.234 {
xx++
}
}
}