From 422b390c48b0a1ffe789f50d7dfbf464989ddb7b Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Fri, 2 Apr 2021 16:56:52 +0200 Subject: [PATCH] fix ast node duplication/reference bug in certain optimizers --- .../compiler/astprocessing/AstChecker.kt | 2 +- .../compiler/astprocessing/VariousCleanups.kt | 88 ++++++++++++++++--- .../src/prog8/optimizer/BinExprSplitter.kt | 8 +- .../src/prog8/optimizer/StatementOptimizer.kt | 7 +- .../prog8/ast/expressions/AstExpressions.kt | 4 + .../src/prog8/ast/statements/AstStatements.kt | 9 +- compilerAst/src/prog8/ast/walk/AstWalker.kt | 2 +- docs/source/todo.rst | 1 + examples/test.p8 | 4 +- 9 files changed, 101 insertions(+), 24 deletions(-) diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index 291cfd33a..b0d1018e2 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -70,7 +70,7 @@ internal class AstChecker(private val program: Program, if(expectedReturnValues.size==1 && returnStmt.value!=null) { val valueDt = returnStmt.value!!.inferType(program) if(!valueDt.isKnown) { - errors.err("return value type mismatch", returnStmt.value!!.position) + errors.err("return value type mismatch or unknown symbol", returnStmt.value!!.position) } else { if (expectedReturnValues[0] != valueDt.typeOrElse(DataType.STRUCT)) errors.err("type $valueDt of return value doesn't match subroutine's return type ${expectedReturnValues[0]}", returnStmt.value!!.position) diff --git a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt index 8fda7178c..304250048 100644 --- a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt +++ b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt @@ -4,6 +4,7 @@ import prog8.ast.IFunctionCall import prog8.ast.INameScope import prog8.ast.Node import prog8.ast.Program +import prog8.ast.base.FatalAstException import prog8.ast.base.Position import prog8.ast.expressions.* import prog8.ast.statements.* @@ -57,56 +58,119 @@ internal class VariousCleanups(private val program: Program, val errors: IErrorR private fun before(functionCall: IFunctionCall, parent: Node, position: Position): Iterable { + val modifications = mutableListOf() + if(compilerOptions.optimize) { val sub = functionCall.target.targetSubroutine(program) if(sub!=null && sub.inline && !sub.isAsmSubroutine) - annotateInlinedSubroutineIdentifiers(sub) + modifications.addAll(annotateInlinedSubroutineIdentifiers(sub)) } - if(functionCall.target.nameInSource==listOf("peek")) { // peek(a) is synonymous with @(a) val memread = DirectMemoryRead(functionCall.args.single(), position) - return listOf(IAstModification.ReplaceNode(functionCall as Node, memread, parent)) + modifications.add(IAstModification.ReplaceNode(functionCall as Node, memread, parent)) } if(functionCall.target.nameInSource==listOf("poke")) { // poke(a, v) is synonymous with @(a) = v val tgt = AssignTarget(null, null, DirectMemoryWrite(functionCall.args[0], position), position) val assign = Assignment(tgt, functionCall.args[1], position) - return listOf(IAstModification.ReplaceNode(functionCall as Node, assign, parent)) + modifications.add(IAstModification.ReplaceNode(functionCall as Node, assign, parent)) } + return modifications + } + + override fun after(assignment: Assignment, parent: Node): Iterable { + if(assignment.parent!==parent) + throw FatalAstException("parent node mismatch at $assignment") return noModifications } - private fun annotateInlinedSubroutineIdentifiers(sub: Subroutine) { + override fun after(assignTarget: AssignTarget, parent: Node): Iterable { + if(assignTarget.parent!==parent) + throw FatalAstException("parent node mismatch at $assignTarget") + return noModifications + } + + override fun after(decl: VarDecl, parent: Node): Iterable { + if(decl.parent!==parent) + throw FatalAstException("parent node mismatch at $decl") + return noModifications + } + + override fun after(scope: AnonymousScope, parent: Node): Iterable { + if(scope.parent!==parent) + throw FatalAstException("parent node mismatch at $scope") + return noModifications + } + + override fun after(typecast: TypecastExpression, parent: Node): Iterable { + if(typecast.parent!==parent) + throw FatalAstException("parent node mismatch at $typecast") + return noModifications + } + + override fun after(returnStmt: Return, parent: Node): Iterable { + if(returnStmt.parent!==parent) + throw FatalAstException("parent node mismatch at $returnStmt") + return noModifications + } + + override fun after(identifier: IdentifierReference, parent: Node): Iterable { + if(identifier.parent!==parent) + throw FatalAstException("parent node mismatch at $identifier") + return noModifications + } + + private fun annotateInlinedSubroutineIdentifiers(sub: Subroutine): List { // this adds full name prefixes to all identifiers used in the subroutine, // so that the statements can be inlined (=copied) in the call site and still reference // the correct symbols as seen from the scope of the subroutine. + // TODO warning : "inlining a subroutine with variables, this could result in large code/memory size", identifier.position) + class Annotator: AstWalker() { var numReturns=0 override fun before(identifier: IdentifierReference, parent: Node): Iterable { val stmt = identifier.targetStatement(program)!! - val prefixed = stmt.makeScopedName(identifier.nameInSource.last()).split('.') - val withPrefix = IdentifierReference(prefixed, identifier.position) - return listOf(IAstModification.ReplaceNode(identifier, withPrefix, identifier.parent)) + val subroutine = identifier.definingSubroutine() + return if(stmt is VarDecl && stmt.parent === subroutine) { + val prefixed = stmt.makeScopedName(identifier.nameInSource.last()).replace('.','_') + val withPrefix = IdentifierReference(listOf(prefixed), identifier.position) + listOf(IAstModification.ReplaceNode(identifier, withPrefix, parent)) + } else { + val prefixed = stmt.makeScopedName(identifier.nameInSource.last()).split('.') + val withPrefix = IdentifierReference(prefixed, identifier.position) + listOf(IAstModification.ReplaceNode(identifier, withPrefix, parent)) + } + } + + override fun after(decl: VarDecl, parent: Node): Iterable { + val prefixed = decl.makeScopedName(decl.name).replace('.','_') + val newdecl = VarDecl(decl.type, decl.datatype, decl.zeropage, decl.arraysize, prefixed, decl.struct?.name, decl.value, decl.isArray, decl.autogeneratedDontRemove, decl.position) + return listOf(IAstModification.ReplaceNode(decl, newdecl, parent)) } override fun before(returnStmt: Return, parent: Node): Iterable { numReturns++ if(parent !== sub || sub.indexOfChild(returnStmt) { + return this.modifications.map { it.first }.toList() } } val annotator = Annotator() sub.accept(annotator, sub.parent) - if(annotator.numReturns>1) + if(annotator.numReturns>1) { errors.err("inlined subroutine can only have one return statement", sub.position) - else - annotator.applyModifications() + return noModifications + } + return annotator.theModifications() } } diff --git a/compiler/src/prog8/optimizer/BinExprSplitter.kt b/compiler/src/prog8/optimizer/BinExprSplitter.kt index 423ea23af..f873a0f0c 100644 --- a/compiler/src/prog8/optimizer/BinExprSplitter.kt +++ b/compiler/src/prog8/optimizer/BinExprSplitter.kt @@ -1,5 +1,6 @@ package prog8.optimizer +import prog8.ast.INameScope import prog8.ast.Node import prog8.ast.Program import prog8.ast.expressions.* @@ -58,12 +59,13 @@ X = BinExpr X = LeftExpr return noModifications if(isSimpleExpression(binExpr.right) && !assignment.isAugmentable) { - val firstAssign = Assignment(assignment.target, binExpr.left, binExpr.left.position) + val firstAssign = Assignment(assignment.target.copy(), binExpr.left, binExpr.left.position) val targetExpr = assignment.target.toExpression() val augExpr = BinaryExpression(targetExpr, binExpr.operator, binExpr.right, binExpr.right.position) return listOf( - IAstModification.InsertBefore(assignment, firstAssign, assignment.definingScope()), - IAstModification.ReplaceNode(assignment.value, augExpr, assignment)) + IAstModification.ReplaceNode(binExpr, augExpr, assignment), + IAstModification.InsertBefore(assignment, firstAssign, assignment.parent as INameScope) + ) } } diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index 5da266424..914c76b75 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -451,10 +451,11 @@ internal class StatementOptimizer(private val program: Program, if (returnDt in IntegerDatatypes) { // first assign to intermediary variable, then return that subsThatNeedReturnVariable.add(Triple(subr, returnDt, returnStmt.position)) - val returnValueIntermediary = IdentifierReference(listOf(retvalName), returnStmt.position) - val tgt = AssignTarget(returnValueIntermediary, null, null, returnStmt.position) + val returnValueIntermediary1 = IdentifierReference(listOf(retvalName), returnStmt.position) + val returnValueIntermediary2 = IdentifierReference(listOf(retvalName), returnStmt.position) + val tgt = AssignTarget(returnValueIntermediary1, null, null, returnStmt.position) val assign = Assignment(tgt, value, returnStmt.position) - val returnReplacement = Return(returnValueIntermediary, returnStmt.position) + val returnReplacement = Return(returnValueIntermediary2, returnStmt.position) return listOf( IAstModification.InsertBefore(returnStmt, assign, parent as INameScope), IAstModification.ReplaceNode(returnStmt, returnReplacement, parent) diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index 14f4f397f..eaa3edd87 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -271,6 +271,8 @@ class ArrayIndexedExpression(var arrayvar: IdentifierReference, override fun toString(): String { return "ArrayIndexed(ident=$arrayvar, arraysize=$indexer; pos=$position)" } + + fun copy() = ArrayIndexedExpression(arrayvar.copy(), indexer.copy(), position) } class TypecastExpression(var expression: Expression, var type: DataType, val implicit: Boolean, override val position: Position) : Expression() { @@ -351,6 +353,8 @@ class DirectMemoryRead(var addressExpression: Expression, override val position: override fun toString(): String { return "DirectMemoryRead($addressExpression)" } + + fun copy() = DirectMemoryRead(addressExpression, position) } class NumericLiteralValue(val type: DataType, // only numerical types allowed diff --git a/compilerAst/src/prog8/ast/statements/AstStatements.kt b/compilerAst/src/prog8/ast/statements/AstStatements.kt index 5ff050a3e..29841a46a 100644 --- a/compilerAst/src/prog8/ast/statements/AstStatements.kt +++ b/compilerAst/src/prog8/ast/statements/AstStatements.kt @@ -312,6 +312,7 @@ class ArrayIndex(var indexExpr: Expression, fun constIndex() = (indexExpr as? NumericLiteralValue)?.number?.toInt() infix fun isSameAs(other: ArrayIndex): Boolean = indexExpr isSameAs other.indexExpr + fun copy() = ArrayIndex(indexExpr, position) } open class Assignment(var target: AssignTarget, var value: Expression, override val position: Position) : Statement() { @@ -429,9 +430,10 @@ data class AssignTarget(var identifier: IdentifierReference?, } fun toExpression(): Expression { + // return a copy of the assignment target but as a source expression. return when { - identifier != null -> identifier!! - arrayindexed != null -> arrayindexed!! + identifier != null -> identifier!!.copy() + arrayindexed != null -> arrayindexed!!.copy() memoryAddress != null -> DirectMemoryRead(memoryAddress.addressExpression, memoryAddress.position) else -> throw FatalAstException("invalid assignmenttarget $this") } @@ -476,6 +478,8 @@ data class AssignTarget(var identifier: IdentifierReference?, } return false } + + fun copy() = AssignTarget(identifier?.copy(), arrayindexed?.copy(), memoryAddress?.copy(), position) } @@ -992,4 +996,5 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) + fun copy() = DirectMemoryWrite(addressExpression, position) } diff --git a/compilerAst/src/prog8/ast/walk/AstWalker.kt b/compilerAst/src/prog8/ast/walk/AstWalker.kt index 08a7d1ecf..c3f050a5a 100644 --- a/compilerAst/src/prog8/ast/walk/AstWalker.kt +++ b/compilerAst/src/prog8/ast/walk/AstWalker.kt @@ -157,7 +157,7 @@ abstract class AstWalker { open fun after(whenStatement: WhenStatement, parent: Node): Iterable = emptyList() open fun after(whileLoop: WhileLoop, parent: Node): Iterable = emptyList() - private val modifications = mutableListOf>() + protected val modifications = mutableListOf>() // private val modificationsReplacedNodes = mutableSetOf>() private fun track(mods: Iterable, node: Node, parent: Node) { diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 2e98e58fe..433ed9b85 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -3,6 +3,7 @@ TODO ==== - allow inlining of subroutines with vardecls +- allow inlining of subroutines with params - optimize several inner loops in gfx2 - hoist all variable declarations up to the subroutine scope *before* even the constant folding takes place (to avoid undefined symbol errors when referring to a variable from another nested scope in the subroutine) - optimize swap of two memread values with index, using the same pointer expression/variable, like swap(@(ptr+1), @(ptr+2)) diff --git a/examples/test.p8 b/examples/test.p8 index c4858afeb..31611435a 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -26,7 +26,7 @@ main { block3 { ubyte returnvalue=10 - sub thing()->ubyte { + inline sub thing()->ubyte { return returnvalue } } @@ -41,6 +41,6 @@ otherblock { } inline sub othersub() -> ubyte { - return calc(calcparam) + return calc(calcparam)+othervar } }