From a8be94de6b421669b0072789ce9f776b7a95ca02 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Fri, 15 Dec 2023 22:05:57 +0100 Subject: [PATCH] better error message when attempting to cast a float to integer --- codeCore/src/prog8/code/ast/AstExpressions.kt | 8 +- .../optimizer/ConstantIdentifierReplacer.kt | 2 +- .../compiler/astprocessing/AstChecker.kt | 10 +- compiler/test/TestNumericLiteral.kt | 8 +- compiler/test/TestOptimization.kt | 2 +- compiler/test/TestPtNumber.kt | 8 +- compiler/test/TestTypecasts.kt | 20 ++-- .../prog8/ast/expressions/AstExpressions.kt | 100 +++++++++--------- docs/source/syntaxreference.rst | 2 +- docs/source/todo.rst | 2 + examples/test.p8 | 52 ++------- 11 files changed, 90 insertions(+), 124 deletions(-) diff --git a/codeCore/src/prog8/code/ast/AstExpressions.kt b/codeCore/src/prog8/code/ast/AstExpressions.kt index bb82072e7..5a9139874 100644 --- a/codeCore/src/prog8/code/ast/AstExpressions.kt +++ b/codeCore/src/prog8/code/ast/AstExpressions.kt @@ -3,7 +3,7 @@ package prog8.code.ast import prog8.code.core.* import java.util.* import kotlin.math.abs -import kotlin.math.round +import kotlin.math.truncate sealed class PtExpression(val type: DataType, position: Position) : PtNode(position) { @@ -227,9 +227,9 @@ class PtNumber(type: DataType, val number: Double, position: Position) : PtExpre if(type==DataType.BOOL) throw IllegalArgumentException("bool should have become ubyte @$position") if(type!=DataType.FLOAT) { - val rounded = round(number) - if (rounded != number) - throw IllegalArgumentException("refused rounding of float to avoid loss of precision @$position") + val trunc = truncate(number) + if (trunc != number) + throw IllegalArgumentException("refused truncating of float to avoid loss of precision @$position") } } diff --git a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt index 51709be3e..e77860032 100644 --- a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt +++ b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt @@ -25,7 +25,7 @@ class VarConstantValueTypeAdjuster(private val program: Program, private val err && declConstValue.type != decl.datatype) { // avoid silent float roundings if(decl.datatype in IntegerDatatypes && declConstValue.type == DataType.FLOAT) { - errors.err("refused rounding of float to avoid loss of precision", decl.value!!.position) + errors.err("refused truncating of float to avoid loss of precision", decl.value!!.position) } else { // cast the numeric literal to the appropriate datatype of the variable declConstValue.linkParents(decl) diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index f240bd38d..c66170669 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -652,7 +652,7 @@ internal class AstChecker(private val program: Program, } else -> { if(decl.type==VarDeclType.CONST) { - err("const declaration needs a compile-time constant initializer value, or range") + err("const declaration needs a compile-time constant initializer value") super.visit(decl) return } @@ -1069,8 +1069,12 @@ internal class AstChecker(private val program: Program, if(!typecast.expression.inferType(program).isKnown) errors.err("this expression doesn't return a value", typecast.expression.position) - if(typecast.expression is NumericLiteral) - errors.err("can't cast the value to the requested target type", typecast.expression.position) + if(typecast.expression is NumericLiteral) { + val castResult = (typecast.expression as NumericLiteral).cast(typecast.type) + if(castResult.isValid) + throw FatalAstException("cast should have been performed in const eval already") + errors.err(castResult.whyFailed!!, typecast.expression.position) + } super.visit(typecast) } diff --git a/compiler/test/TestNumericLiteral.kt b/compiler/test/TestNumericLiteral.kt index 0838ae244..ca556480b 100644 --- a/compiler/test/TestNumericLiteral.kt +++ b/compiler/test/TestNumericLiteral.kt @@ -35,16 +35,16 @@ class TestNumericLiteral: FunSpec({ sameValueAndType(NumericLiteral(DataType.UWORD, 12345.0, dummyPos), NumericLiteral(DataType.UWORD, 12345.0, dummyPos)) shouldBe true } - test("test rounding") { + test("test truncating") { shouldThrow { NumericLiteral(DataType.BYTE, -2.345, dummyPos) - }.message shouldContain "refused rounding" + }.message shouldContain "refused truncating" shouldThrow { NumericLiteral(DataType.BYTE, -2.6, dummyPos) - }.message shouldContain "refused rounding" + }.message shouldContain "refused truncating" shouldThrow { NumericLiteral(DataType.UWORD, 2222.345, dummyPos) - }.message shouldContain "refused rounding" + }.message shouldContain "refused truncating" NumericLiteral(DataType.UBYTE, 2.0, dummyPos).number shouldBe 2.0 NumericLiteral(DataType.BYTE, -2.0, dummyPos).number shouldBe -2.0 NumericLiteral(DataType.UWORD, 2222.0, dummyPos).number shouldBe 2222.0 diff --git a/compiler/test/TestOptimization.kt b/compiler/test/TestOptimization.kt index 5ea36d162..0491a047c 100644 --- a/compiler/test/TestOptimization.kt +++ b/compiler/test/TestOptimization.kt @@ -541,7 +541,7 @@ class TestOptimization: FunSpec({ val errors = ErrorReporterForTests() compileText(C64Target(), optimize=false, src, writeAssembly=false, errors = errors) shouldBe null errors.errors.size shouldBe 1 - errors.errors[0] shouldContain "can't cast" + errors.errors[0] shouldContain "no cast" } test("test augmented expression asmgen") { diff --git a/compiler/test/TestPtNumber.kt b/compiler/test/TestPtNumber.kt index aa43910df..fd95cad6b 100644 --- a/compiler/test/TestPtNumber.kt +++ b/compiler/test/TestPtNumber.kt @@ -33,16 +33,16 @@ class TestPtNumber: FunSpec({ sameValueAndType(PtNumber(DataType.UWORD, 12345.0, dummyPos), PtNumber(DataType.UWORD, 12345.0, dummyPos)) shouldBe true } - test("test rounding") { + test("test truncating") { shouldThrow { PtNumber(DataType.BYTE, -2.345, dummyPos) - }.message shouldContain "refused rounding" + }.message shouldContain "refused truncating" shouldThrow { PtNumber(DataType.BYTE, -2.6, dummyPos) - }.message shouldContain "refused rounding" + }.message shouldContain "refused truncating" shouldThrow { PtNumber(DataType.UWORD, 2222.345, dummyPos) - }.message shouldContain "refused rounding" + }.message shouldContain "refused truncating" PtNumber(DataType.UBYTE, 2.0, dummyPos).number shouldBe 2.0 PtNumber(DataType.BYTE, -2.0, dummyPos).number shouldBe -2.0 PtNumber(DataType.UWORD, 2222.0, dummyPos).number shouldBe 2222.0 diff --git a/compiler/test/TestTypecasts.kt b/compiler/test/TestTypecasts.kt index cce7ad43d..b48a8af67 100644 --- a/compiler/test/TestTypecasts.kt +++ b/compiler/test/TestTypecasts.kt @@ -762,11 +762,11 @@ main { val errors = ErrorReporterForTests() compileText(C64Target(), false, text, writeAssembly = true, errors=errors) shouldBe null errors.errors.size shouldBe 2 - errors.errors[0] shouldContain "can't cast" - errors.errors[1] shouldContain "can't cast" + errors.errors[0] shouldContain "no cast" + errors.errors[1] shouldContain "no cast" } - test("refuse to round float literal 1") { + test("refuse to truncate float literal 1") { val text = """ %option enable_floats main { @@ -778,11 +778,11 @@ main { val errors = ErrorReporterForTests() compileText(C64Target(), false, text, errors=errors) shouldBe null errors.errors.size shouldBe 2 - errors.errors[0] shouldContain "can't cast" - errors.errors[1] shouldContain "can't cast" + errors.errors[0] shouldContain "refused" + errors.errors[1] shouldContain "refused" } - test("refuse to round float literal 2") { + test("refuse to truncate float literal 2") { val text = """ %option enable_floats main { @@ -798,7 +798,7 @@ main { errors.errors[0] shouldContain "in-place makes no sense" } - test("refuse to round float literal 3") { + test("refuse to truncate float literal 3") { val text = """ %option enable_floats main { @@ -811,8 +811,8 @@ main { val errors = ErrorReporterForTests() compileText(C64Target(), false, text, errors=errors) shouldBe null errors.errors.size shouldBe 2 - errors.errors[0] shouldContain "can't cast" - errors.errors[1] shouldContain "can't cast" + errors.errors[0] shouldContain "refused" + errors.errors[1] shouldContain "refused" } test("correct implicit casts of signed number comparison and logical expressions") { @@ -1063,7 +1063,7 @@ main { val errors=ErrorReporterForTests() compileText(C64Target(), false, src, writeAssembly = false, errors=errors) shouldBe null errors.errors.size shouldBe 5 - errors.errors[0] shouldContain "can't cast" + errors.errors[0] shouldContain "no cast" errors.errors[1] shouldContain "overflow" errors.errors[2] shouldContain "LONG doesn't match" errors.errors[3] shouldContain "out of range" diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index 5504a47f8..622995531 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -13,7 +13,7 @@ import prog8.code.core.* import java.util.* import kotlin.math.abs import kotlin.math.floor -import kotlin.math.round +import kotlin.math.truncate sealed class Expression: Node { @@ -447,10 +447,10 @@ class NumericLiteral(val type: DataType, // only numerical types allowed if(type==DataType.FLOAT) numbervalue else { - val rounded = round(numbervalue) - if(rounded != numbervalue) - throw ExpressionError("refused rounding of float to avoid loss of precision", position) - rounded + val trunc = truncate(numbervalue) + if(trunc != numbervalue) + throw ExpressionError("refused truncating of float to avoid loss of precision", position) + trunc } } @@ -540,7 +540,7 @@ class NumericLiteral(val type: DataType, // only numerical types allowed operator fun compareTo(other: NumericLiteral): Int = number.compareTo(other.number) - class CastValue(val isValid: Boolean, private val value: NumericLiteral?) { + class CastValue(val isValid: Boolean, val whyFailed: String?, private val value: NumericLiteral?) { fun valueOrZero() = if(isValid) value!! else NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) fun linkParent(parent: Node) { value?.linkParents(parent) @@ -555,122 +555,122 @@ class NumericLiteral(val type: DataType, // only numerical types allowed private fun internalCast(targettype: DataType): CastValue { if(type==targettype) - return CastValue(true, this) + return CastValue(true, null, this) when(type) { DataType.UBYTE -> { if(targettype== DataType.BYTE && number <= 127) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype== DataType.WORD || targettype== DataType.UWORD) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype== DataType.FLOAT) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.LONG) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.BOOL) - return CastValue(true, fromBoolean(number!=0.0, position)) + return CastValue(true, null, fromBoolean(number!=0.0, position)) } DataType.BYTE -> { if(targettype== DataType.UBYTE) { if(number in -128.0..0.0) - return CastValue(true, NumericLiteral(targettype, number.toInt().toUByte().toDouble(), position)) + return CastValue(true, null, NumericLiteral(targettype, number.toInt().toUByte().toDouble(), position)) else if(number in 0.0..255.0) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } if(targettype== DataType.UWORD) { if(number in -32768.0..0.0) - return CastValue(true, NumericLiteral(targettype, number.toInt().toUShort().toDouble(), position)) + return CastValue(true, null, NumericLiteral(targettype, number.toInt().toUShort().toDouble(), position)) else if(number in 0.0..65535.0) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } if(targettype== DataType.WORD) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype== DataType.FLOAT) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.BOOL) - return CastValue(true, fromBoolean(number!=0.0, position)) + return CastValue(true, null, fromBoolean(number!=0.0, position)) if(targettype==DataType.LONG) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } DataType.UWORD -> { if(targettype== DataType.BYTE && number <= 127) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype== DataType.UBYTE && number <= 255) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype== DataType.WORD && number <= 32767) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype== DataType.FLOAT) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.BOOL) - return CastValue(true, fromBoolean(number!=0.0, position)) + return CastValue(true, null, fromBoolean(number!=0.0, position)) if(targettype==DataType.LONG) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } DataType.WORD -> { if(targettype== DataType.BYTE && number >= -128 && number <=127) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype== DataType.UBYTE) { if(number in -128.0..0.0) - return CastValue(true, NumericLiteral(targettype, number.toInt().toUByte().toDouble(), position)) + return CastValue(true, null, NumericLiteral(targettype, number.toInt().toUByte().toDouble(), position)) else if(number in 0.0..255.0) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } if(targettype== DataType.UWORD) { if(number in -32768.0 .. 0.0) - return CastValue(true, NumericLiteral(targettype, number.toInt().toUShort().toDouble(), position)) + return CastValue(true, null, NumericLiteral(targettype, number.toInt().toUShort().toDouble(), position)) else if(number in 0.0..65535.0) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } if(targettype== DataType.FLOAT) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.BOOL) - return CastValue(true, fromBoolean(number!=0.0, position)) + return CastValue(true, null, fromBoolean(number!=0.0, position)) if(targettype==DataType.LONG) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } DataType.FLOAT -> { try { if (targettype == DataType.BYTE && number >= -128 && number <= 127) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if (targettype == DataType.UBYTE && number >= 0 && number <= 255) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if (targettype == DataType.WORD && number >= -32768 && number <= 32767) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if (targettype == DataType.UWORD && number >= 0 && number <= 65535) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.LONG && number >=0 && number <= 2147483647) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.BOOL) - return CastValue(true, fromBoolean(number!=0.0, position)) + return CastValue(true, null, fromBoolean(number!=0.0, position)) } catch (x: ExpressionError) { - return CastValue(false, null) + return CastValue(false, x.message,null) } } DataType.BOOL -> { - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } DataType.LONG -> { try { if (targettype == DataType.BYTE && number >= -128 && number <= 127) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if (targettype == DataType.UBYTE && number >= 0 && number <= 255) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if (targettype == DataType.WORD && number >= -32768 && number <= 32767) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if (targettype == DataType.UWORD && number >= 0 && number <= 65535) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) if(targettype==DataType.BOOL) - return CastValue(true, fromBoolean(number!=0.0, position)) + return CastValue(true, null, fromBoolean(number!=0.0, position)) if(targettype== DataType.FLOAT) - return CastValue(true, NumericLiteral(targettype, number, position)) + return CastValue(true, null, NumericLiteral(targettype, number, position)) } catch (x: ExpressionError) { - return CastValue(false, null) + return CastValue(false, x.message, null) } } else -> { throw FatalAstException("type cast of weird type $type") } } - return CastValue(false, null) + return CastValue(false, "no cast available between these types", null) } } diff --git a/docs/source/syntaxreference.rst b/docs/source/syntaxreference.rst index e674b1787..6cc5ce62d 100644 --- a/docs/source/syntaxreference.rst +++ b/docs/source/syntaxreference.rst @@ -419,7 +419,7 @@ For instance ``%1001_0001`` is a valid binary number and ``3_000_000.99`` is a v Data type conversion ^^^^^^^^^^^^^^^^^^^^ Many type conversions are possible by just writing ``as `` at the end of an expression, -for example ``ubyte ub = floatvalue as ubyte`` will convert the floating point value to an unsigned byte. +for example ``word ww = bytevalue as word`` will convert the byte value to a signed word. Memory mapped variables diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 9fbd81d4c..77f6eeffb 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -4,6 +4,8 @@ TODO - merge branch optimize-st for some optimizations regarding SymbolTable use +- fix that pesky unit test that puts temp files in the compiler directory + - [on branch: call-pointers] allow calling a subroutine via a pointer variable (indirect JSR, optimized form of callfar()) modify programs (shell, paint) that now use callfar diff --git a/examples/test.p8 b/examples/test.p8 index 40ad3985c..f06fb8a49 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,53 +1,13 @@ %import textio -%import string +%import floats %zeropage basicsafe main { sub start() { - bool @shared statusc = test_carry_set() - bool @shared statusv = test_v_set() - bool @shared statusz = test_z_set() - bool @shared statusn = test_n_set() + const uword vera_freq = (136.5811 / 0.3725290298461914) as uword + const uword v_f_10 = (136.5811 / 0.3725290298461914 + 0.5) as uword - if test_carry_set() { - txt.print("set!\n") - } - if test_v_set() { - txt.print("set!\n") - } - if test_z_set() { - txt.print("set!\n") - } - if test_n_set() { - txt.print("set!\n") - } + txt.print_uw(v_f_10) + txt.print_uw(vera_freq) } - - asmsub test_carry_set() -> bool @Pc { - %asm {{ - sec - rts - }} - } - - asmsub test_v_set() -> bool @Pv { - %asm {{ - sec - rts - }} - } - - asmsub test_z_set() -> bool @Pz { - %asm {{ - sec - rts - }} - } - - asmsub test_n_set() -> bool @Pn { - %asm {{ - sec - rts - }} - } -} + }