Merge branch 'master' into c128target

# Conflicts:
#	examples/test.p8
This commit is contained in:
Irmen de Jong
2021-12-28 13:57:27 +01:00
33 changed files with 937 additions and 606 deletions

View File

@@ -760,7 +760,7 @@ class AsmGen(private val program: Program,
outputSourceLine(stmt) outputSourceLine(stmt)
when(stmt) { when(stmt) {
is VarDecl -> translate(stmt) is VarDecl -> translate(stmt)
is NopStatement -> {} is Nop -> {}
is Directive -> translate(stmt) is Directive -> translate(stmt)
is Return -> translate(stmt) is Return -> translate(stmt)
is Subroutine -> translateSubroutine(stmt) is Subroutine -> translateSubroutine(stmt)
@@ -775,24 +775,21 @@ class AsmGen(private val program: Program,
} }
} }
is Assignment -> assignmentAsmGen.translate(stmt) is Assignment -> assignmentAsmGen.translate(stmt)
is Jump -> translate(stmt) is Jump -> jmp(getJumpTarget(stmt))
is GoSub -> translate(stmt)
is PostIncrDecr -> postincrdecrAsmGen.translate(stmt) is PostIncrDecr -> postincrdecrAsmGen.translate(stmt)
is Label -> translate(stmt) is Label -> translate(stmt)
is BranchStatement -> translate(stmt) is Branch -> translate(stmt)
is IfStatement -> translate(stmt) is IfElse -> translate(stmt)
is ForLoop -> forloopsAsmGen.translate(stmt) is ForLoop -> forloopsAsmGen.translate(stmt)
is Break -> {
if(loopEndLabels.isEmpty())
throw AssemblyError("break statement out of context ${stmt.position}")
jmp(loopEndLabels.peek())
}
is WhileLoop -> translate(stmt)
is RepeatLoop -> translate(stmt) is RepeatLoop -> translate(stmt)
is UntilLoop -> translate(stmt) is When -> translate(stmt)
is WhenStatement -> translate(stmt)
is BuiltinFunctionStatementPlaceholder -> throw AssemblyError("builtin function should not have placeholder anymore?")
is AnonymousScope -> translate(stmt) is AnonymousScope -> translate(stmt)
is BuiltinFunctionPlaceholder -> throw AssemblyError("builtin function should not have placeholder anymore")
is UntilLoop -> throw AssemblyError("do..until should have been desugared to jumps")
is WhileLoop -> throw AssemblyError("while should have been desugared to jumps")
is Block -> throw AssemblyError("block should have been handled elsewhere") is Block -> throw AssemblyError("block should have been handled elsewhere")
is Break -> throw AssemblyError("break should have been replaced by goto")
else -> throw AssemblyError("missing asm translation for $stmt") else -> throw AssemblyError("missing asm translation for $stmt")
} }
} }
@@ -893,11 +890,11 @@ class AsmGen(private val program: Program,
internal fun translateExpression(expression: Expression) = internal fun translateExpression(expression: Expression) =
expressionsAsmGen.translateExpression(expression) expressionsAsmGen.translateExpression(expression)
internal fun translateBuiltinFunctionCallExpression(functionCall: FunctionCall, signature: FSignature, resultToStack: Boolean, resultRegister: RegisterOrPair?) = internal fun translateBuiltinFunctionCallExpression(functionCallExpr: FunctionCallExpr, signature: FSignature, resultToStack: Boolean, resultRegister: RegisterOrPair?) =
builtinFunctionsAsmGen.translateFunctioncallExpression(functionCall, signature, resultToStack, resultRegister) builtinFunctionsAsmGen.translateFunctioncallExpression(functionCallExpr, signature, resultToStack, resultRegister)
internal fun translateFunctionCall(functionCall: FunctionCall, isExpression: Boolean) = internal fun translateFunctionCall(functionCallExpr: FunctionCallExpr, isExpression: Boolean) =
functioncallAsmGen.translateFunctionCall(functionCall, isExpression) functioncallAsmGen.translateFunctionCall(functionCallExpr, isExpression)
internal fun saveXbeforeCall(functionCall: IFunctionCall) = internal fun saveXbeforeCall(functionCall: IFunctionCall) =
functioncallAsmGen.saveXbeforeCall(functionCall) functioncallAsmGen.saveXbeforeCall(functionCall)
@@ -1097,21 +1094,26 @@ class AsmGen(private val program: Program,
} }
} }
private fun translate(stmt: IfStatement) { private fun translate(stmt: IfElse) {
requireComparisonExpression(stmt.condition) // IfStatement: condition must be of form 'x <comparison> <value>' requireComparisonExpression(stmt.condition) // IfStatement: condition must be of form 'x <comparison> <value>'
val booleanCondition = stmt.condition as BinaryExpression val booleanCondition = stmt.condition as BinaryExpression
if (stmt.elsepart.isEmpty()) { if (stmt.elsepart.isEmpty()) {
val jump = stmt.truepart.statements.singleOrNull()
if(jump is Jump) {
translateCompareAndJumpIfTrue(booleanCondition, jump)
} else {
val endLabel = makeLabel("if_end") val endLabel = makeLabel("if_end")
translateComparisonExpressionWithJumpIfFalse(booleanCondition, endLabel) translateCompareAndJumpIfFalse(booleanCondition, endLabel)
translate(stmt.truepart) translate(stmt.truepart)
out(endLabel) out(endLabel)
} }
}
else { else {
// both true and else parts // both true and else parts
val elseLabel = makeLabel("if_else") val elseLabel = makeLabel("if_else")
val endLabel = makeLabel("if_end") val endLabel = makeLabel("if_end")
translateComparisonExpressionWithJumpIfFalse(booleanCondition, elseLabel) translateCompareAndJumpIfFalse(booleanCondition, elseLabel)
translate(stmt.truepart) translate(stmt.truepart)
jmp(endLabel) jmp(endLabel)
out(elseLabel) out(elseLabel)
@@ -1272,36 +1274,7 @@ $repeatLabel lda $counterVar
return counterVar return counterVar
} }
private fun translate(stmt: WhileLoop) { private fun translate(stmt: When) {
requireComparisonExpression(stmt.condition) // WhileLoop: condition must be of form 'x <comparison> <value>'
val booleanCondition = stmt.condition as BinaryExpression
val whileLabel = makeLabel("while")
val endLabel = makeLabel("whileend")
loopEndLabels.push(endLabel)
out(whileLabel)
translateComparisonExpressionWithJumpIfFalse(booleanCondition, endLabel)
translate(stmt.body)
jmp(whileLabel)
out(endLabel)
loopEndLabels.pop()
}
private fun translate(stmt: UntilLoop) {
requireComparisonExpression(stmt.condition) // UntilLoop: condition must be of form 'x <comparison> <value>'
val booleanCondition = stmt.condition as BinaryExpression
val repeatLabel = makeLabel("repeat")
val endLabel = makeLabel("repeatend")
loopEndLabels.push(endLabel)
out(repeatLabel)
translate(stmt.body)
translateComparisonExpressionWithJumpIfFalse(booleanCondition, repeatLabel)
out(endLabel)
loopEndLabels.pop()
}
private fun translate(stmt: WhenStatement) {
val endLabel = makeLabel("choice_end") val endLabel = makeLabel("choice_end")
val choiceBlocks = mutableListOf<Pair<String, AnonymousScope>>() val choiceBlocks = mutableListOf<Pair<String, AnonymousScope>>()
val conditionDt = stmt.condition.inferType(program) val conditionDt = stmt.condition.inferType(program)
@@ -1355,45 +1328,23 @@ $repeatLabel lda $counterVar
scope.statements.forEach{ translate(it) } scope.statements.forEach{ translate(it) }
} }
private fun translate(stmt: BranchStatement) { private fun translate(stmt: Branch) {
if(stmt.truepart.isEmpty() && stmt.elsepart.isNotEmpty()) if(stmt.truepart.isEmpty() && stmt.elsepart.isNotEmpty())
throw AssemblyError("only else part contains code, shoud have been switched already") throw AssemblyError("only else part contains code, shoud have been switched already")
val jump = stmt.truepart.statements.first() as? Jump val jump = stmt.truepart.statements.first() as? Jump
if(jump!=null && !jump.isGosub) { if(jump!=null) {
// branch with only a jump (goto) // branch with only a jump (goto)
val instruction = branchInstruction(stmt.condition, false) val instruction = branchInstruction(stmt.condition, false)
out(" $instruction ${getJumpTarget(jump)}") out(" $instruction ${getJumpTarget(jump)}")
translate(stmt.elsepart) translate(stmt.elsepart)
} else { } else {
val truePartIsJustBreak = stmt.truepart.statements.firstOrNull() is Break
val elsePartIsJustBreak = stmt.elsepart.statements.firstOrNull() is Break
if(stmt.elsepart.isEmpty()) { if(stmt.elsepart.isEmpty()) {
if(truePartIsJustBreak) {
// branch with just a break (jump out of loop)
val instruction = branchInstruction(stmt.condition, false)
val loopEndLabel = loopEndLabels.peek()
out(" $instruction $loopEndLabel")
} else {
val instruction = branchInstruction(stmt.condition, true) val instruction = branchInstruction(stmt.condition, true)
val elseLabel = makeLabel("branch_else") val elseLabel = makeLabel("branch_else")
out(" $instruction $elseLabel") out(" $instruction $elseLabel")
translate(stmt.truepart) translate(stmt.truepart)
out(elseLabel) out(elseLabel)
}
}
else if(truePartIsJustBreak) {
// branch with just a break (jump out of loop)
val instruction = branchInstruction(stmt.condition, false)
val loopEndLabel = loopEndLabels.peek()
out(" $instruction $loopEndLabel")
translate(stmt.elsepart)
} else if(elsePartIsJustBreak) {
// branch with just a break (jump out of loop) but true/false inverted
val instruction = branchInstruction(stmt.condition, true)
val loopEndLabel = loopEndLabels.peek()
out(" $instruction $loopEndLabel")
translate(stmt.truepart)
} else { } else {
val instruction = branchInstruction(stmt.condition, true) val instruction = branchInstruction(stmt.condition, true)
val elseLabel = makeLabel("branch_else") val elseLabel = makeLabel("branch_else")
@@ -1450,22 +1401,17 @@ $label nop""")
} }
} }
private fun translate(jump: Jump) { private fun translate(gosub: GoSub) {
if(jump.isGosub) { val tgt = gosub.identifier!!.targetSubroutine(program)
jump as GoSub
val tgt = jump.identifier!!.targetSubroutine(program)
if(tgt!=null && tgt.isAsmSubroutine) { if(tgt!=null && tgt.isAsmSubroutine) {
// no need to rescue X , this has been taken care of already // no need to rescue X , this has been taken care of already
out(" jsr ${getJumpTarget(jump)}") out(" jsr ${getJumpTarget(gosub)}")
} else { } else {
saveXbeforeCall(jump) saveXbeforeCall(gosub)
out(" jsr ${getJumpTarget(jump)}") out(" jsr ${getJumpTarget(gosub)}")
restoreXafterCall(jump) restoreXafterCall(gosub)
} }
} }
else
jmp(getJumpTarget(jump))
}
private fun getJumpTarget(jump: Jump): String { private fun getJumpTarget(jump: Jump): String {
val ident = jump.identifier val ident = jump.identifier
@@ -1479,6 +1425,18 @@ $label nop""")
} }
} }
private fun getJumpTarget(gosub: GoSub): String {
val ident = gosub.identifier
val label = gosub.generatedLabel
val addr = gosub.address
return when {
ident!=null -> asmSymbolName(ident)
label!=null -> label
addr!=null -> addr.toHex()
else -> "????"
}
}
private fun translate(ret: Return, withRts: Boolean=true) { private fun translate(ret: Return, withRts: Boolean=true) {
ret.value?.let { returnvalue -> ret.value?.let { returnvalue ->
val sub = ret.definingSubroutine!! val sub = ret.definingSubroutine!!
@@ -1662,41 +1620,46 @@ $label nop""")
return false return false
} }
private fun translateCompareAndJumpIfTrue(expr: BinaryExpression, jump: Jump) {
if(expr.operator !in ComparisonOperators)
throw AssemblyError("must be comparison expression")
// invert the comparison, so we can reuse the JumpIfFalse code generation routines
val invertedComparisonOperator = invertedComparisonOperator(expr.operator)
?: throw AssemblyError("can't invert comparison $expr")
private fun translateComparisonExpressionWithJumpIfFalse(expr: BinaryExpression, jumpIfFalseLabel: String) { val left = expr.left
// This is a helper routine called from while, do-util, and if expressions to generate optimized conditional branching code. val right = expr.right
// First, if it is of the form: <constvalue> <comparison> X , then flip the expression so the constant is always the right operand. val rightConstVal = right.constValue(program)
var left = expr.left val label = when {
var right = expr.right jump.generatedLabel!=null -> jump.generatedLabel!!
var operator = expr.operator jump.identifier!=null -> asmSymbolName(jump.identifier!!)
var leftConstVal = left.constValue(program) jump.address!=null -> jump.address!!.toHex()
var rightConstVal = right.constValue(program) else -> throw AssemblyError("weird jump")
}
// make sure the constant value is on the right of the comparison expression if (rightConstVal!=null && rightConstVal.number == 0.0)
if(leftConstVal!=null) { testZeroAndJump(left, invertedComparisonOperator, label)
val tmp = left else {
left = right val leftConstVal = left.constValue(program)
right = tmp testNonzeroComparisonAndJump(left, invertedComparisonOperator, right, label, leftConstVal, rightConstVal)
val tmp2 = leftConstVal
leftConstVal = rightConstVal
rightConstVal = tmp2
when(expr.operator) {
"<" -> operator = ">"
"<=" -> operator = ">="
">" -> operator = "<"
">=" -> operator = "<="
} }
} }
private fun translateCompareAndJumpIfFalse(expr: BinaryExpression, jumpIfFalseLabel: String) {
val left = expr.left
val right = expr.right
val operator = expr.operator
val leftConstVal = left.constValue(program)
val rightConstVal = right.constValue(program)
if (rightConstVal!=null && rightConstVal.number == 0.0) if (rightConstVal!=null && rightConstVal.number == 0.0)
jumpIfZeroOrNot(left, operator, jumpIfFalseLabel) testZeroAndJump(left, operator, jumpIfFalseLabel)
else else
jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal) testNonzeroComparisonAndJump(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal)
} }
private fun jumpIfZeroOrNot( private fun testZeroAndJump(
left: Expression, left: Expression,
operator: String, operator: String,
jumpIfFalseLabel: String jumpIfFalseLabel: String
@@ -1711,7 +1674,7 @@ $label nop""")
} }
if(dt==DataType.UBYTE) { if(dt==DataType.UBYTE) {
assignExpressionToRegister(left, RegisterOrPair.A, false) assignExpressionToRegister(left, RegisterOrPair.A, false)
if (left is FunctionCall && !left.isSimple) if (left is FunctionCallExpr && !left.isSimple)
out(" cmp #0") out(" cmp #0")
} else { } else {
assignExpressionToRegister(left, RegisterOrPair.AY, false) assignExpressionToRegister(left, RegisterOrPair.AY, false)
@@ -1727,7 +1690,7 @@ $label nop""")
} }
DataType.BYTE -> { DataType.BYTE -> {
assignExpressionToRegister(left, RegisterOrPair.A, true) assignExpressionToRegister(left, RegisterOrPair.A, true)
if (left is FunctionCall && !left.isSimple) if (left is FunctionCallExpr && !left.isSimple)
out(" cmp #0") out(" cmp #0")
when (operator) { when (operator) {
"==" -> out(" bne $jumpIfFalseLabel") "==" -> out(" bne $jumpIfFalseLabel")
@@ -1785,7 +1748,7 @@ $label nop""")
} }
} }
private fun jumpIfComparison( private fun testNonzeroComparisonAndJump(
left: Expression, left: Expression,
operator: String, operator: String,
right: Expression, right: Expression,

View File

@@ -44,7 +44,7 @@ internal fun asmsub6502ArgsHaveRegisterClobberRisk(args: List<Expression>,
it.registerOrPair in listOf(RegisterOrPair.Y, RegisterOrPair.AY, RegisterOrPair.XY) it.registerOrPair in listOf(RegisterOrPair.Y, RegisterOrPair.AY, RegisterOrPair.XY)
} }
} }
is FunctionCall -> { is FunctionCallExpr -> {
if (expr.target.nameInSource == listOf("lsb") || expr.target.nameInSource == listOf("msb")) if (expr.target.nameInSource == listOf("lsb") || expr.target.nameInSource == listOf("msb"))
return isClobberRisk(expr.args[0]) return isClobberRisk(expr.args[0])
if (expr.target.nameInSource == listOf("mkword")) if (expr.target.nameInSource == listOf("mkword"))

View File

@@ -19,7 +19,7 @@ import prog8.compilerinterface.subroutineFloatEvalResultVar2
internal class BuiltinFunctionsAsmGen(private val program: Program, private val asmgen: AsmGen, private val assignAsmGen: AssignmentAsmGen) { internal class BuiltinFunctionsAsmGen(private val program: Program, private val asmgen: AsmGen, private val assignAsmGen: AssignmentAsmGen) {
internal fun translateFunctioncallExpression(fcall: FunctionCall, func: FSignature, resultToStack: Boolean, resultRegister: RegisterOrPair?) { internal fun translateFunctioncallExpression(fcall: FunctionCallExpr, func: FSignature, resultToStack: Boolean, resultRegister: RegisterOrPair?) {
translateFunctioncall(fcall, func, discardResult = false, resultToStack = resultToStack, resultRegister = resultRegister) translateFunctioncall(fcall, func, discardResult = false, resultToStack = resultToStack, resultRegister = resultRegister)
} }
@@ -414,7 +414,7 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, private val
} }
private fun funcMemory(fcall: IFunctionCall, discardResult: Boolean, resultToStack: Boolean, resultRegister: RegisterOrPair?) { private fun funcMemory(fcall: IFunctionCall, discardResult: Boolean, resultToStack: Boolean, resultRegister: RegisterOrPair?) {
if(discardResult || fcall !is FunctionCall) if(discardResult || fcall !is FunctionCallExpr)
throw AssemblyError("should not discard result of memory allocation at $fcall") throw AssemblyError("should not discard result of memory allocation at $fcall")
val nameRef = fcall.args[0] as IdentifierReference val nameRef = fcall.args[0] as IdentifierReference
val name = (nameRef.targetVarDecl(program)!!.value as StringLiteralValue).value val name = (nameRef.targetVarDecl(program)!!.value as StringLiteralValue).value

View File

@@ -3,7 +3,7 @@ package prog8.compiler.target.cpu6502.codegen
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.BuiltinFunctionStatementPlaceholder import prog8.ast.statements.BuiltinFunctionPlaceholder
import prog8.ast.statements.Subroutine import prog8.ast.statements.Subroutine
import prog8.ast.toHex import prog8.ast.toHex
import prog8.compiler.target.AssemblyError import prog8.compiler.target.AssemblyError
@@ -36,18 +36,18 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
is DirectMemoryRead -> asmgen.translateDirectMemReadExpressionToRegAorStack(expression, true) is DirectMemoryRead -> asmgen.translateDirectMemReadExpressionToRegAorStack(expression, true)
is NumericLiteralValue -> translateExpression(expression) is NumericLiteralValue -> translateExpression(expression)
is IdentifierReference -> translateExpression(expression) is IdentifierReference -> translateExpression(expression)
is FunctionCall -> translateFunctionCallResultOntoStack(expression) is FunctionCallExpr -> translateFunctionCallResultOntoStack(expression)
is ArrayLiteralValue, is StringLiteralValue -> throw AssemblyError("no asm gen for string/array literal value assignment - should have been replaced by a variable") is ArrayLiteralValue, is StringLiteralValue -> throw AssemblyError("no asm gen for string/array literal value assignment - should have been replaced by a variable")
is RangeExpr -> throw AssemblyError("range expression should have been changed into array values") is RangeExpr -> throw AssemblyError("range expression should have been changed into array values")
is CharLiteral -> throw AssemblyError("charliteral should have been replaced by ubyte using certain encoding") is CharLiteral -> throw AssemblyError("charliteral should have been replaced by ubyte using certain encoding")
} }
} }
private fun translateFunctionCallResultOntoStack(call: FunctionCall) { private fun translateFunctionCallResultOntoStack(call: FunctionCallExpr) {
// only for use in nested expression evaluation // only for use in nested expression evaluation
val sub = call.target.targetStatement(program) val sub = call.target.targetStatement(program)
if(sub is BuiltinFunctionStatementPlaceholder) { if(sub is BuiltinFunctionPlaceholder) {
val builtinFunc = BuiltinFunctions.getValue(sub.name) val builtinFunc = BuiltinFunctions.getValue(sub.name)
asmgen.translateBuiltinFunctionCallExpression(call, builtinFunc, true, null) asmgen.translateBuiltinFunctionCallExpression(call, builtinFunc, true, null)
} else { } else {

View File

@@ -167,7 +167,7 @@ internal class AsmAssignSource(val kind: SourceStorageKind,
val dt = value.inferType(program).getOrElse { throw AssemblyError("unknown dt") } val dt = value.inferType(program).getOrElse { throw AssemblyError("unknown dt") }
AsmAssignSource(SourceStorageKind.ARRAY, program, asmgen, dt, array = value) AsmAssignSource(SourceStorageKind.ARRAY, program, asmgen, dt, array = value)
} }
is FunctionCall -> { is FunctionCallExpr -> {
when (val sub = value.target.targetStatement(program)) { when (val sub = value.target.targetStatement(program)) {
is Subroutine -> { is Subroutine -> {
val returnType = sub.returntypes.zip(sub.asmReturnvaluesRegisters).firstOrNull { rr -> rr.second.registerOrPair != null || rr.second.statusflag!=null }?.first val returnType = sub.returntypes.zip(sub.asmReturnvaluesRegisters).firstOrNull { rr -> rr.second.registerOrPair != null || rr.second.statusflag!=null }?.first
@@ -175,7 +175,7 @@ internal class AsmAssignSource(val kind: SourceStorageKind,
AsmAssignSource(SourceStorageKind.EXPRESSION, program, asmgen, returnType, expression = value) AsmAssignSource(SourceStorageKind.EXPRESSION, program, asmgen, returnType, expression = value)
} }
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionPlaceholder -> {
val returnType = value.inferType(program) val returnType = value.inferType(program)
AsmAssignSource(SourceStorageKind.EXPRESSION, program, asmgen, returnType.getOrElse { throw AssemblyError("unknown dt") }, expression = value) AsmAssignSource(SourceStorageKind.EXPRESSION, program, asmgen, returnType.getOrElse { throw AssemblyError("unknown dt") }, expression = value)
} }

View File

@@ -159,7 +159,7 @@ internal class AssignmentAsmGen(private val program: Program, private val asmgen
is ArrayIndexedExpression -> throw AssemblyError("source kind should have been array") is ArrayIndexedExpression -> throw AssemblyError("source kind should have been array")
is DirectMemoryRead -> throw AssemblyError("source kind should have been memory") is DirectMemoryRead -> throw AssemblyError("source kind should have been memory")
is TypecastExpression -> assignTypeCastedValue(assign.target, value.type, value.expression, value) is TypecastExpression -> assignTypeCastedValue(assign.target, value.type, value.expression, value)
is FunctionCall -> { is FunctionCallExpr -> {
when (val sub = value.target.targetStatement(program)) { when (val sub = value.target.targetStatement(program)) {
is Subroutine -> { is Subroutine -> {
asmgen.saveXbeforeCall(value) asmgen.saveXbeforeCall(value)
@@ -215,7 +215,7 @@ internal class AssignmentAsmGen(private val program: Program, private val asmgen
} }
} }
} }
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionPlaceholder -> {
val signature = BuiltinFunctions.getValue(sub.name) val signature = BuiltinFunctions.getValue(sub.name)
asmgen.translateBuiltinFunctionCallExpression(value, signature, false, assign.target.register) asmgen.translateBuiltinFunctionCallExpression(value, signature, false, assign.target.register)
if(assign.target.register==null) { if(assign.target.register==null) {
@@ -474,7 +474,7 @@ internal class AssignmentAsmGen(private val program: Program, private val asmgen
} }
private fun assignCastViaLsbFunc(value: Expression, target: AsmAssignTarget) { private fun assignCastViaLsbFunc(value: Expression, target: AsmAssignTarget) {
val lsb = FunctionCall(IdentifierReference(listOf("lsb"), value.position), mutableListOf(value), value.position) val lsb = FunctionCallExpr(IdentifierReference(listOf("lsb"), value.position), mutableListOf(value), value.position)
lsb.linkParents(value.parent) lsb.linkParents(value.parent)
val src = AsmAssignSource(SourceStorageKind.EXPRESSION, program, asmgen, DataType.UBYTE, expression = lsb) val src = AsmAssignSource(SourceStorageKind.EXPRESSION, program, asmgen, DataType.UBYTE, expression = lsb)
val assign = AsmAssignment(src, target, false, program.memsizer, value.position) val assign = AsmAssignment(src, target, false, program.memsizer, value.position)

View File

@@ -303,11 +303,11 @@ class ConstantFoldingOptimizer(private val program: Program) : AstWalker() {
return noModifications return noModifications
} }
override fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> { override fun after(functionCallExpr: FunctionCallExpr, parent: Node): Iterable<IAstModification> {
// the args of a fuction are constfolded via recursion already. // the args of a fuction are constfolded via recursion already.
val constvalue = functionCall.constValue(program) val constvalue = functionCallExpr.constValue(program)
return if(constvalue!=null) return if(constvalue!=null)
listOf(IAstModification.ReplaceNode(functionCall, constvalue, parent)) listOf(IAstModification.ReplaceNode(functionCallExpr, constvalue, parent))
else else
noModifications noModifications
} }

View File

@@ -1,5 +1,6 @@
package prog8.optimizer package prog8.optimizer
import prog8.ast.IStatementContainer
import prog8.ast.Node import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.DataType import prog8.ast.base.DataType
@@ -7,7 +8,10 @@ import prog8.ast.base.FatalAstException
import prog8.ast.base.IntegerDatatypes import prog8.ast.base.IntegerDatatypes
import prog8.ast.base.NumericDatatypes import prog8.ast.base.NumericDatatypes
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.AnonymousScope
import prog8.ast.statements.Assignment import prog8.ast.statements.Assignment
import prog8.ast.statements.IfElse
import prog8.ast.statements.Jump
import prog8.ast.walk.AstWalker import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification import prog8.ast.walk.IAstModification
import kotlin.math.abs import kotlin.math.abs
@@ -54,35 +58,26 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
return mods return mods
} }
override fun before(expr: PrefixExpression, parent: Node): Iterable<IAstModification> { override fun after(ifElse: IfElse, parent: Node): Iterable<IAstModification> {
if (expr.operator == "+") { val truepart = ifElse.truepart
// +X --> X val elsepart = ifElse.elsepart
return listOf(IAstModification.ReplaceNode(expr, expr.expression, parent)) if(truepart.isNotEmpty() && elsepart.isNotEmpty()) {
} else if (expr.operator == "not") { if(truepart.statements.singleOrNull() is Jump) {
when(expr.expression) { return listOf(
is PrefixExpression -> { IAstModification.InsertAfter(ifElse, elsepart, parent as IStatementContainer),
// NOT(NOT(...)) -> ... IAstModification.ReplaceNode(elsepart, AnonymousScope(mutableListOf(), elsepart.position), ifElse)
val pe = expr.expression as PrefixExpression )
if(pe.operator == "not")
return listOf(IAstModification.ReplaceNode(expr, pe.expression, parent))
} }
is BinaryExpression -> { if(elsepart.statements.singleOrNull() is Jump) {
// NOT (xxxx) -> invert the xxxx val invertedCondition = invertCondition(ifElse.condition)
val be = expr.expression as BinaryExpression if(invertedCondition!=null) {
val newExpr = when (be.operator) { return listOf(
"<" -> BinaryExpression(be.left, ">=", be.right, be.position) IAstModification.ReplaceNode(ifElse.condition, invertedCondition, ifElse),
">" -> BinaryExpression(be.left, "<=", be.right, be.position) IAstModification.InsertAfter(ifElse, truepart, parent as IStatementContainer),
"<=" -> BinaryExpression(be.left, ">", be.right, be.position) IAstModification.ReplaceNode(elsepart, AnonymousScope(mutableListOf(), elsepart.position), ifElse),
">=" -> BinaryExpression(be.left, "<", be.right, be.position) IAstModification.ReplaceNode(truepart, elsepart, ifElse)
"==" -> BinaryExpression(be.left, "!=", be.right, be.position) )
"!=" -> BinaryExpression(be.left, "==", be.right, be.position)
else -> null
} }
if (newExpr != null)
return listOf(IAstModification.ReplaceNode(expr, newExpr, parent))
}
else -> return noModifications
} }
} }
return noModifications return noModifications
@@ -216,7 +211,7 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
// signedw < 0 --> msb(signedw) & $80 // signedw < 0 --> msb(signedw) & $80
return listOf(IAstModification.ReplaceNode( return listOf(IAstModification.ReplaceNode(
expr, expr,
BinaryExpression(FunctionCall(IdentifierReference(listOf("msb"), expr.position), BinaryExpression(FunctionCallExpr(IdentifierReference(listOf("msb"), expr.position),
mutableListOf(expr.left), mutableListOf(expr.left),
expr.position expr.position
), "&", NumericLiteralValue.optimalInteger(0x80, expr.position), expr.position), ), "&", NumericLiteralValue.optimalInteger(0x80, expr.position), expr.position),
@@ -306,31 +301,31 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
return noModifications return noModifications
} }
override fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> { override fun after(functionCallExpr: FunctionCallExpr, parent: Node): Iterable<IAstModification> {
if(functionCall.target.nameInSource == listOf("lsb")) { if(functionCallExpr.target.nameInSource == listOf("lsb")) {
val arg = functionCall.args[0] val arg = functionCallExpr.args[0]
if(arg is TypecastExpression) { if(arg is TypecastExpression) {
val valueDt = arg.expression.inferType(program) val valueDt = arg.expression.inferType(program)
if (valueDt istype DataType.BYTE || valueDt istype DataType.UBYTE) { if (valueDt istype DataType.BYTE || valueDt istype DataType.UBYTE) {
// useless lsb() of byte value that was typecasted to word // useless lsb() of byte value that was typecasted to word
return listOf(IAstModification.ReplaceNode(functionCall, arg.expression, parent)) return listOf(IAstModification.ReplaceNode(functionCallExpr, arg.expression, parent))
} }
} else { } else {
val argDt = arg.inferType(program) val argDt = arg.inferType(program)
if (argDt istype DataType.BYTE || argDt istype DataType.UBYTE) { if (argDt istype DataType.BYTE || argDt istype DataType.UBYTE) {
// useless lsb() of byte value // useless lsb() of byte value
return listOf(IAstModification.ReplaceNode(functionCall, arg, parent)) return listOf(IAstModification.ReplaceNode(functionCallExpr, arg, parent))
} }
} }
} }
else if(functionCall.target.nameInSource == listOf("msb")) { else if(functionCallExpr.target.nameInSource == listOf("msb")) {
val arg = functionCall.args[0] val arg = functionCallExpr.args[0]
if(arg is TypecastExpression) { if(arg is TypecastExpression) {
val valueDt = arg.expression.inferType(program) val valueDt = arg.expression.inferType(program)
if (valueDt istype DataType.BYTE || valueDt istype DataType.UBYTE) { if (valueDt istype DataType.BYTE || valueDt istype DataType.UBYTE) {
// useless msb() of byte value that was typecasted to word, replace with 0 // useless msb() of byte value that was typecasted to word, replace with 0
return listOf(IAstModification.ReplaceNode( return listOf(IAstModification.ReplaceNode(
functionCall, functionCallExpr,
NumericLiteralValue(valueDt.getOr(DataType.UBYTE), 0.0, arg.expression.position), NumericLiteralValue(valueDt.getOr(DataType.UBYTE), 0.0, arg.expression.position),
parent)) parent))
} }
@@ -339,7 +334,7 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
if (argDt istype DataType.BYTE || argDt istype DataType.UBYTE) { if (argDt istype DataType.BYTE || argDt istype DataType.UBYTE) {
// useless msb() of byte value, replace with 0 // useless msb() of byte value, replace with 0
return listOf(IAstModification.ReplaceNode( return listOf(IAstModification.ReplaceNode(
functionCall, functionCallExpr,
NumericLiteralValue(argDt.getOr(DataType.UBYTE), 0.0, arg.position), NumericLiteralValue(argDt.getOr(DataType.UBYTE), 0.0, arg.position),
parent)) parent))
} }
@@ -460,7 +455,7 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
} }
0.5 -> { 0.5 -> {
// sqrt(left) // sqrt(left)
return FunctionCall(IdentifierReference(listOf("sqrt"), expr.position), mutableListOf(expr.left), expr.position) return FunctionCallExpr(IdentifierReference(listOf("sqrt"), expr.position), mutableListOf(expr.left), expr.position)
} }
1.0 -> { 1.0 -> {
// left // left
@@ -653,12 +648,12 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
if (amount >= 16) { if (amount >= 16) {
return NumericLiteralValue(targetDt, 0.0, expr.position) return NumericLiteralValue(targetDt, 0.0, expr.position)
} else if (amount >= 8) { } else if (amount >= 8) {
val lsb = FunctionCall(IdentifierReference(listOf("lsb"), expr.position), mutableListOf(expr.left), expr.position) val lsb = FunctionCallExpr(IdentifierReference(listOf("lsb"), expr.position), mutableListOf(expr.left), expr.position)
if (amount == 8) { if (amount == 8) {
return FunctionCall(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(lsb, NumericLiteralValue.optimalInteger(0, expr.position)), expr.position) return FunctionCallExpr(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(lsb, NumericLiteralValue.optimalInteger(0, expr.position)), expr.position)
} }
val shifted = BinaryExpression(lsb, "<<", NumericLiteralValue.optimalInteger(amount - 8, expr.position), expr.position) val shifted = BinaryExpression(lsb, "<<", NumericLiteralValue.optimalInteger(amount - 8, expr.position), expr.position)
return FunctionCall(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(shifted, NumericLiteralValue.optimalInteger(0, expr.position)), expr.position) return FunctionCallExpr(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(shifted, NumericLiteralValue.optimalInteger(0, expr.position)), expr.position)
} }
} }
else -> { else -> {
@@ -695,11 +690,11 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
return NumericLiteralValue.optimalInteger(0, expr.position) return NumericLiteralValue.optimalInteger(0, expr.position)
} }
else if (amount >= 8) { else if (amount >= 8) {
val msb = FunctionCall(IdentifierReference(listOf("msb"), expr.position), mutableListOf(expr.left), expr.position) val msb = FunctionCallExpr(IdentifierReference(listOf("msb"), expr.position), mutableListOf(expr.left), expr.position)
if (amount == 8) { if (amount == 8) {
// mkword(0, msb(v)) // mkword(0, msb(v))
val zero = NumericLiteralValue(DataType.UBYTE, 0.0, expr.position) val zero = NumericLiteralValue(DataType.UBYTE, 0.0, expr.position)
return FunctionCall(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(zero, msb), expr.position) return FunctionCallExpr(IdentifierReference(listOf("mkword"), expr.position), mutableListOf(zero, msb), expr.position)
} }
return TypecastExpression(BinaryExpression(msb, ">>", NumericLiteralValue.optimalInteger(amount - 8, expr.position), expr.position), DataType.UWORD, true, expr.position) return TypecastExpression(BinaryExpression(msb, ">>", NumericLiteralValue.optimalInteger(amount - 8, expr.position), expr.position), DataType.UWORD, true, expr.position)
} }

View File

@@ -6,7 +6,6 @@ import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.AstWalker import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification import prog8.ast.walk.IAstModification
import prog8.ast.walk.IAstVisitor
import prog8.compilerinterface.ICompilationTarget import prog8.compilerinterface.ICompilationTarget
import prog8.compilerinterface.IErrorReporter import prog8.compilerinterface.IErrorReporter
import prog8.compilerinterface.size import prog8.compilerinterface.size
@@ -19,7 +18,7 @@ class StatementOptimizer(private val program: Program,
private val compTarget: ICompilationTarget private val compTarget: ICompilationTarget
) : AstWalker() { ) : AstWalker() {
override fun before(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> { override fun before(functionCallExpr: FunctionCallExpr, parent: Node): Iterable<IAstModification> {
// if the first instruction in the called subroutine is a return statement with a simple value, // if the first instruction in the called subroutine is a return statement with a simple value,
// remove the jump altogeter and inline the returnvalue directly. // remove the jump altogeter and inline the returnvalue directly.
@@ -28,7 +27,7 @@ class StatementOptimizer(private val program: Program,
return IdentifierReference(target.scopedName, variable.position) return IdentifierReference(target.scopedName, variable.position)
} }
val subroutine = functionCall.target.targetSubroutine(program) val subroutine = functionCallExpr.target.targetSubroutine(program)
if(subroutine!=null) { if(subroutine!=null) {
val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull() val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull()
if(first is Return && first.value?.isSimple==true) { if(first is Return && first.value?.isSimple==true) {
@@ -48,7 +47,7 @@ class StatementOptimizer(private val program: Program,
is StringLiteralValue -> orig.copy() is StringLiteralValue -> orig.copy()
else -> return noModifications else -> return noModifications
} }
return listOf(IAstModification.ReplaceNode(functionCall, copy, parent)) return listOf(IAstModification.ReplaceNode(functionCallExpr, copy, parent))
} }
} }
return noModifications return noModifications
@@ -57,7 +56,7 @@ class StatementOptimizer(private val program: Program,
override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> { override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
if(functionCallStatement.target.targetStatement(program) is BuiltinFunctionStatementPlaceholder) { if(functionCallStatement.target.targetStatement(program) is BuiltinFunctionPlaceholder) {
val functionName = functionCallStatement.target.nameInSource[0] val functionName = functionCallStatement.target.nameInSource[0]
if (functionName in functions.purefunctionNames) { if (functionName in functions.purefunctionNames) {
errors.warn("statement has no effect (function return value is discarded)", functionCallStatement.position) errors.warn("statement has no effect (function return value is discarded)", functionCallStatement.position)
@@ -134,33 +133,33 @@ class StatementOptimizer(private val program: Program,
return noModifications return noModifications
} }
override fun after(ifStatement: IfStatement, parent: Node): Iterable<IAstModification> { override fun after(ifElse: IfElse, parent: Node): Iterable<IAstModification> {
// remove empty if statements // remove empty if statements
if(ifStatement.truepart.isEmpty() && ifStatement.elsepart.isEmpty()) if(ifElse.truepart.isEmpty() && ifElse.elsepart.isEmpty())
return listOf(IAstModification.Remove(ifStatement, parent as IStatementContainer)) return listOf(IAstModification.Remove(ifElse, parent as IStatementContainer))
// empty true part? switch with the else part // empty true part? switch with the else part
if(ifStatement.truepart.isEmpty() && ifStatement.elsepart.isNotEmpty()) { if(ifElse.truepart.isEmpty() && ifElse.elsepart.isNotEmpty()) {
val invertedCondition = PrefixExpression("not", ifStatement.condition, ifStatement.condition.position) val invertedCondition = PrefixExpression("not", ifElse.condition, ifElse.condition.position)
val emptyscope = AnonymousScope(mutableListOf(), ifStatement.elsepart.position) val emptyscope = AnonymousScope(mutableListOf(), ifElse.elsepart.position)
val truepart = AnonymousScope(ifStatement.elsepart.statements, ifStatement.truepart.position) val truepart = AnonymousScope(ifElse.elsepart.statements, ifElse.truepart.position)
return listOf( return listOf(
IAstModification.ReplaceNode(ifStatement.condition, invertedCondition, ifStatement), IAstModification.ReplaceNode(ifElse.condition, invertedCondition, ifElse),
IAstModification.ReplaceNode(ifStatement.truepart, truepart, ifStatement), IAstModification.ReplaceNode(ifElse.truepart, truepart, ifElse),
IAstModification.ReplaceNode(ifStatement.elsepart, emptyscope, ifStatement) IAstModification.ReplaceNode(ifElse.elsepart, emptyscope, ifElse)
) )
} }
val constvalue = ifStatement.condition.constValue(program) val constvalue = ifElse.condition.constValue(program)
if(constvalue!=null) { if(constvalue!=null) {
return if(constvalue.asBooleanValue){ return if(constvalue.asBooleanValue){
// always true -> keep only if-part // always true -> keep only if-part
errors.warn("condition is always true", ifStatement.position) errors.warn("condition is always true", ifElse.position)
listOf(IAstModification.ReplaceNode(ifStatement, ifStatement.truepart, parent)) listOf(IAstModification.ReplaceNode(ifElse, ifElse.truepart, parent))
} else { } else {
// always false -> keep only else-part // always false -> keep only else-part
errors.warn("condition is always false", ifStatement.position) errors.warn("condition is always false", ifElse.position)
listOf(IAstModification.ReplaceNode(ifStatement, ifStatement.elsepart, parent)) listOf(IAstModification.ReplaceNode(ifElse, ifElse.elsepart, parent))
} }
} }
@@ -228,15 +227,14 @@ class StatementOptimizer(private val program: Program,
override fun before(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> { override fun before(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> {
val constvalue = untilLoop.condition.constValue(program) val constvalue = untilLoop.condition.constValue(program)
if(constvalue!=null) { if(constvalue!=null) {
if(constvalue.asBooleanValue) { return if(constvalue.asBooleanValue) {
// always true -> keep only the statement block (if there are no break statements) // always true -> keep only the statement block
errors.warn("condition is always true", untilLoop.condition.position) errors.warn("condition is always true", untilLoop.condition.position)
if(!hasBreak(untilLoop.body)) listOf(IAstModification.ReplaceNode(untilLoop, untilLoop.body, parent))
return listOf(IAstModification.ReplaceNode(untilLoop, untilLoop.body, parent))
} else { } else {
// always false // always false
val forever = RepeatLoop(null, untilLoop.body, untilLoop.position) val forever = RepeatLoop(null, untilLoop.body, untilLoop.position)
return listOf(IAstModification.ReplaceNode(untilLoop, forever, parent)) listOf(IAstModification.ReplaceNode(untilLoop, forever, parent))
} }
} }
return noModifications return noModifications
@@ -279,24 +277,25 @@ class StatementOptimizer(private val program: Program,
} }
override fun after(jump: Jump, parent: Node): Iterable<IAstModification> { override fun after(jump: Jump, parent: Node): Iterable<IAstModification> {
if(jump.isGosub) {
// if the next statement is return with no returnvalue, change into a regular jump if there are no parameters as well.
val subroutineParams = jump.identifier?.targetSubroutine(program)?.parameters
if(subroutineParams!=null && subroutineParams.isEmpty()) {
val returnstmt = jump.nextSibling() as? Return
if(returnstmt!=null && returnstmt.value==null) {
return listOf(
IAstModification.Remove(returnstmt, parent as IStatementContainer),
IAstModification.ReplaceNode(jump, Jump(jump.address, jump.identifier, jump.generatedLabel, jump.position), parent)
)
}
}
} else {
// if the jump is to the next statement, remove the jump // if the jump is to the next statement, remove the jump
val scope = jump.parent as IStatementContainer val scope = jump.parent as IStatementContainer
val label = jump.identifier?.targetStatement(program) val label = jump.identifier?.targetStatement(program)
if (label != null && scope.statements.indexOf(label) == scope.statements.indexOf(jump) + 1) if (label != null && scope.statements.indexOf(label) == scope.statements.indexOf(jump) + 1)
return listOf(IAstModification.Remove(jump, scope)) return listOf(IAstModification.Remove(jump, scope))
return noModifications
}
override fun after(gosub: GoSub, parent: Node): Iterable<IAstModification> {
// if the next statement is return with no returnvalue, change into a regular jump if there are no parameters as well.
val subroutineParams = gosub.identifier?.targetSubroutine(program)?.parameters
if(subroutineParams!=null && subroutineParams.isEmpty()) {
val returnstmt = gosub.nextSibling() as? Return
if(returnstmt!=null && returnstmt.value==null) {
return listOf(
IAstModification.Remove(returnstmt, parent as IStatementContainer),
IAstModification.ReplaceNode(gosub, Jump(gosub.address, gosub.identifier, gosub.generatedLabel, gosub.position), parent)
)
}
} }
return noModifications return noModifications
} }
@@ -474,25 +473,4 @@ class StatementOptimizer(private val program: Program,
return noModifications return noModifications
} }
private fun hasBreak(scope: IStatementContainer): Boolean {
class Searcher: IAstVisitor
{
var count=0
override fun visit(breakStmt: Break) {
count++
}
}
val s=Searcher()
for(stmt in scope.statements) {
stmt.accept(s)
if(s.count>0)
return true
}
return s.count > 0
}
} }

View File

@@ -33,7 +33,6 @@ class UnusedCodeRemover(private val program: Program,
} }
override fun before(jump: Jump, parent: Node): Iterable<IAstModification> { override fun before(jump: Jump, parent: Node): Iterable<IAstModification> {
if(!jump.isGosub)
reportUnreachable(jump) reportUnreachable(jump)
return emptyList() return emptyList()
} }
@@ -235,7 +234,7 @@ class UnusedCodeRemover(private val program: Program,
is PrefixExpression, is PrefixExpression,
is BinaryExpression, is BinaryExpression,
is TypecastExpression, is TypecastExpression,
is FunctionCall -> { /* don't remove */ } is FunctionCallExpr -> { /* don't remove */ }
else -> linesToRemove.add(assign1) else -> linesToRemove.add(assign1)
} }
} }

View File

@@ -31,6 +31,18 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
varsList.add(decl.name to decl) varsList.add(decl.name to decl)
} }
override fun before(breakStmt: Break, parent: Node): Iterable<IAstModification> {
throw FatalAstException("break should have been replaced by goto $breakStmt")
}
override fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> {
throw FatalAstException("while should have been desugared to jumps")
}
override fun before(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> {
throw FatalAstException("do..until should have been desugared to jumps")
}
override fun before(block: Block, parent: Node): Iterable<IAstModification> { override fun before(block: Block, parent: Node): Iterable<IAstModification> {
// move all subroutines to the bottom of the block // move all subroutines to the bottom of the block
val subs = block.statements.filterIsInstance<Subroutine>() val subs = block.statements.filterIsInstance<Subroutine>()
@@ -127,7 +139,7 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
if (subroutineStmtIdx > 0) { if (subroutineStmtIdx > 0) {
val prevStmt = outerStatements[subroutineStmtIdx-1] val prevStmt = outerStatements[subroutineStmtIdx-1]
if(outerScope !is Block if(outerScope !is Block
&& (prevStmt !is Jump || prevStmt.isGosub) && (prevStmt !is Jump)
&& prevStmt !is Subroutine && prevStmt !is Subroutine
&& prevStmt !is Return) { && prevStmt !is Return) {
mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope) mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope)
@@ -181,20 +193,19 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
return noModifications return noModifications
} }
@Suppress("DuplicatedCode") override fun after(ifElse: IfElse, parent: Node): Iterable<IAstModification> {
override fun after(ifStatement: IfStatement, parent: Node): Iterable<IAstModification> { val prefixExpr = ifElse.condition as? PrefixExpression
val prefixExpr = ifStatement.condition as? PrefixExpression
if(prefixExpr!=null && prefixExpr.operator=="not") { if(prefixExpr!=null && prefixExpr.operator=="not") {
// if not x -> if x==0 // if not x -> if x==0
val booleanExpr = BinaryExpression(prefixExpr.expression, "==", NumericLiteralValue.optimalInteger(0, ifStatement.condition.position), ifStatement.condition.position) val booleanExpr = BinaryExpression(prefixExpr.expression, "==", NumericLiteralValue.optimalInteger(0, ifElse.condition.position), ifElse.condition.position)
return listOf(IAstModification.ReplaceNode(ifStatement.condition, booleanExpr, ifStatement)) return listOf(IAstModification.ReplaceNode(ifElse.condition, booleanExpr, ifElse))
} }
val binExpr = ifStatement.condition as? BinaryExpression val binExpr = ifElse.condition as? BinaryExpression
if(binExpr==null || binExpr.operator !in ComparisonOperators) { if(binExpr==null || binExpr.operator !in ComparisonOperators) {
// if x -> if x!=0, if x+5 -> if x+5 != 0 // if x -> if x!=0, if x+5 -> if x+5 != 0
val booleanExpr = BinaryExpression(ifStatement.condition, "!=", NumericLiteralValue.optimalInteger(0, ifStatement.condition.position), ifStatement.condition.position) val booleanExpr = BinaryExpression(ifElse.condition, "!=", NumericLiteralValue.optimalInteger(0, ifElse.condition.position), ifElse.condition.position)
return listOf(IAstModification.ReplaceNode(ifStatement.condition, booleanExpr, ifStatement)) return listOf(IAstModification.ReplaceNode(ifElse.condition, booleanExpr, ifElse))
} }
if((binExpr.left as? NumericLiteralValue)?.number==0.0 && if((binExpr.left as? NumericLiteralValue)?.number==0.0 &&
@@ -208,11 +219,11 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
val modifications = mutableListOf<IAstModification>() val modifications = mutableListOf<IAstModification>()
if(simplify.rightVarAssignment!=null) { if(simplify.rightVarAssignment!=null) {
modifications += IAstModification.ReplaceNode(binExpr.right, simplify.rightOperandReplacement!!, binExpr) modifications += IAstModification.ReplaceNode(binExpr.right, simplify.rightOperandReplacement!!, binExpr)
modifications += IAstModification.InsertBefore(ifStatement, simplify.rightVarAssignment, parent as IStatementContainer) modifications += IAstModification.InsertBefore(ifElse, simplify.rightVarAssignment, parent as IStatementContainer)
} }
if(simplify.leftVarAssignment!=null) { if(simplify.leftVarAssignment!=null) {
modifications += IAstModification.ReplaceNode(binExpr.left, simplify.leftOperandReplacement!!, binExpr) modifications += IAstModification.ReplaceNode(binExpr.left, simplify.leftOperandReplacement!!, binExpr)
modifications += IAstModification.InsertBefore(ifStatement, simplify.leftVarAssignment, parent as IStatementContainer) modifications += IAstModification.InsertBefore(ifElse, simplify.leftVarAssignment, parent as IStatementContainer)
} }
return modifications return modifications
@@ -278,74 +289,6 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
) )
} }
@Suppress("DuplicatedCode")
override fun after(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> {
val prefixExpr = untilLoop.condition as? PrefixExpression
if(prefixExpr!=null && prefixExpr.operator=="not") {
// until not x -> until x==0
val booleanExpr = BinaryExpression(prefixExpr.expression, "==", NumericLiteralValue.optimalInteger(0, untilLoop.condition.position), untilLoop.condition.position)
return listOf(IAstModification.ReplaceNode(untilLoop.condition, booleanExpr, untilLoop))
}
val binExpr = untilLoop.condition as? BinaryExpression
if(binExpr==null || binExpr.operator !in ComparisonOperators) {
// until x -> until x!=0, until x+5 -> until x+5 != 0
val booleanExpr = BinaryExpression(untilLoop.condition, "!=", NumericLiteralValue.optimalInteger(0, untilLoop.condition.position), untilLoop.condition.position)
return listOf(IAstModification.ReplaceNode(untilLoop.condition, booleanExpr, untilLoop))
}
if((binExpr.left as? NumericLiteralValue)?.number==0.0 &&
(binExpr.right as? NumericLiteralValue)?.number!=0.0)
throw FatalAstException("0==X should have been swapped to if X==0")
// simplify the conditional expression, introduce simple assignments if required.
// NOTE: sometimes this increases code size because additional stores/loads are generated for the
// intermediate variables. We assume these are optimized away from the resulting assembly code later.
val simplify = simplifyConditionalExpression(binExpr)
val modifications = mutableListOf<IAstModification>()
if(simplify.rightVarAssignment!=null) {
modifications += IAstModification.ReplaceNode(binExpr.right, simplify.rightOperandReplacement!!, binExpr)
modifications += IAstModification.InsertLast(simplify.rightVarAssignment, untilLoop.body)
}
if(simplify.leftVarAssignment!=null) {
modifications += IAstModification.ReplaceNode(binExpr.left, simplify.leftOperandReplacement!!, binExpr)
modifications += IAstModification.InsertLast(simplify.leftVarAssignment, untilLoop.body)
}
return modifications
}
@Suppress("DuplicatedCode")
override fun after(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> {
val prefixExpr = whileLoop.condition as? PrefixExpression
if(prefixExpr!=null && prefixExpr.operator=="not") {
// while not x -> while x==0
val booleanExpr = BinaryExpression(prefixExpr.expression, "==", NumericLiteralValue.optimalInteger(0, whileLoop.condition.position), whileLoop.condition.position)
return listOf(IAstModification.ReplaceNode(whileLoop.condition, booleanExpr, whileLoop))
}
val binExpr = whileLoop.condition as? BinaryExpression
if(binExpr==null || binExpr.operator !in ComparisonOperators) {
// while x -> while x!=0, while x+5 -> while x+5 != 0
val booleanExpr = BinaryExpression(whileLoop.condition, "!=", NumericLiteralValue.optimalInteger(0, whileLoop.condition.position), whileLoop.condition.position)
return listOf(IAstModification.ReplaceNode(whileLoop.condition, booleanExpr, whileLoop))
}
if((binExpr.left as? NumericLiteralValue)?.number==0.0 &&
(binExpr.right as? NumericLiteralValue)?.number!=0.0)
throw FatalAstException("0==X should have been swapped to if X==0")
// TODO simplify the conditional expression, introduce simple assignments if required.
// make sure to evaluate it only once, but also right at entry of the while loop
// NOTE: sometimes this increases code size because additional stores/loads are generated for the
// intermediate variables. We assume these are optimized away from the resulting assembly code later.
// NOTE: this is nasty for a while-statement as the condition occurs at the top of the loop
// so the expression needs to be evaluated also before the loop is entered...
// but I don't want to duplicate the expression.
// val simplify = simplifyConditionalExpression(binExpr)
return noModifications
}
override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> { override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
if(functionCallStatement.target.nameInSource==listOf("cmp")) { if(functionCallStatement.target.nameInSource==listOf("cmp")) {
// if the datatype of the arguments of cmp() are different, cast the byte one to word. // if the datatype of the arguments of cmp() are different, cast the byte one to word.
@@ -398,12 +341,12 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o
complexArrayIndexedExpressions.add(arrayIndexedExpression) complexArrayIndexedExpressions.add(arrayIndexedExpression)
} }
override fun visit(branchStatement: BranchStatement) {} override fun visit(branch: Branch) {}
override fun visit(forLoop: ForLoop) {} override fun visit(forLoop: ForLoop) {}
override fun visit(ifStatement: IfStatement) { override fun visit(ifElse: IfElse) {
ifStatement.condition.accept(this) ifElse.condition.accept(this)
} }
override fun visit(untilLoop: UntilLoop) { override fun visit(untilLoop: UntilLoop) {

View File

@@ -267,6 +267,8 @@ private fun processAst(program: Program, errors: IErrorReporter, compilerOptions
program.charLiteralsToUByteLiterals(compilerOptions.compTarget) program.charLiteralsToUByteLiterals(compilerOptions.compTarget)
program.constantFold(errors, compilerOptions.compTarget) program.constantFold(errors, compilerOptions.compTarget)
errors.report() errors.report()
program.desugaring(errors)
errors.report()
program.reorderStatements(errors, compilerOptions) program.reorderStatements(errors, compilerOptions)
errors.report() errors.report()
program.addTypecasts(errors, compilerOptions) program.addTypecasts(errors, compilerOptions)
@@ -297,21 +299,21 @@ private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, e
if (optsDone1 + optsDone2 + optsDone3 == 0) if (optsDone1 + optsDone2 + optsDone3 == 0)
break break
} }
errors.report() errors.report()
} }
private fun postprocessAst(program: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) { private fun postprocessAst(program: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) {
program.desugaring(errors)
program.addTypecasts(errors, compilerOptions) program.addTypecasts(errors, compilerOptions)
errors.report() errors.report()
program.variousCleanups(program, errors) program.variousCleanups(program, errors)
program.checkValid(errors, compilerOptions) // check if final tree is still valid
errors.report()
val callGraph = CallGraph(program) val callGraph = CallGraph(program)
callGraph.checkRecursiveCalls(errors) callGraph.checkRecursiveCalls(errors)
errors.report() errors.report()
program.verifyFunctionArgTypes() program.verifyFunctionArgTypes()
program.moveMainAndStartToFirst() program.moveMainAndStartToFirst()
program.checkValid(errors, compilerOptions) // check if final tree is still valid
errors.report()
} }
private sealed class WriteAssemblyResult { private sealed class WriteAssemblyResult {

View File

@@ -80,10 +80,10 @@ internal class AstChecker(private val program: Program,
super.visit(returnStmt) super.visit(returnStmt)
} }
override fun visit(ifStatement: IfStatement) { override fun visit(ifElse: IfElse) {
if(!ifStatement.condition.inferType(program).isInteger) if(!ifElse.condition.inferType(program).isInteger)
errors.err("condition value should be an integer type", ifStatement.condition.position) errors.err("condition value should be an integer type", ifElse.condition.position)
super.visit(ifStatement) super.visit(ifElse)
} }
override fun visit(forLoop: ForLoop) { override fun visit(forLoop: ForLoop) {
@@ -165,10 +165,10 @@ internal class AstChecker(private val program: Program,
if(ident!=null) { if(ident!=null) {
val targetStatement = checkFunctionOrLabelExists(ident, jump) val targetStatement = checkFunctionOrLabelExists(ident, jump)
if(targetStatement!=null) { if(targetStatement!=null) {
if(targetStatement is BuiltinFunctionStatementPlaceholder) if(targetStatement is BuiltinFunctionPlaceholder)
errors.err("can't jump to a builtin function", jump.position) errors.err("can't jump to a builtin function", jump.position)
} }
if(!jump.isGosub && targetStatement is Subroutine && targetStatement.parameters.any()) { if(targetStatement is Subroutine && targetStatement.parameters.any()) {
errors.err("can't jump to a subroutine that takes parameters", jump.position) errors.err("can't jump to a subroutine that takes parameters", jump.position)
} }
} }
@@ -179,6 +179,22 @@ internal class AstChecker(private val program: Program,
super.visit(jump) super.visit(jump)
} }
override fun visit(gosub: GoSub) {
val ident = gosub.identifier
if(ident!=null) {
val targetStatement = checkFunctionOrLabelExists(ident, gosub)
if(targetStatement!=null) {
if(targetStatement is BuiltinFunctionPlaceholder)
errors.err("can't gosub to a builtin function", gosub.position)
}
}
val addr = gosub.address
if(addr!=null && addr > 65535u)
errors.err("gosub address must be valid integer 0..\$ffff", gosub.position)
super.visit(gosub)
}
override fun visit(block: Block) { override fun visit(block: Block) {
val addr = block.address val addr = block.address
if(addr!=null && addr>65535u) { if(addr!=null && addr>65535u) {
@@ -193,7 +209,7 @@ internal class AstChecker(private val program: Program,
is VarDecl, is VarDecl,
is InlineAssembly, is InlineAssembly,
is IStatementContainer, is IStatementContainer,
is NopStatement -> true is Nop -> true
is Assignment -> { is Assignment -> {
val target = statement.target.identifier!!.targetStatement(program) val target = statement.target.identifier!!.targetStatement(program)
target === statement.previousSibling() // an initializer assignment is okay target === statement.previousSibling() // an initializer assignment is okay
@@ -226,7 +242,6 @@ internal class AstChecker(private val program: Program,
count++ count++
} }
override fun visit(jump: Jump) { override fun visit(jump: Jump) {
if(!jump.isGosub)
count++ count++
} }
@@ -488,7 +503,7 @@ internal class AstChecker(private val program: Program,
} else { } else {
val sourceDatatype = assignment.value.inferType(program) val sourceDatatype = assignment.value.inferType(program)
if (sourceDatatype.isUnknown) { if (sourceDatatype.isUnknown) {
if (assignment.value !is FunctionCall) if (assignment.value !is FunctionCallExpr)
errors.err("assignment value is invalid or has no proper datatype", assignment.value.position) errors.err("assignment value is invalid or has no proper datatype", assignment.value.position)
} else { } else {
checkAssignmentCompatible(targetDatatype.getOr(DataType.BYTE), checkAssignmentCompatible(targetDatatype.getOr(DataType.BYTE),
@@ -920,28 +935,28 @@ internal class AstChecker(private val program: Program,
} }
} }
override fun visit(functionCall: FunctionCall) { override fun visit(functionCallExpr: FunctionCallExpr) {
// this function call is (part of) an expression, which should be in a statement somewhere. // this function call is (part of) an expression, which should be in a statement somewhere.
val stmtOfExpression = findParentNode<Statement>(functionCall) val stmtOfExpression = findParentNode<Statement>(functionCallExpr)
?: throw FatalAstException("cannot determine statement scope of function call expression at ${functionCall.position}") ?: throw FatalAstException("cannot determine statement scope of function call expression at ${functionCallExpr.position}")
val targetStatement = checkFunctionOrLabelExists(functionCall.target, stmtOfExpression) val targetStatement = checkFunctionOrLabelExists(functionCallExpr.target, stmtOfExpression)
if(targetStatement!=null) if(targetStatement!=null)
checkFunctionCall(targetStatement, functionCall.args, functionCall.position) checkFunctionCall(targetStatement, functionCallExpr.args, functionCallExpr.position)
// warn about sgn(unsigned) this is likely a mistake // warn about sgn(unsigned) this is likely a mistake
if(functionCall.target.nameInSource.last()=="sgn") { if(functionCallExpr.target.nameInSource.last()=="sgn") {
val sgnArgType = functionCall.args.first().inferType(program) val sgnArgType = functionCallExpr.args.first().inferType(program)
if(sgnArgType istype DataType.UBYTE || sgnArgType istype DataType.UWORD) if(sgnArgType istype DataType.UBYTE || sgnArgType istype DataType.UWORD)
errors.warn("sgn() of unsigned type is always 0 or 1, this is perhaps not what was intended", functionCall.args.first().position) errors.warn("sgn() of unsigned type is always 0 or 1, this is perhaps not what was intended", functionCallExpr.args.first().position)
} }
val error = VerifyFunctionArgTypes.checkTypes(functionCall, program) val error = VerifyFunctionArgTypes.checkTypes(functionCallExpr, program)
if(error!=null) if(error!=null)
errors.err(error, functionCall.position) errors.err(error, functionCallExpr.position)
// check the functions that return multiple returnvalues. // check the functions that return multiple returnvalues.
val stmt = functionCall.target.targetStatement(program) val stmt = functionCallExpr.target.targetStatement(program)
if (stmt is Subroutine) { if (stmt is Subroutine) {
if (stmt.returntypes.size > 1) { if (stmt.returntypes.size > 1) {
// Currently, it's only possible to handle ONE (or zero) return values from a subroutine. // Currently, it's only possible to handle ONE (or zero) return values from a subroutine.
@@ -954,7 +969,7 @@ internal class AstChecker(private val program: Program,
// dealing with the status bit as just being that, the status bit after the call. // dealing with the status bit as just being that, the status bit after the call.
val (returnRegisters, _) = stmt.asmReturnvaluesRegisters.partition { rr -> rr.registerOrPair != null } val (returnRegisters, _) = stmt.asmReturnvaluesRegisters.partition { rr -> rr.registerOrPair != null }
if (returnRegisters.size>1) { if (returnRegisters.size>1) {
errors.err("It's not possible to store the multiple result values of this asmsub call; you should use a small block of custom inline assembly for this.", functionCall.position) errors.err("It's not possible to store the multiple result values of this asmsub call; you should use a small block of custom inline assembly for this.", functionCallExpr.position)
} }
} }
} }
@@ -962,18 +977,18 @@ internal class AstChecker(private val program: Program,
// functions that don't return a value, can't be used in an expression or assignment // functions that don't return a value, can't be used in an expression or assignment
if(targetStatement is Subroutine) { if(targetStatement is Subroutine) {
if(targetStatement.returntypes.isEmpty()) { if(targetStatement.returntypes.isEmpty()) {
if(functionCall.parent is Expression || functionCall.parent is Assignment) if(functionCallExpr.parent is Expression || functionCallExpr.parent is Assignment)
errors.err("subroutine doesn't return a value", functionCall.position) errors.err("subroutine doesn't return a value", functionCallExpr.position)
} }
} }
else if(targetStatement is BuiltinFunctionStatementPlaceholder) { else if(targetStatement is BuiltinFunctionPlaceholder) {
if(builtinFunctionReturnType(targetStatement.name, functionCall.args, program).isUnknown) { if(builtinFunctionReturnType(targetStatement.name, functionCallExpr.args, program).isUnknown) {
if(functionCall.parent is Expression || functionCall.parent is Assignment) if(functionCallExpr.parent is Expression || functionCallExpr.parent is Assignment)
errors.err("function doesn't return a value", functionCall.position) errors.err("function doesn't return a value", functionCallExpr.position)
} }
} }
super.visit(functionCall) super.visit(functionCallExpr)
} }
override fun visit(functionCallStatement: FunctionCallStatement) { override fun visit(functionCallStatement: FunctionCallStatement) {
@@ -1029,7 +1044,7 @@ internal class AstChecker(private val program: Program,
if(target is Label && args.isNotEmpty()) if(target is Label && args.isNotEmpty())
errors.err("cannot use arguments when calling a label", position) errors.err("cannot use arguments when calling a label", position)
if(target is BuiltinFunctionStatementPlaceholder) { if(target is BuiltinFunctionPlaceholder) {
if(target.name=="swap") { if(target.name=="swap") {
// swap() is a bit weird because this one is translated into an operations directly, instead of being a function call // swap() is a bit weird because this one is translated into an operations directly, instead of being a function call
val dt1 = args[0].inferType(program) val dt1 = args[0].inferType(program)
@@ -1065,8 +1080,8 @@ internal class AstChecker(private val program: Program,
var ident: IdentifierReference? = null var ident: IdentifierReference? = null
if(arg.value is IdentifierReference) if(arg.value is IdentifierReference)
ident = arg.value as IdentifierReference ident = arg.value as IdentifierReference
else if(arg.value is FunctionCall) { else if(arg.value is FunctionCallExpr) {
val fcall = arg.value as FunctionCall val fcall = arg.value as FunctionCallExpr
if(fcall.target.nameInSource == listOf("lsb") || fcall.target.nameInSource == listOf("msb")) if(fcall.target.nameInSource == listOf("lsb") || fcall.target.nameInSource == listOf("msb"))
ident = fcall.args[0] as? IdentifierReference ident = fcall.args[0] as? IdentifierReference
} }
@@ -1149,11 +1164,11 @@ internal class AstChecker(private val program: Program,
super.visit(arrayIndexedExpression) super.visit(arrayIndexedExpression)
} }
override fun visit(whenStatement: WhenStatement) { override fun visit(whenStmt: When) {
if(!whenStatement.condition.inferType(program).isInteger) if(!whenStmt.condition.inferType(program).isInteger)
errors.err("when condition must be an integer value", whenStatement.position) errors.err("when condition must be an integer value", whenStmt.position)
val tally = mutableSetOf<Int>() val tally = mutableSetOf<Int>()
for((choices, choiceNode) in whenStatement.choiceValues(program)) { for((choices, choiceNode) in whenStmt.choiceValues(program)) {
if(choices!=null) { if(choices!=null) {
for (c in choices) { for (c in choices) {
if(c in tally) if(c in tally)
@@ -1164,14 +1179,14 @@ internal class AstChecker(private val program: Program,
} }
} }
if(whenStatement.choices.isEmpty()) if(whenStmt.choices.isEmpty())
errors.err("empty when statement", whenStatement.position) errors.err("empty when statement", whenStmt.position)
super.visit(whenStatement) super.visit(whenStmt)
} }
override fun visit(whenChoice: WhenChoice) { override fun visit(whenChoice: WhenChoice) {
val whenStmt = whenChoice.parent as WhenStatement val whenStmt = whenChoice.parent as When
if(whenChoice.values!=null) { if(whenChoice.values!=null) {
val conditionType = whenStmt.condition.inferType(program) val conditionType = whenStmt.condition.inferType(program)
if(!conditionType.isKnown) if(!conditionType.isKnown)
@@ -1193,7 +1208,7 @@ internal class AstChecker(private val program: Program,
private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: Statement): Statement? { private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: Statement): Statement? {
when (val targetStatement = target.targetStatement(program)) { when (val targetStatement = target.targetStatement(program)) {
is Label, is Subroutine, is BuiltinFunctionStatementPlaceholder -> return targetStatement is Label, is Subroutine, is BuiltinFunctionPlaceholder -> return targetStatement
null -> errors.err("undefined function or subroutine: ${target.nameInSource.joinToString(".")}", statement.position) null -> errors.err("undefined function or subroutine: ${target.nameInSource.joinToString(".")}", statement.position)
else -> errors.err("cannot call that: ${target.nameInSource.joinToString(".")}", statement.position) else -> errors.err("cannot call that: ${target.nameInSource.joinToString(".")}", statement.position)
} }
@@ -1417,7 +1432,7 @@ internal fun checkUnusedReturnValues(call: FunctionCallStatement, target: Statem
errors.warn("result value of subroutine call is discarded (use void?)", call.position) errors.warn("result value of subroutine call is discarded (use void?)", call.position)
else else
errors.warn("result values of subroutine call are discarded (use void?)", call.position) errors.warn("result values of subroutine call are discarded (use void?)", call.position)
} else if (target is BuiltinFunctionStatementPlaceholder) { } else if (target is BuiltinFunctionPlaceholder) {
val rt = builtinFunctionReturnType(target.name, call.args, program) val rt = builtinFunctionReturnType(target.name, call.args, program)
if (rt.isKnown) if (rt.isKnown)
errors.warn("result value of a function call is discarded (use void?)", call.position) errors.warn("result value of a function call is discarded (use void?)", call.position)

View File

@@ -16,6 +16,8 @@ import prog8.compilerinterface.IStringEncoding
internal fun Program.checkValid(errors: IErrorReporter, compilerOptions: CompilationOptions) { internal fun Program.checkValid(errors: IErrorReporter, compilerOptions: CompilationOptions) {
val parentChecker = ParentNodeChecker()
parentChecker.visit(this)
val checker = AstChecker(this, errors, compilerOptions) val checker = AstChecker(this, errors, compilerOptions)
checker.visit(this) checker.visit(this)
} }
@@ -59,6 +61,12 @@ internal fun Program.addTypecasts(errors: IErrorReporter, options: CompilationOp
caster.applyModifications() caster.applyModifications()
} }
fun Program.desugaring(errors: IErrorReporter): Int {
val desugar = CodeDesugarer(this, errors)
desugar.visit(this)
return desugar.applyModifications()
}
internal fun Program.verifyFunctionArgTypes() { internal fun Program.verifyFunctionArgTypes() {
val fixer = VerifyFunctionArgTypes(this) val fixer = VerifyFunctionArgTypes(this)
fixer.visit(this) fixer.visit(this)

View File

@@ -3,9 +3,8 @@ package prog8.compiler.astprocessing
import prog8.ast.IFunctionCall import prog8.ast.IFunctionCall
import prog8.ast.Node import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.FatalAstException
import prog8.ast.base.Position import prog8.ast.base.Position
import prog8.ast.expressions.FunctionCall import prog8.ast.expressions.FunctionCallExpr
import prog8.ast.expressions.StringLiteralValue import prog8.ast.expressions.StringLiteralValue
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.IAstVisitor import prog8.ast.walk.IAstVisitor
@@ -129,7 +128,7 @@ internal class AstIdentifiersChecker(private val errors: IErrorReporter,
super.visit(string) super.visit(string)
} }
override fun visit(functionCall: FunctionCall) = visitFunctionCall(functionCall) override fun visit(functionCallExpr: FunctionCallExpr) = visitFunctionCall(functionCallExpr)
override fun visit(functionCallStatement: FunctionCallStatement) = visitFunctionCall(functionCallStatement) override fun visit(functionCallStatement: FunctionCallStatement) = visitFunctionCall(functionCallStatement)
private fun visitFunctionCall(call: IFunctionCall) { private fun visitFunctionCall(call: IFunctionCall) {
@@ -140,7 +139,7 @@ internal class AstIdentifiersChecker(private val errors: IErrorReporter,
errors.err("invalid number of arguments", pos) errors.err("invalid number of arguments", pos)
} }
} }
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionPlaceholder -> {
val func = BuiltinFunctions.getValue(target.name) val func = BuiltinFunctions.getValue(target.name)
if(call.args.size != func.parameters.size) { if(call.args.size != func.parameters.size) {
val pos = (if(call.args.any()) call.args[0] else (call as Node)).position val pos = (if(call.args.any()) call.args[0] else (call as Node)).position

View File

@@ -0,0 +1,138 @@
package prog8.compiler.astprocessing
import prog8.ast.IFunctionCall
import prog8.ast.IStatementContainer
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.ParentSentinel
import prog8.ast.base.Position
import prog8.ast.expressions.DirectMemoryRead
import prog8.ast.expressions.FunctionCallExpr
import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.PrefixExpression
import prog8.ast.statements.*
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
import prog8.compilerinterface.*
internal class CodeDesugarer(val program: Program, private val errors: IErrorReporter) : AstWalker() {
// Some more code shuffling to simplify the Ast that the codegenerator has to process.
// Several changes have already been done by the StatementReorderer !
// But the ones here are simpler and are repeated once again after all optimization steps
// have been performed (because those could re-introduce nodes that have to be desugared)
//
// List of modifications:
// - replace 'break' statements by a goto + generated after label.
private var generatedLabelSequenceNumber: Int = 0
private val generatedLabelPrefix = "prog8_label_"
private fun makeLabel(postfix: String, position: Position): Label {
generatedLabelSequenceNumber++
return Label("${generatedLabelPrefix}${generatedLabelSequenceNumber}_$postfix", position)
}
private fun jumpLabel(label: Label): Jump {
val ident = IdentifierReference(listOf(label.name), label.position)
return Jump(null, ident, null, label.position)
}
override fun before(breakStmt: Break, parent: Node): Iterable<IAstModification> {
fun jumpAfter(stmt: Statement): Iterable<IAstModification> {
val label = makeLabel("after", breakStmt.position)
return listOf(
IAstModification.ReplaceNode(breakStmt, jumpLabel(label), parent),
IAstModification.InsertAfter(stmt, label, stmt.parent as IStatementContainer)
)
}
var partof = parent
while(true) {
when (partof) {
is Subroutine, is Block, is ParentSentinel -> {
errors.err("break in wrong scope", breakStmt.position)
return noModifications
}
is ForLoop,
is RepeatLoop,
is UntilLoop,
is WhileLoop -> return jumpAfter(partof as Statement)
else -> partof = partof.parent
}
}
}
override fun after(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> {
/*
do { STUFF } until CONDITION
===>
_loop:
STUFF
if not CONDITION
goto _loop
*/
val pos = untilLoop.position
val loopLabel = makeLabel("untilloop", pos)
val notCondition = PrefixExpression("not", untilLoop.condition, pos)
val replacement = AnonymousScope(mutableListOf(
loopLabel,
untilLoop.body,
IfElse(notCondition,
AnonymousScope(mutableListOf(jumpLabel(loopLabel)), pos),
AnonymousScope(mutableListOf(), pos),
pos)
), pos)
return listOf(IAstModification.ReplaceNode(untilLoop, replacement, parent))
}
override fun after(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> {
/*
while CONDITION { STUFF }
==>
_whileloop:
if NOT CONDITION goto _after
STUFF
goto _whileloop
_after:
*/
val pos = whileLoop.position
val loopLabel = makeLabel("whileloop", pos)
val afterLabel = makeLabel("afterwhile", pos)
val notCondition = PrefixExpression("not", whileLoop.condition, pos)
val replacement = AnonymousScope(mutableListOf(
loopLabel,
IfElse(notCondition,
AnonymousScope(mutableListOf(jumpLabel(afterLabel)), pos),
AnonymousScope(mutableListOf(), pos),
pos),
whileLoop.body,
jumpLabel(loopLabel),
afterLabel
), pos)
return listOf(IAstModification.ReplaceNode(whileLoop, replacement, parent))
}
override fun before(functionCallStatement: FunctionCallStatement, parent: Node) =
before(functionCallStatement as IFunctionCall, parent, functionCallStatement.position)
override fun before(functionCallExpr: FunctionCallExpr, parent: Node) =
before(functionCallExpr as IFunctionCall, parent, functionCallExpr.position)
private fun before(functionCall: IFunctionCall, parent: Node, position: Position): Iterable<IAstModification> {
if(functionCall.target.nameInSource==listOf("peek")) {
// peek(a) is synonymous with @(a)
val memread = DirectMemoryRead(functionCall.args.single(), position)
return listOf(IAstModification.ReplaceNode(functionCall as Node, memread, parent))
}
if(functionCall.target.nameInSource==listOf("poke")) {
// poke(a, v) is synonymous with @(a) = v
val tgt = AssignTarget(null, null, DirectMemoryWrite(functionCall.args[0], position), position)
val assign = Assignment(tgt, functionCall.args[1], position)
return listOf(IAstModification.ReplaceNode(functionCall as Node, assign, parent))
}
return noModifications
}
}

View File

@@ -0,0 +1,247 @@
package prog8.compiler.astprocessing
import prog8.ast.Module
import prog8.ast.Node
import prog8.ast.base.FatalAstException
import prog8.ast.expressions.*
import prog8.ast.statements.*
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
internal class ParentNodeChecker: AstWalker() {
override fun before(addressOf: AddressOf, parent: Node): Iterable<IAstModification> {
if(addressOf.parent!==parent)
throw FatalAstException("parent node mismatch at $addressOf")
return noModifications
}
override fun before(array: ArrayLiteralValue, parent: Node): Iterable<IAstModification> {
if(array.parent!==parent)
throw FatalAstException("parent node mismatch at $array")
return noModifications
}
override fun before(arrayIndexedExpression: ArrayIndexedExpression, parent: Node): Iterable<IAstModification> {
if(arrayIndexedExpression.parent!==parent)
throw FatalAstException("parent node mismatch at $arrayIndexedExpression")
return noModifications
}
override fun before(assignTarget: AssignTarget, parent: Node): Iterable<IAstModification> {
if(assignTarget.parent!==parent)
throw FatalAstException("parent node mismatch at $assignTarget")
return noModifications
}
override fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> {
if(assignment.parent!==parent)
throw FatalAstException("parent node mismatch at $assignment")
return noModifications
}
override fun before(block: Block, parent: Node): Iterable<IAstModification> {
if(block.parent!==parent)
throw FatalAstException("parent node mismatch at $block")
return noModifications
}
override fun before(branch: Branch, parent: Node): Iterable<IAstModification> {
if(branch.parent!==parent)
throw FatalAstException("parent node mismatch at $branch")
return noModifications
}
override fun before(breakStmt: Break, parent: Node): Iterable<IAstModification> {
if(breakStmt.parent!==parent)
throw FatalAstException("parent node mismatch at $breakStmt")
return noModifications
}
override fun before(decl: VarDecl, parent: Node): Iterable<IAstModification> {
if(decl.parent!==parent)
throw FatalAstException("parent node mismatch at $decl")
return noModifications
}
override fun before(directive: Directive, parent: Node): Iterable<IAstModification> {
if(directive.parent!==parent)
throw FatalAstException("parent node mismatch at $directive")
return noModifications
}
override fun before(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
if(expr.parent!==parent)
throw FatalAstException("parent node mismatch at $expr")
return noModifications
}
override fun before(expr: PrefixExpression, parent: Node): Iterable<IAstModification> {
if(expr.parent!==parent)
throw FatalAstException("parent node mismatch at $expr")
return noModifications
}
override fun before(forLoop: ForLoop, parent: Node): Iterable<IAstModification> {
if(forLoop.parent!==parent)
throw FatalAstException("parent node mismatch at $forLoop")
return noModifications
}
override fun before(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> {
if(repeatLoop.parent!==parent)
throw FatalAstException("parent node mismatch at $repeatLoop")
return noModifications
}
override fun before(identifier: IdentifierReference, parent: Node): Iterable<IAstModification> {
if(identifier.parent!==parent)
throw FatalAstException("parent node mismatch at $identifier")
return noModifications
}
override fun before(ifElse: IfElse, parent: Node): Iterable<IAstModification> {
if(ifElse.parent!==parent)
throw FatalAstException("parent node mismatch at $ifElse")
return noModifications
}
override fun before(inlineAssembly: InlineAssembly, parent: Node): Iterable<IAstModification> {
if(inlineAssembly.parent!==parent)
throw FatalAstException("parent node mismatch at $inlineAssembly")
return noModifications
}
override fun before(jump: Jump, parent: Node): Iterable<IAstModification> {
if(jump.parent!==parent)
throw FatalAstException("parent node mismatch at $jump")
return noModifications
}
override fun before(gosub: GoSub, parent: Node): Iterable<IAstModification> {
if(gosub.parent!==parent)
throw FatalAstException("parent node mismatch at $gosub")
return noModifications
}
override fun before(label: Label, parent: Node): Iterable<IAstModification> {
if(label.parent!==parent)
throw FatalAstException("parent node mismatch at $label")
return noModifications
}
override fun before(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
if(memread.parent!==parent)
throw FatalAstException("parent node mismatch at $memread")
return noModifications
}
override fun before(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> {
if(memwrite.parent!==parent)
throw FatalAstException("parent node mismatch at $memwrite")
return noModifications
}
override fun before(module: Module, parent: Node): Iterable<IAstModification> {
if(module.parent!==parent)
throw FatalAstException("parent node mismatch at $module")
return noModifications
}
override fun before(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> {
if(numLiteral.parent!==parent)
throw FatalAstException("parent node mismatch at $numLiteral")
return noModifications
}
override fun before(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> {
if(postIncrDecr.parent!==parent)
throw FatalAstException("parent node mismatch at $postIncrDecr")
return noModifications
}
override fun before(range: RangeExpr, parent: Node): Iterable<IAstModification> {
if(range.parent!==parent)
throw FatalAstException("parent node mismatch at $range")
return noModifications
}
override fun before(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> {
if(untilLoop.parent!==parent)
throw FatalAstException("parent node mismatch at $untilLoop")
return noModifications
}
override fun before(returnStmt: Return, parent: Node): Iterable<IAstModification> {
if(returnStmt.parent!==parent)
throw FatalAstException("parent node mismatch at $returnStmt")
return noModifications
}
override fun before(char: CharLiteral, parent: Node): Iterable<IAstModification> {
if(char.parent!==parent)
throw FatalAstException("parent node mismatch at $char")
return noModifications
}
override fun before(string: StringLiteralValue, parent: Node): Iterable<IAstModification> {
if(string.parent!==parent)
throw FatalAstException("parent node mismatch at $string")
return noModifications
}
override fun before(subroutine: Subroutine, parent: Node): Iterable<IAstModification> {
if(subroutine.parent!==parent)
throw FatalAstException("parent node mismatch at $subroutine")
return noModifications
}
override fun before(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
if(typecast.parent!==parent)
throw FatalAstException("parent node mismatch at $typecast")
return noModifications
}
override fun before(whenChoice: WhenChoice, parent: Node): Iterable<IAstModification> {
if(whenChoice.parent!==parent)
throw FatalAstException("parent node mismatch at $whenChoice")
return noModifications
}
override fun before(`when`: When, parent: Node): Iterable<IAstModification> {
if(`when`.parent!==parent)
throw FatalAstException("parent node mismatch at $`when`")
return noModifications
}
override fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> {
if(whileLoop.parent!==parent)
throw FatalAstException("parent node mismatch at $whileLoop")
return noModifications
}
override fun before(functionCallExpr: FunctionCallExpr, parent: Node): Iterable<IAstModification> {
if(functionCallExpr.parent!==parent)
throw FatalAstException("parent node mismatch at $functionCallExpr")
return noModifications
}
override fun before(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
if(functionCallStatement.parent!==parent)
throw FatalAstException("parent node mismatch at $functionCallStatement")
return noModifications
}
override fun before(nop: Nop, parent: Node): Iterable<IAstModification> {
if(nop.parent!==parent)
throw FatalAstException("parent node mismatch at $nop")
return noModifications
}
override fun before(scope: AnonymousScope, parent: Node): Iterable<IAstModification> {
if(scope.parent!==parent)
throw FatalAstException("parent node mismatch at $scope")
return noModifications
}
}

View File

@@ -8,10 +8,8 @@ import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification import prog8.ast.walk.IAstModification
import prog8.compilerinterface.BuiltinFunctions import prog8.compilerinterface.BuiltinFunctions
import prog8.compilerinterface.CompilationOptions import prog8.compilerinterface.CompilationOptions
import prog8.compilerinterface.ICompilationTarget
import prog8.compilerinterface.IErrorReporter import prog8.compilerinterface.IErrorReporter
internal class StatementReorderer(val program: Program, internal class StatementReorderer(val program: Program,
val errors: IErrorReporter, val errors: IErrorReporter,
private val options: CompilationOptions) : AstWalker() { private val options: CompilationOptions) : AstWalker() {
@@ -61,6 +59,14 @@ internal class StatementReorderer(val program: Program,
return noModifications return noModifications
} }
val nextStmt = decl.nextSibling() val nextStmt = decl.nextSibling()
val nextAssign = nextStmt as? Assignment
if(nextAssign!=null && !nextAssign.isAugmentable) {
val target = nextAssign.target.identifier?.targetStatement(program)
if(target === decl) {
// an initializer assignment for a vardecl is already here
return noModifications
}
}
val nextFor = nextStmt as? ForLoop val nextFor = nextStmt as? ForLoop
val hasNextForWithThisLoopvar = nextFor?.loopVar?.nameInSource==listOf(decl.name) val hasNextForWithThisLoopvar = nextFor?.loopVar?.nameInSource==listOf(decl.name)
if (!hasNextForWithThisLoopvar) { if (!hasNextForWithThisLoopvar) {
@@ -204,7 +210,7 @@ internal class StatementReorderer(val program: Program,
return listOf(IAstModification.ReplaceNode(expr.left, cast, expr)) return listOf(IAstModification.ReplaceNode(expr.left, cast, expr))
} }
} }
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionPlaceholder -> {
val func = BuiltinFunctions.getValue(callee.name) val func = BuiltinFunctions.getValue(callee.name)
val paramTypes = func.parameters[argnum].possibleDatatypes val paramTypes = func.parameters[argnum].possibleDatatypes
for(type in paramTypes) { for(type in paramTypes) {
@@ -256,19 +262,19 @@ internal class StatementReorderer(val program: Program,
return noModifications return noModifications
} }
override fun after(whenStatement: WhenStatement, parent: Node): Iterable<IAstModification> { override fun after(`when`: When, parent: Node): Iterable<IAstModification> {
val lastChoiceValues = whenStatement.choices.lastOrNull()?.values val lastChoiceValues = `when`.choices.lastOrNull()?.values
if(lastChoiceValues?.isNotEmpty()==true) { if(lastChoiceValues?.isNotEmpty()==true) {
val elseChoice = whenStatement.choices.indexOfFirst { it.values==null || it.values?.isEmpty()==true } val elseChoice = `when`.choices.indexOfFirst { it.values==null || it.values?.isEmpty()==true }
if(elseChoice>=0) if(elseChoice>=0)
errors.err("else choice must be the last one", whenStatement.choices[elseChoice].position) errors.err("else choice must be the last one", `when`.choices[elseChoice].position)
} }
val choices = whenStatement.choiceValues(program).sortedBy { val choices = `when`.choiceValues(program).sortedBy {
it.first?.first() ?: Int.MAX_VALUE it.first?.first() ?: Int.MAX_VALUE
} }
whenStatement.choices.clear() `when`.choices.clear()
choices.mapTo(whenStatement.choices) { it.second } choices.mapTo(`when`.choices) { it.second }
return noModifications return noModifications
} }

View File

@@ -135,8 +135,8 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
return afterFunctionCallArgs(functionCallStatement) return afterFunctionCallArgs(functionCallStatement)
} }
override fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> { override fun after(functionCallExpr: FunctionCallExpr, parent: Node): Iterable<IAstModification> {
return afterFunctionCallArgs(functionCall) return afterFunctionCallArgs(functionCallExpr)
} }
private fun afterFunctionCallArgs(call: IFunctionCall): Iterable<IAstModification> { private fun afterFunctionCallArgs(call: IFunctionCall): Iterable<IAstModification> {
@@ -180,7 +180,7 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val
} }
} }
} }
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionPlaceholder -> {
val func = BuiltinFunctions.getValue(sub.name) val func = BuiltinFunctions.getValue(sub.name)
func.parameters.zip(call.args).forEachIndexed { index, pair -> func.parameters.zip(call.args).forEachIndexed { index, pair ->
val argItype = pair.second.inferType(program) val argItype = pair.second.inferType(program)

View File

@@ -1,11 +1,9 @@
package prog8.compiler.astprocessing package prog8.compiler.astprocessing
import prog8.ast.IFunctionCall
import prog8.ast.IStatementContainer import prog8.ast.IStatementContainer
import prog8.ast.Node import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.FatalAstException import prog8.ast.base.FatalAstException
import prog8.ast.base.Position
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.AstWalker import prog8.ast.walk.AstWalker
@@ -15,18 +13,18 @@ import prog8.compilerinterface.IErrorReporter
internal class VariousCleanups(val program: Program, val errors: IErrorReporter): AstWalker() { internal class VariousCleanups(val program: Program, val errors: IErrorReporter): AstWalker() {
override fun before(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> { override fun before(nop: Nop, parent: Node): Iterable<IAstModification> {
return listOf(IAstModification.Remove(nopStatement, parent as IStatementContainer)) return listOf(IAstModification.Remove(nop, parent as IStatementContainer))
} }
override fun before(scope: AnonymousScope, parent: Node): Iterable<IAstModification> { override fun after(scope: AnonymousScope, parent: Node): Iterable<IAstModification> {
return if(parent is IStatementContainer) return if(parent is IStatementContainer)
listOf(ScopeFlatten(scope, parent as IStatementContainer)) listOf(ScopeFlatten(scope, parent as IStatementContainer))
else else
noModifications noModifications
} }
class ScopeFlatten(val scope: AnonymousScope, val into: IStatementContainer) : IAstModification { private class ScopeFlatten(val scope: AnonymousScope, val into: IStatementContainer) : IAstModification {
override fun perform() { override fun perform() {
val idx = into.statements.indexOf(scope) val idx = into.statements.indexOf(scope)
if(idx>=0) { if(idx>=0) {
@@ -37,33 +35,7 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter)
} }
} }
override fun before(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> {
return before(functionCallStatement as IFunctionCall, parent, functionCallStatement.position)
}
override fun before(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> {
return before(functionCall as IFunctionCall, parent, functionCall.position)
}
private fun before(functionCall: IFunctionCall, parent: Node, position: Position): Iterable<IAstModification> {
if(functionCall.target.nameInSource==listOf("peek")) {
// peek(a) is synonymous with @(a)
val memread = DirectMemoryRead(functionCall.args.single(), position)
return listOf(IAstModification.ReplaceNode(functionCall as Node, memread, parent))
}
if(functionCall.target.nameInSource==listOf("poke")) {
// poke(a, v) is synonymous with @(a) = v
val tgt = AssignTarget(null, null, DirectMemoryWrite(functionCall.args[0], position), position)
val assign = Assignment(tgt, functionCall.args[1], position)
return listOf(IAstModification.ReplaceNode(functionCall as Node, assign, parent))
}
return noModifications
}
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> { override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
if(typecast.parent!==parent)
throw FatalAstException("parent node mismatch at $typecast")
if(typecast.expression is NumericLiteralValue) { if(typecast.expression is NumericLiteralValue) {
val value = (typecast.expression as NumericLiteralValue).cast(typecast.type) val value = (typecast.expression as NumericLiteralValue).cast(typecast.type)
if(value.isValid) if(value.isValid)
@@ -85,16 +57,7 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter)
return noModifications return noModifications
} }
override fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> {
if(subroutine.parent!==parent)
throw FatalAstException("parent node mismatch at $subroutine")
return noModifications
}
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> { override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
if(assignment.parent!==parent)
throw FatalAstException("parent node mismatch at $assignment")
val nextAssign = assignment.nextSibling() as? Assignment val nextAssign = assignment.nextSibling() as? Assignment
if(nextAssign!=null && nextAssign.target.isSameAs(assignment.target, program)) { if(nextAssign!=null && nextAssign.target.isSameAs(assignment.target, program)) {
if(nextAssign.value isSameAs assignment.value) if(nextAssign.value isSameAs assignment.value)
@@ -104,33 +67,48 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter)
return noModifications return noModifications
} }
override fun after(assignTarget: AssignTarget, parent: Node): Iterable<IAstModification> { override fun after(expr: PrefixExpression, parent: Node): Iterable<IAstModification> {
if(assignTarget.parent!==parent) if(expr.operator=="+") {
throw FatalAstException("parent node mismatch at $assignTarget") // +X --> X
return listOf(IAstModification.ReplaceNode(expr, expr.expression, parent))
}
if(expr.operator=="not") {
val nestedPrefix = expr.expression as? PrefixExpression
if(nestedPrefix!=null && nestedPrefix.operator=="not") {
// NOT NOT X --> X
return listOf(IAstModification.ReplaceNode(expr, nestedPrefix.expression, parent))
}
val comparison = expr.expression as? BinaryExpression
if (comparison != null) {
// NOT COMPARISON ==> inverted COMPARISON
val invertedOperator = invertedComparisonOperator(comparison.operator)
if (invertedOperator != null) {
comparison.operator = invertedOperator
return listOf(IAstModification.ReplaceNode(expr, comparison, parent))
}
}
}
return noModifications return noModifications
} }
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> { override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
if(decl.parent!==parent) if(expr.operator in ComparisonOperators) {
throw FatalAstException("parent node mismatch at $decl") val leftConstVal = expr.left.constValue(program)
return noModifications val rightConstVal = expr.right.constValue(program)
// make sure the constant value is on the right of the comparison expression
if(rightConstVal==null && leftConstVal!=null) {
val newOperator =
when(expr.operator) {
"<" -> ">"
"<=" -> ">="
">" -> "<"
">=" -> "<="
else -> expr.operator
} }
val replacement = BinaryExpression(expr.right, newOperator, expr.left, expr.position)
override fun after(scope: AnonymousScope, parent: Node): Iterable<IAstModification> { return listOf(IAstModification.ReplaceNode(expr, replacement, parent))
if(scope.parent!==parent)
throw FatalAstException("parent node mismatch at $scope")
return noModifications
} }
override fun after(returnStmt: Return, parent: Node): Iterable<IAstModification> {
if(returnStmt.parent!==parent)
throw FatalAstException("parent node mismatch at $returnStmt")
return noModifications
} }
override fun after(identifier: IdentifierReference, parent: Node): Iterable<IAstModification> {
if(identifier.parent!==parent)
throw FatalAstException("parent node mismatch at $identifier")
return noModifications return noModifications
} }
} }

View File

@@ -4,7 +4,7 @@ import prog8.ast.IFunctionCall
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.DataType import prog8.ast.base.DataType
import prog8.ast.expressions.Expression import prog8.ast.expressions.Expression
import prog8.ast.expressions.FunctionCall import prog8.ast.expressions.FunctionCallExpr
import prog8.ast.expressions.TypecastExpression import prog8.ast.expressions.TypecastExpression
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.IAstVisitor import prog8.ast.walk.IAstVisitor
@@ -13,8 +13,8 @@ import prog8.compilerinterface.InternalCompilerException
internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor { internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor {
override fun visit(functionCall: FunctionCall) { override fun visit(functionCallExpr: FunctionCallExpr) {
val error = checkTypes(functionCall as IFunctionCall, program) val error = checkTypes(functionCallExpr as IFunctionCall, program)
if(error!=null) if(error!=null)
throw InternalCompilerException(error) throw InternalCompilerException(error)
} }
@@ -76,7 +76,7 @@ internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor {
} }
} }
} }
else if (target is BuiltinFunctionStatementPlaceholder) { else if (target is BuiltinFunctionPlaceholder) {
val func = BuiltinFunctions.getValue(target.name) val func = BuiltinFunctions.getValue(target.name)
if(call.args.size != func.parameters.size) if(call.args.size != func.parameters.size)
return "invalid number of arguments" return "invalid number of arguments"

View File

@@ -444,7 +444,7 @@ class TestOptimization: FunSpec({
}""" }"""
val result = compileText(C64Target, optimize=true, src, writeAssembly=false).assertSuccess() val result = compileText(C64Target, optimize=true, src, writeAssembly=false).assertSuccess()
result.program.entrypoint.statements.size shouldBe 3 result.program.entrypoint.statements.size shouldBe 3
val ifstmt = result.program.entrypoint.statements[0] as IfStatement val ifstmt = result.program.entrypoint.statements[0] as IfElse
ifstmt.truepart.statements.size shouldBe 1 ifstmt.truepart.statements.size shouldBe 1
(ifstmt.truepart.statements[0] as Assignment).target.identifier!!.nameInSource shouldBe listOf("cx16", "r0") (ifstmt.truepart.statements[0] as Assignment).target.identifier!!.nameInSource shouldBe listOf("cx16", "r0")
val func2 = result.program.entrypoint.statements[2] as Subroutine val func2 = result.program.entrypoint.statements[2] as Subroutine
@@ -633,14 +633,13 @@ class TestOptimization: FunSpec({
uword zz uword zz
zz = 60 zz = 60
ubyte xx ubyte xx
xx = 0
xx = sin8u(xx) xx = sin8u(xx)
xx += 6 xx += 6
*/ */
val stmts = result.program.entrypoint.statements val stmts = result.program.entrypoint.statements
stmts.size shouldBe 8 stmts.size shouldBe 7
stmts.filterIsInstance<VarDecl>().size shouldBe 3 stmts.filterIsInstance<VarDecl>().size shouldBe 3
stmts.filterIsInstance<Assignment>().size shouldBe 5 stmts.filterIsInstance<Assignment>().size shouldBe 4
} }
test("only substitue assignments with 0 after a =0 initializer if it is the same variable") { test("only substitue assignments with 0 after a =0 initializer if it is the same variable") {
@@ -666,6 +665,8 @@ class TestOptimization: FunSpec({
yy = 0 yy = 0
xx += 10 xx += 10
*/ */
printProgram(result.program)
val stmts = result.program.entrypoint.statements val stmts = result.program.entrypoint.statements
stmts.size shouldBe 7 stmts.size shouldBe 7
stmts.filterIsInstance<VarDecl>().size shouldBe 2 stmts.filterIsInstance<VarDecl>().size shouldBe 2

View File

@@ -206,14 +206,14 @@ class TestSubroutines: FunSpec({
val block = module.statements.single() as Block val block = module.statements.single() as Block
val thing = block.statements.filterIsInstance<Subroutine>().single {it.name=="thing"} val thing = block.statements.filterIsInstance<Subroutine>().single {it.name=="thing"}
block.name shouldBe "main" block.name shouldBe "main"
thing.statements.size shouldBe 11 // rr paramdecl, xx, xx assign, yy decl, yy init 0, yy assign, other, other assign 0, zz, zz assign, return thing.statements.size shouldBe 10 // rr paramdecl, xx, xx assign, yy decl, yy assign, other, other assign 0, zz, zz assign, return
val xx = thing.statements[1] as VarDecl val xx = thing.statements[1] as VarDecl
withClue("vardecl init values must have been moved to separate assignments") { withClue("vardecl init values must have been moved to separate assignments") {
xx.value shouldBe null xx.value shouldBe null
} }
val assignXX = thing.statements[2] as Assignment val assignXX = thing.statements[2] as Assignment
val assignYY = thing.statements[5] as Assignment val assignYY = thing.statements[4] as Assignment
val assignZZ = thing.statements[9] as Assignment val assignZZ = thing.statements[8] as Assignment
assignXX.target.identifier!!.nameInSource shouldBe listOf("xx") assignXX.target.identifier!!.nameInSource shouldBe listOf("xx")
assignYY.target.identifier!!.nameInSource shouldBe listOf("yy") assignYY.target.identifier!!.nameInSource shouldBe listOf("yy")
assignZZ.target.identifier!!.nameInSource shouldBe listOf("zz") assignZZ.target.identifier!!.nameInSource shouldBe listOf("zz")
@@ -337,6 +337,6 @@ class TestSubroutines: FunSpec({
stmts.last() shouldBe instanceOf<Subroutine>() stmts.last() shouldBe instanceOf<Subroutine>()
stmts.dropLast(1).last() shouldBe instanceOf<Return>() // this prevents the fallthrough stmts.dropLast(1).last() shouldBe instanceOf<Return>() // this prevents the fallthrough
stmts.dropLast(2).last() shouldBe instanceOf<Jump>() stmts.dropLast(2).last() shouldBe instanceOf<GoSub>()
} }
}) })

View File

@@ -19,7 +19,6 @@ import prog8.ast.base.DataType
import prog8.ast.base.Position import prog8.ast.base.Position
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.printProgram
import prog8.compiler.target.C64Target import prog8.compiler.target.C64Target
import prog8.compiler.target.cbm.Petscii import prog8.compiler.target.cbm.Petscii
import prog8.parser.ParseError import prog8.parser.ParseError
@@ -334,7 +333,7 @@ class TestProg8Parser: FunSpec( {
assertPositionOf(rhsFoo, mpf, 4, 21, 22) assertPositionOf(rhsFoo, mpf, 4, 21, 22)
val declBar = startSub.statements.filterIsInstance<VarDecl>()[1] val declBar = startSub.statements.filterIsInstance<VarDecl>()[1]
assertPositionOf(declBar, mpf, 5, 9, 13) assertPositionOf(declBar, mpf, 5, 9, 13)
val whenStmt = startSub.statements.filterIsInstance<WhenStatement>()[0] val whenStmt = startSub.statements.filterIsInstance<When>()[0]
assertPositionOf(whenStmt, mpf, 6, 9, 12) assertPositionOf(whenStmt, mpf, 6, 9, 12)
assertPositionOf(whenStmt.choices[0], mpf, 7, 13, 14) assertPositionOf(whenStmt.choices[0], mpf, 7, 13, 14)
assertPositionOf(whenStmt.choices[1], mpf, 8, 13, 14) assertPositionOf(whenStmt.choices[1], mpf, 8, 13, 14)

View File

@@ -195,8 +195,8 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program:
} }
} }
override fun visit(functionCall: FunctionCall) { override fun visit(functionCallExpr: FunctionCallExpr) {
printout(functionCall as IFunctionCall) printout(functionCallExpr as IFunctionCall)
} }
override fun visit(functionCallStatement: FunctionCallStatement) { override fun visit(functionCallStatement: FunctionCallStatement) {
@@ -219,9 +219,6 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program:
} }
override fun visit(jump: Jump) { override fun visit(jump: Jump) {
if(jump.isGosub)
output("gosub ")
else
output("goto ") output("goto ")
when { when {
jump.address!=null -> output(jump.address.toHex()) jump.address!=null -> output(jump.address.toHex())
@@ -230,23 +227,32 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program:
} }
} }
override fun visit(ifStatement: IfStatement) { override fun visit(gosub: GoSub) {
output("if ") output("gosub ")
ifStatement.condition.accept(this) when {
output(" ") gosub.address!=null -> output(gosub.address.toHex())
ifStatement.truepart.accept(this) gosub.generatedLabel!=null -> output(gosub.generatedLabel)
if(ifStatement.elsepart.statements.isNotEmpty()) { gosub.identifier!=null -> gosub.identifier.accept(this)
output(" else ")
ifStatement.elsepart.accept(this)
} }
} }
override fun visit(branchStatement: BranchStatement) { override fun visit(ifElse: IfElse) {
output("if_${branchStatement.condition.toString().lowercase()} ") output("if ")
branchStatement.truepart.accept(this) ifElse.condition.accept(this)
if(branchStatement.elsepart.statements.isNotEmpty()) { output(" ")
ifElse.truepart.accept(this)
if(ifElse.elsepart.statements.isNotEmpty()) {
output(" else ") output(" else ")
branchStatement.elsepart.accept(this) ifElse.elsepart.accept(this)
}
}
override fun visit(branch: Branch) {
output("if_${branch.condition.toString().lowercase()} ")
branch.truepart.accept(this)
if(branch.elsepart.statements.isNotEmpty()) {
output(" else ")
branch.elsepart.accept(this)
} }
} }
@@ -416,12 +422,12 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program:
outputlni("}}") outputlni("}}")
} }
override fun visit(whenStatement: WhenStatement) { override fun visit(`when`: When) {
output("when ") output("when ")
whenStatement.condition.accept(this) `when`.condition.accept(this)
outputln(" {") outputln(" {")
scopelevel++ scopelevel++
whenStatement.choices.forEach { it.accept(this) } `when`.choices.forEach { it.accept(this) }
scopelevel-- scopelevel--
outputlni("}") outputlni("}")
} }

View File

@@ -39,14 +39,14 @@ interface IStatementContainer {
when(it) { when(it) {
is Label -> result.add(it) is Label -> result.add(it)
is IStatementContainer -> find(it) is IStatementContainer -> find(it)
is IfStatement -> { is IfElse -> {
find(it.truepart) find(it.truepart)
find(it.elsepart) find(it.elsepart)
} }
is UntilLoop -> find(it.body) is UntilLoop -> find(it.body)
is RepeatLoop -> find(it.body) is RepeatLoop -> find(it.body)
is WhileLoop -> find(it.body) is WhileLoop -> find(it.body)
is WhenStatement -> it.choices.forEach { choice->find(choice.statements) } is When -> it.choices.forEach { choice->find(choice.statements) }
else -> { /* do nothing */ } else -> { /* do nothing */ }
} }
} }
@@ -87,12 +87,12 @@ interface IStatementContainer {
if(found!=null) if(found!=null)
return found return found
} }
is IfStatement -> { is IfElse -> {
val found = stmt.truepart.searchSymbol(name) ?: stmt.elsepart.searchSymbol(name) val found = stmt.truepart.searchSymbol(name) ?: stmt.elsepart.searchSymbol(name)
if(found!=null) if(found!=null)
return found return found
} }
is BranchStatement -> { is Branch -> {
val found = stmt.truepart.searchSymbol(name) ?: stmt.elsepart.searchSymbol(name) val found = stmt.truepart.searchSymbol(name) ?: stmt.elsepart.searchSymbol(name)
if(found!=null) if(found!=null)
return found return found
@@ -117,7 +117,7 @@ interface IStatementContainer {
if(found!=null) if(found!=null)
return found return found
} }
is WhenStatement -> { is When -> {
stmt.choices.forEach { stmt.choices.forEach {
val found = it.statements.searchSymbol(name) val found = it.statements.searchSymbol(name)
if(found!=null) if(found!=null)

View File

@@ -280,12 +280,12 @@ private fun Prog8ANTLRParser.Functioncall_stmtContext.toAst(): Statement {
FunctionCallStatement(location, expression_list().toAst().toMutableList(), void, toPosition()) FunctionCallStatement(location, expression_list().toAst().toMutableList(), void, toPosition())
} }
private fun Prog8ANTLRParser.FunctioncallContext.toAst(): FunctionCall { private fun Prog8ANTLRParser.FunctioncallContext.toAst(): FunctionCallExpr {
val location = scoped_identifier().toAst() val location = scoped_identifier().toAst()
return if(expression_list() == null) return if(expression_list() == null)
FunctionCall(location, mutableListOf(), toPosition()) FunctionCallExpr(location, mutableListOf(), toPosition())
else else
FunctionCall(location, expression_list().toAst().toMutableList(), toPosition()) FunctionCallExpr(location, expression_list().toAst().toMutableList(), toPosition())
} }
private fun Prog8ANTLRParser.InlineasmContext.toAst(): InlineAssembly { private fun Prog8ANTLRParser.InlineasmContext.toAst(): InlineAssembly {
@@ -516,28 +516,28 @@ private fun Prog8ANTLRParser.BooleanliteralContext.toAst() = when(text) {
private fun Prog8ANTLRParser.ArrayliteralContext.toAst() : Array<Expression> = private fun Prog8ANTLRParser.ArrayliteralContext.toAst() : Array<Expression> =
expression().map { it.toAst() }.toTypedArray() expression().map { it.toAst() }.toTypedArray()
private fun Prog8ANTLRParser.If_stmtContext.toAst(): IfStatement { private fun Prog8ANTLRParser.If_stmtContext.toAst(): IfElse {
val condition = expression().toAst() val condition = expression().toAst()
val trueStatements = statement_block()?.toAst() ?: mutableListOf(statement().toAst()) val trueStatements = statement_block()?.toAst() ?: mutableListOf(statement().toAst())
val elseStatements = else_part()?.toAst() ?: mutableListOf() val elseStatements = else_part()?.toAst() ?: mutableListOf()
val trueScope = AnonymousScope(trueStatements, statement_block()?.toPosition() val trueScope = AnonymousScope(trueStatements, statement_block()?.toPosition()
?: statement().toPosition()) ?: statement().toPosition())
val elseScope = AnonymousScope(elseStatements, else_part()?.toPosition() ?: toPosition()) val elseScope = AnonymousScope(elseStatements, else_part()?.toPosition() ?: toPosition())
return IfStatement(condition, trueScope, elseScope, toPosition()) return IfElse(condition, trueScope, elseScope, toPosition())
} }
private fun Prog8ANTLRParser.Else_partContext.toAst(): MutableList<Statement> { private fun Prog8ANTLRParser.Else_partContext.toAst(): MutableList<Statement> {
return statement_block()?.toAst() ?: mutableListOf(statement().toAst()) return statement_block()?.toAst() ?: mutableListOf(statement().toAst())
} }
private fun Prog8ANTLRParser.Branch_stmtContext.toAst(): BranchStatement { private fun Prog8ANTLRParser.Branch_stmtContext.toAst(): Branch {
val branchcondition = branchcondition().toAst() val branchcondition = branchcondition().toAst()
val trueStatements = statement_block()?.toAst() ?: mutableListOf(statement().toAst()) val trueStatements = statement_block()?.toAst() ?: mutableListOf(statement().toAst())
val elseStatements = else_part()?.toAst() ?: mutableListOf() val elseStatements = else_part()?.toAst() ?: mutableListOf()
val trueScope = AnonymousScope(trueStatements, statement_block()?.toPosition() val trueScope = AnonymousScope(trueStatements, statement_block()?.toPosition()
?: statement().toPosition()) ?: statement().toPosition())
val elseScope = AnonymousScope(elseStatements, else_part()?.toPosition() ?: toPosition()) val elseScope = AnonymousScope(elseStatements, else_part()?.toPosition() ?: toPosition())
return BranchStatement(branchcondition, trueScope, elseScope, toPosition()) return Branch(branchcondition, trueScope, elseScope, toPosition())
} }
private fun Prog8ANTLRParser.BranchconditionContext.toAst() = BranchCondition.valueOf( private fun Prog8ANTLRParser.BranchconditionContext.toAst() = BranchCondition.valueOf(
@@ -581,10 +581,10 @@ private fun Prog8ANTLRParser.UntilloopContext.toAst(): UntilLoop {
return UntilLoop(scope, untilCondition, toPosition()) return UntilLoop(scope, untilCondition, toPosition())
} }
private fun Prog8ANTLRParser.WhenstmtContext.toAst(): WhenStatement { private fun Prog8ANTLRParser.WhenstmtContext.toAst(): When {
val condition = expression().toAst() val condition = expression().toAst()
val choices = this.when_choice()?.map { it.toAst() }?.toMutableList() ?: mutableListOf() val choices = this.when_choice()?.map { it.toAst() }?.toMutableList() ?: mutableListOf()
return WhenStatement(condition, choices, toPosition()) return When(condition, choices, toPosition())
} }
private fun Prog8ANTLRParser.When_choiceContext.toAst(): WhenChoice { private fun Prog8ANTLRParser.When_choiceContext.toAst(): WhenChoice {

View File

@@ -15,6 +15,17 @@ val AugmentAssignmentOperators = setOf("+", "-", "/", "*", "**", "&", "|", "^",
val LogicalOperators = setOf("and", "or", "xor", "not") val LogicalOperators = setOf("and", "or", "xor", "not")
val BitwiseOperators = setOf("&", "|", "^") val BitwiseOperators = setOf("&", "|", "^")
fun invertedComparisonOperator(operator: String) =
when (operator) {
"==" -> "!="
"!=" -> "=="
"<" -> ">="
">" -> "<="
"<=" -> ">"
">=" -> "<"
else -> null
}
sealed class Expression: Node { sealed class Expression: Node {
abstract override fun copy(): Expression abstract override fun copy(): Expression
@@ -55,8 +66,8 @@ sealed class Expression: Node {
is RangeExpr -> { is RangeExpr -> {
(other is RangeExpr && other.from==from && other.to==to && other.step==step) (other is RangeExpr && other.from==from && other.to==to && other.step==step)
} }
is FunctionCall -> { is FunctionCallExpr -> {
(other is FunctionCall && other.target.nameInSource == target.nameInSource (other is FunctionCallExpr && other.target.nameInSource == target.nameInSource
&& other.args.size == args.size && other.args.size == args.size
&& other.args.zip(args).all { it.first isSameAs it.second } ) && other.args.zip(args).all { it.first isSameAs it.second } )
} }
@@ -476,7 +487,12 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed
} }
override fun referencesIdentifier(nameInSource: List<String>) = false override fun referencesIdentifier(nameInSource: List<String>) = false
override fun constValue(program: Program) = this override fun constValue(program: Program): NumericLiteralValue {
return copy().also {
if(::parent.isInitialized)
it.parent = parent
}
}
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@@ -795,7 +811,7 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
fun targetStatement(program: Program) = fun targetStatement(program: Program) =
if(nameInSource.size==1 && nameInSource[0] in program.builtinFunctions.names) if(nameInSource.size==1 && nameInSource[0] in program.builtinFunctions.names)
BuiltinFunctionStatementPlaceholder(nameInSource[0], position, parent) BuiltinFunctionPlaceholder(nameInSource[0], position, parent)
else else
definingScope.lookup(nameInSource) definingScope.lookup(nameInSource)
@@ -852,7 +868,7 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
} }
} }
class FunctionCall(override var target: IdentifierReference, class FunctionCallExpr(override var target: IdentifierReference,
override var args: MutableList<Expression>, override var args: MutableList<Expression>,
override val position: Position) : Expression(), IFunctionCall { override val position: Position) : Expression(), IFunctionCall {
override lateinit var parent: Node override lateinit var parent: Node
@@ -863,7 +879,7 @@ class FunctionCall(override var target: IdentifierReference,
args.forEach { it.linkParents(this) } args.forEach { it.linkParents(this) }
} }
override fun copy() = FunctionCall(target.copy(), args.map { it.copy() }.toMutableList(), position) override fun copy() = FunctionCallExpr(target.copy(), args.map { it.copy() }.toMutableList(), position)
override val isSimple = target.nameInSource.size==1 && (target.nameInSource[0] in arrayOf("msb", "lsb", "peek", "peekw")) override val isSimple = target.nameInSource.size==1 && (target.nameInSource[0] in arrayOf("msb", "lsb", "peek", "peekw"))
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
@@ -909,7 +925,7 @@ class FunctionCall(override var target: IdentifierReference,
return InferredTypes.knownFor(constVal.type) return InferredTypes.knownFor(constVal.type)
val stmt = target.targetStatement(program) ?: return InferredTypes.unknown() val stmt = target.targetStatement(program) ?: return InferredTypes.unknown()
when (stmt) { when (stmt) {
is BuiltinFunctionStatementPlaceholder -> { is BuiltinFunctionPlaceholder -> {
if(target.nameInSource[0] == "set_carry" || target.nameInSource[0]=="set_irqd" || if(target.nameInSource[0] == "set_carry" || target.nameInSource[0]=="set_irqd" ||
target.nameInSource[0] == "clear_carry" || target.nameInSource[0]=="clear_irqd") { target.nameInSource[0] == "clear_carry" || target.nameInSource[0]=="clear_irqd") {
return InferredTypes.void() // these have no return value return InferredTypes.void() // these have no return value
@@ -933,3 +949,13 @@ class FunctionCall(override var target: IdentifierReference,
} }
} }
} }
fun invertCondition(cond: Expression): BinaryExpression? {
if(cond is BinaryExpression) {
val invertedOperator = invertedComparisonOperator(cond.operator)
if (invertedOperator != null)
return BinaryExpression(cond.left, invertedOperator, cond.right, cond.position)
}
return null
}

View File

@@ -49,7 +49,7 @@ sealed class Statement : Node {
} }
class BuiltinFunctionStatementPlaceholder(val name: String, override val position: Position, override var parent: Node) : Statement() { class BuiltinFunctionPlaceholder(val name: String, override val position: Position, override var parent: Node) : Statement() {
override fun linkParents(parent: Node) {} override fun linkParents(parent: Node) {}
override fun accept(visitor: IAstVisitor) = throw FatalAstException("should not iterate over this node") override fun accept(visitor: IAstVisitor) = throw FatalAstException("should not iterate over this node")
override fun accept(visitor: AstWalker, parent: Node) = throw FatalAstException("should not iterate over this node") override fun accept(visitor: AstWalker, parent: Node) = throw FatalAstException("should not iterate over this node")
@@ -130,7 +130,7 @@ data class Label(override val name: String, override val position: Position) : S
override fun toString()= "Label(name=$name, pos=$position)" override fun toString()= "Label(name=$name, pos=$position)"
} }
open class Return(var value: Expression?, final override val position: Position) : Statement() { class Return(var value: Expression?, final override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -171,7 +171,7 @@ enum class ZeropageWish {
NOT_IN_ZEROPAGE NOT_IN_ZEROPAGE
} }
open class VarDecl(val type: VarDeclType, class VarDecl(val type: VarDeclType,
private val declaredDatatype: DataType, private val declaredDatatype: DataType,
val zeropage: ZeropageWish, val zeropage: ZeropageWish,
var arraysize: ArrayIndex?, var arraysize: ArrayIndex?,
@@ -299,7 +299,7 @@ class ArrayIndex(var indexExpr: Expression,
override fun copy() = ArrayIndex(indexExpr.copy(), position) override fun copy() = ArrayIndex(indexExpr.copy(), position)
} }
open class Assignment(var target: AssignTarget, var value: Expression, final override val position: Position) : Statement() { class Assignment(var target: AssignTarget, var value: Expression, final override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -499,12 +499,11 @@ class PostIncrDecr(var target: AssignTarget, val operator: String, override val
override fun toString() = "PostIncrDecr(op: $operator, target: $target, pos=$position)" override fun toString() = "PostIncrDecr(op: $operator, target: $target, pos=$position)"
} }
open class Jump(val address: UInt?, class Jump(val address: UInt?,
val identifier: IdentifierReference?, val identifier: IdentifierReference?,
val generatedLabel: String?, // can be used in code generation scenarios val generatedLabel: String?, // can be used in code generation scenarios
override val position: Position) : Statement() { override val position: Position) : Statement() {
override lateinit var parent: Node override lateinit var parent: Node
open val isGosub = false
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
this.parent = parent this.parent = parent
@@ -521,11 +520,22 @@ open class Jump(val address: UInt?,
} }
// a GoSub is ONLY created internally for calling subroutines // a GoSub is ONLY created internally for calling subroutines
class GoSub(address: UInt?, identifier: IdentifierReference?, generatedLabel: String?, position: Position) : class GoSub(val address: UInt?,
Jump(address, identifier, generatedLabel, position) { val identifier: IdentifierReference?,
val generatedLabel: String?, // can be used in code generation scenarios
override val position: Position) : Statement() {
override lateinit var parent: Node
override val isGosub = true override fun linkParents(parent: Node) {
this.parent = parent
identifier?.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun copy() = GoSub(address, identifier?.copy(), generatedLabel, position) override fun copy() = GoSub(address, identifier?.copy(), generatedLabel, position)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
override fun toString() = override fun toString() =
"GoSub(addr: $address, identifier: $identifier, label: $generatedLabel; pos=$position)" "GoSub(addr: $address, identifier: $identifier, label: $generatedLabel; pos=$position)"
} }
@@ -595,7 +605,7 @@ class AnonymousScope(override var statements: MutableList<Statement>,
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
class NopStatement(override val position: Position): Statement() { class Nop(override val position: Position): Statement() {
override lateinit var parent: Node override lateinit var parent: Node
override fun linkParents(parent: Node) { override fun linkParents(parent: Node) {
@@ -603,7 +613,7 @@ class NopStatement(override val position: Position): Statement() {
} }
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun copy() = NopStatement(position) override fun copy() = Nop(position)
override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
@@ -744,7 +754,7 @@ open class SubroutineParameter(val name: String,
override fun toString() = "Param($type:$name)" override fun toString() = "Param($type:$name)"
} }
class IfStatement(var condition: Expression, class IfElse(var condition: Expression,
var truepart: AnonymousScope, var truepart: AnonymousScope,
var elsepart: AnonymousScope, var elsepart: AnonymousScope,
override val position: Position) : Statement() { override val position: Position) : Statement() {
@@ -774,7 +784,7 @@ class IfStatement(var condition: Expression,
} }
class BranchStatement(var condition: BranchCondition, class Branch(var condition: BranchCondition,
var truepart: AnonymousScope, var truepart: AnonymousScope,
var elsepart: AnonymousScope, var elsepart: AnonymousScope,
override val position: Position) : Statement() { override val position: Position) : Statement() {
@@ -909,7 +919,7 @@ class UntilLoop(var body: AnonymousScope,
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
} }
class WhenStatement(var condition: Expression, class When(var condition: Expression,
var choices: MutableList<WhenChoice>, var choices: MutableList<WhenChoice>,
override val position: Position): Statement() { override val position: Position): Statement() {
override lateinit var parent: Node override lateinit var parent: Node

View File

@@ -85,7 +85,7 @@ abstract class AstWalker {
open fun before(assignTarget: AssignTarget, parent: Node): Iterable<IAstModification> = noModifications open fun before(assignTarget: AssignTarget, parent: Node): Iterable<IAstModification> = noModifications
open fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> = noModifications open fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> = noModifications
open fun before(block: Block, parent: Node): Iterable<IAstModification> = noModifications open fun before(block: Block, parent: Node): Iterable<IAstModification> = noModifications
open fun before(branchStatement: BranchStatement, parent: Node): Iterable<IAstModification> = noModifications open fun before(branch: Branch, parent: Node): Iterable<IAstModification> = noModifications
open fun before(breakStmt: Break, parent: Node): Iterable<IAstModification> = noModifications open fun before(breakStmt: Break, parent: Node): Iterable<IAstModification> = noModifications
open fun before(decl: VarDecl, parent: Node): Iterable<IAstModification> = noModifications open fun before(decl: VarDecl, parent: Node): Iterable<IAstModification> = noModifications
open fun before(directive: Directive, parent: Node): Iterable<IAstModification> = noModifications open fun before(directive: Directive, parent: Node): Iterable<IAstModification> = noModifications
@@ -93,17 +93,18 @@ abstract class AstWalker {
open fun before(expr: PrefixExpression, parent: Node): Iterable<IAstModification> = noModifications open fun before(expr: PrefixExpression, parent: Node): Iterable<IAstModification> = noModifications
open fun before(forLoop: ForLoop, parent: Node): Iterable<IAstModification> = noModifications open fun before(forLoop: ForLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun before(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> = noModifications open fun before(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun before(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> = noModifications open fun before(functionCallExpr: FunctionCallExpr, parent: Node): Iterable<IAstModification> = noModifications
open fun before(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> = noModifications open fun before(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> = noModifications
open fun before(identifier: IdentifierReference, parent: Node): Iterable<IAstModification> = noModifications open fun before(identifier: IdentifierReference, parent: Node): Iterable<IAstModification> = noModifications
open fun before(ifStatement: IfStatement, parent: Node): Iterable<IAstModification> = noModifications open fun before(ifElse: IfElse, parent: Node): Iterable<IAstModification> = noModifications
open fun before(inlineAssembly: InlineAssembly, parent: Node): Iterable<IAstModification> = noModifications open fun before(inlineAssembly: InlineAssembly, parent: Node): Iterable<IAstModification> = noModifications
open fun before(jump: Jump, parent: Node): Iterable<IAstModification> = noModifications open fun before(jump: Jump, parent: Node): Iterable<IAstModification> = noModifications
open fun before(gosub: GoSub, parent: Node): Iterable<IAstModification> = noModifications
open fun before(label: Label, parent: Node): Iterable<IAstModification> = noModifications open fun before(label: Label, parent: Node): Iterable<IAstModification> = noModifications
open fun before(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> = noModifications open fun before(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> = noModifications
open fun before(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> = noModifications open fun before(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> = noModifications
open fun before(module: Module, parent: Node): Iterable<IAstModification> = noModifications open fun before(module: Module, parent: Node): Iterable<IAstModification> = noModifications
open fun before(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> = noModifications open fun before(nop: Nop, parent: Node): Iterable<IAstModification> = noModifications
open fun before(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications open fun before(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications
open fun before(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications open fun before(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications
open fun before(program: Program): Iterable<IAstModification> = noModifications open fun before(program: Program): Iterable<IAstModification> = noModifications
@@ -116,7 +117,7 @@ abstract class AstWalker {
open fun before(subroutine: Subroutine, parent: Node): Iterable<IAstModification> = noModifications open fun before(subroutine: Subroutine, parent: Node): Iterable<IAstModification> = noModifications
open fun before(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> = noModifications open fun before(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> = noModifications
open fun before(whenChoice: WhenChoice, parent: Node): Iterable<IAstModification> = noModifications open fun before(whenChoice: WhenChoice, parent: Node): Iterable<IAstModification> = noModifications
open fun before(whenStatement: WhenStatement, parent: Node): Iterable<IAstModification> = noModifications open fun before(whenStmt: When, parent: Node): Iterable<IAstModification> = noModifications
open fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications open fun before(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun after(addressOf: AddressOf, parent: Node): Iterable<IAstModification> = noModifications open fun after(addressOf: AddressOf, parent: Node): Iterable<IAstModification> = noModifications
@@ -125,26 +126,27 @@ abstract class AstWalker {
open fun after(assignTarget: AssignTarget, parent: Node): Iterable<IAstModification> = noModifications open fun after(assignTarget: AssignTarget, parent: Node): Iterable<IAstModification> = noModifications
open fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> = noModifications open fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> = noModifications
open fun after(block: Block, parent: Node): Iterable<IAstModification> = noModifications open fun after(block: Block, parent: Node): Iterable<IAstModification> = noModifications
open fun after(branchStatement: BranchStatement, parent: Node): Iterable<IAstModification> = noModifications open fun after(branch: Branch, parent: Node): Iterable<IAstModification> = noModifications
open fun after(breakStmt: Break, parent: Node): Iterable<IAstModification> = noModifications open fun after(breakStmt: Break, parent: Node): Iterable<IAstModification> = noModifications
open fun after(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder, parent: Node): Iterable<IAstModification> = noModifications open fun after(builtinFunctionPlaceholder: BuiltinFunctionPlaceholder, parent: Node): Iterable<IAstModification> = noModifications
open fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> = noModifications open fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> = noModifications
open fun after(directive: Directive, parent: Node): Iterable<IAstModification> = noModifications open fun after(directive: Directive, parent: Node): Iterable<IAstModification> = noModifications
open fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> = noModifications open fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> = noModifications
open fun after(expr: PrefixExpression, parent: Node): Iterable<IAstModification> = noModifications open fun after(expr: PrefixExpression, parent: Node): Iterable<IAstModification> = noModifications
open fun after(forLoop: ForLoop, parent: Node): Iterable<IAstModification> = noModifications open fun after(forLoop: ForLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun after(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> = noModifications open fun after(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun after(functionCall: FunctionCall, parent: Node): Iterable<IAstModification> = noModifications open fun after(functionCallExpr: FunctionCallExpr, parent: Node): Iterable<IAstModification> = noModifications
open fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> = noModifications open fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable<IAstModification> = noModifications
open fun after(identifier: IdentifierReference, parent: Node): Iterable<IAstModification> = noModifications open fun after(identifier: IdentifierReference, parent: Node): Iterable<IAstModification> = noModifications
open fun after(ifStatement: IfStatement, parent: Node): Iterable<IAstModification> = noModifications open fun after(ifElse: IfElse, parent: Node): Iterable<IAstModification> = noModifications
open fun after(inlineAssembly: InlineAssembly, parent: Node): Iterable<IAstModification> = noModifications open fun after(inlineAssembly: InlineAssembly, parent: Node): Iterable<IAstModification> = noModifications
open fun after(jump: Jump, parent: Node): Iterable<IAstModification> = noModifications open fun after(jump: Jump, parent: Node): Iterable<IAstModification> = noModifications
open fun after(gosub: GoSub, parent: Node): Iterable<IAstModification> = noModifications
open fun after(label: Label, parent: Node): Iterable<IAstModification> = noModifications open fun after(label: Label, parent: Node): Iterable<IAstModification> = noModifications
open fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> = noModifications open fun after(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> = noModifications
open fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> = noModifications open fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable<IAstModification> = noModifications
open fun after(module: Module, parent: Node): Iterable<IAstModification> = noModifications open fun after(module: Module, parent: Node): Iterable<IAstModification> = noModifications
open fun after(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> = noModifications open fun after(nop: Nop, parent: Node): Iterable<IAstModification> = noModifications
open fun after(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications open fun after(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications
open fun after(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications open fun after(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications
open fun after(program: Program): Iterable<IAstModification> = noModifications open fun after(program: Program): Iterable<IAstModification> = noModifications
@@ -157,7 +159,7 @@ abstract class AstWalker {
open fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> = noModifications open fun after(subroutine: Subroutine, parent: Node): Iterable<IAstModification> = noModifications
open fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> = noModifications open fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> = noModifications
open fun after(whenChoice: WhenChoice, parent: Node): Iterable<IAstModification> = noModifications open fun after(whenChoice: WhenChoice, parent: Node): Iterable<IAstModification> = noModifications
open fun after(whenStatement: WhenStatement, parent: Node): Iterable<IAstModification> = noModifications open fun after(whenStmt: When, parent: Node): Iterable<IAstModification> = noModifications
open fun after(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications open fun after(whileLoop: WhileLoop, parent: Node): Iterable<IAstModification> = noModifications
protected val modifications = mutableListOf<Triple<IAstModification, Node, Node>>() protected val modifications = mutableListOf<Triple<IAstModification, Node, Node>>()
@@ -245,11 +247,11 @@ abstract class AstWalker {
track(after(subroutine, parent), subroutine, parent) track(after(subroutine, parent), subroutine, parent)
} }
fun visit(functionCall: FunctionCall, parent: Node) { fun visit(functionCallExpr: FunctionCallExpr, parent: Node) {
track(before(functionCall, parent), functionCall, parent) track(before(functionCallExpr, parent), functionCallExpr, parent)
functionCall.target.accept(this, functionCall) functionCallExpr.target.accept(this, functionCallExpr)
functionCall.args.forEach { it.accept(this, functionCall) } functionCallExpr.args.forEach { it.accept(this, functionCallExpr) }
track(after(functionCall, parent), functionCall, parent) track(after(functionCallExpr, parent), functionCallExpr, parent)
} }
fun visit(functionCallStatement: FunctionCallStatement, parent: Node) { fun visit(functionCallStatement: FunctionCallStatement, parent: Node) {
@@ -270,19 +272,25 @@ abstract class AstWalker {
track(after(jump, parent), jump, parent) track(after(jump, parent), jump, parent)
} }
fun visit(ifStatement: IfStatement, parent: Node) { fun visit(gosub: GoSub, parent: Node) {
track(before(ifStatement, parent), ifStatement, parent) track(before(gosub, parent), gosub, parent)
ifStatement.condition.accept(this, ifStatement) gosub.identifier?.accept(this, gosub)
ifStatement.truepart.accept(this, ifStatement) track(after(gosub, parent), gosub, parent)
ifStatement.elsepart.accept(this, ifStatement)
track(after(ifStatement, parent), ifStatement, parent)
} }
fun visit(branchStatement: BranchStatement, parent: Node) { fun visit(ifElse: IfElse, parent: Node) {
track(before(branchStatement, parent), branchStatement, parent) track(before(ifElse, parent), ifElse, parent)
branchStatement.truepart.accept(this, branchStatement) ifElse.condition.accept(this, ifElse)
branchStatement.elsepart.accept(this, branchStatement) ifElse.truepart.accept(this, ifElse)
track(after(branchStatement, parent), branchStatement, parent) ifElse.elsepart.accept(this, ifElse)
track(after(ifElse, parent), ifElse, parent)
}
fun visit(branch: Branch, parent: Node) {
track(before(branch, parent), branch, parent)
branch.truepart.accept(this, branch)
branch.elsepart.accept(this, branch)
track(after(branch, parent), branch, parent)
} }
fun visit(range: RangeExpr, parent: Node) { fun visit(range: RangeExpr, parent: Node) {
@@ -422,16 +430,16 @@ abstract class AstWalker {
track(after(inlineAssembly, parent), inlineAssembly, parent) track(after(inlineAssembly, parent), inlineAssembly, parent)
} }
fun visit(nopStatement: NopStatement, parent: Node) { fun visit(nop: Nop, parent: Node) {
track(before(nopStatement, parent), nopStatement, parent) track(before(nop, parent), nop, parent)
track(after(nopStatement, parent), nopStatement, parent) track(after(nop, parent), nop, parent)
} }
fun visit(whenStatement: WhenStatement, parent: Node) { fun visit(whenStmt: When, parent: Node) {
track(before(whenStatement, parent), whenStatement, parent) track(before(whenStmt, parent), whenStmt, parent)
whenStatement.condition.accept(this, whenStatement) whenStmt.condition.accept(this, whenStmt)
whenStatement.choices.forEach { it.accept(this, whenStatement) } whenStmt.choices.forEach { it.accept(this, whenStmt) }
track(after(whenStatement, parent), whenStatement, parent) track(after(whenStmt, parent), whenStmt, parent)
} }
fun visit(whenChoice: WhenChoice, parent: Node) { fun visit(whenChoice: WhenChoice, parent: Node) {

View File

@@ -39,9 +39,9 @@ interface IAstVisitor {
subroutine.statements.forEach { it.accept(this) } subroutine.statements.forEach { it.accept(this) }
} }
fun visit(functionCall: FunctionCall) { fun visit(functionCallExpr: FunctionCallExpr) {
functionCall.target.accept(this) functionCallExpr.target.accept(this)
functionCall.args.forEach { it.accept(this) } functionCallExpr.args.forEach { it.accept(this) }
} }
fun visit(functionCallStatement: FunctionCallStatement) { fun visit(functionCallStatement: FunctionCallStatement) {
@@ -56,15 +56,19 @@ interface IAstVisitor {
jump.identifier?.accept(this) jump.identifier?.accept(this)
} }
fun visit(ifStatement: IfStatement) { fun visit(gosub: GoSub) {
ifStatement.condition.accept(this) gosub.identifier?.accept(this)
ifStatement.truepart.accept(this)
ifStatement.elsepart.accept(this)
} }
fun visit(branchStatement: BranchStatement) { fun visit(ifElse: IfElse) {
branchStatement.truepart.accept(this) ifElse.condition.accept(this)
branchStatement.elsepart.accept(this) ifElse.truepart.accept(this)
ifElse.elsepart.accept(this)
}
fun visit(branch: Branch) {
branch.truepart.accept(this)
branch.elsepart.accept(this)
} }
fun visit(range: RangeExpr) { fun visit(range: RangeExpr) {
@@ -160,12 +164,12 @@ interface IAstVisitor {
fun visit(inlineAssembly: InlineAssembly) { fun visit(inlineAssembly: InlineAssembly) {
} }
fun visit(nopStatement: NopStatement) { fun visit(nop: Nop) {
} }
fun visit(whenStatement: WhenStatement) { fun visit(`when`: When) {
whenStatement.condition.accept(this) `when`.condition.accept(this)
whenStatement.choices.forEach { it.accept(this) } `when`.choices.forEach { it.accept(this) }
} }
fun visit(whenChoice: WhenChoice) { fun visit(whenChoice: WhenChoice) {

View File

@@ -6,7 +6,7 @@ import prog8.ast.Program
import prog8.ast.base.Position import prog8.ast.base.Position
import prog8.ast.base.VarDeclType import prog8.ast.base.VarDeclType
import prog8.ast.expressions.AddressOf import prog8.ast.expressions.AddressOf
import prog8.ast.expressions.FunctionCall import prog8.ast.expressions.FunctionCallExpr
import prog8.ast.expressions.IdentifierReference import prog8.ast.expressions.IdentifierReference
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.IAstVisitor import prog8.ast.walk.IAstVisitor
@@ -59,15 +59,15 @@ class CallGraph(private val program: Program) : IAstVisitor {
super.visit(directive) super.visit(directive)
} }
override fun visit(functionCall: FunctionCall) { override fun visit(functionCallExpr: FunctionCallExpr) {
val otherSub = functionCall.target.targetSubroutine(program) val otherSub = functionCallExpr.target.targetSubroutine(program)
if (otherSub != null) { if (otherSub != null) {
functionCall.definingSubroutine?.let { thisSub -> functionCallExpr.definingSubroutine?.let { thisSub ->
calls[thisSub] = calls.getValue(thisSub) + otherSub calls[thisSub] = calls.getValue(thisSub) + otherSub
calledBy[otherSub] = calledBy.getValue(otherSub) + functionCall calledBy[otherSub] = calledBy.getValue(otherSub) + functionCallExpr
} }
} }
super.visit(functionCall) super.visit(functionCallExpr)
} }
override fun visit(functionCallStatement: FunctionCallStatement) { override fun visit(functionCallStatement: FunctionCallStatement) {
@@ -103,6 +103,17 @@ class CallGraph(private val program: Program) : IAstVisitor {
super.visit(jump) super.visit(jump)
} }
override fun visit(gosub: GoSub) {
val otherSub = gosub.identifier?.targetSubroutine(program)
if (otherSub != null) {
gosub.definingSubroutine?.let { thisSub ->
calls[thisSub] = calls.getValue(thisSub) + otherSub
calledBy[otherSub] = calledBy.getValue(otherSub) + gosub
}
}
super.visit(gosub)
}
override fun visit(identifier: IdentifierReference) { override fun visit(identifier: IdentifierReference) {
allIdentifiersAndTargets[Pair(identifier, identifier.position)] = identifier.targetStatement(program)!! allIdentifiersAndTargets[Pair(identifier, identifier.position)] = identifier.targetStatement(program)!!
} }

View File

@@ -5,7 +5,6 @@ For next compiler release (7.6)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
... ...
Blocked by an official Commander-x16 v39 release Blocked by an official Commander-x16 v39 release
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- simplify cx16.joystick_get2() once this cx16 rom issue is resolved: https://github.com/commanderx16/x16-rom/issues/203 - simplify cx16.joystick_get2() once this cx16 rom issue is resolved: https://github.com/commanderx16/x16-rom/issues/203
@@ -37,10 +36,6 @@ Future
More code optimization ideas More code optimization ideas
^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- remove special code generation for while and util expression
by rewriting while and until expressions into if+jump (just consider them syntactic sugar)
but the result should not produce larger code ofcourse!
- while-expression should now also get the simplifyConditionalExpression() treatment
- byte typed expressions should be evaluated in the accumulator where possible, without (temp)var - byte typed expressions should be evaluated in the accumulator where possible, without (temp)var
for instance value = otherbyte >> 1 --> lda otherbite ; lsr a; sta value for instance value = otherbyte >> 1 --> lda otherbite ; lsr a; sta value
- rewrite multiple choice if into when: - rewrite multiple choice if into when: