replaced typecastsAdder with version based on astwalker

This commit is contained in:
Irmen de Jong 2020-03-20 22:28:18 +01:00
parent a191ec71a4
commit f265199fbe
7 changed files with 123 additions and 170 deletions

View File

@ -33,6 +33,7 @@ internal fun Program.reorderStatements() {
internal fun Program.addTypecasts(errors: ErrorReporter) { internal fun Program.addTypecasts(errors: ErrorReporter) {
val caster = TypecastsAdder(this, errors) val caster = TypecastsAdder(this, errors)
caster.visit(this) caster.visit(this)
caster.applyModifications()
} }
internal fun Module.checkImportedValid() { internal fun Module.checkImportedValid() {

View File

@ -446,7 +446,7 @@ internal class AstChecker(private val program: Program,
override fun visit(decl: VarDecl) { override fun visit(decl: VarDecl) {
fun err(msg: String, position: Position?=null) { fun err(msg: String, position: Position?=null) {
err(msg, position ?: decl.position) errors.err(msg, position ?: decl.position)
} }
// the initializer value can't refer to the variable itself (recursive definition) // the initializer value can't refer to the variable itself (recursive definition)

View File

@ -23,7 +23,7 @@ interface IAstModification {
} }
} }
class Replace(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification { class ReplaceStmt(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification {
override fun perform() { override fun perform() {
if(parent is INameScope) { if(parent is INameScope) {
val idx = parent.statements.indexOf(statement) val idx = parent.statements.indexOf(statement)
@ -34,6 +34,13 @@ interface IAstModification {
} }
} }
} }
class ReplaceExpr(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification {
override fun perform() {
setter(newExpr)
newExpr.linkParents(parent)
}
}
} }

View File

@ -12,7 +12,7 @@ internal class ForeverLoopsMaker: AstWalker() {
val numeric = repeatLoop.untilCondition as? NumericLiteralValue val numeric = repeatLoop.untilCondition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() == 0) { if(numeric!=null && numeric.number.toInt() == 0) {
val forever = ForeverLoop(repeatLoop.body, repeatLoop.position) val forever = ForeverLoop(repeatLoop.body, repeatLoop.position)
return listOf(IAstModification.Replace(repeatLoop, forever, parent)) return listOf(IAstModification.ReplaceStmt(repeatLoop, forever, parent))
} }
return emptyList() return emptyList()
} }
@ -21,7 +21,7 @@ internal class ForeverLoopsMaker: AstWalker() {
val numeric = whileLoop.condition as? NumericLiteralValue val numeric = whileLoop.condition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() != 0) { if(numeric!=null && numeric.number.toInt() != 0) {
val forever = ForeverLoop(whileLoop.body, whileLoop.position) val forever = ForeverLoop(whileLoop.body, whileLoop.position)
return listOf(IAstModification.Replace(whileLoop, forever, parent)) return listOf(IAstModification.ReplaceStmt(whileLoop, forever, parent))
} }
return emptyList() return emptyList()
} }

View File

@ -2,6 +2,7 @@ package prog8.ast.processing
import prog8.ast.IFunctionCall import prog8.ast.IFunctionCall
import prog8.ast.INameScope import prog8.ast.INameScope
import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.DataType import prog8.ast.base.DataType
import prog8.ast.base.ErrorReporter import prog8.ast.base.ErrorReporter
@ -11,73 +12,64 @@ import prog8.ast.statements.*
import prog8.functions.BuiltinFunctions import prog8.functions.BuiltinFunctions
internal class TypecastsAdder(private val program: Program, class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalker() {
private val errors: ErrorReporter): IAstModifyingVisitor { /*
// Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type. * Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type.
// (this includes function call arguments) * (this includes function call arguments)
*/
override fun visit(expr: BinaryExpression): Expression { override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
val expr2 = super.visit(expr) val leftDt = expr.left.inferType(program)
if(expr2 !is BinaryExpression) val rightDt = expr.right.inferType(program)
return expr2
val leftDt = expr2.left.inferType(program)
val rightDt = expr2.right.inferType(program)
if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) { if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) {
// determine common datatype and add typecast as required to make left and right equal types // determine common datatype and add typecast as required to make left and right equal types
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr2.left, expr2.right) val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr.left, expr.right)
if(toFix!=null) { if(toFix!=null) {
when { return when {
toFix===expr2.left -> { toFix===expr.left -> listOf(IAstModification.ReplaceExpr(
expr2.left = TypecastExpression(expr2.left, commonDt, true, expr2.left.position) { newExpr -> expr.left = newExpr },
expr2.left.linkParents(expr2) TypecastExpression(expr.left, commonDt, true, expr.left.position),
} expr))
toFix===expr2.right -> { toFix===expr.right -> listOf(IAstModification.ReplaceExpr(
expr2.right = TypecastExpression(expr2.right, commonDt, true, expr2.right.position) { newExpr -> expr.right = newExpr },
expr2.right.linkParents(expr2) TypecastExpression(expr.right, commonDt, true, expr.right.position),
} expr))
else -> throw FatalAstException("confused binary expression side") else -> throw FatalAstException("confused binary expression side")
} }
} }
} }
return expr2 return emptyList()
} }
override fun visit(assignment: Assignment): Statement { override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
val assg = super.visit(assignment)
if(assg !is Assignment)
return assg
// see if a typecast is needed to convert the value's type into the proper target type // see if a typecast is needed to convert the value's type into the proper target type
val valueItype = assg.value.inferType(program) val valueItype = assignment.value.inferType(program)
val targetItype = assg.target.inferType(program, assg) val targetItype = assignment.target.inferType(program, assignment)
if(targetItype.isKnown && valueItype.isKnown) { if(targetItype.isKnown && valueItype.isKnown) {
val targettype = targetItype.typeOrElse(DataType.STRUCT) val targettype = targetItype.typeOrElse(DataType.STRUCT)
val valuetype = valueItype.typeOrElse(DataType.STRUCT) val valuetype = valueItype.typeOrElse(DataType.STRUCT)
if (valuetype != targettype) { if (valuetype != targettype) {
if (valuetype isAssignableTo targettype) { if (valuetype isAssignableTo targettype)
assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position) return listOf(IAstModification.ReplaceExpr(
assg.value.linkParents(assg) { newExpr -> assignment.value=newExpr },
} TypecastExpression(assignment.value, targettype, true, assignment.value.position),
// if they're not assignable, we'll get a proper error later from the AstChecker assignment))
} }
} }
return assg return emptyList()
} }
override fun visit(functionCallStatement: FunctionCallStatement): Statement { override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
checkFunctionCallArguments(functionCallStatement, functionCallStatement.definingScope()) return afterFunctionCallArgs(functionCallStatement, functionCallStatement.definingScope())
return super.visit(functionCallStatement)
} }
override fun visit(functionCall: FunctionCall): Expression { override fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> {
checkFunctionCallArguments(functionCall, functionCall.definingScope()) return afterFunctionCallArgs(functionCall, functionCall.definingScope())
return super.visit(functionCall)
} }
private fun checkFunctionCallArguments(call: IFunctionCall, scope: INameScope) { private fun afterFunctionCallArgs(call: IFunctionCall, scope: INameScope): Iterable<IAstModification> {
// see if a typecast is needed to convert the arguments into the required parameter's type // see if a typecast is needed to convert the arguments into the required parameter's type
when(val sub = call.target.targetStatement(scope)) { return when(val sub = call.target.targetStatement(scope)) {
is Subroutine -> { is Subroutine -> {
for(arg in sub.parameters.zip(call.args.withIndex())) { for(arg in sub.parameters.zip(call.args.withIndex())) {
val argItype = arg.second.value.inferType(program) val argItype = arg.second.value.inferType(program)
@ -86,14 +78,15 @@ internal class TypecastsAdder(private val program: Program,
val requiredType = arg.first.type val requiredType = arg.first.type
if (requiredType != argtype) { if (requiredType != argtype) {
if (argtype isAssignableTo requiredType) { if (argtype isAssignableTo requiredType) {
val typecasted = TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position) return listOf(IAstModification.ReplaceExpr(
typecasted.linkParents(arg.second.value.parent) { newExpr -> call.args[arg.second.index] = newExpr },
call.args[arg.second.index] = typecasted TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position),
call as Node))
} }
// if they're not assignable, we'll get a proper error later from the AstChecker
} }
} }
} }
emptyList()
} }
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionStatementPlaceholder -> {
val func = BuiltinFunctions.getValue(sub.name) val func = BuiltinFunctions.getValue(sub.name)
@ -107,93 +100,103 @@ internal class TypecastsAdder(private val program: Program,
continue continue
for (possibleType in arg.first.possibleDatatypes) { for (possibleType in arg.first.possibleDatatypes) {
if (argtype isAssignableTo possibleType) { if (argtype isAssignableTo possibleType) {
val typecasted = TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position) return listOf(IAstModification.ReplaceExpr(
typecasted.linkParents(arg.second.value.parent) { newExpr -> call.args[arg.second.index] = newExpr },
call.args[arg.second.index] = typecasted TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position),
break call as Node
))
} }
} }
} }
} }
} }
emptyList()
} }
null -> {} null -> emptyList()
else -> throw FatalAstException("call to something weird $sub ${call.target}") else -> throw FatalAstException("call to something weird $sub ${call.target}")
} }
} }
override fun visit(typecast: TypecastExpression): Expression { override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
// warn about any implicit type casts to Float, because that may not be intended // warn about any implicit type casts to Float, because that may not be intended
if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) { if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) {
errors.warn("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position) errors.warn("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position)
} }
return super.visit(typecast) return emptyList()
} }
override fun visit(memread: DirectMemoryRead): Expression { override fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
// make sure the memory address is an uword // make sure the memory address is an uword
val dt = memread.addressExpression.inferType(program) val dt = memread.addressExpression.inferType(program)
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memread.addressExpression as? NumericLiteralValue val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
if(literaladdr!=null) { ?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
memread.addressExpression = literaladdr.cast(DataType.UWORD) return listOf(IAstModification.ReplaceExpr(
} else { { newExpr -> memread.addressExpression = newExpr },
memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) typecast,
memread.addressExpression.parent = memread memread
} ))
} }
return super.visit(memread) return emptyList()
} }
override fun visit(memwrite: DirectMemoryWrite) { override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> {
// make sure the memory address is an uword
val dt = memwrite.addressExpression.inferType(program) val dt = memwrite.addressExpression.inferType(program)
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memwrite.addressExpression as? NumericLiteralValue val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
if(literaladdr!=null) { ?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
memwrite.addressExpression = literaladdr.cast(DataType.UWORD) return listOf(IAstModification.ReplaceExpr(
} else { { newExpr -> memwrite.addressExpression = newExpr },
memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) typecast,
memwrite.addressExpression.parent = memwrite memwrite
} ))
} }
super.visit(memwrite) return emptyList()
} }
override fun visit(structLv: StructLiteralValue): Expression { override fun after(structLv: StructLiteralValue, parent: Node): Iterable<IAstModification> {
val litval = super.visit(structLv) // assignment of a struct literal value, some member values may need proper typecast
if(litval !is StructLiteralValue)
return litval
val decl = litval.parent as? VarDecl fun addTypecastsIfNeeded(struct: StructDecl): Iterable<IAstModification> {
val newValues = struct.statements.zip(structLv.values).map { (structMemberDecl, memberValue) ->
val memberDt = (structMemberDecl as VarDecl).datatype
val valueDt = memberValue.inferType(program)
if (valueDt.typeOrElse(memberDt) != memberDt)
TypecastExpression(memberValue, memberDt, true, memberValue.position)
else
memberValue
}
class Replacer(val targetStructLv: StructLiteralValue, val typecastValues: List<Expression>) : IAstModification {
override fun perform() {
targetStructLv.values = typecastValues
typecastValues.forEach { it.linkParents(targetStructLv) }
}
}
return if(structLv.values.zip(newValues).any { (v1, v2) -> v1 !== v2})
listOf(Replacer(structLv, newValues))
else
emptyList()
}
val decl = structLv.parent as? VarDecl
if(decl != null) { if(decl != null) {
val struct = decl.struct val struct = decl.struct
if(struct != null) { if(struct != null)
addTypecastsIfNeeded(litval, struct) return addTypecastsIfNeeded(struct)
}
} else { } else {
val assign = litval.parent as? Assignment val assign = structLv.parent as? Assignment
if (assign != null) { if (assign != null) {
val decl2 = assign.target.identifier?.targetVarDecl(program.namespace) val decl2 = assign.target.identifier?.targetVarDecl(program.namespace)
if(decl2 != null) { if(decl2 != null) {
val struct = decl2.struct val struct = decl2.struct
if(struct != null) { if(struct != null)
addTypecastsIfNeeded(litval, struct) return addTypecastsIfNeeded(struct)
}
} }
} }
} }
return emptyList()
return litval
}
private fun addTypecastsIfNeeded(structLv: StructLiteralValue, struct: StructDecl) {
structLv.values = struct.statements.zip(structLv.values).map {
val memberDt = (it.first as VarDecl).datatype
val valueDt = it.second.inferType(program)
if (valueDt.typeOrElse(memberDt) != memberDt)
TypecastExpression(it.second, memberDt, true, it.second.position)
else
it.second
}
} }
} }

View File

@ -5,6 +5,7 @@ TODO
- remove statements after an exit() or return - remove statements after an exit() or return
- fix warnings about that unreachable code? - fix warnings about that unreachable code?
- aliases for imported symbols for example perhaps '%alias print = c64scr.print'
- option to load library files from a directory instead of the embedded ones - option to load library files from a directory instead of the embedded ones

View File

@ -1,78 +1,19 @@
%import c64utils %import c64utils
%import c64flt
%option enable_floats
%zeropage basicsafe %zeropage basicsafe
main { main {
struct Color {
uword red
uword green
uword blue
}
sub start() { sub start() {
ubyte x11=44 ; Color c = [1,2,3] ; TODO fix compiler error
byte bb0=99 Color c = {1,2,3}
A=x11
while true {
A=99
}
repeat {
ubyte x1
x1=A
A=x1
if A==44 {
ubyte y1
A=y1
} else {
byte bb1=99
bb1 += A
bb0=bb1
}
A=44
} until false
c64scr.print("spstart:")
print_stackpointer()
sub1()
c64scr.print("spend:")
print_stackpointer()
} }
sub sub1() {
c64scr.print("sp1:")
print_stackpointer()
sub2()
}
sub sub2() {
c64scr.print("sp2:")
print_stackpointer()
exit(33)
sub3() ; TODO warning about unreachable code
sub3() ; TODO remove statements after a return/exit
c64scr.print("sp2:")
c64scr.print("sp2:")
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
}
sub sub3() {
c64scr.print("sp3:")
print_stackpointer()
}
sub print_stackpointer() {
c64scr.print_ub(X) ; prints stack pointer
c64.CHROUT('\n')
}
} }