diff --git a/DeprecatedStackVm/src/compiler/Compiler.kt b/DeprecatedStackVm/src/compiler/Compiler.kt index 9f5216e90..24a29ca2d 100644 --- a/DeprecatedStackVm/src/compiler/Compiler.kt +++ b/DeprecatedStackVm/src/compiler/Compiler.kt @@ -484,8 +484,8 @@ internal class Compiler(private val program: Program) { translatePrefixOperator(expr.operator, expr.expression.inferType(program)) } is BinaryExpression -> { - val leftDt = expr.left.inferType(program)!! - val rightDt = expr.right.inferType(program)!! + val leftDt = expr.left.inferType(program) + val rightDt = expr.right.inferType(program) val (commonDt, _) = expr.commonDatatype(leftDt, rightDt, expr.left, expr.right) translate(expr.left) if(leftDt!=commonDt) @@ -654,7 +654,7 @@ internal class Compiler(private val program: Program) { // cast type if needed if(builtinFuncParams!=null) { val paramDts = builtinFuncParams[index].possibleDatatypes - val argDt = arg.inferType(program)!! + val argDt = arg.inferType(program) if(argDt !in paramDts) { for(paramDt in paramDts.sorted()) if(tryConvertType(argDt, paramDt)) @@ -827,8 +827,8 @@ internal class Compiler(private val program: Program) { // swap(x,y) is treated differently, it's not a normal function call if (args.size != 2) throw AstException("swap requires 2 arguments") - val dt1 = args[0].inferType(program)!! - val dt2 = args[1].inferType(program)!! + val dt1 = args[0].inferType(program) + val dt2 = args[1].inferType(program) if (dt1 != dt2) throw AstException("swap requires 2 args of identical type") if (args[0].constValue(program) != null || args[1].constValue(program) != null) @@ -861,7 +861,7 @@ internal class Compiler(private val program: Program) { // (subroutine arguments are not passed via the stack!) for (arg in arguments.zip(subroutine.parameters)) { translate(arg.first) - convertType(arg.first.inferType(program)!!, arg.second.type) // convert types of arguments to required parameter type + convertType(arg.first.inferType(program), arg.second.type) // convert types of arguments to required parameter type val opcode = opcodePopvar(arg.second.type) prog.instr(opcode, callLabel = subroutine.scopedname + "." + arg.second.name) } diff --git a/compiler/src/prog8/ast/expressions/AstExpressions.kt b/compiler/src/prog8/ast/expressions/AstExpressions.kt index c932172af..cda512169 100644 --- a/compiler/src/prog8/ast/expressions/AstExpressions.kt +++ b/compiler/src/prog8/ast/expressions/AstExpressions.kt @@ -30,7 +30,7 @@ sealed class Expression: Node { abstract fun accept(visitor: IAstModifyingVisitor): Expression abstract fun accept(visitor: IAstVisitor) abstract fun referencesIdentifiers(vararg name: String): Boolean // todo: remove this and add identifier usage tracking into CallGraph instead - abstract fun inferType(program: Program): DataType? + abstract fun inferType(program: Program): InferredTypes.InferredType infix fun isSameAs(other: Expression): Boolean { if(this===other) @@ -68,7 +68,7 @@ class PrefixExpression(val operator: String, var expression: Expression, overrid override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name) - override fun inferType(program: Program): DataType? = expression.inferType(program) + override fun inferType(program: Program): InferredTypes.InferredType = expression.inferType(program) override fun toString(): String { return "Prefix($operator $expression)" @@ -94,15 +94,22 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String) = left.referencesIdentifiers(*name) || right.referencesIdentifiers(*name) - override fun inferType(program: Program): DataType? { + override fun inferType(program: Program): InferredTypes.InferredType { val leftDt = left.inferType(program) val rightDt = right.inferType(program) return when (operator) { - "+", "-", "*", "**", "%", "/" -> if (leftDt == null || rightDt == null) null else { - try { - commonDatatype(leftDt, rightDt, null, null).first - } catch (x: FatalAstException) { - null + "+", "-", "*", "**", "%", "/" -> { + if (!leftDt.isKnown || !rightDt.isKnown) + InferredTypes.unknown() + else { + try { + InferredTypes.knownFor(commonDatatype( + leftDt.typeOrElse(DataType.BYTE), + rightDt.typeOrElse(DataType.BYTE), + null, null).first) + } catch (x: FatalAstException) { + InferredTypes.unknown() + } } } "&" -> leftDt @@ -111,7 +118,7 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex "and", "or", "xor", "<", ">", "<=", ">=", - "==", "!=" -> DataType.UBYTE + "==", "!=" -> InferredTypes.knownFor(DataType.UBYTE) "<<", ">>" -> leftDt else -> throw FatalAstException("resulting datatype check for invalid operator $operator") } @@ -191,17 +198,16 @@ class ArrayIndexedExpression(var identifier: IdentifierReference, override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String) = identifier.referencesIdentifiers(*name) - override fun inferType(program: Program): DataType? { + override fun inferType(program: Program): InferredTypes.InferredType { val target = identifier.targetStatement(program.namespace) if (target is VarDecl) { return when (target.datatype) { - in NumericDatatypes -> null - in StringDatatypes -> DataType.UBYTE - in ArrayDatatypes -> ArrayElementTypes[target.datatype] - else -> throw FatalAstException("invalid dt") + in StringDatatypes -> InferredTypes.knownFor(DataType.UBYTE) + in ArrayDatatypes -> InferredTypes.knownFor(ArrayElementTypes.getValue(target.datatype)) + else -> InferredTypes.unknown() } } - return null + return InferredTypes.unknown() } override fun toString(): String { @@ -220,7 +226,7 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name) - override fun inferType(program: Program): DataType? = type + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type) override fun constValue(program: Program): NumericLiteralValue? { val cv = expression.constValue(program) ?: return null return cv.cast(type) @@ -243,7 +249,7 @@ data class AddressOf(var identifier: IdentifierReference, override val position: override fun constValue(program: Program): NumericLiteralValue? = null override fun referencesIdentifiers(vararg name: String) = false - override fun inferType(program: Program) = DataType.UWORD + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) } @@ -259,7 +265,7 @@ class DirectMemoryRead(var addressExpression: Expression, override val position: override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String) = false - override fun inferType(program: Program): DataType? = DataType.UBYTE + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UBYTE) override fun constValue(program: Program): NumericLiteralValue? = null override fun toString(): String { @@ -315,7 +321,7 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed override fun toString(): String = "NumericLiteral(${type.name}:$number)" - override fun inferType(program: Program) = type + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type) override fun hashCode(): Int = Objects.hash(type, number) @@ -399,7 +405,7 @@ class StructLiteralValue(var values: List, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String) = values.any { it.referencesIdentifiers(*name) } - override fun inferType(program: Program) = DataType.STRUCT + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.STRUCT) override fun toString(): String { return "struct{ ${values.joinToString(", ")} }" @@ -423,7 +429,7 @@ class StringLiteralValue(val type: DataType, // only string types override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun toString(): String = "'${escape(value)}'" - override fun inferType(program: Program) = type + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type) operator fun compareTo(other: StringLiteralValue): Int = value.compareTo(other.value) override fun hashCode(): Int = Objects.hash(value, type) override fun equals(other: Any?): Boolean { @@ -458,7 +464,7 @@ class ArrayLiteralValue(val type: DataType, // only array types override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun toString(): String = "$value" - override fun inferType(program: Program) = type + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type) operator fun compareTo(other: ArrayLiteralValue): Int = throw ExpressionError("cannot order compare arrays", position) override fun hashCode(): Int = Objects.hash(value, type) override fun equals(other: Any?): Boolean { @@ -538,18 +544,18 @@ class RangeExpr(var from: Expression, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String): Boolean = from.referencesIdentifiers(*name) || to.referencesIdentifiers(*name) - override fun inferType(program: Program): DataType? { + override fun inferType(program: Program): InferredTypes.InferredType { val fromDt=from.inferType(program) val toDt=to.inferType(program) return when { - fromDt==null || toDt==null -> null - fromDt== DataType.UBYTE && toDt== DataType.UBYTE -> DataType.ARRAY_UB - fromDt== DataType.UWORD && toDt== DataType.UWORD -> DataType.ARRAY_UW - fromDt== DataType.STR && toDt== DataType.STR -> DataType.STR - fromDt== DataType.STR_S && toDt== DataType.STR_S -> DataType.STR_S - fromDt== DataType.WORD || toDt== DataType.WORD -> DataType.ARRAY_W - fromDt== DataType.BYTE || toDt== DataType.BYTE -> DataType.ARRAY_B - else -> DataType.ARRAY_UB + !fromDt.isKnown || !toDt.isKnown -> InferredTypes.unknown() + fromDt istype DataType.UBYTE && toDt istype DataType.UBYTE -> InferredTypes.knownFor(DataType.ARRAY_UB) + fromDt istype DataType.UWORD && toDt istype DataType.UWORD -> InferredTypes.knownFor(DataType.ARRAY_UW) + fromDt istype DataType.STR && toDt istype DataType.STR -> InferredTypes.knownFor(DataType.STR) + fromDt istype DataType.STR_S && toDt istype DataType.STR_S -> InferredTypes.knownFor(DataType.STR_S) + fromDt istype DataType.WORD || toDt istype DataType.WORD -> InferredTypes.knownFor(DataType.ARRAY_W) + fromDt istype DataType.BYTE || toDt istype DataType.BYTE -> InferredTypes.knownFor(DataType.ARRAY_B) + else -> InferredTypes.knownFor(DataType.ARRAY_UB) } } override fun toString(): String { @@ -583,17 +589,21 @@ class RangeExpr(var from: Expression, toVal = toLv.number.toInt() } val stepVal = (step as? NumericLiteralValue)?.number?.toInt() ?: 1 - return when { - fromVal <= toVal -> when { - stepVal <= 0 -> IntRange.EMPTY - stepVal == 1 -> fromVal..toVal - else -> fromVal..toVal step stepVal - } - else -> when { - stepVal >= 0 -> IntRange.EMPTY - stepVal == -1 -> fromVal downTo toVal - else -> fromVal downTo toVal step abs(stepVal) - } + return makeRange(fromVal, toVal, stepVal) + } +} + +internal fun makeRange(fromVal: Int, toVal: Int, stepVal: Int): IntProgression { + return when { + fromVal <= toVal -> when { + stepVal <= 0 -> IntRange.EMPTY + stepVal == 1 -> fromVal..toVal + else -> fromVal..toVal step stepVal + } + else -> when { + stepVal >= 0 -> IntRange.EMPTY + stepVal == -1 -> fromVal downTo toVal + else -> fromVal downTo toVal step abs(stepVal) } } } @@ -613,7 +623,7 @@ class RegisterExpr(val register: Register, override val position: Position) : Ex return "RegisterExpr(register=$register, pos=$position)" } - override fun inferType(program: Program) = DataType.UBYTE + override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UBYTE) } data class IdentifierReference(val nameInSource: List, override val position: Position) : Expression() { @@ -652,10 +662,10 @@ data class IdentifierReference(val nameInSource: List, override val posi override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String): Boolean = nameInSource.last() in name - override fun inferType(program: Program): DataType? { + override fun inferType(program: Program): InferredTypes.InferredType { val targetStmt = targetStatement(program.namespace) if(targetStmt is VarDecl) { - return targetStmt.datatype + return InferredTypes.knownFor(targetStmt.datatype) } else { throw FatalAstException("cannot get datatype from identifier reference ${this}, pos=$position") } @@ -705,7 +715,7 @@ class FunctionCall(override var target: IdentifierReference, if(withDatatypeCheck) { val resultDt = this.inferType(program) - if(resultValue==null || resultDt == resultValue.type) + if(resultValue==null || resultDt istype resultValue.type) return resultValue throw FatalAstException("evaluated const expression result value doesn't match expected datatype $resultDt, pos=$position") } else { @@ -726,27 +736,27 @@ class FunctionCall(override var target: IdentifierReference, override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun referencesIdentifiers(vararg name: String): Boolean = target.referencesIdentifiers(*name) || arglist.any{it.referencesIdentifiers(*name)} - override fun inferType(program: Program): DataType? { + override fun inferType(program: Program): InferredTypes.InferredType { val constVal = constValue(program ,false) if(constVal!=null) - return constVal.type - val stmt = target.targetStatement(program.namespace) ?: return null + return InferredTypes.knownFor(constVal.type) + val stmt = target.targetStatement(program.namespace) ?: return InferredTypes.unknown() when (stmt) { is BuiltinFunctionStatementPlaceholder -> { if(target.nameInSource[0] == "set_carry" || target.nameInSource[0]=="set_irqd" || target.nameInSource[0] == "clear_carry" || target.nameInSource[0]=="clear_irqd") { - return null // these have no return value + return InferredTypes.void() // these have no return value } return builtinFunctionReturnType(target.nameInSource[0], this.arglist, program) } is Subroutine -> { if(stmt.returntypes.isEmpty()) - return null // no return value + return InferredTypes.void() // no return value if(stmt.returntypes.size==1) - return stmt.returntypes[0] - return null // has multiple return types... so not a single resulting datatype possible + return InferredTypes.knownFor(stmt.returntypes[0]) + return InferredTypes.unknown() // has multiple return types... so not a single resulting datatype possible } - else -> return null + else -> return InferredTypes.unknown() } } } diff --git a/compiler/src/prog8/ast/expressions/InferredTypes.kt b/compiler/src/prog8/ast/expressions/InferredTypes.kt new file mode 100644 index 000000000..023879566 --- /dev/null +++ b/compiler/src/prog8/ast/expressions/InferredTypes.kt @@ -0,0 +1,50 @@ +package prog8.ast.expressions + +import prog8.ast.base.DataType + +object InferredTypes { + class InferredType private constructor(val isUnknown: Boolean, val isVoid: Boolean, private var datatype: DataType?) { + init { + if(datatype!=null && (isUnknown || isVoid)) + throw IllegalArgumentException("invalid combination of args") + } + + val isKnown = datatype!=null + fun typeOrElse(alternative: DataType) = if(isUnknown || isVoid) alternative else datatype!! + infix fun istype(type: DataType): Boolean = if(isUnknown || isVoid) false else this.datatype==type + + companion object { + fun unknown() = InferredType(isUnknown = true, isVoid = false, datatype = null) + fun void() = InferredType(isUnknown = false, isVoid = true, datatype = null) + fun known(type: DataType) = InferredType(isUnknown = false, isVoid = false, datatype = type) + } + + override fun equals(other: Any?): Boolean { + if(other !is InferredType) + return false + return isVoid==other.isVoid && datatype==other.datatype + } + } + + private val unknownInstance = InferredType.unknown() + private val voidInstance = InferredType.void() + private val knownInstances = mapOf( + DataType.BYTE to InferredType.known(DataType.BYTE), + DataType.BYTE to InferredType.known(DataType.BYTE), + DataType.UWORD to InferredType.known(DataType.UWORD), + DataType.WORD to InferredType.known(DataType.WORD), + DataType.FLOAT to InferredType.known(DataType.FLOAT), + DataType.STR to InferredType.known(DataType.STR), + DataType.STR_S to InferredType.known(DataType.STR_S), + DataType.ARRAY_UB to InferredType.known(DataType.ARRAY_UB), + DataType.ARRAY_B to InferredType.known(DataType.ARRAY_B), + DataType.ARRAY_UW to InferredType.known(DataType.ARRAY_UW), + DataType.ARRAY_W to InferredType.known(DataType.ARRAY_W), + DataType.ARRAY_F to InferredType.known(DataType.ARRAY_F), + DataType.STRUCT to InferredType.known(DataType.STRUCT) + ) + + fun void() = voidInstance + fun unknown() = unknownInstance + fun knownFor(type: DataType) = knownInstances.getValue(type) +} diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index 99173d554..222277bfd 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -97,14 +97,18 @@ internal class AstChecker(private val program: Program, } if(expectedReturnValues.size==1 && returnStmt.value!=null) { val valueDt = returnStmt.value!!.inferType(program) - if(expectedReturnValues[0]!=valueDt) - checkResult.add(ExpressionError("type $valueDt of return value doesn't match subroutine's return type", returnStmt.value!!.position)) + if(!valueDt.isKnown) { + checkResult.add(ExpressionError("return value type mismatch", returnStmt.value!!.position)) + } else { + if (expectedReturnValues[0] != valueDt.typeOrElse(DataType.STRUCT)) + checkResult.add(ExpressionError("type $valueDt of return value doesn't match subroutine's return type", returnStmt.value!!.position)) + } } super.visit(returnStmt) } override fun visit(ifStatement: IfStatement) { - if(ifStatement.condition.inferType(program) !in IntegerDatatypes) + if(ifStatement.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes) checkResult.add(ExpressionError("condition value should be an integer type", ifStatement.condition.position)) super.visit(ifStatement) } @@ -113,7 +117,7 @@ internal class AstChecker(private val program: Program, if(forLoop.body.containsNoCodeNorVars()) printWarning("for loop body is empty", forLoop.position) - val iterableDt = forLoop.iterable.inferType(program) + val iterableDt = forLoop.iterable.inferType(program).typeOrElse(DataType.BYTE) if(iterableDt !in IterableDatatypes && forLoop.iterable !is RangeExpr) { checkResult.add(ExpressionError("can only loop over an iterable type", forLoop.position)) } else { @@ -328,7 +332,7 @@ internal class AstChecker(private val program: Program, override fun visit(repeatLoop: RepeatLoop) { if(repeatLoop.untilCondition.referencesIdentifiers("A", "X", "Y")) printWarning("using a register in the loop condition is risky (it could get clobbered)", repeatLoop.untilCondition.position) - if(repeatLoop.untilCondition.inferType(program) !in IntegerDatatypes) + if(repeatLoop.untilCondition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes) checkResult.add(ExpressionError("condition value should be an integer type", repeatLoop.untilCondition.position)) super.visit(repeatLoop) } @@ -336,7 +340,7 @@ internal class AstChecker(private val program: Program, override fun visit(whileLoop: WhileLoop) { if(whileLoop.condition.referencesIdentifiers("A", "X", "Y")) printWarning("using a register in the loop condition is risky (it could get clobbered)", whileLoop.condition.position) - if(whileLoop.condition.inferType(program) !in IntegerDatatypes) + if(whileLoop.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes) checkResult.add(ExpressionError("condition value should be an integer type", whileLoop.condition.position)) super.visit(whileLoop) } @@ -350,7 +354,8 @@ internal class AstChecker(private val program: Program, if(stmt.returntypes.size>1) checkResult.add(ExpressionError("It's not possible to store the multiple results of this asmsub call; you should use a small block of custom inline assembly for this.", assignment.value.position)) else { - if(stmt.returntypes.single()!=assignment.target.inferType(program, assignment)) { + val idt = assignment.target.inferType(program, assignment) + if(!idt.isKnown || stmt.returntypes.single()!=idt.typeOrElse(DataType.BYTE)) { checkResult.add(ExpressionError("return type mismatch", assignment.value.position)) } } @@ -402,7 +407,7 @@ internal class AstChecker(private val program: Program, } } } - val targetDt = assignTarget.inferType(program, assignment) + val targetDt = assignTarget.inferType(program, assignment).typeOrElse(DataType.STR) if(targetDt in StringDatatypes || targetDt in ArrayDatatypes) checkResult.add(SyntaxError("cannot assign to a string or array type", assignTarget.position)) @@ -412,13 +417,13 @@ internal class AstChecker(private val program: Program, throw FatalAstException("augmented assignment should have been converted into normal assignment") val targetDatatype = assignTarget.inferType(program, assignment) - if (targetDatatype != null) { + if (targetDatatype.isKnown) { val constVal = assignment.value.constValue(program) if (constVal != null) { - checkValueTypeAndRange(targetDatatype, constVal) + checkValueTypeAndRange(targetDatatype.typeOrElse(DataType.BYTE), constVal) } else { - val sourceDatatype: DataType? = assignment.value.inferType(program) - if (sourceDatatype == null) { + val sourceDatatype = assignment.value.inferType(program) + if (!sourceDatatype.isKnown) { if (assignment.value is FunctionCall) { val targetStmt = (assignment.value as FunctionCall).target.targetStatement(program.namespace) if (targetStmt != null) @@ -426,7 +431,8 @@ internal class AstChecker(private val program: Program, } else checkResult.add(ExpressionError("assignment value is invalid or has no proper datatype", assignment.value.position)) } else { - checkAssignmentCompatible(targetDatatype, assignTarget, sourceDatatype, assignment.value, assignment.position) + checkAssignmentCompatible(targetDatatype.typeOrElse(DataType.BYTE), assignTarget, + sourceDatatype.typeOrElse(DataType.BYTE), assignment.value, assignment.position) } } } @@ -701,7 +707,7 @@ internal class AstChecker(private val program: Program, override fun visit(expr: PrefixExpression) { if(expr.operator=="-") { - val dt = expr.inferType(program) + val dt = expr.inferType(program).typeOrElse(DataType.STRUCT) if (dt != DataType.BYTE && dt != DataType.WORD && dt != DataType.FLOAT) { checkResult.add(ExpressionError("can only take negative of a signed number type", expr.position)) } @@ -710,8 +716,13 @@ internal class AstChecker(private val program: Program, } override fun visit(expr: BinaryExpression) { - val leftDt = expr.left.inferType(program) - val rightDt = expr.right.inferType(program) + val leftIDt = expr.left.inferType(program) + val rightIDt = expr.right.inferType(program) + if(!leftIDt.isKnown || !rightIDt.isKnown) { + throw FatalAstException("can't determine datatype of both expression operands $expr") + } + val leftDt = leftIDt.typeOrElse(DataType.STRUCT) + val rightDt = rightIDt.typeOrElse(DataType.STRUCT) when(expr.operator){ "/", "%" -> { @@ -842,30 +853,30 @@ internal class AstChecker(private val program: Program, val paramTypesForAddressOf = PassByReferenceDatatypes + DataType.UWORD for (arg in args.withIndex().zip(func.parameters)) { val argDt=arg.first.value.inferType(program) - if (argDt != null - && !(argDt isAssignableTo arg.second.possibleDatatypes) - && (argDt != DataType.UWORD || arg.second.possibleDatatypes.intersect(paramTypesForAddressOf).isEmpty())) { + if (argDt.isKnown + && !(argDt.typeOrElse(DataType.STRUCT) isAssignableTo arg.second.possibleDatatypes) + && (argDt.typeOrElse(DataType.STRUCT) != DataType.UWORD || arg.second.possibleDatatypes.intersect(paramTypesForAddressOf).isEmpty())) { checkResult.add(ExpressionError("builtin function '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.possibleDatatypes}", position)) } } if(target.name=="swap") { // swap() is a bit weird because this one is translated into a operations directly, instead of being a function call - val dt1 = args[0].inferType(program)!! - val dt2 = args[1].inferType(program)!! + val dt1 = args[0].inferType(program) + val dt2 = args[1].inferType(program) if (dt1 != dt2) checkResult.add(ExpressionError("swap requires 2 args of identical type", position)) else if (args[0].constValue(program) != null || args[1].constValue(program) != null) checkResult.add(ExpressionError("swap requires 2 variables, not constant value(s)", position)) else if(args[0] isSameAs args[1]) checkResult.add(ExpressionError("swap should have 2 different args", position)) - else if(dt1 !in NumericDatatypes) + else if(dt1.typeOrElse(DataType.STRUCT) !in NumericDatatypes) checkResult.add(ExpressionError("swap requires args of numerical type", position)) } else if(target.name=="all" || target.name=="any") { if((args[0] as? AddressOf)?.identifier?.targetVarDecl(program.namespace)?.datatype in StringDatatypes) { checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position)) } - if(args[0].inferType(program) in StringDatatypes) { + if(args[0].inferType(program).typeOrElse(DataType.STR) in StringDatatypes) { checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position)) } } @@ -875,8 +886,12 @@ internal class AstChecker(private val program: Program, checkResult.add(SyntaxError("invalid number of arguments", position)) else { for (arg in args.withIndex().zip(target.parameters)) { - val argDt = arg.first.value.inferType(program) - if(argDt!=null && !(argDt isAssignableTo arg.second.type)) { + val argIDt = arg.first.value.inferType(program) + if(!argIDt.isKnown) { + throw FatalAstException("can't determine arg dt ${arg.first.value}") + } + val argDt=argIDt.typeOrElse(DataType.STRUCT) + if(!(argDt isAssignableTo arg.second.type)) { // for asm subroutines having STR param it's okay to provide a UWORD (address value) if(!(target.isAsmSubroutine && arg.second.type in StringDatatypes && argDt == DataType.UWORD)) checkResult.add(ExpressionError("subroutine '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.type}", position)) @@ -960,7 +975,7 @@ internal class AstChecker(private val program: Program, checkResult.add(SyntaxError("indexing requires a variable to act upon", arrayIndexedExpression.position)) // check index value 0..255 - val dtx = arrayIndexedExpression.arrayspec.index.inferType(program) + val dtx = arrayIndexedExpression.arrayspec.index.inferType(program).typeOrElse(DataType.STRUCT) if(dtx!= DataType.UBYTE && dtx!= DataType.BYTE) checkResult.add(SyntaxError("array indexing is limited to byte size 0..255", arrayIndexedExpression.position)) @@ -968,7 +983,7 @@ internal class AstChecker(private val program: Program, } override fun visit(whenStatement: WhenStatement) { - val conditionType = whenStatement.condition.inferType(program) + val conditionType = whenStatement.condition.inferType(program).typeOrElse(DataType.STRUCT) if(conditionType !in IntegerDatatypes) checkResult.add(SyntaxError("when condition must be an integer value", whenStatement.position)) val choiceValues = whenStatement.choiceValues(program) @@ -987,12 +1002,14 @@ internal class AstChecker(private val program: Program, val whenStmt = whenChoice.parent as WhenStatement if(whenChoice.values!=null) { val conditionType = whenStmt.condition.inferType(program) + if(!conditionType.isKnown) + throw FatalAstException("can't determine when choice datatype $whenChoice") val constvalues = whenChoice.values!!.map { it.constValue(program) } for(constvalue in constvalues) { when { constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position)) constvalue.type !in IntegerDatatypes -> checkResult.add(SyntaxError("choice value must be a byte or word", whenChoice.position)) - constvalue.type != conditionType -> checkResult.add(SyntaxError("choice value datatype differs from condition value", whenChoice.position)) + constvalue.type != conditionType.typeOrElse(DataType.STRUCT) -> checkResult.add(SyntaxError("choice value datatype differs from condition value", whenChoice.position)) } } } else { @@ -1129,8 +1146,8 @@ internal class AstChecker(private val program: Program, return err("number of values is not the same as the number of members in the struct") for(elt in value.value.zip(struct.statements)) { val vardecl = elt.second as VarDecl - val valuetype = elt.first.inferType(program)!! - if (!(valuetype isAssignableTo vardecl.datatype)) { + val valuetype = elt.first.inferType(program) + if (!valuetype.isKnown || !(valuetype.typeOrElse(DataType.STRUCT) isAssignableTo vardecl.datatype)) { checkResult.add(ExpressionError("invalid struct member init value type $valuetype, expected ${vardecl.datatype}", elt.first.position)) return false } diff --git a/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt b/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt index c32ced96f..dd671899c 100644 --- a/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt +++ b/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt @@ -285,15 +285,16 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi } private fun determineArrayDt(array: Array): DataType { - val datatypesInArray = array.mapNotNull { it.inferType(program) } - if(datatypesInArray.isEmpty()) + val datatypesInArray = array.map { it.inferType(program) } + if(datatypesInArray.isEmpty() || datatypesInArray.any { !it.isKnown }) throw IllegalArgumentException("can't determine type of empty array") + val dts = datatypesInArray.map { it.typeOrElse(DataType.STRUCT) } return when { - DataType.FLOAT in datatypesInArray -> DataType.ARRAY_F - DataType.WORD in datatypesInArray -> DataType.ARRAY_W - DataType.UWORD in datatypesInArray -> DataType.ARRAY_UW - DataType.BYTE in datatypesInArray -> DataType.ARRAY_B - DataType.UBYTE in datatypesInArray -> DataType.ARRAY_UB + DataType.FLOAT in dts -> DataType.ARRAY_F + DataType.WORD in dts -> DataType.ARRAY_W + DataType.UWORD in dts -> DataType.ARRAY_UW + DataType.BYTE in dts -> DataType.ARRAY_B + DataType.UBYTE in dts -> DataType.ARRAY_UB else -> throw IllegalArgumentException("can't determine type of array") } } @@ -351,13 +352,15 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi if(constvalue!=null) { if (expr.operator == "*") { // repeat a string a number of times - return StringLiteralValue(string.inferType(program), + val idt = string.inferType(program) + return StringLiteralValue(idt.typeOrElse(DataType.STR), string.value.repeat(constvalue.number.toInt()), null, expr.position) } } if(expr.operator == "+" && operand is StringLiteralValue) { // concatenate two strings - return StringLiteralValue(string.inferType(program), + val idt = string.inferType(program) + return StringLiteralValue(idt.typeOrElse(DataType.STR), "${string.value}${operand.value}", null, expr.position) } return expr diff --git a/compiler/src/prog8/ast/processing/StatementReorderer.kt b/compiler/src/prog8/ast/processing/StatementReorderer.kt index 499eb8bed..c4a04fc4e 100644 --- a/compiler/src/prog8/ast/processing/StatementReorderer.kt +++ b/compiler/src/prog8/ast/processing/StatementReorderer.kt @@ -192,9 +192,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi return expr2 val leftDt = expr2.left.inferType(program) val rightDt = expr2.right.inferType(program) - if(leftDt!=null && rightDt!=null && leftDt!=rightDt) { + if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) { // determine common datatype and add typecast as required to make left and right equal types - val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt, rightDt, expr2.left, expr2.right) + val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr2.left, expr2.right) if(toFix!=null) { when { toFix===expr2.left -> { @@ -218,32 +218,35 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi return assg // see if a typecast is needed to convert the value's type into the proper target type - val valuetype = assg.value.inferType(program) - val targettype = assg.target.inferType(program, assg) - if(targettype!=null && valuetype!=null) { - if(valuetype!=targettype) { + val valueItype = assg.value.inferType(program) + val targetItype = assg.target.inferType(program, assg) + + if(targetItype.isKnown && valueItype.isKnown) { + val targettype = targetItype.typeOrElse(DataType.STRUCT) + val valuetype = valueItype.typeOrElse(DataType.STRUCT) + if (valuetype != targettype) { if (valuetype isAssignableTo targettype) { assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position) assg.value.linkParents(assg) } // if they're not assignable, we'll get a proper error later from the AstChecker } - } - // struct assignments will be flattened (if it's not a struct literal) - if(valuetype==DataType.STRUCT && targettype==DataType.STRUCT) { - if(assg.value is StructLiteralValue) - return assg // do NOT flatten it at this point!! (the compiler will take care if it, later, if needed) + // struct assignments will be flattened (if it's not a struct literal) + if (valuetype == DataType.STRUCT && targettype == DataType.STRUCT) { + if (assg.value is StructLiteralValue) + return assg // do NOT flatten it at this point!! (the compiler will take care if it, later, if needed) - val assignments = flattenStructAssignmentFromIdentifier(assg, program) // 'structvar1 = structvar2' - return if(assignments.isEmpty()) { - // something went wrong (probably incompatible struct types) - // we'll get an error later from the AstChecker - assg - } else { - val scope = AnonymousScope(assignments.toMutableList(), assg.position) - scope.linkParents(assg.parent) - scope + val assignments = flattenStructAssignmentFromIdentifier(assg, program) // 'structvar1 = structvar2' + return if (assignments.isEmpty()) { + // something went wrong (probably incompatible struct types) + // we'll get an error later from the AstChecker + assg + } else { + val scope = AnonymousScope(assignments.toMutableList(), assg.position) + scope.linkParents(assg.parent) + scope + } } } @@ -283,8 +286,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi when(val sub = call.target.targetStatement(scope)) { is Subroutine -> { for(arg in sub.parameters.zip(call.arglist.withIndex())) { - val argtype = arg.second.value.inferType(program) - if(argtype!=null) { + val argItype = arg.second.value.inferType(program) + if(argItype.isKnown) { + val argtype = argItype.typeOrElse(DataType.STRUCT) val requiredType = arg.first.type if (requiredType != argtype) { if (argtype isAssignableTo requiredType) { @@ -302,8 +306,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi if(func.pure) { // non-pure functions don't get automatic typecasts because sometimes they act directly on their parameters for (arg in func.parameters.zip(call.arglist.withIndex())) { - val argtype = arg.second.value.inferType(program) - if (argtype != null) { + val argItype = arg.second.value.inferType(program) + if (argItype.isKnown) { + val argtype = argItype.typeOrElse(DataType.STRUCT) if (arg.first.possibleDatatypes.any { argtype == it }) continue for (possibleType in arg.first.possibleDatatypes) { @@ -334,7 +339,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi override fun visit(memread: DirectMemoryRead): Expression { // make sure the memory address is an uword val dt = memread.addressExpression.inferType(program) - if(dt!=DataType.UWORD) { + if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { val literaladdr = memread.addressExpression as? NumericLiteralValue if(literaladdr!=null) { memread.addressExpression = literaladdr.cast(DataType.UWORD) @@ -348,7 +353,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi override fun visit(memwrite: DirectMemoryWrite) { val dt = memwrite.addressExpression.inferType(program) - if(dt!=DataType.UWORD) { + if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { val literaladdr = memwrite.addressExpression as? NumericLiteralValue if(literaladdr!=null) { memwrite.addressExpression = literaladdr.cast(DataType.UWORD) @@ -391,7 +396,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi structLv.values = struct.statements.zip(structLv.values).map { val memberDt = (it.first as VarDecl).datatype val valueDt = it.second.inferType(program) - if (valueDt != memberDt) + if (valueDt.typeOrElse(memberDt) != memberDt) TypecastExpression(it.second, memberDt, true, it.second.position) else it.second diff --git a/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt b/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt index 9bd8137a3..0c1161375 100644 --- a/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt +++ b/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt @@ -135,7 +135,7 @@ internal class VarInitValueAndAddressOfCreator(private val program: Program): IA for(arg in args.withIndex().zip(signature.parameters)) { val argvalue = arg.first.value val argDt = argvalue.inferType(program) - if(argDt in PassByReferenceDatatypes && DataType.UWORD in arg.second.possibleDatatypes) { + if(argDt.typeOrElse(DataType.UBYTE) in PassByReferenceDatatypes && DataType.UWORD in arg.second.possibleDatatypes) { if(argvalue !is IdentifierReference) throw CompilerException("pass-by-reference parameter isn't an identifier? $argvalue") val addrOf = AddressOf(argvalue, argvalue.position) diff --git a/compiler/src/prog8/ast/statements/AstStatements.kt b/compiler/src/prog8/ast/statements/AstStatements.kt index d045841f0..4499b9b3c 100644 --- a/compiler/src/prog8/ast/statements/AstStatements.kt +++ b/compiler/src/prog8/ast/statements/AstStatements.kt @@ -368,25 +368,23 @@ data class AssignTarget(val register: Register?, } } - fun inferType(program: Program, stmt: Statement): DataType? { + fun inferType(program: Program, stmt: Statement): InferredTypes.InferredType { if(register!=null) - return DataType.UBYTE + return InferredTypes.knownFor(DataType.UBYTE) if(identifier!=null) { - val symbol = program.namespace.lookup(identifier!!.nameInSource, stmt) ?: return null - if (symbol is VarDecl) return symbol.datatype + val symbol = program.namespace.lookup(identifier!!.nameInSource, stmt) ?: return InferredTypes.unknown() + if (symbol is VarDecl) return InferredTypes.knownFor(symbol.datatype) } if(arrayindexed!=null) { - val dt = arrayindexed!!.inferType(program) - if(dt!=null) - return dt + return arrayindexed!!.inferType(program) } if(memoryAddress!=null) - return DataType.UBYTE + return InferredTypes.knownFor(DataType.UBYTE) - return null + return InferredTypes.unknown() } infix fun isSameAs(value: Expression): Boolean { diff --git a/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt b/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt index 415d67ee6..0b549dd42 100644 --- a/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt +++ b/compiler/src/prog8/compiler/target/c64/codegen2/AsmGen2.kt @@ -663,7 +663,10 @@ internal class AsmGen2(val program: Program, } fun translateSubroutineArgument(parameter: IndexedValue, value: Expression, sub: Subroutine) { - val sourceDt = value.inferType(program)!! + val sourceIDt = value.inferType(program) + if(!sourceIDt.isKnown) + throw AssemblyError("arg type unknown") + val sourceDt = sourceIDt.typeOrElse(DataType.STRUCT) if(!argumentTypeCompatible(sourceDt, parameter.value.type)) throw AssemblyError("argument type incompatible") if(sub.asmParameterRegisters.isEmpty()) { @@ -849,7 +852,7 @@ internal class AsmGen2(val program: Program, private fun translate(stmt: IfStatement) { translateExpression(stmt.condition) - translateTestStack(stmt.condition.inferType(program)!!) + translateTestStack(stmt.condition.inferType(program).typeOrElse(DataType.STRUCT)) val elseLabel = makeLabel("if_else") val endLabel = makeLabel("if_end") out(" beq $elseLabel") @@ -877,7 +880,10 @@ internal class AsmGen2(val program: Program, out(whileLabel) // TODO optimize for the simple cases, can we avoid stack use? translateExpression(stmt.condition) - if(stmt.condition.inferType(program) in ByteDatatypes) { + val conditionDt = stmt.condition.inferType(program) + if(!conditionDt.isKnown) + throw AssemblyError("unknown condition dt") + if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) { out(" inx | lda $ESTACK_LO_HEX,x | beq $endLabel") } else { out(""" @@ -904,7 +910,10 @@ internal class AsmGen2(val program: Program, // TODO optimize this for the simple cases, can we avoid stack use? translate(stmt.body) translateExpression(stmt.untilCondition) - if(stmt.untilCondition.inferType(program) in ByteDatatypes) { + val conditionDt = stmt.untilCondition.inferType(program) + if(!conditionDt.isKnown) + throw AssemblyError("unknown condition dt") + if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) { out(" inx | lda $ESTACK_LO_HEX,x | beq $repeatLabel") } else { out(""" @@ -924,8 +933,10 @@ internal class AsmGen2(val program: Program, translateExpression(stmt.condition) val endLabel = makeLabel("choice_end") val choiceBlocks = mutableListOf>() - val conditionDt = stmt.condition.inferType(program)!! - if(conditionDt in ByteDatatypes) + val conditionDt = stmt.condition.inferType(program) + if(!conditionDt.isKnown) + throw AssemblyError("unknown condition dt") + if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) out(" inx | lda $ESTACK_LO_HEX,x") else out(" inx | lda $ESTACK_LO_HEX,x | ldy $ESTACK_HI_HEX,x") @@ -939,7 +950,7 @@ internal class AsmGen2(val program: Program, choiceBlocks.add(Pair(choiceLabel, choice.statements)) for (cv in choice.values!!) { val value = (cv as NumericLiteralValue).number.toInt() - if(conditionDt in ByteDatatypes) { + if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) { out(" cmp #${value.toHex()} | beq $choiceLabel") } else { out(""" @@ -1026,20 +1037,22 @@ internal class AsmGen2(val program: Program, } private fun translate(stmt: ForLoop) { - val iterableDt = stmt.iterable.inferType(program)!! + val iterableDt = stmt.iterable.inferType(program) + if(!iterableDt.isKnown) + throw AssemblyError("can't determine iterable dt") when(stmt.iterable) { is RangeExpr -> { val range = (stmt.iterable as RangeExpr).toConstantIntegerRange() if(range==null) { - translateForOverNonconstRange(stmt, iterableDt, stmt.iterable as RangeExpr) + translateForOverNonconstRange(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as RangeExpr) } else { if (range.isEmpty()) throw AssemblyError("empty range") - translateForOverConstRange(stmt, iterableDt, range) + translateForOverConstRange(stmt, iterableDt.typeOrElse(DataType.STRUCT), range) } } is IdentifierReference -> { - translateForOverIterableVar(stmt, iterableDt, stmt.iterable as IdentifierReference) + translateForOverIterableVar(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as IdentifierReference) } else -> throw AssemblyError("can't iterate over ${stmt.iterable}") } @@ -1524,7 +1537,7 @@ $endLabel""") } targetIdent!=null -> { val what = asmIdentifierName(targetIdent) - val dt = stmt.target.inferType(program, stmt) + val dt = stmt.target.inferType(program, stmt).typeOrElse(DataType.STRUCT) when (dt) { in ByteDatatypes -> out(if (incr) " inc $what" else " dec $what") in WordDatatypes -> { @@ -1562,7 +1575,7 @@ $endLabel""") targetArrayIdx!=null -> { val index = targetArrayIdx.arrayspec.index val what = asmIdentifierName(targetArrayIdx.identifier) - val arrayDt = targetArrayIdx.identifier.inferType(program)!! + val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT) val elementDt = ArrayElementTypes.getValue(arrayDt) when(index) { is NumericLiteralValue -> { @@ -1676,7 +1689,7 @@ $endLabel""") assignFromRegister(assign.target, (assign.value as RegisterExpr).register) } is IdentifierReference -> { - val type = assign.target.inferType(program, assign)!! + val type = assign.target.inferType(program, assign).typeOrElse(DataType.STRUCT) when(type) { DataType.UBYTE, DataType.BYTE -> assignFromByteVariable(assign.target, assign.value as IdentifierReference) DataType.UWORD, DataType.WORD -> assignFromWordVariable(assign.target, assign.value as IdentifierReference) @@ -1743,8 +1756,9 @@ $endLabel""") val cast = assign.value as TypecastExpression val sourceType = cast.expression.inferType(program) val targetType = assign.target.inferType(program, assign) - if((sourceType in ByteDatatypes && targetType in ByteDatatypes) || - (sourceType in WordDatatypes && targetType in WordDatatypes)) { + if(sourceType.isKnown && targetType.isKnown && + (sourceType.typeOrElse(DataType.STRUCT) in ByteDatatypes && targetType.typeOrElse(DataType.STRUCT) in ByteDatatypes) || + (sourceType.typeOrElse(DataType.STRUCT) in WordDatatypes && targetType.typeOrElse(DataType.STRUCT) in WordDatatypes)) { // no need for a type cast assign.value = cast.expression translate(assign) @@ -1787,7 +1801,7 @@ $endLabel""") private fun translateExpression(expr: TypecastExpression) { translateExpression(expr.expression) - when(expr.expression.inferType(program)!!) { + when(expr.expression.inferType(program).typeOrElse(DataType.STRUCT)) { DataType.UBYTE -> { when(expr.type) { DataType.UBYTE, DataType.BYTE -> {} @@ -1959,7 +1973,7 @@ $endLabel""") private fun translateExpression(expr: IdentifierReference) { val varname = asmIdentifierName(expr) - when(expr.inferType(program)!!) { + when(expr.inferType(program).typeOrElse(DataType.STRUCT)) { DataType.UBYTE, DataType.BYTE -> { out(" lda $varname | sta $ESTACK_LO_HEX,x | dex") } @@ -1979,9 +1993,13 @@ $endLabel""") private val powerOfTwos = setOf(0,1,2,4,8,16,32,64,128,256) private fun translateExpression(expr: BinaryExpression) { - val leftDt = expr.left.inferType(program)!! - val rightDt = expr.right.inferType(program)!! + val leftIDt = expr.left.inferType(program) + val rightIDt = expr.right.inferType(program) + if(!leftIDt.isKnown || !rightIDt.isKnown) + throw AssemblyError("can't infer type of both expression operands") + val leftDt = leftIDt.typeOrElse(DataType.STRUCT) + val rightDt = rightIDt.typeOrElse(DataType.STRUCT) // see if we can apply some optimized routines when(expr.operator) { ">>" -> { @@ -2075,7 +2093,7 @@ $endLabel""") private fun translateExpression(expr: PrefixExpression) { translateExpression(expr.expression) - val type = expr.inferType(program) + val type = expr.inferType(program).typeOrElse(DataType.STRUCT) when(expr.operator) { "+" -> {} "-" -> { @@ -2208,7 +2226,7 @@ $endLabel""") } targetIdent!=null -> { val targetName = asmIdentifierName(targetIdent) - val targetDt = targetIdent.inferType(program)!! + val targetDt = targetIdent.inferType(program).typeOrElse(DataType.STRUCT) when(targetDt) { DataType.UBYTE, DataType.BYTE -> { out(" inx | lda $ESTACK_LO_HEX,x | sta $targetName") @@ -2305,10 +2323,10 @@ $endLabel""") targetArrayIdx!=null -> { val index = targetArrayIdx.arrayspec.index val targetName = asmIdentifierName(targetArrayIdx.identifier) - val arrayDt = targetArrayIdx.identifier.inferType(program)!! out(" lda $sourceName | sta $ESTACK_LO_HEX,x | lda $sourceName+1 | sta $ESTACK_HI_HEX,x | dex") translateExpression(index) out(" inx | lda $ESTACK_LO_HEX,x") + val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT) popAndWriteArrayvalueWithIndexA(arrayDt, targetName) } else -> TODO("assign wordvar to $target") @@ -2364,7 +2382,7 @@ $endLabel""") targetArrayIdx!=null -> { val index = targetArrayIdx.arrayspec.index val targetName = asmIdentifierName(targetArrayIdx.identifier) - val arrayDt = targetArrayIdx.identifier.inferType(program)!! + val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT) out(" lda $sourceName | sta $ESTACK_LO_HEX,x | dex") translateExpression(index) out(" inx | lda $ESTACK_LO_HEX,x") diff --git a/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt b/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt index d8395d235..5d030c590 100644 --- a/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt +++ b/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt @@ -43,7 +43,7 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, when(functionName) { "msb" -> { val arg = fcall.arglist.single() - if(arg.inferType(program) !in WordDatatypes) + if(arg.inferType(program).typeOrElse(DataType.STRUCT) !in WordDatatypes) throw AssemblyError("msb required word argument") if(arg is NumericLiteralValue) throw AssemblyError("should have been const-folded") @@ -61,8 +61,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, } "abs" -> { translateFunctionArguments(fcall.arglist, func) - val dt = fcall.arglist.single().inferType(program)!! - when (dt) { + val dt = fcall.arglist.single().inferType(program) + when (dt.typeOrElse(DataType.STRUCT)) { in ByteDatatypes -> asmgen.out(" jsr prog8_lib.abs_b") in WordDatatypes -> asmgen.out(" jsr prog8_lib.abs_w") DataType.FLOAT -> asmgen.out(" jsr c64flt.abs_f") @@ -86,8 +86,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, } "min", "max", "sum" -> { outputPushAddressAndLenghtOfArray(fcall.arglist[0]) - val dt = fcall.arglist.single().inferType(program)!! - when(dt) { + val dt = fcall.arglist.single().inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_ub") DataType.ARRAY_B -> asmgen.out(" jsr prog8_lib.func_${functionName}_b") DataType.ARRAY_UW -> asmgen.out(" jsr prog8_lib.func_${functionName}_uw") @@ -98,8 +98,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, } "any", "all" -> { outputPushAddressAndLenghtOfArray(fcall.arglist[0]) - val dt = fcall.arglist.single().inferType(program)!! - when(dt) { + val dt = fcall.arglist.single().inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { DataType.ARRAY_B, DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_b") DataType.ARRAY_UW, DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w") DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f") @@ -134,8 +134,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, "lsl" -> { // in-place val what = fcall.arglist.single() - val dt = what.inferType(program)!! - when(dt) { + val dt = what.inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { in ByteDatatypes -> { when(what) { is RegisterExpr -> { @@ -168,8 +168,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, "lsr" -> { // in-place val what = fcall.arglist.single() - val dt = what.inferType(program)!! - when(dt) { + val dt = what.inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { DataType.UBYTE -> { when(what) { is RegisterExpr -> { @@ -208,8 +208,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, "rol" -> { // in-place val what = fcall.arglist.single() - val dt = what.inferType(program)!! - when(dt) { + val dt = what.inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { DataType.UBYTE -> { TODO("rol ubyte") } @@ -222,8 +222,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, "rol2" -> { // in-place val what = fcall.arglist.single() - val dt = what.inferType(program)!! - when(dt) { + val dt = what.inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { DataType.UBYTE -> { TODO("rol2 ubyte") } @@ -236,8 +236,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, "ror" -> { // in-place val what = fcall.arglist.single() - val dt = what.inferType(program)!! - when(dt) { + val dt = what.inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { DataType.UBYTE -> { TODO("ror ubyte") } @@ -250,8 +250,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, "ror2" -> { // in-place val what = fcall.arglist.single() - val dt = what.inferType(program)!! - when(dt) { + val dt = what.inferType(program) + when(dt.typeOrElse(DataType.STRUCT)) { DataType.UBYTE -> { TODO("ror2 ubyte") } diff --git a/compiler/src/prog8/functions/BuiltinFunctions.kt b/compiler/src/prog8/functions/BuiltinFunctions.kt index 753e05141..e863e59cb 100644 --- a/compiler/src/prog8/functions/BuiltinFunctions.kt +++ b/compiler/src/prog8/functions/BuiltinFunctions.kt @@ -112,25 +112,29 @@ val BuiltinFunctions = mapOf( ) -fun builtinFunctionReturnType(function: String, args: List, program: Program): DataType? { +fun builtinFunctionReturnType(function: String, args: List, program: Program): InferredTypes.InferredType { fun datatypeFromIterableArg(arglist: Expression): DataType { if(arglist is ArrayLiteralValue) { if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) { val dt = arglist.value.map {it.inferType(program)} - if(dt.any { it!= DataType.UBYTE && it!= DataType.UWORD && it!= DataType.FLOAT}) { + if(dt.any { !(it istype DataType.UBYTE) && !(it istype DataType.UWORD) && !(it istype DataType.FLOAT)}) { throw FatalAstException("fuction $function only accepts arraysize of numeric values") } - if(dt.any { it== DataType.FLOAT }) return DataType.FLOAT - if(dt.any { it== DataType.UWORD }) return DataType.UWORD + if(dt.any { it istype DataType.FLOAT }) return DataType.FLOAT + if(dt.any { it istype DataType.UWORD }) return DataType.UWORD return DataType.UBYTE } } if(arglist is IdentifierReference) { - return when(val dt = arglist.inferType(program)) { - in NumericDatatypes -> dt!! - in StringDatatypes -> dt!! - in ArrayDatatypes -> ArrayElementTypes.getValue(dt!!) + val idt = arglist.inferType(program) + if(!idt.isKnown) + throw FatalAstException("couldn't determine type of iterable $arglist") + val dt = idt.typeOrElse(DataType.STRUCT) + return when(dt) { + in NumericDatatypes -> dt + in StringDatatypes -> dt + in ArrayDatatypes -> ArrayElementTypes.getValue(dt) else -> throw FatalAstException("function '$function' requires one argument which is an iterable") } } @@ -139,43 +143,43 @@ fun builtinFunctionReturnType(function: String, args: List, program: val func = BuiltinFunctions.getValue(function) if(func.returntype!=null) - return func.returntype + return InferredTypes.knownFor(func.returntype) // function has return values, but the return type depends on the arguments return when (function) { "abs" -> { val dt = args.single().inferType(program) - if(dt in NumericDatatypes) + if(dt.typeOrElse(DataType.STRUCT) in NumericDatatypes) return dt else throw FatalAstException("weird datatype passed to abs $dt") } "max", "min" -> { when(val dt = datatypeFromIterableArg(args.single())) { - in NumericDatatypes -> dt - in StringDatatypes -> DataType.UBYTE - in ArrayDatatypes -> ArrayElementTypes.getValue(dt) - else -> null + in NumericDatatypes -> InferredTypes.knownFor(dt) + in StringDatatypes -> InferredTypes.knownFor(DataType.UBYTE) + in ArrayDatatypes -> InferredTypes.knownFor(ArrayElementTypes.getValue(dt)) + else -> InferredTypes.unknown() } } "sum" -> { when(datatypeFromIterableArg(args.single())) { - DataType.UBYTE, DataType.UWORD -> DataType.UWORD - DataType.BYTE, DataType.WORD -> DataType.WORD - DataType.FLOAT -> DataType.FLOAT - DataType.ARRAY_UB, DataType.ARRAY_UW -> DataType.UWORD - DataType.ARRAY_B, DataType.ARRAY_W -> DataType.WORD - DataType.ARRAY_F -> DataType.FLOAT - in StringDatatypes -> DataType.UWORD - else -> null + DataType.UBYTE, DataType.UWORD -> InferredTypes.knownFor(DataType.UWORD) + DataType.BYTE, DataType.WORD -> InferredTypes.knownFor(DataType.WORD) + DataType.FLOAT -> InferredTypes.knownFor(DataType.FLOAT) + DataType.ARRAY_UB, DataType.ARRAY_UW -> InferredTypes.knownFor(DataType.UWORD) + DataType.ARRAY_B, DataType.ARRAY_W -> InferredTypes.knownFor(DataType.WORD) + DataType.ARRAY_F -> InferredTypes.knownFor(DataType.FLOAT) + in StringDatatypes -> InferredTypes.knownFor(DataType.UWORD) + else -> InferredTypes.unknown() } } "len" -> { // a length can be >255 so in that case, the result is an UWORD instead of an UBYTE // but to avoid a lot of code duplication we simply assume UWORD in all cases for now - return DataType.UWORD + return InferredTypes.knownFor(DataType.UWORD) } - else -> return null + else -> return InferredTypes.unknown() } } diff --git a/compiler/src/prog8/optimizer/ConstantFolding.kt b/compiler/src/prog8/optimizer/ConstantFolding.kt index d26452e1d..0b73f571f 100644 --- a/compiler/src/prog8/optimizer/ConstantFolding.kt +++ b/compiler/src/prog8/optimizer/ConstantFolding.kt @@ -79,7 +79,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { errors.add(ExpressionError("range expression size doesn't match declared array size", decl.value?.position!!)) val constRange = rangeExpr.toConstantIntegerRange() if(constRange!=null) { - val eltType = rangeExpr.inferType(program)!! + val eltType = rangeExpr.inferType(program).typeOrElse(DataType.UBYTE) if(eltType in ByteDatatypes) { decl.value = ArrayLiteralValue(decl.datatype, constRange.map { NumericLiteralValue(eltType, it.toShort(), decl.value!!.position) }.toTypedArray(), @@ -612,7 +612,10 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { val lv = assignment.value as? NumericLiteralValue if(lv!=null) { // see if we can promote/convert a literal value to the required datatype - when(assignment.target.inferType(program, assignment)) { + val idt = assignment.target.inferType(program, assignment) + if(!idt.isKnown) + return assignment + when(idt.typeOrElse(DataType.STRUCT)) { DataType.UWORD -> { // we can convert to UWORD: any UBYTE, BYTE/WORD that are >=0, FLOAT that's an integer 0..65535, if(lv.type== DataType.UBYTE) diff --git a/compiler/src/prog8/optimizer/SimplifyExpressions.kt b/compiler/src/prog8/optimizer/SimplifyExpressions.kt index 0b0e9c474..6ef3d0ff7 100644 --- a/compiler/src/prog8/optimizer/SimplifyExpressions.kt +++ b/compiler/src/prog8/optimizer/SimplifyExpressions.kt @@ -1,10 +1,7 @@ package prog8.optimizer import prog8.ast.Program -import prog8.ast.base.AstException -import prog8.ast.base.DataType -import prog8.ast.base.IntegerDatatypes -import prog8.ast.base.NumericDatatypes +import prog8.ast.base.* import prog8.ast.expressions.* import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.statements.Assignment @@ -136,9 +133,14 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying val constTrue = NumericLiteralValue.fromBoolean(true, expr.position) val constFalse = NumericLiteralValue.fromBoolean(false, expr.position) - val leftDt = expr.left.inferType(program) - val rightDt = expr.right.inferType(program) - if (leftDt != null && rightDt != null && leftDt != rightDt) { + val leftIDt = expr.left.inferType(program) + val rightIDt = expr.right.inferType(program) + if(!leftIDt.isKnown || !rightIDt.isKnown) + throw FatalAstException("can't determine datatype of both expression operands $expr") + + val leftDt = leftIDt.typeOrElse(DataType.STRUCT) + val rightDt = rightIDt.typeOrElse(DataType.STRUCT) + if (leftDt != rightDt) { // try to convert a datatype into the other (where ddd if (adjustDatatypes(expr, leftVal, leftDt, rightVal, rightDt)) { optimizationsDone++ @@ -226,7 +228,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying val x = expr.right val y = determineY(x, leftBinExpr) if(y!=null) { - val yPlus1 = BinaryExpression(y, "+", NumericLiteralValue(leftDt!!, 1, y.position), y.position) + val yPlus1 = BinaryExpression(y, "+", NumericLiteralValue(leftDt, 1, y.position), y.position) return BinaryExpression(x, "*", yPlus1, x.position) } } else { @@ -235,7 +237,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying val x = expr.right val y = determineY(x, leftBinExpr) if(y!=null) { - val yMinus1 = BinaryExpression(y, "-", NumericLiteralValue(leftDt!!, 1, y.position), y.position) + val yMinus1 = BinaryExpression(y, "-", NumericLiteralValue(leftDt, 1, y.position), y.position) return BinaryExpression(x, "*", yMinus1, x.position) } } @@ -590,7 +592,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying "%" -> { if (cv == 1.0) { optimizationsDone++ - return NumericLiteralValue(expr.inferType(program)!!, 0, expr.position) + return NumericLiteralValue(expr.inferType(program).typeOrElse(DataType.STRUCT), 0, expr.position) } else if (cv == 2.0) { optimizationsDone++ expr.operator = "&" @@ -613,7 +615,10 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying // right value is a constant, see if we can optimize val rightConst: NumericLiteralValue = rightVal val cv = rightConst.number.toDouble() - val leftDt = expr.left.inferType(program) + val leftIDt = expr.left.inferType(program) + if(!leftIDt.isKnown) + return expr + val leftDt = leftIDt.typeOrElse(DataType.STRUCT) when(cv) { -1.0 -> { // '/' -> -left @@ -701,7 +706,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying return expr.left } 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0 -> { - if(leftValue.inferType(program) in IntegerDatatypes) { + if(leftValue.inferType(program).typeOrElse(DataType.STRUCT) in IntegerDatatypes) { // times a power of two => shift left optimizationsDone++ val numshifts = log2(cv).toInt() @@ -709,7 +714,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying } } -2.0, -4.0, -8.0, -16.0, -32.0, -64.0, -128.0, -256.0, -512.0, -1024.0, -2048.0, -4096.0, -8192.0, -16384.0, -32768.0, -65536.0 -> { - if(leftValue.inferType(program) in IntegerDatatypes) { + if(leftValue.inferType(program).typeOrElse(DataType.STRUCT) in IntegerDatatypes) { // times a negative power of two => negate, then shift left optimizationsDone++ val numshifts = log2(-cv).toInt() diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index b69ae3726..b43675418 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -10,6 +10,7 @@ import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstVisitor import prog8.ast.statements.* import prog8.compiler.target.c64.Petscii +import prog8.compiler.target.c64.codegen2.AssemblyError import prog8.functions.BuiltinFunctions import kotlin.math.floor @@ -422,7 +423,10 @@ internal class StatementOptimizer(private val program: Program) : IAstModifyingV return NopStatement.insteadOf(assignment) } } - val targetDt = assignment.target.inferType(program, assignment) + val targetIDt = assignment.target.inferType(program, assignment) + if(!targetIDt.isKnown) + throw AssemblyError("can't infer type of assignment target") + val targetDt = targetIDt.typeOrElse(DataType.STRUCT) val bexpr=assignment.value as? BinaryExpression if(bexpr!=null) { val cv = bexpr.right.constValue(program)?.number?.toDouble() diff --git a/compiler/src/prog8/vm/astvm/AstVm.kt b/compiler/src/prog8/vm/astvm/AstVm.kt index b931e40d7..7075412f7 100644 --- a/compiler/src/prog8/vm/astvm/AstVm.kt +++ b/compiler/src/prog8/vm/astvm/AstVm.kt @@ -412,9 +412,11 @@ class AstVm(val program: Program) { stmt.target.arrayindexed != null -> { val arrayvar = stmt.target.arrayindexed!!.identifier.targetVarDecl(program.namespace)!! val arrayvalue = runtimeVariables.get(arrayvar.definingScope(), arrayvar.name) - val elementType = stmt.target.arrayindexed!!.inferType(program)!! val index = evaluate(stmt.target.arrayindexed!!.arrayspec.index, evalCtx).integerValue() - var value = RuntimeValue(elementType, arrayvalue.array!![index].toInt()) + val elementType = stmt.target.arrayindexed!!.inferType(program) + if(!elementType.isKnown) + throw VmExecutionException("unknown/void elt type") + var value = RuntimeValue(elementType.typeOrElse(DataType.BYTE), arrayvalue.array!![index].toInt()) value = when { stmt.operator == "++" -> value.inc() stmt.operator == "--" -> value.dec() @@ -472,7 +474,8 @@ class AstVm(val program: Program) { loopvarDt = DataType.UBYTE loopvar = IdentifierReference(listOf(stmt.loopRegister.name), stmt.position) } else { - loopvarDt = stmt.loopVar!!.inferType(program)!! + val dt = stmt.loopVar!!.inferType(program) + loopvarDt = dt.typeOrElse(DataType.UBYTE) loopvar = stmt.loopVar!! } val iterator = iterable.iterator() @@ -619,8 +622,10 @@ class AstVm(val program: Program) { else { val address = (vardecl.value as NumericLiteralValue).number.toInt() val index = evaluate(targetArrayIndexed.arrayspec.index, evalCtx).integerValue() - val elementType = targetArrayIndexed.inferType(program)!! - when(elementType) { + val elementType = targetArrayIndexed.inferType(program) + if(!elementType.isKnown) + throw VmExecutionException("unknown/void array elt type $targetArrayIndexed") + when(elementType.typeOrElse(DataType.UBYTE)) { DataType.UBYTE -> mem.setUByte(address+index, value.byteval!!) DataType.BYTE -> mem.setSByte(address+index, value.byteval!!) DataType.UWORD -> mem.setUWord(address+index*2, value.wordval!!) diff --git a/compiler/src/prog8/vm/astvm/Expressions.kt b/compiler/src/prog8/vm/astvm/Expressions.kt index 84647ffee..7de4c5de1 100644 --- a/compiler/src/prog8/vm/astvm/Expressions.kt +++ b/compiler/src/prog8/vm/astvm/Expressions.kt @@ -12,7 +12,6 @@ import prog8.ast.statements.Subroutine import prog8.ast.statements.VarDecl import prog8.vm.RuntimeValue import prog8.vm.RuntimeValueRange -import kotlin.math.abs typealias BuiltinfunctionCaller = (name: String, args: List, flags: StatusFlags) -> RuntimeValue? @@ -147,24 +146,22 @@ fun evaluate(expr: Expression, ctx: EvalContext): RuntimeValue { } is RangeExpr -> { val cRange = expr.toConstantIntegerRange() - if(cRange!=null) - return RuntimeValueRange(expr.inferType(ctx.program)!!, cRange) + if(cRange!=null) { + val dt = expr.inferType(ctx.program) + if(dt.isKnown) + return RuntimeValueRange(dt.typeOrElse(DataType.UBYTE), cRange) + else + throw VmExecutionException("couldn't determine datatype") + } val fromVal = evaluate(expr.from, ctx).integerValue() val toVal = evaluate(expr.to, ctx).integerValue() val stepVal = evaluate(expr.step, ctx).integerValue() - val range = when { - fromVal <= toVal -> when { - stepVal <= 0 -> IntRange.EMPTY - stepVal == 1 -> fromVal..toVal - else -> fromVal..toVal step stepVal - } - else -> when { - stepVal >= 0 -> IntRange.EMPTY - stepVal == -1 -> fromVal downTo toVal - else -> fromVal downTo toVal step abs(stepVal) - } - } - return RuntimeValueRange(expr.inferType(ctx.program)!!, range) + val range = makeRange(fromVal, toVal, stepVal) + val dt = expr.inferType(ctx.program) + if(dt.isKnown) + return RuntimeValueRange(dt.typeOrElse(DataType.UBYTE), range) + else + throw VmExecutionException("couldn't determine datatype") } else -> { throw VmExecutionException("unimplemented expression node $expr")