diff --git a/compiler/src/prog8/ast/base/Extensions.kt b/compiler/src/prog8/ast/base/Extensions.kt index 09013b01f..1e09478a1 100644 --- a/compiler/src/prog8/ast/base/Extensions.kt +++ b/compiler/src/prog8/ast/base/Extensions.kt @@ -33,6 +33,7 @@ internal fun Program.reorderStatements() { internal fun Program.addTypecasts(errors: ErrorReporter) { val caster = TypecastsAdder(this, errors) caster.visit(this) + caster.applyModifications() } internal fun Module.checkImportedValid() { diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index ac314a032..09038ef33 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -446,7 +446,7 @@ internal class AstChecker(private val program: Program, override fun visit(decl: VarDecl) { fun err(msg: String, position: Position?=null) { - err(msg, position ?: decl.position) + errors.err(msg, position ?: decl.position) } // the initializer value can't refer to the variable itself (recursive definition) diff --git a/compiler/src/prog8/ast/processing/AstWalker.kt b/compiler/src/prog8/ast/processing/AstWalker.kt index 6413e2386..4eb127a54 100644 --- a/compiler/src/prog8/ast/processing/AstWalker.kt +++ b/compiler/src/prog8/ast/processing/AstWalker.kt @@ -23,7 +23,7 @@ interface IAstModification { } } - class Replace(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification { + class ReplaceStmt(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification { override fun perform() { if(parent is INameScope) { val idx = parent.statements.indexOf(statement) @@ -34,6 +34,13 @@ interface IAstModification { } } } + + class ReplaceExpr(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification { + override fun perform() { + setter(newExpr) + newExpr.linkParents(parent) + } + } } diff --git a/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt b/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt index d8b9938e7..e5288c7fd 100644 --- a/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt +++ b/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt @@ -12,7 +12,7 @@ internal class ForeverLoopsMaker: AstWalker() { val numeric = repeatLoop.untilCondition as? NumericLiteralValue if(numeric!=null && numeric.number.toInt() == 0) { val forever = ForeverLoop(repeatLoop.body, repeatLoop.position) - return listOf(IAstModification.Replace(repeatLoop, forever, parent)) + return listOf(IAstModification.ReplaceStmt(repeatLoop, forever, parent)) } return emptyList() } @@ -21,7 +21,7 @@ internal class ForeverLoopsMaker: AstWalker() { val numeric = whileLoop.condition as? NumericLiteralValue if(numeric!=null && numeric.number.toInt() != 0) { val forever = ForeverLoop(whileLoop.body, whileLoop.position) - return listOf(IAstModification.Replace(whileLoop, forever, parent)) + return listOf(IAstModification.ReplaceStmt(whileLoop, forever, parent)) } return emptyList() } diff --git a/compiler/src/prog8/ast/processing/TypecastsAdder.kt b/compiler/src/prog8/ast/processing/TypecastsAdder.kt index b1bbd6606..33c376c49 100644 --- a/compiler/src/prog8/ast/processing/TypecastsAdder.kt +++ b/compiler/src/prog8/ast/processing/TypecastsAdder.kt @@ -2,6 +2,7 @@ package prog8.ast.processing import prog8.ast.IFunctionCall import prog8.ast.INameScope +import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.DataType import prog8.ast.base.ErrorReporter @@ -11,73 +12,64 @@ import prog8.ast.statements.* import prog8.functions.BuiltinFunctions -internal class TypecastsAdder(private val program: Program, - private val errors: ErrorReporter): IAstModifyingVisitor { - // Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type. - // (this includes function call arguments) +class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalker() { + /* + * Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type. + * (this includes function call arguments) + */ - override fun visit(expr: BinaryExpression): Expression { - val expr2 = super.visit(expr) - if(expr2 !is BinaryExpression) - return expr2 - val leftDt = expr2.left.inferType(program) - val rightDt = expr2.right.inferType(program) + override fun after(expr: BinaryExpression, parent: Node): Iterable { + val leftDt = expr.left.inferType(program) + val rightDt = expr.right.inferType(program) if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) { // determine common datatype and add typecast as required to make left and right equal types - val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr2.left, expr2.right) + val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr.left, expr.right) if(toFix!=null) { - when { - toFix===expr2.left -> { - expr2.left = TypecastExpression(expr2.left, commonDt, true, expr2.left.position) - expr2.left.linkParents(expr2) - } - toFix===expr2.right -> { - expr2.right = TypecastExpression(expr2.right, commonDt, true, expr2.right.position) - expr2.right.linkParents(expr2) - } + return when { + toFix===expr.left -> listOf(IAstModification.ReplaceExpr( + { newExpr -> expr.left = newExpr }, + TypecastExpression(expr.left, commonDt, true, expr.left.position), + expr)) + toFix===expr.right -> listOf(IAstModification.ReplaceExpr( + { newExpr -> expr.right = newExpr }, + TypecastExpression(expr.right, commonDt, true, expr.right.position), + expr)) else -> throw FatalAstException("confused binary expression side") } } } - return expr2 + return emptyList() } - override fun visit(assignment: Assignment): Statement { - val assg = super.visit(assignment) - if(assg !is Assignment) - return assg - + override fun after(assignment: Assignment, parent: Node): Iterable { // see if a typecast is needed to convert the value's type into the proper target type - val valueItype = assg.value.inferType(program) - val targetItype = assg.target.inferType(program, assg) - + val valueItype = assignment.value.inferType(program) + val targetItype = assignment.target.inferType(program, assignment) if(targetItype.isKnown && valueItype.isKnown) { val targettype = targetItype.typeOrElse(DataType.STRUCT) val valuetype = valueItype.typeOrElse(DataType.STRUCT) if (valuetype != targettype) { - if (valuetype isAssignableTo targettype) { - assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position) - assg.value.linkParents(assg) - } - // if they're not assignable, we'll get a proper error later from the AstChecker + if (valuetype isAssignableTo targettype) + return listOf(IAstModification.ReplaceExpr( + { newExpr -> assignment.value=newExpr }, + TypecastExpression(assignment.value, targettype, true, assignment.value.position), + assignment)) } } - return assg + return emptyList() } - override fun visit(functionCallStatement: FunctionCallStatement): Statement { - checkFunctionCallArguments(functionCallStatement, functionCallStatement.definingScope()) - return super.visit(functionCallStatement) + override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable { + return afterFunctionCallArgs(functionCallStatement, functionCallStatement.definingScope()) } - override fun visit(functionCall: FunctionCall): Expression { - checkFunctionCallArguments(functionCall, functionCall.definingScope()) - return super.visit(functionCall) + override fun after(functionCall: FunctionCall, parent: Node): Iterable { + return afterFunctionCallArgs(functionCall, functionCall.definingScope()) } - private fun checkFunctionCallArguments(call: IFunctionCall, scope: INameScope) { + private fun afterFunctionCallArgs(call: IFunctionCall, scope: INameScope): Iterable { // see if a typecast is needed to convert the arguments into the required parameter's type - when(val sub = call.target.targetStatement(scope)) { + return when(val sub = call.target.targetStatement(scope)) { is Subroutine -> { for(arg in sub.parameters.zip(call.args.withIndex())) { val argItype = arg.second.value.inferType(program) @@ -86,14 +78,15 @@ internal class TypecastsAdder(private val program: Program, val requiredType = arg.first.type if (requiredType != argtype) { if (argtype isAssignableTo requiredType) { - val typecasted = TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position) - typecasted.linkParents(arg.second.value.parent) - call.args[arg.second.index] = typecasted + return listOf(IAstModification.ReplaceExpr( + { newExpr -> call.args[arg.second.index] = newExpr }, + TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position), + call as Node)) } - // if they're not assignable, we'll get a proper error later from the AstChecker } } } + emptyList() } is BuiltinFunctionStatementPlaceholder -> { val func = BuiltinFunctions.getValue(sub.name) @@ -107,93 +100,103 @@ internal class TypecastsAdder(private val program: Program, continue for (possibleType in arg.first.possibleDatatypes) { if (argtype isAssignableTo possibleType) { - val typecasted = TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position) - typecasted.linkParents(arg.second.value.parent) - call.args[arg.second.index] = typecasted - break + return listOf(IAstModification.ReplaceExpr( + { newExpr -> call.args[arg.second.index] = newExpr }, + TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position), + call as Node + )) } } } } } + emptyList() } - null -> {} + null -> emptyList() else -> throw FatalAstException("call to something weird $sub ${call.target}") } } - override fun visit(typecast: TypecastExpression): Expression { + override fun after(typecast: TypecastExpression, parent: Node): Iterable { // warn about any implicit type casts to Float, because that may not be intended if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) { errors.warn("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position) } - return super.visit(typecast) + return emptyList() } - override fun visit(memread: DirectMemoryRead): Expression { + override fun after(memread: DirectMemoryRead, parent: Node): Iterable { // make sure the memory address is an uword val dt = memread.addressExpression.inferType(program) if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { - val literaladdr = memread.addressExpression as? NumericLiteralValue - if(literaladdr!=null) { - memread.addressExpression = literaladdr.cast(DataType.UWORD) - } else { - memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) - memread.addressExpression.parent = memread - } + val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD) + ?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) + return listOf(IAstModification.ReplaceExpr( + { newExpr -> memread.addressExpression = newExpr }, + typecast, + memread + )) } - return super.visit(memread) + return emptyList() } - override fun visit(memwrite: DirectMemoryWrite) { + override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable { + // make sure the memory address is an uword val dt = memwrite.addressExpression.inferType(program) if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { - val literaladdr = memwrite.addressExpression as? NumericLiteralValue - if(literaladdr!=null) { - memwrite.addressExpression = literaladdr.cast(DataType.UWORD) - } else { - memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) - memwrite.addressExpression.parent = memwrite - } + val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD) + ?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) + return listOf(IAstModification.ReplaceExpr( + { newExpr -> memwrite.addressExpression = newExpr }, + typecast, + memwrite + )) } - super.visit(memwrite) + return emptyList() } - override fun visit(structLv: StructLiteralValue): Expression { - val litval = super.visit(structLv) - if(litval !is StructLiteralValue) - return litval + override fun after(structLv: StructLiteralValue, parent: Node): Iterable { + // assignment of a struct literal value, some member values may need proper typecast - val decl = litval.parent as? VarDecl + fun addTypecastsIfNeeded(struct: StructDecl): Iterable { + val newValues = struct.statements.zip(structLv.values).map { (structMemberDecl, memberValue) -> + val memberDt = (structMemberDecl as VarDecl).datatype + val valueDt = memberValue.inferType(program) + if (valueDt.typeOrElse(memberDt) != memberDt) + TypecastExpression(memberValue, memberDt, true, memberValue.position) + else + memberValue + } + + class Replacer(val targetStructLv: StructLiteralValue, val typecastValues: List) : IAstModification { + override fun perform() { + targetStructLv.values = typecastValues + typecastValues.forEach { it.linkParents(targetStructLv) } + } + } + + return if(structLv.values.zip(newValues).any { (v1, v2) -> v1 !== v2}) + listOf(Replacer(structLv, newValues)) + else + emptyList() + } + + val decl = structLv.parent as? VarDecl if(decl != null) { val struct = decl.struct - if(struct != null) { - addTypecastsIfNeeded(litval, struct) - } + if(struct != null) + return addTypecastsIfNeeded(struct) } else { - val assign = litval.parent as? Assignment + val assign = structLv.parent as? Assignment if (assign != null) { val decl2 = assign.target.identifier?.targetVarDecl(program.namespace) if(decl2 != null) { val struct = decl2.struct - if(struct != null) { - addTypecastsIfNeeded(litval, struct) - } + if(struct != null) + return addTypecastsIfNeeded(struct) } } } - - return litval - } - - private fun addTypecastsIfNeeded(structLv: StructLiteralValue, struct: StructDecl) { - structLv.values = struct.statements.zip(structLv.values).map { - val memberDt = (it.first as VarDecl).datatype - val valueDt = it.second.inferType(program) - if (valueDt.typeOrElse(memberDt) != memberDt) - TypecastExpression(it.second, memberDt, true, it.second.position) - else - it.second - } + return emptyList() } } diff --git a/docs/source/todo.rst b/docs/source/todo.rst index c146d23a5..4fec6cf3a 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -5,6 +5,7 @@ TODO - remove statements after an exit() or return - fix warnings about that unreachable code? +- aliases for imported symbols for example perhaps '%alias print = c64scr.print' - option to load library files from a directory instead of the embedded ones diff --git a/examples/test.p8 b/examples/test.p8 index 08fa1c930..cfce52a0d 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,78 +1,19 @@ %import c64utils +%import c64flt +%option enable_floats %zeropage basicsafe main { + struct Color { + uword red + uword green + uword blue + } + sub start() { - ubyte x11=44 - byte bb0=99 - - A=x11 - - - while true { - A=99 - } - - repeat { - ubyte x1 - - x1=A - A=x1 - - if A==44 { - ubyte y1 - A=y1 - } else { - byte bb1=99 - bb1 += A - bb0=bb1 - } - A=44 - } until false - - - c64scr.print("spstart:") - print_stackpointer() - sub1() - c64scr.print("spend:") - print_stackpointer() + ; Color c = [1,2,3] ; TODO fix compiler error + Color c = {1,2,3} } - sub sub1() { - c64scr.print("sp1:") - print_stackpointer() - sub2() - } - - sub sub2() { - c64scr.print("sp2:") - print_stackpointer() - exit(33) - sub3() ; TODO warning about unreachable code - sub3() ; TODO remove statements after a return/exit - c64scr.print("sp2:") - c64scr.print("sp2:") - sub3() - - sub3() - sub3() - sub3() - sub3() - sub3() - sub3() - sub3() - sub3() - sub3() - } - - sub sub3() { - c64scr.print("sp3:") - print_stackpointer() - } - - sub print_stackpointer() { - c64scr.print_ub(X) ; prints stack pointer - c64.CHROUT('\n') - } }