diff --git a/codeOptimizers/src/prog8/optimizer/ConstantFoldingOptimizer.kt b/codeOptimizers/src/prog8/optimizer/ConstantFoldingOptimizer.kt index 3262c8567..3721dd7dd 100644 --- a/codeOptimizers/src/prog8/optimizer/ConstantFoldingOptimizer.kt +++ b/codeOptimizers/src/prog8/optimizer/ConstantFoldingOptimizer.kt @@ -4,10 +4,7 @@ import prog8.ast.Node import prog8.ast.Program import prog8.ast.expressions.* import prog8.ast.maySwapOperandOrder -import prog8.ast.statements.ForLoop -import prog8.ast.statements.RepeatLoop -import prog8.ast.statements.VarDecl -import prog8.ast.statements.VarDeclType +import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification import prog8.code.core.AssociativeOperators @@ -35,6 +32,19 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors: return noModifications } + override fun after(numLiteral: NumericLiteral, parent: Node): Iterable { + if(parent is Assignment) { + val iDt = parent.target.inferType(program) + if(iDt.isKnown && !iDt.isBool && !iDt.istype(numLiteral.type)) { + val casted = numLiteral.cast(iDt.getOr(DataType.UNDEFINED)) + if(casted.isValid) { + return listOf(IAstModification.ReplaceNode(numLiteral, casted.value!!, parent)) + } + } + } + return noModifications + } + override fun after(containment: ContainmentCheck, parent: Node): Iterable { val result = containment.constValue(program) if(result!=null) @@ -312,14 +322,14 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors: if(stepLiteral!=null) { val stepCast = stepLiteral.cast(targetDt) if(stepCast.isValid) - stepCast.valueOrZero() + stepCast.value!! else range.step } else { range.step } - return RangeExpression(fromCast.valueOrZero(), toCast.valueOrZero(), newStep, range.position) + return RangeExpression(fromCast.value!!, toCast.value!!, newStep, range.position) } // adjust the datatype of a range expression in for loops to the loop variable. @@ -378,7 +388,7 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors: if(decl.datatype!=DataType.BOOL || valueDt.isnot(DataType.UBYTE)) { val cast = numval.cast(decl.datatype) if (cast.isValid) - return listOf(IAstModification.ReplaceNode(numval, cast.valueOrZero(), decl)) + return listOf(IAstModification.ReplaceNode(numval, cast.value!!, decl)) } } } diff --git a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt index 1487b8dec..b13597521 100644 --- a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt +++ b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt @@ -40,7 +40,7 @@ class VarConstantValueTypeAdjuster( declConstValue.linkParents(decl) val cast = declConstValue.cast(decl.datatype) if (cast.isValid) - return listOf(IAstModification.ReplaceNode(decl.value!!, cast.valueOrZero(), decl)) + return listOf(IAstModification.ReplaceNode(decl.value!!, cast.value!!, decl)) } } } catch (x: UndefinedSymbolError) { diff --git a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt index b864cf1e4..0f3433304 100644 --- a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt +++ b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt @@ -32,8 +32,8 @@ class ExpressionSimplifier(private val program: Program, val literal = typecast.expression as? NumericLiteral if (literal != null) { val newLiteral = literal.cast(typecast.type) - if (newLiteral.isValid && newLiteral.valueOrZero() !== literal) { - mods += IAstModification.ReplaceNode(typecast, newLiteral.valueOrZero(), parent) + if (newLiteral.isValid && newLiteral.value!! !== literal) { + mods += IAstModification.ReplaceNode(typecast, newLiteral.value!!, parent) } } diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index 0c72892f2..f8b60c713 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -1687,7 +1687,7 @@ internal class AstChecker(private val program: Program, if(cast==null || !cast.isValid) -9999999 else - cast.valueOrZero().number.toInt() + cast.value!!.number.toInt() } else -> -9999999 } diff --git a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt index 051476e9b..a16f7e47c 100644 --- a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt +++ b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt @@ -207,7 +207,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val fun castLiteral(cvalue2: NumericLiteral): List { val cast = cvalue2.cast(targettype) return if(cast.isValid) - listOf(IAstModification.ReplaceNode(assignment.value, cast.valueOrZero(), assignment)) + listOf(IAstModification.ReplaceNode(assignment.value, cast.value!!, assignment)) else emptyList() } @@ -314,7 +314,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val val modifications = mutableListOf() val dt = memread.addressExpression.inferType(program) if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) { - val castedValue = (memread.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.valueOrZero() + val castedValue = (memread.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.value if(castedValue!=null) modifications += IAstModification.ReplaceNode(memread.addressExpression, castedValue, memread) else @@ -328,7 +328,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val val modifications = mutableListOf() val dt = memwrite.addressExpression.inferType(program) if(dt.isKnown && dt.getOr(DataType.UWORD)!=DataType.UWORD) { - val castedValue = (memwrite.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.valueOrZero() + val castedValue = (memwrite.addressExpression as? NumericLiteral)?.cast(DataType.UWORD)?.value if(castedValue!=null) modifications += IAstModification.ReplaceNode(memwrite.addressExpression, castedValue, memwrite) else @@ -349,9 +349,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val if (returnDt istype subReturnType or returnDt.isNotAssignableTo(subReturnType)) return noModifications if (returnValue is NumericLiteral) { - val cast = returnValue.cast(subroutine.returntypes.single()) + val cast = returnValue.cast(subReturnType) if(cast.isValid) - returnStmt.value = cast.valueOrZero() + returnStmt.value = cast.value } else { val modifications = mutableListOf() addTypecastOrCastedValueModification(modifications, returnValue, subReturnType, returnStmt) @@ -402,9 +402,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val val castedValue = expressionToCast.cast(requiredType) if (castedValue.isValid) { val signOriginal = sign(expressionToCast.number) - val signCasted = sign(castedValue.valueOrZero().number) + val signCasted = sign(castedValue.value!!.number) if(signOriginal==signCasted) { - modifications += IAstModification.ReplaceNode(expressionToCast, castedValue.valueOrZero(), parent) + modifications += IAstModification.ReplaceNode(expressionToCast, castedValue.value!!, parent) } return } diff --git a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt index 21d7be9c9..9137372ce 100644 --- a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt +++ b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt @@ -50,7 +50,7 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter, if(typecast.expression is NumericLiteral) { val value = (typecast.expression as NumericLiteral).cast(typecast.type) if(value.isValid) - return listOf(IAstModification.ReplaceNode(typecast, value.valueOrZero(), parent)) + return listOf(IAstModification.ReplaceNode(typecast, value.value!!, parent)) } val sourceDt = typecast.expression.inferType(program) diff --git a/compiler/test/TestTypecasts.kt b/compiler/test/TestTypecasts.kt index 62de9b67e..4d01eaec1 100644 --- a/compiler/test/TestTypecasts.kt +++ b/compiler/test/TestTypecasts.kt @@ -1088,4 +1088,16 @@ main { errors.errors[3] shouldContain "overflow" } + test("type fitting of const assignment values") { + val src=""" +main { + sub start() { + &ubyte mapped = 8000 + mapped = 6144 >> 9 + ubyte @shared ubb = 6144 >> 9 + bool @shared bb = 6144 + } +}""" + compileText(C64Target(), true, src, writeAssembly = true) shouldNotBe null + } }) diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index b2475312b..80993cb9a 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -366,7 +366,7 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp val cv = expression.constValue(program) ?: return null val cast = cv.cast(type) return if(cast.isValid) { - val newval = cast.valueOrZero() + val newval = cast.value!! newval.linkParents(parent) return newval } @@ -566,16 +566,11 @@ class NumericLiteral(val type: DataType, // only numerical types allowed operator fun compareTo(other: NumericLiteral): Int = number.compareTo(other.number) - class ValueAfterCast(val isValid: Boolean, val whyFailed: String?, private val value: NumericLiteral?) { - fun valueOrZero() = if(isValid) value!! else NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) - fun linkParent(parent: Node) { - value?.linkParents(parent) - } - } + data class ValueAfterCast(val isValid: Boolean, val whyFailed: String?, val value: NumericLiteral?) fun cast(targettype: DataType): ValueAfterCast { val result = internalCast(targettype) - result.linkParent(this.parent) + result.value?.linkParents(this.parent) return result } @@ -870,7 +865,7 @@ class ArrayLiteral(val type: InferredTypes.InferredType, // inferred because val castArray = value.map { val cast = (it as NumericLiteral).cast(elementType) if(cast.isValid) - cast.valueOrZero() as Expression + cast.value!! as Expression else return null // abort }.toTypedArray() @@ -879,12 +874,12 @@ class ArrayLiteral(val type: InferredTypes.InferredType, // inferred because else if(elementType in WordDatatypes && value.all { it is NumericLiteral || it is AddressOf || it is IdentifierReference}) { val castArray = value.map { when(it) { - is AddressOf -> it as Expression - is IdentifierReference -> it as Expression + is AddressOf -> it + is IdentifierReference -> it is NumericLiteral -> { val numcast = it.cast(elementType) if(numcast.isValid) - numcast.valueOrZero() as Expression + numcast.value!! else return null // abort }