diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index 413b6e2fd..41355daa8 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -39,6 +39,41 @@ internal class StatementOptimizer(private val program: Program, return noModifications } + override fun before(functionCall: FunctionCall, parent: Node): Iterable { + // if the first instruction in the called subroutine is a return statement with a simple value, + // remove the jump altogeter and inline the returnvalue directly. + val subroutine = functionCall.target.targetSubroutine(program) + if(subroutine!=null) { + val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull() + if(first is Return && first.value?.isSimple==true) { + val orig = first.value!! + val copy = when(orig) { + is AddressOf -> { + val scoped = scopePrefix(orig.identifier, subroutine) + AddressOf(scoped, orig.position) + } + is DirectMemoryRead -> { + when(val expr = orig.addressExpression) { + is NumericLiteralValue -> DirectMemoryRead(expr.copy(), orig.position) + else -> return noModifications + } + } + is IdentifierReference -> scopePrefix(orig, subroutine) + is NumericLiteralValue -> orig.copy() + is StringLiteralValue -> orig.copy() + else -> return noModifications + } + return listOf(IAstModification.ReplaceNode(functionCall, copy, parent)) + } + } + return noModifications + } + + private fun scopePrefix(variable: IdentifierReference, subroutine: Subroutine): IdentifierReference { + val scoped = subroutine.makeScopedName(variable.nameInSource.last()) + return IdentifierReference(scoped.split('.'), variable.position) + } + override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable { if(functionCallStatement.target.nameInSource.size==1 && functionCallStatement.target.nameInSource[0] in functions.names) { val functionName = functionCallStatement.target.nameInSource[0] @@ -101,19 +136,19 @@ internal class StatementOptimizer(private val program: Program, return noModifications } - override fun before(functionCall: FunctionCall, parent: Node): Iterable { - // if the first instruction in the called subroutine is a return statement with constant value, replace with the constant value - val subroutine = functionCall.target.targetSubroutine(program) - if(subroutine!=null) { - val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull() - if(first is Return && first.value!=null) { - val constval = first.value?.constValue(program) - if(constval!=null) - return listOf(IAstModification.ReplaceNode(functionCall, constval, parent)) - } - } - return noModifications - } +// override fun before(functionCall: FunctionCall, parent: Node): Iterable { +// // if the first instruction in the called subroutine is a return statement with constant value, replace with the constant value +// val subroutine = functionCall.target.targetSubroutine(program) +// if(subroutine!=null) { +// val first = subroutine.statements.asSequence().filterNot { it is VarDecl || it is Directive }.firstOrNull() +// if(first is Return && first.value!=null) { +// val constval = first.value?.constValue(program) +// if(constval!=null) +// return listOf(IAstModification.ReplaceNode(functionCall, constval, parent)) +// } +// } +// return noModifications +// } override fun after(ifStatement: IfStatement, parent: Node): Iterable { // remove empty if statements diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index e4ea9097a..1d6d721c2 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -348,7 +348,7 @@ class DirectMemoryRead(var addressExpression: Expression, override val position: this.addressExpression.linkParents(this) } - override val isSimple = true + override val isSimple = addressExpression is NumericLiteralValue || addressExpression is IdentifierReference override fun replaceChildNode(node: Node, replacement: Node) { require(replacement is Expression && node===addressExpression) @@ -366,8 +366,6 @@ class DirectMemoryRead(var addressExpression: Expression, override val position: override fun toString(): String { return "DirectMemoryRead($addressExpression)" } - - fun copy() = DirectMemoryRead(addressExpression, position) } class NumericLiteralValue(val type: DataType, // only numerical types allowed @@ -376,6 +374,7 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed override lateinit var parent: Node override val isSimple = true + fun copy() = NumericLiteralValue(type, number, position) companion object { fun fromBoolean(bool: Boolean, position: Position) = @@ -509,6 +508,7 @@ class StringLiteralValue(val value: String, } override val isSimple = true + fun copy() = StringLiteralValue(value, altEncoding, position) override fun replaceChildNode(node: Node, replacement: Node) { throw FatalAstException("can't replace here") diff --git a/examples/test.p8 b/examples/test.p8 index 75adb566b..0475e7eeb 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,12 +1,70 @@ -%import palette -%import test_stack -%zeropage basicsafe +%import textio +%zeropage dontuse main { sub start() { - ; TODO inline a subroutine that only contains a direct call to another subroutine - palette.set_all_black() - palette.set_all_white() + uword v + v = test.get_value1() + txt.print_uw(v) + txt.nl() + v = test.get_value2() + txt.print_uw(v) + txt.nl() + v = test.get_value3() + txt.print_uw(v) + txt.nl() + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + v = test.get_value4() + txt.print_uw(v) + txt.nl() + v = test.get_value5() + txt.print_uw(v) + txt.nl() + v = test.get_value6() + txt.print_uw(v) + txt.nl() + } +} + + +test { + uword[] arr = [1111,2222,3333] + uword value = 9999 + + sub get_value1() -> uword { + return &value + } + sub get_value2() -> uword { + return arr[2] + } + sub get_value3() -> ubyte { + return @($c000) + } + sub get_value4() -> uword { + return value + } + sub get_value5() -> uword { + return $c000 + } + sub get_value6() -> uword { + return "string" } }