diff --git a/compiler/src/prog8/CompilerMain.kt b/compiler/src/prog8/CompilerMain.kt index 2f6035c0c..e27d32721 100644 --- a/compiler/src/prog8/CompilerMain.kt +++ b/compiler/src/prog8/CompilerMain.kt @@ -48,6 +48,7 @@ private fun compileMain(args: Array) { var writeVmCode = false var writeAssembly = true var optimize = true + var optimizeInlining = true for (arg in args) { if(arg=="-emu") emulatorToStart = "x64" @@ -59,6 +60,8 @@ private fun compileMain(args: Array) { writeAssembly = false else if(arg=="-noopt") optimize = false + else if(arg=="-nooptinline") + optimizeInlining = false else if(!arg.startsWith("-")) moduleFile = arg else @@ -118,7 +121,7 @@ private fun compileMain(args: Array) { while (true) { // keep optimizing expressions and statements until no more steps remain val optsDone1 = programAst.simplifyExpressions() - val optsDone2 = programAst.optimizeStatements() + val optsDone2 = programAst.optimizeStatements(optimizeInlining) if (optsDone1 + optsDone2 == 0) break } @@ -225,12 +228,13 @@ fun determineCompilationOptions(program: Program): CompilationOptions { private fun usage() { System.err.println("Missing argument(s):") - System.err.println(" [-emu] auto-start the 'x64' C-64 emulator after successful compilation") - System.err.println(" [-emu2] auto-start the 'x64sc' C-64 emulator after successful compilation") - System.err.println(" [-writevm] write intermediate vm code to a file as well") - System.err.println(" [-noasm] don't create assembly code") - System.err.println(" [-vm] launch the prog8 virtual machine instead of the compiler") - System.err.println(" [-noopt] don't perform optimizations") - System.err.println(" modulefile main module file to compile") + System.err.println(" [-emu] auto-start the 'x64' C-64 emulator after successful compilation") + System.err.println(" [-emu2] auto-start the 'x64sc' C-64 emulator after successful compilation") + System.err.println(" [-writevm] write intermediate vm code to a file as well") + System.err.println(" [-noasm] don't create assembly code") + System.err.println(" [-vm] launch the prog8 virtual machine instead of the compiler") + System.err.println(" [-noopt] don't perform any optimizations") + System.err.println(" [-nooptinline] don't perform subroutine inlining optimizations") + System.err.println(" modulefile main module file to compile") exitProcess(1) } diff --git a/compiler/src/prog8/ast/AST.kt b/compiler/src/prog8/ast/AST.kt index af188c938..61eedd084 100644 --- a/compiler/src/prog8/ast/AST.kt +++ b/compiler/src/prog8/ast/AST.kt @@ -191,14 +191,14 @@ interface IAstProcessor { fun process(ifStatement: IfStatement): IStatement { ifStatement.condition = ifStatement.condition.process(this) - ifStatement.truepart = ifStatement.truepart.process(this) - ifStatement.elsepart = ifStatement.elsepart.process(this) + ifStatement.truepart = ifStatement.truepart.process(this) as AnonymousScope + ifStatement.elsepart = ifStatement.elsepart.process(this) as AnonymousScope return ifStatement } fun process(branchStatement: BranchStatement): IStatement { - branchStatement.truepart = branchStatement.truepart.process(this) - branchStatement.elsepart = branchStatement.elsepart.process(this) + branchStatement.truepart = branchStatement.truepart.process(this) as AnonymousScope + branchStatement.elsepart = branchStatement.elsepart.process(this) as AnonymousScope return branchStatement } @@ -240,19 +240,19 @@ interface IAstProcessor { fun process(forLoop: ForLoop): IStatement { forLoop.loopVar?.process(this) forLoop.iterable = forLoop.iterable.process(this) - forLoop.body = forLoop.body.process(this) + forLoop.body = forLoop.body.process(this) as AnonymousScope return forLoop } fun process(whileLoop: WhileLoop): IStatement { whileLoop.condition = whileLoop.condition.process(this) - whileLoop.body = whileLoop.body.process(this) + whileLoop.body = whileLoop.body.process(this) as AnonymousScope return whileLoop } fun process(repeatLoop: RepeatLoop): IStatement { repeatLoop.untilCondition = repeatLoop.untilCondition.process(this) - repeatLoop.body = repeatLoop.body.process(this) + repeatLoop.body = repeatLoop.body.process(this) as AnonymousScope return repeatLoop } @@ -274,7 +274,7 @@ interface IAstProcessor { return assignTarget } - fun process(scope: AnonymousScope): AnonymousScope { + fun process(scope: AnonymousScope): IStatement { scope.statements = scope.statements.asSequence().map { it.process(this) }.toMutableList() return scope } @@ -513,8 +513,8 @@ class Module(override val name: String, val source: Path) : Node, INameScope { override lateinit var parent: Node lateinit var program: Program + val importedBy = mutableListOf() val imports = mutableSetOf() - val importedBy = mutableSetOf() override fun linkParents(parent: Node) { this.parent=parent @@ -1695,9 +1695,10 @@ class Subroutine(override val name: String, val isAsmSubroutine: Boolean, override var statements: MutableList, override val position: Position) : IStatement, INameScope { + var keepAlways: Boolean = false override lateinit var parent: Node - val calledBy = mutableSetOf() + val calledBy = mutableListOf() val calls = mutableSetOf() val scopedname: String by lazy { makeScopedName(name) } diff --git a/compiler/src/prog8/ast/AstIdentifiersChecker.kt b/compiler/src/prog8/ast/AstIdentifiersChecker.kt index 3a2a276b6..0be53576c 100644 --- a/compiler/src/prog8/ast/AstIdentifiersChecker.kt +++ b/compiler/src/prog8/ast/AstIdentifiersChecker.kt @@ -51,8 +51,7 @@ fun Program.checkIdentifiers() { private class AstIdentifiersChecker(private val namespace: INameScope) : IAstProcessor { private val checkResult: MutableList = mutableListOf() - var blocks: MutableMap = mutableMapOf() - private set + private var blocks: MutableMap = mutableMapOf() fun result(): List { return checkResult diff --git a/compiler/src/prog8/ast/StmtReorderer.kt b/compiler/src/prog8/ast/StmtReorderer.kt index 5b6d0c71e..73541344c 100644 --- a/compiler/src/prog8/ast/StmtReorderer.kt +++ b/compiler/src/prog8/ast/StmtReorderer.kt @@ -224,10 +224,10 @@ private class VarInitValueAndAddressOfCreator(private val namespace: INameScope) // Also takes care to insert AddressOf (&) expression where required (string params to a UWORD function param etc). - private val vardeclsToAdd = mutableMapOf>() override fun process(module: Module) { + vardeclsToAdd.clear() super.process(module) // add any new vardecls to the various scopes @@ -254,8 +254,9 @@ private class VarInitValueAndAddressOfCreator(private val namespace: INameScope) } else declvalue + val identifierName = listOf(decl.name) // // TODO this was: (scoped name) decl.scopedname.split(".") return VariableInitializationAssignment( - AssignTarget(null, IdentifierReference(decl.scopedname.split("."), decl.position), null, null, decl.position), + AssignTarget(null, IdentifierReference(identifierName, decl.position), null, null, decl.position), null, value, decl.position diff --git a/compiler/src/prog8/optimizing/CallGraphBuilder.kt b/compiler/src/prog8/optimizing/CallGraph.kt similarity index 86% rename from compiler/src/prog8/optimizing/CallGraphBuilder.kt rename to compiler/src/prog8/optimizing/CallGraph.kt index 382c1b0e5..a2dcd6844 100644 --- a/compiler/src/prog8/optimizing/CallGraphBuilder.kt +++ b/compiler/src/prog8/optimizing/CallGraph.kt @@ -4,12 +4,16 @@ import prog8.ast.* import prog8.compiler.loadAsmIncludeFile -class CallGraphBuilder(private val program: Program): IAstProcessor { +class CallGraph(private val program: Program): IAstProcessor { - private val modulesImporting = mutableMapOf>().withDefault { mutableSetOf() } - private val modulesImportedBy = mutableMapOf>().withDefault { mutableSetOf() } - private val subroutinesCalling = mutableMapOf>().withDefault { mutableSetOf() } - private val subroutinesCalledBy = mutableMapOf>().withDefault { mutableSetOf() } + private val modulesImporting = mutableMapOf>().withDefault { mutableListOf() } + private val modulesImportedBy = mutableMapOf>().withDefault { mutableListOf() } + private val subroutinesCalling = mutableMapOf>().withDefault { mutableListOf() } + private val subroutinesCalledBy = mutableMapOf>().withDefault { mutableListOf() } + + init { + process(program) + } fun forAllSubroutines(scope: INameScope, sub: (s: Subroutine) -> Unit) { fun findSubs(scope: INameScope) { @@ -67,7 +71,7 @@ class CallGraphBuilder(private val program: Program): IAstProcessor { if(otherSub!=null) { functionCall.definingSubroutine()?.let { thisSub -> subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) - subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(thisSub) + subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCall) } } return super.process(functionCall) @@ -78,7 +82,7 @@ class CallGraphBuilder(private val program: Program): IAstProcessor { if(otherSub!=null) { functionCallStatement.definingSubroutine()?.let { thisSub -> subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) - subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(thisSub) + subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCallStatement) } } return super.process(functionCallStatement) @@ -89,7 +93,7 @@ class CallGraphBuilder(private val program: Program): IAstProcessor { if(otherSub!=null) { jump.definingSubroutine()?.let { thisSub -> subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) - subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(thisSub) + subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(jump) } } return super.process(jump) @@ -102,7 +106,7 @@ class CallGraphBuilder(private val program: Program): IAstProcessor { return super.process(inlineAssembly) } - private fun scanAssemblyCode(asm: String, context: Node, scope: INameScope) { + private fun scanAssemblyCode(asm: String, context: IStatement, scope: INameScope) { val asmJumpRx = Regex("""[\-+a-zA-Z0-9_ \t]+(jmp|jsr)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE) val asmRefRx = Regex("""[\-+a-zA-Z0-9_ \t]+(...)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE) asm.lines().forEach { line -> @@ -113,13 +117,13 @@ class CallGraphBuilder(private val program: Program): IAstProcessor { val node = program.namespace.lookup(jumptarget.split('.'), context) if (node is Subroutine) { subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node) - subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(scope) + subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context) } else if(jumptarget.contains('.')) { // maybe only the first part already refers to a subroutine val node2 = program.namespace.lookup(listOf(jumptarget.substringBefore('.')), context) if (node2 is Subroutine) { subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node2) - subroutinesCalledBy[node2] = subroutinesCalledBy.getValue(node2).plus(scope) + subroutinesCalledBy[node2] = subroutinesCalledBy.getValue(node2).plus(context) } } } @@ -131,7 +135,7 @@ class CallGraphBuilder(private val program: Program): IAstProcessor { val node = program.namespace.lookup(listOf(target.substringBefore('.')), context) if (node is Subroutine) { subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node) - subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(scope) + subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context) } } } diff --git a/compiler/src/prog8/optimizing/Extensions.kt b/compiler/src/prog8/optimizing/Extensions.kt index b8fe4649d..abc841d95 100644 --- a/compiler/src/prog8/optimizing/Extensions.kt +++ b/compiler/src/prog8/optimizing/Extensions.kt @@ -1,7 +1,6 @@ package prog8.optimizing -import prog8.ast.AstException -import prog8.ast.Program +import prog8.ast.* import prog8.parser.ParsingFailedError @@ -27,12 +26,16 @@ fun Program.constantFold() { } -fun Program.optimizeStatements(): Int { - val optimizer = StatementOptimizer(this) +fun Program.optimizeStatements(optimizeInlining: Boolean): Int { + val optimizer = StatementOptimizer(this, optimizeInlining) optimizer.process(this) - for(stmt in optimizer.statementsToRemove) { - val scope=stmt.definingScope() - scope.remove(stmt) + for(scope in optimizer.scopesToFlatten.reversed()) { + val namescope = scope.parent as INameScope + val idx = namescope.statements.indexOf(scope as IStatement) + if(idx>=0) { + namescope.statements[idx] = NopStatement(scope.position) + namescope.statements.addAll(idx, scope.statements) + } } modules.forEach { it.linkParents() } // re-link in final configuration diff --git a/compiler/src/prog8/optimizing/StatementOptimizer.kt b/compiler/src/prog8/optimizing/StatementOptimizer.kt index 00727c678..206b03772 100644 --- a/compiler/src/prog8/optimizing/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizing/StatementOptimizer.kt @@ -8,28 +8,88 @@ import kotlin.math.floor /* todo: subroutines with 1 or 2 byte args or 1 word arg can be converted to asm sub calling convention (args in registers) - - todo: implement usage counters for variables (locals and heap), blocks. Remove if count is zero. - - todo inline subroutines that are called exactly once (regardless of their size) - todo inline subroutines that are only called a few times (max 3?) (if < 20 statements) - todo inline all subroutines that are "very small" (0-3 statements) - + todo: implement usage counters for labels, variables (locals and heap), blocks. Remove if count is zero. todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to) + print warning about this */ -class StatementOptimizer(private val program: Program) : IAstProcessor { +class StatementOptimizer(private val program: Program, private val optimizeInlining: Boolean) : IAstProcessor { var optimizationsDone: Int = 0 private set - var statementsToRemove = mutableListOf() - private set + var scopesToFlatten = mutableListOf() + private val pureBuiltinFunctions = BuiltinFunctions.filter { it.value.pure } + companion object { + private var generatedLabelSequenceNumber = 0 + } override fun process(program: Program) { - val callgraph = CallGraphBuilder(program) - callgraph.process(program) + val callgraph = CallGraph(program) + removeUnusedCode(callgraph) + if(optimizeInlining) { + inlineSubroutines(callgraph) + } + super.process(program) + } + private fun inlineSubroutines(callgraph: CallGraph) { + val entrypoint = program.entrypoint() + program.modules.forEach { + callgraph.forAllSubroutines(it) { sub -> + if(sub!==entrypoint && !sub.isAsmSubroutine) { + if (sub.statements.size <= 3) { + sub.calledBy.toList().forEach { caller -> inlineSubroutine(sub, caller) } + } else if (sub.calledBy.size==1 && sub.statements.size < 50) { + inlineSubroutine(sub, sub.calledBy[0]) + } else if(sub.calledBy.size<=3 && sub.statements.size < 10) { + sub.calledBy.toList().forEach { caller -> inlineSubroutine(sub, caller) } + } + } + } + } + } + + private fun inlineSubroutine(sub: Subroutine, caller: Node) { + // if the sub is called multiple times from the same scope, we can't inline (would result in duplicate definitions) + // (unless we add a sequence number to all vars/labels and references to them in the inlined code, but I skip that for now) + val scope = caller.definingScope() + if(sub.calledBy.count { it.definingScope()===scope } > 1) + return + if(caller !is IFunctionCall || caller !is IStatement || sub.statements.any { it is Subroutine }) + return + + if(sub.parameters.isEmpty() && sub.returntypes.isEmpty()) { + // sub without params and without return value can be easily inlined + val parent = caller.parent as INameScope + val inlined = AnonymousScope(sub.statements.toMutableList(), caller.position) + parent.statements[parent.statements.indexOf(caller)] = inlined + // replace return statements in the inlined sub by a jump to the end of it + var endlabel = inlined.statements.last() as? Label + if(endlabel==null) { + endlabel = makeLabel("_prog8_auto_sub_end", inlined.statements.last().position) + inlined.statements.add(endlabel) + endlabel.parent = inlined + } + val returns = inlined.statements.withIndex().filter { iv -> iv.value is Return }.map { iv -> Pair(iv.index, iv.value as Return)} + for(returnIdx in returns) { + assert(returnIdx.second.values.isEmpty()) + val jump = Jump(null, IdentifierReference(listOf(endlabel.name), returnIdx.second.position), null, returnIdx.second.position) + inlined.statements[returnIdx.first] = jump + } + inlined.linkParents(caller.parent) + sub.calledBy.remove(caller) // if there are no callers left, the sub will be removed automatically later + optimizationsDone++ + } else { + // TODO inline subroutine that has params or returnvalues or both + } + } + + private fun makeLabel(name: String, position: Position): Label { + generatedLabelSequenceNumber++ + return Label("${name}_$generatedLabelSequenceNumber", position) + } + + private fun removeUnusedCode(callgraph: CallGraph) { // TODO remove unused variables (local and global) // remove all subroutines that aren't called, or are empty @@ -37,12 +97,12 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { val entrypoint = program.entrypoint() program.modules.forEach { callgraph.forAllSubroutines(it) { sub -> - if(sub !== entrypoint && !sub.keepAlways && (sub.calledBy.isEmpty() || (sub.containsNoCodeNorVars() && !sub.isAsmSubroutine))) + if (sub !== entrypoint && !sub.keepAlways && (sub.calledBy.isEmpty() || (sub.containsNoCodeNorVars() && !sub.isAsmSubroutine))) removeSubroutines.add(sub) } } - if(removeSubroutines.isNotEmpty()) { + if (removeSubroutines.isNotEmpty()) { removeSubroutines.forEach { it.definingScope().statements.remove(it) } @@ -55,7 +115,7 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { removeBlocks.add(block) } - if(removeBlocks.isNotEmpty()) { + if (removeBlocks.isNotEmpty()) { removeBlocks.forEach { it.definingScope().statements.remove(it) } } @@ -66,18 +126,16 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { removeModules.add(it) } - if(removeModules.isNotEmpty()) { + if (removeModules.isNotEmpty()) { println("[debug] removing ${removeModules.size} empty/unused modules") program.modules.removeAll(removeModules) } - - super.process(program) } override fun process(block: Block): IStatement { if(block.containsNoCodeNorVars()) { optimizationsDone++ - statementsToRemove.add(block) + return NopStatement(block.position) } return super.process(block) } @@ -88,7 +146,7 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { if(subroutine.asmAddress==null) { if(subroutine.containsNoCodeNorVars()) { optimizationsDone++ - statementsToRemove.add(subroutine) + return NopStatement(subroutine.position) } } @@ -161,8 +219,8 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { val functionName = functionCallStatement.target.nameInSource[0] if (functionName in pureBuiltinFunctions) { printWarning("statement has no effect (function return value is discarded)", functionCallStatement.position) - statementsToRemove.add(functionCallStatement) - return functionCallStatement + optimizationsDone++ + return NopStatement(functionCallStatement.position) } } @@ -238,9 +296,8 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { super.process(ifStatement) if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsNoCodeNorVars()) { - statementsToRemove.add(ifStatement) optimizationsDone++ - return ifStatement + return NopStatement(ifStatement.position) } if(ifStatement.truepart.containsNoCodeNorVars() && ifStatement.elsepart.containsCodeOrVars()) { @@ -273,16 +330,14 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { super.process(forLoop) if(forLoop.body.containsNoCodeNorVars()) { // remove empty for loop - statementsToRemove.add(forLoop) optimizationsDone++ - return forLoop + return NopStatement(forLoop.position) } else if(forLoop.body.statements.size==1) { val loopvar = forLoop.body.statements[0] as? VarDecl if(loopvar!=null && loopvar.name==forLoop.loopVar?.nameInSource?.singleOrNull()) { // remove empty for loop - statementsToRemove.add(forLoop) optimizationsDone++ - return forLoop + return NopStatement(forLoop.position) } } @@ -392,6 +447,17 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { return first } } + + // if the jump is to the next statement, remove the jump + val scope = jump.definingScope() + val label = jump.identifier?.targetStatement(scope) + if(label!=null) { + if(scope.statements.indexOf(label) == scope.statements.indexOf(jump)+1) { + optimizationsDone++ + return NopStatement(jump.position) + } + } + return jump } @@ -405,7 +471,7 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { optimizationsDone++ return NopStatement(assignment.position) } - val targetDt = target.determineDatatype(program, assignment)!! + val targetDt = target.determineDatatype(program, assignment) val bexpr=assignment.value as? BinaryExpression if(bexpr!=null) { val cv = bexpr.right.constValue(program)?.asNumericValue?.toDouble() @@ -528,14 +594,29 @@ class StatementOptimizer(private val program: Program) : IAstProcessor { return super.process(assignment) } - override fun process(scope: AnonymousScope): AnonymousScope { + override fun process(scope: AnonymousScope): IStatement { val linesToRemove = deduplicateAssignments(scope.statements) if(linesToRemove.isNotEmpty()) { linesToRemove.reversed().forEach{scope.statements.removeAt(it)} } + + if(scope.parent is INameScope) { + scopesToFlatten.add(scope) // get rid of the anonymous scope + } + return super.process(scope) } + override fun process(label: Label): IStatement { + // remove duplicate labels + val stmts = label.definingScope().statements + val startIdx = stmts.indexOf(label) + if(startIdx<(stmts.size-1) && stmts[startIdx+1] == label) + return NopStatement(label.position) + + return super.process(label) + } + private fun same(target: AssignTarget, value: IExpression): Boolean { return when { target.memoryAddress!=null -> false diff --git a/examples/test.p8 b/examples/test.p8 index aa7ef5c8c..29c3e241f 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,28 +1,61 @@ %zeropage basicsafe -%option enable_floats -%import c64flt ~ main { - float[] fa = [1.1,2.2,3.3] - ubyte[] uba = [10,2,3,4] - byte[] ba = [-10,2,3,4] - uword[] uwa = [100,20,30,40] - word[] wa = [-100,20,30,40] - sub start() { - float a - a=avg([1,2,3,4]) - c64flt.print_f(a) - c64.CHROUT('\n') - a=avg([100,200,300,400]) - c64flt.print_f(a) - c64.CHROUT('\n') - a=avg([1.1,2.2,3.3,4.4]) - c64flt.print_f(a) + greeting() + + ubyte square = stuff.function(12) + + c64scr.print_ub(square) c64.CHROUT('\n') + stuff.name() + stuff.name() + stuff.bye() + + abs(4) + abs(4) + abs(4) + abs(4) + abs(4) + foobar() + foobar() + foobar() + foobar() + + + if(false) { + } else { + + } + } + + + sub foobar() { + } + + sub greeting() { + c64scr.print("hello\n") + } +} + + +~ stuff { + + sub function(ubyte v) -> ubyte { + return v*v + } + + sub name() { + c64scr.print("name\n") + } + + sub bye() { + c64scr.print("bye\n") + } + }