tweaked ast modifications

This commit is contained in:
Irmen de Jong 2020-03-21 18:39:36 +01:00
parent 824f06e17f
commit 43781c02d0
11 changed files with 320 additions and 104 deletions

View File

@ -36,6 +36,8 @@ interface Node {
return this
throw FatalAstException("scope missing from $this")
}
fun replaceChildNode(node: Node, replacement: Node)
}
interface IFunctionCall {
@ -226,6 +228,12 @@ class Program(val name: String, val modules: MutableList<Module>): Node {
it.linkParents(this)
}
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(node is Module && replacement is Module)
val idx = modules.indexOf(node)
modules[idx] = replacement
}
}
class Module(override val name: String,
@ -247,6 +255,11 @@ class Module(override val name: String,
}
override fun definingScope(): INameScope = program.namespace
override fun replaceChildNode(node: Node, replacement: Node) {
require(node is Statement && replacement is Statement)
val idx = statements.indexOf(node)
statements[idx] = replacement
}
override fun toString() = "Module(name=$name, pos=$position, lib=$isLibraryModule)"
@ -266,6 +279,10 @@ class GlobalNamespace(val modules: List<Module>): Node, INameScope {
modules.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
throw FatalAstException("cannot replace anything in the namespace")
}
override fun lookup(scopedName: List<String>, localContext: Node): Statement? {
if (scopedName.size == 1 && scopedName[0] in BuiltinFunctions) {
// builtin functions always exist, return a dummy localContext for them

View File

@ -150,6 +150,7 @@ object ParentSentinel : Node {
override val position = Position("<<sentinel>>", 0, 0, 0)
override var parent: Node = this
override fun linkParents(parent: Node) {}
override fun replaceChildNode(node: Node, replacement: Node) {}
}
data class Position(val file: String, val line: Int, val startCol: Int, val endCol: Int) {

View File

@ -58,6 +58,11 @@ class PrefixExpression(val operator: String, var expression: Expression, overrid
expression.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(node === expression && replacement is Expression)
expression = replacement
}
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
@ -100,6 +105,15 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
right.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression)
when {
node===left -> left = replacement
node===right -> right = replacement
else -> throw FatalAstException("invalid replace")
}
}
override fun toString(): String {
return "[$left $operator $right]"
}
@ -211,6 +225,11 @@ class ArrayIndexedExpression(var identifier: IdentifierReference,
arrayspec.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is IdentifierReference && node===identifier)
identifier = replacement
}
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
@ -243,6 +262,11 @@ class TypecastExpression(var expression: Expression, var type: DataType, val imp
expression.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression && node===expression)
expression = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -269,6 +293,11 @@ data class AddressOf(var identifier: IdentifierReference, override val position:
identifier.parent=this
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is IdentifierReference && node===identifier)
identifier = replacement
}
override fun constValue(program: Program): NumericLiteralValue? = null
override fun referencesIdentifiers(vararg name: String) = false
override fun inferType(program: Program): InferredTypes.InferredType = InferredTypes.knownFor(DataType.UWORD)
@ -285,6 +314,11 @@ class DirectMemoryRead(var addressExpression: Expression, override val position:
this.addressExpression.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression && node===addressExpression)
addressExpression = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node)= visitor.visit(this, parent)
@ -338,6 +372,10 @@ class NumericLiteralValue(val type: DataType, // only numerical types allowed
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) {
throw FatalAstException("can't replace here")
}
override fun referencesIdentifiers(vararg name: String) = false
override fun constValue(program: Program) = this
@ -427,6 +465,10 @@ class StructLiteralValue(var values: List<Expression>,
values.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
throw FatalAstException("can't replace here")
}
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
@ -452,6 +494,11 @@ class StringLiteralValue(val value: String,
override fun linkParents(parent: Node) {
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) {
throw FatalAstException("can't replace here")
}
override fun referencesIdentifiers(vararg name: String) = false
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
@ -480,6 +527,13 @@ class ArrayLiteralValue(val type: InferredTypes.InferredType, // inferred be
this.parent = parent
value.forEach {it.linkParents(this)}
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression)
val idx = value.indexOf(node)
value[idx] = replacement
}
override fun referencesIdentifiers(vararg name: String) = value.any { it.referencesIdentifiers(*name) }
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
@ -564,6 +618,16 @@ class RangeExpr(var from: Expression,
step.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression)
when {
from===node -> from=replacement
to===node -> to=replacement
step===node -> step=replacement
else -> throw FatalAstException("invalid replacement")
}
}
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
@ -640,6 +704,10 @@ class RegisterExpr(val register: Register, override val position: Position) : Ex
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) {
throw FatalAstException("can't replace here")
}
override fun constValue(program: Program): NumericLiteralValue? = null
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
@ -669,6 +737,10 @@ data class IdentifierReference(val nameInSource: List<String>, override val posi
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) {
throw FatalAstException("can't replace here")
}
override fun constValue(program: Program): NumericLiteralValue? {
val node = program.namespace.lookup(nameInSource, this)
?: throw UndefinedSymbolError(this)
@ -725,6 +797,15 @@ class FunctionCall(override var target: IdentifierReference,
args.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
if(node===target)
target=replacement as IdentifierReference
else {
val idx = args.indexOf(node)
args[idx] = replacement as Expression
}
}
override fun constValue(program: Program) = constValue(program, true)
private fun constValue(program: Program, withDatatypeCheck: Boolean): NumericLiteralValue? {

View File

@ -23,19 +23,7 @@ interface IAstModification {
}
}
class ReplaceStmt(val statement: Statement, val replacement: Statement, val parent: Node) : IAstModification {
override fun perform() {
if(parent is INameScope) {
val idx = parent.statements.indexOf(statement)
parent.statements[idx] = replacement
replacement.linkParents(parent)
} else {
throw FatalAstException("parent of a replace modification is not an INameScope")
}
}
}
class ReplaceExpr(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification {
class SetExpression(val setter: (newExpr: Expression) -> Unit, val newExpr: Expression, val parent: Node) : IAstModification {
override fun perform() {
setter(newExpr)
newExpr.linkParents(parent)
@ -53,6 +41,13 @@ interface IAstModification {
}
}
}
class ReplaceNode(val node: Node, val replacement: Node, val parent: Node) : IAstModification {
override fun perform() {
parent.replaceChildNode(node, replacement)
replacement.parent = parent
}
}
}

View File

@ -12,7 +12,7 @@ internal class ForeverLoopsMaker: AstWalker() {
val numeric = repeatLoop.untilCondition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() == 0) {
val forever = ForeverLoop(repeatLoop.body, repeatLoop.position)
return listOf(IAstModification.ReplaceStmt(repeatLoop, forever, parent))
return listOf(IAstModification.ReplaceNode(repeatLoop, forever, parent))
}
return emptyList()
}
@ -21,7 +21,7 @@ internal class ForeverLoopsMaker: AstWalker() {
val numeric = whileLoop.condition as? NumericLiteralValue
if(numeric!=null && numeric.number.toInt() != 0) {
val forever = ForeverLoop(whileLoop.body, whileLoop.position)
return listOf(IAstModification.ReplaceStmt(whileLoop, forever, parent))
return listOf(IAstModification.ReplaceNode(whileLoop, forever, parent))
}
return emptyList()
}

View File

@ -26,14 +26,10 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt.typeOrElse(DataType.STRUCT), rightDt.typeOrElse(DataType.STRUCT), expr.left, expr.right)
if(toFix!=null) {
return when {
toFix===expr.left -> listOf(IAstModification.ReplaceExpr(
{ newExpr -> expr.left = newExpr },
TypecastExpression(expr.left, commonDt, true, expr.left.position),
expr))
toFix===expr.right -> listOf(IAstModification.ReplaceExpr(
{ newExpr -> expr.right = newExpr },
TypecastExpression(expr.right, commonDt, true, expr.right.position),
expr))
toFix===expr.left -> listOf(IAstModification.ReplaceNode(
expr.left, TypecastExpression(expr.left, commonDt, true, expr.left.position), expr))
toFix===expr.right -> listOf(IAstModification.ReplaceNode(
expr.right, TypecastExpression(expr.right, commonDt, true, expr.right.position), expr))
else -> throw FatalAstException("confused binary expression side")
}
}
@ -50,8 +46,8 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
val valuetype = valueItype.typeOrElse(DataType.STRUCT)
if (valuetype != targettype) {
if (valuetype isAssignableTo targettype)
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> assignment.value=newExpr },
return listOf(IAstModification.ReplaceNode(
assignment.value,
TypecastExpression(assignment.value, targettype, true, assignment.value.position),
assignment))
}
@ -78,8 +74,8 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
val requiredType = arg.first.type
if (requiredType != argtype) {
if (argtype isAssignableTo requiredType) {
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> call.args[arg.second.index] = newExpr },
return listOf(IAstModification.ReplaceNode(
call.args[arg.second.index],
TypecastExpression(arg.second.value, requiredType, true, arg.second.value.position),
call as Node))
}
@ -100,11 +96,10 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
continue
for (possibleType in arg.first.possibleDatatypes) {
if (argtype isAssignableTo possibleType) {
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> call.args[arg.second.index] = newExpr },
return listOf(IAstModification.ReplaceNode(
call.args[arg.second.index],
TypecastExpression(arg.second.value, possibleType, true, arg.second.value.position),
call as Node
))
call as Node))
}
}
}
@ -131,11 +126,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val typecast = (memread.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
?: TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> memread.addressExpression = newExpr },
typecast,
memread
))
return listOf(IAstModification.ReplaceNode(memread.addressExpression, typecast, memread))
}
return emptyList()
}
@ -146,11 +137,7 @@ class TypecastsAdder(val program: Program, val errors: ErrorReporter) : AstWalke
if(dt.isKnown && dt.typeOrElse(DataType.UWORD)!=DataType.UWORD) {
val typecast = (memwrite.addressExpression as? NumericLiteralValue)?.cast(DataType.UWORD)
?: TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
return listOf(IAstModification.ReplaceExpr(
{ newExpr -> memwrite.addressExpression = newExpr },
typecast,
memwrite
))
return listOf(IAstModification.ReplaceNode(memwrite.addressExpression, typecast, memwrite))
}
return emptyList()
}

View File

@ -23,17 +23,16 @@ internal class VarInitValueCreator(val program: Program): AstWalker() {
// array datatype without initialization value, add list of zeros
val arraysize = decl.arraysize!!.size()!!
val zero = decl.asDefaultValueDecl(decl).value!!
return listOf(IAstModification.ReplaceExpr(
return listOf(IAstModification.SetExpression( // can't use replaceNode here because value is null
{ newExpr -> decl.value = newExpr },
ArrayLiteralValue(InferredTypes.InferredType.known(decl.datatype),
Array(arraysize) { zero },
decl.position),
decl
))
decl))
}
if(decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes) {
val declvalue = decl.value!!
val declvalue = decl.value
if(decl.type == VarDeclType.VAR && declvalue != null && decl.datatype in NumericDatatypes) {
val value =
if(declvalue is NumericLiteralValue)
declvalue.cast(decl.datatype)
@ -49,8 +48,8 @@ internal class VarInitValueCreator(val program: Program): AstWalker() {
val zero = decl.asDefaultValueDecl(decl).value!!
return listOf(
IAstModification.Insert(decl, initvalue, parent),
IAstModification.ReplaceExpr(
{ newExpr -> decl.value = newExpr },
IAstModification.ReplaceNode(
declvalue,
zero,
decl
)
@ -100,11 +99,10 @@ internal class VarInitValueCreator(val program: Program): AstWalker() {
if(idref!=null) {
val variable = idref.targetVarDecl(program.namespace)
if(variable!=null && variable.datatype in IterableDatatypes) {
replacements += IAstModification.ReplaceExpr(
{ newExpr -> arglist[argparam.first.index] = newExpr },
replacements += IAstModification.ReplaceNode(
arglist[argparam.first.index],
AddressOf(idref, idref.position),
parent
)
parent)
}
}
}

View File

@ -48,6 +48,7 @@ class BuiltinFunctionStatementPlaceholder(val name: String, override val positio
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
override fun definingScope(): INameScope = BuiltinFunctionScopePlaceholder
override fun replaceChildNode(node: Node, replacement: Node) {}
override val expensiveToInline = false
}
@ -67,6 +68,12 @@ class Block(override val name: String,
statements.forEach {it.linkParents(this)}
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement)
val idx = statements.indexOf(node)
statements[idx] = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -87,6 +94,7 @@ data class Directive(val directive: String, val args: List<DirectiveArg>, overri
args.forEach{it.linkParents(this)}
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -98,6 +106,7 @@ data class DirectiveArg(val str: String?, val name: String?, val int: Int?, over
override fun linkParents(parent: Node) {
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
}
data class Label(val name: String, override val position: Position) : Statement() {
@ -108,6 +117,7 @@ data class Label(val name: String, override val position: Position) : Statement(
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -126,6 +136,11 @@ open class Return(var value: Expression?, override val position: Position) : Sta
value?.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression)
value = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -142,6 +157,7 @@ class ReturnFromIrq(override val position: Position) : Return(null, position) {
override fun toString(): String {
return "ReturnFromIrq(pos=$position)"
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
}
class Continue(override val position: Position) : Statement() {
@ -152,6 +168,7 @@ class Continue(override val position: Position) : Statement() {
this.parent=parent
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -165,6 +182,7 @@ class Break(override val position: Position) : Statement() {
this.parent=parent
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -243,6 +261,11 @@ class VarDecl(val type: VarDeclType,
}
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression && node===value)
value = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -298,6 +321,11 @@ class ArrayIndex(var index: Expression, override val position: Position) : Node
index.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression && node===index)
index = replacement
}
companion object {
fun forArray(v: ArrayLiteralValue): ArrayIndex {
return ArrayIndex(NumericLiteralValue.optimalNumeric(v.value.size, v.position), v.position)
@ -328,6 +356,14 @@ open class Assignment(var target: AssignTarget, val aug_op : String?, var value:
value.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
when {
node===target -> target = replacement as AssignTarget
node===value -> value = replacement as Expression
else -> throw FatalAstException("invalid replace")
}
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -356,6 +392,14 @@ data class AssignTarget(val register: Register?,
memoryAddress?.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
when {
node===identifier -> identifier = replacement as IdentifierReference
node===arrayindexed -> arrayindexed = replacement as ArrayIndexedExpression
else -> throw FatalAstException("invalid replace")
}
}
fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
fun accept(visitor: IAstVisitor) = visitor.visit(this)
fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -455,6 +499,11 @@ class PostIncrDecr(var target: AssignTarget, val operator: String, override val
target.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is AssignTarget && node===target)
target = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -476,6 +525,7 @@ class Jump(val address: Int?,
identifier?.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -499,6 +549,15 @@ class FunctionCallStatement(override var target: IdentifierReference,
args.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
if(node===target)
target = replacement as IdentifierReference
else {
val idx = args.indexOf(node)
args[idx] = replacement as Expression
}
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -516,6 +575,7 @@ class InlineAssembly(val assembly: String, override val position: Position) : St
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -542,6 +602,12 @@ class AnonymousScope(override var statements: MutableList<Statement>,
statements.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement)
val idx = statements.indexOf(node)
statements[idx] = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -555,6 +621,7 @@ class NopStatement(override val position: Position): Statement() {
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) = throw FatalAstException("can't replace here")
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -598,6 +665,12 @@ class Subroutine(override val name: String,
statements.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement)
val idx = statements.indexOf(node)
statements[idx] = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -622,6 +695,10 @@ open class SubroutineParameter(val name: String,
override fun linkParents(parent: Node) {
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) {
throw FatalAstException("can't replace anything in a subroutineparameter node")
}
}
class IfStatement(var condition: Expression,
@ -639,6 +716,15 @@ class IfStatement(var condition: Expression,
elsepart.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
when {
node===condition -> condition = replacement as Expression
node===truepart -> truepart = replacement as AnonymousScope
node===elsepart -> elsepart = replacement as AnonymousScope
else -> throw FatalAstException("invalid replace")
}
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -659,6 +745,14 @@ class BranchStatement(var condition: BranchCondition,
elsepart.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
when {
node===truepart -> truepart = replacement as AnonymousScope
node===elsepart -> elsepart = replacement as AnonymousScope
else -> throw FatalAstException("invalid replace")
}
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -680,6 +774,15 @@ class ForLoop(val loopRegister: Register?,
body.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
when {
node===loopVar -> loopVar = replacement as IdentifierReference
node===iterable -> iterable = replacement as Expression
node===body -> body = replacement as AnonymousScope
else -> throw FatalAstException("invalid replace")
}
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -707,6 +810,14 @@ class WhileLoop(var condition: Expression,
body.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
when {
node===condition -> condition = replacement as Expression
node===body -> body = replacement as AnonymousScope
else -> throw FatalAstException("invalid replace")
}
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -721,6 +832,11 @@ class ForeverLoop(var body: AnonymousScope, override val position: Position) : S
body.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is AnonymousScope && node===body)
body = replacement
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -738,6 +854,14 @@ class RepeatLoop(var body: AnonymousScope,
body.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
when {
node===untilCondition -> untilCondition = replacement as Expression
node===body -> body = replacement as AnonymousScope
else -> throw FatalAstException("invalid replace")
}
}
override fun accept(visitor: IAstModifyingVisitor) = visitor.visit(this)
override fun accept(visitor: IAstVisitor) = visitor.visit(this)
override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent)
@ -755,6 +879,15 @@ class WhenStatement(var condition: Expression,
choices.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
if(node===condition)
condition = replacement as Expression
else {
val idx = choices.indexOf(node)
choices[idx] = replacement as WhenChoice
}
}
fun choiceValues(program: Program): List<Pair<List<Int>?, WhenChoice>> {
// only gives sensible results when the choices are all valid (constant integers)
val result = mutableListOf<Pair<List<Int>?, WhenChoice>>()
@ -788,6 +921,11 @@ class WhenChoice(var values: List<Expression>?, // if null, this is t
this.parent = parent
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is AnonymousScope && node===statements)
statements = replacement
}
override fun toString(): String {
return "Choice($values at $position)"
}
@ -810,6 +948,12 @@ class StructDecl(override val name: String,
this.statements.forEach { it.linkParents(this) }
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Statement)
val idx = statements.indexOf(node)
statements[idx] = replacement
}
val numberOfElements: Int
get() = this.statements.size
@ -828,6 +972,11 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position
this.addressExpression.linkParents(this)
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is Expression && node===addressExpression)
addressExpression = replacement
}
override fun toString(): String {
return "DirectMemoryWrite($addressExpression)"
}

View File

@ -1,11 +1,13 @@
package prog8.optimizer
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.*
import prog8.ast.expressions.*
import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModification
import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.statements.Assignment
import prog8.ast.statements.Statement
import kotlin.math.abs
import kotlin.math.log2
import kotlin.math.pow
@ -17,68 +19,49 @@ import kotlin.math.pow
*/
internal class ExpressionSimplifier(private val program: Program) : IAstModifyingVisitor {
var optimizationsDone: Int = 0
override fun visit(assignment: Assignment): Statement {
class ExpressionSimplifier2(private val program: Program): AstWalker() {
var optimizationsDone = 0
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
if (assignment.aug_op != null)
throw AstException("augmented assignments should have been converted to normal assignments before this optimizer: $assignment")
return super.visit(assignment)
return emptyList()
}
override fun visit(memread: DirectMemoryRead): Expression {
// @( &thing ) --> thing
val addrOf = memread.addressExpression as? AddressOf
if(addrOf!=null)
return super.visit(addrOf.identifier)
return super.visit(memread)
}
override fun visit(typecast: TypecastExpression): Expression {
var tc = typecast
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
val mods = mutableListOf<IAstModification>()
// try to statically convert a literal value into one of the desired type
val literal = tc.expression as? NumericLiteralValue
val literal = typecast.expression as? NumericLiteralValue
if(literal!=null) {
val newLiteral = literal.cast(tc.type)
val newLiteral = literal.cast(typecast.type)
if(newLiteral!==literal) {
optimizationsDone++
return newLiteral
mods += IAstModification.ReplaceNode(typecast.expression, newLiteral, typecast)
}
}
// remove redundant typecasts
while(true) {
val expr = tc.expression
if(expr !is TypecastExpression || expr.type!=tc.type) {
val assignment = typecast.parent as? Assignment
if(assignment!=null) {
val targetDt = assignment.target.inferType(program, assignment)
if(tc.expression.inferType(program)==targetDt) {
optimizationsDone++
return tc.expression
}
}
val subTc = tc.expression as? TypecastExpression
if(subTc!=null) {
// if the previous typecast was casting to a 'bigger' type, just ignore that one
// if the previous typecast was casting to a similar type, ignore that one
if(subTc.type largerThan tc.type || subTc.type equalsSize tc.type) {
subTc.type = tc.type
subTc.parent = tc.parent
optimizationsDone++
return subTc
}
}
return super.visit(tc)
}
optimizationsDone++
tc = expr
// remove redundant nested typecasts:
// if the typecast casts a value to the same type, remove the cast.
// if the typecast contains another typecast, remove the inner typecast.
val subTypecast = typecast.expression as? TypecastExpression
if(subTypecast!=null) {
mods += IAstModification.ReplaceNode(typecast.expression, subTypecast.expression, typecast)
} else {
if(typecast.expression.inferType(program).istype(typecast.type))
mods += IAstModification.ReplaceNode(typecast, typecast.expression, parent)
}
optimizationsDone += mods.size
return mods
}
}
internal class ExpressionSimplifier(private val program: Program) : IAstModifyingVisitor {
var optimizationsDone: Int = 0
override fun visit(expr: PrefixExpression): Expression {
if (expr.operator == "+") {

View File

@ -27,7 +27,11 @@ internal fun Program.optimizeStatements(errors: ErrorReporter): Int {
}
internal fun Program.simplifyExpressions() : Int {
val opti = ExpressionSimplifier2(this)
opti.visit(this)
opti.applyModifications()
val optimizer = ExpressionSimplifier(this)
optimizer.visit(this)
return optimizer.optimizationsDone
return opti.optimizationsDone + optimizer.optimizationsDone
}

View File

@ -11,6 +11,9 @@ main {
ubyte key=c64.GETIN()
ubyte[] zzzz = [1,2,3]
A = 9.0 as ubyte as uword as ubyte as ubyte
A = Y as ubyte
A = @(&bb1)
A = len(meuk)
A = msb(meuk)
; A = strlen(meuk)
@ -26,6 +29,4 @@ main {
c64.CHROUT('\n')
c64.CHROUT('\n')
}
}
}