diff --git a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt index 5761b5591..f0816cfc5 100644 --- a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt +++ b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt @@ -120,13 +120,13 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, val errors: I // and if an assembly block doesn't contain a rts/rti, and some other situations. val mods = mutableListOf() val returnStmt = Return(null, subroutine.position) - if (subroutine.asmAddress == null - && !subroutine.inline - && subroutine.statements.isNotEmpty() - && subroutine.amountOfRtsInAsm() == 0 + if (subroutine.asmAddress == null && !subroutine.inline) { + if(subroutine.statements.isEmpty() || + (subroutine.amountOfRtsInAsm() == 0 && subroutine.statements.lastOrNull { it !is VarDecl } !is Return - && subroutine.statements.last() !is Subroutine) { - mods += IAstModification.InsertLast(returnStmt, subroutine) + && subroutine.statements.last() !is Subroutine)) { + mods += IAstModification.InsertLast(returnStmt, subroutine) + } } // precede a subroutine with a return to avoid falling through into the subroutine from code above it diff --git a/compiler/src/prog8/optimizer/UnusedCodeRemover.kt b/compiler/src/prog8/optimizer/UnusedCodeRemover.kt index 1e362606e..45dbec5b3 100644 --- a/compiler/src/prog8/optimizer/UnusedCodeRemover.kt +++ b/compiler/src/prog8/optimizer/UnusedCodeRemover.kt @@ -81,20 +81,20 @@ internal class UnusedCodeRemover(private val program: Program, val forceOutput = "force_output" in subroutine.definingBlock.options() if (subroutine !== program.entrypoint && !forceOutput && !subroutine.inline && !subroutine.isAsmSubroutine) { if(callgraph.unused(subroutine)) { + if(subroutine.containsNoCodeNorVars) { + if(!subroutine.definingModule.isLibrary) + errors.warn("removing empty subroutine '${subroutine.name}'", subroutine.position) + val removals = mutableListOf(IAstModification.Remove(subroutine, subroutine.definingScope)) + callgraph.calledBy[subroutine]?.let { + for(node in it) + removals.add(IAstModification.Remove(node, node.definingScope)) + } + return removals + } if(!subroutine.definingModule.isLibrary) errors.warn("removing unused subroutine '${subroutine.name}'", subroutine.position) return listOf(IAstModification.Remove(subroutine, subroutine.definingScope)) } - if(subroutine.containsNoCodeNorVars) { - if(!subroutine.definingModule.isLibrary) - errors.warn("removing empty subroutine '${subroutine.name}'", subroutine.position) - val removals = mutableListOf(IAstModification.Remove(subroutine, subroutine.definingScope)) - callgraph.calledBy[subroutine]?.let { - for(node in it) - removals.add(IAstModification.Remove(node, node.definingScope)) - } - return removals - } } val removeDoubleAssignments = deduplicateAssignments(subroutine.statements) diff --git a/compiler/test/TestCallgraph.kt b/compiler/test/TestCallgraph.kt new file mode 100644 index 000000000..16e9bf6a1 --- /dev/null +++ b/compiler/test/TestCallgraph.kt @@ -0,0 +1,91 @@ +package prog8tests + +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance +import prog8.ast.statements.Block +import prog8.ast.statements.Subroutine +import prog8.compiler.target.C64Target +import prog8.optimizer.CallGraph +import prog8tests.helpers.assertSuccess +import prog8tests.helpers.compileText +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class TestCallgraph { + @Test + fun testGraphForEmptySubs() { + val sourcecode = """ + %import string + main { + sub start() { + } + sub empty() { + } + } + """ + val result = compileText(C64Target, false, sourcecode).assertSuccess() + val graph = CallGraph(result.programAst) + + assertEquals(1, graph.imports.size) + assertEquals(1, graph.importedBy.size) + val toplevelModule = result.programAst.toplevelModule + val importedModule = graph.imports.getValue(toplevelModule).single() + assertEquals("string", importedModule.name) + val importedBy = graph.importedBy.getValue(importedModule).single() + assertTrue(importedBy.name.startsWith("on_the_fly_test")) + + assertFalse(graph.unused(toplevelModule)) + assertFalse(graph.unused(importedModule)) + + val mainBlock = toplevelModule.statements.filterIsInstance().single() + for(stmt in mainBlock.statements) { + val sub = stmt as Subroutine + assertFalse(sub in graph.calls) + assertFalse(sub in graph.calledBy) + + if(sub === result.programAst.entrypoint) + assertFalse(graph.unused(sub), "start() should always be marked as used to avoid having it removed") + else + assertTrue(graph.unused(sub)) + } + } + + @Test + fun testGraphForEmptyButReferencedSub() { + val sourcecode = """ + %import string + main { + sub start() { + uword xx = &empty + xx++ + } + sub empty() { + } + } + """ + val result = compileText(C64Target, false, sourcecode).assertSuccess() + val graph = CallGraph(result.programAst) + + assertEquals(1, graph.imports.size) + assertEquals(1, graph.importedBy.size) + val toplevelModule = result.programAst.toplevelModule + val importedModule = graph.imports.getValue(toplevelModule).single() + assertEquals("string", importedModule.name) + val importedBy = graph.importedBy.getValue(importedModule).single() + assertTrue(importedBy.name.startsWith("on_the_fly_test")) + + assertFalse(graph.unused(toplevelModule)) + assertFalse(graph.unused(importedModule)) + + val mainBlock = toplevelModule.statements.filterIsInstance().single() + val startSub = mainBlock.statements.filterIsInstance().single{it.name=="start"} + val emptySub = mainBlock.statements.filterIsInstance().single{it.name=="empty"} + + assertTrue(startSub in graph.calls, "start 'calls' (references) empty") + assertFalse(emptySub in graph.calls, "empty doesn't call anything") + assertTrue(emptySub in graph.calledBy, "empty gets 'called'") + assertFalse(startSub in graph.calledBy, "start doesn't get called (except as entrypoint ofc.)") + } +} diff --git a/compiler/test/TestOptimization.kt b/compiler/test/TestOptimization.kt new file mode 100644 index 000000000..f3df0da7d --- /dev/null +++ b/compiler/test/TestOptimization.kt @@ -0,0 +1,59 @@ +package prog8tests + +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance +import prog8.ast.statements.Block +import prog8.ast.statements.Subroutine +import prog8.compiler.target.C64Target +import prog8tests.helpers.assertSuccess +import prog8tests.helpers.compileText +import kotlin.test.assertEquals +import kotlin.test.assertSame + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class TestOptimization { + @Test + fun testRemoveEmptySubroutineExceptStart() { + val sourcecode = """ + main { + sub start() { + } + sub empty() { + ; going to be removed + } + } + """ + val result = compileText(C64Target, true, sourcecode).assertSuccess() + val toplevelModule = result.programAst.toplevelModule + val mainBlock = toplevelModule.statements.single() as Block + assertEquals(1, mainBlock.statements.size) + val startSub = mainBlock.statements[0] as Subroutine + assertSame(result.programAst.entrypoint, startSub) + assertEquals("start", startSub.name) + assertEquals(0, startSub.statements.size) + } + + @Test + fun testDontRemoveEmptySubroutineIfItsReferenced() { + val sourcecode = """ + main { + sub start() { + uword xx = &empty + xx++ + } + sub empty() { + ; should not be removed + } + } + """ + val result = compileText(C64Target, true, sourcecode).assertSuccess() + val toplevelModule = result.programAst.toplevelModule + val mainBlock = toplevelModule.statements.single() as Block + val startSub = mainBlock.statements[0] as Subroutine + val emptySub = mainBlock.statements[1] as Subroutine + assertSame(result.programAst.entrypoint, startSub) + assertEquals("start", startSub.name) + assertEquals("empty", emptySub.name) + assertEquals(0, emptySub.statements.size) + } +} diff --git a/examples/test.p8 b/examples/test.p8 index cfa7a4362..f5ba855fa 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,18 +1,10 @@ -%import textio - main { sub start() { - ubyte xx + uword address = &irq + ; cx16.set_irq(&irq, false) + address++ + } - when xx { - 2 -> { - } - 3 -> { - } - 50 -> { - } - else -> { - } - } + sub irq() { } }