
288 lines
13 KiB
Raw Normal View History

2021-02-07 05:55:56 +00:00
package prog8.compiler.astprocessing
2020-07-03 20:09:44 +00:00
import prog8.ast.IFunctionCall
import prog8.ast.IStatementContainer
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.FatalAstException
import prog8.ast.expressions.*
import prog8.ast.statements.*
2021-02-07 05:49:51 +00:00
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
2022-04-10 11:24:17 +00:00
import prog8.code.core.*
2020-07-03 20:09:44 +00:00
2022-01-06 21:45:36 +00:00
internal class VariousCleanups(val program: Program, val errors: IErrorReporter, val options: CompilationOptions): AstWalker() {
2020-07-03 20:09:44 +00:00
override fun after(block: Block, parent: Node): Iterable<IAstModification> {
val inheritOptions = block.definingModule.options() intersect setOf("splitarrays", "no_symbol_prefixing", "ignore_unused") subtract block.options()
if(inheritOptions.isNotEmpty()) {
val directive = Directive("%option",{ DirectiveArg(null, it, null, block.position) }, block.position)
return listOf(IAstModification.InsertFirst(directive, block))
return noModifications
override fun after(scope: AnonymousScope, parent: Node): Iterable<IAstModification> {
return if(parent is IStatementContainer)
listOf(ScopeFlatten(scope, parent as IStatementContainer))
2020-07-03 20:09:44 +00:00
private class ScopeFlatten(val scope: AnonymousScope, val into: IStatementContainer) : IAstModification {
2020-07-03 20:09:44 +00:00
override fun perform() {
val idx = into.statements.indexOf(scope)
if(idx>=0) {
into.statements.addAll(idx+1, scope.statements)
scope.statements.forEach { it.parent = into as Node }
2020-07-03 20:09:44 +00:00
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
val constValue = typecast.constValue(program)
return listOf(IAstModification.ReplaceNode(typecast, constValue, parent))
2022-02-10 23:21:40 +00:00
if(typecast.expression is NumericLiteral) {
2024-02-04 23:53:18 +00:00
val value = (typecast.expression as NumericLiteral).cast(typecast.type, typecast.implicit)
2024-02-04 23:53:18 +00:00
return listOf(IAstModification.ReplaceNode(typecast, value.valueOrZero(), parent))
val sourceDt = typecast.expression.inferType(program)
if(sourceDt istype typecast.type)
return listOf(IAstModification.ReplaceNode(typecast, typecast.expression, parent))
2021-11-12 22:23:51 +00:00
if(parent is Assignment) {
2024-01-11 23:27:39 +00:00
val targetDt = { throw FatalAstException("invalid dt ${}") }
2021-11-12 22:23:51 +00:00
if(sourceDt istype targetDt) {
2022-06-29 20:19:44 +00:00
// we can get rid of this typecast because the type is already the target type
2021-11-12 22:23:51 +00:00
return listOf(IAstModification.ReplaceNode(typecast, typecast.expression, parent))
2021-11-12 01:17:37 +00:00
return noModifications
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
if( isSameAs assignment.value) {
// remove assignment to self
return listOf(IAstModification.Remove(assignment, parent as IStatementContainer))
// remove duplicated assignments, but not if it's a memory mapped IO register
val isIO = try {
} catch (_: FatalAstException) {
if(!isIO) {
val nextAssign = assignment.nextSibling() as? Assignment
if (nextAssign != null &&, program)) {
if (!nextAssign.isAugmentable && nextAssign.value isSameAs assignment.value && assignment.value !is IFunctionCall) // don't remove function calls even when they're duplicates
return listOf(IAstModification.Remove(assignment, parent as IStatementContainer))
return noModifications
override fun after(expr: PrefixExpression, parent: Node): Iterable<IAstModification> {
if(expr.operator=="+") {
// +X --> X
return listOf(IAstModification.ReplaceNode(expr, expr.expression, parent))
if(!options.strictBool && expr.operator=="not") {
if(expr.expression.inferType(program).isBytes) {
// not bytevalue --> bytevalue==0
val cmp = BinaryExpression(expr.expression, "==",
NumericLiteral(expr.expression.inferType(program).getOr(DataType.UNDEFINED), 0.0, expr.expression.position),
return listOf(IAstModification.ReplaceNode(expr, cmp, parent))
return noModifications
override fun before(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
// try to replace a multi-comparison expression (if x==1 | x==2 | x==3 ... ) by a simple containment check.
// but only if the containment check is the top-level expression.
if(parent is BinaryExpression)
return noModifications
2022-07-11 10:22:14 +00:00
if(expr.operator == "|" || expr.operator=="or") {
val leftBinExpr1 = expr.left as? BinaryExpression
val rightBinExpr1 = expr.right as? BinaryExpression
if(rightBinExpr1?.operator=="==" && rightBinExpr1.right is NumericLiteral && leftBinExpr1!=null) {
val needle = rightBinExpr1.left
val values = mutableListOf(rightBinExpr1.right as NumericLiteral)
fun isMultiComparisonRecurse(expr: BinaryExpression): Boolean {
if(expr.operator=="==") {
if(expr.right is NumericLiteral && expr.left isSameAs needle) {
values.add(expr.right as NumericLiteral)
return true
return false
2022-07-11 10:22:14 +00:00
if(expr.operator!="|" && expr.operator!="or")
return false
val leftBinExpr = expr.left as? BinaryExpression
val rightBinExpr = expr.right as? BinaryExpression
if(leftBinExpr==null || rightBinExpr==null || rightBinExpr.right !is NumericLiteral || !rightBinExpr.left.isSameAs(needle))
return false
values.add(rightBinExpr.right as NumericLiteral)
return false
return isMultiComparisonRecurse(leftBinExpr)
if(isMultiComparisonRecurse(leftBinExpr1)) {
// replace it!
val valueCopies = values.sortedBy { it.number }.map { it.copy() }
val elementType = needle.inferType(program).getOrElse { throw FatalAstException("invalid needle dt") }
val arrayType = ElementToArrayTypes.getValue(elementType)
val valuesArray = ArrayLiteral(InferredTypes.InferredType.known(arrayType), valueCopies.toTypedArray(), expr.position)
val containment = ContainmentCheck(needle, valuesArray, expr.position)
return listOf(IAstModification.ReplaceNode(expr, containment, parent))
return noModifications
2021-12-27 14:04:25 +00:00
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
if(expr.operator in ComparisonOperators) {
val leftConstVal = expr.left.constValue(program)
val rightConstVal = expr.right.constValue(program)
// make sure the constant value is on the right of the comparison expression
if(rightConstVal==null && leftConstVal!=null) {
val newOperator =
when(expr.operator) {
"<" -> ">"
"<=" -> ">="
">" -> "<"
">=" -> "<="
else -> expr.operator
val replacement = BinaryExpression(expr.right, newOperator, expr.left, expr.position)
return listOf(IAstModification.ReplaceNode(expr, replacement, parent))
return noModifications
override fun after(containment: ContainmentCheck, parent: Node): Iterable<IAstModification> {
// replace trivial containment checks with just false or a single comparison
2022-02-10 23:21:40 +00:00
fun replaceWithEquals(value: NumericLiteral): Iterable<IAstModification> {
2023-12-28 12:48:01 +00:00"containment could be written as just a single comparison", containment.position)
val equals = BinaryExpression(containment.element, "==", value, containment.position)
return listOf(IAstModification.ReplaceNode(containment, equals, parent))
fun replaceWithFalse(): Iterable<IAstModification> {
errors.warn("condition is always false", containment.position)
2022-07-04 21:42:49 +00:00
return listOf(IAstModification.ReplaceNode(containment, NumericLiteral(DataType.UBYTE, 0.0, containment.position), parent))
fun checkArray(array: Array<Expression>): Iterable<IAstModification> {
return replaceWithFalse()
if(array.size==1) {
val constVal = array[0].constValue(program)
return replaceWithEquals(constVal)
return noModifications
fun checkArray(variable: VarDecl): Iterable<IAstModification> {
return if(variable.value==null) {
val arraySpec = variable.arraysize!!
val size = arraySpec.indexExpr.constValue(program)?.number?.toInt() ?: throw FatalAstException("no array size")
return if(size==0)
else if(variable.value is ArrayLiteral) {
checkArray((variable.value as ArrayLiteral).value)
else noModifications
2022-02-10 23:21:40 +00:00
fun checkString(stringVal: StringLiteral): Iterable<IAstModification> {
return replaceWithFalse()
if(stringVal.value.length==1) {
val string = program.encoding.encodeString(stringVal.value, stringVal.encoding)
2022-02-10 23:21:40 +00:00
return replaceWithEquals(NumericLiteral(DataType.UBYTE, string[0].toDouble(), stringVal.position))
return noModifications
when(containment.iterable) {
2022-02-10 23:21:40 +00:00
is ArrayLiteral -> {
val array = (containment.iterable as ArrayLiteral).value
return checkArray(array)
is RangeExpression -> {
val constValues = (containment.iterable as RangeExpression).toConstantIntegerRange()
if(constValues!=null) {
if (constValues.isEmpty())
return replaceWithFalse()
if (constValues.count()==1)
2022-02-10 23:21:40 +00:00
return replaceWithEquals(NumericLiteral.optimalNumeric(constValues.first, containment.position))
2022-02-10 23:21:40 +00:00
is StringLiteral -> {
val stringVal = containment.iterable as StringLiteral
return checkString(stringVal)
else -> {}
return noModifications
override fun after(branch: ConditionalBranch, parent: Node): Iterable<IAstModification> {
if(branch.truepart.isEmpty() && branch.elsepart.isEmpty()) {
2023-12-28 12:48:01 +00:00"removing empty conditional branch", branch.position)
return listOf(IAstModification.Remove(branch, parent as IStatementContainer))
return noModifications
override fun after(ifElse: IfElse, parent: Node): Iterable<IAstModification> {
if(ifElse.truepart.isEmpty() && ifElse.elsepart.isEmpty()) {
2023-12-28 12:48:01 +00:00"removing empty if-else statement", ifElse.position)
return listOf(IAstModification.Remove(ifElse, parent as IStatementContainer))
return noModifications
override fun after(arrayIndexedExpression: ArrayIndexedExpression, parent: Node): Iterable<IAstModification> {
val index = arrayIndexedExpression.indexer.constIndex()
if(index!=null && index<0) {
val target = arrayIndexedExpression.arrayvar.targetVarDecl(program)
val arraysize = target?.arraysize?.constIndex()
if(arraysize!=null) {
// replace the negative index by the normal index
val newIndex = NumericLiteral.optimalNumeric(arraysize+index, arrayIndexedExpression.indexer.position)
arrayIndexedExpression.indexer.indexExpr = newIndex
return noModifications
2020-07-03 20:09:44 +00:00