From ed54cf680a1e8d7af9fd03d37d8cc2b2a0be6db7 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Mon, 6 Apr 2020 14:31:02 +0200 Subject: [PATCH] fixed ast parent link bug in AstWalker, rewrote StatementReorderer using new API, `when` labels are sorted. --- compiler/src/prog8/ast/base/Extensions.kt | 5 +- .../src/prog8/ast/processing/AstChecker.kt | 2 +- .../src/prog8/ast/processing/AstWalker.kt | 2 +- .../ast/processing/StatementReorderer.kt | 234 +++++++----------- examples/test.p8 | 25 ++ 5 files changed, 120 insertions(+), 148 deletions(-) diff --git a/compiler/src/prog8/ast/base/Extensions.kt b/compiler/src/prog8/ast/base/Extensions.kt index 09febce79..762ab5564 100644 --- a/compiler/src/prog8/ast/base/Extensions.kt +++ b/compiler/src/prog8/ast/base/Extensions.kt @@ -24,8 +24,9 @@ internal fun Program.reorderStatements() { initvalueCreator.visit(this) initvalueCreator.applyModifications() - val checker = StatementReorderer(this) - checker.visit(this) + val reorder = StatementReorderer(this) + reorder.visit(this) + reorder.applyModifications() } internal fun Program.addTypecasts(errors: ErrorReporter) { diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index 1fa20c3b4..cf36243cd 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -22,7 +22,7 @@ internal class AstChecker(private val program: Program, if(mainBlocks.size>1) errors.err("more than one 'main' block", mainBlocks[0].position) if(mainBlocks.isEmpty()) - errors.err("there is no 'main' block", program.position) + errors.err("there is no 'main' block", program.modules.firstOrNull()?.position ?: program.position) for(mainBlock in mainBlocks) { val startSub = mainBlock.subScopes()["start"] as? Subroutine diff --git a/compiler/src/prog8/ast/processing/AstWalker.kt b/compiler/src/prog8/ast/processing/AstWalker.kt index b476ade9a..af0b01aa1 100644 --- a/compiler/src/prog8/ast/processing/AstWalker.kt +++ b/compiler/src/prog8/ast/processing/AstWalker.kt @@ -56,7 +56,7 @@ interface IAstModification { class ReplaceNode(val node: Node, val replacement: Node, val parent: Node) : IAstModification { override fun perform() { parent.replaceChildNode(node, replacement) - replacement.parent = parent + replacement.linkParents(parent) } } diff --git a/compiler/src/prog8/ast/processing/StatementReorderer.kt b/compiler/src/prog8/ast/processing/StatementReorderer.kt index 43ec41a57..0ec88568c 100644 --- a/compiler/src/prog8/ast/processing/StatementReorderer.kt +++ b/compiler/src/prog8/ast/processing/StatementReorderer.kt @@ -9,187 +9,133 @@ import prog8.ast.expressions.* import prog8.ast.statements.* -internal class StatementReorderer(private val program: Program): IAstModifyingVisitor { +internal class StatementReorderer(val program: Program) : AstWalker() { // Reorders the statements in a way the compiler needs. // - 'main' block must be the very first statement UNLESS it has an address set. - // - blocks are ordered by address, where blocks without address are put at the end. - // - in every scope: - // -- the directives '%output', '%launcher', '%zeropage', '%zpreserved', '%address' and '%option' will come first. - // -- all vardecls then follow. - // -- the remaining statements then follow in their original order. - // - // - the 'start' subroutine in the 'main' block will be moved to the top immediately following the directives. - // - all other subroutines will be moved to the end of their block. + // - library blocks are put last. + // - blocks are ordered by address, where blocks without address are placed last. + // - in every scope, most directives and vardecls are moved to the top. + // - the 'start' subroutine is moved to the top. + // - (syntax desugaring) a vardecl with a non-const initializer value is split into a regular vardecl and an assignment statement. + // - (syntax desugaring) augmented assignment is turned into regular assignment. + // - (syntax desugaring) struct value assignment is expanded into several struct member assignments. // - sorts the choices in when statement. - // - a vardecl with a non-const initializer value is split into a regular vardecl and an assignment statement. + private val directivesToMove = setOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address", "%option") - private val addVardecls = mutableMapOf>() - - override fun visit(module: Module) { - addVardecls.clear() - super.visit(module) - + override fun after(module: Module, parent: Node): Iterable { val (blocks, other) = module.statements.partition { it is Block } module.statements = other.asSequence().plus(blocks.sortedBy { (it as Block).address ?: Int.MAX_VALUE }).toMutableList() - // make sure user-defined blocks come BEFORE library blocks, and move the "main" block to the top of everything - val nonLibraryBlocks = module.statements.withIndex() - .filter { it.value is Block && !(it.value as Block).isInLibrary } - .map { it.index to it.value } - .reversed() - for(nonLibBlock in nonLibraryBlocks) - module.statements.removeAt(nonLibBlock.first) - for(nonLibBlock in nonLibraryBlocks) - module.statements.add(0, nonLibBlock.second) - val mainBlock = module.statements.singleOrNull { it is Block && it.name=="main" } - if(mainBlock!=null && (mainBlock as Block).address==null) { - module.remove(mainBlock) + val mainBlock = module.statements.filterIsInstance().firstOrNull { it.name=="main" } + if(mainBlock!=null && mainBlock.address==null) { + module.statements.remove(mainBlock) module.statements.add(0, mainBlock) } - val varDecls = module.statements.filterIsInstance() - module.statements.removeAll(varDecls) - module.statements.addAll(0, varDecls) - - val directives = module.statements.filter {it is Directive && it.directive in directivesToMove} - module.statements.removeAll(directives) - module.statements.addAll(0, directives) - - for((where, decls) in addVardecls) { - where.statements.addAll(0, decls) - decls.forEach { it.linkParents(where as Node) } - } + reorderVardeclsAndDirectives(module.statements) + return emptyList() } - override fun visit(block: Block): Statement { + private fun reorderVardeclsAndDirectives(statements: MutableList) { + val varDecls = statements.filterIsInstance() + statements.removeAll(varDecls) + statements.addAll(0, varDecls) - val subroutines = block.statements.filterIsInstance() - var numSubroutinesAtEnd = 0 - // move all subroutines to the end of the block - for (subroutine in subroutines) { - if(subroutine.name!="start" || block.name!="main") { - block.remove(subroutine) - block.statements.add(subroutine) - } - numSubroutinesAtEnd++ + val directives = statements.filterIsInstance().filter {it.directive in directivesToMove} + statements.removeAll(directives) + statements.addAll(0, directives) + } + + override fun before(block: Block, parent: Node): Iterable { + parent as Module + if(block.isInLibrary) { + return listOf( + IAstModification.Remove(block, parent), + IAstModification.InsertAfter(parent.statements.last(), block, parent) + ) } - // move the "start" subroutine to the top - if(block.name=="main") { - block.statements.singleOrNull { it is Subroutine && it.name == "start" } ?.let { - block.remove(it) - block.statements.add(0, it) - numSubroutinesAtEnd-- + + reorderVardeclsAndDirectives(block.statements) + return emptyList() + } + + override fun before(subroutine: Subroutine, parent: Node): Iterable { + if(subroutine.name=="start" && parent is Block) { + if(parent.statements.filterIsInstance().first().name!="start") { + return listOf( + IAstModification.Remove(subroutine, parent), + IAstModification.InsertFirst(subroutine, parent) + ) } } - - val varDecls = block.statements.filterIsInstance() - block.statements.removeAll(varDecls) - block.statements.addAll(0, varDecls) - val directives = block.statements.filter {it is Directive && it.directive in directivesToMove} - block.statements.removeAll(directives) - block.statements.addAll(0, directives) - block.linkParents(block.parent) - - return super.visit(block) + reorderVardeclsAndDirectives(subroutine.statements) + return emptyList() } - override fun visit(subroutine: Subroutine): Statement { - super.visit(subroutine) - - val varDecls = subroutine.statements.filterIsInstance() - subroutine.statements.removeAll(varDecls) - subroutine.statements.addAll(0, varDecls) - val directives = subroutine.statements.filter {it is Directive && it.directive in directivesToMove} - subroutine.statements.removeAll(directives) - subroutine.statements.addAll(0, directives) - - return subroutine - } - - - private fun addVarDecl(scope: INameScope, variable: VarDecl): VarDecl { - if(scope !in addVardecls) - addVardecls[scope] = mutableListOf() - val declList = addVardecls.getValue(scope) - val existing = declList.singleOrNull { it.name==variable.name } - return if(existing!=null) { - existing - } else { - declList.add(variable) - variable - } - } - - override fun visit(decl: VarDecl): Statement { + override fun after(decl: VarDecl, parent: Node): Iterable { val declValue = decl.value if(declValue!=null && decl.type== VarDeclType.VAR && decl.datatype in NumericDatatypes) { val declConstValue = declValue.constValue(program) if(declConstValue==null) { // move the vardecl (without value) to the scope and replace this with a regular assignment + decl.value = null val target = AssignTarget(null, IdentifierReference(listOf(decl.name), decl.position), null, null, decl.position) val assign = Assignment(target, null, declValue, decl.position) - assign.linkParents(decl.parent) - decl.value = null - addVarDecl(decl.definingScope(), decl) - return assign + return listOf( + IAstModification.ReplaceNode(decl, assign, parent), + IAstModification.InsertFirst(decl, decl.definingScope() as Node) + ) } } - return super.visit(decl) + return emptyList() } - override fun visit(assignment: Assignment): Statement { - val assg = super.visit(assignment) - if(assg !is Assignment) - return assg + override fun after(whenStatement: WhenStatement, parent: Node): Iterable { + val choices = whenStatement.choiceValues(program).sortedBy { + it.first?.first() ?: Int.MAX_VALUE + } + whenStatement.choices.clear() + choices.mapTo(whenStatement.choices) { it.second } + return emptyList() + } - // see if a typecast is needed to convert the value's type into the proper target type - val valueItype = assg.value.inferType(program) - val targetItype = assg.target.inferType(program, assg) + override fun before(assignment: Assignment, parent: Node): Iterable { + if(assignment.aug_op!=null) { + // TODO instead of desugaring augmented assignments, instead just keep them and use them for possibly more efficient code generation ? + // this also means that we should actually reverse this stuff below: A = A + 5 ---> A += 5 + val leftOperand: Expression = + when { + assignment.target.register != null -> RegisterExpr(assignment.target.register!!, assignment.target.position) + assignment.target.identifier != null -> assignment.target.identifier!! + assignment.target.arrayindexed != null -> assignment.target.arrayindexed!! + assignment.target.memoryAddress != null -> DirectMemoryRead(assignment.target.memoryAddress!!.addressExpression, assignment.value.position) + else -> throw FatalAstException("strange assignment") + } - if(targetItype.isKnown && valueItype.isKnown) { - val targettype = targetItype.typeOrElse(DataType.STRUCT) - val valuetype = valueItype.typeOrElse(DataType.STRUCT) + val expression = BinaryExpression(leftOperand, assignment.aug_op.substringBeforeLast('='), assignment.value, assignment.position) + val convertedAssignment = Assignment(assignment.target, null, expression, assignment.position) + return listOf(IAstModification.ReplaceNode(assignment, convertedAssignment, parent)) + } - // struct assignments will be flattened (if it's not a struct literal) - if (valuetype == DataType.STRUCT && targettype == DataType.STRUCT) { - val assignments = if (assg.value is StructLiteralValue) { - flattenStructAssignmentFromStructLiteral(assg, program) // 'structvar = { ..... } ' - } else { - flattenStructAssignmentFromIdentifier(assg, program) // 'structvar1 = structvar2' - } - return if (assignments.isEmpty()) { - // something went wrong (probably incompatible struct types) - // we'll get an error later from the AstChecker - assg - } else { - val scope = AnonymousScope(assignments.toMutableList(), assg.position) - scope.linkParents(assg.parent) - scope - } + val valueType = assignment.value.inferType(program) + val targetType = assignment.target.inferType(program, assignment) + if(valueType.istype(DataType.STRUCT) && targetType.istype(DataType.STRUCT)) { + val assignments = if (assignment.value is StructLiteralValue) { + flattenStructAssignmentFromStructLiteral(assignment, program) // 'structvar = { ..... } ' + } else { + flattenStructAssignmentFromIdentifier(assignment, program) // 'structvar1 = structvar2' + } + if(assignments.isNotEmpty()) { + val modifications = mutableListOf() + assignments.reversed().mapTo(modifications) { IAstModification.InsertAfter(assignment, it, parent) } + modifications.add(IAstModification.Remove(assignment, parent)) + return modifications } } - if(assg.aug_op!=null) { - // transform augmented assg into normal assg so we have one case less to deal with later - val newTarget: Expression = - when { - assg.target.register != null -> RegisterExpr(assg.target.register!!, assg.target.position) - assg.target.identifier != null -> assg.target.identifier!! - assg.target.arrayindexed != null -> assg.target.arrayindexed!! - assg.target.memoryAddress != null -> DirectMemoryRead(assg.target.memoryAddress!!.addressExpression, assg.value.position) - else -> throw FatalAstException("strange assg") - } - - val expression = BinaryExpression(newTarget, assg.aug_op.substringBeforeLast('='), assg.value, assg.position) - expression.linkParents(assg.parent) - val convertedAssignment = Assignment(assg.target, null, expression, assg.position) - convertedAssignment.linkParents(assg.parent) - return super.visit(convertedAssignment) - } - - return assg + return emptyList() } private fun flattenStructAssignmentFromStructLiteral(structAssignment: Assignment, program: Program): List { diff --git a/examples/test.p8 b/examples/test.p8 index 8a150e20c..60f498175 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -3,6 +3,21 @@ %import c64flt %zeropage basicsafe +foobar { + %option force_output + + ubyte xx + + sub derp() { + byte yy=cos8(A) + + if A==0 { + ; ubyte qq=cos8(A) + A=54 + } + } + +} main { sub start() { @@ -16,6 +31,16 @@ main { c64flt.print_f(floats[0]) c64flt.print_f(floats[1]) + foobar.derp() + when A { + 100 -> Y=4 + 101 -> Y=5 + 1 -> Y=66 + 10 -> Y=77 + else -> Y=9 + } + + A+=99 } }