From 6d17e5307c52f6c5cb332120c281baa1ab6cfed9 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Thu, 27 Aug 2020 19:06:27 +0200 Subject: [PATCH] fixed typecasting of const arguments once again --- compiler/src/prog8/ast/base/Base.kt | 3 +- .../prog8/ast/expressions/AstExpressions.kt | 68 ++++++++++--------- .../src/prog8/ast/processing/AstChecker.kt | 6 +- .../prog8/ast/processing/TypecastsAdder.kt | 34 +++++----- .../prog8/ast/processing/VariousCleanups.kt | 5 +- .../optimizer/ConstantFoldingOptimizer.kt | 38 ++++++----- .../prog8/optimizer/ExpressionSimplifier.kt | 6 +- examples/arithmetic/div.p8 | 10 +-- examples/arithmetic/minus.p8 | 20 +++--- examples/arithmetic/mult.p8 | 12 ++-- examples/arithmetic/plus.p8 | 16 +++-- examples/arithmetic/remainder.p8 | 1 - 12 files changed, 122 insertions(+), 97 deletions(-) diff --git a/compiler/src/prog8/ast/base/Base.kt b/compiler/src/prog8/ast/base/Base.kt index acb5a4b9a..46c4e60e7 100644 --- a/compiler/src/prog8/ast/base/Base.kt +++ b/compiler/src/prog8/ast/base/Base.kt @@ -21,10 +21,9 @@ enum class DataType { STRUCT; // pass by reference /** - * is the type assignable to the given other type? + * is the type assignable to the given other type (perhaps via a typecast) without loss of precision? */ infix fun isAssignableTo(targetType: DataType) = - // what types are assignable to others, perhaps via a typecast, without loss of precision? when(this) { UBYTE -> targetType in setOf(UBYTE, WORD, UWORD, FLOAT) BYTE -> targetType in setOf(BYTE, WORD, FLOAT) diff --git a/compiler/src/prog8/ast/expressions/AstExpressions.kt b/compiler/src/prog8/ast/expressions/AstExpressions.kt index cdcac3f25..4803b9321 100644 --- a/compiler/src/prog8/ast/expressions/AstExpressions.kt +++ b/compiler/src/prog8/ast/expressions/AstExpressions.kt @@ -295,9 +295,11 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type) override fun constValue(program: Program): NumericLiteralValue? { val cv = expression.constValue(program) ?: return null - return cv.castNoCheck(type) - // val value = RuntimeValue(cv.type, cv.asNumericValue!!).cast(type) - // return LiteralValue.fromNumber(value.numericValue(), value.type, position).cast(type) + val cast = cv.cast(type) + return if(cast.isValid) + cast.valueOrZero() + else + null } override fun toString(): String { @@ -416,62 +418,66 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed operator fun compareTo(other: NumericLiteralValue): Int = number.toDouble().compareTo(other.number.toDouble()) - fun castNoCheck(targettype: DataType): NumericLiteralValue { + class CastValue(val isValid: Boolean, private val value: NumericLiteralValue?) { + fun valueOrZero() = if(isValid) value!! else NumericLiteralValue(DataType.UBYTE, 0, Position.DUMMY) + } + + fun cast(targettype: DataType): CastValue { if(type==targettype) - return this + return CastValue(true, this) val numval = number.toDouble() when(type) { DataType.UBYTE -> { if(targettype== DataType.BYTE && numval <= 127) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if(targettype== DataType.WORD || targettype== DataType.UWORD) - return NumericLiteralValue(targettype, number.toInt(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toInt(), position)) if(targettype== DataType.FLOAT) - return NumericLiteralValue(targettype, number.toDouble(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toDouble(), position)) } DataType.BYTE -> { if(targettype== DataType.UBYTE && numval >= 0) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if(targettype== DataType.UWORD && numval >= 0) - return NumericLiteralValue(targettype, number.toInt(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toInt(), position)) if(targettype== DataType.WORD) - return NumericLiteralValue(targettype, number.toInt(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toInt(), position)) if(targettype== DataType.FLOAT) - return NumericLiteralValue(targettype, number.toDouble(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toDouble(), position)) } DataType.UWORD -> { if(targettype== DataType.BYTE && numval <= 127) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if(targettype== DataType.UBYTE && numval <= 255) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if(targettype== DataType.WORD && numval <= 32767) - return NumericLiteralValue(targettype, number.toInt(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toInt(), position)) if(targettype== DataType.FLOAT) - return NumericLiteralValue(targettype, number.toDouble(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toDouble(), position)) } DataType.WORD -> { if(targettype== DataType.BYTE && numval >= -128 && numval <=127) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if(targettype== DataType.UBYTE && numval >= 0 && numval <= 255) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if(targettype== DataType.UWORD && numval >=0) - return NumericLiteralValue(targettype, number.toInt(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toInt(), position)) if(targettype== DataType.FLOAT) - return NumericLiteralValue(targettype, number.toDouble(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toDouble(), position)) } DataType.FLOAT -> { if (targettype == DataType.BYTE && numval >= -128 && numval <=127) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if (targettype == DataType.UBYTE && numval >=0 && numval <= 255) - return NumericLiteralValue(targettype, number.toShort(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toShort(), position)) if (targettype == DataType.WORD && numval >= -32768 && numval <= 32767) - return NumericLiteralValue(targettype, number.toInt(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toInt(), position)) if (targettype == DataType.UWORD && numval >=0 && numval <= 65535) - return NumericLiteralValue(targettype, number.toInt(), position) + return CastValue(true, NumericLiteralValue(targettype, number.toInt(), position)) } else -> {} } - throw ExpressionError("can't cast $type into $targettype", position) + return CastValue(false, null) } } @@ -581,14 +587,14 @@ class ArrayLiteralValue(val type: InferredTypes.InferredType, // inferred be if(num==null) { // an array of UWORDs could possibly also contain AddressOfs, other stuff can't be casted if (elementType != DataType.UWORD || it !is AddressOf) - return null + return null // can't cast a value of the array, abort it } else { - try { - num.castNoCheck(elementType) - } catch(x: ExpressionError) { - return null - } + val cast = num.cast(elementType) + if(cast.isValid) + cast.valueOrZero() + else + return null // can't cast a value of the array, abort } }.toTypedArray() return ArrayLiteralValue(InferredTypes.InferredType.known(targettype), castArray, position = position) diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index e907101d3..1466a8f4a 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -1215,7 +1215,11 @@ internal class AstChecker(private val program: Program, is AddressOf -> it.identifier.heapId(program.namespace) is TypecastExpression -> { val constVal = it.expression.constValue(program) - constVal?.castNoCheck(it.type)?.number?.toInt() ?: -9999999 + val cast = constVal?.cast(it.type) + if(cast==null || !cast.isValid) + -9999999 + else + cast.valueOrZero().number.toInt() } else -> -9999999 } diff --git a/compiler/src/prog8/ast/processing/TypecastsAdder.kt b/compiler/src/prog8/ast/processing/TypecastsAdder.kt index de3745f06..4641cacd1 100644 --- a/compiler/src/prog8/ast/processing/TypecastsAdder.kt +++ b/compiler/src/prog8/ast/processing/TypecastsAdder.kt @@ -51,8 +51,13 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke TypecastExpression(assignment.value, targettype, true, assignment.value.position), assignment)) } else { - fun castLiteral(cvalue: NumericLiteralValue): List = - listOf(IAstModification.ReplaceNode(cvalue, cvalue.castNoCheck(targettype), cvalue.parent)) + fun castLiteral(cvalue: NumericLiteralValue): List { + val cast = cvalue.cast(targettype) + return if(cast.isValid) + listOf(IAstModification.ReplaceNode(cvalue, cast.valueOrZero(), cvalue.parent)) + else + emptyList() + } val cvalue = assignment.value.constValue(program) if(cvalue!=null) { val number = cvalue.number.toDouble() @@ -109,17 +114,12 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke AddressOf(arg.second.value as IdentifierReference, arg.second.value.position), call as Node) } else if(arg.second.value is NumericLiteralValue) { - if(argtype.isAssignableTo(requiredType)) { - try { - val castedValue = (arg.second.value as NumericLiteralValue).castNoCheck(requiredType) - modifications += IAstModification.ReplaceNode( - call.args[arg.second.index], - castedValue, - call as Node) - } catch (x: ExpressionError) { - // cast failed - } - } + val cast = (arg.second.value as NumericLiteralValue).cast(requiredType) + if(cast.isValid) + modifications += IAstModification.ReplaceNode( + call.args[arg.second.index], + cast.valueOrZero(), + call as Node) } } } @@ -162,7 +162,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke // make sure the memory address is an uword val dt = memread.addressExpression.inferType(program) if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { - val typecast = (memread.addressExpression as? NumericLiteralValue)?.castNoCheck(DataType.UWORD) + val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)?.valueOrZero() ?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) return listOf(IAstModification.ReplaceNode(memread.addressExpression, typecast, memread)) } @@ -173,7 +173,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke // make sure the memory address is an uword val dt = memwrite.addressExpression.inferType(program) if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { - val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.castNoCheck(DataType.UWORD) + val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)?.valueOrZero() ?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) return listOf(IAstModification.ReplaceNode(memwrite.addressExpression, typecast, memwrite)) } @@ -190,7 +190,9 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke if (returnValue.inferType(program).istype(subReturnType)) return noModifications if (returnValue is NumericLiteralValue) { - returnStmt.value = returnValue.castNoCheck(subroutine.returntypes.single()) + val cast = returnValue.cast(subroutine.returntypes.single()) + if(cast.isValid) + returnStmt.value = cast.valueOrZero() } else { return listOf(IAstModification.ReplaceNode( returnValue, diff --git a/compiler/src/prog8/ast/processing/VariousCleanups.kt b/compiler/src/prog8/ast/processing/VariousCleanups.kt index 83155a6cd..2af019a3e 100644 --- a/compiler/src/prog8/ast/processing/VariousCleanups.kt +++ b/compiler/src/prog8/ast/processing/VariousCleanups.kt @@ -34,8 +34,9 @@ internal class VariousCleanups: AstWalker() { override fun before(typecast: TypecastExpression, parent: Node): Iterable { if(typecast.expression is NumericLiteralValue) { - val value = (typecast.expression as NumericLiteralValue).castNoCheck(typecast.type) - return listOf(IAstModification.ReplaceNode(typecast, value, parent)) + val value = (typecast.expression as NumericLiteralValue).cast(typecast.type) + if(value.isValid) + return listOf(IAstModification.ReplaceNode(typecast, value.valueOrZero(), parent)) } return noModifications diff --git a/compiler/src/prog8/optimizer/ConstantFoldingOptimizer.kt b/compiler/src/prog8/optimizer/ConstantFoldingOptimizer.kt index a5147fbe5..da478dfd7 100644 --- a/compiler/src/prog8/optimizer/ConstantFoldingOptimizer.kt +++ b/compiler/src/prog8/optimizer/ConstantFoldingOptimizer.kt @@ -171,7 +171,9 @@ internal class ConstantIdentifierReplacer(private val program: Program, private if(declValue!=null && decl.type==VarDeclType.VAR && declValue is NumericLiteralValue && !declValue.inferType(program).istype(decl.datatype)) { // cast the numeric literal to the appropriate datatype of the variable - return listOf(IAstModification.ReplaceNode(decl.value!!, declValue.castNoCheck(decl.datatype), decl)) + val cast = declValue.cast(decl.datatype) + if(cast.isValid) + return listOf(IAstModification.ReplaceNode(decl.value!!, cast.valueOrZero(), decl)) } return noModifications @@ -317,20 +319,23 @@ internal class ConstantFoldingOptimizer(private val program: Program) : AstWalke override fun after(forLoop: ForLoop, parent: Node): Iterable { fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr { - val newFrom: NumericLiteralValue - val newTo: NumericLiteralValue - try { - newFrom = rangeFrom.castNoCheck(targetDt) - newTo = rangeTo.castNoCheck(targetDt) - } catch (x: ExpressionError) { + val fromCast = rangeFrom.cast(targetDt) + val toCast = rangeTo.cast(targetDt) + if(!fromCast.isValid || !toCast.isValid) return range - } - val newStep: Expression = try { - stepLiteral?.castNoCheck(targetDt)?: range.step - } catch(ee: ExpressionError) { - range.step - } - return RangeExpr(newFrom, newTo, newStep, range.position) + + val newStep = + if(stepLiteral!=null) { + val stepCast = stepLiteral.cast(targetDt) + if(stepCast.isValid) + stepCast.valueOrZero() + else + range.step + } else { + range.step + } + + return RangeExpr(fromCast.valueOrZero(), toCast.valueOrZero(), newStep, range.position) } // adjust the datatype of a range expression in for loops to the loop variable. @@ -381,8 +386,9 @@ internal class ConstantFoldingOptimizer(private val program: Program) : AstWalke if(decl.type== VarDeclType.CONST && numval!=null) { val valueDt = numval.inferType(program) if(!valueDt.istype(decl.datatype)) { - val adjustedVal = numval.castNoCheck(decl.datatype) - return listOf(IAstModification.ReplaceNode(numval, adjustedVal, decl)) + val cast = numval.cast(decl.datatype) + if(cast.isValid) + return listOf(IAstModification.ReplaceNode(numval, cast.valueOrZero(), decl)) } } return noModifications diff --git a/compiler/src/prog8/optimizer/ExpressionSimplifier.kt b/compiler/src/prog8/optimizer/ExpressionSimplifier.kt index f3b8e9aab..bbe7a03cf 100644 --- a/compiler/src/prog8/optimizer/ExpressionSimplifier.kt +++ b/compiler/src/prog8/optimizer/ExpressionSimplifier.kt @@ -29,9 +29,9 @@ internal class ExpressionSimplifier(private val program: Program) : AstWalker() // try to statically convert a literal value into one of the desired type val literal = typecast.expression as? NumericLiteralValue if (literal != null) { - val newLiteral = literal.castNoCheck(typecast.type) - if (newLiteral !== literal) - mods += IAstModification.ReplaceNode(typecast.expression, newLiteral, typecast) + val newLiteral = literal.cast(typecast.type) + if (newLiteral.isValid && newLiteral.valueOrZero() !== literal) + mods += IAstModification.ReplaceNode(typecast.expression, newLiteral.valueOrZero(), typecast) } // remove redundant nested typecasts diff --git a/examples/arithmetic/div.p8 b/examples/arithmetic/div.p8 index 7bbd8e059..1827b4023 100644 --- a/examples/arithmetic/div.p8 +++ b/examples/arithmetic/div.p8 @@ -2,6 +2,8 @@ %import c64textio %zeropage basicsafe +; TODO implement DIV asm generation + main { sub start() { @@ -9,16 +11,16 @@ main { div_ubyte(100, 6, 16) div_ubyte(255, 2, 127) - div_byte(0, 1, 0) ; TODO fix type error - div_byte(100, -6, -16) ; TODO fix type error - div_byte(127, -2, -63) ; TODO fix type error + div_byte(0, 1, 0) + div_byte(100, -6, -16) + div_byte(127, -2, -63) div_uword(0,1,0) div_uword(40000,500,80) div_uword(43211,2,21605) div_word(0,1,0) - div_word(-20000,500,-40) ; TODO fix type error + div_word(-20000,500,-40) div_word(-2222,2,-1111) div_float(0,1,0) diff --git a/examples/arithmetic/minus.p8 b/examples/arithmetic/minus.p8 index 658303625..fcc560410 100644 --- a/examples/arithmetic/minus.p8 +++ b/examples/arithmetic/minus.p8 @@ -2,6 +2,8 @@ %import c64textio %zeropage basicsafe +; TODO implement float MINUS asm generation + main { sub start() { @@ -10,11 +12,11 @@ main { minus_ubyte(200, 100, 100) minus_ubyte(100, 200, 156) - minus_byte(0, 0, 0) ; TODO fix type error - minus_byte(100, 100, 0) ; TODO fix type error - minus_byte(50, -50, 100) ; TODO fix type error - minus_byte(0, -30, 30) ; TODO fix type error - minus_byte(-30, 0, -30) ; TODO fix type error + minus_byte(0, 0, 0) + minus_byte(100, 100, 0) + minus_byte(50, -50, 100) + minus_byte(0, -30, 30) + minus_byte(-30, 0, -30) minus_uword(0,0,0) minus_uword(50000,0, 50000) @@ -22,10 +24,10 @@ main { minus_uword(20000,50000,35536) minus_word(0,0,0) - minus_word(1000,1000,0) ; TODO fix type error - minus_word(-1000,1000,-2000) ; TODO fix type error - minus_word(1000,500,500) ; TODO fix type error - minus_word(0,-3333,3333) ; TODO fix type error + minus_word(1000,1000,0) + minus_word(-1000,1000,-2000) + minus_word(1000,500,500) + minus_word(0,-3333,3333) minus_word(-3333,0,-3333) minus_float(0,0,0) diff --git a/examples/arithmetic/mult.p8 b/examples/arithmetic/mult.p8 index 829c9e194..211f4b145 100644 --- a/examples/arithmetic/mult.p8 +++ b/examples/arithmetic/mult.p8 @@ -2,6 +2,8 @@ %import c64textio %zeropage basicsafe +; TODO implement MUL asm generation + main { sub start() { @@ -9,17 +11,17 @@ main { mul_ubyte(20, 1, 20) mul_ubyte(20, 10, 200) - mul_byte(0, 0, 0) ; TODO fix type error - mul_byte(10, 10, 100) ; TODO fix type error - mul_byte(5, -5, -25) ; TODO fix type error - mul_byte(0, -30, 0) ; TODO fix type error + mul_byte(0, 0, 0) + mul_byte(10, 10, 100) + mul_byte(5, -5, -25) + mul_byte(0, -30, 0) mul_uword(0,0,0) mul_uword(50000,1, 50000) mul_uword(500,100,50000) mul_word(0,0,0) - mul_word(-10,1000,-10000) ; TODO fix type error + mul_word(-10,1000,-10000) mul_word(1,-3333,-3333) mul_float(0,0,0) diff --git a/examples/arithmetic/plus.p8 b/examples/arithmetic/plus.p8 index e2c552d47..15fb0d4a1 100644 --- a/examples/arithmetic/plus.p8 +++ b/examples/arithmetic/plus.p8 @@ -2,6 +2,8 @@ %import c64textio %zeropage basicsafe +; TODO implement float PLUS asm generation + main { sub start() { @@ -9,19 +11,19 @@ main { plus_ubyte(0, 200, 200) plus_ubyte(100, 200, 44) - plus_byte(0, 0, 0) ; TODO fix type error - plus_byte(-100, 100, 0) ; TODO fix type error - plus_byte(-50, 100, 50) ; TODO fix type error - plus_byte(0, -30, -30) ; TODO fix type error - plus_byte(-30, 0, -30) ; TODO fix type error + plus_byte(0, 0, 0) + plus_byte(-100, 100, 0) + plus_byte(-50, 100, 50) + plus_byte(0, -30, -30) + plus_byte(-30, 0, -30) plus_uword(0,0,0) plus_uword(0,50000,50000) plus_uword(50000,20000,4464) plus_word(0,0,0) - plus_word(-1000,1000,0) ; TODO fix type error - plus_word(-500,1000,500) ; TODO fix type error + plus_word(-1000,1000,0) + plus_word(-500,1000,500) plus_word(0,-3333,-3333) plus_word(-3333,0,-3333) diff --git a/examples/arithmetic/remainder.p8 b/examples/arithmetic/remainder.p8 index 9bb039e2c..a8c5d69e5 100644 --- a/examples/arithmetic/remainder.p8 +++ b/examples/arithmetic/remainder.p8 @@ -3,7 +3,6 @@ ; TODO implement REMAINDER asmgeneration - main { sub start() {