don't fall-through into nested subroutine

This commit is contained in:
Irmen de Jong 2019-08-18 02:33:42 +02:00
parent 1cc1f2d91d
commit c495f54bbb
3 changed files with 117 additions and 90 deletions

View File

@ -1,7 +1,6 @@
package prog8.ast.processing
import prog8.ast.Module
import prog8.ast.Program
import prog8.ast.*
import prog8.ast.base.DataType
import prog8.ast.base.FatalAstException
import prog8.ast.base.initvarsSubName
@ -64,7 +63,10 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
private val directivesToMove = setOf("%output", "%launcher", "%zeropage", "%zpreserved", "%address", "%option")
private val addReturns = mutableListOf<Pair<INameScope, Int>>()
override fun visit(module: Module) {
addReturns.clear()
super.visit(module)
val (blocks, other) = module.statements.partition { it is Block }
@ -92,6 +94,13 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
val directives = module.statements.filter {it is Directive && it.directive in directivesToMove}
module.statements.removeAll(directives)
module.statements.addAll(0, directives)
for(pos in addReturns) {
println(pos)
val returnStmt = Return(null, pos.first.position)
returnStmt.linkParents(pos.first as Node)
pos.first.statements.add(pos.second, returnStmt)
}
}
override fun visit(block: Block): Statement {
@ -161,6 +170,19 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
override fun visit(subroutine: Subroutine): Statement {
super.visit(subroutine)
val scope = subroutine.definingScope()
if(scope is Subroutine) {
for(stmt in scope.statements.withIndex()) {
if(stmt.index>0 && stmt.value===subroutine) {
val precedingStmt = scope.statements[stmt.index-1]
if(precedingStmt !is Jump && precedingStmt !is Subroutine) {
// insert a return statement before a nested subroutine, to avoid falling trough inside the subroutine
addReturns.add(Pair(scope, stmt.index))
}
}
}
}
val varDecls = subroutine.statements.filterIsInstance<VarDecl>()
subroutine.statements.removeAll(varDecls)
subroutine.statements.addAll(0, varDecls)

View File

@ -96,7 +96,7 @@ fun compileProgram(filepath: Path,
programAst.checkValid(compilerOptions) // check if final tree is valid
programAst.checkRecursion() // check if there are recursive subroutine calls
// printAst(programAst)
printAst(programAst)
if(writeAssembly) {
// asm generation directly from the Ast, no need for intermediate code

View File

@ -6,99 +6,104 @@ main {
sub start() {
ubyte[] uba = [10,0,2,8,5,4,3,9]
uword[] uwa = [1000,0,200,8000,50,40000,3,900]
byte[] ba = [-10,0,-2,8,5,4,-3,9,-99]
word[] wa = [-1000,0,-200,8000,50,31111,3,-900]
print_name()
for ubyte ub in uba {
c64scr.print_ub(ub)
c64.CHROUT(',')
sub print_name() {
c64scr.print("irmen\n")
}
c64.CHROUT('\n')
for uword uw in uwa {
c64scr.print_uw(uw)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for byte bb in ba {
c64scr.print_b(bb)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for word ww in wa {
c64scr.print_w(ww)
c64.CHROUT(',')
}
c64.CHROUT('\n')
c64.CHROUT('\n')
sort(uba)
sort(uwa)
sort(ba)
sort(wa)
for ubyte ub2 in uba {
c64scr.print_ub(ub2)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for uword uw2 in uwa {
c64scr.print_uw(uw2)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for byte bb2 in ba {
c64scr.print_b(bb2)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for word ww2 in wa {
c64scr.print_w(ww2)
c64.CHROUT(',')
}
c64.CHROUT('\n')
c64.CHROUT('\n')
reverse(uba)
reverse(uwa)
reverse(ba)
reverse(wa)
for ubyte ub3 in uba {
c64scr.print_ub(ub3)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for uword uw3 in uwa {
c64scr.print_uw(uw3)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for byte bb3 in ba {
c64scr.print_b(bb3)
c64.CHROUT(',')
}
c64.CHROUT('\n')
for word ww3 in wa {
c64scr.print_w(ww3)
c64.CHROUT(',')
}
c64.CHROUT('\n')
c64.CHROUT('\n')
; ubyte[] uba = [10,0,2,8,5,4,3,9]
; uword[] uwa = [1000,0,200,8000,50,40000,3,900]
; byte[] ba = [-10,0,-2,8,5,4,-3,9,-99]
; word[] wa = [-1000,0,-200,8000,50,31111,3,-900]
;
; for ubyte ub in uba {
; c64scr.print_ub(ub)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for uword uw in uwa {
; c64scr.print_uw(uw)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for byte bb in ba {
; c64scr.print_b(bb)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for word ww in wa {
; c64scr.print_w(ww)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
; c64.CHROUT('\n')
;
; sort(uba)
; sort(uwa)
; sort(ba)
; sort(wa)
;
; for ubyte ub2 in uba {
; c64scr.print_ub(ub2)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for uword uw2 in uwa {
; c64scr.print_uw(uw2)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for byte bb2 in ba {
; c64scr.print_b(bb2)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for word ww2 in wa {
; c64scr.print_w(ww2)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
; c64.CHROUT('\n')
;
; reverse(uba)
; reverse(uwa)
; reverse(ba)
; reverse(wa)
;
; for ubyte ub3 in uba {
; c64scr.print_ub(ub3)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for uword uw3 in uwa {
; c64scr.print_uw(uw3)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for byte bb3 in ba {
; c64scr.print_b(bb3)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
;
; for word ww3 in wa {
; c64scr.print_w(ww3)
; c64.CHROUT(',')
; }
; c64.CHROUT('\n')
; c64.CHROUT('\n')
; TODO 2 for loops that both define the same loopvar -> double definition -> fix second for -> 'unknown symbol' ????
; TODO code runs into a subroutine that follows it instead of inserting an implicit return
}
}