From 43781c02d0a2b447d6797a69071e24f553209e6d Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sat, 21 Mar 2020 18:39:36 +0100 Subject: [PATCH] tweaked ast modifications --- compiler/src/prog8/ast/AstToplevel.kt | 17 ++ compiler/src/prog8/ast/base/Base.kt | 1 + .../prog8/ast/expressions/AstExpressions.kt | 81 ++++++++++ .../src/prog8/ast/processing/AstWalker.kt | 21 +-- .../prog8/ast/processing/ForeverLoopsMaker.kt | 4 +- .../prog8/ast/processing/TypecastsAdder.kt | 39 ++--- .../VarInitValueAndAddressOfCreator.kt | 20 ++- .../src/prog8/ast/statements/AstStatements.kt | 149 ++++++++++++++++++ .../prog8/optimizer/ExpressionSimplifier.kt | 79 ++++------ compiler/src/prog8/optimizer/Extensions.kt | 6 +- examples/test.p8 | 7 +- 11 files changed, 320 insertions(+), 104 deletions(-) diff --git a/compiler/src/prog8/ast/AstToplevel.kt b/compiler/src/prog8/ast/AstToplevel.kt index 06bf59f0b..6444e8a60 100644 --- a/compiler/src/prog8/ast/AstToplevel.kt +++ b/compiler/src/prog8/ast/AstToplevel.kt @@ -36,6 +36,8 @@ interface Node { return this throw FatalAstException("scope missing from $this") } + + fun replaceChildNode(node: Node, replacement: Node) } interface IFunctionCall { @@ -226,6 +228,12 @@ class Program(val name: String, val modules: MutableList): Node { it.linkParents(this) } } + + override fun replaceChildNode(node: Node, replacement: Node) { + require(node is Module && replacement is Module) + val idx = modules.indexOf(node) + modules[idx] = replacement + } } class Module(override val name: String, @@ -247,6 +255,11 @@ class Module(override val name: String, } override fun definingScope(): INameScope = program.namespace + override fun replaceChildNode(node: Node, replacement: Node) { + require(node is Statement && replacement is Statement) + val idx = statements.indexOf(node) + statements[idx] = replacement + } override fun toString() = "Module(name=$name, pos=$position, lib=$isLibraryModule)" @@ -266,6 +279,10 @@ class GlobalNamespace(val modules: List): Node, INameScope { modules.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + throw FatalAstException("cannot replace anything in the namespace") + } + override fun lookup(scopedName: List, localContext: Node): Statement? { if (scopedName.size == 1 && scopedName[0] in BuiltinFunctions) { // builtin functions always exist, return a dummy localContext for them diff --git a/compiler/src/prog8/ast/base/Base.kt b/compiler/src/prog8/ast/base/Base.kt index c398d8d15..cfdb9e4af 100644 --- a/compiler/src/prog8/ast/base/Base.kt +++ b/compiler/src/prog8/ast/base/Base.kt @@ -150,6 +150,7 @@ object ParentSentinel : Node { override val position = Position("<>", 0, 0, 0) override var parent: Node = this override fun linkParents(parent: Node) {} + override fun replaceChildNode(node: Node, replacement: Node) {} } data class Position(val file: String, val line: Int, val startCol: Int, val endCol: Int) { diff --git a/compiler/src/prog8/ast/expressions/AstExpressions.kt b/compiler/src/prog8/ast/expressions/AstExpressions.kt index 1757d3b20..2ab2222f6 100644 --- a/compiler/src/prog8/ast/expressions/AstExpressions.kt +++ b/compiler/src/prog8/ast/expressions/AstExpressions.kt @@ -58,6 +58,11 @@ class PrefixExpression(val operator: String, var expression: Expression, overrid expression.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(node === expression && replacement is Expression) + expression = replacement + } + override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) @@ -100,6 +105,15 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex right.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression) + when { + node===left -> left = replacement + node===right -> right = replacement + else -> throw FatalAstException("invalid replace") + } + } + override fun toString(): String { return "[$left $operator $right]" } @@ -211,6 +225,11 @@ class ArrayIndexedExpression(var identifier: IdentifierReference, arrayspec.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is IdentifierReference && node===identifier) + identifier = replacement + } + override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) @@ -243,6 +262,11 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp expression.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression && node===expression) + expression = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) @@ -269,6 +293,11 @@ data class AddressOf(var identifier: IdentifierReference, override val position: identifier.parent=this } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is IdentifierReference && node===identifier) + identifier = replacement + } + override fun constValue(program: Program): NumericLiteralValue? = null override fun referencesIdentifiers(vararg name: String) = false override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD) @@ -285,6 +314,11 @@ class DirectMemoryRead(var addressExpression: Expression, override val position: this.addressExpression.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression && node===addressExpression) + addressExpression = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent) @@ -338,6 +372,10 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) { + throw FatalAstException("can't replace here") + } + override fun referencesIdentifiers(vararg name: String) = false override fun constValue(program: Program) = this @@ -427,6 +465,10 @@ class StructLiteralValue(var values: List, values.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + throw FatalAstException("can't replace here") + } + override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) @@ -452,6 +494,11 @@ class StringLiteralValue(val value: String, override fun linkParents(parent: Node) { this.parent = parent } + + override fun replaceChildNode(node: Node, replacement: Node) { + throw FatalAstException("can't replace here") + } + override fun referencesIdentifiers(vararg name: String) = false override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) @@ -480,6 +527,13 @@ class ArrayLiteralValue(val type: InferredTypes.InferredType, // inferred be this.parent = parent value.forEach {it.linkParents(this)} } + + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression) + val idx = value.indexOf(node) + value[idx] = replacement + } + override fun referencesIdentifiers(vararg name: String) = value.any { it.referencesIdentifiers(*name) } override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) @@ -564,6 +618,16 @@ class RangeExpr(var from: Expression, step.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression) + when { + from===node -> from=replacement + to===node -> to=replacement + step===node -> step=replacement + else -> throw FatalAstException("invalid replacement") + } + } + override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) @@ -640,6 +704,10 @@ class RegisterExpr(val register: Register, override val position: Position) : Ex this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) { + throw FatalAstException("can't replace here") + } + override fun constValue(program: Program): NumericLiteralValue? = null override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) @@ -669,6 +737,10 @@ data class IdentifierReference(val nameInSource: List, override val posi this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) { + throw FatalAstException("can't replace here") + } + override fun constValue(program: Program): NumericLiteralValue? { val node = program.namespace.lookup(nameInSource, this) ?: throw UndefinedSymbolError(this) @@ -725,6 +797,15 @@ class FunctionCall(override var target: IdentifierReference, args.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + if(node===target) + target=replacement as IdentifierReference + else { + val idx = args.indexOf(node) + args[idx] = replacement as Expression + } + } + override fun constValue(program: Program) = constValue(program, true) private fun constValue(program: Program, withDatatypeCheck: Boolean): NumericLiteralValue? { diff --git a/compiler/src/prog8/ast/processing/AstWalker.kt b/compiler/src/prog8/ast/processing/AstWalker.kt index 475ea0185..487a39488 100644 --- a/compiler/src/prog8/ast/processing/AstWalker.kt +++ b/compiler/src/prog8/ast/processing/AstWalker.kt @@ -23,19 +23,7 @@ interface IAstModification { } } - class ReplaceStmt(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification { - override fun perform() { - if(parent is INameScope) { - val idx = parent.statements.indexOf(statement) - parent.statements[idx] = replacement - replacement.linkParents(parent) - } else { - throw FatalAstException("parent of a replace modification is not an INameScope") - } - } - } - - class ReplaceExpr(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification { + class SetExpression(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification { override fun perform() { setter(newExpr) newExpr.linkParents(parent) @@ -53,6 +41,13 @@ interface IAstModification { } } } + + class ReplaceNode(val node: Node, val replacement: Node, val parent: Node) : IAstModification { + override fun perform() { + parent.replaceChildNode(node, replacement) + replacement.parent = parent + } + } } diff --git a/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt b/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt index e5288c7fd..89ca459e9 100644 --- a/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt +++ b/compiler/src/prog8/ast/processing/ForeverLoopsMaker.kt @@ -12,7 +12,7 @@ internal class ForeverLoopsMaker: AstWalker() { val numeric = repeatLoop.untilCondition as? NumericLiteralValue if(numeric!=null && numeric.number.toInt() == 0) { val forever = ForeverLoop(repeatLoop.body, repeatLoop.position) - return listOf(IAstModification.ReplaceStmt(repeatLoop, forever, parent)) + return listOf(IAstModification.ReplaceNode(repeatLoop, forever, parent)) } return emptyList() } @@ -21,7 +21,7 @@ internal class ForeverLoopsMaker: AstWalker() { val numeric = whileLoop.condition as? NumericLiteralValue if(numeric!=null && numeric.number.toInt() != 0) { val forever = ForeverLoop(whileLoop.body, whileLoop.position) - return listOf(IAstModification.ReplaceStmt(whileLoop, forever, parent)) + return listOf(IAstModification.ReplaceNode(whileLoop, forever, parent)) } return emptyList() } diff --git a/compiler/src/prog8/ast/processing/TypecastsAdder.kt b/compiler/src/prog8/ast/processing/TypecastsAdder.kt index df955185d..93587226b 100644 --- a/compiler/src/prog8/ast/processing/TypecastsAdder.kt +++ b/compiler/src/prog8/ast/processing/TypecastsAdder.kt @@ -26,14 +26,10 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr.left, expr.right) if(toFix!=null) { return when { - toFix===expr.left -> listOf(IAstModification.ReplaceExpr( - { newExpr -> expr.left = newExpr }, - TypecastExpression(expr.left, commonDt, true, expr.left.position), - expr)) - toFix===expr.right -> listOf(IAstModification.ReplaceExpr( - { newExpr -> expr.right = newExpr }, - TypecastExpression(expr.right, commonDt, true, expr.right.position), - expr)) + toFix===expr.left -> listOf(IAstModification.ReplaceNode( + expr.left, TypecastExpression(expr.left, commonDt, true, expr.left.position), expr)) + toFix===expr.right -> listOf(IAstModification.ReplaceNode( + expr.right, TypecastExpression(expr.right, commonDt, true, expr.right.position), expr)) else -> throw FatalAstException("confused binary expression side") } } @@ -50,8 +46,8 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke val valuetype = valueItype.typeOrElse(DataType.STRUCT) if (valuetype != targettype) { if (valuetype isAssignableTo targettype) - return listOf(IAstModification.ReplaceExpr( - { newExpr -> assignment.value=newExpr }, + return listOf(IAstModification.ReplaceNode( + assignment.value, TypecastExpression(assignment.value, targettype, true, assignment.value.position), assignment)) } @@ -78,8 +74,8 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke val requiredType = arg.first.type if (requiredType != argtype) { if (argtype isAssignableTo requiredType) { - return listOf(IAstModification.ReplaceExpr( - { newExpr -> call.args[arg.second.index] = newExpr }, + return listOf(IAstModification.ReplaceNode( + call.args[arg.second.index], TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position), call as Node)) } @@ -100,11 +96,10 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke continue for (possibleType in arg.first.possibleDatatypes) { if (argtype isAssignableTo possibleType) { - return listOf(IAstModification.ReplaceExpr( - { newExpr -> call.args[arg.second.index] = newExpr }, + return listOf(IAstModification.ReplaceNode( + call.args[arg.second.index], TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position), - call as Node - )) + call as Node)) } } } @@ -131,11 +126,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD) ?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) - return listOf(IAstModification.ReplaceExpr( - { newExpr -> memread.addressExpression = newExpr }, - typecast, - memread - )) + return listOf(IAstModification.ReplaceNode(memread.addressExpression, typecast, memread)) } return emptyList() } @@ -146,11 +137,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) { val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD) ?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) - return listOf(IAstModification.ReplaceExpr( - { newExpr -> memwrite.addressExpression = newExpr }, - typecast, - memwrite - )) + return listOf(IAstModification.ReplaceNode(memwrite.addressExpression, typecast, memwrite)) } return emptyList() } diff --git a/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt b/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt index 7341d891e..3ad6d676d 100644 --- a/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt +++ b/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt @@ -23,17 +23,16 @@ internal class VarInitValueCreator(val program: Program): AstWalker() { // array datatype without initialization value, add list of zeros val arraysize = decl.arraysize!!.size()!! val zero = decl.asDefaultValueDecl(decl).value!! - return listOf(IAstModification.ReplaceExpr( + return listOf(IAstModification.SetExpression( // can't use replaceNode here because value is null { newExpr -> decl.value = newExpr }, ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype), Array(arraysize) { zero }, decl.position), - decl - )) + decl)) } - if(decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes) { - val declvalue = decl.value!! + val declvalue = decl.value + if(decl.type == VarDeclType.VAR && declvalue != null && decl.datatype in NumericDatatypes) { val value = if(declvalue is NumericLiteralValue) declvalue.cast(decl.datatype) @@ -49,8 +48,8 @@ internal class VarInitValueCreator(val program: Program): AstWalker() { val zero = decl.asDefaultValueDecl(decl).value!! return listOf( IAstModification.Insert(decl, initvalue, parent), - IAstModification.ReplaceExpr( - { newExpr -> decl.value = newExpr }, + IAstModification.ReplaceNode( + declvalue, zero, decl ) @@ -100,11 +99,10 @@ internal class VarInitValueCreator(val program: Program): AstWalker() { if(idref!=null) { val variable = idref.targetVarDecl(program.namespace) if(variable!=null && variable.datatype in IterableDatatypes) { - replacements += IAstModification.ReplaceExpr( - { newExpr -> arglist[argparam.first.index] = newExpr }, + replacements += IAstModification.ReplaceNode( + arglist[argparam.first.index], AddressOf(idref, idref.position), - parent - ) + parent) } } } diff --git a/compiler/src/prog8/ast/statements/AstStatements.kt b/compiler/src/prog8/ast/statements/AstStatements.kt index a893f850e..202032609 100644 --- a/compiler/src/prog8/ast/statements/AstStatements.kt +++ b/compiler/src/prog8/ast/statements/AstStatements.kt @@ -48,6 +48,7 @@ class BuiltinFunctionStatementPlaceholder(val name: String, override val positio override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun definingScope(): INameScope = BuiltinFunctionScopePlaceholder + override fun replaceChildNode(node: Node, replacement: Node) {} override val expensiveToInline = false } @@ -67,6 +68,12 @@ class Block(override val name: String, statements.forEach {it.linkParents(this)} } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Statement) + val idx = statements.indexOf(node) + statements[idx] = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -87,6 +94,7 @@ data class Directive(val directive: String, val args: List, overri args.forEach{it.linkParents(this)} } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -98,6 +106,7 @@ data class DirectiveArg(val str: String?, val name: String?, val int: Int?, over override fun linkParents(parent: Node) { this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") } data class Label(val name: String, override val position: Position) : Statement() { @@ -108,6 +117,7 @@ data class Label(val name: String, override val position: Position) : Statement( this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -126,6 +136,11 @@ open class Return(var value: Expression?, override val position: Position) : Sta value?.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression) + value = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -142,6 +157,7 @@ class ReturnFromIrq(override val position: Position) : Return(null, position) { override fun toString(): String { return "ReturnFromIrq(pos=$position)" } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") } class Continue(override val position: Position) : Statement() { @@ -152,6 +168,7 @@ class Continue(override val position: Position) : Statement() { this.parent=parent } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -165,6 +182,7 @@ class Break(override val position: Position) : Statement() { this.parent=parent } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -243,6 +261,11 @@ class VarDecl(val type: VarDeclType, } } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression && node===value) + value = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -298,6 +321,11 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node index.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression && node===index) + index = replacement + } + companion object { fun forArray(v: ArrayLiteralValue): ArrayIndex { return ArrayIndex(NumericLiteralValue.optimalNumeric(v.value.size, v.position), v.position) @@ -328,6 +356,14 @@ open class Assignment(var target: AssignTarget, val aug_op : String?, var value: value.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + when { + node===target -> target = replacement as AssignTarget + node===value -> value = replacement as Expression + else -> throw FatalAstException("invalid replace") + } + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -356,6 +392,14 @@ data class AssignTarget(val register: Register?, memoryAddress?.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + when { + node===identifier -> identifier = replacement as IdentifierReference + node===arrayindexed -> arrayindexed = replacement as ArrayIndexedExpression + else -> throw FatalAstException("invalid replace") + } + } + fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) fun accept(visitor: IAstVisitor) = visitor.visit(this) fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -455,6 +499,11 @@ class PostIncrDecr(var target: AssignTarget, val operator: String, override val target.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is AssignTarget && node===target) + target = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -476,6 +525,7 @@ class Jump(val address: Int?, identifier?.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -499,6 +549,15 @@ class FunctionCallStatement(override var target: IdentifierReference, args.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + if(node===target) + target = replacement as IdentifierReference + else { + val idx = args.indexOf(node) + args[idx] = replacement as Expression + } + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -516,6 +575,7 @@ class InlineAssembly(val assembly: String, override val position: Position) : St this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -542,6 +602,12 @@ class AnonymousScope(override var statements: MutableList, statements.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Statement) + val idx = statements.indexOf(node) + statements[idx] = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -555,6 +621,7 @@ class NopStatement(override val position: Position): Statement() { this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here") override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -598,6 +665,12 @@ class Subroutine(override val name: String, statements.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Statement) + val idx = statements.indexOf(node) + statements[idx] = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -622,6 +695,10 @@ open class SubroutineParameter(val name: String, override fun linkParents(parent: Node) { this.parent = parent } + + override fun replaceChildNode(node: Node, replacement: Node) { + throw FatalAstException("can't replace anything in a subroutineparameter node") + } } class IfStatement(var condition: Expression, @@ -639,6 +716,15 @@ class IfStatement(var condition: Expression, elsepart.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + when { + node===condition -> condition = replacement as Expression + node===truepart -> truepart = replacement as AnonymousScope + node===elsepart -> elsepart = replacement as AnonymousScope + else -> throw FatalAstException("invalid replace") + } + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -659,6 +745,14 @@ class BranchStatement(var condition: BranchCondition, elsepart.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + when { + node===truepart -> truepart = replacement as AnonymousScope + node===elsepart -> elsepart = replacement as AnonymousScope + else -> throw FatalAstException("invalid replace") + } + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -680,6 +774,15 @@ class ForLoop(val loopRegister: Register?, body.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + when { + node===loopVar -> loopVar = replacement as IdentifierReference + node===iterable -> iterable = replacement as Expression + node===body -> body = replacement as AnonymousScope + else -> throw FatalAstException("invalid replace") + } + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -707,6 +810,14 @@ class WhileLoop(var condition: Expression, body.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + when { + node===condition -> condition = replacement as Expression + node===body -> body = replacement as AnonymousScope + else -> throw FatalAstException("invalid replace") + } + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -721,6 +832,11 @@ class ForeverLoop(var body: AnonymousScope, override val position: Position) : S body.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is AnonymousScope && node===body) + body = replacement + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -738,6 +854,14 @@ class RepeatLoop(var body: AnonymousScope, body.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + when { + node===untilCondition -> untilCondition = replacement as Expression + node===body -> body = replacement as AnonymousScope + else -> throw FatalAstException("invalid replace") + } + } + override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this) override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) @@ -755,6 +879,15 @@ class WhenStatement(var condition: Expression, choices.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + if(node===condition) + condition = replacement as Expression + else { + val idx = choices.indexOf(node) + choices[idx] = replacement as WhenChoice + } + } + fun choiceValues(program: Program): List?, WhenChoice>> { // only gives sensible results when the choices are all valid (constant integers) val result = mutableListOf?, WhenChoice>>() @@ -788,6 +921,11 @@ class WhenChoice(var values: List?, // if null, this is t this.parent = parent } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is AnonymousScope && node===statements) + statements = replacement + } + override fun toString(): String { return "Choice($values at $position)" } @@ -810,6 +948,12 @@ class StructDecl(override val name: String, this.statements.forEach { it.linkParents(this) } } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Statement) + val idx = statements.indexOf(node) + statements[idx] = replacement + } + val numberOfElements: Int get() = this.statements.size @@ -828,6 +972,11 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position this.addressExpression.linkParents(this) } + override fun replaceChildNode(node: Node, replacement: Node) { + require(replacement is Expression && node===addressExpression) + addressExpression = replacement + } + override fun toString(): String { return "DirectMemoryWrite($addressExpression)" } diff --git a/compiler/src/prog8/optimizer/ExpressionSimplifier.kt b/compiler/src/prog8/optimizer/ExpressionSimplifier.kt index 1fb16a03b..aa46e1d36 100644 --- a/compiler/src/prog8/optimizer/ExpressionSimplifier.kt +++ b/compiler/src/prog8/optimizer/ExpressionSimplifier.kt @@ -1,11 +1,13 @@ package prog8.optimizer +import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.* import prog8.ast.expressions.* +import prog8.ast.processing.AstWalker +import prog8.ast.processing.IAstModification import prog8.ast.processing.IAstModifyingVisitor import prog8.ast.statements.Assignment -import prog8.ast.statements.Statement import kotlin.math.abs import kotlin.math.log2 import kotlin.math.pow @@ -17,68 +19,49 @@ import kotlin.math.pow */ -internal class ExpressionSimplifier(private val program: Program) : IAstModifyingVisitor { - var optimizationsDone: Int = 0 - override fun visit(assignment: Assignment): Statement { +class ExpressionSimplifier2(private val program: Program): AstWalker() { + var optimizationsDone = 0 + + override fun after(assignment: Assignment, parent: Node): Iterable { if (assignment.aug_op != null) throw AstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment") - return super.visit(assignment) + return emptyList() } - override fun visit(memread: DirectMemoryRead): Expression { - // @( &thing ) --> thing - val addrOf = memread.addressExpression as? AddressOf - if(addrOf!=null) - return super.visit(addrOf.identifier) - return super.visit(memread) - } - - override fun visit(typecast: TypecastExpression): Expression { - var tc = typecast + override fun after(typecast: TypecastExpression, parent: Node): Iterable { + val mods = mutableListOf() // try to statically convert a literal value into one of the desired type - val literal = tc.expression as? NumericLiteralValue + val literal = typecast.expression as? NumericLiteralValue if(literal!=null) { - val newLiteral = literal.cast(tc.type) + val newLiteral = literal.cast(typecast.type) if(newLiteral!==literal) { optimizationsDone++ - return newLiteral + mods += IAstModification.ReplaceNode(typecast.expression, newLiteral, typecast) } } - // remove redundant typecasts - while(true) { - val expr = tc.expression - if(expr !is TypecastExpression || expr.type!=tc.type) { - val assignment = typecast.parent as? Assignment - if(assignment!=null) { - val targetDt = assignment.target.inferType(program, assignment) - if(tc.expression.inferType(program)==targetDt) { - optimizationsDone++ - return tc.expression - } - } - - val subTc = tc.expression as? TypecastExpression - if(subTc!=null) { - // if the previous typecast was casting to a 'bigger' type, just ignore that one - // if the previous typecast was casting to a similar type, ignore that one - if(subTc.type largerThan tc.type || subTc.type equalsSize tc.type) { - subTc.type = tc.type - subTc.parent = tc.parent - optimizationsDone++ - return subTc - } - } - - return super.visit(tc) - } - - optimizationsDone++ - tc = expr + // remove redundant nested typecasts: + // if the typecast casts a value to the same type, remove the cast. + // if the typecast contains another typecast, remove the inner typecast. + val subTypecast = typecast.expression as? TypecastExpression + if(subTypecast!=null) { + mods += IAstModification.ReplaceNode(typecast.expression, subTypecast.expression, typecast) + } else { + if(typecast.expression.inferType(program).istype(typecast.type)) + mods += IAstModification.ReplaceNode(typecast, typecast.expression, parent) } + + optimizationsDone += mods.size + return mods } +} + + + +internal class ExpressionSimplifier(private val program: Program) : IAstModifyingVisitor { + var optimizationsDone: Int = 0 override fun visit(expr: PrefixExpression): Expression { if (expr.operator == "+") { diff --git a/compiler/src/prog8/optimizer/Extensions.kt b/compiler/src/prog8/optimizer/Extensions.kt index f14767d27..983c0d160 100644 --- a/compiler/src/prog8/optimizer/Extensions.kt +++ b/compiler/src/prog8/optimizer/Extensions.kt @@ -27,7 +27,11 @@ internal fun Program.optimizeStatements(errors: ErrorReporter): Int { } internal fun Program.simplifyExpressions() : Int { + val opti = ExpressionSimplifier2(this) + opti.visit(this) + opti.applyModifications() + val optimizer = ExpressionSimplifier(this) optimizer.visit(this) - return optimizer.optimizationsDone + return opti.optimizationsDone + optimizer.optimizationsDone } diff --git a/examples/test.p8 b/examples/test.p8 index f98d35ff4..e91cf3c20 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -11,6 +11,9 @@ main { ubyte key=c64.GETIN() ubyte[] zzzz = [1,2,3] + A = 9.0 as ubyte as uword as ubyte as ubyte + A = Y as ubyte + A = @(&bb1) A = len(meuk) A = msb(meuk) ; A = strlen(meuk) @@ -26,6 +29,4 @@ main { c64.CHROUT('\n') c64.CHROUT('\n') } - - - } +}