replaced typecastsAdder with version based on astwalker

This commit is contained in:
Irmen de Jong 2020-03-20 22:28:18 +01:00
parent a191ec71a4
commit f265199fbe
7 changed files with 123 additions and 170 deletions

View File

@ -33,6 +33,7 @@ internal fun Program.reorderStatements() {
internal fun Program.addTypecasts(errors: ErrorReporter) {
val caster = TypecastsAdder(this, errors)
caster.visit(this)
caster.applyModifications()
}
internal fun Module.checkImportedValid() {

View File

@ -446,7 +446,7 @@ internal class AstChecker(private val program: Program,
override fun visit(decl: VarDecl) {
fun err(msg: String, position: Position?=null) {
err(msg, position ?: decl.position)
errors.err(msg, position ?: decl.position)
}
// the initializer value can't refer to the variable itself (recursive definition)

View File

@ -23,7 +23,7 @@ interface IAstModification {
}
}
class Replace(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification {
class ReplaceStmt(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification {
override fun perform() {
if(parent is INameScope) {
val idx = parent.statements.indexOf(statement)
@ -34,6 +34,13 @@ interface IAstModification {
}
}
}
class ReplaceExpr(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification {
override fun perform() {
setter(newExpr)
newExpr.linkParents(parent)
}
}
}

View File

@ -12,7 +12,7 @@ internal class ForeverLoopsMaker: AstWalker() {
val numeric = repeatLoop.untilCondition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() == 0) {
val forever = ForeverLoop(repeatLoop.body, repeatLoop.position)
return listOf(IAstModification.Replace(repeatLoop, forever, parent))
return listOf(IAstModification.ReplaceStmt(repeatLoop, forever, parent))
}
return emptyList()
}
@ -21,7 +21,7 @@ internal class ForeverLoopsMaker: AstWalker() {
val numeric = whileLoop.condition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() != 0) {
val forever = ForeverLoop(whileLoop.body, whileLoop.position)
return listOf(IAstModification.Replace(whileLoop, forever, parent))
return listOf(IAstModification.ReplaceStmt(whileLoop, forever, parent))
}
return emptyList()
}

View File

@ -2,6 +2,7 @@ package prog8.ast.processing
import prog8.ast.IFunctionCall
import prog8.ast.INameScope
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.DataType
import prog8.ast.base.ErrorReporter
@ -11,73 +12,64 @@ import prog8.ast.statements.*
import prog8.functions.BuiltinFunctions
internal class TypecastsAdder(private val program: Program,
private val errors: ErrorReporter): IAstModifyingVisitor {
// Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type.
// (this includes function call arguments)
class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalker() {
/*
* Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type.
* (this includes function call arguments)
*/
override fun visit(expr: BinaryExpression): Expression {
val expr2 = super.visit(expr)
if(expr2 !is BinaryExpression)
return expr2
val leftDt = expr2.left.inferType(program)
val rightDt = expr2.right.inferType(program)
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)
if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) {
// determine common datatype and add typecast as required to make left and right equal types
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr2.left, expr2.right)
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr.left, expr.right)
if(toFix!=null) {
when {
toFix===expr2.left -> {
expr2.left = TypecastExpression(expr2.left, commonDt, true, expr2.left.position)
expr2.left.linkParents(expr2)
}
toFix===expr2.right -> {
expr2.right = TypecastExpression(expr2.right, commonDt, true, expr2.right.position)
expr2.right.linkParents(expr2)
}
return when {
toFix===expr.left -> listOf(IAstModification.ReplaceExpr(
{ newExpr -> expr.left = newExpr },
TypecastExpression(expr.left, commonDt, true, expr.left.position),
expr))
toFix===expr.right -> listOf(IAstModification.ReplaceExpr(
{ newExpr -> expr.right = newExpr },
TypecastExpression(expr.right, commonDt, true, expr.right.position),
expr))
else -> throw FatalAstException("confused binary expression side")
}
}
}
return expr2
return emptyList()
}
override fun visit(assignment: Assignment): Statement {
val assg = super.visit(assignment)
if(assg !is Assignment)
return assg
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
// see if a typecast is needed to convert the value's type into the proper target type
val valueItype = assg.value.inferType(program)
val targetItype = assg.target.inferType(program, assg)
val valueItype = assignment.value.inferType(program)
val targetItype = assignment.target.inferType(program, assignment)
if(targetItype.isKnown && valueItype.isKnown) {
val targettype = targetItype.typeOrElse(DataType.STRUCT)
val valuetype = valueItype.typeOrElse(DataType.STRUCT)
if (valuetype != targettype) {
if (valuetype isAssignableTo targettype) {
assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position)
assg.value.linkParents(assg)
}
// if they're not assignable, we'll get a proper error later from the AstChecker
if (valuetype isAssignableTo targettype)
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> assignment.value=newExpr },
TypecastExpression(assignment.value, targettype, true, assignment.value.position),
assignment))
}
}
return assg
return emptyList()
}
override fun visit(functionCallStatement: FunctionCallStatement): Statement {
checkFunctionCallArguments(functionCallStatement, functionCallStatement.definingScope())
return super.visit(functionCallStatement)
override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
return afterFunctionCallArgs(functionCallStatement, functionCallStatement.definingScope())
}
override fun visit(functionCall: FunctionCall): Expression {
checkFunctionCallArguments(functionCall, functionCall.definingScope())
return super.visit(functionCall)
override fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> {
return afterFunctionCallArgs(functionCall, functionCall.definingScope())
}
private fun checkFunctionCallArguments(call: IFunctionCall, scope: INameScope) {
private fun afterFunctionCallArgs(call: IFunctionCall, scope: INameScope): Iterable<IAstModification> {
// see if a typecast is needed to convert the arguments into the required parameter's type
when(val sub = call.target.targetStatement(scope)) {
return when(val sub = call.target.targetStatement(scope)) {
is Subroutine -> {
for(arg in sub.parameters.zip(call.args.withIndex())) {
val argItype = arg.second.value.inferType(program)
@ -86,14 +78,15 @@ internal class TypecastsAdder(private val program: Program,
val requiredType = arg.first.type
if (requiredType != argtype) {
if (argtype isAssignableTo requiredType) {
val typecasted = TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position)
typecasted.linkParents(arg.second.value.parent)
call.args[arg.second.index] = typecasted
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> call.args[arg.second.index] = newExpr },
TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position),
call as Node))
}
// if they're not assignable, we'll get a proper error later from the AstChecker
}
}
}
emptyList()
}
is BuiltinFunctionStatementPlaceholder -> {
val func = BuiltinFunctions.getValue(sub.name)
@ -107,93 +100,103 @@ internal class TypecastsAdder(private val program: Program,
continue
for (possibleType in arg.first.possibleDatatypes) {
if (argtype isAssignableTo possibleType) {
val typecasted = TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position)
typecasted.linkParents(arg.second.value.parent)
call.args[arg.second.index] = typecasted
break
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> call.args[arg.second.index] = newExpr },
TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position),
call as Node
))
}
}
}
}
}
emptyList()
}
null -> {}
null -> emptyList()
else -> throw FatalAstException("call to something weird $sub ${call.target}")
}
}
override fun visit(typecast: TypecastExpression): Expression {
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
// warn about any implicit type casts to Float, because that may not be intended
if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) {
errors.warn("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position)
}
return super.visit(typecast)
return emptyList()
}
override fun visit(memread: DirectMemoryRead): Expression {
override fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
// make sure the memory address is an uword
val dt = memread.addressExpression.inferType(program)
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memread.addressExpression as? NumericLiteralValue
if(literaladdr!=null) {
memread.addressExpression = literaladdr.cast(DataType.UWORD)
} else {
memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
memread.addressExpression.parent = memread
}
val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> memread.addressExpression = newExpr },
typecast,
memread
))
}
return super.visit(memread)
return emptyList()
}
override fun visit(memwrite: DirectMemoryWrite) {
override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> {
// make sure the memory address is an uword
val dt = memwrite.addressExpression.inferType(program)
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val literaladdr = memwrite.addressExpression as? NumericLiteralValue
if(literaladdr!=null) {
memwrite.addressExpression = literaladdr.cast(DataType.UWORD)
} else {
memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
memwrite.addressExpression.parent = memwrite
}
val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> memwrite.addressExpression = newExpr },
typecast,
memwrite
))
}
super.visit(memwrite)
return emptyList()
}
override fun visit(structLv: StructLiteralValue): Expression {
val litval = super.visit(structLv)
if(litval !is StructLiteralValue)
return litval
override fun after(structLv: StructLiteralValue, parent: Node): Iterable<IAstModification> {
// assignment of a struct literal value, some member values may need proper typecast
val decl = litval.parent as? VarDecl
fun addTypecastsIfNeeded(struct: StructDecl): Iterable<IAstModification> {
val newValues = struct.statements.zip(structLv.values).map { (structMemberDecl, memberValue) ->
val memberDt = (structMemberDecl as VarDecl).datatype
val valueDt = memberValue.inferType(program)
if (valueDt.typeOrElse(memberDt) != memberDt)
TypecastExpression(memberValue, memberDt, true, memberValue.position)
else
memberValue
}
class Replacer(val targetStructLv: StructLiteralValue, val typecastValues: List<Expression>) : IAstModification {
override fun perform() {
targetStructLv.values = typecastValues
typecastValues.forEach { it.linkParents(targetStructLv) }
}
}
return if(structLv.values.zip(newValues).any { (v1, v2) -> v1 !== v2})
listOf(Replacer(structLv, newValues))
else
emptyList()
}
val decl = structLv.parent as? VarDecl
if(decl != null) {
val struct = decl.struct
if(struct != null) {
addTypecastsIfNeeded(litval, struct)
}
if(struct != null)
return addTypecastsIfNeeded(struct)
} else {
val assign = litval.parent as? Assignment
val assign = structLv.parent as? Assignment
if (assign != null) {
val decl2 = assign.target.identifier?.targetVarDecl(program.namespace)
if(decl2 != null) {
val struct = decl2.struct
if(struct != null) {
addTypecastsIfNeeded(litval, struct)
}
if(struct != null)
return addTypecastsIfNeeded(struct)
}
}
}
return litval
}
private fun addTypecastsIfNeeded(structLv: StructLiteralValue, struct: StructDecl) {
structLv.values = struct.statements.zip(structLv.values).map {
val memberDt = (it.first as VarDecl).datatype
val valueDt = it.second.inferType(program)
if (valueDt.typeOrElse(memberDt) != memberDt)
TypecastExpression(it.second, memberDt, true, it.second.position)
else
it.second
}
return emptyList()
}
}

View File

@ -5,6 +5,7 @@ TODO
- remove statements after an exit() or return
- fix warnings about that unreachable code?
- aliases for imported symbols for example perhaps '%alias print = c64scr.print'
- option to load library files from a directory instead of the embedded ones

View File

@ -1,78 +1,19 @@
%import c64utils
%import c64flt
%option enable_floats
%zeropage basicsafe
main {
struct Color {
uword red
uword green
uword blue
}
sub start() {
ubyte x11=44
byte bb0=99
A=x11
while true {
A=99
}
repeat {
ubyte x1
x1=A
A=x1
if A==44 {
ubyte y1
A=y1
} else {
byte bb1=99
bb1 += A
bb0=bb1
}
A=44
} until false
c64scr.print("spstart:")
print_stackpointer()
sub1()
c64scr.print("spend:")
print_stackpointer()
; Color c = [1,2,3] ; TODO fix compiler error
Color c = {1,2,3}
}
sub sub1() {
c64scr.print("sp1:")
print_stackpointer()
sub2()
}
sub sub2() {
c64scr.print("sp2:")
print_stackpointer()
exit(33)
sub3() ; TODO warning about unreachable code
sub3() ; TODO remove statements after a return/exit
c64scr.print("sp2:")
c64scr.print("sp2:")
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
sub3()
}
sub sub3() {
c64scr.print("sp3:")
print_stackpointer()
}
sub print_stackpointer() {
c64scr.print_ub(X) ; prints stack pointer
c64.CHROUT('\n')
}
}