mirror of
https://github.com/irmen/prog8.git
synced 2024-11-23 07:32:10 +00:00
replaced typecastsAdder with version based on astwalker
This commit is contained in:
parent
a191ec71a4
commit
f265199fbe
@ -33,6 +33,7 @@ internal fun Program.reorderStatements() {
|
||||
internal fun Program.addTypecasts(errors: ErrorReporter) {
|
||||
val caster = TypecastsAdder(this, errors)
|
||||
caster.visit(this)
|
||||
caster.applyModifications()
|
||||
}
|
||||
|
||||
internal fun Module.checkImportedValid() {
|
||||
|
@ -446,7 +446,7 @@ internal class AstChecker(private val program: Program,
|
||||
|
||||
override fun visit(decl: VarDecl) {
|
||||
fun err(msg: String, position: Position?=null) {
|
||||
err(msg, position ?: decl.position)
|
||||
errors.err(msg, position ?: decl.position)
|
||||
}
|
||||
|
||||
// the initializer value can't refer to the variable itself (recursive definition)
|
||||
|
@ -23,7 +23,7 @@ interface IAstModification {
|
||||
}
|
||||
}
|
||||
|
||||
class Replace(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification {
|
||||
class ReplaceStmt(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification {
|
||||
override fun perform() {
|
||||
if(parent is INameScope) {
|
||||
val idx = parent.statements.indexOf(statement)
|
||||
@ -34,6 +34,13 @@ interface IAstModification {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ReplaceExpr(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification {
|
||||
override fun perform() {
|
||||
setter(newExpr)
|
||||
newExpr.linkParents(parent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ internal class ForeverLoopsMaker: AstWalker() {
|
||||
val numeric = repeatLoop.untilCondition as? NumericLiteralValue
|
||||
if(numeric!=null && numeric.number.toInt() == 0) {
|
||||
val forever = ForeverLoop(repeatLoop.body, repeatLoop.position)
|
||||
return listOf(IAstModification.Replace(repeatLoop, forever, parent))
|
||||
return listOf(IAstModification.ReplaceStmt(repeatLoop, forever, parent))
|
||||
}
|
||||
return emptyList()
|
||||
}
|
||||
@ -21,7 +21,7 @@ internal class ForeverLoopsMaker: AstWalker() {
|
||||
val numeric = whileLoop.condition as? NumericLiteralValue
|
||||
if(numeric!=null && numeric.number.toInt() != 0) {
|
||||
val forever = ForeverLoop(whileLoop.body, whileLoop.position)
|
||||
return listOf(IAstModification.Replace(whileLoop, forever, parent))
|
||||
return listOf(IAstModification.ReplaceStmt(whileLoop, forever, parent))
|
||||
}
|
||||
return emptyList()
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package prog8.ast.processing
|
||||
|
||||
import prog8.ast.IFunctionCall
|
||||
import prog8.ast.INameScope
|
||||
import prog8.ast.Node
|
||||
import prog8.ast.Program
|
||||
import prog8.ast.base.DataType
|
||||
import prog8.ast.base.ErrorReporter
|
||||
@ -11,73 +12,64 @@ import prog8.ast.statements.*
|
||||
import prog8.functions.BuiltinFunctions
|
||||
|
||||
|
||||
internal class TypecastsAdder(private val program: Program,
|
||||
private val errors: ErrorReporter): IAstModifyingVisitor {
|
||||
// Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type.
|
||||
// (this includes function call arguments)
|
||||
class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalker() {
|
||||
/*
|
||||
* Make sure any value assignments get the proper type casts if needed to cast them into the target variable's type.
|
||||
* (this includes function call arguments)
|
||||
*/
|
||||
|
||||
override fun visit(expr: BinaryExpression): Expression {
|
||||
val expr2 = super.visit(expr)
|
||||
if(expr2 !is BinaryExpression)
|
||||
return expr2
|
||||
val leftDt = expr2.left.inferType(program)
|
||||
val rightDt = expr2.right.inferType(program)
|
||||
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
|
||||
val leftDt = expr.left.inferType(program)
|
||||
val rightDt = expr.right.inferType(program)
|
||||
if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) {
|
||||
// determine common datatype and add typecast as required to make left and right equal types
|
||||
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr2.left, expr2.right)
|
||||
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr.left, expr.right)
|
||||
if(toFix!=null) {
|
||||
when {
|
||||
toFix===expr2.left -> {
|
||||
expr2.left = TypecastExpression(expr2.left, commonDt, true, expr2.left.position)
|
||||
expr2.left.linkParents(expr2)
|
||||
}
|
||||
toFix===expr2.right -> {
|
||||
expr2.right = TypecastExpression(expr2.right, commonDt, true, expr2.right.position)
|
||||
expr2.right.linkParents(expr2)
|
||||
}
|
||||
return when {
|
||||
toFix===expr.left -> listOf(IAstModification.ReplaceExpr(
|
||||
{ newExpr -> expr.left = newExpr },
|
||||
TypecastExpression(expr.left, commonDt, true, expr.left.position),
|
||||
expr))
|
||||
toFix===expr.right -> listOf(IAstModification.ReplaceExpr(
|
||||
{ newExpr -> expr.right = newExpr },
|
||||
TypecastExpression(expr.right, commonDt, true, expr.right.position),
|
||||
expr))
|
||||
else -> throw FatalAstException("confused binary expression side")
|
||||
}
|
||||
}
|
||||
}
|
||||
return expr2
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
override fun visit(assignment: Assignment): Statement {
|
||||
val assg = super.visit(assignment)
|
||||
if(assg !is Assignment)
|
||||
return assg
|
||||
|
||||
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
|
||||
// see if a typecast is needed to convert the value's type into the proper target type
|
||||
val valueItype = assg.value.inferType(program)
|
||||
val targetItype = assg.target.inferType(program, assg)
|
||||
|
||||
val valueItype = assignment.value.inferType(program)
|
||||
val targetItype = assignment.target.inferType(program, assignment)
|
||||
if(targetItype.isKnown && valueItype.isKnown) {
|
||||
val targettype = targetItype.typeOrElse(DataType.STRUCT)
|
||||
val valuetype = valueItype.typeOrElse(DataType.STRUCT)
|
||||
if (valuetype != targettype) {
|
||||
if (valuetype isAssignableTo targettype) {
|
||||
assg.value = TypecastExpression(assg.value, targettype, true, assg.value.position)
|
||||
assg.value.linkParents(assg)
|
||||
}
|
||||
// if they're not assignable, we'll get a proper error later from the AstChecker
|
||||
if (valuetype isAssignableTo targettype)
|
||||
return listOf(IAstModification.ReplaceExpr(
|
||||
{ newExpr -> assignment.value=newExpr },
|
||||
TypecastExpression(assignment.value, targettype, true, assignment.value.position),
|
||||
assignment))
|
||||
}
|
||||
}
|
||||
return assg
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
override fun visit(functionCallStatement: FunctionCallStatement): Statement {
|
||||
checkFunctionCallArguments(functionCallStatement, functionCallStatement.definingScope())
|
||||
return super.visit(functionCallStatement)
|
||||
override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
|
||||
return afterFunctionCallArgs(functionCallStatement, functionCallStatement.definingScope())
|
||||
}
|
||||
|
||||
override fun visit(functionCall: FunctionCall): Expression {
|
||||
checkFunctionCallArguments(functionCall, functionCall.definingScope())
|
||||
return super.visit(functionCall)
|
||||
override fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> {
|
||||
return afterFunctionCallArgs(functionCall, functionCall.definingScope())
|
||||
}
|
||||
|
||||
private fun checkFunctionCallArguments(call: IFunctionCall, scope: INameScope) {
|
||||
private fun afterFunctionCallArgs(call: IFunctionCall, scope: INameScope): Iterable<IAstModification> {
|
||||
// see if a typecast is needed to convert the arguments into the required parameter's type
|
||||
when(val sub = call.target.targetStatement(scope)) {
|
||||
return when(val sub = call.target.targetStatement(scope)) {
|
||||
is Subroutine -> {
|
||||
for(arg in sub.parameters.zip(call.args.withIndex())) {
|
||||
val argItype = arg.second.value.inferType(program)
|
||||
@ -86,14 +78,15 @@ internal class TypecastsAdder(private val program: Program,
|
||||
val requiredType = arg.first.type
|
||||
if (requiredType != argtype) {
|
||||
if (argtype isAssignableTo requiredType) {
|
||||
val typecasted = TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position)
|
||||
typecasted.linkParents(arg.second.value.parent)
|
||||
call.args[arg.second.index] = typecasted
|
||||
return listOf(IAstModification.ReplaceExpr(
|
||||
{ newExpr -> call.args[arg.second.index] = newExpr },
|
||||
TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position),
|
||||
call as Node))
|
||||
}
|
||||
// if they're not assignable, we'll get a proper error later from the AstChecker
|
||||
}
|
||||
}
|
||||
}
|
||||
emptyList()
|
||||
}
|
||||
is BuiltinFunctionStatementPlaceholder -> {
|
||||
val func = BuiltinFunctions.getValue(sub.name)
|
||||
@ -107,93 +100,103 @@ internal class TypecastsAdder(private val program: Program,
|
||||
continue
|
||||
for (possibleType in arg.first.possibleDatatypes) {
|
||||
if (argtype isAssignableTo possibleType) {
|
||||
val typecasted = TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position)
|
||||
typecasted.linkParents(arg.second.value.parent)
|
||||
call.args[arg.second.index] = typecasted
|
||||
break
|
||||
return listOf(IAstModification.ReplaceExpr(
|
||||
{ newExpr -> call.args[arg.second.index] = newExpr },
|
||||
TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position),
|
||||
call as Node
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
emptyList()
|
||||
}
|
||||
null -> {}
|
||||
null -> emptyList()
|
||||
else -> throw FatalAstException("call to something weird $sub ${call.target}")
|
||||
}
|
||||
}
|
||||
|
||||
override fun visit(typecast: TypecastExpression): Expression {
|
||||
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
|
||||
// warn about any implicit type casts to Float, because that may not be intended
|
||||
if(typecast.implicit && typecast.type in setOf(DataType.FLOAT, DataType.ARRAY_F)) {
|
||||
errors.warn("byte or word value implicitly converted to float. Suggestion: use explicit cast as float, a float number, or revert to integer arithmetic", typecast.position)
|
||||
}
|
||||
return super.visit(typecast)
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
override fun visit(memread: DirectMemoryRead): Expression {
|
||||
override fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
|
||||
// make sure the memory address is an uword
|
||||
val dt = memread.addressExpression.inferType(program)
|
||||
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
|
||||
val literaladdr = memread.addressExpression as? NumericLiteralValue
|
||||
if(literaladdr!=null) {
|
||||
memread.addressExpression = literaladdr.cast(DataType.UWORD)
|
||||
} else {
|
||||
memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
|
||||
memread.addressExpression.parent = memread
|
||||
}
|
||||
val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
|
||||
?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
|
||||
return listOf(IAstModification.ReplaceExpr(
|
||||
{ newExpr -> memread.addressExpression = newExpr },
|
||||
typecast,
|
||||
memread
|
||||
))
|
||||
}
|
||||
return super.visit(memread)
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
override fun visit(memwrite: DirectMemoryWrite) {
|
||||
override fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> {
|
||||
// make sure the memory address is an uword
|
||||
val dt = memwrite.addressExpression.inferType(program)
|
||||
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
|
||||
val literaladdr = memwrite.addressExpression as? NumericLiteralValue
|
||||
if(literaladdr!=null) {
|
||||
memwrite.addressExpression = literaladdr.cast(DataType.UWORD)
|
||||
} else {
|
||||
memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
|
||||
memwrite.addressExpression.parent = memwrite
|
||||
}
|
||||
val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
|
||||
?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
|
||||
return listOf(IAstModification.ReplaceExpr(
|
||||
{ newExpr -> memwrite.addressExpression = newExpr },
|
||||
typecast,
|
||||
memwrite
|
||||
))
|
||||
}
|
||||
super.visit(memwrite)
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
override fun visit(structLv: StructLiteralValue): Expression {
|
||||
val litval = super.visit(structLv)
|
||||
if(litval !is StructLiteralValue)
|
||||
return litval
|
||||
override fun after(structLv: StructLiteralValue, parent: Node): Iterable<IAstModification> {
|
||||
// assignment of a struct literal value, some member values may need proper typecast
|
||||
|
||||
val decl = litval.parent as? VarDecl
|
||||
fun addTypecastsIfNeeded(struct: StructDecl): Iterable<IAstModification> {
|
||||
val newValues = struct.statements.zip(structLv.values).map { (structMemberDecl, memberValue) ->
|
||||
val memberDt = (structMemberDecl as VarDecl).datatype
|
||||
val valueDt = memberValue.inferType(program)
|
||||
if (valueDt.typeOrElse(memberDt) != memberDt)
|
||||
TypecastExpression(memberValue, memberDt, true, memberValue.position)
|
||||
else
|
||||
memberValue
|
||||
}
|
||||
|
||||
class Replacer(val targetStructLv: StructLiteralValue, val typecastValues: List<Expression>) : IAstModification {
|
||||
override fun perform() {
|
||||
targetStructLv.values = typecastValues
|
||||
typecastValues.forEach { it.linkParents(targetStructLv) }
|
||||
}
|
||||
}
|
||||
|
||||
return if(structLv.values.zip(newValues).any { (v1, v2) -> v1 !== v2})
|
||||
listOf(Replacer(structLv, newValues))
|
||||
else
|
||||
emptyList()
|
||||
}
|
||||
|
||||
val decl = structLv.parent as? VarDecl
|
||||
if(decl != null) {
|
||||
val struct = decl.struct
|
||||
if(struct != null) {
|
||||
addTypecastsIfNeeded(litval, struct)
|
||||
}
|
||||
if(struct != null)
|
||||
return addTypecastsIfNeeded(struct)
|
||||
} else {
|
||||
val assign = litval.parent as? Assignment
|
||||
val assign = structLv.parent as? Assignment
|
||||
if (assign != null) {
|
||||
val decl2 = assign.target.identifier?.targetVarDecl(program.namespace)
|
||||
if(decl2 != null) {
|
||||
val struct = decl2.struct
|
||||
if(struct != null) {
|
||||
addTypecastsIfNeeded(litval, struct)
|
||||
}
|
||||
if(struct != null)
|
||||
return addTypecastsIfNeeded(struct)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return litval
|
||||
}
|
||||
|
||||
private fun addTypecastsIfNeeded(structLv: StructLiteralValue, struct: StructDecl) {
|
||||
structLv.values = struct.statements.zip(structLv.values).map {
|
||||
val memberDt = (it.first as VarDecl).datatype
|
||||
val valueDt = it.second.inferType(program)
|
||||
if (valueDt.typeOrElse(memberDt) != memberDt)
|
||||
TypecastExpression(it.second, memberDt, true, it.second.position)
|
||||
else
|
||||
it.second
|
||||
}
|
||||
return emptyList()
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ TODO
|
||||
- remove statements after an exit() or return
|
||||
- fix warnings about that unreachable code?
|
||||
|
||||
- aliases for imported symbols for example perhaps '%alias print = c64scr.print'
|
||||
- option to load library files from a directory instead of the embedded ones
|
||||
|
||||
|
||||
|
@ -1,78 +1,19 @@
|
||||
%import c64utils
|
||||
%import c64flt
|
||||
%option enable_floats
|
||||
%zeropage basicsafe
|
||||
|
||||
main {
|
||||
|
||||
struct Color {
|
||||
uword red
|
||||
uword green
|
||||
uword blue
|
||||
}
|
||||
|
||||
sub start() {
|
||||
ubyte x11=44
|
||||
byte bb0=99
|
||||
|
||||
A=x11
|
||||
|
||||
|
||||
while true {
|
||||
A=99
|
||||
}
|
||||
|
||||
repeat {
|
||||
ubyte x1
|
||||
|
||||
x1=A
|
||||
A=x1
|
||||
|
||||
if A==44 {
|
||||
ubyte y1
|
||||
A=y1
|
||||
} else {
|
||||
byte bb1=99
|
||||
bb1 += A
|
||||
bb0=bb1
|
||||
}
|
||||
A=44
|
||||
} until false
|
||||
|
||||
|
||||
c64scr.print("spstart:")
|
||||
print_stackpointer()
|
||||
sub1()
|
||||
c64scr.print("spend:")
|
||||
print_stackpointer()
|
||||
; Color c = [1,2,3] ; TODO fix compiler error
|
||||
Color c = {1,2,3}
|
||||
}
|
||||
|
||||
sub sub1() {
|
||||
c64scr.print("sp1:")
|
||||
print_stackpointer()
|
||||
sub2()
|
||||
}
|
||||
|
||||
sub sub2() {
|
||||
c64scr.print("sp2:")
|
||||
print_stackpointer()
|
||||
exit(33)
|
||||
sub3() ; TODO warning about unreachable code
|
||||
sub3() ; TODO remove statements after a return/exit
|
||||
c64scr.print("sp2:")
|
||||
c64scr.print("sp2:")
|
||||
sub3()
|
||||
|
||||
sub3()
|
||||
sub3()
|
||||
sub3()
|
||||
sub3()
|
||||
sub3()
|
||||
sub3()
|
||||
sub3()
|
||||
sub3()
|
||||
sub3()
|
||||
}
|
||||
|
||||
sub sub3() {
|
||||
c64scr.print("sp3:")
|
||||
print_stackpointer()
|
||||
}
|
||||
|
||||
sub print_stackpointer() {
|
||||
c64scr.print_ub(X) ; prints stack pointer
|
||||
c64.CHROUT('\n')
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user