improved handling of inferredType

This commit is contained in:
Irmen de Jong 2019-08-14 02:25:27 +02:00
parent b64d611e02
commit 47297f7e31
16 changed files with 359 additions and 240 deletions

View File

@ -484,8 +484,8 @@ internal class Compiler(private val program: Program) {
translatePrefixOperator(expr.operator, expr.expression.inferType(program))
}
is BinaryExpression -> {
val leftDt = expr.left.inferType(program)!!
val rightDt = expr.right.inferType(program)!!
val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)
val (commonDt, _) = expr.commonDatatype(leftDt, rightDt, expr.left, expr.right)
translate(expr.left)
if(leftDt!=commonDt)
@ -654,7 +654,7 @@ internal class Compiler(private val program: Program) {
// cast type if needed
if(builtinFuncParams!=null) {
val paramDts = builtinFuncParams[index].possibleDatatypes
val argDt = arg.inferType(program)!!
val argDt = arg.inferType(program)
if(argDt !in paramDts) {
for(paramDt in paramDts.sorted())
if(tryConvertType(argDt, paramDt))
@ -827,8 +827,8 @@ internal class Compiler(private val program: Program) {
// swap(x,y) is treated differently, it's not a normal function call
if (args.size != 2)
throw AstException("swap requires 2 arguments")
val dt1 = args[0].inferType(program)!!
val dt2 = args[1].inferType(program)!!
val dt1 = args[0].inferType(program)
val dt2 = args[1].inferType(program)
if (dt1 != dt2)
throw AstException("swap requires 2 args of identical type")
if (args[0].constValue(program) != null || args[1].constValue(program) != null)
@ -861,7 +861,7 @@ internal class Compiler(private val program: Program) {
// (subroutine arguments are not passed via the stack!)
for (arg in arguments.zip(subroutine.parameters)) {
translate(arg.first)
convertType(arg.first.inferType(program)!!, arg.second.type) // convert types of arguments to required parameter type
convertType(arg.first.inferType(program), arg.second.type) // convert types of arguments to required parameter type
val opcode = opcodePopvar(arg.second.type)
prog.instr(opcode, callLabel = subroutine.scopedname + "." + arg.second.name)
}

View File

@ -30,7 +30,7 @@ sealed class Expression: Node {
abstract fun accept(visitor: IAstModifyingVisitor): Expression
abstract fun accept(visitor: IAstVisitor)
abstract fun referencesIdentifiers(vararg name: String): Boolean // todo: remove this and add identifier usage tracking into CallGraph instead
abstract fun inferType(program: Program): DataType?
abstract fun inferType(program: Program): InferredTypes.InferredType
infix fun isSameAs(other: Expression): Boolean {
if(this===other)
@ -68,7 +68,7 @@ class PrefixExpression(val operator: String, var expression: Expression, overrid
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? = expression.inferType(program)
override fun inferType(program: Program): InferredTypes.InferredType = expression.inferType(program)
override fun toString(): String {
return "Prefix($operator $expression)"
@ -94,15 +94,22 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = left.referencesIdentifiers(*name) || right.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? {
override fun inferType(program: Program): InferredTypes.InferredType {
val leftDt = left.inferType(program)
val rightDt = right.inferType(program)
return when (operator) {
"+", "-", "*", "**", "%", "/" -> if (leftDt == null || rightDt == null) null else {
try {
commonDatatype(leftDt, rightDt, null, null).first
} catch (x: FatalAstException) {
null
"+", "-", "*", "**", "%", "/" -> {
if (!leftDt.isKnown || !rightDt.isKnown)
InferredTypes.unknown()
else {
try {
InferredTypes.knownFor(commonDatatype(
leftDt.typeOrElse(DataType.BYTE),
rightDt.typeOrElse(DataType.BYTE),
null, null).first)
} catch (x: FatalAstException) {
InferredTypes.unknown()
}
}
}
"&" -> leftDt
@ -111,7 +118,7 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
"and", "or", "xor",
"<", ">",
"<=", ">=",
"==", "!=" -> DataType.UBYTE
"==", "!=" -> InferredTypes.knownFor(DataType.UBYTE)
"<<", ">>" -> leftDt
else -> throw FatalAstException("resulting datatype check for invalid operator $operator")
}
@ -191,17 +198,16 @@ class ArrayIndexedExpression(var identifier: IdentifierReference,
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = identifier.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? {
override fun inferType(program: Program): InferredTypes.InferredType {
val target = identifier.targetStatement(program.namespace)
if (target is VarDecl) {
return when (target.datatype) {
in NumericDatatypes -> null
in StringDatatypes -> DataType.UBYTE
in ArrayDatatypes -> ArrayElementTypes[target.datatype]
else -> throw FatalAstException("invalid dt")
in StringDatatypes -> InferredTypes.knownFor(DataType.UBYTE)
in ArrayDatatypes -> InferredTypes.knownFor(ArrayElementTypes.getValue(target.datatype))
else -> InferredTypes.unknown()
}
}
return null
return InferredTypes.unknown()
}
override fun toString(): String {
@ -220,7 +226,7 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? = type
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
override fun constValue(program: Program): NumericLiteralValue? {
val cv = expression.constValue(program) ?: return null
return cv.cast(type)
@ -243,7 +249,7 @@ data class AddressOf(var identifier: IdentifierReference, override val position:
override fun constValue(program: Program): NumericLiteralValue? = null
override fun referencesIdentifiers(vararg name: String) = false
override fun inferType(program: Program) = DataType.UWORD
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD)
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
}
@ -259,7 +265,7 @@ class DirectMemoryRead(var addressExpression: Expression, override val position:
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = false
override fun inferType(program: Program): DataType? = DataType.UBYTE
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UBYTE)
override fun constValue(program: Program): NumericLiteralValue? = null
override fun toString(): String {
@ -315,7 +321,7 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed
override fun toString(): String = "NumericLiteral(${type.name}:$number)"
override fun inferType(program: Program) = type
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
override fun hashCode(): Int = Objects.hash(type, number)
@ -399,7 +405,7 @@ class StructLiteralValue(var values: List<Expression>,
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String) = values.any { it.referencesIdentifiers(*name) }
override fun inferType(program: Program) = DataType.STRUCT
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.STRUCT)
override fun toString(): String {
return "struct{ ${values.joinToString(", ")} }"
@ -423,7 +429,7 @@ class StringLiteralValue(val type: DataType, // only string types
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun toString(): String = "'${escape(value)}'"
override fun inferType(program: Program) = type
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
operator fun compareTo(other: StringLiteralValue): Int = value.compareTo(other.value)
override fun hashCode(): Int = Objects.hash(value, type)
override fun equals(other: Any?): Boolean {
@ -458,7 +464,7 @@ class ArrayLiteralValue(val type: DataType, // only array types
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun toString(): String = "$value"
override fun inferType(program: Program) = type
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type)
operator fun compareTo(other: ArrayLiteralValue): Int = throw ExpressionError("cannot order compare arrays", position)
override fun hashCode(): Int = Objects.hash(value, type)
override fun equals(other: Any?): Boolean {
@ -538,18 +544,18 @@ class RangeExpr(var from: Expression,
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String): Boolean = from.referencesIdentifiers(*name) || to.referencesIdentifiers(*name)
override fun inferType(program: Program): DataType? {
override fun inferType(program: Program): InferredTypes.InferredType {
val fromDt=from.inferType(program)
val toDt=to.inferType(program)
return when {
fromDt==null || toDt==null -> null
fromDt== DataType.UBYTE && toDt== DataType.UBYTE -> DataType.ARRAY_UB
fromDt== DataType.UWORD && toDt== DataType.UWORD -> DataType.ARRAY_UW
fromDt== DataType.STR && toDt== DataType.STR -> DataType.STR
fromDt== DataType.STR_S && toDt== DataType.STR_S -> DataType.STR_S
fromDt== DataType.WORD || toDt== DataType.WORD -> DataType.ARRAY_W
fromDt== DataType.BYTE || toDt== DataType.BYTE -> DataType.ARRAY_B
else -> DataType.ARRAY_UB
!fromDt.isKnown || !toDt.isKnown -> InferredTypes.unknown()
fromDt istype DataType.UBYTE && toDt istype DataType.UBYTE -> InferredTypes.knownFor(DataType.ARRAY_UB)
fromDt istype DataType.UWORD && toDt istype DataType.UWORD -> InferredTypes.knownFor(DataType.ARRAY_UW)
fromDt istype DataType.STR && toDt istype DataType.STR -> InferredTypes.knownFor(DataType.STR)
fromDt istype DataType.STR_S && toDt istype DataType.STR_S -> InferredTypes.knownFor(DataType.STR_S)
fromDt istype DataType.WORD || toDt istype DataType.WORD -> InferredTypes.knownFor(DataType.ARRAY_W)
fromDt istype DataType.BYTE || toDt istype DataType.BYTE -> InferredTypes.knownFor(DataType.ARRAY_B)
else -> InferredTypes.knownFor(DataType.ARRAY_UB)
}
}
override fun toString(): String {
@ -583,17 +589,21 @@ class RangeExpr(var from: Expression,
toVal = toLv.number.toInt()
}
val stepVal = (step as? NumericLiteralValue)?.number?.toInt() ?: 1
return when {
fromVal <= toVal -> when {
stepVal <= 0 -> IntRange.EMPTY
stepVal == 1 -> fromVal..toVal
else -> fromVal..toVal step stepVal
}
else -> when {
stepVal >= 0 -> IntRange.EMPTY
stepVal == -1 -> fromVal downTo toVal
else -> fromVal downTo toVal step abs(stepVal)
}
return makeRange(fromVal, toVal, stepVal)
}
}
internal fun makeRange(fromVal: Int, toVal: Int, stepVal: Int): IntProgression {
return when {
fromVal <= toVal -> when {
stepVal <= 0 -> IntRange.EMPTY
stepVal == 1 -> fromVal..toVal
else -> fromVal..toVal step stepVal
}
else -> when {
stepVal >= 0 -> IntRange.EMPTY
stepVal == -1 -> fromVal downTo toVal
else -> fromVal downTo toVal step abs(stepVal)
}
}
}
@ -613,7 +623,7 @@ class RegisterExpr(val register: Register, override val position: Position) : Ex
return "RegisterExpr(register=$register, pos=$position)"
}
override fun inferType(program: Program) = DataType.UBYTE
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UBYTE)
}
data class IdentifierReference(val nameInSource: List<String>, override val position: Position) : Expression() {
@ -652,10 +662,10 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String): Boolean = nameInSource.last() in name
override fun inferType(program: Program): DataType? {
override fun inferType(program: Program): InferredTypes.InferredType {
val targetStmt = targetStatement(program.namespace)
if(targetStmt is VarDecl) {
return targetStmt.datatype
return InferredTypes.knownFor(targetStmt.datatype)
} else {
throw FatalAstException("cannot get datatype from identifier reference ${this}, pos=$position")
}
@ -705,7 +715,7 @@ class FunctionCall(override var target: IdentifierReference,
if(withDatatypeCheck) {
val resultDt = this.inferType(program)
if(resultValue==null || resultDt == resultValue.type)
if(resultValue==null || resultDt istype resultValue.type)
return resultValue
throw FatalAstException("evaluated const expression result value doesn't match expected datatype $resultDt, pos=$position")
} else {
@ -726,27 +736,27 @@ class FunctionCall(override var target: IdentifierReference,
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun referencesIdentifiers(vararg name: String): Boolean = target.referencesIdentifiers(*name) || arglist.any{it.referencesIdentifiers(*name)}
override fun inferType(program: Program): DataType? {
override fun inferType(program: Program): InferredTypes.InferredType {
val constVal = constValue(program ,false)
if(constVal!=null)
return constVal.type
val stmt = target.targetStatement(program.namespace) ?: return null
return InferredTypes.knownFor(constVal.type)
val stmt = target.targetStatement(program.namespace) ?: return InferredTypes.unknown()
when (stmt) {
is BuiltinFunctionStatementPlaceholder -> {
if(target.nameInSource[0] == "set_carry" || target.nameInSource[0]=="set_irqd" ||
target.nameInSource[0] == "clear_carry" || target.nameInSource[0]=="clear_irqd") {
return null // these have no return value
return InferredTypes.void() // these have no return value
}
return builtinFunctionReturnType(target.nameInSource[0], this.arglist, program)
}
is Subroutine -> {
if(stmt.returntypes.isEmpty())
return null // no return value
return InferredTypes.void() // no return value
if(stmt.returntypes.size==1)
return stmt.returntypes[0]
return null // has multiple return types... so not a single resulting datatype possible
return InferredTypes.knownFor(stmt.returntypes[0])
return InferredTypes.unknown() // has multiple return types... so not a single resulting datatype possible
}
else -> return null
else -> return InferredTypes.unknown()
}
}
}

View File

@ -0,0 +1,50 @@
package prog8.ast.expressions
import prog8.ast.base.DataType
object InferredTypes {
class InferredType private constructor(val isUnknown: Boolean, val isVoid: Boolean, private var datatype: DataType?) {
init {
if(datatype!=null && (isUnknown || isVoid))
throw IllegalArgumentException("invalid combination of args")
}
val isKnown = datatype!=null
fun typeOrElse(alternative: DataType) = if(isUnknown || isVoid) alternative else datatype!!
infix fun istype(type: DataType): Boolean = if(isUnknown || isVoid) false else this.datatype==type
companion object {
fun unknown() = InferredType(isUnknown = true, isVoid = false, datatype = null)
fun void() = InferredType(isUnknown = false, isVoid = true, datatype = null)
fun known(type: DataType) = InferredType(isUnknown = false, isVoid = false, datatype = type)
}
override fun equals(other: Any?): Boolean {
if(other !is InferredType)
return false
return isVoid==other.isVoid && datatype==other.datatype
}
}
private val unknownInstance = InferredType.unknown()
private val voidInstance = InferredType.void()
private val knownInstances = mapOf(
DataType.BYTE to InferredType.known(DataType.BYTE),
DataType.BYTE to InferredType.known(DataType.BYTE),
DataType.UWORD to InferredType.known(DataType.UWORD),
DataType.WORD to InferredType.known(DataType.WORD),
DataType.FLOAT to InferredType.known(DataType.FLOAT),
DataType.STR to InferredType.known(DataType.STR),
DataType.STR_S to InferredType.known(DataType.STR_S),
DataType.ARRAY_UB to InferredType.known(DataType.ARRAY_UB),
DataType.ARRAY_B to InferredType.known(DataType.ARRAY_B),
DataType.ARRAY_UW to InferredType.known(DataType.ARRAY_UW),
DataType.ARRAY_W to InferredType.known(DataType.ARRAY_W),
DataType.ARRAY_F to InferredType.known(DataType.ARRAY_F),
DataType.STRUCT to InferredType.known(DataType.STRUCT)
)
fun void() = voidInstance
fun unknown() = unknownInstance
fun knownFor(type: DataType) = knownInstances.getValue(type)
}

View File

@ -97,14 +97,18 @@ internal class AstChecker(private val program: Program,
}
if(expectedReturnValues.size==1 && returnStmt.value!=null) {
val valueDt = returnStmt.value!!.inferType(program)
if(expectedReturnValues[0]!=valueDt)
checkResult.add(ExpressionError("type $valueDt of return value doesn't match subroutine's return type", returnStmt.value!!.position))
if(!valueDt.isKnown) {
checkResult.add(ExpressionError("return value type mismatch", returnStmt.value!!.position))
} else {
if (expectedReturnValues[0] != valueDt.typeOrElse(DataType.STRUCT))
checkResult.add(ExpressionError("type $valueDt of return value doesn't match subroutine's return type", returnStmt.value!!.position))
}
}
super.visit(returnStmt)
}
override fun visit(ifStatement: IfStatement) {
if(ifStatement.condition.inferType(program) !in IntegerDatatypes)
if(ifStatement.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
checkResult.add(ExpressionError("condition value should be an integer type", ifStatement.condition.position))
super.visit(ifStatement)
}
@ -113,7 +117,7 @@ internal class AstChecker(private val program: Program,
if(forLoop.body.containsNoCodeNorVars())
printWarning("for loop body is empty", forLoop.position)
val iterableDt = forLoop.iterable.inferType(program)
val iterableDt = forLoop.iterable.inferType(program).typeOrElse(DataType.BYTE)
if(iterableDt !in IterableDatatypes && forLoop.iterable !is RangeExpr) {
checkResult.add(ExpressionError("can only loop over an iterable type", forLoop.position))
} else {
@ -328,7 +332,7 @@ internal class AstChecker(private val program: Program,
override fun visit(repeatLoop: RepeatLoop) {
if(repeatLoop.untilCondition.referencesIdentifiers("A", "X", "Y"))
printWarning("using a register in the loop condition is risky (it could get clobbered)", repeatLoop.untilCondition.position)
if(repeatLoop.untilCondition.inferType(program) !in IntegerDatatypes)
if(repeatLoop.untilCondition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
checkResult.add(ExpressionError("condition value should be an integer type", repeatLoop.untilCondition.position))
super.visit(repeatLoop)
}
@ -336,7 +340,7 @@ internal class AstChecker(private val program: Program,
override fun visit(whileLoop: WhileLoop) {
if(whileLoop.condition.referencesIdentifiers("A", "X", "Y"))
printWarning("using a register in the loop condition is risky (it could get clobbered)", whileLoop.condition.position)
if(whileLoop.condition.inferType(program) !in IntegerDatatypes)
if(whileLoop.condition.inferType(program).typeOrElse(DataType.STRUCT) !in IntegerDatatypes)
checkResult.add(ExpressionError("condition value should be an integer type", whileLoop.condition.position))
super.visit(whileLoop)
}
@ -350,7 +354,8 @@ internal class AstChecker(private val program: Program,
if(stmt.returntypes.size>1)
checkResult.add(ExpressionError("It's not possible to store the multiple results of this asmsub call; you should use a small block of custom inline assembly for this.", assignment.value.position))
else {
if(stmt.returntypes.single()!=assignment.target.inferType(program, assignment)) {
val idt = assignment.target.inferType(program, assignment)
if(!idt.isKnown || stmt.returntypes.single()!=idt.typeOrElse(DataType.BYTE)) {
checkResult.add(ExpressionError("return type mismatch", assignment.value.position))
}
}
@ -402,7 +407,7 @@ internal class AstChecker(private val program: Program,
}
}
}
val targetDt = assignTarget.inferType(program, assignment)
val targetDt = assignTarget.inferType(program, assignment).typeOrElse(DataType.STR)
if(targetDt in StringDatatypes || targetDt in ArrayDatatypes)
checkResult.add(SyntaxError("cannot assign to a string or array type", assignTarget.position))
@ -412,13 +417,13 @@ internal class AstChecker(private val program: Program,
throw FatalAstException("augmented assignment should have been converted into normal assignment")
val targetDatatype = assignTarget.inferType(program, assignment)
if (targetDatatype != null) {
if (targetDatatype.isKnown) {
val constVal = assignment.value.constValue(program)
if (constVal != null) {
checkValueTypeAndRange(targetDatatype, constVal)
checkValueTypeAndRange(targetDatatype.typeOrElse(DataType.BYTE), constVal)
} else {
val sourceDatatype: DataType? = assignment.value.inferType(program)
if (sourceDatatype == null) {
val sourceDatatype = assignment.value.inferType(program)
if (!sourceDatatype.isKnown) {
if (assignment.value is FunctionCall) {
val targetStmt = (assignment.value as FunctionCall).target.targetStatement(program.namespace)
if (targetStmt != null)
@ -426,7 +431,8 @@ internal class AstChecker(private val program: Program,
} else
checkResult.add(ExpressionError("assignment value is invalid or has no proper datatype", assignment.value.position))
} else {
checkAssignmentCompatible(targetDatatype, assignTarget, sourceDatatype, assignment.value, assignment.position)
checkAssignmentCompatible(targetDatatype.typeOrElse(DataType.BYTE), assignTarget,
sourceDatatype.typeOrElse(DataType.BYTE), assignment.value, assignment.position)
}
}
}
@ -701,7 +707,7 @@ internal class AstChecker(private val program: Program,
override fun visit(expr: PrefixExpression) {
if(expr.operator=="-") {
val dt = expr.inferType(program)
val dt = expr.inferType(program).typeOrElse(DataType.STRUCT)
if (dt != DataType.BYTE && dt != DataType.WORD && dt != DataType.FLOAT) {
checkResult.add(ExpressionError("can only take negative of a signed number type", expr.position))
}
@ -710,8 +716,13 @@ internal class AstChecker(private val program: Program,
}
override fun visit(expr: BinaryExpression) {
val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)
val leftIDt = expr.left.inferType(program)
val rightIDt = expr.right.inferType(program)
if(!leftIDt.isKnown || !rightIDt.isKnown) {
throw FatalAstException("can't determine datatype of both expression operands $expr")
}
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
val rightDt = rightIDt.typeOrElse(DataType.STRUCT)
when(expr.operator){
"/", "%" -> {
@ -842,30 +853,30 @@ internal class AstChecker(private val program: Program,
val paramTypesForAddressOf = PassByReferenceDatatypes + DataType.UWORD
for (arg in args.withIndex().zip(func.parameters)) {
val argDt=arg.first.value.inferType(program)
if (argDt != null
&& !(argDt isAssignableTo arg.second.possibleDatatypes)
&& (argDt != DataType.UWORD || arg.second.possibleDatatypes.intersect(paramTypesForAddressOf).isEmpty())) {
if (argDt.isKnown
&& !(argDt.typeOrElse(DataType.STRUCT) isAssignableTo arg.second.possibleDatatypes)
&& (argDt.typeOrElse(DataType.STRUCT) != DataType.UWORD || arg.second.possibleDatatypes.intersect(paramTypesForAddressOf).isEmpty())) {
checkResult.add(ExpressionError("builtin function '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.possibleDatatypes}", position))
}
}
if(target.name=="swap") {
// swap() is a bit weird because this one is translated into a operations directly, instead of being a function call
val dt1 = args[0].inferType(program)!!
val dt2 = args[1].inferType(program)!!
val dt1 = args[0].inferType(program)
val dt2 = args[1].inferType(program)
if (dt1 != dt2)
checkResult.add(ExpressionError("swap requires 2 args of identical type", position))
else if (args[0].constValue(program) != null || args[1].constValue(program) != null)
checkResult.add(ExpressionError("swap requires 2 variables, not constant value(s)", position))
else if(args[0] isSameAs args[1])
checkResult.add(ExpressionError("swap should have 2 different args", position))
else if(dt1 !in NumericDatatypes)
else if(dt1.typeOrElse(DataType.STRUCT) !in NumericDatatypes)
checkResult.add(ExpressionError("swap requires args of numerical type", position))
}
else if(target.name=="all" || target.name=="any") {
if((args[0] as? AddressOf)?.identifier?.targetVarDecl(program.namespace)?.datatype in StringDatatypes) {
checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position))
}
if(args[0].inferType(program) in StringDatatypes) {
if(args[0].inferType(program).typeOrElse(DataType.STR) in StringDatatypes) {
checkResult.add(ExpressionError("any/all on a string is useless (is always true unless the string is empty)", position))
}
}
@ -875,8 +886,12 @@ internal class AstChecker(private val program: Program,
checkResult.add(SyntaxError("invalid number of arguments", position))
else {
for (arg in args.withIndex().zip(target.parameters)) {
val argDt = arg.first.value.inferType(program)
if(argDt!=null && !(argDt isAssignableTo arg.second.type)) {
val argIDt = arg.first.value.inferType(program)
if(!argIDt.isKnown) {
throw FatalAstException("can't determine arg dt ${arg.first.value}")
}
val argDt=argIDt.typeOrElse(DataType.STRUCT)
if(!(argDt isAssignableTo arg.second.type)) {
// for asm subroutines having STR param it's okay to provide a UWORD (address value)
if(!(target.isAsmSubroutine && arg.second.type in StringDatatypes && argDt == DataType.UWORD))
checkResult.add(ExpressionError("subroutine '${target.name}' argument ${arg.first.index + 1} has invalid type $argDt, expected ${arg.second.type}", position))
@ -960,7 +975,7 @@ internal class AstChecker(private val program: Program,
checkResult.add(SyntaxError("indexing requires a variable to act upon", arrayIndexedExpression.position))
// check index value 0..255
val dtx = arrayIndexedExpression.arrayspec.index.inferType(program)
val dtx = arrayIndexedExpression.arrayspec.index.inferType(program).typeOrElse(DataType.STRUCT)
if(dtx!= DataType.UBYTE && dtx!= DataType.BYTE)
checkResult.add(SyntaxError("array indexing is limited to byte size 0..255", arrayIndexedExpression.position))
@ -968,7 +983,7 @@ internal class AstChecker(private val program: Program,
}
override fun visit(whenStatement: WhenStatement) {
val conditionType = whenStatement.condition.inferType(program)
val conditionType = whenStatement.condition.inferType(program).typeOrElse(DataType.STRUCT)
if(conditionType !in IntegerDatatypes)
checkResult.add(SyntaxError("when condition must be an integer value", whenStatement.position))
val choiceValues = whenStatement.choiceValues(program)
@ -987,12 +1002,14 @@ internal class AstChecker(private val program: Program,
val whenStmt = whenChoice.parent as WhenStatement
if(whenChoice.values!=null) {
val conditionType = whenStmt.condition.inferType(program)
if(!conditionType.isKnown)
throw FatalAstException("can't determine when choice datatype $whenChoice")
val constvalues = whenChoice.values!!.map { it.constValue(program) }
for(constvalue in constvalues) {
when {
constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position))
constvalue.type !in IntegerDatatypes -> checkResult.add(SyntaxError("choice value must be a byte or word", whenChoice.position))
constvalue.type != conditionType -> checkResult.add(SyntaxError("choice value datatype differs from condition value", whenChoice.position))
constvalue.type != conditionType.typeOrElse(DataType.STRUCT) -> checkResult.add(SyntaxError("choice value datatype differs from condition value", whenChoice.position))
}
}
} else {
@ -1129,8 +1146,8 @@ internal class AstChecker(private val program: Program,
return err("number of values is not the same as the number of members in the struct")
for(elt in value.value.zip(struct.statements)) {
val vardecl = elt.second as VarDecl
val valuetype = elt.first.inferType(program)!!
if (!(valuetype isAssignableTo vardecl.datatype)) {
val valuetype = elt.first.inferType(program)
if (!valuetype.isKnown || !(valuetype.typeOrElse(DataType.STRUCT) isAssignableTo vardecl.datatype)) {
checkResult.add(ExpressionError("invalid struct member init value type $valuetype, expected ${vardecl.datatype}", elt.first.position))
return false
}

View File

@ -285,15 +285,16 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
}
private fun determineArrayDt(array: Array<Expression>): DataType {
val datatypesInArray = array.mapNotNull { it.inferType(program) }
if(datatypesInArray.isEmpty())
val datatypesInArray = array.map { it.inferType(program) }
if(datatypesInArray.isEmpty() || datatypesInArray.any { !it.isKnown })
throw IllegalArgumentException("can't determine type of empty array")
val dts = datatypesInArray.map { it.typeOrElse(DataType.STRUCT) }
return when {
DataType.FLOAT in datatypesInArray -> DataType.ARRAY_F
DataType.WORD in datatypesInArray -> DataType.ARRAY_W
DataType.UWORD in datatypesInArray -> DataType.ARRAY_UW
DataType.BYTE in datatypesInArray -> DataType.ARRAY_B
DataType.UBYTE in datatypesInArray -> DataType.ARRAY_UB
DataType.FLOAT in dts -> DataType.ARRAY_F
DataType.WORD in dts -> DataType.ARRAY_W
DataType.UWORD in dts -> DataType.ARRAY_UW
DataType.BYTE in dts -> DataType.ARRAY_B
DataType.UBYTE in dts -> DataType.ARRAY_UB
else -> throw IllegalArgumentException("can't determine type of array")
}
}
@ -351,13 +352,15 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
if(constvalue!=null) {
if (expr.operator == "*") {
// repeat a string a number of times
return StringLiteralValue(string.inferType(program),
val idt = string.inferType(program)
return StringLiteralValue(idt.typeOrElse(DataType.STR),
string.value.repeat(constvalue.number.toInt()), null, expr.position)
}
}
if(expr.operator == "+" && operand is StringLiteralValue) {
// concatenate two strings
return StringLiteralValue(string.inferType(program),
val idt = string.inferType(program)
return StringLiteralValue(idt.typeOrElse(DataType.STR),
"${string.value}${operand.value}", null, expr.position)
}
return expr

View File

@ -192,9 +192,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return expr2
val leftDt = expr2.left.inferType(program)
val rightDt = expr2.right.inferType(program)
if(leftDt!=null && rightDt!=null && leftDt!=rightDt) {
if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) {
// determine common datatype and add typecast as required to make left and right equal types
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt, rightDt, expr2.left, expr2.right)
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr2.left, expr2.right)
if(toFix!=null) {
when {
toFix===expr2.left -> {
@ -218,32 +218,35 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
return assg
// see if a typecast is needed to convert the value's type into the proper target type
val valuetype = assg.value.inferType(program)
val targettype = assg.target.inferType(program, assg)
if(targettype!=null && valuetype!=null) {
if(valuetype!=targettype) {
val valueItype = assg.value.inferType(program)
val targetItype = assg.target.inferType(program, assg)
if(targetItype.isKnown && valueItype.isKnown) {
val targettype = targetItype.typeOrElse(DataType.STRUCT)
val valuetype = valueItype.typeOrElse(DataType.STRUCT)
if (valuetype != targettype) {
if (valuetype isAssignableTo targettype) {
assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position)
assg.value.linkParents(assg)
}
// if they're not assignable, we'll get a proper error later from the AstChecker
}
}
// struct assignments will be flattened (if it's not a struct literal)
if(valuetype==DataType.STRUCT && targettype==DataType.STRUCT) {
if(assg.value is StructLiteralValue)
return assg // do NOT flatten it at this point!! (the compiler will take care if it, later, if needed)
// struct assignments will be flattened (if it's not a struct literal)
if (valuetype == DataType.STRUCT && targettype == DataType.STRUCT) {
if (assg.value is StructLiteralValue)
return assg // do NOT flatten it at this point!! (the compiler will take care if it, later, if needed)
val assignments = flattenStructAssignmentFromIdentifier(assg, program) // 'structvar1 = structvar2'
return if(assignments.isEmpty()) {
// something went wrong (probably incompatible struct types)
// we'll get an error later from the AstChecker
assg
} else {
val scope = AnonymousScope(assignments.toMutableList(), assg.position)
scope.linkParents(assg.parent)
scope
val assignments = flattenStructAssignmentFromIdentifier(assg, program) // 'structvar1 = structvar2'
return if (assignments.isEmpty()) {
// something went wrong (probably incompatible struct types)
// we'll get an error later from the AstChecker
assg
} else {
val scope = AnonymousScope(assignments.toMutableList(), assg.position)
scope.linkParents(assg.parent)
scope
}
}
}
@ -283,8 +286,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
when(val sub = call.target.targetStatement(scope)) {
is Subroutine -> {
for(arg in sub.parameters.zip(call.arglist.withIndex())) {
val argtype = arg.second.value.inferType(program)
if(argtype!=null) {
val argItype = arg.second.value.inferType(program)
if(argItype.isKnown) {
val argtype = argItype.typeOrElse(DataType.STRUCT)
val requiredType = arg.first.type
if (requiredType != argtype) {
if (argtype isAssignableTo requiredType) {
@ -302,8 +306,9 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
if(func.pure) {
// non-pure functions don't get automatic typecasts because sometimes they act directly on their parameters
for (arg in func.parameters.zip(call.arglist.withIndex())) {
val argtype = arg.second.value.inferType(program)
if (argtype != null) {
val argItype = arg.second.value.inferType(program)
if (argItype.isKnown) {
val argtype = argItype.typeOrElse(DataType.STRUCT)
if (arg.first.possibleDatatypes.any { argtype == it })
continue
for (possibleType in arg.first.possibleDatatypes) {
@ -334,7 +339,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
override fun visit(memread: DirectMemoryRead): Expression {
// make sure the memory address is an uword
val dt = memread.addressExpression.inferType(program)
if(dt!=DataType.UWORD) {
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memread.addressExpression as? NumericLiteralValue
if(literaladdr!=null) {
memread.addressExpression = literaladdr.cast(DataType.UWORD)
@ -348,7 +353,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
override fun visit(memwrite: DirectMemoryWrite) {
val dt = memwrite.addressExpression.inferType(program)
if(dt!=DataType.UWORD) {
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memwrite.addressExpression as? NumericLiteralValue
if(literaladdr!=null) {
memwrite.addressExpression = literaladdr.cast(DataType.UWORD)
@ -391,7 +396,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
structLv.values = struct.statements.zip(structLv.values).map {
val memberDt = (it.first as VarDecl).datatype
val valueDt = it.second.inferType(program)
if (valueDt != memberDt)
if (valueDt.typeOrElse(memberDt) != memberDt)
TypecastExpression(it.second, memberDt, true, it.second.position)
else
it.second

View File

@ -135,7 +135,7 @@ internal class VarInitValueAndAddressOfCreator(private val program: Program): IA
for(arg in args.withIndex().zip(signature.parameters)) {
val argvalue = arg.first.value
val argDt = argvalue.inferType(program)
if(argDt in PassByReferenceDatatypes && DataType.UWORD in arg.second.possibleDatatypes) {
if(argDt.typeOrElse(DataType.UBYTE) in PassByReferenceDatatypes && DataType.UWORD in arg.second.possibleDatatypes) {
if(argvalue !is IdentifierReference)
throw CompilerException("pass-by-reference parameter isn't an identifier? $argvalue")
val addrOf = AddressOf(argvalue, argvalue.position)

View File

@ -368,25 +368,23 @@ data class AssignTarget(val register: Register?,
}
}
fun inferType(program: Program, stmt: Statement): DataType? {
fun inferType(program: Program, stmt: Statement): InferredTypes.InferredType {
if(register!=null)
return DataType.UBYTE
return InferredTypes.knownFor(DataType.UBYTE)
if(identifier!=null) {
val symbol = program.namespace.lookup(identifier!!.nameInSource, stmt) ?: return null
if (symbol is VarDecl) return symbol.datatype
val symbol = program.namespace.lookup(identifier!!.nameInSource, stmt) ?: return InferredTypes.unknown()
if (symbol is VarDecl) return InferredTypes.knownFor(symbol.datatype)
}
if(arrayindexed!=null) {
val dt = arrayindexed!!.inferType(program)
if(dt!=null)
return dt
return arrayindexed!!.inferType(program)
}
if(memoryAddress!=null)
return DataType.UBYTE
return InferredTypes.knownFor(DataType.UBYTE)
return null
return InferredTypes.unknown()
}
infix fun isSameAs(value: Expression): Boolean {

View File

@ -663,7 +663,10 @@ internal class AsmGen2(val program: Program,
}
fun translateSubroutineArgument(parameter: IndexedValue<SubroutineParameter>, value: Expression, sub: Subroutine) {
val sourceDt = value.inferType(program)!!
val sourceIDt = value.inferType(program)
if(!sourceIDt.isKnown)
throw AssemblyError("arg type unknown")
val sourceDt = sourceIDt.typeOrElse(DataType.STRUCT)
if(!argumentTypeCompatible(sourceDt, parameter.value.type))
throw AssemblyError("argument type incompatible")
if(sub.asmParameterRegisters.isEmpty()) {
@ -849,7 +852,7 @@ internal class AsmGen2(val program: Program,
private fun translate(stmt: IfStatement) {
translateExpression(stmt.condition)
translateTestStack(stmt.condition.inferType(program)!!)
translateTestStack(stmt.condition.inferType(program).typeOrElse(DataType.STRUCT))
val elseLabel = makeLabel("if_else")
val endLabel = makeLabel("if_end")
out(" beq $elseLabel")
@ -877,7 +880,10 @@ internal class AsmGen2(val program: Program,
out(whileLabel)
// TODO optimize for the simple cases, can we avoid stack use?
translateExpression(stmt.condition)
if(stmt.condition.inferType(program) in ByteDatatypes) {
val conditionDt = stmt.condition.inferType(program)
if(!conditionDt.isKnown)
throw AssemblyError("unknown condition dt")
if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) {
out(" inx | lda $ESTACK_LO_HEX,x | beq $endLabel")
} else {
out("""
@ -904,7 +910,10 @@ internal class AsmGen2(val program: Program,
// TODO optimize this for the simple cases, can we avoid stack use?
translate(stmt.body)
translateExpression(stmt.untilCondition)
if(stmt.untilCondition.inferType(program) in ByteDatatypes) {
val conditionDt = stmt.untilCondition.inferType(program)
if(!conditionDt.isKnown)
throw AssemblyError("unknown condition dt")
if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) {
out(" inx | lda $ESTACK_LO_HEX,x | beq $repeatLabel")
} else {
out("""
@ -924,8 +933,10 @@ internal class AsmGen2(val program: Program,
translateExpression(stmt.condition)
val endLabel = makeLabel("choice_end")
val choiceBlocks = mutableListOf<Pair<String, AnonymousScope>>()
val conditionDt = stmt.condition.inferType(program)!!
if(conditionDt in ByteDatatypes)
val conditionDt = stmt.condition.inferType(program)
if(!conditionDt.isKnown)
throw AssemblyError("unknown condition dt")
if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes)
out(" inx | lda $ESTACK_LO_HEX,x")
else
out(" inx | lda $ESTACK_LO_HEX,x | ldy $ESTACK_HI_HEX,x")
@ -939,7 +950,7 @@ internal class AsmGen2(val program: Program,
choiceBlocks.add(Pair(choiceLabel, choice.statements))
for (cv in choice.values!!) {
val value = (cv as NumericLiteralValue).number.toInt()
if(conditionDt in ByteDatatypes) {
if(conditionDt.typeOrElse(DataType.BYTE) in ByteDatatypes) {
out(" cmp #${value.toHex()} | beq $choiceLabel")
} else {
out("""
@ -1026,20 +1037,22 @@ internal class AsmGen2(val program: Program,
}
private fun translate(stmt: ForLoop) {
val iterableDt = stmt.iterable.inferType(program)!!
val iterableDt = stmt.iterable.inferType(program)
if(!iterableDt.isKnown)
throw AssemblyError("can't determine iterable dt")
when(stmt.iterable) {
is RangeExpr -> {
val range = (stmt.iterable as RangeExpr).toConstantIntegerRange()
if(range==null) {
translateForOverNonconstRange(stmt, iterableDt, stmt.iterable as RangeExpr)
translateForOverNonconstRange(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as RangeExpr)
} else {
if (range.isEmpty())
throw AssemblyError("empty range")
translateForOverConstRange(stmt, iterableDt, range)
translateForOverConstRange(stmt, iterableDt.typeOrElse(DataType.STRUCT), range)
}
}
is IdentifierReference -> {
translateForOverIterableVar(stmt, iterableDt, stmt.iterable as IdentifierReference)
translateForOverIterableVar(stmt, iterableDt.typeOrElse(DataType.STRUCT), stmt.iterable as IdentifierReference)
}
else -> throw AssemblyError("can't iterate over ${stmt.iterable}")
}
@ -1524,7 +1537,7 @@ $endLabel""")
}
targetIdent!=null -> {
val what = asmIdentifierName(targetIdent)
val dt = stmt.target.inferType(program, stmt)
val dt = stmt.target.inferType(program, stmt).typeOrElse(DataType.STRUCT)
when (dt) {
in ByteDatatypes -> out(if (incr) " inc $what" else " dec $what")
in WordDatatypes -> {
@ -1562,7 +1575,7 @@ $endLabel""")
targetArrayIdx!=null -> {
val index = targetArrayIdx.arrayspec.index
val what = asmIdentifierName(targetArrayIdx.identifier)
val arrayDt = targetArrayIdx.identifier.inferType(program)!!
val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT)
val elementDt = ArrayElementTypes.getValue(arrayDt)
when(index) {
is NumericLiteralValue -> {
@ -1676,7 +1689,7 @@ $endLabel""")
assignFromRegister(assign.target, (assign.value as RegisterExpr).register)
}
is IdentifierReference -> {
val type = assign.target.inferType(program, assign)!!
val type = assign.target.inferType(program, assign).typeOrElse(DataType.STRUCT)
when(type) {
DataType.UBYTE, DataType.BYTE -> assignFromByteVariable(assign.target, assign.value as IdentifierReference)
DataType.UWORD, DataType.WORD -> assignFromWordVariable(assign.target, assign.value as IdentifierReference)
@ -1743,8 +1756,9 @@ $endLabel""")
val cast = assign.value as TypecastExpression
val sourceType = cast.expression.inferType(program)
val targetType = assign.target.inferType(program, assign)
if((sourceType in ByteDatatypes && targetType in ByteDatatypes) ||
(sourceType in WordDatatypes && targetType in WordDatatypes)) {
if(sourceType.isKnown && targetType.isKnown &&
(sourceType.typeOrElse(DataType.STRUCT) in ByteDatatypes && targetType.typeOrElse(DataType.STRUCT) in ByteDatatypes) ||
(sourceType.typeOrElse(DataType.STRUCT) in WordDatatypes && targetType.typeOrElse(DataType.STRUCT) in WordDatatypes)) {
// no need for a type cast
assign.value = cast.expression
translate(assign)
@ -1787,7 +1801,7 @@ $endLabel""")
private fun translateExpression(expr: TypecastExpression) {
translateExpression(expr.expression)
when(expr.expression.inferType(program)!!) {
when(expr.expression.inferType(program).typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> {
when(expr.type) {
DataType.UBYTE, DataType.BYTE -> {}
@ -1959,7 +1973,7 @@ $endLabel""")
private fun translateExpression(expr: IdentifierReference) {
val varname = asmIdentifierName(expr)
when(expr.inferType(program)!!) {
when(expr.inferType(program).typeOrElse(DataType.STRUCT)) {
DataType.UBYTE, DataType.BYTE -> {
out(" lda $varname | sta $ESTACK_LO_HEX,x | dex")
}
@ -1979,9 +1993,13 @@ $endLabel""")
private val powerOfTwos = setOf(0,1,2,4,8,16,32,64,128,256)
private fun translateExpression(expr: BinaryExpression) {
val leftDt = expr.left.inferType(program)!!
val rightDt = expr.right.inferType(program)!!
val leftIDt = expr.left.inferType(program)
val rightIDt = expr.right.inferType(program)
if(!leftIDt.isKnown || !rightIDt.isKnown)
throw AssemblyError("can't infer type of both expression operands")
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
val rightDt = rightIDt.typeOrElse(DataType.STRUCT)
// see if we can apply some optimized routines
when(expr.operator) {
">>" -> {
@ -2075,7 +2093,7 @@ $endLabel""")
private fun translateExpression(expr: PrefixExpression) {
translateExpression(expr.expression)
val type = expr.inferType(program)
val type = expr.inferType(program).typeOrElse(DataType.STRUCT)
when(expr.operator) {
"+" -> {}
"-" -> {
@ -2208,7 +2226,7 @@ $endLabel""")
}
targetIdent!=null -> {
val targetName = asmIdentifierName(targetIdent)
val targetDt = targetIdent.inferType(program)!!
val targetDt = targetIdent.inferType(program).typeOrElse(DataType.STRUCT)
when(targetDt) {
DataType.UBYTE, DataType.BYTE -> {
out(" inx | lda $ESTACK_LO_HEX,x | sta $targetName")
@ -2305,10 +2323,10 @@ $endLabel""")
targetArrayIdx!=null -> {
val index = targetArrayIdx.arrayspec.index
val targetName = asmIdentifierName(targetArrayIdx.identifier)
val arrayDt = targetArrayIdx.identifier.inferType(program)!!
out(" lda $sourceName | sta $ESTACK_LO_HEX,x | lda $sourceName+1 | sta $ESTACK_HI_HEX,x | dex")
translateExpression(index)
out(" inx | lda $ESTACK_LO_HEX,x")
val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT)
popAndWriteArrayvalueWithIndexA(arrayDt, targetName)
}
else -> TODO("assign wordvar to $target")
@ -2364,7 +2382,7 @@ $endLabel""")
targetArrayIdx!=null -> {
val index = targetArrayIdx.arrayspec.index
val targetName = asmIdentifierName(targetArrayIdx.identifier)
val arrayDt = targetArrayIdx.identifier.inferType(program)!!
val arrayDt = targetArrayIdx.identifier.inferType(program).typeOrElse(DataType.STRUCT)
out(" lda $sourceName | sta $ESTACK_LO_HEX,x | dex")
translateExpression(index)
out(" inx | lda $ESTACK_LO_HEX,x")

View File

@ -43,7 +43,7 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
when(functionName) {
"msb" -> {
val arg = fcall.arglist.single()
if(arg.inferType(program) !in WordDatatypes)
if(arg.inferType(program).typeOrElse(DataType.STRUCT) !in WordDatatypes)
throw AssemblyError("msb required word argument")
if(arg is NumericLiteralValue)
throw AssemblyError("should have been const-folded")
@ -61,8 +61,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
}
"abs" -> {
translateFunctionArguments(fcall.arglist, func)
val dt = fcall.arglist.single().inferType(program)!!
when (dt) {
val dt = fcall.arglist.single().inferType(program)
when (dt.typeOrElse(DataType.STRUCT)) {
in ByteDatatypes -> asmgen.out(" jsr prog8_lib.abs_b")
in WordDatatypes -> asmgen.out(" jsr prog8_lib.abs_w")
DataType.FLOAT -> asmgen.out(" jsr c64flt.abs_f")
@ -86,8 +86,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
}
"min", "max", "sum" -> {
outputPushAddressAndLenghtOfArray(fcall.arglist[0])
val dt = fcall.arglist.single().inferType(program)!!
when(dt) {
val dt = fcall.arglist.single().inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_ub")
DataType.ARRAY_B -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW -> asmgen.out(" jsr prog8_lib.func_${functionName}_uw")
@ -98,8 +98,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
}
"any", "all" -> {
outputPushAddressAndLenghtOfArray(fcall.arglist[0])
val dt = fcall.arglist.single().inferType(program)!!
when(dt) {
val dt = fcall.arglist.single().inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.ARRAY_B, DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW, DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w")
DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f")
@ -134,8 +134,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"lsl" -> {
// in-place
val what = fcall.arglist.single()
val dt = what.inferType(program)!!
when(dt) {
val dt = what.inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
in ByteDatatypes -> {
when(what) {
is RegisterExpr -> {
@ -168,8 +168,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"lsr" -> {
// in-place
val what = fcall.arglist.single()
val dt = what.inferType(program)!!
when(dt) {
val dt = what.inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> {
when(what) {
is RegisterExpr -> {
@ -208,8 +208,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"rol" -> {
// in-place
val what = fcall.arglist.single()
val dt = what.inferType(program)!!
when(dt) {
val dt = what.inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> {
TODO("rol ubyte")
}
@ -222,8 +222,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"rol2" -> {
// in-place
val what = fcall.arglist.single()
val dt = what.inferType(program)!!
when(dt) {
val dt = what.inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> {
TODO("rol2 ubyte")
}
@ -236,8 +236,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"ror" -> {
// in-place
val what = fcall.arglist.single()
val dt = what.inferType(program)!!
when(dt) {
val dt = what.inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> {
TODO("ror ubyte")
}
@ -250,8 +250,8 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
"ror2" -> {
// in-place
val what = fcall.arglist.single()
val dt = what.inferType(program)!!
when(dt) {
val dt = what.inferType(program)
when(dt.typeOrElse(DataType.STRUCT)) {
DataType.UBYTE -> {
TODO("ror2 ubyte")
}

View File

@ -112,25 +112,29 @@ val BuiltinFunctions = mapOf(
)
fun builtinFunctionReturnType(function: String, args: List<Expression>, program: Program): DataType? {
fun builtinFunctionReturnType(function: String, args: List<Expression>, program: Program): InferredTypes.InferredType {
fun datatypeFromIterableArg(arglist: Expression): DataType {
if(arglist is ArrayLiteralValue) {
if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) {
val dt = arglist.value.map {it.inferType(program)}
if(dt.any { it!= DataType.UBYTE && it!= DataType.UWORD && it!= DataType.FLOAT}) {
if(dt.any { !(it istype DataType.UBYTE) && !(it istype DataType.UWORD) && !(it istype DataType.FLOAT)}) {
throw FatalAstException("fuction $function only accepts arraysize of numeric values")
}
if(dt.any { it== DataType.FLOAT }) return DataType.FLOAT
if(dt.any { it== DataType.UWORD }) return DataType.UWORD
if(dt.any { it istype DataType.FLOAT }) return DataType.FLOAT
if(dt.any { it istype DataType.UWORD }) return DataType.UWORD
return DataType.UBYTE
}
}
if(arglist is IdentifierReference) {
return when(val dt = arglist.inferType(program)) {
in NumericDatatypes -> dt!!
in StringDatatypes -> dt!!
in ArrayDatatypes -> ArrayElementTypes.getValue(dt!!)
val idt = arglist.inferType(program)
if(!idt.isKnown)
throw FatalAstException("couldn't determine type of iterable $arglist")
val dt = idt.typeOrElse(DataType.STRUCT)
return when(dt) {
in NumericDatatypes -> dt
in StringDatatypes -> dt
in ArrayDatatypes -> ArrayElementTypes.getValue(dt)
else -> throw FatalAstException("function '$function' requires one argument which is an iterable")
}
}
@ -139,43 +143,43 @@ fun builtinFunctionReturnType(function: String, args: List<Expression>, program:
val func = BuiltinFunctions.getValue(function)
if(func.returntype!=null)
return func.returntype
return InferredTypes.knownFor(func.returntype)
// function has return values, but the return type depends on the arguments
return when (function) {
"abs" -> {
val dt = args.single().inferType(program)
if(dt in NumericDatatypes)
if(dt.typeOrElse(DataType.STRUCT) in NumericDatatypes)
return dt
else
throw FatalAstException("weird datatype passed to abs $dt")
}
"max", "min" -> {
when(val dt = datatypeFromIterableArg(args.single())) {
in NumericDatatypes -> dt
in StringDatatypes -> DataType.UBYTE
in ArrayDatatypes -> ArrayElementTypes.getValue(dt)
else -> null
in NumericDatatypes -> InferredTypes.knownFor(dt)
in StringDatatypes -> InferredTypes.knownFor(DataType.UBYTE)
in ArrayDatatypes -> InferredTypes.knownFor(ArrayElementTypes.getValue(dt))
else -> InferredTypes.unknown()
}
}
"sum" -> {
when(datatypeFromIterableArg(args.single())) {
DataType.UBYTE, DataType.UWORD -> DataType.UWORD
DataType.BYTE, DataType.WORD -> DataType.WORD
DataType.FLOAT -> DataType.FLOAT
DataType.ARRAY_UB, DataType.ARRAY_UW -> DataType.UWORD
DataType.ARRAY_B, DataType.ARRAY_W -> DataType.WORD
DataType.ARRAY_F -> DataType.FLOAT
in StringDatatypes -> DataType.UWORD
else -> null
DataType.UBYTE, DataType.UWORD -> InferredTypes.knownFor(DataType.UWORD)
DataType.BYTE, DataType.WORD -> InferredTypes.knownFor(DataType.WORD)
DataType.FLOAT -> InferredTypes.knownFor(DataType.FLOAT)
DataType.ARRAY_UB, DataType.ARRAY_UW -> InferredTypes.knownFor(DataType.UWORD)
DataType.ARRAY_B, DataType.ARRAY_W -> InferredTypes.knownFor(DataType.WORD)
DataType.ARRAY_F -> InferredTypes.knownFor(DataType.FLOAT)
in StringDatatypes -> InferredTypes.knownFor(DataType.UWORD)
else -> InferredTypes.unknown()
}
}
"len" -> {
// a length can be >255 so in that case, the result is an UWORD instead of an UBYTE
// but to avoid a lot of code duplication we simply assume UWORD in all cases for now
return DataType.UWORD
return InferredTypes.knownFor(DataType.UWORD)
}
else -> return null
else -> return InferredTypes.unknown()
}
}

View File

@ -79,7 +79,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
errors.add(ExpressionError("range expression size doesn't match declared array size", decl.value?.position!!))
val constRange = rangeExpr.toConstantIntegerRange()
if(constRange!=null) {
val eltType = rangeExpr.inferType(program)!!
val eltType = rangeExpr.inferType(program).typeOrElse(DataType.UBYTE)
if(eltType in ByteDatatypes) {
decl.value = ArrayLiteralValue(decl.datatype,
constRange.map { NumericLiteralValue(eltType, it.toShort(), decl.value!!.position) }.toTypedArray(),
@ -612,7 +612,10 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
val lv = assignment.value as? NumericLiteralValue
if(lv!=null) {
// see if we can promote/convert a literal value to the required datatype
when(assignment.target.inferType(program, assignment)) {
val idt = assignment.target.inferType(program, assignment)
if(!idt.isKnown)
return assignment
when(idt.typeOrElse(DataType.STRUCT)) {
DataType.UWORD -> {
// we can convert to UWORD: any UBYTE, BYTE/WORD that are >=0, FLOAT that's an integer 0..65535,
if(lv.type== DataType.UBYTE)

View File

@ -1,10 +1,7 @@
package prog8.optimizer
import prog8.ast.Program
import prog8.ast.base.AstException
import prog8.ast.base.DataType
import prog8.ast.base.IntegerDatatypes
import prog8.ast.base.NumericDatatypes
import prog8.ast.base.*
import prog8.ast.expressions.*
import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.statements.Assignment
@ -136,9 +133,14 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val constTrue = NumericLiteralValue.fromBoolean(true, expr.position)
val constFalse = NumericLiteralValue.fromBoolean(false, expr.position)
val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)
if (leftDt != null && rightDt != null && leftDt != rightDt) {
val leftIDt = expr.left.inferType(program)
val rightIDt = expr.right.inferType(program)
if(!leftIDt.isKnown || !rightIDt.isKnown)
throw FatalAstException("can't determine datatype of both expression operands $expr")
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
val rightDt = rightIDt.typeOrElse(DataType.STRUCT)
if (leftDt != rightDt) {
// try to convert a datatype into the other (where ddd
if (adjustDatatypes(expr, leftVal, leftDt, rightVal, rightDt)) {
optimizationsDone++
@ -226,7 +228,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val x = expr.right
val y = determineY(x, leftBinExpr)
if(y!=null) {
val yPlus1 = BinaryExpression(y, "+", NumericLiteralValue(leftDt!!, 1, y.position), y.position)
val yPlus1 = BinaryExpression(y, "+", NumericLiteralValue(leftDt, 1, y.position), y.position)
return BinaryExpression(x, "*", yPlus1, x.position)
}
} else {
@ -235,7 +237,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val x = expr.right
val y = determineY(x, leftBinExpr)
if(y!=null) {
val yMinus1 = BinaryExpression(y, "-", NumericLiteralValue(leftDt!!, 1, y.position), y.position)
val yMinus1 = BinaryExpression(y, "-", NumericLiteralValue(leftDt, 1, y.position), y.position)
return BinaryExpression(x, "*", yMinus1, x.position)
}
}
@ -590,7 +592,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
"%" -> {
if (cv == 1.0) {
optimizationsDone++
return NumericLiteralValue(expr.inferType(program)!!, 0, expr.position)
return NumericLiteralValue(expr.inferType(program).typeOrElse(DataType.STRUCT), 0, expr.position)
} else if (cv == 2.0) {
optimizationsDone++
expr.operator = "&"
@ -613,7 +615,10 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
// right value is a constant, see if we can optimize
val rightConst: NumericLiteralValue = rightVal
val cv = rightConst.number.toDouble()
val leftDt = expr.left.inferType(program)
val leftIDt = expr.left.inferType(program)
if(!leftIDt.isKnown)
return expr
val leftDt = leftIDt.typeOrElse(DataType.STRUCT)
when(cv) {
-1.0 -> {
// '/' -> -left
@ -701,7 +706,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
return expr.left
}
2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0, 8192.0, 16384.0, 32768.0, 65536.0 -> {
if(leftValue.inferType(program) in IntegerDatatypes) {
if(leftValue.inferType(program).typeOrElse(DataType.STRUCT) in IntegerDatatypes) {
// times a power of two => shift left
optimizationsDone++
val numshifts = log2(cv).toInt()
@ -709,7 +714,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
}
}
-2.0, -4.0, -8.0, -16.0, -32.0, -64.0, -128.0, -256.0, -512.0, -1024.0, -2048.0, -4096.0, -8192.0, -16384.0, -32768.0, -65536.0 -> {
if(leftValue.inferType(program) in IntegerDatatypes) {
if(leftValue.inferType(program).typeOrElse(DataType.STRUCT) in IntegerDatatypes) {
// times a negative power of two => negate, then shift left
optimizationsDone++
val numshifts = log2(-cv).toInt()

View File

@ -10,6 +10,7 @@ import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.processing.IAstVisitor
import prog8.ast.statements.*
import prog8.compiler.target.c64.Petscii
import prog8.compiler.target.c64.codegen2.AssemblyError
import prog8.functions.BuiltinFunctions
import kotlin.math.floor
@ -422,7 +423,10 @@ internal class StatementOptimizer(private val program: Program) : IAstModifyingV
return NopStatement.insteadOf(assignment)
}
}
val targetDt = assignment.target.inferType(program, assignment)
val targetIDt = assignment.target.inferType(program, assignment)
if(!targetIDt.isKnown)
throw AssemblyError("can't infer type of assignment target")
val targetDt = targetIDt.typeOrElse(DataType.STRUCT)
val bexpr=assignment.value as? BinaryExpression
if(bexpr!=null) {
val cv = bexpr.right.constValue(program)?.number?.toDouble()

View File

@ -412,9 +412,11 @@ class AstVm(val program: Program) {
stmt.target.arrayindexed != null -> {
val arrayvar = stmt.target.arrayindexed!!.identifier.targetVarDecl(program.namespace)!!
val arrayvalue = runtimeVariables.get(arrayvar.definingScope(), arrayvar.name)
val elementType = stmt.target.arrayindexed!!.inferType(program)!!
val index = evaluate(stmt.target.arrayindexed!!.arrayspec.index, evalCtx).integerValue()
var value = RuntimeValue(elementType, arrayvalue.array!![index].toInt())
val elementType = stmt.target.arrayindexed!!.inferType(program)
if(!elementType.isKnown)
throw VmExecutionException("unknown/void elt type")
var value = RuntimeValue(elementType.typeOrElse(DataType.BYTE), arrayvalue.array!![index].toInt())
value = when {
stmt.operator == "++" -> value.inc()
stmt.operator == "--" -> value.dec()
@ -472,7 +474,8 @@ class AstVm(val program: Program) {
loopvarDt = DataType.UBYTE
loopvar = IdentifierReference(listOf(stmt.loopRegister.name), stmt.position)
} else {
loopvarDt = stmt.loopVar!!.inferType(program)!!
val dt = stmt.loopVar!!.inferType(program)
loopvarDt = dt.typeOrElse(DataType.UBYTE)
loopvar = stmt.loopVar!!
}
val iterator = iterable.iterator()
@ -619,8 +622,10 @@ class AstVm(val program: Program) {
else {
val address = (vardecl.value as NumericLiteralValue).number.toInt()
val index = evaluate(targetArrayIndexed.arrayspec.index, evalCtx).integerValue()
val elementType = targetArrayIndexed.inferType(program)!!
when(elementType) {
val elementType = targetArrayIndexed.inferType(program)
if(!elementType.isKnown)
throw VmExecutionException("unknown/void array elt type $targetArrayIndexed")
when(elementType.typeOrElse(DataType.UBYTE)) {
DataType.UBYTE -> mem.setUByte(address+index, value.byteval!!)
DataType.BYTE -> mem.setSByte(address+index, value.byteval!!)
DataType.UWORD -> mem.setUWord(address+index*2, value.wordval!!)

View File

@ -12,7 +12,6 @@ import prog8.ast.statements.Subroutine
import prog8.ast.statements.VarDecl
import prog8.vm.RuntimeValue
import prog8.vm.RuntimeValueRange
import kotlin.math.abs
typealias BuiltinfunctionCaller = (name: String, args: List<RuntimeValue>, flags: StatusFlags) -> RuntimeValue?
@ -147,24 +146,22 @@ fun evaluate(expr: Expression, ctx: EvalContext): RuntimeValue {
}
is RangeExpr -> {
val cRange = expr.toConstantIntegerRange()
if(cRange!=null)
return RuntimeValueRange(expr.inferType(ctx.program)!!, cRange)
if(cRange!=null) {
val dt = expr.inferType(ctx.program)
if(dt.isKnown)
return RuntimeValueRange(dt.typeOrElse(DataType.UBYTE), cRange)
else
throw VmExecutionException("couldn't determine datatype")
}
val fromVal = evaluate(expr.from, ctx).integerValue()
val toVal = evaluate(expr.to, ctx).integerValue()
val stepVal = evaluate(expr.step, ctx).integerValue()
val range = when {
fromVal <= toVal -> when {
stepVal <= 0 -> IntRange.EMPTY
stepVal == 1 -> fromVal..toVal
else -> fromVal..toVal step stepVal
}
else -> when {
stepVal >= 0 -> IntRange.EMPTY
stepVal == -1 -> fromVal downTo toVal
else -> fromVal downTo toVal step abs(stepVal)
}
}
return RuntimeValueRange(expr.inferType(ctx.program)!!, range)
val range = makeRange(fromVal, toVal, stepVal)
val dt = expr.inferType(ctx.program)
if(dt.isKnown)
return RuntimeValueRange(dt.typeOrElse(DataType.UBYTE), range)
else
throw VmExecutionException("couldn't determine datatype")
}
else -> {
throw VmExecutionException("unimplemented expression node $expr")