From c495f54bbb7c469d9de55990b9b5d4501814ccbc Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sun, 18 Aug 2019 02:33:42 +0200 Subject: [PATCH] don't fall-through into nested subroutine --- .../ast/processing/StatementReorderer.kt | 26 ++- compiler/src/prog8/compiler/Main.kt | 2 +- examples/test.p8 | 179 +++++++++--------- 3 files changed, 117 insertions(+), 90 deletions(-) diff --git a/compiler/src/prog8/ast/processing/StatementReorderer.kt b/compiler/src/prog8/ast/processing/StatementReorderer.kt index d30047f4d..02c4a6d2e 100644 --- a/compiler/src/prog8/ast/processing/StatementReorderer.kt +++ b/compiler/src/prog8/ast/processing/StatementReorderer.kt @@ -1,7 +1,6 @@ package prog8.ast.processing -import prog8.ast.Module -import prog8.ast.Program +import prog8.ast.* import prog8.ast.base.DataType import prog8.ast.base.FatalAstException import prog8.ast.base.initvarsSubName @@ -64,7 +63,10 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi private val directivesToMove = setOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address", "%option") + private val addReturns = mutableListOf>() + override fun visit(module: Module) { + addReturns.clear() super.visit(module) val (blocks, other) = module.statements.partition { it is Block } @@ -92,6 +94,13 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi val directives = module.statements.filter {it is Directive && it.directive in directivesToMove} module.statements.removeAll(directives) module.statements.addAll(0, directives) + + for(pos in addReturns) { + println(pos) + val returnStmt = Return(null, pos.first.position) + returnStmt.linkParents(pos.first as Node) + pos.first.statements.add(pos.second, returnStmt) + } } override fun visit(block: Block): Statement { @@ -161,6 +170,19 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi override fun visit(subroutine: Subroutine): Statement { super.visit(subroutine) + val scope = subroutine.definingScope() + if(scope is Subroutine) { + for(stmt in scope.statements.withIndex()) { + if(stmt.index>0 && stmt.value===subroutine) { + val precedingStmt = scope.statements[stmt.index-1] + if(precedingStmt !is Jump && precedingStmt !is Subroutine) { + // insert a return statement before a nested subroutine, to avoid falling trough inside the subroutine + addReturns.add(Pair(scope, stmt.index)) + } + } + } + } + val varDecls = subroutine.statements.filterIsInstance() subroutine.statements.removeAll(varDecls) subroutine.statements.addAll(0, varDecls) diff --git a/compiler/src/prog8/compiler/Main.kt b/compiler/src/prog8/compiler/Main.kt index 304e34c7c..e4a2b0ffc 100644 --- a/compiler/src/prog8/compiler/Main.kt +++ b/compiler/src/prog8/compiler/Main.kt @@ -96,7 +96,7 @@ fun compileProgram(filepath: Path, programAst.checkValid(compilerOptions) // check if final tree is valid programAst.checkRecursion() // check if there are recursive subroutine calls - // printAst(programAst) + printAst(programAst) if(writeAssembly) { // asm generation directly from the Ast, no need for intermediate code diff --git a/examples/test.p8 b/examples/test.p8 index 1363faaee..7579140af 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -6,99 +6,104 @@ main { sub start() { - ubyte[] uba = [10,0,2,8,5,4,3,9] - uword[] uwa = [1000,0,200,8000,50,40000,3,900] - byte[] ba = [-10,0,-2,8,5,4,-3,9,-99] - word[] wa = [-1000,0,-200,8000,50,31111,3,-900] + print_name() - for ubyte ub in uba { - c64scr.print_ub(ub) - c64.CHROUT(',') + sub print_name() { + c64scr.print("irmen\n") } - c64.CHROUT('\n') - for uword uw in uwa { - c64scr.print_uw(uw) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for byte bb in ba { - c64scr.print_b(bb) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for word ww in wa { - c64scr.print_w(ww) - c64.CHROUT(',') - } - c64.CHROUT('\n') - c64.CHROUT('\n') - - sort(uba) - sort(uwa) - sort(ba) - sort(wa) - - for ubyte ub2 in uba { - c64scr.print_ub(ub2) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for uword uw2 in uwa { - c64scr.print_uw(uw2) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for byte bb2 in ba { - c64scr.print_b(bb2) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for word ww2 in wa { - c64scr.print_w(ww2) - c64.CHROUT(',') - } - c64.CHROUT('\n') - c64.CHROUT('\n') - - reverse(uba) - reverse(uwa) - reverse(ba) - reverse(wa) - - for ubyte ub3 in uba { - c64scr.print_ub(ub3) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for uword uw3 in uwa { - c64scr.print_uw(uw3) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for byte bb3 in ba { - c64scr.print_b(bb3) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - for word ww3 in wa { - c64scr.print_w(ww3) - c64.CHROUT(',') - } - c64.CHROUT('\n') - c64.CHROUT('\n') +; ubyte[] uba = [10,0,2,8,5,4,3,9] +; uword[] uwa = [1000,0,200,8000,50,40000,3,900] +; byte[] ba = [-10,0,-2,8,5,4,-3,9,-99] +; word[] wa = [-1000,0,-200,8000,50,31111,3,-900] +; +; for ubyte ub in uba { +; c64scr.print_ub(ub) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for uword uw in uwa { +; c64scr.print_uw(uw) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for byte bb in ba { +; c64scr.print_b(bb) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for word ww in wa { +; c64scr.print_w(ww) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; c64.CHROUT('\n') +; +; sort(uba) +; sort(uwa) +; sort(ba) +; sort(wa) +; +; for ubyte ub2 in uba { +; c64scr.print_ub(ub2) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for uword uw2 in uwa { +; c64scr.print_uw(uw2) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for byte bb2 in ba { +; c64scr.print_b(bb2) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for word ww2 in wa { +; c64scr.print_w(ww2) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; c64.CHROUT('\n') +; +; reverse(uba) +; reverse(uwa) +; reverse(ba) +; reverse(wa) +; +; for ubyte ub3 in uba { +; c64scr.print_ub(ub3) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for uword uw3 in uwa { +; c64scr.print_uw(uw3) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for byte bb3 in ba { +; c64scr.print_b(bb3) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; for word ww3 in wa { +; c64scr.print_w(ww3) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; c64.CHROUT('\n') ; TODO 2 for loops that both define the same loopvar -> double definition -> fix second for -> 'unknown symbol' ???? - ; TODO code runs into a subroutine that follows it instead of inserting an implicit return } }