diff --git a/compiler/src/prog8/ast/AstToplevel.kt b/compiler/src/prog8/ast/AstToplevel.kt index 9599b0220..6f70598fb 100644 --- a/compiler/src/prog8/ast/AstToplevel.kt +++ b/compiler/src/prog8/ast/AstToplevel.kt @@ -5,7 +5,7 @@ import prog8.ast.expressions.Expression import prog8.ast.expressions.IdentifierReference import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstVisitor -import prog8.ast.processing.IGenericAstModifyingVisitor +import prog8.ast.processing.AstWalker import prog8.ast.statements.* import prog8.functions.BuiltinFunctions import java.nio.file.Path @@ -252,7 +252,7 @@ class Module(override val name: String, fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this) - fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } diff --git a/compiler/src/prog8/ast/base/Extensions.kt b/compiler/src/prog8/ast/base/Extensions.kt index 7eb32ddf1..b71970578 100644 --- a/compiler/src/prog8/ast/base/Extensions.kt +++ b/compiler/src/prog8/ast/base/Extensions.kt @@ -8,13 +8,7 @@ import prog8.optimizer.FlattenAnonymousScopesAndRemoveNops // the name of the subroutine that should be called for every block to initialize its variables -internal const val initvarsSubName="prog8_init_vars" - - -internal fun Program.removeNopsFlattenAnonScopes() { - val flattener = FlattenAnonymousScopesAndRemoveNops() - flattener.visit(this) -} +internal const val initvarsSubName = "prog8_init_vars" internal fun Program.checkValid(compilerOptions: CompilationOptions, errors: ErrorReporter) { @@ -22,13 +16,11 @@ internal fun Program.checkValid(compilerOptions: CompilationOptions, errors: Err checker.visit(this) } - internal fun Program.anonscopeVarsCleanup(errors: ErrorReporter) { - val mover = AnonymousScopeVarsCleanup(errors) + val mover = MoveAnonScopeVarsToSubroutine(errors) mover.visit(this) } - internal fun Program.reorderStatements() { val initvalueCreator = VarInitValueAndAddressOfCreator(this) initvalueCreator.visit(this) @@ -42,9 +34,10 @@ internal fun Program.addTypecasts(errors: ErrorReporter) { caster.visit(this) } -internal fun Module.checkImportedValid(errors: ErrorReporter) { - val checker = ImportedModuleDirectiveRemover(errors) - checker.visit(this) +internal fun Module.checkImportedValid() { + val imr = ImportedModuleDirectiveRemover() + imr.visit(this, this.parent) + imr.applyModifications() } internal fun Program.checkRecursion(errors: ErrorReporter) { @@ -53,18 +46,22 @@ internal fun Program.checkRecursion(errors: ErrorReporter) { checker.processMessages(name) } - internal fun Program.checkIdentifiers(errors: ErrorReporter) { val checker = AstIdentifiersChecker(this, errors) checker.visit(this) - if(modules.map {it.name}.toSet().size != modules.size) { + if (modules.map { it.name }.toSet().size != modules.size) { throw FatalAstException("modules should all be unique") } } - internal fun Program.makeForeverLoops() { val checker = MakeForeverLoops() checker.visit(this) + checker.applyModifications() +} + +internal fun Program.removeNopsFlattenAnonScopes() { + val flattener = FlattenAnonymousScopesAndRemoveNops() + flattener.visit(this) } diff --git a/compiler/src/prog8/ast/expressions/AstExpressions.kt b/compiler/src/prog8/ast/expressions/AstExpressions.kt index 22a873c50..96d0b0129 100644 --- a/compiler/src/prog8/ast/expressions/AstExpressions.kt +++ b/compiler/src/prog8/ast/expressions/AstExpressions.kt @@ -5,7 +5,7 @@ import prog8.ast.antlr.escape import prog8.ast.base.* import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstVisitor -import prog8.ast.processing.IGenericAstModifyingVisitor +import prog8.ast.processing.AstWalker import prog8.ast.statements.* import prog8.compiler.target.CompilationTarget import prog8.functions.BuiltinFunctions @@ -22,7 +22,7 @@ sealed class Expression: Node { abstract fun constValue(program: Program): NumericLiteralValue? abstract fun accept(visitor: IAstModifyingVisitor): Expression abstract fun accept(visitor: IAstVisitor) - abstract fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) + abstract fun accept(visitor: AstWalker, parent: Node) abstract fun referencesIdentifiers(vararg name: String): Boolean // todo: remove this and add identifier usage tracking into CallGraph instead abstract fun inferType(program: Program): InferredTypes.InferredType @@ -61,7 +61,7 @@ class PrefixExpression(val operator: String, var expression: Expression, overrid override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name) override fun inferType(program: Program): InferredTypes.InferredType { @@ -109,7 +109,7 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String) = left.referencesIdentifiers(*name) || right.referencesIdentifiers(*name) override fun inferType(program: Program): InferredTypes.InferredType { @@ -214,7 +214,7 @@ class ArrayIndexedExpression(var identifier: IdentifierReference, override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String) = identifier.referencesIdentifiers(*name) @@ -245,7 +245,7 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String) = expression.referencesIdentifiers(*name) override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(type) @@ -274,7 +274,7 @@ data class AddressOf(var identifier: IdentifierReference, override val position: override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) } class DirectMemoryRead(var addressExpression: Expression, override val position: Position) : Expression(), IAssignable { @@ -287,7 +287,7 @@ class DirectMemoryRead(var addressExpression: Expression, override val position: override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String) = false override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UBYTE) @@ -343,7 +343,7 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun toString(): String = "NumericLiteral(${type.name}:$number)" @@ -430,7 +430,7 @@ class StructLiteralValue(var values: List, override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String) = values.any { it.referencesIdentifiers(*name) } override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.STRUCT) @@ -456,7 +456,7 @@ class StringLiteralValue(val value: String, override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun toString(): String = "'${escape(value)}'" override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.STR) @@ -484,7 +484,7 @@ class ArrayLiteralValue(val type: InferredTypes.InferredType, // inferred be override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun toString(): String = "$value" override fun inferType(program: Program): InferredTypes.InferredType = if(type.isUnknown) type else guessDatatype(program) @@ -567,7 +567,7 @@ class RangeExpr(var from: Expression, override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String): Boolean = from.referencesIdentifiers(*name) || to.referencesIdentifiers(*name) override fun inferType(program: Program): InferredTypes.InferredType { @@ -643,7 +643,7 @@ class RegisterExpr(val register: Register, override val position: Position) : Ex override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String): Boolean = register.name in name override fun toString(): String { @@ -687,7 +687,7 @@ data class IdentifierReference(val nameInSource: List, override val posi override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String): Boolean = nameInSource.last() in name @@ -763,7 +763,7 @@ class FunctionCall(override var target: IdentifierReference, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node)= visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) override fun referencesIdentifiers(vararg name: String): Boolean = target.referencesIdentifiers(*name) || args.any{it.referencesIdentifiers(*name)} diff --git a/compiler/src/prog8/ast/processing/AstWalker.kt b/compiler/src/prog8/ast/processing/AstWalker.kt new file mode 100644 index 000000000..8017f6481 --- /dev/null +++ b/compiler/src/prog8/ast/processing/AstWalker.kt @@ -0,0 +1,407 @@ +package prog8.ast.processing + +import prog8.ast.INameScope +import prog8.ast.Module +import prog8.ast.Node +import prog8.ast.Program +import prog8.ast.base.FatalAstException +import prog8.ast.expressions.* +import prog8.ast.statements.* + + +abstract class AstModification(val node: Node) { + abstract fun perform() + + class Remove(node: Node, val parent: Node) : AstModification(node) { + override fun perform() { + if(parent is INameScope) { + if (!parent.statements.remove(node)) + throw FatalAstException("attempt to remove non-existing node $node") + } else { + throw FatalAstException("parent of a remove modification is not an INameScope") + } + } + } + + class Replace(statement: Statement, val replacement: Statement, val parent: Node) : AstModification(statement) { + override fun perform() { + if(parent is INameScope) { + val idx = parent.statements.indexOf(node) + parent.statements[idx] = replacement + replacement.linkParents(parent) + } else { + throw FatalAstException("parent of a replace modification is not an INameScope") + } + } + } +} + + +abstract class AstWalker { + open fun before(addressOf: AddressOf, parent: Node): Iterable = emptyList() + open fun before(array: ArrayLiteralValue, parent: Node): Iterable = emptyList() + open fun before(arrayIndexedExpression: ArrayIndexedExpression, parent: Node): Iterable = emptyList() + open fun before(assignTarget: AssignTarget, parent: Node): Iterable = emptyList() + open fun before(assignment: Assignment, parent: Node): Iterable = emptyList() + open fun before(block: Block, parent: Node): Iterable = emptyList() + open fun before(branchStatement: BranchStatement, parent: Node): Iterable = emptyList() + open fun before(breakStmt: Break, parent: Node): Iterable = emptyList() + open fun before(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder, parent: Node): Iterable = emptyList() + open fun before(contStmt: Continue, parent: Node): Iterable = emptyList() + open fun before(decl: VarDecl, parent: Node): Iterable = emptyList() + open fun before(directive: Directive, parent: Node): Iterable = emptyList() + open fun before(expr: BinaryExpression, parent: Node): Iterable = emptyList() + open fun before(expr: PrefixExpression, parent: Node): Iterable = emptyList() + open fun before(forLoop: ForLoop, parent: Node): Iterable = emptyList() + open fun before(foreverLoop: ForeverLoop, parent: Node): Iterable = emptyList() + open fun before(functionCall: FunctionCall, parent: Node): Iterable = emptyList() + open fun before(functionCallStatement: FunctionCallStatement, parent: Node): Iterable = emptyList() + open fun before(identifier: IdentifierReference, parent: Node): Iterable = emptyList() + open fun before(ifStatement: IfStatement, parent: Node): Iterable = emptyList() + open fun before(inlineAssembly: InlineAssembly, parent: Node): Iterable = emptyList() + open fun before(jump: Jump, parent: Node): Iterable = emptyList() + open fun before(label: Label, parent: Node): Iterable = emptyList() + open fun before(memread: DirectMemoryRead, parent: Node): Iterable = emptyList() + open fun before(memwrite: DirectMemoryWrite, parent: Node): Iterable = emptyList() + open fun before(module: Module, parent: Node): Iterable = emptyList() + open fun before(nopStatement: NopStatement, parent: Node): Iterable = emptyList() + open fun before(numLiteral: NumericLiteralValue, parent: Node): Iterable = emptyList() + open fun before(postIncrDecr: PostIncrDecr, parent: Node): Iterable = emptyList() + open fun before(program: Program, parent: Node): Iterable = emptyList() + open fun before(range: RangeExpr, parent: Node): Iterable = emptyList() + open fun before(registerExpr: RegisterExpr, parent: Node): Iterable = emptyList() + open fun before(repeatLoop: RepeatLoop, parent: Node): Iterable = emptyList() + open fun before(returnStmt: Return, parent: Node): Iterable = emptyList() + open fun before(scope: AnonymousScope, parent: Node): Iterable = emptyList() + open fun before(string: StringLiteralValue, parent: Node): Iterable = emptyList() + open fun before(structDecl: StructDecl, parent: Node): Iterable = emptyList() + open fun before(structLv: StructLiteralValue, parent: Node): Iterable = emptyList() + open fun before(subroutine: Subroutine, parent: Node): Iterable = emptyList() + open fun before(typecast: TypecastExpression, parent: Node): Iterable = emptyList() + open fun before(whenChoice: WhenChoice, parent: Node): Iterable = emptyList() + open fun before(whenStatement: WhenStatement, parent: Node): Iterable = emptyList() + open fun before(whileLoop: WhileLoop, parent: Node): Iterable = emptyList() + + open fun after(addressOf: AddressOf, parent: Node): Iterable = emptyList() + open fun after(array: ArrayLiteralValue, parent: Node): Iterable = emptyList() + open fun after(arrayIndexedExpression: ArrayIndexedExpression, parent: Node): Iterable = emptyList() + open fun after(assignTarget: AssignTarget, parent: Node): Iterable = emptyList() + open fun after(assignment: Assignment, parent: Node): Iterable = emptyList() + open fun after(block: Block, parent: Node): Iterable = emptyList() + open fun after(branchStatement: BranchStatement, parent: Node): Iterable = emptyList() + open fun after(breakStmt: Break, parent: Node): Iterable = emptyList() + open fun after(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder, parent: Node): Iterable = emptyList() + open fun after(contStmt: Continue, parent: Node): Iterable = emptyList() + open fun after(decl: VarDecl, parent: Node): Iterable = emptyList() + open fun after(directive: Directive, parent: Node): Iterable = emptyList() + open fun after(expr: BinaryExpression, parent: Node): Iterable = emptyList() + open fun after(expr: PrefixExpression, parent: Node): Iterable = emptyList() + open fun after(forLoop: ForLoop, parent: Node): Iterable = emptyList() + open fun after(foreverLoop: ForeverLoop, parent: Node): Iterable = emptyList() + open fun after(functionCall: FunctionCall, parent: Node): Iterable = emptyList() + open fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable = emptyList() + open fun after(identifier: IdentifierReference, parent: Node): Iterable = emptyList() + open fun after(ifStatement: IfStatement, parent: Node): Iterable = emptyList() + open fun after(inlineAssembly: InlineAssembly, parent: Node): Iterable = emptyList() + open fun after(jump: Jump, parent: Node): Iterable = emptyList() + open fun after(label: Label, parent: Node): Iterable = emptyList() + open fun after(memread: DirectMemoryRead, parent: Node): Iterable = emptyList() + open fun after(memwrite: DirectMemoryWrite, parent: Node): Iterable = emptyList() + open fun after(module: Module, parent: Node): Iterable = emptyList() + open fun after(nopStatement: NopStatement, parent: Node): Iterable = emptyList() + open fun after(numLiteral: NumericLiteralValue, parent: Node): Iterable = emptyList() + open fun after(postIncrDecr: PostIncrDecr, parent: Node): Iterable = emptyList() + open fun after(program: Program, parent: Node): Iterable = emptyList() + open fun after(range: RangeExpr, parent: Node): Iterable = emptyList() + open fun after(registerExpr: RegisterExpr, parent: Node): Iterable = emptyList() + open fun after(repeatLoop: RepeatLoop, parent: Node): Iterable = emptyList() + open fun after(returnStmt: Return, parent: Node): Iterable = emptyList() + open fun after(scope: AnonymousScope, parent: Node): Iterable = emptyList() + open fun after(string: StringLiteralValue, parent: Node): Iterable = emptyList() + open fun after(structDecl: StructDecl, parent: Node): Iterable = emptyList() + open fun after(structLv: StructLiteralValue, parent: Node): Iterable = emptyList() + open fun after(subroutine: Subroutine, parent: Node): Iterable = emptyList() + open fun after(typecast: TypecastExpression, parent: Node): Iterable = emptyList() + open fun after(whenChoice: WhenChoice, parent: Node): Iterable = emptyList() + open fun after(whenStatement: WhenStatement, parent: Node): Iterable = emptyList() + open fun after(whileLoop: WhileLoop, parent: Node): Iterable = emptyList() + + private val modifications = mutableListOf>() + + private fun track(mods: Iterable, node: Node, parent: Node) { + for (it in mods) modifications += Triple(it, node, parent) + } + + fun applyModifications() { + modifications.forEach { + it.first.perform() + } + } + + fun visit(program: Program) { + track(before(program, program), program, program) + program.modules.forEach { it.accept(this, program) } + track(after(program, program), program, program) + } + + fun visit(module: Module, parent: Node) { + track(before(module, parent), module, parent) + module.statements.forEach{ it.accept(this, module) } + track(after(module, parent), module, parent) + } + + fun visit(expr: PrefixExpression, parent: Node) { + track(before(expr, parent), expr, parent) + expr.expression.accept(this, expr) + track(after(expr, parent), expr, parent) + } + + fun visit(expr: BinaryExpression, parent: Node) { + track(before(expr, parent), expr, parent) + expr.left.accept(this, expr) + expr.right.accept(this, expr) + track(after(expr, parent), expr, parent) + } + + fun visit(directive: Directive, parent: Node) { + track(before(directive, parent), directive, parent) + track(after(directive, parent), directive, parent) + } + + fun visit(block: Block, parent: Node) { + track(before(block, parent), block, parent) + block.statements.forEach { it.accept(this, block) } + track(after(block, parent), block, parent) + } + + fun visit(decl: VarDecl, parent: Node) { + track(before(decl, parent), decl, parent) + decl.value?.accept(this, decl) + decl.arraysize?.accept(this, decl) + track(after(decl, parent), decl, parent) + } + + fun visit(subroutine: Subroutine, parent: Node) { + track(before(subroutine, parent), subroutine, parent) + subroutine.statements.forEach { it.accept(this, subroutine) } + track(after(subroutine, parent), subroutine, parent) + } + + fun visit(functionCall: FunctionCall, parent: Node) { + track(before(functionCall, parent), functionCall, parent) + functionCall.target.accept(this, functionCall) + functionCall.args.forEach { it.accept(this, functionCall) } + track(after(functionCall, parent), functionCall, parent) + } + + fun visit(functionCallStatement: FunctionCallStatement, parent: Node) { + track(before(functionCallStatement, parent), functionCallStatement, parent) + functionCallStatement.target.accept(this, functionCallStatement) + functionCallStatement.args.forEach { it.accept(this, functionCallStatement) } + track(after(functionCallStatement, parent), functionCallStatement, parent) + } + + fun visit(identifier: IdentifierReference, parent: Node) { + track(before(identifier, parent), identifier, parent) + track(after(identifier, parent), identifier, parent) + } + + fun visit(jump: Jump, parent: Node) { + track(before(jump, parent), jump, parent) + jump.identifier?.accept(this, jump) + track(after(jump, parent), jump, parent) + } + + fun visit(ifStatement: IfStatement, parent: Node) { + track(before(ifStatement, parent), ifStatement, parent) + ifStatement.condition.accept(this, ifStatement) + ifStatement.truepart.accept(this, ifStatement) + ifStatement.elsepart.accept(this, ifStatement) + track(after(ifStatement, parent), ifStatement, parent) + } + + fun visit(branchStatement: BranchStatement, parent: Node) { + track(before(branchStatement, parent), branchStatement, parent) + branchStatement.truepart.accept(this, branchStatement) + branchStatement.elsepart.accept(this, branchStatement) + track(after(branchStatement, parent), branchStatement, parent) + } + + fun visit(range: RangeExpr, parent: Node) { + track(before(range, parent), range, parent) + range.from.accept(this, range) + range.to.accept(this, range) + range.step.accept(this, range) + track(after(range, parent), range, parent) + } + + fun visit(label: Label, parent: Node) { + track(before(label, parent), label, parent) + track(after(label, parent), label, parent) + } + + fun visit(numLiteral: NumericLiteralValue, parent: Node) { + track(before(numLiteral, parent), numLiteral, parent) + track(after(numLiteral, parent), numLiteral, parent) + } + + fun visit(string: StringLiteralValue, parent: Node) { + track(before(string, parent), string, parent) + track(after(string, parent), string, parent) + } + + fun visit(array: ArrayLiteralValue, parent: Node) { + track(before(array, parent), array, parent) + array.value.forEach { v->v.accept(this, array) } + track(after(array, parent), array, parent) + } + + fun visit(assignment: Assignment, parent: Node) { + track(before(assignment, parent), assignment, parent) + assignment.target.accept(this, assignment) + assignment.value.accept(this, assignment) + track(after(assignment, parent), assignment, parent) + } + + fun visit(postIncrDecr: PostIncrDecr, parent: Node) { + track(before(postIncrDecr, parent), postIncrDecr, parent) + postIncrDecr.target.accept(this, postIncrDecr) + track(after(postIncrDecr, parent), postIncrDecr, parent) + } + + fun visit(contStmt: Continue, parent: Node) { + track(before(contStmt, parent), contStmt, parent) + track(after(contStmt, parent), contStmt, parent) + } + + fun visit(breakStmt: Break, parent: Node) { + track(before(breakStmt, parent), breakStmt, parent) + track(after(breakStmt, parent), breakStmt, parent) + } + + fun visit(forLoop: ForLoop, parent: Node) { + track(before(forLoop, parent), forLoop, parent) + forLoop.loopVar?.accept(this, forLoop) + forLoop.iterable.accept(this, forLoop) + forLoop.body.accept(this, forLoop) + track(after(forLoop, parent), forLoop, parent) + } + + fun visit(whileLoop: WhileLoop, parent: Node) { + track(before(whileLoop, parent), whileLoop, parent) + whileLoop.condition.accept(this, whileLoop) + whileLoop.body.accept(this, whileLoop) + track(after(whileLoop, parent), whileLoop, parent) + } + + fun visit(foreverLoop: ForeverLoop, parent: Node) { + track(before(foreverLoop, parent), foreverLoop, parent) + foreverLoop.body.accept(this, foreverLoop) + track(after(foreverLoop, parent), foreverLoop, parent) + } + + fun visit(repeatLoop: RepeatLoop, parent: Node) { + track(before(repeatLoop, parent), repeatLoop, parent) + repeatLoop.untilCondition.accept(this, repeatLoop) + repeatLoop.body.accept(this, repeatLoop) + track(after(repeatLoop, parent), repeatLoop, parent) + } + + fun visit(returnStmt: Return, parent: Node) { + track(before(returnStmt, parent), returnStmt, parent) + returnStmt.value?.accept(this, returnStmt) + track(after(returnStmt, parent), returnStmt, parent) + } + + fun visit(arrayIndexedExpression: ArrayIndexedExpression, parent: Node) { + track(before(arrayIndexedExpression, parent), arrayIndexedExpression, parent) + arrayIndexedExpression.identifier.accept(this, arrayIndexedExpression) + arrayIndexedExpression.arrayspec.accept(this, arrayIndexedExpression) + track(after(arrayIndexedExpression, parent), arrayIndexedExpression, parent) + } + + fun visit(assignTarget: AssignTarget, parent: Node) { + track(before(assignTarget, parent), assignTarget, parent) + assignTarget.arrayindexed?.accept(this, assignTarget) + assignTarget.identifier?.accept(this, assignTarget) + assignTarget.memoryAddress?.accept(this, assignTarget) + track(after(assignTarget, parent), assignTarget, parent) + } + + fun visit(scope: AnonymousScope, parent: Node) { + track(before(scope, parent), scope, parent) + scope.statements.forEach { it.accept(this, scope) } + track(after(scope, parent), scope, parent) + } + + fun visit(typecast: TypecastExpression, parent: Node) { + track(before(typecast, parent), typecast, parent) + typecast.expression.accept(this, typecast) + track(after(typecast, parent), typecast, parent) + } + + fun visit(memread: DirectMemoryRead, parent: Node) { + track(before(memread, parent), memread, parent) + memread.addressExpression.accept(this, memread) + track(after(memread, parent), memread, parent) + } + + fun visit(memwrite: DirectMemoryWrite, parent: Node) { + track(before(memwrite, parent), memwrite, parent) + memwrite.addressExpression.accept(this, memwrite) + track(after(memwrite, parent), memwrite, parent) + } + + fun visit(addressOf: AddressOf, parent: Node) { + track(before(addressOf, parent), addressOf, parent) + addressOf.identifier.accept(this, addressOf) + track(after(addressOf, parent), addressOf, parent) + } + + fun visit(inlineAssembly: InlineAssembly, parent: Node) { + track(before(inlineAssembly, parent), inlineAssembly, parent) + track(after(inlineAssembly, parent), inlineAssembly, parent) + } + + fun visit(registerExpr: RegisterExpr, parent: Node) { + track(before(registerExpr, parent), registerExpr, parent) + track(after(registerExpr, parent), registerExpr, parent) + } + + fun visit(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder, parent: Node) { + track(before(builtinFunctionStatementPlaceholder, parent), builtinFunctionStatementPlaceholder, parent) + track(after(builtinFunctionStatementPlaceholder, parent), builtinFunctionStatementPlaceholder, parent) + } + + fun visit(nopStatement: NopStatement, parent: Node) { + track(before(nopStatement, parent), nopStatement, parent) + track(after(nopStatement, parent), nopStatement, parent) + } + + fun visit(whenStatement: WhenStatement, parent: Node) { + track(before(whenStatement, parent), whenStatement, parent) + whenStatement.condition.accept(this, whenStatement) + whenStatement.choices.forEach { it.accept(this, whenStatement) } + track(after(whenStatement, parent), whenStatement, parent) + } + + fun visit(whenChoice: WhenChoice, parent: Node) { + track(before(whenChoice, parent), whenChoice, parent) + whenChoice.values?.forEach { it.accept(this, whenChoice) } + whenChoice.statements.accept(this, whenChoice) + track(after(whenChoice, parent), whenChoice, parent) + } + + fun visit(structDecl: StructDecl, parent: Node) { + track(before(structDecl, parent), structDecl, parent) + structDecl.statements.forEach { it.accept(this, structDecl) } + track(after(structDecl, parent), structDecl, parent) + } + + fun visit(structLv: StructLiteralValue, parent: Node) { + track(before(structLv, parent), structLv, parent) + structLv.values.forEach { it.accept(this, structLv) } + track(after(structLv, parent), structLv, parent) + } +} + diff --git a/compiler/src/prog8/ast/processing/IGenericAstModifyingVisitor.kt b/compiler/src/prog8/ast/processing/IGenericAstModifyingVisitor.kt deleted file mode 100644 index 0ee48c53b..000000000 --- a/compiler/src/prog8/ast/processing/IGenericAstModifyingVisitor.kt +++ /dev/null @@ -1,373 +0,0 @@ -package prog8.ast.processing - -import prog8.ast.Module -import prog8.ast.Node -import prog8.ast.Program -import prog8.ast.expressions.* -import prog8.ast.statements.* - - -typealias AstModification = (node: Node, parent: Node) -> Unit - - -interface IGenericAstModifyingVisitor { - - fun before(addressOf: AddressOf, parent: Node): List = emptyList() - fun before(array: ArrayLiteralValue, parent: Node): List = emptyList() - fun before(arrayIndexedExpression: ArrayIndexedExpression, parent: Node): List = emptyList() - fun before(assignTarget: AssignTarget, parent: Node): List = emptyList() - fun before(assignment: Assignment, parent: Node): List = emptyList() - fun before(block: Block, parent: Node): List = emptyList() - fun before(branchStatement: BranchStatement, parent: Node): List = emptyList() - fun before(breakStmt: Break, parent: Node): List = emptyList() - fun before(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder, parent: Node): List = emptyList() - fun before(contStmt: Continue, parent: Node): List = emptyList() - fun before(decl: VarDecl, parent: Node): List = emptyList() - fun before(directive: Directive, parent: Node): List = emptyList() - fun before(expr: BinaryExpression, parent: Node): List = emptyList() - fun before(expr: PrefixExpression, parent: Node): List = emptyList() - fun before(forLoop: ForLoop, parent: Node): List = emptyList() - fun before(foreverLoop: ForeverLoop, parent: Node): List = emptyList() - fun before(functionCall: FunctionCall, parent: Node): List = emptyList() - fun before(functionCallStatement: FunctionCallStatement, parent: Node): List = emptyList() - fun before(identifier: IdentifierReference, parent: Node): List = emptyList() - fun before(ifStatement: IfStatement, parent: Node): List = emptyList() - fun before(inlineAssembly: InlineAssembly, parent: Node): List = emptyList() - fun before(jump: Jump, parent: Node): List = emptyList() - fun before(label: Label, parent: Node): List = emptyList() - fun before(memread: DirectMemoryRead, parent: Node): List = emptyList() - fun before(memwrite: DirectMemoryWrite, parent: Node): List = emptyList() - fun before(module: Module, parent: Node): List = emptyList() - fun before(nopStatement: NopStatement, parent: Node): List = emptyList() - fun before(numLiteral: NumericLiteralValue, parent: Node): List = emptyList() - fun before(postIncrDecr: PostIncrDecr, parent: Node): List = emptyList() - fun before(program: Program, parent: Node): List = emptyList() - fun before(range: RangeExpr, parent: Node): List = emptyList() - fun before(registerExpr: RegisterExpr, parent: Node): List = emptyList() - fun before(repeatLoop: RepeatLoop, parent: Node): List = emptyList() - fun before(returnStmt: Return, parent: Node): List = emptyList() - fun before(scope: AnonymousScope, parent: Node): List = emptyList() - fun before(string: StringLiteralValue, parent: Node): List = emptyList() - fun before(structDecl: StructDecl, parent: Node): List = emptyList() - fun before(structLv: StructLiteralValue, parent: Node): List = emptyList() - fun before(subroutine: Subroutine, parent: Node): List = emptyList() - fun before(typecast: TypecastExpression, parent: Node): List = emptyList() - fun before(whenChoice: WhenChoice, parent: Node): List = emptyList() - fun before(whenStatement: WhenStatement, parent: Node): List = emptyList() - fun before(whileLoop: WhileLoop, parent: Node): List = emptyList() - - fun after(addressOf: AddressOf, parent: Node): List = emptyList() - fun after(array: ArrayLiteralValue, parent: Node): List = emptyList() - fun after(arrayIndexedExpression: ArrayIndexedExpression, parent: Node): List = emptyList() - fun after(assignTarget: AssignTarget, parent: Node): List = emptyList() - fun after(assignment: Assignment, parent: Node): List = emptyList() - fun after(block: Block, parent: Node): List = emptyList() - fun after(branchStatement: BranchStatement, parent: Node): List = emptyList() - fun after(breakStmt: Break, parent: Node): List = emptyList() - fun after(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder, parent: Node): List = emptyList() - fun after(contStmt: Continue, parent: Node): List = emptyList() - fun after(decl: VarDecl, parent: Node): List = emptyList() - fun after(directive: Directive, parent: Node): List = emptyList() - fun after(expr: BinaryExpression, parent: Node): List = emptyList() - fun after(expr: PrefixExpression, parent: Node): List = emptyList() - fun after(forLoop: ForLoop, parent: Node): List = emptyList() - fun after(foreverLoop: ForeverLoop, parent: Node): List = emptyList() - fun after(functionCall: FunctionCall, parent: Node): List = emptyList() - fun after(functionCallStatement: FunctionCallStatement, parent: Node): List = emptyList() - fun after(identifier: IdentifierReference, parent: Node): List = emptyList() - fun after(ifStatement: IfStatement, parent: Node): List = emptyList() - fun after(inlineAssembly: InlineAssembly, parent: Node): List = emptyList() - fun after(jump: Jump, parent: Node): List = emptyList() - fun after(label: Label, parent: Node): List = emptyList() - fun after(memread: DirectMemoryRead, parent: Node): List = emptyList() - fun after(memwrite: DirectMemoryWrite, parent: Node): List = emptyList() - fun after(module: Module, parent: Node): List = emptyList() - fun after(nopStatement: NopStatement, parent: Node): List = emptyList() - fun after(numLiteral: NumericLiteralValue, parent: Node): List = emptyList() - fun after(postIncrDecr: PostIncrDecr, parent: Node): List = emptyList() - fun after(program: Program, parent: Node): List = emptyList() - fun after(range: RangeExpr, parent: Node): List = emptyList() - fun after(registerExpr: RegisterExpr, parent: Node): List = emptyList() - fun after(repeatLoop: RepeatLoop, parent: Node): List = emptyList() - fun after(returnStmt: Return, parent: Node): List = emptyList() - fun after(scope: AnonymousScope, parent: Node): List = emptyList() - fun after(string: StringLiteralValue, parent: Node): List = emptyList() - fun after(structDecl: StructDecl, parent: Node): List = emptyList() - fun after(structLv: StructLiteralValue, parent: Node): List = emptyList() - fun after(subroutine: Subroutine, parent: Node): List = emptyList() - fun after(typecast: TypecastExpression, parent: Node): List = emptyList() - fun after(whenChoice: WhenChoice, parent: Node): List = emptyList() - fun after(whenStatement: WhenStatement, parent: Node): List = emptyList() - fun after(whileLoop: WhileLoop, parent: Node): List = emptyList() - - private fun applyModifications(mods: List, node: Node, parent: Node) { - mods.forEach { it.invoke(node, parent) } - } - - fun visit(program: Program) { - applyModifications(before(program, program), program, program) - program.modules.forEach { it.accept(this, program) } - applyModifications(after(program, program), program, program) - } - - fun visit(module: Module, parent: Node) { - applyModifications(before(module, parent), module, parent) - module.statements.forEach{ it.accept(this, module) } - applyModifications(after(module, parent), module, parent) - } - - fun visit(expr: PrefixExpression, parent: Node) { - applyModifications(before(expr, parent), expr, parent) - expr.expression.accept(this, expr) - applyModifications(after(expr, parent), expr, parent) - } - - fun visit(expr: BinaryExpression, parent: Node) { - applyModifications(before(expr, parent), expr, parent) - expr.left.accept(this, expr) - expr.right.accept(this, expr) - applyModifications(after(expr, parent), expr, parent) - } - - fun visit(directive: Directive, parent: Node) { - applyModifications(before(directive, parent), directive, parent) - applyModifications(after(directive, parent), directive, parent) - } - - fun visit(block: Block, parent: Node) { - applyModifications(before(block, parent), block, parent) - block.statements.forEach { it.accept(this, block) } - applyModifications(after(block, parent), block, parent) - } - - fun visit(decl: VarDecl, parent: Node) { - applyModifications(before(decl, parent), decl, parent) - decl.value?.accept(this, decl) - decl.arraysize?.accept(this, decl) - applyModifications(after(decl, parent), decl, parent) - } - - fun visit(subroutine: Subroutine, parent: Node) { - applyModifications(before(subroutine, parent), subroutine, parent) - subroutine.statements.forEach { it.accept(this, subroutine) } - applyModifications(after(subroutine, parent), subroutine, parent) - } - - fun visit(functionCall: FunctionCall, parent: Node) { - applyModifications(before(functionCall, parent), functionCall, parent) - functionCall.target.accept(this, functionCall) - functionCall.args.forEach { it.accept(this, functionCall) } - applyModifications(after(functionCall, parent), functionCall, parent) - } - - fun visit(functionCallStatement: FunctionCallStatement, parent: Node) { - applyModifications(before(functionCallStatement, parent), functionCallStatement, parent) - functionCallStatement.target.accept(this, functionCallStatement) - functionCallStatement.args.forEach { it.accept(this, functionCallStatement) } - applyModifications(after(functionCallStatement, parent), functionCallStatement, parent) - } - - fun visit(identifier: IdentifierReference, parent: Node) { - applyModifications(before(identifier, parent), identifier, parent) - applyModifications(after(identifier, parent), identifier, parent) - } - - fun visit(jump: Jump, parent: Node) { - applyModifications(before(jump, parent), jump, parent) - jump.identifier?.accept(this, jump) - applyModifications(after(jump, parent), jump, parent) - } - - fun visit(ifStatement: IfStatement, parent: Node) { - applyModifications(before(ifStatement, parent), ifStatement, parent) - ifStatement.condition.accept(this, ifStatement) - ifStatement.truepart.accept(this, ifStatement) - ifStatement.elsepart.accept(this, ifStatement) - applyModifications(after(ifStatement, parent), ifStatement, parent) - } - - fun visit(branchStatement: BranchStatement, parent: Node) { - applyModifications(before(branchStatement, parent), branchStatement, parent) - branchStatement.truepart.accept(this, branchStatement) - branchStatement.elsepart.accept(this, branchStatement) - applyModifications(after(branchStatement, parent), branchStatement, parent) - } - - fun visit(range: RangeExpr, parent: Node) { - applyModifications(before(range, parent), range, parent) - range.from.accept(this, range) - range.to.accept(this, range) - range.step.accept(this, range) - applyModifications(after(range, parent), range, parent) - } - - fun visit(label: Label, parent: Node) { - applyModifications(before(label, parent), label, parent) - applyModifications(after(label, parent), label, parent) - } - - fun visit(numLiteral: NumericLiteralValue, parent: Node) { - applyModifications(before(numLiteral, parent), numLiteral, parent) - applyModifications(after(numLiteral, parent), numLiteral, parent) - } - - fun visit(string: StringLiteralValue, parent: Node) { - applyModifications(before(string, parent), string, parent) - applyModifications(after(string, parent), string, parent) - } - - fun visit(array: ArrayLiteralValue, parent: Node) { - applyModifications(before(array, parent), array, parent) - array.value.forEach { v->v.accept(this, array) } - applyModifications(after(array, parent), array, parent) - } - - fun visit(assignment: Assignment, parent: Node) { - applyModifications(before(assignment, parent), assignment, parent) - assignment.target.accept(this, assignment) - assignment.value.accept(this, assignment) - applyModifications(after(assignment, parent), assignment, parent) - } - - fun visit(postIncrDecr: PostIncrDecr, parent: Node) { - applyModifications(before(postIncrDecr, parent), postIncrDecr, parent) - postIncrDecr.target.accept(this, postIncrDecr) - applyModifications(after(postIncrDecr, parent), postIncrDecr, parent) - } - - fun visit(contStmt: Continue, parent: Node) { - applyModifications(before(contStmt, parent), contStmt, parent) - applyModifications(after(contStmt, parent), contStmt, parent) - } - - fun visit(breakStmt: Break, parent: Node) { - applyModifications(before(breakStmt, parent), breakStmt, parent) - applyModifications(after(breakStmt, parent), breakStmt, parent) - } - - fun visit(forLoop: ForLoop, parent: Node) { - applyModifications(before(forLoop, parent), forLoop, parent) - forLoop.loopVar?.accept(this, forLoop) - forLoop.iterable.accept(this, forLoop) - forLoop.body.accept(this, forLoop) - applyModifications(after(forLoop, parent), forLoop, parent) - } - - fun visit(whileLoop: WhileLoop, parent: Node) { - applyModifications(before(whileLoop, parent), whileLoop, parent) - whileLoop.condition.accept(this, whileLoop) - whileLoop.body.accept(this, whileLoop) - applyModifications(after(whileLoop, parent), whileLoop, parent) - } - - fun visit(foreverLoop: ForeverLoop, parent: Node) { - applyModifications(before(foreverLoop, parent), foreverLoop, parent) - foreverLoop.body.accept(this, foreverLoop) - applyModifications(after(foreverLoop, parent), foreverLoop, parent) - } - - fun visit(repeatLoop: RepeatLoop, parent: Node) { - applyModifications(before(repeatLoop, parent), repeatLoop, parent) - repeatLoop.untilCondition.accept(this, repeatLoop) - repeatLoop.body.accept(this, repeatLoop) - applyModifications(after(repeatLoop, parent), repeatLoop, parent) - } - - fun visit(returnStmt: Return, parent: Node) { - applyModifications(before(returnStmt, parent), returnStmt, parent) - returnStmt.value?.accept(this, returnStmt) - applyModifications(after(returnStmt, parent), returnStmt, parent) - } - - fun visit(arrayIndexedExpression: ArrayIndexedExpression, parent: Node) { - applyModifications(before(arrayIndexedExpression, parent), arrayIndexedExpression, parent) - arrayIndexedExpression.identifier.accept(this, arrayIndexedExpression) - arrayIndexedExpression.arrayspec.accept(this, arrayIndexedExpression) - applyModifications(after(arrayIndexedExpression, parent), arrayIndexedExpression, parent) - } - - fun visit(assignTarget: AssignTarget, parent: Node) { - applyModifications(before(assignTarget, parent), assignTarget, parent) - assignTarget.arrayindexed?.accept(this, assignTarget) - assignTarget.identifier?.accept(this, assignTarget) - assignTarget.memoryAddress?.accept(this, assignTarget) - applyModifications(after(assignTarget, parent), assignTarget, parent) - } - - fun visit(scope: AnonymousScope, parent: Node) { - applyModifications(before(scope, parent), scope, parent) - scope.statements.forEach { it.accept(this, scope) } - applyModifications(after(scope, parent), scope, parent) - } - - fun visit(typecast: TypecastExpression, parent: Node) { - applyModifications(before(typecast, parent), typecast, parent) - typecast.expression.accept(this, typecast) - applyModifications(after(typecast, parent), typecast, parent) - } - - fun visit(memread: DirectMemoryRead, parent: Node) { - applyModifications(before(memread, parent), memread, parent) - memread.addressExpression.accept(this, memread) - applyModifications(after(memread, parent), memread, parent) - } - - fun visit(memwrite: DirectMemoryWrite, parent: Node) { - applyModifications(before(memwrite, parent), memwrite, parent) - memwrite.addressExpression.accept(this, memwrite) - applyModifications(after(memwrite, parent), memwrite, parent) - } - - fun visit(addressOf: AddressOf, parent: Node) { - applyModifications(before(addressOf, parent), addressOf, parent) - addressOf.identifier.accept(this, addressOf) - applyModifications(after(addressOf, parent), addressOf, parent) - } - - fun visit(inlineAssembly: InlineAssembly, parent: Node) { - applyModifications(before(inlineAssembly, parent), inlineAssembly, parent) - applyModifications(after(inlineAssembly, parent), inlineAssembly, parent) - } - - fun visit(registerExpr: RegisterExpr, parent: Node) { - applyModifications(before(registerExpr, parent), registerExpr, parent) - applyModifications(after(registerExpr, parent), registerExpr, parent) - } - - fun visit(builtinFunctionStatementPlaceholder: BuiltinFunctionStatementPlaceholder, parent: Node) { - applyModifications(before(builtinFunctionStatementPlaceholder, parent), builtinFunctionStatementPlaceholder, parent) - applyModifications(after(builtinFunctionStatementPlaceholder, parent), builtinFunctionStatementPlaceholder, parent) - } - - fun visit(nopStatement: NopStatement, parent: Node) { - applyModifications(before(nopStatement, parent), nopStatement, parent) - applyModifications(after(nopStatement, parent), nopStatement, parent) - } - - fun visit(whenStatement: WhenStatement, parent: Node) { - applyModifications(before(whenStatement, parent), whenStatement, parent) - whenStatement.condition.accept(this, whenStatement) - whenStatement.choices.forEach { it.accept(this, whenStatement) } - applyModifications(after(whenStatement, parent), whenStatement, parent) - } - - fun visit(whenChoice: WhenChoice, parent: Node) { - applyModifications(before(whenChoice, parent), whenChoice, parent) - whenChoice.values?.forEach { it.accept(this, whenChoice) } - whenChoice.statements.accept(this, whenChoice) - applyModifications(after(whenChoice, parent), whenChoice, parent) - } - - fun visit(structDecl: StructDecl, parent: Node) { - applyModifications(before(structDecl, parent), structDecl, parent) - structDecl.statements.forEach { it.accept(this, structDecl) } - applyModifications(after(structDecl, parent), structDecl, parent) - } - - fun visit(structLv: StructLiteralValue, parent: Node) { - applyModifications(before(structLv, parent), structLv, parent) - structLv.values.forEach { it.accept(this, structLv) } - applyModifications(after(structLv, parent), structLv, parent) - } -} - diff --git a/compiler/src/prog8/ast/processing/ImportedModuleDirectiveRemover.kt b/compiler/src/prog8/ast/processing/ImportedModuleDirectiveRemover.kt index 4fe3b3e1b..8dd4563f5 100644 --- a/compiler/src/prog8/ast/processing/ImportedModuleDirectiveRemover.kt +++ b/compiler/src/prog8/ast/processing/ImportedModuleDirectiveRemover.kt @@ -1,29 +1,20 @@ package prog8.ast.processing -import prog8.ast.Module -import prog8.ast.base.ErrorReporter +import prog8.ast.Node import prog8.ast.statements.Directive -import prog8.ast.statements.Statement -internal class ImportedModuleDirectiveRemover(private val errors: ErrorReporter) : IAstModifyingVisitor { + +internal class ImportedModuleDirectiveRemover: AstWalker() { /** * Most global directives don't apply for imported modules, so remove them */ - override fun visit(module: Module) { - super.visit(module) - val newStatements : MutableList = mutableListOf() - val moduleLevelDirectives = listOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address") - for (sourceStmt in module.statements) { - val stmt = sourceStmt.accept(this) - if(stmt is Directive && stmt.parent is Module) { - if(stmt.directive in moduleLevelDirectives) { - errors.warn("ignoring module directive because it was imported", stmt.position) - continue - } - } - newStatements.add(stmt) + private val moduleLevelDirectives = listOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address") + + override fun before(directive: Directive, parent: Node): Iterable { + if(directive.directive in moduleLevelDirectives) { + return listOf(AstModification.Remove(directive, parent)) } - module.statements = newStatements + return emptyList() } } diff --git a/compiler/src/prog8/ast/processing/MakeForeverLoops.kt b/compiler/src/prog8/ast/processing/MakeForeverLoops.kt index f7143aa7c..e771cce4d 100644 --- a/compiler/src/prog8/ast/processing/MakeForeverLoops.kt +++ b/compiler/src/prog8/ast/processing/MakeForeverLoops.kt @@ -1,24 +1,28 @@ package prog8.ast.processing +import prog8.ast.Node import prog8.ast.expressions.NumericLiteralValue import prog8.ast.statements.ForeverLoop import prog8.ast.statements.RepeatLoop -import prog8.ast.statements.Statement import prog8.ast.statements.WhileLoop -internal class MakeForeverLoops : IAstModifyingVisitor { - override fun visit(whileLoop: WhileLoop): Statement { - val numeric = whileLoop.condition as? NumericLiteralValue - if(numeric!=null && numeric.number.toInt() != 0) { - return ForeverLoop(whileLoop.body, whileLoop.position) + +internal class MakeForeverLoops: AstWalker() { + override fun before(repeatLoop: RepeatLoop, parent: Node): Iterable { + val numeric = repeatLoop.untilCondition as? NumericLiteralValue + if(numeric!=null && numeric.number.toInt() == 0) { + val forever = ForeverLoop(repeatLoop.body, repeatLoop.position) + return listOf(AstModification.Replace(repeatLoop, forever, parent)) } - return super.visit(whileLoop) + return emptyList() } - override fun visit(repeatLoop: RepeatLoop): Statement { - val numeric = repeatLoop.untilCondition as? NumericLiteralValue - if(numeric!=null && numeric.number.toInt() == 0) - return ForeverLoop(repeatLoop.body, repeatLoop.position) - return super.visit(repeatLoop) + override fun before(whileLoop: WhileLoop, parent: Node): Iterable { + val numeric = whileLoop.condition as? NumericLiteralValue + if(numeric!=null && numeric.number.toInt() != 0) { + val forever = ForeverLoop(whileLoop.body, whileLoop.position) + return listOf(AstModification.Replace(whileLoop, forever, parent)) + } + return emptyList() } } diff --git a/compiler/src/prog8/ast/processing/AnonymousScopeVarsCleanup.kt b/compiler/src/prog8/ast/processing/MoveAnonScopeVarsToSubroutine.kt similarity index 93% rename from compiler/src/prog8/ast/processing/AnonymousScopeVarsCleanup.kt rename to compiler/src/prog8/ast/processing/MoveAnonScopeVarsToSubroutine.kt index d04926bac..9e9fadb4e 100644 --- a/compiler/src/prog8/ast/processing/AnonymousScopeVarsCleanup.kt +++ b/compiler/src/prog8/ast/processing/MoveAnonScopeVarsToSubroutine.kt @@ -6,7 +6,7 @@ import prog8.ast.statements.AnonymousScope import prog8.ast.statements.Statement import prog8.ast.statements.VarDecl -class AnonymousScopeVarsCleanup(private val errors: ErrorReporter): IAstModifyingVisitor { +class MoveAnonScopeVarsToSubroutine(private val errors: ErrorReporter): IAstModifyingVisitor { private val varsToMove: MutableMap> = mutableMapOf() override fun visit(program: Program) { diff --git a/compiler/src/prog8/ast/processing/StatementReorderer.kt b/compiler/src/prog8/ast/processing/StatementReorderer.kt index 7b435443e..eebe867ab 100644 --- a/compiler/src/prog8/ast/processing/StatementReorderer.kt +++ b/compiler/src/prog8/ast/processing/StatementReorderer.kt @@ -8,45 +8,6 @@ import prog8.ast.expressions.* import prog8.ast.statements.* -private fun flattenStructAssignmentFromIdentifier(structAssignment: Assignment, program: Program): List { - val identifier = structAssignment.target.identifier!! - val identifierName = identifier.nameInSource.single() - val targetVar = identifier.targetVarDecl(program.namespace)!! - val struct = targetVar.struct!! - when (structAssignment.value) { - is IdentifierReference -> { - val sourceVar = (structAssignment.value as IdentifierReference).targetVarDecl(program.namespace)!! - if (sourceVar.struct == null) - throw FatalAstException("can only assign arrays or structs to structs") - // struct memberwise copy - val sourceStruct = sourceVar.struct!! - if(sourceStruct!==targetVar.struct) { - // structs are not the same in assignment - return listOf() // error will be printed elsewhere - } - return struct.statements.zip(sourceStruct.statements).map { member -> - val targetDecl = member.first as VarDecl - val sourceDecl = member.second as VarDecl - if(targetDecl.name != sourceDecl.name) - throw FatalAstException("struct member mismatch") - val mangled = mangledStructMemberName(identifierName, targetDecl.name) - val idref = IdentifierReference(listOf(mangled), structAssignment.position) - val sourcemangled = mangledStructMemberName(sourceVar.name, sourceDecl.name) - val sourceIdref = IdentifierReference(listOf(sourcemangled), structAssignment.position) - val assign = Assignment(AssignTarget(null, idref, null, null, structAssignment.position), - null, sourceIdref, member.second.position) - assign.linkParents(structAssignment) - assign - } - } - is StructLiteralValue -> { - throw IllegalArgumentException("not going to flatten a structLv assignment here") - } - else -> throw FatalAstException("strange struct value") - } -} - - internal class StatementReorderer(private val program: Program): IAstModifyingVisitor { // Reorders the statements in a way the compiler needs. // - 'main' block must be the very first statement UNLESS it has an address set. @@ -254,4 +215,44 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi return assg } + + + private fun flattenStructAssignmentFromIdentifier(structAssignment: Assignment, program: Program): List { + val identifier = structAssignment.target.identifier!! + val identifierName = identifier.nameInSource.single() + val targetVar = identifier.targetVarDecl(program.namespace)!! + val struct = targetVar.struct!! + when (structAssignment.value) { + is IdentifierReference -> { + val sourceVar = (structAssignment.value as IdentifierReference).targetVarDecl(program.namespace)!! + if (sourceVar.struct == null) + throw FatalAstException("can only assign arrays or structs to structs") + // struct memberwise copy + val sourceStruct = sourceVar.struct!! + if(sourceStruct!==targetVar.struct) { + // structs are not the same in assignment + return listOf() // error will be printed elsewhere + } + return struct.statements.zip(sourceStruct.statements).map { member -> + val targetDecl = member.first as VarDecl + val sourceDecl = member.second as VarDecl + if(targetDecl.name != sourceDecl.name) + throw FatalAstException("struct member mismatch") + val mangled = mangledStructMemberName(identifierName, targetDecl.name) + val idref = IdentifierReference(listOf(mangled), structAssignment.position) + val sourcemangled = mangledStructMemberName(sourceVar.name, sourceDecl.name) + val sourceIdref = IdentifierReference(listOf(sourcemangled), structAssignment.position) + val assign = Assignment(AssignTarget(null, idref, null, null, structAssignment.position), + null, sourceIdref, member.second.position) + assign.linkParents(structAssignment) + assign + } + } + is StructLiteralValue -> { + throw IllegalArgumentException("not going to flatten a structLv assignment here") + } + else -> throw FatalAstException("strange struct value") + } + } + } diff --git a/compiler/src/prog8/ast/statements/AstStatements.kt b/compiler/src/prog8/ast/statements/AstStatements.kt index 63e6fd3f0..e3d5a0920 100644 --- a/compiler/src/prog8/ast/statements/AstStatements.kt +++ b/compiler/src/prog8/ast/statements/AstStatements.kt @@ -5,13 +5,13 @@ import prog8.ast.base.* import prog8.ast.expressions.* import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.processing.IAstVisitor -import prog8.ast.processing.IGenericAstModifyingVisitor +import prog8.ast.processing.AstWalker sealed class Statement : Node { abstract fun accept(visitor: IAstModifyingVisitor) : Statement abstract fun accept(visitor: IAstVisitor) - abstract fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) + abstract fun accept(visitor: AstWalker, parent: Node) fun makeScopedName(name: String): String { // easy way out is to always return the full scoped name. @@ -46,7 +46,7 @@ class BuiltinFunctionStatementPlaceholder(val name: String, override val positio override fun linkParents(parent: Node) {} override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun definingScope(): INameScope = BuiltinFunctionScopePlaceholder override val expensiveToInline = false } @@ -69,7 +69,7 @@ class Block(override val name: String, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "Block(name=$name, address=$address, ${statements.size} statements)" @@ -89,7 +89,7 @@ data class Directive(val directive: String, val args: List, overri override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } data class DirectiveArg(val str: String?, val name: String?, val int: Int?, override val position: Position) : Node { @@ -110,7 +110,7 @@ data class Label(val name: String, override val position: Position) : Statement( override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "Label(name=$name, pos=$position)" @@ -128,7 +128,7 @@ open class Return(var value: Expression?, override val position: Position) : Sta override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "Return($value, pos=$position)" @@ -154,7 +154,7 @@ class Continue(override val position: Position) : Statement() { override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } class Break(override val position: Position) : Statement() { @@ -167,7 +167,7 @@ class Break(override val position: Position) : Statement() { override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } @@ -245,7 +245,7 @@ class VarDecl(val type: VarDeclType, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) val scopedname: String by lazy { makeScopedName(name) } @@ -308,7 +308,7 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node index = index.accept(visitor) } fun accept(visitor: IAstVisitor) = index.accept(visitor) - fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = index.accept(visitor, parent) + fun accept(visitor: AstWalker, parent: Node) = index.accept(visitor, parent) override fun toString(): String { return("ArrayIndex($index, pos=$position)") @@ -330,7 +330,7 @@ open class Assignment(var target: AssignTarget, val aug_op : String?, var value: override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return("Assignment(augop: $aug_op, target: $target, value: $value, pos=$position)") @@ -358,7 +358,7 @@ data class AssignTarget(val register: Register?, fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this) - fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) companion object { fun fromExpr(expr: Expression): AssignTarget { @@ -457,7 +457,7 @@ class PostIncrDecr(var target: AssignTarget, val operator: String, override val override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "PostIncrDecr(op: $operator, target: $target, pos=$position)" @@ -478,7 +478,7 @@ class Jump(val address: Int?, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "Jump(addr: $address, identifier: $identifier, label: $generatedLabel; pos=$position)" @@ -501,7 +501,7 @@ class FunctionCallStatement(override var target: IdentifierReference, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "FunctionCallStatement(target=$target, pos=$position)" @@ -518,7 +518,7 @@ class InlineAssembly(val assembly: String, override val position: Position) : St override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } class AnonymousScope(override var statements: MutableList, @@ -544,7 +544,7 @@ class AnonymousScope(override var statements: MutableList, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } class NopStatement(override val position: Position): Statement() { @@ -557,7 +557,7 @@ class NopStatement(override val position: Position): Statement() { override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) companion object { fun insteadOf(stmt: Statement): NopStatement { @@ -600,7 +600,7 @@ class Subroutine(override val name: String, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "Subroutine(name=$name, parameters=$parameters, returntypes=$returntypes, ${statements.size} statements, address=$asmAddress)" @@ -641,7 +641,7 @@ class IfStatement(var condition: Expression, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } @@ -661,7 +661,7 @@ class BranchStatement(var condition: BranchCondition, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } @@ -682,7 +682,7 @@ class ForLoop(val loopRegister: Register?, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun toString(): String { return "ForLoop(loopVar: $loopVar, loopReg: $loopRegister, iterable: $iterable, pos=$position)" @@ -709,7 +709,7 @@ class WhileLoop(var condition: Expression, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } class ForeverLoop(var body: AnonymousScope, override val position: Position) : Statement() { @@ -723,7 +723,7 @@ class ForeverLoop(var body: AnonymousScope, override val position: Position) : S override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } class RepeatLoop(var body: AnonymousScope, @@ -740,7 +740,7 @@ class RepeatLoop(var body: AnonymousScope, override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } class WhenStatement(var condition: Expression, @@ -774,7 +774,7 @@ class WhenStatement(var condition: Expression, override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } class WhenChoice(var values: List?, // if null, this is the 'else' part @@ -794,7 +794,7 @@ class WhenChoice(var values: List?, // if null, this is t fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) - fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } @@ -815,7 +815,7 @@ class StructDecl(override val name: String, override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) - override fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) fun nameOfFirstMember() = (statements.first() as VarDecl).name } @@ -834,5 +834,5 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) - fun accept(visitor: IGenericAstModifyingVisitor, parent: Node) = visitor.visit(this, parent) + fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) } diff --git a/compiler/src/prog8/optimizer/FlattenAnonymousScopesAndRemoveNops.kt b/compiler/src/prog8/optimizer/FlattenAnonymousScopesAndRemoveNops.kt new file mode 100644 index 000000000..d78931ec3 --- /dev/null +++ b/compiler/src/prog8/optimizer/FlattenAnonymousScopesAndRemoveNops.kt @@ -0,0 +1,46 @@ +package prog8.optimizer + +import prog8.ast.INameScope +import prog8.ast.Node +import prog8.ast.Program +import prog8.ast.processing.IAstVisitor +import prog8.ast.statements.AnonymousScope +import prog8.ast.statements.NopStatement +import prog8.ast.statements.Statement + +internal class FlattenAnonymousScopesAndRemoveNops: IAstVisitor { + private var scopesToFlatten = mutableListOf() + private val nopStatements = mutableListOf() + + override fun visit(program: Program) { + super.visit(program) + for(scope in scopesToFlatten.reversed()) { + val namescope = scope.parent as INameScope + val idx = namescope.statements.indexOf(scope as Statement) + if(idx>=0) { + val nop = NopStatement.insteadOf(namescope.statements[idx]) + nop.parent = namescope as Node + namescope.statements[idx] = nop + namescope.statements.addAll(idx, scope.statements) + scope.statements.forEach { it.parent = namescope } + visit(nop) + } + } + + this.nopStatements.forEach { + it.definingScope().remove(it) + } + } + + override fun visit(scope: AnonymousScope) { + if(scope.parent is INameScope) { + scopesToFlatten.add(scope) // get rid of the anonymous scope + } + + return super.visit(scope) + } + + override fun visit(nopStatement: NopStatement) { + nopStatements.add(nopStatement) + } +} diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index 21d5aa80d..a546bcbc2 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -2,12 +2,10 @@ package prog8.optimizer import prog8.ast.INameScope import prog8.ast.Module -import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.* import prog8.ast.expressions.* import prog8.ast.processing.IAstModifyingVisitor -import prog8.ast.processing.IAstVisitor import prog8.ast.statements.* import prog8.compiler.target.CompilationTarget import prog8.functions.BuiltinFunctions @@ -582,39 +580,3 @@ internal class StatementOptimizer(private val program: Program, -internal class FlattenAnonymousScopesAndRemoveNops: IAstVisitor { - private var scopesToFlatten = mutableListOf() - private val nopStatements = mutableListOf() - - override fun visit(program: Program) { - super.visit(program) - for(scope in scopesToFlatten.reversed()) { - val namescope = scope.parent as INameScope - val idx = namescope.statements.indexOf(scope as Statement) - if(idx>=0) { - val nop = NopStatement.insteadOf(namescope.statements[idx]) - nop.parent = namescope as Node - namescope.statements[idx] = nop - namescope.statements.addAll(idx, scope.statements) - scope.statements.forEach { it.parent = namescope } - visit(nop) - } - } - - this.nopStatements.forEach { - it.definingScope().remove(it) - } - } - - override fun visit(scope: AnonymousScope) { - if(scope.parent is INameScope) { - scopesToFlatten.add(scope) // get rid of the anonymous scope - } - - return super.visit(scope) - } - - override fun visit(nopStatement: NopStatement) { - nopStatements.add(nopStatement) - } -} diff --git a/compiler/src/prog8/parser/ModuleParsing.kt b/compiler/src/prog8/parser/ModuleParsing.kt index 60c4c21f6..8e596d314 100644 --- a/compiler/src/prog8/parser/ModuleParsing.kt +++ b/compiler/src/prog8/parser/ModuleParsing.kt @@ -140,7 +140,7 @@ internal class ModuleImporter(private val errors: ErrorReporter) { importModule(program, modulePath) } - importedModule.checkImportedValid(errors) + importedModule.checkImportedValid() return importedModule } diff --git a/examples/test.p8 b/examples/test.p8 index 646ed0de0..b1ba127d7 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -5,6 +5,15 @@ main { sub start() { + while true { + A=99 + } + + repeat { + A=44 + } until false + + c64scr.print("spstart:") print_stackpointer() sub1()