diff --git a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt index e69b425ac..571c5a4c4 100644 --- a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt +++ b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt @@ -31,6 +31,14 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o varsList.add(decl.name to decl) } + override fun before(block: Block, parent: Node): Iterable { + // move all subroutines to the bottom of the block + val subs = block.statements.filterIsInstance() + block.statements.removeAll(subs) + block.statements.addAll(subs) + return noModifications + } + override fun after(decl: VarDecl, parent: Node): Iterable { if(decl.type==VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes) throw FatalAstException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl") @@ -116,12 +124,14 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, private val o val outerScope = subroutine.definingScope val outerStatements = outerScope.statements val subroutineStmtIdx = outerStatements.indexOf(subroutine) - if (subroutineStmtIdx > 0 - && outerStatements[subroutineStmtIdx - 1] !is Jump - && outerStatements[subroutineStmtIdx - 1] !is Subroutine - && outerStatements[subroutineStmtIdx - 1] !is Return - && outerScope !is Block) { - mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope) + if (subroutineStmtIdx > 0) { + val prevStmt = outerStatements[subroutineStmtIdx-1] + if(outerScope !is Block + && (prevStmt !is Jump || prevStmt.isGosub) + && prevStmt !is Subroutine + && prevStmt !is Return) { + mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope) + } } return mods } diff --git a/compiler/test/TestSubroutines.kt b/compiler/test/TestSubroutines.kt index d7082b36b..4c1bb72ad 100644 --- a/compiler/test/TestSubroutines.kt +++ b/compiler/test/TestSubroutines.kt @@ -325,4 +325,24 @@ class TestSubroutines: FunSpec({ errors.errors[0] shouldContain "cannot use arguments" errors.errors[1] shouldContain "invalid number of arguments" } + + test("fallthrough prevented") { + val text = """ + main { + sub start() { + func(1) + + sub func(ubyte a) { + a++ + } + } + } + """ + val result = compileText(C64Target, false, text, writeAssembly = true).assertSuccess() + val stmts = result.program.entrypoint.statements + + stmts.last() shouldBe instanceOf() + stmts.dropLast(1).last() shouldBe instanceOf() // this prevents the fallthrough + stmts.dropLast(2).last() shouldBe instanceOf() + } }) diff --git a/examples/test.p8 b/examples/test.p8 index a3f515bd4..12b29ea06 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,30 +1,8 @@ %import textio %zeropage basicsafe -%option no_sysinit main { sub start() { - ubyte @shared bb - - if bb + sin8u(bb) > 100-bb { - bb++ - } - - while bb + sin8u(bb) > 100-bb { - bb++ - } - - do { - bb++ - } until bb + sin8u(bb) > 100-bb - - const ubyte EN_TYPE=2 - uword eRef = $c000 - ubyte chance = rnd() % 100 - - if eRef[EN_TYPE] and chance < (eRef[EN_TYPE] << 1) { - bb++ - } } }