From cf7bea0985a574afa3b940e4d004d94e86f29950 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Wed, 21 May 2025 00:19:50 +0200 Subject: [PATCH] cleanup RTS insertion and ast postprocessing before assembly generation --- compiler/src/prog8/compiler/Compiler.kt | 15 +++++- .../compiler/astprocessing/AstExtensions.kt | 13 ----- .../astprocessing/BeforeAsmAstChanger.kt | 34 +++++++++--- compiler/test/TestBuiltinFunctions.kt | 2 +- compiler/test/TestOptimization.kt | 52 +++++++++---------- compiler/test/TestSubroutines.kt | 10 +++- compiler/test/TestTypecasts.kt | 14 ++--- compiler/test/ast/TestConst.kt | 10 ++-- compiler/test/ast/TestProg8Parser.kt | 14 ++--- compiler/test/ast/TestVariousCompilerAst.kt | 44 ++++------------ compiler/test/codegeneration/TestVariables.kt | 2 +- compilerAst/src/prog8/ast/AstToplevel.kt | 24 ++++++++- 12 files changed, 131 insertions(+), 103 deletions(-) diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 784e60668..e507e134d 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -145,8 +145,6 @@ fun compileProgram(args: CompilerArguments): CompilationResult? { // re-initialize memory areas with final compilationOptions compilationOptions.compTarget.initializeMemoryAreas(compilationOptions) - program.processAstBeforeAsmGeneration(compilationOptions, args.errors) - args.errors.report() if(args.printAst1) { println("\n*********** COMPILER AST *************") @@ -495,8 +493,21 @@ private fun postprocessAst(program: Program, errors: IErrorReporter, compilerOpt program.verifyFunctionArgTypes(errors, compilerOptions) errors.report() program.moveMainBlockAsFirst(compilerOptions.compTarget) + + val fixer = BeforeAsmAstChanger(program, compilerOptions, errors) + fixer.visit(program) + while (errors.noErrors() && fixer.applyModifications() > 0) { + fixer.visit(program) + } + program.checkValid(errors, compilerOptions) // check if final tree is still valid errors.report() + + val cleaner = BeforeAsmTypecastCleaner(program, errors) + cleaner.visit(program) + while (errors.noErrors() && cleaner.applyModifications() > 0) { + cleaner.visit(program) + } } private fun createAssemblyAndAssemble(program: PtProgram, diff --git a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt index 8b0b5dd17..6fd7406f4 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt @@ -20,19 +20,6 @@ internal fun Program.checkValid(errors: IErrorReporter, compilerOptions: Compila checker.visit(this) } -internal fun Program.processAstBeforeAsmGeneration(compilerOptions: CompilationOptions, errors: IErrorReporter) { - val fixer = BeforeAsmAstChanger(this, compilerOptions, errors) - fixer.visit(this) - while (errors.noErrors() && fixer.applyModifications() > 0) { - fixer.visit(this) - } - val cleaner = BeforeAsmTypecastCleaner(this, errors) - cleaner.visit(this) - while (errors.noErrors() && cleaner.applyModifications() > 0) { - cleaner.visit(this) - } -} - internal fun Program.reorderStatements(errors: IErrorReporter) { val reorder = StatementReorderer(this, errors) reorder.visit(this) diff --git a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt index 14174a656..918150571 100644 --- a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt +++ b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt @@ -3,6 +3,7 @@ package prog8.compiler.astprocessing import prog8.ast.IStatementContainer import prog8.ast.Node import prog8.ast.Program +import prog8.ast.defaultZero import prog8.ast.expressions.BinaryExpression import prog8.ast.expressions.NumericLiteral import prog8.ast.statements.* @@ -51,15 +52,24 @@ internal class BeforeAsmAstChanger(val program: Program, private val options: Co // and if an assembly block doesn't contain a rts/rti. if (!subroutine.isAsmSubroutine) { if(subroutine.isEmpty()) { - val returnStmt = Return(arrayOf(), subroutine.position) - mods += IAstModification.InsertLast(returnStmt, subroutine) + if(subroutine.returntypes.isNotEmpty()) + errors.err("subroutine is missing a return statement with value(s)", subroutine.position) + else { + val returnStmt = Return(arrayOf(), subroutine.position) + mods += IAstModification.InsertLast(returnStmt, subroutine) + } } else { val last = subroutine.statements.last() if((last !is InlineAssembly || !last.hasReturnOrRts()) && last !is Return) { val lastStatement = subroutine.statements.reversed().firstOrNull { it !is Subroutine } if(lastStatement !is Return) { - val returnStmt = Return(arrayOf(), subroutine.position) - mods += IAstModification.InsertLast(returnStmt, subroutine) + if(subroutine.returntypes.isNotEmpty()) { + // .... we cannot return this as an error, because that also breaks legitimate cases where the return is done from within a nested scope somewhere + // errors.err("subroutine is missing a return statement with value(s)", subroutine.position) + } else { + val returnStmt = Return(arrayOf(), subroutine.position) + mods += IAstModification.InsertLast(returnStmt, subroutine) + } } } } @@ -76,8 +86,20 @@ internal class BeforeAsmAstChanger(val program: Program, private val options: Co && prevStmt !is Subroutine && prevStmt !is Return ) { - val returnStmt = Return(arrayOf(), subroutine.position) - mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope) + if(!subroutine.inline) { + if(outerScope is Subroutine && outerScope.returntypes.isNotEmpty()) { + if(outerScope.returntypes.size>1 || !outerScope.returntypes[0].isNumericOrBool) { + errors.err("subroutine is missing a return statement to avoid falling through into nested subroutine", outerStatements[subroutineStmtIdx-1].position) + } else { + val zero = defaultZero(outerScope.returntypes[0].base, Position.DUMMY) + val returnStmt = Return(arrayOf(zero), outerStatements[subroutineStmtIdx - 1].position) + mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope) + } + } else { + val returnStmt = Return(arrayOf(), outerStatements[subroutineStmtIdx - 1].position) + mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope) + } + } } } diff --git a/compiler/test/TestBuiltinFunctions.kt b/compiler/test/TestBuiltinFunctions.kt index 5d1e087d9..805839ef9 100644 --- a/compiler/test/TestBuiltinFunctions.kt +++ b/compiler/test/TestBuiltinFunctions.kt @@ -77,7 +77,7 @@ main { }""" val result = compileText(Cx16Target(), false, src, outputDir, writeAssembly = false) val statements = result!!.compilerAst.entrypoint.statements - statements.size shouldBe 7 + statements.size shouldBe 8 val a1 = statements[2] as Assignment val a2 = statements[3] as Assignment val a3 = statements[4] as Assignment diff --git a/compiler/test/TestOptimization.kt b/compiler/test/TestOptimization.kt index c61520242..d05e7f469 100644 --- a/compiler/test/TestOptimization.kt +++ b/compiler/test/TestOptimization.kt @@ -225,7 +225,7 @@ other { } """ val result = compileText(C64Target(), optimize=false, src, outputDir, writeAssembly = false)!! - val assignFF = result.compilerAst.entrypoint.statements.last() as Assignment + val assignFF = result.compilerAst.entrypoint.statements.dropLast(1).last() as Assignment assignFF.isAugmentable shouldBe true assignFF.target.identifier!!.nameInSource shouldBe listOf("ff") val value = assignFF.value as BinaryExpression @@ -252,7 +252,7 @@ other { } """ val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! - result.compilerAst.entrypoint.statements.size shouldBe 7 + result.compilerAst.entrypoint.statements.size shouldBe 8 val alldecls = result.compilerAst.entrypoint.allDefinedSymbols.toList() alldecls.map { it.first } shouldBe listOf("unused_but_shared", "usedvar_only_written", "usedvar") } @@ -276,12 +276,12 @@ other { } }""" val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! - result.compilerAst.entrypoint.statements.size shouldBe 3 + result.compilerAst.entrypoint.statements.size shouldBe 4 val ifstmt = result.compilerAst.entrypoint.statements[0] as IfElse ifstmt.truepart.statements.size shouldBe 1 (ifstmt.truepart.statements[0] as Assignment).target.identifier!!.nameInSource shouldBe listOf("cx16", "r0") - val func2 = result.compilerAst.entrypoint.statements[2] as Subroutine - func2.statements.size shouldBe 2 + val func2 = result.compilerAst.entrypoint.statements.last() as Subroutine + func2.statements.size shouldBe 3 (func2.statements[0] as Assignment).target.identifier!!.nameInSource shouldBe listOf("cx16", "r0") } @@ -312,7 +312,7 @@ main { } }""" val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! - result.compilerAst.entrypoint.statements.size shouldBe 0 + result.compilerAst.entrypoint.statements.size shouldBe 1 result.compilerAst.entrypoint.definingScope.statements.size shouldBe 1 } @@ -350,7 +350,7 @@ main { z6 = z1 - 5 */ val statements = result.compilerAst.entrypoint.statements - statements.size shouldBe 12 + statements.size shouldBe 13 val z1decl = statements[0] as VarDecl val z1init = statements[1] as Assignment val z2decl = statements[2] as VarDecl @@ -395,8 +395,8 @@ main { val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 5 - val assign=stmts.last() as Assignment + stmts.size shouldBe 6 + val assign=stmts[4] as Assignment (assign.target.memoryAddress?.addressExpression as IdentifierReference).nameInSource shouldBe listOf("aa") } @@ -412,8 +412,8 @@ main { """ val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 5 - val assign=stmts.last() as Assignment + stmts.size shouldBe 6 + val assign=stmts[4] as Assignment (assign.target.memoryAddress?.addressExpression as IdentifierReference).nameInSource shouldBe listOf("aa") } @@ -431,7 +431,7 @@ main { """ val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 10 + stmts.size shouldBe 11 stmts.filterIsInstance().size shouldBe 5 stmts.filterIsInstance().size shouldBe 5 } @@ -508,7 +508,7 @@ main { xx += 6 */ val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 6 + stmts.size shouldBe 7 stmts.filterIsInstance().size shouldBe 3 stmts.filterIsInstance().size shouldBe 3 } @@ -537,13 +537,13 @@ main { xx += 10 */ val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 7 + stmts.size shouldBe 8 stmts.filterIsInstance().size shouldBe 2 stmts.filterIsInstance().size shouldBe 5 val assignXX1 = stmts[1] as Assignment assignXX1.target.identifier!!.nameInSource shouldBe listOf("xx") assignXX1.value shouldBe NumericLiteral(BaseDataType.UWORD, 20.0, Position.DUMMY) - val assignXX2 = stmts.last() as Assignment + val assignXX2 = stmts[6] as Assignment assignXX2.target.identifier!!.nameInSource shouldBe listOf("xx") val xxValue = assignXX2.value as BinaryExpression xxValue.operator shouldBe "+" @@ -577,7 +577,7 @@ main { thingy++ */ val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 6 + stmts.size shouldBe 7 val ifStmt = stmts[5] as IfElse val containment = ifStmt.condition as ContainmentCheck (containment.element as IdentifierReference).nameInSource shouldBe listOf("source") @@ -612,7 +612,7 @@ main { val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 5 + stmts.size shouldBe 6 val ifStmt = stmts[4] as IfElse val containment = ifStmt.condition as ContainmentCheck (containment.element as IdentifierReference).nameInSource shouldBe listOf("source") @@ -634,7 +634,7 @@ main { }""" val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 5 + stmts.size shouldBe 6 val ifStmt = stmts[4] as IfElse ifStmt.condition shouldBe instanceOf() } @@ -652,7 +652,7 @@ main { }""" val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 5 + stmts.size shouldBe 6 val ifStmt = stmts[4] as IfElse ifStmt.condition shouldBe instanceOf() } @@ -670,7 +670,7 @@ main { }""" val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 5 + stmts.size shouldBe 6 val ifStmt = stmts[4] as IfElse ifStmt.condition shouldBe instanceOf() } @@ -687,7 +687,7 @@ main { }""" val result = compileText(C64Target(), optimize=true, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 3 + stmts.size shouldBe 4 } test("repeated assignments to IO register should remain") { @@ -828,7 +828,7 @@ main { val errors = ErrorReporterForTests() val result = compileText(Cx16Target(), true, src, outputDir, writeAssembly = false, errors = errors)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 4 + st.size shouldBe 5 val xxConst = st[0] as VarDecl xxConst.type shouldBe VarDeclType.CONST xxConst.name shouldBe "xx" @@ -872,7 +872,7 @@ main { }""" val result = compileText(Cx16Target(), true, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 8 + st.size shouldBe 9 val if1c = (st[4] as IfElse).condition as PrefixExpression val if2c = (st[5] as IfElse).condition as PrefixExpression val if3c = (st[6] as IfElse).condition as PrefixExpression @@ -915,7 +915,7 @@ main { }""" val result = compileText(Cx16Target(), true, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 12 + st.size shouldBe 13 val if1 = st[4] as IfElse val if2 = st[5] as IfElse val if3 = st[6] as IfElse @@ -970,7 +970,7 @@ main { val result = compileText(Cx16Target(), true, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 17 + st.size shouldBe 18 val answerValue = (st[3] as Assignment).value answerValue shouldBe NumericLiteral(BaseDataType.UWORD, 0.0, Position.DUMMY) @@ -1019,7 +1019,7 @@ main { }""" val result = compileText(Cx16Target(), true, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 3 + st.size shouldBe 4 val ifCond1 = (st[0] as IfElse).condition as BinaryExpression val ifCond2 = (st[1] as IfElse).condition as BinaryExpression val ifCond3 = (st[2] as IfElse).condition as BinaryExpression diff --git a/compiler/test/TestSubroutines.kt b/compiler/test/TestSubroutines.kt index ca274c39f..d7e984e3a 100644 --- a/compiler/test/TestSubroutines.kt +++ b/compiler/test/TestSubroutines.kt @@ -54,6 +54,9 @@ class TestSubroutines: FunSpec({ } asmsub asmfunc(str thing @AY) { + %asm {{ + rts + }} } sub func(str thing) { @@ -67,12 +70,12 @@ class TestSubroutines: FunSpec({ val asmfunc = mainBlock.statements.filterIsInstance().single { it.name=="asmfunc"} val func = mainBlock.statements.filterIsInstance().single { it.name=="func"} asmfunc.isAsmSubroutine shouldBe true - asmfunc.statements.isEmpty() shouldBe true + asmfunc.statements.size shouldBe 1 func.isAsmSubroutine shouldBe false withClue("str param for subroutines should be changed into UWORD") { asmfunc.parameters.single().type shouldBe DataType.UWORD func.parameters.single().type shouldBe DataType.UWORD - func.statements.size shouldBe 4 + func.statements.size shouldBe 5 val paramvar = func.statements[0] as VarDecl paramvar.name shouldBe "thing" paramvar.datatype shouldBe DataType.UWORD @@ -174,6 +177,9 @@ class TestSubroutines: FunSpec({ } asmsub asmfunc(ubyte[] thing @AY) { + %asm {{ + rts + }} } sub func(ubyte[] thing) { diff --git a/compiler/test/TestTypecasts.kt b/compiler/test/TestTypecasts.kt index ff6cf0a4e..cf7be51db 100644 --- a/compiler/test/TestTypecasts.kt +++ b/compiler/test/TestTypecasts.kt @@ -54,7 +54,7 @@ class TestTypecasts: FunSpec({ }""" val result = compileText(C64Target(), false, text, outputDir, writeAssembly = false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 6 + stmts.size shouldBe 7 val expr = (stmts[5] as Assignment).value as BinaryExpression expr.operator shouldBe "and" (expr.left as IdentifierReference).nameInSource shouldBe listOf("bb2") // no cast @@ -157,7 +157,7 @@ main { }""" val result = compileText(C64Target(), false, text, outputDir, writeAssembly = false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 7 + stmts.size shouldBe 8 val fcall1 = ((stmts[4] as Assignment).value as IFunctionCall) fcall1.args[0] shouldBe NumericLiteral(BaseDataType.BOOL, 1.0, Position.DUMMY) fcall1.args[1] shouldBe NumericLiteral(BaseDataType.BOOL, 0.0, Position.DUMMY) @@ -209,7 +209,7 @@ main { }""" val result = compileText(C64Target(), false, text, outputDir, writeAssembly = false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 3 + stmts.size shouldBe 4 } test("ubyte to word casts") { @@ -224,7 +224,7 @@ main { val result = compileText(C64Target(), true, src, outputDir, writeAssembly = false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 4 + stmts.size shouldBe 5 val assign1tc = (stmts[2] as Assignment).value as TypecastExpression val assign2tc = (stmts[3] as Assignment).value as TypecastExpression assign1tc.type shouldBe BaseDataType.WORD @@ -255,7 +255,7 @@ main { }""" val result = compileText(C64Target(), false, text, outputDir, writeAssembly = false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 8 + stmts.size shouldBe 9 val arg1 = (stmts[2] as IFunctionCall).args.single() val arg2 = (stmts[3] as IFunctionCall).args.single() val arg3 = (stmts[4] as IFunctionCall).args.single() @@ -903,7 +903,7 @@ main { val result = compileText(C64Target(), false, src, outputDir, writeAssembly = false)!! val program = result.compilerAst val st = program.entrypoint.statements - st.size shouldBe 1 + st.size shouldBe 2 val assign = st[0] as Assignment assign.target.inferType(program).getOrUndef().base shouldBe BaseDataType.BYTE val ifexpr = assign.value as IfExpression @@ -928,7 +928,7 @@ main { val result = compileText(C64Target(), false, src, outputDir, writeAssembly = false)!! val program = result.compilerAst val st = program.entrypoint.statements - st.size shouldBe 6 + st.size shouldBe 7 val v1 = (st[2] as Assignment).value as BinaryExpression v1.operator shouldBe "+" (v1.left as IdentifierReference).nameInSource shouldBe listOf("cx16","r0") diff --git a/compiler/test/ast/TestConst.kt b/compiler/test/ast/TestConst.kt index cd8be2178..6b3d9deea 100644 --- a/compiler/test/ast/TestConst.kt +++ b/compiler/test/ast/TestConst.kt @@ -52,7 +52,7 @@ class TestConst: FunSpec({ // cx16.r5s = llw - 1899 // cx16.r7s = llw + 99 val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 9 + stmts.size shouldBe 10 val addR0value = (stmts[4] as Assignment).value val binexpr0 = addR0value as BinaryExpression @@ -109,7 +109,7 @@ class TestConst: FunSpec({ // result++ // result = llw * 18.0 val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 12 + stmts.size shouldBe 13 val mulR0Value = (stmts[3] as Assignment).value val binexpr0 = mulR0Value as BinaryExpression @@ -157,7 +157,7 @@ class TestConst: FunSpec({ // cx16.r3s = llw /2 *10 // cx16.r4s = llw *90 /5 val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 7 + stmts.size shouldBe 8 val mulR0Value = (stmts[2] as Assignment).value val binexpr0 = mulR0Value as BinaryExpression @@ -251,7 +251,7 @@ main { }""" val result = compileText(Cx16Target(), true, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 5 + st.size shouldBe 6 (st[0] as VarDecl).type shouldBe VarDeclType.CONST val assignv1 = (st[2] as Assignment).value val assignv2 = (st[4] as Assignment).value @@ -274,7 +274,7 @@ main { }""" val result = compileText(Cx16Target(), false, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 6 + st.size shouldBe 7 ((st[0] as VarDecl).value as NumericLiteral).number shouldBe 0x2000 ((st[1] as VarDecl).value as NumericLiteral).number shouldBe 0x9e00 ((st[2] as VarDecl).value as NumericLiteral).number shouldBe 0x9e00+2*30 diff --git a/compiler/test/ast/TestProg8Parser.kt b/compiler/test/ast/TestProg8Parser.kt index 3fbc2aaae..50979a02f 100644 --- a/compiler/test/ast/TestProg8Parser.kt +++ b/compiler/test/ast/TestProg8Parser.kt @@ -907,7 +907,7 @@ class TestProg8Parser: FunSpec( { """ val result = compileText(C64Target(), false, text, outputDir, writeAssembly = false)!! val start = result.compilerAst.entrypoint - val containmentChecks = start.statements.takeLast(4) + val containmentChecks = start.statements.takeLast(5) (containmentChecks[0] as IfElse).condition shouldBe instanceOf() (containmentChecks[1] as IfElse).condition shouldBe instanceOf() (containmentChecks[2] as Assignment).value shouldBe instanceOf() @@ -948,7 +948,7 @@ class TestProg8Parser: FunSpec( { """ val result = compileText(C64Target(), false, text, outputDir, writeAssembly = false)!! val stmt = result.compilerAst.entrypoint.statements - stmt.size shouldBe 12 + stmt.size shouldBe 13 val var1 = stmt[0] as VarDecl var1.sharedWithAsm shouldBe true var1.zeropage shouldBe ZeropageWish.REQUIRE_ZEROPAGE @@ -998,7 +998,7 @@ main { ; curly braces without newline sub start () { foo() derp() other() } sub foo() { cx16.r0++ } - asmsub derp() { %asm {{ nop }} %ir {{ load.b r0,1 }} } + asmsub derp() { %asm {{ nop rts }} %ir {{ load.b r0,1 return }} } ; curly braces on next line sub other() @@ -1014,6 +1014,7 @@ main { {{ txa tay + rts }} } @@ -1022,6 +1023,7 @@ main { %ir {{ load.b r0,1 + return }} } }""" @@ -1042,7 +1044,7 @@ main { }""" val result = compileText(VMTarget(), false, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 8 + st.size shouldBe 9 val assigns = st.filterIsInstance() (assigns[0].value as NumericLiteral).number shouldBe 12345 (assigns[1].value as NumericLiteral).number shouldBe 0xffee @@ -1053,11 +1055,11 @@ main { test("oneliner") { val src=""" main { sub start() { cx16.r0++ cx16.r1++ } } - other { asmsub thing() { %asm {{ inx }} } } + other { asmsub thing() { %asm {{ inx rts }} } } """ val result = compileText(VMTarget(), false, src, outputDir, writeAssembly = false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 2 + st.size shouldBe 3 } }) diff --git a/compiler/test/ast/TestVariousCompilerAst.kt b/compiler/test/ast/TestVariousCompilerAst.kt index 7187aa445..e30bcebb9 100644 --- a/compiler/test/ast/TestVariousCompilerAst.kt +++ b/compiler/test/ast/TestVariousCompilerAst.kt @@ -271,28 +271,6 @@ main { errors.errors[0] shouldContain "has result value" errors.errors[1] shouldContain "has result value" } - - test("missing return value is not a syntax error if there's an external goto") { - val src=""" -main { - sub start() { - cx16.r0 = runit1() - runit2() - } - - sub runit1() -> uword { - repeat { - cx16.r0++ - goto runit2 - } - } - - sub runit2() { - cx16.r0++ - } -}""" - compileText(C64Target(), optimize=false, src, outputDir, writeAssembly=false) shouldNotBe null - } } context("variable declarations") { @@ -355,7 +333,7 @@ main { }""" val result = compileText(Cx16Target(), optimize=true, src, outputDir, writeAssembly=false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 12 + st.size shouldBe 13 st[0] shouldBe instanceOf() // x st[2] shouldBe instanceOf() // y st[4] shouldBe instanceOf() // z @@ -429,7 +407,7 @@ main { errors.warnings.all { "dirty variable" in it } shouldBe true val start = result.compilerAst.entrypoint val st = start.statements - st.size shouldBe 9 + st.size shouldBe 10 val assignments = st.filterIsInstance() assignments.size shouldBe 2 assignments[0].target.identifier?.nameInSource shouldBe listOf("locwi") @@ -523,7 +501,7 @@ main { }""" val result = compileText(C64Target(), optimize=false, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 7 + stmts.size shouldBe 8 val assign1expr = (stmts[3] as Assignment).value as BinaryExpression val assign2expr = (stmts[5] as Assignment).value as BinaryExpression assign1expr.operator shouldBe "<<" @@ -554,7 +532,7 @@ main { }""" val result = compileText(C64Target(), optimize=false, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 9 + stmts.size shouldBe 10 } test("alternative notation for negative containment check") { @@ -570,7 +548,7 @@ main { """ val result = compileText(C64Target(), optimize=false, src, outputDir, writeAssembly=false)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 4 + stmts.size shouldBe 5 val value1 = (stmts[2] as Assignment).value as PrefixExpression val value2 = (stmts[3] as Assignment).value as PrefixExpression value1.operator shouldBe "not" @@ -764,7 +742,7 @@ main { }""" val result=compileText(VMTarget(), optimize=true, src, outputDir, writeAssembly=false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 11 + st.size shouldBe 12 val ifCond = (st[8] as IfElse).condition as BinaryExpression ifCond.operator shouldBe ">=" @@ -792,7 +770,7 @@ main { val result=compileText(Cx16Target(), optimize=false, src, outputDir, writeAssembly=false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 6 + st.size shouldBe 7 val value = (st[5] as Assignment).value as BinaryExpression value.operator shouldBe "%" } @@ -838,10 +816,9 @@ main { }""" val result = compileText(VMTarget(), optimize=true, src, outputDir, writeAssembly=false)!! val st = result.compilerAst.entrypoint.statements - st.size shouldBe 8 - val assignUbbVal = ((st[5] as Assignment).value as TypecastExpression) - assignUbbVal.type shouldBe BaseDataType.UBYTE - assignUbbVal.expression shouldBe instanceOf() + st.size shouldBe 9 + val assignUbbVal = (st[5] as Assignment).value as IdentifierReference + assignUbbVal.inferType(result.compilerAst) shouldBe InferredTypes.knownFor(BaseDataType.BYTE) val assignVaddr = (st[7] as Assignment).value as FunctionCallExpression assignVaddr.target.nameInSource shouldBe listOf("mkword") val tc = assignVaddr.args[0] as TypecastExpression @@ -970,6 +947,7 @@ main { if cx16.r0==0 return cx16.r0+cx16.r1 defer cx16.r2++ + return 999 } }""" val result = compileText(Cx16Target(), optimize=true, src, outputDir, writeAssembly=true)!! diff --git a/compiler/test/codegeneration/TestVariables.kt b/compiler/test/codegeneration/TestVariables.kt index 203c192b9..61c57f3dd 100644 --- a/compiler/test/codegeneration/TestVariables.kt +++ b/compiler/test/codegeneration/TestVariables.kt @@ -206,7 +206,7 @@ main { }""" val result = compileText(C64Target(), false, src, outputDir, writeAssembly = false)!!.compilerAst val st = result.entrypoint.statements - st.size shouldBe 8 + st.size shouldBe 9 st[0] shouldBe instanceOf() st[1] shouldBe instanceOf() st[2] shouldBe instanceOf() diff --git a/compilerAst/src/prog8/ast/AstToplevel.kt b/compilerAst/src/prog8/ast/AstToplevel.kt index 0ccd45ae8..cb594ba0b 100644 --- a/compilerAst/src/prog8/ast/AstToplevel.kt +++ b/compilerAst/src/prog8/ast/AstToplevel.kt @@ -7,7 +7,9 @@ import prog8.ast.expressions.NumericLiteral import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstVisitor -import prog8.code.core.* +import prog8.code.core.BaseDataType +import prog8.code.core.Encoding +import prog8.code.core.Position import prog8.code.source.SourceCode @@ -135,6 +137,26 @@ interface IStatementContainer { return null } + fun hasReturnStatement(): Boolean { + fun hasReturnStatement(stmt: Statement): Boolean { + when(stmt) { + is AnonymousScope -> return stmt.statements.any { hasReturnStatement(it) } + is ForLoop -> return stmt.body.hasReturnStatement() + is IfElse -> return stmt.truepart.hasReturnStatement() || stmt.elsepart.hasReturnStatement() + is WhileLoop -> return stmt.body.hasReturnStatement() + is RepeatLoop -> return stmt.body.hasReturnStatement() + is UntilLoop -> return stmt.body.hasReturnStatement() + is When -> return stmt.choices.any { it.statements.hasReturnStatement() } + is ConditionalBranch -> return stmt.truepart.hasReturnStatement() || stmt.elsepart.hasReturnStatement() + is UnrollLoop -> return stmt.body.hasReturnStatement() + is Return -> return true + else -> return false + } + } + + return statements.any { hasReturnStatement(it) } + } + val allDefinedSymbols: Sequence> get() { return statements.asSequence().filterIsInstance().map { Pair(it.name, it as Statement) }