From ed09dd4e9e856cd42f88c646f9d3c8cd39284560 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Wed, 9 Oct 2024 00:54:17 +0200 Subject: [PATCH] improve automatic type conversions for return values, fixes #155 --- .../compiler/astprocessing/AstChecker.kt | 3 +- .../compiler/astprocessing/TypecastsAdder.kt | 7 +++ compiler/test/TestNumericLiteral.kt | 38 +++++++++++++++ compiler/test/TestTypecasts.kt | 30 ++++++++++++ .../prog8/ast/expressions/AstExpressions.kt | 46 +++++++++++++++++++ docs/source/todo.rst | 4 -- 6 files changed, 122 insertions(+), 6 deletions(-) diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index c80b92bbe..3f0d8bfc9 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -139,8 +139,7 @@ internal class AstChecker(private val program: Program, // you can return a string or array when an uword (pointer) is returned } else if(valueDt istype DataType.UWORD && expectedReturnValues[0]==DataType.STR) { // you can return a uword pointer when the return type is a string - } - else { + } else { errors.err("type $valueDt of return value doesn't match subroutine's return type ${expectedReturnValues[0]}",returnStmt.value!!.position) } } diff --git a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt index b9eccc624..9788e6f6a 100644 --- a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt +++ b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt @@ -308,6 +308,13 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val if(subroutine.returntypes.size==1) { val subReturnType = subroutine.returntypes.first() val returnDt = returnValue.inferType(program) + if(returnDt isnot subReturnType && returnValue is NumericLiteral) { + // see if we might change the returnvalue into the expected type + val castedValue = returnValue.convertTypeKeepValue(subReturnType) + if(castedValue.isValid) { + return listOf(IAstModification.ReplaceNode(returnValue, castedValue.valueOrZero(), returnStmt)) + } + } if (returnDt istype subReturnType or returnDt.isNotAssignableTo(subReturnType)) return noModifications if (returnValue is NumericLiteral) { diff --git a/compiler/test/TestNumericLiteral.kt b/compiler/test/TestNumericLiteral.kt index b714cfe21..c381598c4 100644 --- a/compiler/test/TestNumericLiteral.kt +++ b/compiler/test/TestNumericLiteral.kt @@ -10,6 +10,7 @@ import prog8.ast.expressions.ArrayLiteral import prog8.ast.expressions.InferredTypes import prog8.ast.expressions.NumericLiteral import prog8.ast.expressions.StringLiteral +import prog8.ast.statements.AnonymousScope import prog8.code.core.DataType import prog8.code.core.Encoding import prog8.code.core.Position @@ -183,4 +184,41 @@ class TestNumericLiteral: FunSpec({ NumericLiteral.optimalNumeric(-1234.0, Position.DUMMY).type shouldBe DataType.WORD NumericLiteral.optimalNumeric(-1234.0, Position.DUMMY).number shouldBe -1234.0 } + + test("cast can change value") { + fun num(dt: DataType, num: Double): NumericLiteral { + val n = NumericLiteral(dt, num, Position.DUMMY) + n.linkParents(AnonymousScope(mutableListOf(), Position.DUMMY)) + return n + } + val cast1 = num(DataType.UBYTE, 200.0).cast(DataType.BYTE, false) + cast1.isValid shouldBe true + cast1.valueOrZero().number shouldBe -56.0 + val cast2 = num(DataType.BYTE, -50.0).cast(DataType.UBYTE, false) + cast2.isValid shouldBe true + cast2.valueOrZero().number shouldBe 206.0 + val cast3 = num(DataType.UWORD, 55555.0).cast(DataType.WORD, false) + cast3.isValid shouldBe true + cast3.valueOrZero().number shouldBe -9981.0 + val cast4 = num(DataType.WORD, -3333.0).cast(DataType.UWORD, false) + cast4.isValid shouldBe true + cast4.valueOrZero().number shouldBe 62203.0 + } + + test("convert cannot change value") { + fun num(dt: DataType, num: Double): NumericLiteral { + val n = NumericLiteral(dt, num, Position.DUMMY) + n.linkParents(AnonymousScope(mutableListOf(), Position.DUMMY)) + return n + } + num(DataType.UBYTE, 200.0).convertTypeKeepValue(DataType.BYTE).isValid shouldBe false + num(DataType.BYTE, -50.0).convertTypeKeepValue(DataType.UBYTE).isValid shouldBe false + num(DataType.UWORD, 55555.0).convertTypeKeepValue(DataType.WORD).isValid shouldBe false + num(DataType.WORD, -3333.0).convertTypeKeepValue(DataType.UWORD).isValid shouldBe false + + num(DataType.UBYTE, 42.0).convertTypeKeepValue(DataType.BYTE).isValid shouldBe true + num(DataType.BYTE, 42.0).convertTypeKeepValue(DataType.UBYTE).isValid shouldBe true + num(DataType.UWORD, 12345.0).convertTypeKeepValue(DataType.WORD).isValid shouldBe true + num(DataType.WORD, 12345.0).convertTypeKeepValue(DataType.UWORD).isValid shouldBe true + } }) diff --git a/compiler/test/TestTypecasts.kt b/compiler/test/TestTypecasts.kt index b26127f0c..15f819687 100644 --- a/compiler/test/TestTypecasts.kt +++ b/compiler/test/TestTypecasts.kt @@ -857,4 +857,34 @@ main { errors.errors.size shouldBe 2 errors.errors[1] shouldContain "undefined symbol" } + + test("return unsigned values for signed results ok if value fits") { + val src = """ +main { + sub start() { + void foo() + void bar() + void overflow1() + void overflow2() + } + + sub foo() -> byte { + return 42 + } + sub bar() -> word { + return 12345 + } + sub overflow1() -> byte { + return 200 + } + sub overflow2() -> word { + return 44444 + } +}""" + val errors = ErrorReporterForTests() + compileText(C64Target(), false, src, writeAssembly = false, errors = errors) shouldBe null + errors.errors.size shouldBe 2 + errors.errors[0] shouldContain "17:16: type UBYTE of return value doesn't match subroutine's return type BYTE" + errors.errors[1] shouldContain "20:16: type UWORD of return value doesn't match subroutine's return type WORD" + } }) diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index 92a96ba49..411c1abe6 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -612,6 +612,9 @@ class NumericLiteral(val type: DataType, // only numerical types allowed } private fun internalCast(targettype: DataType, implicit: Boolean): ValueAfterCast { + + // NOTE: this MAY convert a value into another when switching from singed to unsigned!!! + if(type==targettype) return ValueAfterCast(true, null, this) if (implicit) { @@ -736,6 +739,49 @@ class NumericLiteral(val type: DataType, // only numerical types allowed } return ValueAfterCast(false, "no cast available from $type to $targettype", null) } + + fun convertTypeKeepValue(targetDt: DataType): ValueAfterCast { + if(type==targetDt) + return ValueAfterCast(true, null, this) + + when(type) { + DataType.UBYTE -> { + when(targetDt) { + DataType.BYTE -> if(number<=127.0) return cast(targetDt, false) + DataType.UWORD, DataType.WORD, DataType.LONG, DataType.FLOAT -> return cast(targetDt, false) + else -> {} + } + } + DataType.BYTE -> { + when(targetDt) { + DataType.UBYTE, DataType.UWORD -> if(number>=0.0) return cast(targetDt, false) + DataType.WORD, DataType.LONG, DataType.FLOAT -> return cast(targetDt, false) + else -> {} + } + } + DataType.UWORD -> { + when(targetDt) { + DataType.UBYTE -> if(number<=255.0) return cast(targetDt, false) + DataType.BYTE -> if(number<=127.0) return cast(targetDt, false) + DataType.WORD -> if(number<=32767.0) return cast(targetDt, false) + DataType.LONG, DataType.FLOAT -> return cast(targetDt, false) + else -> {} + } + } + DataType.WORD -> { + when(targetDt) { + DataType.UBYTE -> if(number in 0.0..255.0) return cast(targetDt, false) + DataType.BYTE -> if(number in -128.0..127.0) return cast(targetDt, false) + DataType.UWORD -> if(number in 0.0..32767.0) return cast(targetDt, false) + DataType.LONG, DataType.FLOAT -> return cast(targetDt, false) + else -> {} + } + } + DataType.LONG, DataType.FLOAT -> return cast(targetDt, false) + else -> {} + } + return ValueAfterCast(false, "no type conversion possible from $type to $targetDt", null) + } } class CharLiteral private constructor(val value: Char, diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 00249f05a..4e686fab9 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -1,10 +1,6 @@ TODO ==== -Don't allow assigning str to array! -Don't allow assigning array to str! -Don't allow assigning a word to an array or string! - Put palette fade to white / black in. Regenerate skeleton doc files.