From a999c23014bb83ead00726bcfacb4613f297b9f1 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sat, 27 Jun 2020 16:51:25 +0200 Subject: [PATCH] simple subroutine inlining added --- compiler/src/prog8/ast/AstToplevel.kt | 1 + compiler/src/prog8/ast/base/Extensions.kt | 6 +++ .../ast/processing/AstVariousTransforms.kt | 17 +++--- .../prog8/ast/processing/SubroutineInliner.kt | 39 ++++++++++++++ .../src/prog8/ast/statements/AstStatements.kt | 32 +++++++++-- compiler/src/prog8/compiler/Main.kt | 1 + compiler/src/prog8/optimizer/CallGraph.kt | 44 +++++++-------- .../src/prog8/optimizer/StatementOptimizer.kt | 1 - examples/test.p8 | 53 ++++++++----------- 9 files changed, 130 insertions(+), 64 deletions(-) create mode 100644 compiler/src/prog8/ast/processing/SubroutineInliner.kt diff --git a/compiler/src/prog8/ast/AstToplevel.kt b/compiler/src/prog8/ast/AstToplevel.kt index 2d8e91340..4596c7c9d 100644 --- a/compiler/src/prog8/ast/AstToplevel.kt +++ b/compiler/src/prog8/ast/AstToplevel.kt @@ -155,6 +155,7 @@ interface INameScope { } fun containsCodeOrVars() = statements.any { it !is Directive || it.directive == "%asminclude" || it.directive == "%asm"} + fun containsNoVars() = statements.all { it !is VarDecl } fun containsNoCodeNorVars() = !containsCodeOrVars() fun remove(stmt: Statement) { diff --git a/compiler/src/prog8/ast/base/Extensions.kt b/compiler/src/prog8/ast/base/Extensions.kt index a85aadd3c..6f96b403f 100644 --- a/compiler/src/prog8/ast/base/Extensions.kt +++ b/compiler/src/prog8/ast/base/Extensions.kt @@ -26,6 +26,12 @@ internal fun Program.reorderStatements() { reorder.applyModifications() } +internal fun Program.inlineSubroutines() { + val reorder = SubroutineInliner(this) + reorder.visit(this) + reorder.applyModifications() +} + internal fun Program.addTypecasts(errors: ErrorReporter) { val caster = TypecastsAdder(this, errors) caster.visit(this) diff --git a/compiler/src/prog8/ast/processing/AstVariousTransforms.kt b/compiler/src/prog8/ast/processing/AstVariousTransforms.kt index 0eaa771a9..b5760aac8 100644 --- a/compiler/src/prog8/ast/processing/AstVariousTransforms.kt +++ b/compiler/src/prog8/ast/processing/AstVariousTransforms.kt @@ -83,13 +83,16 @@ internal class AstVariousTransforms(private val program: Program) : AstWalker() val symbolsInSub = subroutine.allDefinedSymbols() val namesInSub = symbolsInSub.map{ it.first }.toSet() if(subroutine.asmAddress==null) { - if(subroutine.asmParameterRegisters.isEmpty()) { - return subroutine.parameters - .filter { it.name !in namesInSub } - .map { - val vardecl = ParameterVarDecl(it.name, it.type, subroutine.position) - IAstModification.InsertFirst(vardecl, subroutine) - } + if(subroutine.asmParameterRegisters.isEmpty() && subroutine.parameters.isNotEmpty()) { + val vars = subroutine.statements.filterIsInstance().map { it.name }.toSet() + if(!vars.containsAll(subroutine.parameters.map{it.name})) { + return subroutine.parameters + .filter { it.name !in namesInSub } + .map { + val vardecl = ParameterVarDecl(it.name, it.type, subroutine.position) + IAstModification.InsertFirst(vardecl, subroutine) + } + } } } diff --git a/compiler/src/prog8/ast/processing/SubroutineInliner.kt b/compiler/src/prog8/ast/processing/SubroutineInliner.kt new file mode 100644 index 000000000..ce6555338 --- /dev/null +++ b/compiler/src/prog8/ast/processing/SubroutineInliner.kt @@ -0,0 +1,39 @@ +package prog8.ast.processing + +import prog8.ast.Node +import prog8.ast.Program +import prog8.ast.statements.* +import prog8.optimizer.CallGraph + + +internal class SubroutineInliner(private val program: Program) : AstWalker() { + private val noModifications = emptyList() + private val callgraph = CallGraph(program) + + override fun after(subroutine: Subroutine, parent: Node): Iterable { + + if(!subroutine.isAsmSubroutine && callgraph.calledBy[subroutine]!=null && subroutine.containsCodeOrVars()) { + + // TODO for now, inlined subroutines can't have parameters or local variables - improve this + if(subroutine.parameters.isEmpty() && subroutine.containsNoVars()) { + if (subroutine.countStatements() <= 5) { + if (callgraph.calledBy.getValue(subroutine).size == 1 || !subroutine.statements.any { it.expensiveToInline }) + return inline(subroutine) + } + } + } + return noModifications + } + + private fun inline(subroutine: Subroutine): Iterable { + val calls = callgraph.calledBy.getValue(subroutine) + return calls.map { + call -> IAstModification.ReplaceNode( + call, + AnonymousScope(subroutine.statements, call.position), + call.parent + ) + }.plus(IAstModification.Remove(subroutine, subroutine.parent)) + } + +} diff --git a/compiler/src/prog8/ast/statements/AstStatements.kt b/compiler/src/prog8/ast/statements/AstStatements.kt index 22e81c2e9..c838bea00 100644 --- a/compiler/src/prog8/ast/statements/AstStatements.kt +++ b/compiler/src/prog8/ast/statements/AstStatements.kt @@ -340,7 +340,7 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node open class Assignment(var target: AssignTarget, var aug_op : String?, var value: Expression, override val position: Position) : Statement() { override lateinit var parent: Node override val expensiveToInline - get() = value !is NumericLiteralValue + get() = value is BinaryExpression override fun linkParents(parent: Node) { this.parent = parent @@ -668,8 +668,8 @@ class Subroutine(override val name: String, get() = statements.any { it.expensiveToInline } override lateinit var parent: Node - val calledBy = mutableListOf() - val calls = mutableSetOf() + val calledBy = mutableListOf() // TODO remove, use callgraph only + val calls = mutableSetOf() // TODO remove, use callgraph only val scopedname: String by lazy { makeScopedName(name) } @@ -700,6 +700,32 @@ class Subroutine(override val name: String, .filter { it is InlineAssembly } .map { (it as InlineAssembly).assembly } .count { " rti" in it || "\trti" in it || " rts" in it || "\trts" in it || " jmp" in it || "\tjmp" in it } + + fun countStatements(): Int { + class StatementCounter: IAstVisitor { + var count = 0 + + override fun visit(block: Block) { + count += block.statements.size + super.visit(block) + } + + override fun visit(subroutine: Subroutine) { + count += subroutine.statements.size + super.visit(subroutine) + } + + override fun visit(scope: AnonymousScope) { + count += scope.statements.size + super.visit(scope) + } + } + + // the (recursive) number of statements + val counter = StatementCounter() + counter.visit(this) + return counter.count + } } diff --git a/compiler/src/prog8/compiler/Main.kt b/compiler/src/prog8/compiler/Main.kt index e3622ac6e..a471e7042 100644 --- a/compiler/src/prog8/compiler/Main.kt +++ b/compiler/src/prog8/compiler/Main.kt @@ -150,6 +150,7 @@ private fun processAst(programAst: Program, errors: ErrorReporter, compilerOptio programAst.reorderStatements() programAst.addTypecasts(errors) errors.handle() + programAst.inlineSubroutines() programAst.checkValid(compilerOptions, errors) errors.handle() programAst.checkIdentifiers(errors) diff --git a/compiler/src/prog8/optimizer/CallGraph.kt b/compiler/src/prog8/optimizer/CallGraph.kt index 8dcaad9dd..adaa56513 100644 --- a/compiler/src/prog8/optimizer/CallGraph.kt +++ b/compiler/src/prog8/optimizer/CallGraph.kt @@ -24,10 +24,10 @@ private val asmRefRx = Regex("""[\-+a-zA-Z0-9_ \t]+(...)[ \t]+(\S+).*""", RegexO class CallGraph(private val program: Program) : IAstVisitor { - val modulesImporting = mutableMapOf>().withDefault { mutableListOf() } - val modulesImportedBy = mutableMapOf>().withDefault { mutableListOf() } - val subroutinesCalling = mutableMapOf>().withDefault { mutableListOf() } - val subroutinesCalledBy = mutableMapOf>().withDefault { mutableListOf() } + val imports = mutableMapOf>().withDefault { mutableListOf() } + val importedBy = mutableMapOf>().withDefault { mutableListOf() } + val calls = mutableMapOf>().withDefault { mutableListOf() } + val calledBy = mutableMapOf>().withDefault { mutableListOf() } // TODO add dataflow graph: what statements use what variables - can be used to eliminate unused vars val usedSymbols = mutableSetOf() @@ -55,15 +55,15 @@ class CallGraph(private val program: Program) : IAstVisitor { it.importedBy.clear() it.imports.clear() - it.importedBy.addAll(modulesImportedBy.getValue(it)) - it.imports.addAll(modulesImporting.getValue(it)) + it.importedBy.addAll(importedBy.getValue(it)) + it.imports.addAll(imports.getValue(it)) forAllSubroutines(it) { sub -> sub.calledBy.clear() sub.calls.clear() - sub.calledBy.addAll(subroutinesCalledBy.getValue(sub)) - sub.calls.addAll(subroutinesCalling.getValue(sub)) + sub.calledBy.addAll(calledBy.getValue(sub)) + sub.calls.addAll(calls.getValue(sub)) } } @@ -85,8 +85,8 @@ class CallGraph(private val program: Program) : IAstVisitor { val thisModule = directive.definingModule() if (directive.directive == "%import") { val importedModule: Module = program.modules.single { it.name == directive.args[0].name } - modulesImporting[thisModule] = modulesImporting.getValue(thisModule).plus(importedModule) - modulesImportedBy[importedModule] = modulesImportedBy.getValue(importedModule).plus(thisModule) + imports[thisModule] = imports.getValue(thisModule).plus(importedModule) + importedBy[importedModule] = importedBy.getValue(importedModule).plus(thisModule) } else if (directive.directive == "%asminclude") { val asm = loadAsmIncludeFile(directive.args[0].str!!, thisModule.source) val scope = directive.definingScope() @@ -141,8 +141,8 @@ class CallGraph(private val program: Program) : IAstVisitor { val otherSub = functionCall.target.targetSubroutine(program.namespace) if (otherSub != null) { functionCall.definingSubroutine()?.let { thisSub -> - subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) - subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCall) + calls[thisSub] = calls.getValue(thisSub).plus(otherSub) + calledBy[otherSub] = calledBy.getValue(otherSub).plus(functionCall) } } super.visit(functionCall) @@ -152,8 +152,8 @@ class CallGraph(private val program: Program) : IAstVisitor { val otherSub = functionCallStatement.target.targetSubroutine(program.namespace) if (otherSub != null) { functionCallStatement.definingSubroutine()?.let { thisSub -> - subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) - subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCallStatement) + calls[thisSub] = calls.getValue(thisSub).plus(otherSub) + calledBy[otherSub] = calledBy.getValue(otherSub).plus(functionCallStatement) } } super.visit(functionCallStatement) @@ -163,8 +163,8 @@ class CallGraph(private val program: Program) : IAstVisitor { val otherSub = jump.identifier?.targetSubroutine(program.namespace) if (otherSub != null) { jump.definingSubroutine()?.let { thisSub -> - subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub) - subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(jump) + calls[thisSub] = calls.getValue(thisSub).plus(otherSub) + calledBy[otherSub] = calledBy.getValue(otherSub).plus(jump) } } super.visit(jump) @@ -190,14 +190,14 @@ class CallGraph(private val program: Program) : IAstVisitor { if (jumptarget != null && (jumptarget[0].isLetter() || jumptarget[0] == '_')) { val node = program.namespace.lookup(jumptarget.split('.'), context) if (node is Subroutine) { - subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node) - subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context) + calls[scope] = calls.getValue(scope).plus(node) + calledBy[node] = calledBy.getValue(node).plus(context) } else if (jumptarget.contains('.')) { // maybe only the first part already refers to a subroutine val node2 = program.namespace.lookup(listOf(jumptarget.substringBefore('.')), context) if (node2 is Subroutine) { - subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node2) - subroutinesCalledBy[node2] = subroutinesCalledBy.getValue(node2).plus(context) + calls[scope] = calls.getValue(scope).plus(node2) + calledBy[node2] = calledBy.getValue(node2).plus(context) } } } @@ -209,8 +209,8 @@ class CallGraph(private val program: Program) : IAstVisitor { if (target.contains('.')) { val node = program.namespace.lookup(listOf(target.substringBefore('.')), context) if (node is Subroutine) { - subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node) - subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context) + calls[scope] = calls.getValue(scope).plus(node) + calledBy[node] = calledBy.getValue(node).plus(context) } } } diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index 9d4fc1691..6b2840af8 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -16,7 +16,6 @@ import kotlin.math.floor /* TODO: remove unreachable code after return and exit() - TODO: proper inlining of tiny subroutines (at first, restrict to subs without parameters and variables in them, and build it up from there: correctly renaming/relocating all variables in them and refs to those as well) */ diff --git a/examples/test.p8 b/examples/test.p8 index e0c236b00..2e1fb5831 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -7,38 +7,29 @@ main { + sub ss(ubyte qq) { + Y=qq + } + + sub tiny() { + Y++ + } + + sub tiny2() { + for A in 10 to 20 { + ubyte xx = A + } + } + sub start() { - float[] fa=[1.1111,2.2222,3.3333,4.4444] - ubyte[] uba = [1,2,3,4] - word[] uwa = [1111,2222,3333,4444] - ubyte ii = 1 - ubyte jj = 3 - - float f1 = 1.123456 - float f2 = 2.223344 - - swap(f1, f2) - - swap(fa[0], fa[1]) - swap(uba[0], uba[1]) - swap(uwa[0], uwa[1]) - - ubyte i1 - ubyte i2 - swap(fa[i1], fa[i2]) - swap(uba[i1], uba[i2]) - swap(uwa[i1], uwa[i2]) - - c64flt.print_f(f1) - c64.CHROUT('\n') - c64flt.print_f(f2) - c64.CHROUT('\n') - - swap(f1,f2) - c64flt.print_f(f1) - c64.CHROUT('\n') - c64flt.print_f(f2) - c64.CHROUT('\n') + uword zomg=2 + A=lsb(zomg) + ss(100) + ss(101) + tiny() + tiny() + tiny2() + tiny2() } }