fix certain assignment data type mismatch crash

This commit is contained in:
Irmen de Jong 2024-01-25 21:14:20 +01:00
parent 8cf0b6cf51
commit d4a2031c07
8 changed files with 48 additions and 31 deletions

View File

@ -4,10 +4,7 @@ import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.maySwapOperandOrder import prog8.ast.maySwapOperandOrder
import prog8.ast.statements.ForLoop import prog8.ast.statements.*
import prog8.ast.statements.RepeatLoop
import prog8.ast.statements.VarDecl
import prog8.ast.statements.VarDeclType
import prog8.ast.walk.AstWalker import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification import prog8.ast.walk.IAstModification
import prog8.code.core.AssociativeOperators import prog8.code.core.AssociativeOperators
@ -35,6 +32,19 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors:
return noModifications return noModifications
} }
override fun after(numLiteral: NumericLiteral, parent: Node): Iterable<IAstModification> {
if(parent is Assignment) {
val iDt = parent.target.inferType(program)
if(iDt.isKnown && !iDt.isBool && !iDt.istype(numLiteral.type)) {
val casted = numLiteral.cast(iDt.getOr(DataType.UNDEFINED))
if(casted.isValid) {
return listOf(IAstModification.ReplaceNode(numLiteral, casted.value!!, parent))
}
}
}
return noModifications
}
override fun after(containment: ContainmentCheck, parent: Node): Iterable<IAstModification> { override fun after(containment: ContainmentCheck, parent: Node): Iterable<IAstModification> {
val result = containment.constValue(program) val result = containment.constValue(program)
if(result!=null) if(result!=null)
@ -312,14 +322,14 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors:
if(stepLiteral!=null) { if(stepLiteral!=null) {
val stepCast = stepLiteral.cast(targetDt) val stepCast = stepLiteral.cast(targetDt)
if(stepCast.isValid) if(stepCast.isValid)
stepCast.valueOrZero() stepCast.value!!
else else
range.step range.step
} else { } else {
range.step range.step
} }
return RangeExpression(fromCast.valueOrZero(), toCast.valueOrZero(), newStep, range.position) return RangeExpression(fromCast.value!!, toCast.value!!, newStep, range.position)
} }
// adjust the datatype of a range expression in for loops to the loop variable. // adjust the datatype of a range expression in for loops to the loop variable.
@ -378,7 +388,7 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors:
if(decl.datatype!=DataType.BOOL || valueDt.isnot(DataType.UBYTE)) { if(decl.datatype!=DataType.BOOL || valueDt.isnot(DataType.UBYTE)) {
val cast = numval.cast(decl.datatype) val cast = numval.cast(decl.datatype)
if (cast.isValid) if (cast.isValid)
return listOf(IAstModification.ReplaceNode(numval, cast.valueOrZero(), decl)) return listOf(IAstModification.ReplaceNode(numval, cast.value!!, decl))
} }
} }
} }

View File

@ -40,7 +40,7 @@ class VarConstantValueTypeAdjuster(
declConstValue.linkParents(decl) declConstValue.linkParents(decl)
val cast = declConstValue.cast(decl.datatype) val cast = declConstValue.cast(decl.datatype)
if (cast.isValid) if (cast.isValid)
return listOf(IAstModification.ReplaceNode(decl.value!!, cast.valueOrZero(), decl)) return listOf(IAstModification.ReplaceNode(decl.value!!, cast.value!!, decl))
} }
} }
} catch (x: UndefinedSymbolError) { } catch (x: UndefinedSymbolError) {

View File

@ -32,8 +32,8 @@ class ExpressionSimplifier(private val program: Program,
val literal = typecast.expression as? NumericLiteral val literal = typecast.expression as? NumericLiteral
if (literal != null) { if (literal != null) {
val newLiteral = literal.cast(typecast.type) val newLiteral = literal.cast(typecast.type)
if (newLiteral.isValid && newLiteral.valueOrZero() !== literal) { if (newLiteral.isValid && newLiteral.value!! !== literal) {
mods += IAstModification.ReplaceNode(typecast, newLiteral.valueOrZero(), parent) mods += IAstModification.ReplaceNode(typecast, newLiteral.value!!, parent)
} }
} }

View File

@ -1687,7 +1687,7 @@ internal class AstChecker(private val program: Program,
if(cast==null || !cast.isValid) if(cast==null || !cast.isValid)
-9999999 -9999999
else else
cast.valueOrZero().number.toInt() cast.value!!.number.toInt()
} }
else -> -9999999 else -> -9999999
} }

View File

@ -207,7 +207,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
fun castLiteral(cvalue2: NumericLiteral): List<IAstModification.ReplaceNode> { fun castLiteral(cvalue2: NumericLiteral): List<IAstModification.ReplaceNode> {
val cast = cvalue2.cast(targettype) val cast = cvalue2.cast(targettype)
return if(cast.isValid) return if(cast.isValid)
listOf(IAstModification.ReplaceNode(assignment.value, cast.valueOrZero(), assignment)) listOf(IAstModification.ReplaceNode(assignment.value, cast.value!!, assignment))
else else
emptyList() emptyList()
} }
@ -314,7 +314,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
val modifications = mutableListOf<IAstModification>() val modifications = mutableListOf<IAstModification>()
val dt = memread.addressExpression.inferType(program) val dt = memread.addressExpression.inferType(program)
if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) { if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) {
val castedValue = (memread.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.valueOrZero() val castedValue = (memread.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.value
if(castedValue!=null) if(castedValue!=null)
modifications += IAstModification.ReplaceNode(memread.addressExpression, castedValue, memread) modifications += IAstModification.ReplaceNode(memread.addressExpression, castedValue, memread)
else else
@ -328,7 +328,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
val modifications = mutableListOf<IAstModification>() val modifications = mutableListOf<IAstModification>()
val dt = memwrite.addressExpression.inferType(program) val dt = memwrite.addressExpression.inferType(program)
if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) { if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) {
val castedValue = (memwrite.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.valueOrZero() val castedValue = (memwrite.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.value
if(castedValue!=null) if(castedValue!=null)
modifications += IAstModification.ReplaceNode(memwrite.addressExpression, castedValue, memwrite) modifications += IAstModification.ReplaceNode(memwrite.addressExpression, castedValue, memwrite)
else else
@ -349,9 +349,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
if (returnDt istype subReturnType or returnDt.isNotAssignableTo(subReturnType)) if (returnDt istype subReturnType or returnDt.isNotAssignableTo(subReturnType))
return noModifications return noModifications
if (returnValue is NumericLiteral) { if (returnValue is NumericLiteral) {
val cast = returnValue.cast(subroutine.returntypes.single()) val cast = returnValue.cast(subReturnType)
if(cast.isValid) if(cast.isValid)
returnStmt.value = cast.valueOrZero() returnStmt.value = cast.value
} else { } else {
val modifications = mutableListOf<IAstModification>() val modifications = mutableListOf<IAstModification>()
addTypecastOrCastedValueModification(modifications, returnValue, subReturnType, returnStmt) addTypecastOrCastedValueModification(modifications, returnValue, subReturnType, returnStmt)
@ -402,9 +402,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
val castedValue = expressionToCast.cast(requiredType) val castedValue = expressionToCast.cast(requiredType)
if (castedValue.isValid) { if (castedValue.isValid) {
val signOriginal = sign(expressionToCast.number) val signOriginal = sign(expressionToCast.number)
val signCasted = sign(castedValue.valueOrZero().number) val signCasted = sign(castedValue.value!!.number)
if(signOriginal==signCasted) { if(signOriginal==signCasted) {
modifications += IAstModification.ReplaceNode(expressionToCast, castedValue.valueOrZero(), parent) modifications += IAstModification.ReplaceNode(expressionToCast, castedValue.value!!, parent)
} }
return return
} }

View File

@ -50,7 +50,7 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter,
if(typecast.expression is NumericLiteral) { if(typecast.expression is NumericLiteral) {
val value = (typecast.expression as NumericLiteral).cast(typecast.type) val value = (typecast.expression as NumericLiteral).cast(typecast.type)
if(value.isValid) if(value.isValid)
return listOf(IAstModification.ReplaceNode(typecast, value.valueOrZero(), parent)) return listOf(IAstModification.ReplaceNode(typecast, value.value!!, parent))
} }
val sourceDt = typecast.expression.inferType(program) val sourceDt = typecast.expression.inferType(program)

View File

@ -1088,4 +1088,16 @@ main {
errors.errors[3] shouldContain "overflow" errors.errors[3] shouldContain "overflow"
} }
test("type fitting of const assignment values") {
val src="""
main {
sub start() {
&ubyte mapped = 8000
mapped = 6144 >> 9
ubyte @shared ubb = 6144 >> 9
bool @shared bb = 6144
}
}"""
compileText(C64Target(), true, src, writeAssembly = true) shouldNotBe null
}
}) })

View File

@ -366,7 +366,7 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp
val cv = expression.constValue(program) ?: return null val cv = expression.constValue(program) ?: return null
val cast = cv.cast(type) val cast = cv.cast(type)
return if(cast.isValid) { return if(cast.isValid) {
val newval = cast.valueOrZero() val newval = cast.value!!
newval.linkParents(parent) newval.linkParents(parent)
return newval return newval
} }
@ -566,16 +566,11 @@ class NumericLiteral(val type: DataType, // only numerical types allowed
operator fun compareTo(other: NumericLiteral): Int = number.compareTo(other.number) operator fun compareTo(other: NumericLiteral): Int = number.compareTo(other.number)
class ValueAfterCast(val isValid: Boolean, val whyFailed: String?, private val value: NumericLiteral?) { data class ValueAfterCast(val isValid: Boolean, val whyFailed: String?, val value: NumericLiteral?)
fun valueOrZero() = if(isValid) value!! else NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY)
fun linkParent(parent: Node) {
value?.linkParents(parent)
}
}
fun cast(targettype: DataType): ValueAfterCast { fun cast(targettype: DataType): ValueAfterCast {
val result = internalCast(targettype) val result = internalCast(targettype)
result.linkParent(this.parent) result.value?.linkParents(this.parent)
return result return result
} }
@ -870,7 +865,7 @@ class ArrayLiteral(val type: InferredTypes.InferredType, // inferred because
val castArray = value.map { val castArray = value.map {
val cast = (it as NumericLiteral).cast(elementType) val cast = (it as NumericLiteral).cast(elementType)
if(cast.isValid) if(cast.isValid)
cast.valueOrZero() as Expression cast.value!! as Expression
else else
return null // abort return null // abort
}.toTypedArray() }.toTypedArray()
@ -879,12 +874,12 @@ class ArrayLiteral(val type: InferredTypes.InferredType, // inferred because
else if(elementType in WordDatatypes && value.all { it is NumericLiteral || it is AddressOf || it is IdentifierReference}) { else if(elementType in WordDatatypes && value.all { it is NumericLiteral || it is AddressOf || it is IdentifierReference}) {
val castArray = value.map { val castArray = value.map {
when(it) { when(it) {
is AddressOf -> it as Expression is AddressOf -> it
is IdentifierReference -> it as Expression is IdentifierReference -> it
is NumericLiteral -> { is NumericLiteral -> {
val numcast = it.cast(elementType) val numcast = it.cast(elementType)
if(numcast.isValid) if(numcast.isValid)
numcast.valueOrZero() as Expression numcast.value!!
else else
return null // abort return null // abort
} }