diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 956bcf0e4..b9710d20c 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -91,7 +91,7 @@ fun compileProgram(filepath: Path, importedFiles = imported processAst(programAst, errors, compilationOptions) if (compilationOptions.optimize) - optimizeAst(programAst, errors) + optimizeAst(programAst, errors, BuiltinFunctionsFacade(BuiltinFunctions)) postprocessAst(programAst, errors, compilationOptions) // printAst(programAst) @@ -134,6 +134,8 @@ private class BuiltinFunctionsFacade(functions: Map): IBuilt lateinit var program: Program override val names = functions.keys + override val purefunctionNames = functions.filter { it.value.pure }.map { it.key }.toSet() + override fun constValue(name: String, args: List, position: Position): NumericLiteralValue? { val func = BuiltinFunctions[name] if(func!=null) { @@ -252,14 +254,14 @@ private fun processAst(programAst: Program, errors: ErrorReporter, compilerOptio errors.handle() } -private fun optimizeAst(programAst: Program, errors: ErrorReporter) { +private fun optimizeAst(programAst: Program, errors: ErrorReporter, functions: IBuiltinFunctions) { // optimize the parse tree println("Optimizing...") while (true) { // keep optimizing expressions and statements until no more steps remain val optsDone1 = programAst.simplifyExpressions() val optsDone2 = programAst.splitBinaryExpressions() - val optsDone3 = programAst.optimizeStatements(errors) + val optsDone3 = programAst.optimizeStatements(errors, functions) programAst.constantFold(errors) // because simplified statements and expressions can result in more constants that can be folded away errors.handle() if (optsDone1 + optsDone2 + optsDone3 == 0) diff --git a/compiler/src/prog8/optimizer/Extensions.kt b/compiler/src/prog8/optimizer/Extensions.kt index 9adb188a7..b4e0f1a8e 100644 --- a/compiler/src/prog8/optimizer/Extensions.kt +++ b/compiler/src/prog8/optimizer/Extensions.kt @@ -1,5 +1,6 @@ package prog8.optimizer +import prog8.ast.IBuiltinFunctions import prog8.ast.Program import prog8.compiler.ErrorReporter @@ -38,8 +39,8 @@ internal fun Program.constantFold(errors: ErrorReporter) { } -internal fun Program.optimizeStatements(errors: ErrorReporter): Int { - val optimizer = StatementOptimizer(this, errors) +internal fun Program.optimizeStatements(errors: ErrorReporter, functions: IBuiltinFunctions): Int { + val optimizer = StatementOptimizer(this, errors, functions) optimizer.visit(this) val optimizationCount = optimizer.applyModifications() diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index ba62fe856..988e047d2 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -1,5 +1,6 @@ package prog8.optimizer +import prog8.ast.IBuiltinFunctions import prog8.ast.INameScope import prog8.ast.Node import prog8.ast.Program @@ -10,18 +11,17 @@ import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification import prog8.ast.walk.IAstVisitor import prog8.compiler.ErrorReporter -import prog8.compiler.functions.BuiltinFunctions import prog8.compiler.target.CompilationTarget import kotlin.math.floor internal class StatementOptimizer(private val program: Program, - private val errors: ErrorReporter + private val errors: ErrorReporter, + private val functions: IBuiltinFunctions ) : AstWalker() { private val noModifications = emptyList() private val callgraph = CallGraph(program) - private val pureBuiltinFunctions = BuiltinFunctions.filter { it.value.pure } override fun after(block: Block, parent: Node): Iterable { if("force_output" !in block.options()) { @@ -72,9 +72,9 @@ internal class StatementOptimizer(private val program: Program, } override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable { - if(functionCallStatement.target.nameInSource.size==1 && functionCallStatement.target.nameInSource[0] in BuiltinFunctions) { + if(functionCallStatement.target.nameInSource.size==1 && functionCallStatement.target.nameInSource[0] in functions.names) { val functionName = functionCallStatement.target.nameInSource[0] - if (functionName in pureBuiltinFunctions) { + if (functionName in functions.purefunctionNames) { errors.warn("statement has no effect (function return value is discarded)", functionCallStatement.position) return listOf(IAstModification.Remove(functionCallStatement, functionCallStatement.definingScope())) } diff --git a/compiler/test/UnitTests.kt b/compiler/test/UnitTests.kt index e9b745bda..d7b8beefc 100644 --- a/compiler/test/UnitTests.kt +++ b/compiler/test/UnitTests.kt @@ -408,6 +408,7 @@ class TestPetscii { class TestMemory { private class DummyFunctions: IBuiltinFunctions { override val names: Set = emptySet() + override val purefunctionNames: Set = emptySet() override fun constValue(name: String, args: List, position: Position): NumericLiteralValue? = null override fun returnType(name: String, args: MutableList) = InferredTypes.InferredType.unknown() } diff --git a/compilerAst/src/prog8/ast/AstToplevel.kt b/compilerAst/src/prog8/ast/AstToplevel.kt index cb05704b3..5c1132834 100644 --- a/compilerAst/src/prog8/ast/AstToplevel.kt +++ b/compilerAst/src/prog8/ast/AstToplevel.kt @@ -246,6 +246,7 @@ interface IAssignable { interface IBuiltinFunctions { val names: Set + val purefunctionNames: Set fun constValue(name: String, args: List, position: Position): NumericLiteralValue? fun returnType(name: String, args: MutableList): InferredTypes.InferredType }