fix certain assignment data type mismatch crash

This commit is contained in:
Irmen de Jong 2024-01-25 21:14:20 +01:00
parent 8cf0b6cf51
commit d4a2031c07
8 changed files with 48 additions and 31 deletions

View File

@ -4,10 +4,7 @@ import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.expressions.*
import prog8.ast.maySwapOperandOrder
import prog8.ast.statements.ForLoop
import prog8.ast.statements.RepeatLoop
import prog8.ast.statements.VarDecl
import prog8.ast.statements.VarDeclType
import prog8.ast.statements.*
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
import prog8.code.core.AssociativeOperators
@ -35,6 +32,19 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors:
return noModifications
}
override fun after(numLiteral: NumericLiteral, parent: Node): Iterable<IAstModification> {
if(parent is Assignment) {
val iDt = parent.target.inferType(program)
if(iDt.isKnown && !iDt.isBool && !iDt.istype(numLiteral.type)) {
val casted = numLiteral.cast(iDt.getOr(DataType.UNDEFINED))
if(casted.isValid) {
return listOf(IAstModification.ReplaceNode(numLiteral, casted.value!!, parent))
}
}
}
return noModifications
}
override fun after(containment: ContainmentCheck, parent: Node): Iterable<IAstModification> {
val result = containment.constValue(program)
if(result!=null)
@ -312,14 +322,14 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors:
if(stepLiteral!=null) {
val stepCast = stepLiteral.cast(targetDt)
if(stepCast.isValid)
stepCast.valueOrZero()
stepCast.value!!
else
range.step
} else {
range.step
}
return RangeExpression(fromCast.valueOrZero(), toCast.valueOrZero(), newStep, range.position)
return RangeExpression(fromCast.value!!, toCast.value!!, newStep, range.position)
}
// adjust the datatype of a range expression in for loops to the loop variable.
@ -378,7 +388,7 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors:
if(decl.datatype!=DataType.BOOL || valueDt.isnot(DataType.UBYTE)) {
val cast = numval.cast(decl.datatype)
if (cast.isValid)
return listOf(IAstModification.ReplaceNode(numval, cast.valueOrZero(), decl))
return listOf(IAstModification.ReplaceNode(numval, cast.value!!, decl))
}
}
}

View File

@ -40,7 +40,7 @@ class VarConstantValueTypeAdjuster(
declConstValue.linkParents(decl)
val cast = declConstValue.cast(decl.datatype)
if (cast.isValid)
return listOf(IAstModification.ReplaceNode(decl.value!!, cast.valueOrZero(), decl))
return listOf(IAstModification.ReplaceNode(decl.value!!, cast.value!!, decl))
}
}
} catch (x: UndefinedSymbolError) {

View File

@ -32,8 +32,8 @@ class ExpressionSimplifier(private val program: Program,
val literal = typecast.expression as? NumericLiteral
if (literal != null) {
val newLiteral = literal.cast(typecast.type)
if (newLiteral.isValid && newLiteral.valueOrZero() !== literal) {
mods += IAstModification.ReplaceNode(typecast, newLiteral.valueOrZero(), parent)
if (newLiteral.isValid && newLiteral.value!! !== literal) {
mods += IAstModification.ReplaceNode(typecast, newLiteral.value!!, parent)
}
}

View File

@ -1687,7 +1687,7 @@ internal class AstChecker(private val program: Program,
if(cast==null || !cast.isValid)
-9999999
else
cast.valueOrZero().number.toInt()
cast.value!!.number.toInt()
}
else -> -9999999
}

View File

@ -207,7 +207,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
fun castLiteral(cvalue2: NumericLiteral): List<IAstModification.ReplaceNode> {
val cast = cvalue2.cast(targettype)
return if(cast.isValid)
listOf(IAstModification.ReplaceNode(assignment.value, cast.valueOrZero(), assignment))
listOf(IAstModification.ReplaceNode(assignment.value, cast.value!!, assignment))
else
emptyList()
}
@ -314,7 +314,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
val modifications = mutableListOf<IAstModification>()
val dt = memread.addressExpression.inferType(program)
if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) {
val castedValue = (memread.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.valueOrZero()
val castedValue = (memread.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.value
if(castedValue!=null)
modifications += IAstModification.ReplaceNode(memread.addressExpression, castedValue, memread)
else
@ -328,7 +328,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
val modifications = mutableListOf<IAstModification>()
val dt = memwrite.addressExpression.inferType(program)
if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) {
val castedValue = (memwrite.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.valueOrZero()
val castedValue = (memwrite.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.value
if(castedValue!=null)
modifications += IAstModification.ReplaceNode(memwrite.addressExpression, castedValue, memwrite)
else
@ -349,9 +349,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
if (returnDt istype subReturnType or returnDt.isNotAssignableTo(subReturnType))
return noModifications
if (returnValue is NumericLiteral) {
val cast = returnValue.cast(subroutine.returntypes.single())
val cast = returnValue.cast(subReturnType)
if(cast.isValid)
returnStmt.value = cast.valueOrZero()
returnStmt.value = cast.value
} else {
val modifications = mutableListOf<IAstModification>()
addTypecastOrCastedValueModification(modifications, returnValue, subReturnType, returnStmt)
@ -402,9 +402,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
val castedValue = expressionToCast.cast(requiredType)
if (castedValue.isValid) {
val signOriginal = sign(expressionToCast.number)
val signCasted = sign(castedValue.valueOrZero().number)
val signCasted = sign(castedValue.value!!.number)
if(signOriginal==signCasted) {
modifications += IAstModification.ReplaceNode(expressionToCast, castedValue.valueOrZero(), parent)
modifications += IAstModification.ReplaceNode(expressionToCast, castedValue.value!!, parent)
}
return
}

View File

@ -50,7 +50,7 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter,
if(typecast.expression is NumericLiteral) {
val value = (typecast.expression as NumericLiteral).cast(typecast.type)
if(value.isValid)
return listOf(IAstModification.ReplaceNode(typecast, value.valueOrZero(), parent))
return listOf(IAstModification.ReplaceNode(typecast, value.value!!, parent))
}
val sourceDt = typecast.expression.inferType(program)

View File

@ -1088,4 +1088,16 @@ main {
errors.errors[3] shouldContain "overflow"
}
test("type fitting of const assignment values") {
val src="""
main {
sub start() {
&ubyte mapped = 8000
mapped = 6144 >> 9
ubyte @shared ubb = 6144 >> 9
bool @shared bb = 6144
}
}"""
compileText(C64Target(), true, src, writeAssembly = true) shouldNotBe null
}
})

View File

@ -366,7 +366,7 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp
val cv = expression.constValue(program) ?: return null
val cast = cv.cast(type)
return if(cast.isValid) {
val newval = cast.valueOrZero()
val newval = cast.value!!
newval.linkParents(parent)
return newval
}
@ -566,16 +566,11 @@ class NumericLiteral(val type: DataType, // only numerical types allowed
operator fun compareTo(other: NumericLiteral): Int = number.compareTo(other.number)
class ValueAfterCast(val isValid: Boolean, val whyFailed: String?, private val value: NumericLiteral?) {
fun valueOrZero() = if(isValid) value!! else NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY)
fun linkParent(parent: Node) {
value?.linkParents(parent)
}
}
data class ValueAfterCast(val isValid: Boolean, val whyFailed: String?, val value: NumericLiteral?)
fun cast(targettype: DataType): ValueAfterCast {
val result = internalCast(targettype)
result.linkParent(this.parent)
result.value?.linkParents(this.parent)
return result
}
@ -870,7 +865,7 @@ class ArrayLiteral(val type: InferredTypes.InferredType, // inferred because
val castArray = value.map {
val cast = (it as NumericLiteral).cast(elementType)
if(cast.isValid)
cast.valueOrZero() as Expression
cast.value!! as Expression
else
return null // abort
}.toTypedArray()
@ -879,12 +874,12 @@ class ArrayLiteral(val type: InferredTypes.InferredType, // inferred because
else if(elementType in WordDatatypes && value.all { it is NumericLiteral || it is AddressOf || it is IdentifierReference}) {
val castArray = value.map {
when(it) {
is AddressOf -> it as Expression
is IdentifierReference -> it as Expression
is AddressOf -> it
is IdentifierReference -> it
is NumericLiteral -> {
val numcast = it.cast(elementType)
if(numcast.isValid)
numcast.valueOrZero() as Expression
numcast.value!!
else
return null // abort
}