improved handling of inferredType

This commit is contained in:
Irmen de Jong
2019-08-14 02:25:27 +02:00
parent b64d611e02
commit 47297f7e31
16 changed files with 359 additions and 240 deletions

View File

@@ -484,8 +484,8 @@ internal class Compiler(private val program: Program) {
translatePrefixOperator(expr.operator, expr.expression.inferType(program)) translatePrefixOperator(expr.operator, expr.expression.inferType(program))
} }
is BinaryExpression -> { is BinaryExpression -> {
val leftDt = expr.left.inferType(program)!! val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)!! val rightDt = expr.right.inferType(program)
val (commonDt, _) = expr.commonDatatype(leftDt, rightDt, expr.left, expr.right) val (commonDt, _) = expr.commonDatatype(leftDt, rightDt, expr.left, expr.right)
translate(expr.left) translate(expr.left)
if(leftDt!=commonDt) if(leftDt!=commonDt)
@@ -654,7 +654,7 @@ internal class Compiler(private val program: Program) {
// cast type if needed // cast type if needed
if(builtinFuncParams!=null) { if(builtinFuncParams!=null) {
val paramDts = builtinFuncParams[index].possibleDatatypes val paramDts = builtinFuncParams[index].possibleDatatypes
val argDt = arg.inferType(program)!! val argDt = arg.inferType(program)
if(argDt !in paramDts) { if(argDt !in paramDts) {
for(paramDt in paramDts.sorted()) for(paramDt in paramDts.sorted())
if(tryConvertType(argDt, paramDt)) if(tryConvertType(argDt, paramDt))
@@ -827,8 +827,8 @@ internal class Compiler(private val program: Program) {
// swap(x,y) is treated differently, it's not a normal function call // swap(x,y) is treated differently, it's not a normal function call
if (args.size != 2) if (args.size != 2)
throw AstException("swap requires 2 arguments") throw AstException("swap requires 2 arguments")
val dt1 = args[0].inferType(program)!! val dt1 = args[0].inferType(program)
val dt2 = args[1].inferType(program)!! val dt2 = args[1].inferType(program)
if (dt1 != dt2) if (dt1 != dt2)
throw AstException("swap requires 2 args of identical type") throw AstException("swap requires 2 args of identical type")
if (args[0].constValue(program) != null || args[1].constValue(program) != null) if (args[0].constValue(program) != null || args[1].constValue(program) != null)
@@ -861,7 +861,7 @@ internal class Compiler(private val program: Program) {
// (subroutine arguments are not passed via the stack!) // (subroutine arguments are not passed via the stack!)
for (arg in arguments.zip(subroutine.parameters)) { for (arg in arguments.zip(subroutine.parameters)) {
translate(arg.first) translate(arg.first)
convertType(arg.first.inferType(program)!!, arg.second.type) // convert types of arguments to required parameter type convertType(arg.first.inferType(program), arg.second.type) // convert types of arguments to required parameter type
val opcode = opcodePopvar(arg.second.type) val opcode = opcodePopvar(arg.second.type)
prog.instr(opcode, callLabel = subroutine.scopedname + "." + arg.second.name) prog.instr(opcode, callLabel = subroutine.scopedname + "." + arg.second.name)
} }

View File

@@ -30,7 +30,7 @@ sealed class Expression: Node {
abstract fun accept(visitor: IAstModifyingVisitor): Expression abstract fun accept(visitor: IAstModifyingVisitor): Expression
abstract fun accept(visitor: IAstVisitor) abstract fun accept(visitor: IAstVisitor)
abstract fun referencesIdentifiers(vararg name: String): Boolean // todo: remove this and add identifier usage tracking into CallGraph instead abstract fun referencesIdentifiers(vararg name: String): Boolean // todo: remove this and add identifier usage tracking into CallGraph instead
abstract fun inferType(program: Program): DataType? abstract fun inferType(program: Program): InferredTypes.InferredType
infix fun isSameAs(other: Expression): Boolean { infix fun isSameAs(other: Expression): Boolean {
if(this===other) if(this===other)
@@ -68,7 +68,7 @@ class PrefixExpression(val operator: String, var expression: Expression, overrid
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name) override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? = expression.inferType(program) override fun inferType(program: Program): InferredTypes.InferredType = expression.inferType(program)
override fun toString(): String { override fun toString(): String {
return "Prefix($operator $expression)" return "Prefix($operator $expression)"
@@ -94,15 +94,22 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = left.referencesIdentifiers(*name) || right.referencesIdentifiers(*name) override fun referencesIdentifiers(vararg name: String) = left.referencesIdentifiers(*name) || right.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? { override fun inferType(program: Program): InferredTypes.InferredType {
val leftDt = left.inferType(program) val leftDt = left.inferType(program)
val rightDt = right.inferType(program) val rightDt = right.inferType(program)
return when (operator) { return when (operator) {
"+", "-", "*", "**", "%", "/" -> if (leftDt == null || rightDt == null) null else { "+", "-", "*", "**", "%", "/" -> {
try { if (!leftDt.isKnown || !rightDt.isKnown)
commonDatatype(leftDt, rightDt, null, null).first InferredTypes.unknown()
} catch (x: FatalAstException) { else {
null try {
InferredTypes.knownFor(commonDatatype(
leftDt.typeOrElse(DataType.BYTE),
rightDt.typeOrElse(DataType.BYTE),
null, null).first)
} catch (x: FatalAstException) {
InferredTypes.unknown()
}
} }
} }
"&" -> leftDt "&" -> leftDt
@@ -111,7 +118,7 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
"and", "or", "xor", "and", "or", "xor",
"<", ">", "<", ">",
"<=", ">=", "<=", ">=",
"==", "!=" -> DataType.UBYTE "==", "!=" -> InferredTypes.knownFor(DataType.UBYTE)
"<<", ">>" -> leftDt "<<", ">>" -> leftDt
else -> throw FatalAstException("resulting datatype check for invalid operator $operator") else -> throw FatalAstException("resulting datatype check for invalid operator $operator")
} }
@@ -191,17 +198,16 @@ class ArrayIndexedExpression(var identifier: IdentifierReference,
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = identifier.referencesIdentifiers(*name) override fun referencesIdentifiers(vararg name: String) = identifier.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? { override fun inferType(program: Program): InferredTypes.InferredType {
val target = identifier.targetStatement(program.namespace) val target = identifier.targetStatement(program.namespace)
if (target is VarDecl) { if (target is VarDecl) {
return when (target.datatype) { return when (target.datatype) {
in NumericDatatypes -> null in StringDatatypes -> InferredTypes.knownFor(DataType.UBYTE)
in StringDatatypes -> DataType.UBYTE in ArrayDatatypes -> InferredTypes.knownFor(ArrayElementTypes.getValue(target.datatype))
in ArrayDatatypes -> ArrayElementTypes[target.datatype] else -> InferredTypes.unknown()
else -> throw FatalAstException("invalid dt")
} }
} }
return null return InferredTypes.unknown()
} }
override fun toString(): String { override fun toString(): String {
@@ -220,7 +226,7 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name) override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? = type override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
override fun constValue(program: Program): NumericLiteralValue? { override fun constValue(program: Program): NumericLiteralValue? {
val cv = expression.constValue(program) ?: return null val cv = expression.constValue(program) ?: return null
return cv.cast(type) return cv.cast(type)
@@ -243,7 +249,7 @@ data class AddressOf(var identifier: IdentifierReference, override val position:
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun referencesIdentifiers(vararg name: String) = false override fun referencesIdentifiers(vararg name: String) = false
override fun inferType(program: Program) = DataType.UWORD override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD)
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
} }
@@ -259,7 +265,7 @@ class DirectMemoryRead(var addressExpression: Expression, override val position:
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = false override fun referencesIdentifiers(vararg name: String) = false
override fun inferType(program: Program): DataType? = DataType.UBYTE override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UBYTE)
override fun constValue(program: Program): NumericLiteralValue? = null override fun constValue(program: Program): NumericLiteralValue? = null
override fun toString(): String { override fun toString(): String {
@@ -315,7 +321,7 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed
override fun toString(): String = "NumericLiteral(${type.name}:$number)" override fun toString(): String = "NumericLiteral(${type.name}:$number)"
override fun inferType(program: Program) = type override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
override fun hashCode(): Int = Objects.hash(type, number) override fun hashCode(): Int = Objects.hash(type, number)
@@ -399,7 +405,7 @@ class StructLiteralValue(var values: List<Expression>,
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = values.any { it.referencesIdentifiers(*name) } override fun referencesIdentifiers(vararg name: String) = values.any { it.referencesIdentifiers(*name) }
override fun inferType(program: Program) = DataType.STRUCT override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.STRUCT)
override fun toString(): String { override fun toString(): String {
return "struct{ ${values.joinToString(", ")} }" return "struct{ ${values.joinToString(", ")} }"
@@ -423,7 +429,7 @@ class StringLiteralValue(val type: DataType, // only string types
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun toString(): String = "'${escape(value)}'" override fun toString(): String = "'${escape(value)}'"
override fun inferType(program: Program) = type override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
operator fun compareTo(other: StringLiteralValue): Int = value.compareTo(other.value) operator fun compareTo(other: StringLiteralValue): Int = value.compareTo(other.value)
override fun hashCode(): Int = Objects.hash(value, type) override fun hashCode(): Int = Objects.hash(value, type)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
@@ -458,7 +464,7 @@ class ArrayLiteralValue(val type: DataType, // only array types
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun toString(): String = "$value" override fun toString(): String = "$value"
override fun inferType(program: Program) = type override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
operator fun compareTo(other: ArrayLiteralValue): Int = throw ExpressionError("cannot order compare arrays", position) operator fun compareTo(other: ArrayLiteralValue): Int = throw ExpressionError("cannot order compare arrays", position)
override fun hashCode(): Int = Objects.hash(value, type) override fun hashCode(): Int = Objects.hash(value, type)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
@@ -538,18 +544,18 @@ class RangeExpr(var from: Expression,
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String): Boolean = from.referencesIdentifiers(*name) || to.referencesIdentifiers(*name) override fun referencesIdentifiers(vararg name: String): Boolean = from.referencesIdentifiers(*name) || to.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? { override fun inferType(program: Program): InferredTypes.InferredType {
val fromDt=from.inferType(program) val fromDt=from.inferType(program)
val toDt=to.inferType(program) val toDt=to.inferType(program)
return when { return when {
fromDt==null || toDt==null -> null !fromDt.isKnown || !toDt.isKnown -> InferredTypes.unknown()
fromDt== DataType.UBYTE && toDt== DataType.UBYTE -> DataType.ARRAY_UB fromDt istype DataType.UBYTE && toDt istype DataType.UBYTE -> InferredTypes.knownFor(DataType.ARRAY_UB)
fromDt== DataType.UWORD && toDt== DataType.UWORD -> DataType.ARRAY_UW fromDt istype DataType.UWORD && toDt istype DataType.UWORD -> InferredTypes.knownFor(DataType.ARRAY_UW)
fromDt== DataType.STR && toDt== DataType.STR -> DataType.STR fromDt istype DataType.STR && toDt istype DataType.STR -> InferredTypes.knownFor(DataType.STR)
fromDt== DataType.STR_S && toDt== DataType.STR_S -> DataType.STR_S fromDt istype DataType.STR_S && toDt istype DataType.STR_S -> InferredTypes.knownFor(DataType.STR_S)
fromDt== DataType.WORD || toDt== DataType.WORD -> DataType.ARRAY_W fromDt istype DataType.WORD || toDt istype DataType.WORD -> InferredTypes.knownFor(DataType.ARRAY_W)
fromDt== DataType.BYTE || toDt== DataType.BYTE -> DataType.ARRAY_B fromDt istype DataType.BYTE || toDt istype DataType.BYTE -> InferredTypes.knownFor(DataType.ARRAY_B)
else -> DataType.ARRAY_UB else -> InferredTypes.knownFor(DataType.ARRAY_UB)
} }
} }
override fun toString(): String { override fun toString(): String {
@@ -583,17 +589,21 @@ class RangeExpr(var from: Expression,
toVal = toLv.number.toInt() toVal = toLv.number.toInt()
} }
val stepVal = (step as? NumericLiteralValue)?.number?.toInt() ?: 1 val stepVal = (step as? NumericLiteralValue)?.number?.toInt() ?: 1
return when { return makeRange(fromVal, toVal, stepVal)
fromVal <= toVal -> when { }
stepVal <= 0 -> IntRange.EMPTY }
stepVal == 1 -> fromVal..toVal
else -> fromVal..toVal step stepVal internal fun makeRange(fromVal: Int, toVal: Int, stepVal: Int): IntProgression {
} return when {
else -> when { fromVal <= toVal -> when {
stepVal >= 0 -> IntRange.EMPTY stepVal <= 0 -> IntRange.EMPTY
stepVal == -1 -> fromVal downTo toVal stepVal == 1 -> fromVal..toVal
else -> fromVal downTo toVal step abs(stepVal) else -> fromVal..toVal step stepVal
} }
else -> when {
stepVal >= 0 -> IntRange.EMPTY
stepVal == -1 -> fromVal downTo toVal
else -> fromVal downTo toVal step abs(stepVal)
} }
} }
} }
@@ -613,7 +623,7 @@ class RegisterExpr(val register: Register, override val position: Position) : Ex
return "RegisterExpr(register=$register, pos=$position)" return "RegisterExpr(register=$register, pos=$position)"
} }
override fun inferType(program: Program) = DataType.UBYTE override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UBYTE)
} }
data class IdentifierReference(val nameInSource: List<String>, override val position: Position) : Expression() { data class IdentifierReference(val nameInSource: List<String>, override val position: Position) : Expression() {
@@ -652,10 +662,10 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String): Boolean = nameInSource.last() in name override fun referencesIdentifiers(vararg name: String): Boolean = nameInSource.last() in name
override fun inferType(program: Program): DataType? { override fun inferType(program: Program): InferredTypes.InferredType {
val targetStmt = targetStatement(program.namespace) val targetStmt = targetStatement(program.namespace)
if(targetStmt is VarDecl) { if(targetStmt is VarDecl) {
return targetStmt.datatype return InferredTypes.knownFor(targetStmt.datatype)
} else { } else {
throw FatalAstException("cannot get datatype from identifier reference ${this}, pos=$position") throw FatalAstException("cannot get datatype from identifier reference ${this}, pos=$position")
} }
@@ -705,7 +715,7 @@ class FunctionCall(override var target: IdentifierReference,
if(withDatatypeCheck) { if(withDatatypeCheck) {
val resultDt = this.inferType(program) val resultDt = this.inferType(program)
if(resultValue==null || resultDt == resultValue.type) if(resultValue==null || resultDt istype resultValue.type)
return resultValue return resultValue
throw FatalAstException("evaluated const expression result value doesn't match expected datatype $resultDt, pos=$position") throw FatalAstException("evaluated const expression result value doesn't match expected datatype $resultDt, pos=$position")
} else { } else {
@@ -726,27 +736,27 @@ class FunctionCall(override var target: IdentifierReference,
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String): Boolean = target.referencesIdentifiers(*name) || arglist.any{it.referencesIdentifiers(*name)} override fun referencesIdentifiers(vararg name: String): Boolean = target.referencesIdentifiers(*name) || arglist.any{it.referencesIdentifiers(*name)}
override fun inferType(program: Program): DataType? { override fun inferType(program: Program): InferredTypes.InferredType {
val constVal = constValue(program ,false) val constVal = constValue(program ,false)
if(constVal!=null) if(constVal!=null)
return constVal.type return InferredTypes.knownFor(constVal.type)
val stmt = target.targetStatement(program.namespace) ?: return null val stmt = target.targetStatement(program.namespace) ?: return InferredTypes.unknown()
when (stmt) { when (stmt) {
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionStatementPlaceholder -> {
if(target.nameInSource[0] == "set_carry" || target.nameInSource[0]=="set_irqd" || if(target.nameInSource[0] == "set_carry" || target.nameInSource[0]=="set_irqd" ||
target.nameInSource[0] == "clear_carry" || target.nameInSource[0]=="clear_irqd") { target.nameInSource[0] == "clear_carry" || target.nameInSource[0]=="clear_irqd") {
return null // these have no return value return InferredTypes.void() // these have no return value
} }
return builtinFunctionReturnType(target.nameInSource[0], this.arglist, program) return builtinFunctionReturnType(target.nameInSource[0], this.arglist, program)
} }
is Subroutine -> { is Subroutine -> {
if(stmt.returntypes.isEmpty()) if(stmt.returntypes.isEmpty())
return null // no return value return InferredTypes.void() // no return value
if(stmt.returntypes.size==1) if(stmt.returntypes.size==1)
return stmt.returntypes[0] return InferredTypes.knownFor(stmt.returntypes[0])
return null // has multiple return types... so not a single resulting datatype possible return InferredTypes.unknown() // has multiple return types... so not a single resulting datatype possible
} }
else -> return null else -> return InferredTypes.unknown()
} }
} }
} }

View File

@@ -0,0 +1,50 @@
package prog8.ast.expressions
import prog8.ast.base.DataType
object InferredTypes {
class InferredType private constructor(val isUnknown: Boolean, val isVoid: Boolean, private var datatype: DataType?) {
init {
if(datatype!=null && (isUnknown || isVoid))
throw IllegalArgumentException("invalid combination of args")
}
val isKnown = datatype!=null
fun typeOrElse(alternative: DataType) = if(isUnknown || isVoid) alternative else datatype!!
infix fun istype(type: DataType): Boolean = if(isUnknown || isVoid) false else this.datatype==type
companion object {
fun unknown() = InferredType(isUnknown = true, isVoid = false, datatype = null)
fun void() = InferredType(isUnknown = false, isVoid = true, datatype = null)
fun known(type: DataType) = InferredType(isUnknown = false, isVoid = false, datatype = type)
}
override fun equals(other: Any?): Boolean {
if(other !is InferredType)
return false
return isVoid==other.isVoid && datatype==other.datatype
}
}
private val unknownInstance = InferredType.unknown()
private val voidInstance = InferredType.void()
private val knownInstances = mapOf(
DataType.BYTE to InferredType.known(DataType.BYTE),
DataType.BYTE to InferredType.known(DataType.BYTE),
DataType.UWORD to InferredType.known(DataType.UWORD),
DataType.WORD to InferredType.known(DataType.WORD),
DataType.FLOAT to InferredType.known(DataType.FLOAT),
DataType.STR to InferredType.known(DataType.STR),
DataType.STR_S to InferredType.known(DataType.STR_S),
DataType.ARRAY_UB to InferredType.known(DataType.ARRAY_UB),
DataType.ARRAY_B to InferredType.known(DataType.ARRAY_B),
DataType.ARRAY_UW to InferredType.known(DataType.ARRAY_UW),
DataType.ARRAY_W to InferredType.known(DataType.ARRAY_W),
DataType.ARRAY_F to InferredType.known(DataType.ARRAY_F),
DataType.STRUCT to InferredType.known(DataType.STRUCT)
)
fun void() = voidInstance
fun unknown() = unknownInstance
fun knownFor(type: DataType) = knownInstances.getValue(type)
}

View File

@@ -97,14 +97,18 @@ internal class AstChecker(private val program: Program,
} }
if(expectedReturnValues.size==1 && returnStmt.value!=null) { if(expectedReturnValues.size==1 && returnStmt.value!=null) {
val valueDt = returnStmt.value!!.inferType(program) val valueDt = returnStmt.value!!.inferType(program)
if(expectedReturnValues[0]!=valueDt) if(!valueDt.isKnown) {
checkResult.add(ExpressionError("type $valueDt of return value doesn't match subroutine's return type", returnStmt.value!!.position)) checkResult.add(ExpressionError("return value type mismatch", returnStmt.value!!.position))
} else {
if (expectedReturnValues[0] != valueDt.typeOrElse(DataType.STRUCT))
checkResult.add(ExpressionError("type $valueDt of return value doesn't match subroutine's return type", returnStmt.value!!.position))
}
} }
super.visit(returnStmt) super.visit(returnStmt)
} }
override fun visit(ifStatement: IfStatement) { override fun visit(ifStatement: IfStatement) {
if(ifStatement.condition.inferType(program) !in IntegerDatatypes) if(ifStatement.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
checkResult.add(ExpressionError("condition value should be an integer type", ifStatement.condition.position)) checkResult.add(ExpressionError("condition value should be an integer type", ifStatement.condition.position))
super.visit(ifStatement) super.visit(ifStatement)
} }
@@ -113,7 +117,7 @@ internal class AstChecker(private val program: Program,
if(forLoop.body.containsNoCodeNorVars()) if(forLoop.body.containsNoCodeNorVars())
printWarning("for loop body is empty", forLoop.position) printWarning("for loop body is empty", forLoop.position)
val iterableDt = forLoop.iterable.inferType(program) val iterableDt = forLoop.iterable.inferType(program).typeOrElse(DataType.BYTE)
if(iterableDt !in IterableDatatypes && forLoop.iterable !is RangeExpr) { if(iterableDt !in IterableDatatypes && forLoop.iterable !is RangeExpr) {
checkResult.add(ExpressionError("can only loop over an iterable type", forLoop.position)) checkResult.add(ExpressionError("can only loop over an iterable type", forLoop.position))
} else { } else {
@@ -328,7 +332,7 @@ internal class AstChecker(private val program: Program,
override fun visit(repeatLoop: RepeatLoop) { override fun visit(repeatLoop: RepeatLoop) {
if(repeatLoop.untilCondition.referencesIdentifiers("A", "X", "Y")) if(repeatLoop.untilCondition.referencesIdentifiers("A", "X", "Y"))
printWarning("using a register in the loop condition is risky (it could get clobbered)", repeatLoop.untilCondition.position) printWarning("using a register in the loop condition is risky (it could get clobbered)", repeatLoop.untilCondition.position)
if(repeatLoop.untilCondition.inferType(program) !in IntegerDatatypes) if(repeatLoop.untilCondition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
checkResult.add(ExpressionError("condition value should be an integer type", repeatLoop.untilCondition.position)) checkResult.add(ExpressionError("condition value should be an integer type", repeatLoop.untilCondition.position))
super.visit(repeatLoop) super.visit(repeatLoop)
} }
@@ -336,7 +340,7 @@ internal class AstChecker(private val program: Program,
override fun visit(whileLoop: WhileLoop) { override fun visit(whileLoop: WhileLoop) {
if(whileLoop.condition.referencesIdentifiers("A", "X", "Y")) if(whileLoop.condition.referencesIdentifiers("A", "X", "Y"))
printWarning("using a register in the loop condition is risky (it could get clobbered)", whileLoop.condition.position) printWarning("using a register in the loop condition is risky (it could get clobbered)", whileLoop.condition.position)
if(whileLoop.condition.inferType(program) !in IntegerDatatypes) if(whileLoop.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
checkResult.add(ExpressionError("condition value should be an integer type", whileLoop.condition.position)) checkResult.add(ExpressionError("condition value should be an integer type", whileLoop.condition.position))
super.visit(whileLoop) super.visit(whileLoop)
} }
@@ -350,7 +354,8 @@ internal class AstChecker(private val program: Program,
if(stmt.returntypes.size>1) if(stmt.returntypes.size>1)
checkResult.add(ExpressionError("It's not possible to store the multiple results of this asmsub call; you should use a small block of custom inline assembly for this.", assignment.value.position)) checkResult.add(ExpressionError("It's not possible to store the multiple results of this asmsub call; you should use a small block of custom inline assembly for this.", assignment.value.position))
else { else {
if(stmt.returntypes.single()!=assignment.target.inferType(program, assignment)) { val idt = assignment.target.inferType(program, assignment)
if(!idt.isKnown || stmt.returntypes.single()!=idt.typeOrElse(DataType.BYTE)) {
checkResult.add(ExpressionError("return type mismatch", assignment.value.position)) checkResult.add(ExpressionError("return type mismatch", assignment.value.position))
} }
} }
@@ -402,7 +407,7 @@ internal class AstChecker(private val program: Program,
} }
} }
} }
val targetDt = assignTarget.inferType(program, assignment) val targetDt = assignTarget.inferType(program, assignment).typeOrElse(DataType.STR)
if(targetDt in StringDatatypes || targetDt in ArrayDatatypes) if(targetDt in StringDatatypes || targetDt in ArrayDatatypes)
checkResult.add(SyntaxError("cannot assign to a string or array type", assignTarget.position)) checkResult.add(SyntaxError("cannot assign to a string or array type", assignTarget.position))
@@ -412,13 +417,13 @@ internal class AstChecker(private val program: Program,
throw FatalAstException("augmented assignment should have been converted into normal assignment") throw FatalAstException("augmented assignment should have been converted into normal assignment")
val targetDatatype = assignTarget.inferType(program, assignment) val targetDatatype = assignTarget.inferType(program, assignment)
if (targetDatatype != null) { if (targetDatatype.isKnown) {
val constVal = assignment.value.constValue(program) val constVal = assignment.value.constValue(program)
if (constVal != null) { if (constVal != null) {
checkValueTypeAndRange(targetDatatype, constVal) checkValueTypeAndRange(targetDatatype.typeOrElse(DataType.BYTE), constVal)
} else { } else {
val sourceDatatype: DataType? = assignment.value.inferType(program) val sourceDatatype = assignment.value.inferType(program)
if (sourceDatatype == null) { if (!sourceDatatype.isKnown) {
if (assignment.value is FunctionCall) { if (assignment.value is FunctionCall) {
val targetStmt = (assignment.value as FunctionCall).target.targetStatement(program.namespace) val targetStmt = (assignment.value as FunctionCall).target.targetStatement(program.namespace)
if (targetStmt != null) if (targetStmt != null)
@@ -426,7 +431,8 @@ internal class AstChecker(private val program: Program,
} else } else
checkResult.add(ExpressionError("assignment value is invalid or has no proper datatype", assignment.value.position)) checkResult.add(ExpressionError("assignment value is invalid or has no proper datatype", assignment.value.position))
} else { } else {
checkAssignmentCompatible(targetDatatype, assignTarget, sourceDatatype, assignment.value, assignment.position) checkAssignmentCompatible(targetDatatype.typeOrElse(DataType.BYTE), assignTarget,
sourceDatatype.typeOrElse(DataType.BYTE), assignment.value, assignment.position)
} }
} }
} }
@@ -701,7 +707,7 @@ internal class AstChecker(private val program: Program,
override fun visit(expr: PrefixExpression) { override fun visit(expr: PrefixExpression) {
if(expr.operator=="-") { if(expr.operator=="-") {
val dt = expr.inferType(program) val dt = expr.inferType(program).typeOrElse(DataType.STRUCT)
if (dt != DataType.BYTE && dt != DataType.WORD && dt != DataType.FLOAT) { if (dt != DataType.BYTE && dt != DataType.WORD && dt != DataType.FLOAT) {
checkResult.add(ExpressionError("can only take negative of a signed number type", expr.position)) checkResult.add(ExpressionError("can only take negative of a signed number type", expr.position))
} }
@@ -710,8 +716,13 @@ internal class AstChecker(private val program: Program,
} }
override fun visit(expr: BinaryExpression) { override fun visit(expr: BinaryExpression) {
val leftDt = expr.left.inferType(program) val leftIDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program) val rightIDt = expr.right.inferType(program)
if(!leftIDt.isKnown || !rightIDt.isKnown) {
throw FatalAstException("can't determine datatype of both expression operands $expr")
}
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
val rightDt = rightIDt.typeOrElse(DataType.STRUCT)
when(expr.operator){ when(expr.operator){
"/", "%" -> { "/", "%" -> {
@@ -842,30 +853,30 @@ internal class AstChecker(private val program: Program,
val paramTypesForAddressOf = PassByReferenceDatatypes + DataType.UWORD val paramTypesForAddressOf = PassByReferenceDatatypes + DataType.UWORD
for (arg in args.withIndex().zip(func.parameters)) { for (arg in args.withIndex().zip(func.parameters)) {
val argDt=arg.first.value.inferType(program) val argDt=arg.first.value.inferType(program)
if (argDt != null if (argDt.isKnown
&& !(argDt isAssignableTo arg.second.possibleDatatypes) && !(argDt.typeOrElse(DataType.STRUCT) isAssignableTo arg.second.possibleDatatypes)
&& (argDt != DataType.UWORD || arg.second.possibleDatatypes.intersect(paramTypesForAddressOf).isEmpty())) { && (argDt.typeOrElse(DataType.STRUCT) != DataType.UWORD || arg.second.possibleDatatypes.intersect(paramTypesForAddressOf).isEmpty())) {
checkResult.add(ExpressionError("builtin function '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.possibleDatatypes}", position)) checkResult.add(ExpressionError("builtin function '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.possibleDatatypes}", position))
} }
} }
if(target.name=="swap") { if(target.name=="swap") {
// swap() is a bit weird because this one is translated into a operations directly, instead of being a function call // swap() is a bit weird because this one is translated into a operations directly, instead of being a function call
val dt1 = args[0].inferType(program)!! val dt1 = args[0].inferType(program)
val dt2 = args[1].inferType(program)!! val dt2 = args[1].inferType(program)
if (dt1 != dt2) if (dt1 != dt2)
checkResult.add(ExpressionError("swap requires 2 args of identical type", position)) checkResult.add(ExpressionError("swap requires 2 args of identical type", position))
else if (args[0].constValue(program) != null || args[1].constValue(program) != null) else if (args[0].constValue(program) != null || args[1].constValue(program) != null)
checkResult.add(ExpressionError("swap requires 2 variables, not constant value(s)", position)) checkResult.add(ExpressionError("swap requires 2 variables, not constant value(s)", position))
else if(args[0] isSameAs args[1]) else if(args[0] isSameAs args[1])
checkResult.add(ExpressionError("swap should have 2 different args", position)) checkResult.add(ExpressionError("swap should have 2 different args", position))
else if(dt1 !in NumericDatatypes) else if(dt1.typeOrElse(DataType.STRUCT) !in NumericDatatypes)
checkResult.add(ExpressionError("swap requires args of numerical type", position)) checkResult.add(ExpressionError("swap requires args of numerical type", position))
} }
else if(target.name=="all" || target.name=="any") { else if(target.name=="all" || target.name=="any") {
if((args[0] as? AddressOf)?.identifier?.targetVarDecl(program.namespace)?.datatype in StringDatatypes) { if((args[0] as? AddressOf)?.identifier?.targetVarDecl(program.namespace)?.datatype in StringDatatypes) {
checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position)) checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position))
} }
if(args[0].inferType(program) in StringDatatypes) { if(args[0].inferType(program).typeOrElse(DataType.STR) in StringDatatypes) {
checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position)) checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position))
} }
} }
@@ -875,8 +886,12 @@ internal class AstChecker(private val program: Program,
checkResult.add(SyntaxError("invalid number of arguments", position)) checkResult.add(SyntaxError("invalid number of arguments", position))
else { else {
for (arg in args.withIndex().zip(target.parameters)) { for (arg in args.withIndex().zip(target.parameters)) {
val argDt = arg.first.value.inferType(program) val argIDt = arg.first.value.inferType(program)
if(argDt!=null && !(argDt isAssignableTo arg.second.type)) { if(!argIDt.isKnown) {
throw FatalAstException("can't determine arg dt ${arg.first.value}")
}
val argDt=argIDt.typeOrElse(DataType.STRUCT)
if(!(argDt isAssignableTo arg.second.type)) {
// for asm subroutines having STR param it's okay to provide a UWORD (address value) // for asm subroutines having STR param it's okay to provide a UWORD (address value)
if(!(target.isAsmSubroutine && arg.second.type in StringDatatypes && argDt == DataType.UWORD)) if(!(target.isAsmSubroutine && arg.second.type in StringDatatypes && argDt == DataType.UWORD))
checkResult.add(ExpressionError("subroutine '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.type}", position)) checkResult.add(ExpressionError("subroutine '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.type}", position))
@@ -960,7 +975,7 @@ internal class AstChecker(private val program: Program,
checkResult.add(SyntaxError("indexing requires a variable to act upon", arrayIndexedExpression.position)) checkResult.add(SyntaxError("indexing requires a variable to act upon", arrayIndexedExpression.position))
// check index value 0..255 // check index value 0..255
val dtx = arrayIndexedExpression.arrayspec.index.inferType(program) val dtx = arrayIndexedExpression.arrayspec.index.inferType(program).typeOrElse(DataType.STRUCT)
if(dtx!= DataType.UBYTE && dtx!= DataType.BYTE) if(dtx!= DataType.UBYTE && dtx!= DataType.BYTE)
checkResult.add(SyntaxError("array indexing is limited to byte size 0..255", arrayIndexedExpression.position)) checkResult.add(SyntaxError("array indexing is limited to byte size 0..255", arrayIndexedExpression.position))
@@ -968,7 +983,7 @@ internal class AstChecker(private val program: Program,
} }
override fun visit(whenStatement: WhenStatement) { override fun visit(whenStatement: WhenStatement) {
val conditionType = whenStatement.condition.inferType(program) val conditionType = whenStatement.condition.inferType(program).typeOrElse(DataType.STRUCT)
if(conditionType !in IntegerDatatypes) if(conditionType !in IntegerDatatypes)
checkResult.add(SyntaxError("when condition must be an integer value", whenStatement.position)) checkResult.add(SyntaxError("when condition must be an integer value", whenStatement.position))
val choiceValues = whenStatement.choiceValues(program) val choiceValues = whenStatement.choiceValues(program)
@@ -987,12 +1002,14 @@ internal class AstChecker(private val program: Program,
val whenStmt = whenChoice.parent as WhenStatement val whenStmt = whenChoice.parent as WhenStatement
if(whenChoice.values!=null) { if(whenChoice.values!=null) {
val conditionType = whenStmt.condition.inferType(program) val conditionType = whenStmt.condition.inferType(program)
if(!conditionType.isKnown)
throw FatalAstException("can't determine when choice datatype $whenChoice")
val constvalues = whenChoice.values!!.map { it.constValue(program) } val constvalues = whenChoice.values!!.map { it.constValue(program) }
for(constvalue in constvalues) { for(constvalue in constvalues) {
when { when {
constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position)) constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position))
constvalue.type !in IntegerDatatypes -> checkResult.add(SyntaxError("choice value must be a byte or word", whenChoice.position)) constvalue.type !in IntegerDatatypes -> checkResult.add(SyntaxError("choice value must be a byte or word", whenChoice.position))
constvalue.type != conditionType -> checkResult.add(SyntaxError("choice value datatype differs from condition value", whenChoice.position)) constvalue.type != conditionType.typeOrElse(DataType.STRUCT) -> checkResult.add(SyntaxError("choice value datatype differs from condition value", whenChoice.position))
} }
} }
} else { } else {
@@ -1129,8 +1146,8 @@ internal class AstChecker(private val program: Program,
return err("number of values is not the same as the number of members in the struct") return err("number of values is not the same as the number of members in the struct")
for(elt in value.value.zip(struct.statements)) { for(elt in value.value.zip(struct.statements)) {
val vardecl = elt.second as VarDecl val vardecl = elt.second as VarDecl
val valuetype = elt.first.inferType(program)!! val valuetype = elt.first.inferType(program)
if (!(valuetype isAssignableTo vardecl.datatype)) { if (!valuetype.isKnown || !(valuetype.typeOrElse(DataType.STRUCT) isAssignableTo vardecl.datatype)) {
checkResult.add(ExpressionError("invalid struct member init value type $valuetype, expected ${vardecl.datatype}", elt.first.position)) checkResult.add(ExpressionError("invalid struct member init value type $valuetype, expected ${vardecl.datatype}", elt.first.position))
return false return false
} }

View File

@@ -285,15 +285,16 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
} }
private fun determineArrayDt(array: Array<Expression>): DataType { private fun determineArrayDt(array: Array<Expression>): DataType {
val datatypesInArray = array.mapNotNull { it.inferType(program) } val datatypesInArray = array.map { it.inferType(program) }
if(datatypesInArray.isEmpty()) if(datatypesInArray.isEmpty() || datatypesInArray.any { !it.isKnown })
throw IllegalArgumentException("can't determine type of empty array") throw IllegalArgumentException("can't determine type of empty array")
val dts = datatypesInArray.map { it.typeOrElse(DataType.STRUCT) }
return when { return when {
DataType.FLOAT in datatypesInArray -> DataType.ARRAY_F DataType.FLOAT in dts -> DataType.ARRAY_F
DataType.WORD in datatypesInArray -> DataType.ARRAY_W DataType.WORD in dts -> DataType.ARRAY_W
DataType.UWORD in datatypesInArray -> DataType.ARRAY_UW DataType.UWORD in dts -> DataType.ARRAY_UW
DataType.BYTE in datatypesInArray -> DataType.ARRAY_B DataType.BYTE in dts -> DataType.ARRAY_B
DataType.UBYTE in datatypesInArray -> DataType.ARRAY_UB DataType.UBYTE in dts -> DataType.ARRAY_UB
else -> throw IllegalArgumentException("can't determine type of array") else -> throw IllegalArgumentException("can't determine type of array")
} }
} }
@@ -351,13 +352,15 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
if(constvalue!=null) { if(constvalue!=null) {
if (expr.operator == "*") { if (expr.operator == "*") {
// repeat a string a number of times // repeat a string a number of times
return StringLiteralValue(string.inferType(program), val idt = string.inferType(program)
return StringLiteralValue(idt.typeOrElse(DataType.STR),
string.value.repeat(constvalue.number.toInt()), null, expr.position) string.value.repeat(constvalue.number.toInt()), null, expr.position)
} }
} }
if(expr.operator == "+" && operand is StringLiteralValue) { if(expr.operator == "+" && operand is StringLiteralValue) {
// concatenate two strings // concatenate two strings
return StringLiteralValue(string.inferType(program), val idt = string.inferType(program)
return StringLiteralValue(idt.typeOrElse(DataType.STR),
"${string.value}${operand.value}", null, expr.position) "${string.value}${operand.value}", null, expr.position)
} }
return expr return expr

View File

@@ -192,9 +192,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return expr2 return expr2
val leftDt = expr2.left.inferType(program) val leftDt = expr2.left.inferType(program)
val rightDt = expr2.right.inferType(program) val rightDt = expr2.right.inferType(program)
if(leftDt!=null && rightDt!=null && leftDt!=rightDt) { if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) {
// determine common datatype and add typecast as required to make left and right equal types // determine common datatype and add typecast as required to make left and right equal types
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt, rightDt, expr2.left, expr2.right) val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr2.left, expr2.right)
if(toFix!=null) { if(toFix!=null) {
when { when {
toFix===expr2.left -> { toFix===expr2.left -> {
@@ -218,32 +218,35 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return assg return assg
// see if a typecast is needed to convert the value's type into the proper target type // see if a typecast is needed to convert the value's type into the proper target type
val valuetype = assg.value.inferType(program) val valueItype = assg.value.inferType(program)
val targettype = assg.target.inferType(program, assg) val targetItype = assg.target.inferType(program, assg)
if(targettype!=null && valuetype!=null) {
if(valuetype!=targettype) { if(targetItype.isKnown && valueItype.isKnown) {
val targettype = targetItype.typeOrElse(DataType.STRUCT)
val valuetype = valueItype.typeOrElse(DataType.STRUCT)
if (valuetype != targettype) {
if (valuetype isAssignableTo targettype) { if (valuetype isAssignableTo targettype) {
assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position) assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position)
assg.value.linkParents(assg) assg.value.linkParents(assg)
} }
// if they're not assignable, we'll get a proper error later from the AstChecker // if they're not assignable, we'll get a proper error later from the AstChecker
} }
}
// struct assignments will be flattened (if it's not a struct literal) // struct assignments will be flattened (if it's not a struct literal)
if(valuetype==DataType.STRUCT && targettype==DataType.STRUCT) { if (valuetype == DataType.STRUCT && targettype == DataType.STRUCT) {
if(assg.value is StructLiteralValue) if (assg.value is StructLiteralValue)
return assg // do NOT flatten it at this point!! (the compiler will take care if it, later, if needed) return assg // do NOT flatten it at this point!! (the compiler will take care if it, later, if needed)
val assignments = flattenStructAssignmentFromIdentifier(assg, program) // 'structvar1 = structvar2' val assignments = flattenStructAssignmentFromIdentifier(assg, program) // 'structvar1 = structvar2'
return if(assignments.isEmpty()) { return if (assignments.isEmpty()) {
// something went wrong (probably incompatible struct types) // something went wrong (probably incompatible struct types)
// we'll get an error later from the AstChecker // we'll get an error later from the AstChecker
assg assg
} else { } else {
val scope = AnonymousScope(assignments.toMutableList(), assg.position) val scope = AnonymousScope(assignments.toMutableList(), assg.position)
scope.linkParents(assg.parent) scope.linkParents(assg.parent)
scope scope
}
} }
} }
@@ -283,8 +286,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
when(val sub = call.target.targetStatement(scope)) { when(val sub = call.target.targetStatement(scope)) {
is Subroutine -> { is Subroutine -> {
for(arg in sub.parameters.zip(call.arglist.withIndex())) { for(arg in sub.parameters.zip(call.arglist.withIndex())) {
val argtype = arg.second.value.inferType(program) val argItype = arg.second.value.inferType(program)
if(argtype!=null) { if(argItype.isKnown) {
val argtype = argItype.typeOrElse(DataType.STRUCT)
val requiredType = arg.first.type val requiredType = arg.first.type
if (requiredType != argtype) { if (requiredType != argtype) {
if (argtype isAssignableTo requiredType) { if (argtype isAssignableTo requiredType) {
@@ -302,8 +306,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
if(func.pure) { if(func.pure) {
// non-pure functions don't get automatic typecasts because sometimes they act directly on their parameters // non-pure functions don't get automatic typecasts because sometimes they act directly on their parameters
for (arg in func.parameters.zip(call.arglist.withIndex())) { for (arg in func.parameters.zip(call.arglist.withIndex())) {
val argtype = arg.second.value.inferType(program) val argItype = arg.second.value.inferType(program)
if (argtype != null) { if (argItype.isKnown) {
val argtype = argItype.typeOrElse(DataType.STRUCT)
if (arg.first.possibleDatatypes.any { argtype == it }) if (arg.first.possibleDatatypes.any { argtype == it })
continue continue
for (possibleType in arg.first.possibleDatatypes) { for (possibleType in arg.first.possibleDatatypes) {
@@ -334,7 +339,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
override fun visit(memread: DirectMemoryRead): Expression { override fun visit(memread: DirectMemoryRead): Expression {
// make sure the memory address is an uword // make sure the memory address is an uword
val dt = memread.addressExpression.inferType(program) val dt = memread.addressExpression.inferType(program)
if(dt!=DataType.UWORD) { if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memread.addressExpression as? NumericLiteralValue val literaladdr = memread.addressExpression as? NumericLiteralValue
if(literaladdr!=null) { if(literaladdr!=null) {
memread.addressExpression = literaladdr.cast(DataType.UWORD) memread.addressExpression = literaladdr.cast(DataType.UWORD)
@@ -348,7 +353,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
override fun visit(memwrite: DirectMemoryWrite) { override fun visit(memwrite: DirectMemoryWrite) {
val dt = memwrite.addressExpression.inferType(program) val dt = memwrite.addressExpression.inferType(program)
if(dt!=DataType.UWORD) { if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memwrite.addressExpression as? NumericLiteralValue val literaladdr = memwrite.addressExpression as? NumericLiteralValue
if(literaladdr!=null) { if(literaladdr!=null) {
memwrite.addressExpression = literaladdr.cast(DataType.UWORD) memwrite.addressExpression = literaladdr.cast(DataType.UWORD)
@@ -391,7 +396,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
structLv.values = struct.statements.zip(structLv.values).map { structLv.values = struct.statements.zip(structLv.values).map {
val memberDt = (it.first as VarDecl).datatype val memberDt = (it.first as VarDecl).datatype
val valueDt = it.second.inferType(program) val valueDt = it.second.inferType(program)
if (valueDt != memberDt) if (valueDt.typeOrElse(memberDt) != memberDt)
TypecastExpression(it.second, memberDt, true, it.second.position) TypecastExpression(it.second, memberDt, true, it.second.position)
else else
it.second it.second

View File

@@ -135,7 +135,7 @@ internal class VarInitValueAndAddressOfCreator(private val program: Program): IA
for(arg in args.withIndex().zip(signature.parameters)) { for(arg in args.withIndex().zip(signature.parameters)) {
val argvalue = arg.first.value val argvalue = arg.first.value
val argDt = argvalue.inferType(program) val argDt = argvalue.inferType(program)
if(argDt in PassByReferenceDatatypes && DataType.UWORD in arg.second.possibleDatatypes) { if(argDt.typeOrElse(DataType.UBYTE) in PassByReferenceDatatypes && DataType.UWORD in arg.second.possibleDatatypes) {
if(argvalue !is IdentifierReference) if(argvalue !is IdentifierReference)
throw CompilerException("pass-by-reference parameter isn't an identifier? $argvalue") throw CompilerException("pass-by-reference parameter isn't an identifier? $argvalue")
val addrOf = AddressOf(argvalue, argvalue.position) val addrOf = AddressOf(argvalue, argvalue.position)

View File

@@ -368,25 +368,23 @@ data class AssignTarget(val register: Register?,
} }
} }
fun inferType(program: Program, stmt: Statement): DataType? { fun inferType(program: Program, stmt: Statement): InferredTypes.InferredType {
if(register!=null) if(register!=null)
return DataType.UBYTE return InferredTypes.knownFor(DataType.UBYTE)
if(identifier!=null) { if(identifier!=null) {
val symbol = program.namespace.lookup(identifier!!.nameInSource, stmt) ?: return null val symbol = program.namespace.lookup(identifier!!.nameInSource, stmt) ?: return InferredTypes.unknown()
if (symbol is VarDecl) return symbol.datatype if (symbol is VarDecl) return InferredTypes.knownFor(symbol.datatype)
} }
if(arrayindexed!=null) { if(arrayindexed!=null) {
val dt = arrayindexed!!.inferType(program) return arrayindexed!!.inferType(program)
if(dt!=null)
return dt
} }
if(memoryAddress!=null) if(memoryAddress!=null)
return DataType.UBYTE return InferredTypes.knownFor(DataType.UBYTE)
return null return InferredTypes.unknown()
} }
infix fun isSameAs(value: Expression): Boolean { infix fun isSameAs(value: Expression): Boolean {

View File

@@ -663,7 +663,10 @@ internal class AsmGen2(val program: Program,
} }
fun translateSubroutineArgument(parameter: IndexedValue<SubroutineParameter>, value: Expression, sub: Subroutine) { fun translateSubroutineArgument(parameter: IndexedValue<SubroutineParameter>, value: Expression, sub: Subroutine) {
val sourceDt = value.inferType(program)!! val sourceIDt = value.inferType(program)
if(!sourceIDt.isKnown)
throw AssemblyError("arg type unknown")
val sourceDt = sourceIDt.typeOrElse(DataType.STRUCT)
if(!argumentTypeCompatible(sourceDt, parameter.value.type)) if(!argumentTypeCompatible(sourceDt, parameter.value.type))
throw AssemblyError("argument type incompatible") throw AssemblyError("argument type incompatible")
if(sub.asmParameterRegisters.isEmpty()) { if(sub.asmParameterRegisters.isEmpty()) {
@@ -849,7 +852,7 @@ internal class AsmGen2(val program: Program,
private fun translate(stmt: IfStatement) { private fun translate(stmt: IfStatement) {
translateExpression(stmt.condition) translateExpression(stmt.condition)
translateTestStack(stmt.condition.inferType(program)!!) translateTestStack(stmt.condition.inferType(program).typeOrElse(DataType.STRUCT))
val elseLabel = makeLabel("if_else") val elseLabel = makeLabel("if_else")
val endLabel = makeLabel("if_end") val endLabel = makeLabel("if_end")
out(" beq $elseLabel") out(" beq $elseLabel")
@@ -877,7 +880,10 @@ internal class AsmGen2(val program: Program,
out(whileLabel) out(whileLabel)
// TODO optimize for the simple cases, can we avoid stack use? // TODO optimize for the simple cases, can we avoid stack use?
translateExpression(stmt.condition) translateExpression(stmt.condition)
if(stmt.condition.inferType(program) in ByteDatatypes) { val conditionDt = stmt.condition.inferType(program)
if(!conditionDt.isKnown)
throw AssemblyError("unknown condition dt")
if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) {
out(" inx | lda $ESTACK_LO_HEX,x | beq $endLabel") out(" inx | lda $ESTACK_LO_HEX,x | beq $endLabel")
} else { } else {
out(""" out("""
@@ -904,7 +910,10 @@ internal class AsmGen2(val program: Program,
// TODO optimize this for the simple cases, can we avoid stack use? // TODO optimize this for the simple cases, can we avoid stack use?
translate(stmt.body) translate(stmt.body)
translateExpression(stmt.untilCondition) translateExpression(stmt.untilCondition)
if(stmt.untilCondition.inferType(program) in ByteDatatypes) { val conditionDt = stmt.untilCondition.inferType(program)
if(!conditionDt.isKnown)
throw AssemblyError("unknown condition dt")
if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) {
out(" inx | lda $ESTACK_LO_HEX,x | beq $repeatLabel") out(" inx | lda $ESTACK_LO_HEX,x | beq $repeatLabel")
} else { } else {
out(""" out("""
@@ -924,8 +933,10 @@ internal class AsmGen2(val program: Program,
translateExpression(stmt.condition) translateExpression(stmt.condition)
val endLabel = makeLabel("choice_end") val endLabel = makeLabel("choice_end")
val choiceBlocks = mutableListOf<Pair<String, AnonymousScope>>() val choiceBlocks = mutableListOf<Pair<String, AnonymousScope>>()
val conditionDt = stmt.condition.inferType(program)!! val conditionDt = stmt.condition.inferType(program)
if(conditionDt in ByteDatatypes) if(!conditionDt.isKnown)
throw AssemblyError("unknown condition dt")
if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes)
out(" inx | lda $ESTACK_LO_HEX,x") out(" inx | lda $ESTACK_LO_HEX,x")
else else
out(" inx | lda $ESTACK_LO_HEX,x | ldy $ESTACK_HI_HEX,x") out(" inx | lda $ESTACK_LO_HEX,x | ldy $ESTACK_HI_HEX,x")
@@ -939,7 +950,7 @@ internal class AsmGen2(val program: Program,
choiceBlocks.add(Pair(choiceLabel, choice.statements)) choiceBlocks.add(Pair(choiceLabel, choice.statements))
for (cv in choice.values!!) { for (cv in choice.values!!) {
val value = (cv as NumericLiteralValue).number.toInt() val value = (cv as NumericLiteralValue).number.toInt()
if(conditionDt in ByteDatatypes) { if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) {
out(" cmp #${value.toHex()} | beq $choiceLabel") out(" cmp #${value.toHex()} | beq $choiceLabel")
} else { } else {
out(""" out("""
@@ -1026,20 +1037,22 @@ internal class AsmGen2(val program: Program,
} }
private fun translate(stmt: ForLoop) { private fun translate(stmt: ForLoop) {
val iterableDt = stmt.iterable.inferType(program)!! val iterableDt = stmt.iterable.inferType(program)
if(!iterableDt.isKnown)
throw AssemblyError("can't determine iterable dt")
when(stmt.iterable) { when(stmt.iterable) {
is RangeExpr -> { is RangeExpr -> {
val range = (stmt.iterable as RangeExpr).toConstantIntegerRange() val range = (stmt.iterable as RangeExpr).toConstantIntegerRange()
if(range==null) { if(range==null) {
translateForOverNonconstRange(stmt, iterableDt, stmt.iterable as RangeExpr) translateForOverNonconstRange(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as RangeExpr)
} else { } else {
if (range.isEmpty()) if (range.isEmpty())
throw AssemblyError("empty range") throw AssemblyError("empty range")
translateForOverConstRange(stmt, iterableDt, range) translateForOverConstRange(stmt, iterableDt.typeOrElse(DataType.STRUCT), range)
} }
} }
is IdentifierReference -> { is IdentifierReference -> {
translateForOverIterableVar(stmt, iterableDt, stmt.iterable as IdentifierReference) translateForOverIterableVar(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as IdentifierReference)
} }
else -> throw AssemblyError("can't iterate over ${stmt.iterable}") else -> throw AssemblyError("can't iterate over ${stmt.iterable}")
} }
@@ -1524,7 +1537,7 @@ $endLabel""")
} }
targetIdent!=null -> { targetIdent!=null -> {
val what = asmIdentifierName(targetIdent) val what = asmIdentifierName(targetIdent)
val dt = stmt.target.inferType(program, stmt) val dt = stmt.target.inferType(program, stmt).typeOrElse(DataType.STRUCT)
when (dt) { when (dt) {
in ByteDatatypes -> out(if (incr) " inc $what" else " dec $what") in ByteDatatypes -> out(if (incr) " inc $what" else " dec $what")
in WordDatatypes -> { in WordDatatypes -> {
@@ -1562,7 +1575,7 @@ $endLabel""")
targetArrayIdx!=null -> { targetArrayIdx!=null -> {
val index = targetArrayIdx.arrayspec.index val index = targetArrayIdx.arrayspec.index
val what = asmIdentifierName(targetArrayIdx.identifier) val what = asmIdentifierName(targetArrayIdx.identifier)
val arrayDt = targetArrayIdx.identifier.inferType(program)!! val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT)
val elementDt = ArrayElementTypes.getValue(arrayDt) val elementDt = ArrayElementTypes.getValue(arrayDt)
when(index) { when(index) {
is NumericLiteralValue -> { is NumericLiteralValue -> {
@@ -1676,7 +1689,7 @@ $endLabel""")
assignFromRegister(assign.target, (assign.value as RegisterExpr).register) assignFromRegister(assign.target, (assign.value as RegisterExpr).register)
} }
is IdentifierReference -> { is IdentifierReference -> {
val type = assign.target.inferType(program, assign)!! val type = assign.target.inferType(program, assign).typeOrElse(DataType.STRUCT)
when(type) { when(type) {
DataType.UBYTE, DataType.BYTE -> assignFromByteVariable(assign.target, assign.value as IdentifierReference) DataType.UBYTE, DataType.BYTE -> assignFromByteVariable(assign.target, assign.value as IdentifierReference)
DataType.UWORD, DataType.WORD -> assignFromWordVariable(assign.target, assign.value as IdentifierReference) DataType.UWORD, DataType.WORD -> assignFromWordVariable(assign.target, assign.value as IdentifierReference)
@@ -1743,8 +1756,9 @@ $endLabel""")
val cast = assign.value as TypecastExpression val cast = assign.value as TypecastExpression
val sourceType = cast.expression.inferType(program) val sourceType = cast.expression.inferType(program)
val targetType = assign.target.inferType(program, assign) val targetType = assign.target.inferType(program, assign)
if((sourceType in ByteDatatypes && targetType in ByteDatatypes) || if(sourceType.isKnown && targetType.isKnown &&
(sourceType in WordDatatypes && targetType in WordDatatypes)) { (sourceType.typeOrElse(DataType.STRUCT) in ByteDatatypes && targetType.typeOrElse(DataType.STRUCT) in ByteDatatypes) ||
(sourceType.typeOrElse(DataType.STRUCT) in WordDatatypes && targetType.typeOrElse(DataType.STRUCT) in WordDatatypes)) {
// no need for a type cast // no need for a type cast
assign.value = cast.expression assign.value = cast.expression
translate(assign) translate(assign)
@@ -1787,7 +1801,7 @@ $endLabel""")
private fun translateExpression(expr: TypecastExpression) { private fun translateExpression(expr: TypecastExpression) {
translateExpression(expr.expression) translateExpression(expr.expression)
when(expr.expression.inferType(program)!!) { when(expr.expression.inferType(program).typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> { DataType.UBYTE -> {
when(expr.type) { when(expr.type) {
DataType.UBYTE, DataType.BYTE -> {} DataType.UBYTE, DataType.BYTE -> {}
@@ -1959,7 +1973,7 @@ $endLabel""")
private fun translateExpression(expr: IdentifierReference) { private fun translateExpression(expr: IdentifierReference) {
val varname = asmIdentifierName(expr) val varname = asmIdentifierName(expr)
when(expr.inferType(program)!!) { when(expr.inferType(program).typeOrElse(DataType.STRUCT)) {
DataType.UBYTE, DataType.BYTE -> { DataType.UBYTE, DataType.BYTE -> {
out(" lda $varname | sta $ESTACK_LO_HEX,x | dex") out(" lda $varname | sta $ESTACK_LO_HEX,x | dex")
} }
@@ -1979,9 +1993,13 @@ $endLabel""")
private val powerOfTwos = setOf(0,1,2,4,8,16,32,64,128,256) private val powerOfTwos = setOf(0,1,2,4,8,16,32,64,128,256)
private fun translateExpression(expr: BinaryExpression) { private fun translateExpression(expr: BinaryExpression) {
val leftDt = expr.left.inferType(program)!! val leftIDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)!! val rightIDt = expr.right.inferType(program)
if(!leftIDt.isKnown || !rightIDt.isKnown)
throw AssemblyError("can't infer type of both expression operands")
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
val rightDt = rightIDt.typeOrElse(DataType.STRUCT)
// see if we can apply some optimized routines // see if we can apply some optimized routines
when(expr.operator) { when(expr.operator) {
">>" -> { ">>" -> {
@@ -2075,7 +2093,7 @@ $endLabel""")
private fun translateExpression(expr: PrefixExpression) { private fun translateExpression(expr: PrefixExpression) {
translateExpression(expr.expression) translateExpression(expr.expression)
val type = expr.inferType(program) val type = expr.inferType(program).typeOrElse(DataType.STRUCT)
when(expr.operator) { when(expr.operator) {
"+" -> {} "+" -> {}
"-" -> { "-" -> {
@@ -2208,7 +2226,7 @@ $endLabel""")
} }
targetIdent!=null -> { targetIdent!=null -> {
val targetName = asmIdentifierName(targetIdent) val targetName = asmIdentifierName(targetIdent)
val targetDt = targetIdent.inferType(program)!! val targetDt = targetIdent.inferType(program).typeOrElse(DataType.STRUCT)
when(targetDt) { when(targetDt) {
DataType.UBYTE, DataType.BYTE -> { DataType.UBYTE, DataType.BYTE -> {
out(" inx | lda $ESTACK_LO_HEX,x | sta $targetName") out(" inx | lda $ESTACK_LO_HEX,x | sta $targetName")
@@ -2305,10 +2323,10 @@ $endLabel""")
targetArrayIdx!=null -> { targetArrayIdx!=null -> {
val index = targetArrayIdx.arrayspec.index val index = targetArrayIdx.arrayspec.index
val targetName = asmIdentifierName(targetArrayIdx.identifier) val targetName = asmIdentifierName(targetArrayIdx.identifier)
val arrayDt = targetArrayIdx.identifier.inferType(program)!!
out(" lda $sourceName | sta $ESTACK_LO_HEX,x | lda $sourceName+1 | sta $ESTACK_HI_HEX,x | dex") out(" lda $sourceName | sta $ESTACK_LO_HEX,x | lda $sourceName+1 | sta $ESTACK_HI_HEX,x | dex")
translateExpression(index) translateExpression(index)
out(" inx | lda $ESTACK_LO_HEX,x") out(" inx | lda $ESTACK_LO_HEX,x")
val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT)
popAndWriteArrayvalueWithIndexA(arrayDt, targetName) popAndWriteArrayvalueWithIndexA(arrayDt, targetName)
} }
else -> TODO("assign wordvar to $target") else -> TODO("assign wordvar to $target")
@@ -2364,7 +2382,7 @@ $endLabel""")
targetArrayIdx!=null -> { targetArrayIdx!=null -> {
val index = targetArrayIdx.arrayspec.index val index = targetArrayIdx.arrayspec.index
val targetName = asmIdentifierName(targetArrayIdx.identifier) val targetName = asmIdentifierName(targetArrayIdx.identifier)
val arrayDt = targetArrayIdx.identifier.inferType(program)!! val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT)
out(" lda $sourceName | sta $ESTACK_LO_HEX,x | dex") out(" lda $sourceName | sta $ESTACK_LO_HEX,x | dex")
translateExpression(index) translateExpression(index)
out(" inx | lda $ESTACK_LO_HEX,x") out(" inx | lda $ESTACK_LO_HEX,x")

View File

@@ -43,7 +43,7 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
when(functionName) { when(functionName) {
"msb" -> { "msb" -> {
val arg = fcall.arglist.single() val arg = fcall.arglist.single()
if(arg.inferType(program) !in WordDatatypes) if(arg.inferType(program).typeOrElse(DataType.STRUCT) !in WordDatatypes)
throw AssemblyError("msb required word argument") throw AssemblyError("msb required word argument")
if(arg is NumericLiteralValue) if(arg is NumericLiteralValue)
throw AssemblyError("should have been const-folded") throw AssemblyError("should have been const-folded")
@@ -61,8 +61,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
} }
"abs" -> { "abs" -> {
translateFunctionArguments(fcall.arglist, func) translateFunctionArguments(fcall.arglist, func)
val dt = fcall.arglist.single().inferType(program)!! val dt = fcall.arglist.single().inferType(program)
when (dt) { when (dt.typeOrElse(DataType.STRUCT)) {
in ByteDatatypes -> asmgen.out(" jsr prog8_lib.abs_b") in ByteDatatypes -> asmgen.out(" jsr prog8_lib.abs_b")
in WordDatatypes -> asmgen.out(" jsr prog8_lib.abs_w") in WordDatatypes -> asmgen.out(" jsr prog8_lib.abs_w")
DataType.FLOAT -> asmgen.out(" jsr c64flt.abs_f") DataType.FLOAT -> asmgen.out(" jsr c64flt.abs_f")
@@ -86,8 +86,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
} }
"min", "max", "sum" -> { "min", "max", "sum" -> {
outputPushAddressAndLenghtOfArray(fcall.arglist[0]) outputPushAddressAndLenghtOfArray(fcall.arglist[0])
val dt = fcall.arglist.single().inferType(program)!! val dt = fcall.arglist.single().inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_ub") DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_ub")
DataType.ARRAY_B -> asmgen.out(" jsr prog8_lib.func_${functionName}_b") DataType.ARRAY_B -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW -> asmgen.out(" jsr prog8_lib.func_${functionName}_uw") DataType.ARRAY_UW -> asmgen.out(" jsr prog8_lib.func_${functionName}_uw")
@@ -98,8 +98,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
} }
"any", "all" -> { "any", "all" -> {
outputPushAddressAndLenghtOfArray(fcall.arglist[0]) outputPushAddressAndLenghtOfArray(fcall.arglist[0])
val dt = fcall.arglist.single().inferType(program)!! val dt = fcall.arglist.single().inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
DataType.ARRAY_B, DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_b") DataType.ARRAY_B, DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW, DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w") DataType.ARRAY_UW, DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w")
DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f") DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f")
@@ -134,8 +134,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"lsl" -> { "lsl" -> {
// in-place // in-place
val what = fcall.arglist.single() val what = fcall.arglist.single()
val dt = what.inferType(program)!! val dt = what.inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
in ByteDatatypes -> { in ByteDatatypes -> {
when(what) { when(what) {
is RegisterExpr -> { is RegisterExpr -> {
@@ -168,8 +168,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"lsr" -> { "lsr" -> {
// in-place // in-place
val what = fcall.arglist.single() val what = fcall.arglist.single()
val dt = what.inferType(program)!! val dt = what.inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> { DataType.UBYTE -> {
when(what) { when(what) {
is RegisterExpr -> { is RegisterExpr -> {
@@ -208,8 +208,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"rol" -> { "rol" -> {
// in-place // in-place
val what = fcall.arglist.single() val what = fcall.arglist.single()
val dt = what.inferType(program)!! val dt = what.inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> { DataType.UBYTE -> {
TODO("rol ubyte") TODO("rol ubyte")
} }
@@ -222,8 +222,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"rol2" -> { "rol2" -> {
// in-place // in-place
val what = fcall.arglist.single() val what = fcall.arglist.single()
val dt = what.inferType(program)!! val dt = what.inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> { DataType.UBYTE -> {
TODO("rol2 ubyte") TODO("rol2 ubyte")
} }
@@ -236,8 +236,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"ror" -> { "ror" -> {
// in-place // in-place
val what = fcall.arglist.single() val what = fcall.arglist.single()
val dt = what.inferType(program)!! val dt = what.inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> { DataType.UBYTE -> {
TODO("ror ubyte") TODO("ror ubyte")
} }
@@ -250,8 +250,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"ror2" -> { "ror2" -> {
// in-place // in-place
val what = fcall.arglist.single() val what = fcall.arglist.single()
val dt = what.inferType(program)!! val dt = what.inferType(program)
when(dt) { when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> { DataType.UBYTE -> {
TODO("ror2 ubyte") TODO("ror2 ubyte")
} }

View File

@@ -112,25 +112,29 @@ val BuiltinFunctions = mapOf(
) )
fun builtinFunctionReturnType(function: String, args: List<Expression>, program: Program): DataType? { fun builtinFunctionReturnType(function: String, args: List<Expression>, program: Program): InferredTypes.InferredType {
fun datatypeFromIterableArg(arglist: Expression): DataType { fun datatypeFromIterableArg(arglist: Expression): DataType {
if(arglist is ArrayLiteralValue) { if(arglist is ArrayLiteralValue) {
if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) { if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) {
val dt = arglist.value.map {it.inferType(program)} val dt = arglist.value.map {it.inferType(program)}
if(dt.any { it!= DataType.UBYTE && it!= DataType.UWORD && it!= DataType.FLOAT}) { if(dt.any { !(it istype DataType.UBYTE) && !(it istype DataType.UWORD) && !(it istype DataType.FLOAT)}) {
throw FatalAstException("fuction $function only accepts arraysize of numeric values") throw FatalAstException("fuction $function only accepts arraysize of numeric values")
} }
if(dt.any { it== DataType.FLOAT }) return DataType.FLOAT if(dt.any { it istype DataType.FLOAT }) return DataType.FLOAT
if(dt.any { it== DataType.UWORD }) return DataType.UWORD if(dt.any { it istype DataType.UWORD }) return DataType.UWORD
return DataType.UBYTE return DataType.UBYTE
} }
} }
if(arglist is IdentifierReference) { if(arglist is IdentifierReference) {
return when(val dt = arglist.inferType(program)) { val idt = arglist.inferType(program)
in NumericDatatypes -> dt!! if(!idt.isKnown)
in StringDatatypes -> dt!! throw FatalAstException("couldn't determine type of iterable $arglist")
in ArrayDatatypes -> ArrayElementTypes.getValue(dt!!) val dt = idt.typeOrElse(DataType.STRUCT)
return when(dt) {
in NumericDatatypes -> dt
in StringDatatypes -> dt
in ArrayDatatypes -> ArrayElementTypes.getValue(dt)
else -> throw FatalAstException("function '$function' requires one argument which is an iterable") else -> throw FatalAstException("function '$function' requires one argument which is an iterable")
} }
} }
@@ -139,43 +143,43 @@ fun builtinFunctionReturnType(function: String, args: List<Expression>, program:
val func = BuiltinFunctions.getValue(function) val func = BuiltinFunctions.getValue(function)
if(func.returntype!=null) if(func.returntype!=null)
return func.returntype return InferredTypes.knownFor(func.returntype)
// function has return values, but the return type depends on the arguments // function has return values, but the return type depends on the arguments
return when (function) { return when (function) {
"abs" -> { "abs" -> {
val dt = args.single().inferType(program) val dt = args.single().inferType(program)
if(dt in NumericDatatypes) if(dt.typeOrElse(DataType.STRUCT) in NumericDatatypes)
return dt return dt
else else
throw FatalAstException("weird datatype passed to abs $dt") throw FatalAstException("weird datatype passed to abs $dt")
} }
"max", "min" -> { "max", "min" -> {
when(val dt = datatypeFromIterableArg(args.single())) { when(val dt = datatypeFromIterableArg(args.single())) {
in NumericDatatypes -> dt in NumericDatatypes -> InferredTypes.knownFor(dt)
in StringDatatypes -> DataType.UBYTE in StringDatatypes -> InferredTypes.knownFor(DataType.UBYTE)
in ArrayDatatypes -> ArrayElementTypes.getValue(dt) in ArrayDatatypes -> InferredTypes.knownFor(ArrayElementTypes.getValue(dt))
else -> null else -> InferredTypes.unknown()
} }
} }
"sum" -> { "sum" -> {
when(datatypeFromIterableArg(args.single())) { when(datatypeFromIterableArg(args.single())) {
DataType.UBYTE, DataType.UWORD -> DataType.UWORD DataType.UBYTE, DataType.UWORD -> InferredTypes.knownFor(DataType.UWORD)
DataType.BYTE, DataType.WORD -> DataType.WORD DataType.BYTE, DataType.WORD -> InferredTypes.knownFor(DataType.WORD)
DataType.FLOAT -> DataType.FLOAT DataType.FLOAT -> InferredTypes.knownFor(DataType.FLOAT)
DataType.ARRAY_UB, DataType.ARRAY_UW -> DataType.UWORD DataType.ARRAY_UB, DataType.ARRAY_UW -> InferredTypes.knownFor(DataType.UWORD)
DataType.ARRAY_B, DataType.ARRAY_W -> DataType.WORD DataType.ARRAY_B, DataType.ARRAY_W -> InferredTypes.knownFor(DataType.WORD)
DataType.ARRAY_F -> DataType.FLOAT DataType.ARRAY_F -> InferredTypes.knownFor(DataType.FLOAT)
in StringDatatypes -> DataType.UWORD in StringDatatypes -> InferredTypes.knownFor(DataType.UWORD)
else -> null else -> InferredTypes.unknown()
} }
} }
"len" -> { "len" -> {
// a length can be >255 so in that case, the result is an UWORD instead of an UBYTE // a length can be >255 so in that case, the result is an UWORD instead of an UBYTE
// but to avoid a lot of code duplication we simply assume UWORD in all cases for now // but to avoid a lot of code duplication we simply assume UWORD in all cases for now
return DataType.UWORD return InferredTypes.knownFor(DataType.UWORD)
} }
else -> return null else -> return InferredTypes.unknown()
} }
} }

View File

@@ -79,7 +79,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
errors.add(ExpressionError("range expression size doesn't match declared array size", decl.value?.position!!)) errors.add(ExpressionError("range expression size doesn't match declared array size", decl.value?.position!!))
val constRange = rangeExpr.toConstantIntegerRange() val constRange = rangeExpr.toConstantIntegerRange()
if(constRange!=null) { if(constRange!=null) {
val eltType = rangeExpr.inferType(program)!! val eltType = rangeExpr.inferType(program).typeOrElse(DataType.UBYTE)
if(eltType in ByteDatatypes) { if(eltType in ByteDatatypes) {
decl.value = ArrayLiteralValue(decl.datatype, decl.value = ArrayLiteralValue(decl.datatype,
constRange.map { NumericLiteralValue(eltType, it.toShort(), decl.value!!.position) }.toTypedArray(), constRange.map { NumericLiteralValue(eltType, it.toShort(), decl.value!!.position) }.toTypedArray(),
@@ -612,7 +612,10 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
val lv = assignment.value as? NumericLiteralValue val lv = assignment.value as? NumericLiteralValue
if(lv!=null) { if(lv!=null) {
// see if we can promote/convert a literal value to the required datatype // see if we can promote/convert a literal value to the required datatype
when(assignment.target.inferType(program, assignment)) { val idt = assignment.target.inferType(program, assignment)
if(!idt.isKnown)
return assignment
when(idt.typeOrElse(DataType.STRUCT)) {
DataType.UWORD -> { DataType.UWORD -> {
// we can convert to UWORD: any UBYTE, BYTE/WORD that are >=0, FLOAT that's an integer 0..65535, // we can convert to UWORD: any UBYTE, BYTE/WORD that are >=0, FLOAT that's an integer 0..65535,
if(lv.type== DataType.UBYTE) if(lv.type== DataType.UBYTE)

View File

@@ -1,10 +1,7 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.AstException import prog8.ast.base.*
import prog8.ast.base.DataType
import prog8.ast.base.IntegerDatatypes
import prog8.ast.base.NumericDatatypes
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.statements.Assignment import prog8.ast.statements.Assignment
@@ -136,9 +133,14 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val constTrue = NumericLiteralValue.fromBoolean(true, expr.position) val constTrue = NumericLiteralValue.fromBoolean(true, expr.position)
val constFalse = NumericLiteralValue.fromBoolean(false, expr.position) val constFalse = NumericLiteralValue.fromBoolean(false, expr.position)
val leftDt = expr.left.inferType(program) val leftIDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program) val rightIDt = expr.right.inferType(program)
if (leftDt != null && rightDt != null && leftDt != rightDt) { if(!leftIDt.isKnown || !rightIDt.isKnown)
throw FatalAstException("can't determine datatype of both expression operands $expr")
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
val rightDt = rightIDt.typeOrElse(DataType.STRUCT)
if (leftDt != rightDt) {
// try to convert a datatype into the other (where ddd // try to convert a datatype into the other (where ddd
if (adjustDatatypes(expr, leftVal, leftDt, rightVal, rightDt)) { if (adjustDatatypes(expr, leftVal, leftDt, rightVal, rightDt)) {
optimizationsDone++ optimizationsDone++
@@ -226,7 +228,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val x = expr.right val x = expr.right
val y = determineY(x, leftBinExpr) val y = determineY(x, leftBinExpr)
if(y!=null) { if(y!=null) {
val yPlus1 = BinaryExpression(y, "+", NumericLiteralValue(leftDt!!, 1, y.position), y.position) val yPlus1 = BinaryExpression(y, "+", NumericLiteralValue(leftDt, 1, y.position), y.position)
return BinaryExpression(x, "*", yPlus1, x.position) return BinaryExpression(x, "*", yPlus1, x.position)
} }
} else { } else {
@@ -235,7 +237,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val x = expr.right val x = expr.right
val y = determineY(x, leftBinExpr) val y = determineY(x, leftBinExpr)
if(y!=null) { if(y!=null) {
val yMinus1 = BinaryExpression(y, "-", NumericLiteralValue(leftDt!!, 1, y.position), y.position) val yMinus1 = BinaryExpression(y, "-", NumericLiteralValue(leftDt, 1, y.position), y.position)
return BinaryExpression(x, "*", yMinus1, x.position) return BinaryExpression(x, "*", yMinus1, x.position)
} }
} }
@@ -590,7 +592,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
"%" -> { "%" -> {
if (cv == 1.0) { if (cv == 1.0) {
optimizationsDone++ optimizationsDone++
return NumericLiteralValue(expr.inferType(program)!!, 0, expr.position) return NumericLiteralValue(expr.inferType(program).typeOrElse(DataType.STRUCT), 0, expr.position)
} else if (cv == 2.0) { } else if (cv == 2.0) {
optimizationsDone++ optimizationsDone++
expr.operator = "&" expr.operator = "&"
@@ -613,7 +615,10 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
// right value is a constant, see if we can optimize // right value is a constant, see if we can optimize
val rightConst: NumericLiteralValue = rightVal val rightConst: NumericLiteralValue = rightVal
val cv = rightConst.number.toDouble() val cv = rightConst.number.toDouble()
val leftDt = expr.left.inferType(program) val leftIDt = expr.left.inferType(program)
if(!leftIDt.isKnown)
return expr
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
when(cv) { when(cv) {
-1.0 -> { -1.0 -> {
// '/' -> -left // '/' -> -left
@@ -701,7 +706,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr.left return expr.left
} }
2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0 -> { 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0 -> {
if(leftValue.inferType(program) in IntegerDatatypes) { if(leftValue.inferType(program).typeOrElse(DataType.STRUCT) in IntegerDatatypes) {
// times a power of two => shift left // times a power of two => shift left
optimizationsDone++ optimizationsDone++
val numshifts = log2(cv).toInt() val numshifts = log2(cv).toInt()
@@ -709,7 +714,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
} }
} }
-2.0, -4.0, -8.0, -16.0, -32.0, -64.0, -128.0, -256.0, -512.0, -1024.0, -2048.0, -4096.0, -8192.0, -16384.0, -32768.0, -65536.0 -> { -2.0, -4.0, -8.0, -16.0, -32.0, -64.0, -128.0, -256.0, -512.0, -1024.0, -2048.0, -4096.0, -8192.0, -16384.0, -32768.0, -65536.0 -> {
if(leftValue.inferType(program) in IntegerDatatypes) { if(leftValue.inferType(program).typeOrElse(DataType.STRUCT) in IntegerDatatypes) {
// times a negative power of two => negate, then shift left // times a negative power of two => negate, then shift left
optimizationsDone++ optimizationsDone++
val numshifts = log2(-cv).toInt() val numshifts = log2(-cv).toInt()

View File

@@ -10,6 +10,7 @@ import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.processing.IAstVisitor import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.target.c64.Petscii import prog8.compiler.target.c64.Petscii
import prog8.compiler.target.c64.codegen2.AssemblyError
import prog8.functions.BuiltinFunctions import prog8.functions.BuiltinFunctions
import kotlin.math.floor import kotlin.math.floor
@@ -422,7 +423,10 @@ internal class StatementOptimizer(private val program: Program) : IAstModifyingV
return NopStatement.insteadOf(assignment) return NopStatement.insteadOf(assignment)
} }
} }
val targetDt = assignment.target.inferType(program, assignment) val targetIDt = assignment.target.inferType(program, assignment)
if(!targetIDt.isKnown)
throw AssemblyError("can't infer type of assignment target")
val targetDt = targetIDt.typeOrElse(DataType.STRUCT)
val bexpr=assignment.value as? BinaryExpression val bexpr=assignment.value as? BinaryExpression
if(bexpr!=null) { if(bexpr!=null) {
val cv = bexpr.right.constValue(program)?.number?.toDouble() val cv = bexpr.right.constValue(program)?.number?.toDouble()

View File

@@ -412,9 +412,11 @@ class AstVm(val program: Program) {
stmt.target.arrayindexed != null -> { stmt.target.arrayindexed != null -> {
val arrayvar = stmt.target.arrayindexed!!.identifier.targetVarDecl(program.namespace)!! val arrayvar = stmt.target.arrayindexed!!.identifier.targetVarDecl(program.namespace)!!
val arrayvalue = runtimeVariables.get(arrayvar.definingScope(), arrayvar.name) val arrayvalue = runtimeVariables.get(arrayvar.definingScope(), arrayvar.name)
val elementType = stmt.target.arrayindexed!!.inferType(program)!!
val index = evaluate(stmt.target.arrayindexed!!.arrayspec.index, evalCtx).integerValue() val index = evaluate(stmt.target.arrayindexed!!.arrayspec.index, evalCtx).integerValue()
var value = RuntimeValue(elementType, arrayvalue.array!![index].toInt()) val elementType = stmt.target.arrayindexed!!.inferType(program)
if(!elementType.isKnown)
throw VmExecutionException("unknown/void elt type")
var value = RuntimeValue(elementType.typeOrElse(DataType.BYTE), arrayvalue.array!![index].toInt())
value = when { value = when {
stmt.operator == "++" -> value.inc() stmt.operator == "++" -> value.inc()
stmt.operator == "--" -> value.dec() stmt.operator == "--" -> value.dec()
@@ -472,7 +474,8 @@ class AstVm(val program: Program) {
loopvarDt = DataType.UBYTE loopvarDt = DataType.UBYTE
loopvar = IdentifierReference(listOf(stmt.loopRegister.name), stmt.position) loopvar = IdentifierReference(listOf(stmt.loopRegister.name), stmt.position)
} else { } else {
loopvarDt = stmt.loopVar!!.inferType(program)!! val dt = stmt.loopVar!!.inferType(program)
loopvarDt = dt.typeOrElse(DataType.UBYTE)
loopvar = stmt.loopVar!! loopvar = stmt.loopVar!!
} }
val iterator = iterable.iterator() val iterator = iterable.iterator()
@@ -619,8 +622,10 @@ class AstVm(val program: Program) {
else { else {
val address = (vardecl.value as NumericLiteralValue).number.toInt() val address = (vardecl.value as NumericLiteralValue).number.toInt()
val index = evaluate(targetArrayIndexed.arrayspec.index, evalCtx).integerValue() val index = evaluate(targetArrayIndexed.arrayspec.index, evalCtx).integerValue()
val elementType = targetArrayIndexed.inferType(program)!! val elementType = targetArrayIndexed.inferType(program)
when(elementType) { if(!elementType.isKnown)
throw VmExecutionException("unknown/void array elt type $targetArrayIndexed")
when(elementType.typeOrElse(DataType.UBYTE)) {
DataType.UBYTE -> mem.setUByte(address+index, value.byteval!!) DataType.UBYTE -> mem.setUByte(address+index, value.byteval!!)
DataType.BYTE -> mem.setSByte(address+index, value.byteval!!) DataType.BYTE -> mem.setSByte(address+index, value.byteval!!)
DataType.UWORD -> mem.setUWord(address+index*2, value.wordval!!) DataType.UWORD -> mem.setUWord(address+index*2, value.wordval!!)

View File

@@ -12,7 +12,6 @@ import prog8.ast.statements.Subroutine
import prog8.ast.statements.VarDecl import prog8.ast.statements.VarDecl
import prog8.vm.RuntimeValue import prog8.vm.RuntimeValue
import prog8.vm.RuntimeValueRange import prog8.vm.RuntimeValueRange
import kotlin.math.abs
typealias BuiltinfunctionCaller = (name: String, args: List<RuntimeValue>, flags: StatusFlags) -> RuntimeValue? typealias BuiltinfunctionCaller = (name: String, args: List<RuntimeValue>, flags: StatusFlags) -> RuntimeValue?
@@ -147,24 +146,22 @@ fun evaluate(expr: Expression, ctx: EvalContext): RuntimeValue {
} }
is RangeExpr -> { is RangeExpr -> {
val cRange = expr.toConstantIntegerRange() val cRange = expr.toConstantIntegerRange()
if(cRange!=null) if(cRange!=null) {
return RuntimeValueRange(expr.inferType(ctx.program)!!, cRange) val dt = expr.inferType(ctx.program)
if(dt.isKnown)
return RuntimeValueRange(dt.typeOrElse(DataType.UBYTE), cRange)
else
throw VmExecutionException("couldn't determine datatype")
}
val fromVal = evaluate(expr.from, ctx).integerValue() val fromVal = evaluate(expr.from, ctx).integerValue()
val toVal = evaluate(expr.to, ctx).integerValue() val toVal = evaluate(expr.to, ctx).integerValue()
val stepVal = evaluate(expr.step, ctx).integerValue() val stepVal = evaluate(expr.step, ctx).integerValue()
val range = when { val range = makeRange(fromVal, toVal, stepVal)
fromVal <= toVal -> when { val dt = expr.inferType(ctx.program)
stepVal <= 0 -> IntRange.EMPTY if(dt.isKnown)
stepVal == 1 -> fromVal..toVal return RuntimeValueRange(dt.typeOrElse(DataType.UBYTE), range)
else -> fromVal..toVal step stepVal else
} throw VmExecutionException("couldn't determine datatype")
else -> when {
stepVal >= 0 -> IntRange.EMPTY
stepVal == -1 -> fromVal downTo toVal
else -> fromVal downTo toVal step abs(stepVal)
}
}
return RuntimeValueRange(expr.inferType(ctx.program)!!, range)
} }
else -> { else -> {
throw VmExecutionException("unimplemented expression node $expr") throw VmExecutionException("unimplemented expression node $expr")