
647 lines
32 KiB
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package prog8.optimizer
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.expressions.*
import prog8.ast.maySwapOperandOrder
import prog8.ast.statements.ForLoop
import prog8.ast.statements.RepeatLoop
import prog8.ast.statements.VarDecl
import prog8.ast.statements.VarDeclType
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
import prog8.code.core.AssociativeOperators
import prog8.code.core.DataType
import prog8.code.core.IErrorReporter
import kotlin.math.floor
class ConstantFoldingOptimizer(private val program: Program, private val errors: IErrorReporter) : AstWalker() {
private val evaluator = ConstExprEvaluator()
override fun before(memread: DirectMemoryRead, parent: Node): Iterable<IAstModification> {
// @( &thing ) --> thing (but only if thing is a byte type!)
val addrOf = memread.addressExpression as? AddressOf
if(addrOf!=null) {
return listOf(IAstModification.ReplaceNode(memread, addrOf.identifier, parent))
return noModifications
override fun after(containment: ContainmentCheck, parent: Node): Iterable<IAstModification> {
val result = containment.constValue(program)
return listOf(IAstModification.ReplaceNode(containment, result, parent))
return noModifications
override fun after(expr: PrefixExpression, parent: Node): Iterable<IAstModification> {
val constValue = expr.constValue(program) ?: return noModifications
return listOf(IAstModification.ReplaceNode(expr, constValue, parent))
* Try to constfold a binary expression.
* Compile-time constant sub expressions will be evaluated on the spot.
* For instance, "9 * (4 + 2)" will be optimized into the integer literal 54.
* More complex stuff: reordering to group constants:
* If one of our operands is a Constant,
* and the other operand is a Binary expression,
* and one of ITS operands is a Constant,
* and ITS other operand is NOT a Constant,
* may be possible to rewrite the expression to group the two Constants together,
* to allow them to be const-folded away.
* examples include:
* (X / c1) * c2 -> X / (c2/c1)
* (X + c1) - c2 -> X + (c1-c2)
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
val modifications = mutableListOf<IAstModification>()
val leftconst = expr.left.constValue(program)
val rightconst = expr.right.constValue(program)
if(expr.left.inferType(program) istype DataType.STR) {
if(expr.operator=="+" && expr.left is StringLiteral && expr.right is StringLiteral) {
// concatenate two strings.
val leftString = expr.left as StringLiteral
val rightString = expr.right as StringLiteral
val concatenated = if(leftString.encoding==rightString.encoding) {
leftString.value + rightString.value
} else {
program.encoding.encodeString(leftString.value, leftString.encoding) + program.encoding.encodeString(rightString.value, rightString.encoding),
val concatStr = StringLiteral(concatenated, leftString.encoding, expr.position)
return listOf(IAstModification.ReplaceNode(expr, concatStr, parent))
else if(expr.operator=="*" && rightconst!=null && expr.left is StringLiteral) {
// mutiply a string.
val part = expr.left as StringLiteral
errors.warn("resulting string has length zero", part.position)
val newStr = StringLiteral(part.value.repeat(rightconst.number.toInt()), part.encoding, expr.position)
return listOf(IAstModification.ReplaceNode(expr, newStr, parent))
if(expr.operator=="==" && rightconst!=null) {
val leftExpr = expr.left as? BinaryExpression
// only do this shuffling when the LHS is not a constant itself (otherwise problematic nested replacements)
if(leftExpr?.constValue(program) != null) {
val leftRightConst = leftExpr.right.constValue(program)
if(leftRightConst!=null) {
when (leftExpr.operator) {
"+" -> {
// X + С1 == C2 --> X == C2 - C1
val newRightConst = NumericLiteral(rightconst.type, rightconst.number - leftRightConst.number, rightconst.position)
return listOf(
IAstModification.ReplaceNode(leftExpr, leftExpr.left, expr),
IAstModification.ReplaceNode(expr.right, newRightConst, expr)
"-" -> {
// X - С1 == C2 --> X == C2 + C1
val newRightConst = NumericLiteral(rightconst.type, rightconst.number + leftRightConst.number, rightconst.position)
return listOf(
IAstModification.ReplaceNode(leftExpr, leftExpr.left, expr),
IAstModification.ReplaceNode(expr.right, newRightConst, expr)
if(expr.inferType(program) istype DataType.FLOAT) {
val subExpr: BinaryExpression? = when {
leftconst != null -> expr.right as? BinaryExpression
rightconst != null -> expr.left as? BinaryExpression
else -> null
if (subExpr != null) {
val subleftconst = subExpr.left.constValue(program)
val subrightconst = subExpr.right.constValue(program)
if ((subleftconst != null && subrightconst == null) || (subleftconst == null && subrightconst != null)) {
// try reordering.
val change = groupTwoFloatConstsTogether(
expr, subExpr,
leftconst != null, rightconst != null,
subleftconst != null, subrightconst != null
if (change != null)
modifications += change
// const fold when both operands are a const.
// if in a chained comparison, that one has to be desugared first though.
if(leftconst != null && rightconst != null) {
if((expr.parent as? BinaryExpression)?.isChainedComparison()!=true) {
val result = evaluator.evaluate(leftconst, expr.operator, rightconst)
modifications += IAstModification.ReplaceNode(expr, result, parent)
if(leftconst==null && rightconst!=null && rightconst.number<0.0) {
if (expr.operator == "-") {
// X - -1 ---> X + 1
val posNumber = NumericLiteral.optimalNumeric(-rightconst.number, rightconst.position)
val plusExpr = BinaryExpression(expr.left, "+", posNumber, expr.position)
return listOf(IAstModification.ReplaceNode(expr, plusExpr, parent))
else if (expr.operator == "+") {
// X + -1 ---> X - 1
val posNumber = NumericLiteral.optimalNumeric(-rightconst.number, rightconst.position)
val plusExpr = BinaryExpression(expr.left, "-", posNumber, expr.position)
return listOf(IAstModification.ReplaceNode(expr, plusExpr, parent))
val leftBinExpr = expr.left as? BinaryExpression
val rightBinExpr = expr.right as? BinaryExpression
if(leftBinExpr!=null && rightconst!=null) {
if(expr.operator=="+" || expr.operator=="-") {
if(leftBinExpr.operator in listOf("+", "-")) {
val c2 = leftBinExpr.right.constValue(program)
if(c2!=null) {
// (X + C2) +/- rightConst --> X + (C2 +/- rightConst)
// (X - C2) +/- rightConst --> X - (C2 +/- rightConst) mind the flipped right operator
val operator = if(leftBinExpr.operator=="+") expr.operator else if(expr.operator=="-") "+" else "-"
val constants = BinaryExpression(c2, operator, rightconst, c2.position)
val newExpr = BinaryExpression(leftBinExpr.left, leftBinExpr.operator, constants, expr.position)
return listOf(IAstModification.ReplaceNode(expr, newExpr, parent))
else if(expr.operator=="*") {
val c2 = leftBinExpr.right.constValue(program)
if(c2!=null) {
if(leftBinExpr.operator=="*") {
// (X * C2) * rightConst --> X * (rightConst*C2)
val constants = BinaryExpression(rightconst, "*", c2, c2.position)
val newExpr = BinaryExpression(leftBinExpr.left, "*", constants, expr.position)
return listOf(IAstModification.ReplaceNode(expr, newExpr, parent))
} else if (leftBinExpr.operator=="/") {
if(expr.inferType(program).istype(DataType.FLOAT)) {
// (X / C2) * rightConst --> X * (rightConst/C2) only valid for floating point
val constants = BinaryExpression(rightconst, "/", c2, c2.position)
val newExpr = BinaryExpression(leftBinExpr.left, "*", constants, expr.position)
return listOf(IAstModification.ReplaceNode(expr, newExpr, parent))
else if(expr.operator=="/") {
val c2 = leftBinExpr.right.constValue(program)
if(c2!=null && leftBinExpr.operator=="/") {
// (X / C1) / C2 --> X / (C1*C2)
// NOTE: do not optimize (X * C1) / C2 for integers, --> X * (C1/C2) because this causes precision loss on integers
val constants = BinaryExpression(c2, "*", rightconst, c2.position)
val newExpr = BinaryExpression(leftBinExpr.left, "/", constants, expr.position)
return listOf(IAstModification.ReplaceNode(expr, newExpr, parent))
if(expr.operator=="+" || expr.operator=="-") {
if(leftBinExpr!=null && rightBinExpr!=null) {
val c1 = leftBinExpr.right.constValue(program)
val c2 = rightBinExpr.right.constValue(program)
if(leftBinExpr.operator=="+" && rightBinExpr.operator=="+") {
if (c1 != null && c2 != null) {
// (X + C1) <plusmin> (Y + C2) => (X <plusmin> Y) + (C1 <plusmin> C2)
val c3 = evaluator.evaluate(c1, expr.operator, c2)
val xwithy = BinaryExpression(leftBinExpr.left, expr.operator, rightBinExpr.left, expr.position)
val newExpr = BinaryExpression(xwithy, "+", c3, expr.position)
modifications += IAstModification.ReplaceNode(expr, newExpr, parent)
else if(leftBinExpr.operator=="-" && rightBinExpr.operator=="-") {
if (c1 != null && c2 != null) {
// (X - C1) <plusmin> (Y - C2) => (X <plusmin> Y) - (C1 <plusmin> C2)
val c3 = evaluator.evaluate(c1, expr.operator, c2)
val xwithy = BinaryExpression(leftBinExpr.left, expr.operator, rightBinExpr.left, expr.position)
val newExpr = BinaryExpression(xwithy, "-", c3, expr.position)
modifications += IAstModification.ReplaceNode(expr, newExpr, parent)
else if(leftBinExpr.operator=="*" && rightBinExpr.operator=="*"){
if (c1 != null && c2 != null && c1==c2) {
//(X * C) <plusmin> (Y * C) => (X <plusmin> Y) * C
val xwithy = BinaryExpression(leftBinExpr.left, expr.operator, rightBinExpr.left, expr.position)
val newExpr = BinaryExpression(xwithy, "*", c1, expr.position)
modifications += IAstModification.ReplaceNode(expr, newExpr, parent)
if(rightconst!=null && (expr.operator=="<<" || expr.operator==">>")) {
val dt = expr.left.inferType(program)
if(dt.isBytes && rightconst.number>=8) {
if(dt.istype(DataType.UBYTE)) {
val zeroUB = NumericLiteral(DataType.UBYTE, 0.0, expr.position)
modifications.add(IAstModification.ReplaceNode(expr, zeroUB, parent))
} else {
if(leftconst!=null) {
val zeroB = NumericLiteral(DataType.BYTE, 0.0, expr.position)
val minusoneB = NumericLiteral(DataType.BYTE, -1.0, expr.position)
if(leftconst.number<0.0) {
modifications.add(IAstModification.ReplaceNode(expr, zeroB, parent))
modifications.add(IAstModification.ReplaceNode(expr, minusoneB, parent))
} else {
modifications.add(IAstModification.ReplaceNode(expr, zeroB, parent))
else if(dt.isWords && rightconst.number>=16) {
if(dt.istype(DataType.UWORD)) {
val zeroUW = NumericLiteral(DataType.UWORD, 0.0, expr.position)
modifications.add(IAstModification.ReplaceNode(expr, zeroUW, parent))
} else {
if(leftconst!=null) {
val zeroW = NumericLiteral(DataType.WORD, 0.0, expr.position)
val minusoneW = NumericLiteral(DataType.WORD, -1.0, expr.position)
if(leftconst.number<0.0) {
modifications.add(IAstModification.ReplaceNode(expr, zeroW, parent))
modifications.add(IAstModification.ReplaceNode(expr, minusoneW, parent))
} else {
modifications.add(IAstModification.ReplaceNode(expr, zeroW, parent))
return modifications
override fun after(array: ArrayLiteral, parent: Node): Iterable<IAstModification> {
// because constant folding can result in arrays that are now suddenly capable
// of telling the type of all their elements (for instance, when they contained -2 which
// was a prefix expression earlier), we recalculate the array's datatype.
return noModifications
// if the array literalvalue is inside an array vardecl, take the type from that
// otherwise infer it from the elements of the array
val vardeclType = (array.parent as? VarDecl)?.datatype
if(vardeclType!=null) {
val newArray = array.cast(vardeclType)
if (newArray != null && newArray != array)
return listOf(IAstModification.ReplaceNode(array, newArray, parent))
} else {
val arrayDt = array.guessDatatype(program)
if (arrayDt.isKnown) {
val newArray = array.cast(arrayDt.getOr(DataType.UNDEFINED))
if (newArray != null && newArray != array)
return listOf(IAstModification.ReplaceNode(array, newArray, parent))
return noModifications
override fun after(functionCallExpr: FunctionCallExpression, parent: Node): Iterable<IAstModification> {
val constvalue = functionCallExpr.constValue(program)
return if(constvalue!=null)
listOf(IAstModification.ReplaceNode(functionCallExpr, constvalue, parent))
else {
val const2 = evaluator.evaluate(functionCallExpr, program)
return if(const2!=null)
listOf(IAstModification.ReplaceNode(functionCallExpr, const2, parent))
override fun after(bfc: BuiltinFunctionCall, parent: Node): Iterable<IAstModification> {
val constvalue = bfc.constValue(program)
return if(constvalue!=null)
listOf(IAstModification.ReplaceNode(bfc, constvalue, parent))
override fun after(forLoop: ForLoop, parent: Node): Iterable<IAstModification> {
fun adjustRangeDt(rangeFrom: NumericLiteral, targetDt: DataType, rangeTo: NumericLiteral, stepLiteral: NumericLiteral?, range: RangeExpression): RangeExpression? {
val fromCast = rangeFrom.cast(targetDt)
val toCast = rangeTo.cast(targetDt)
if(!fromCast.isValid || !toCast.isValid)
return null
val newStep =
if(stepLiteral!=null) {
val stepCast = stepLiteral.cast(targetDt)
} else {
return RangeExpression(fromCast.valueOrZero(), toCast.valueOrZero(), newStep, range.position)
// adjust the datatype of a range expression in for loops to the loop variable.
val iterableRange = forLoop.iterable as? RangeExpression ?: return noModifications
val rangeFrom = iterableRange.from as? NumericLiteral
val rangeTo = as? NumericLiteral
if(rangeFrom==null || rangeTo==null) return noModifications
val loopvar = forLoop.loopVar.targetVarDecl(program) ?: return noModifications
val stepLiteral = iterableRange.step as? NumericLiteral
when(loopvar.datatype) {
DataType.UBYTE -> {
if(rangeFrom.type!= DataType.UBYTE) {
// attempt to translate the iterable into ubyte values
val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
DataType.BYTE -> {
if(rangeFrom.type!= DataType.BYTE) {
// attempt to translate the iterable into byte values
val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
DataType.UWORD -> {
if(rangeFrom.type!= DataType.UWORD) {
// attempt to translate the iterable into uword values
val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
DataType.WORD -> {
if(rangeFrom.type!= DataType.WORD) {
// attempt to translate the iterable into word values
val newIter = adjustRangeDt(rangeFrom, loopvar.datatype, rangeTo, stepLiteral, iterableRange)
return listOf(IAstModification.ReplaceNode(forLoop.iterable, newIter, forLoop))
else -> { /* nothing for floats, these are not allowed in for loops and will give an error elsewhere */ }
return noModifications
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> {
val numval = decl.value as? NumericLiteral
if(decl.type== VarDeclType.CONST && numval!=null) {
val valueDt = numval.inferType(program)
if(valueDt isnot decl.datatype) {
if(decl.datatype!=DataType.BOOL || valueDt.isnot(DataType.UBYTE)) {
val cast = numval.cast(decl.datatype)
if (cast.isValid)
return listOf(IAstModification.ReplaceNode(numval, cast.valueOrZero(), decl))
return noModifications
override fun after(repeatLoop: RepeatLoop, parent: Node): Iterable<IAstModification> {
val count = (repeatLoop.iterations as? NumericLiteral)?.number
if(count!=null && floor(count)!=count) {
val integer = NumericLiteral.optimalInteger(count.toInt(), repeatLoop.position)
repeatLoop.iterations = integer
return noModifications
private class ShuffleOperands(val expr: BinaryExpression,
val exprOperator: String?,
val subExpr: BinaryExpression,
val newExprLeft: Expression?,
val newExprRight: Expression?,
val newSubexprLeft: Expression?,
val newSubexprRight: Expression?
): IAstModification {
override fun perform() {
if(exprOperator!=null) expr.operator = exprOperator
if(newExprLeft!=null) expr.left = newExprLeft
if(newExprRight!=null) expr.right = newExprRight
if(newSubexprLeft!=null) subExpr.left = newSubexprLeft
if(newSubexprRight!=null) subExpr.right = newSubexprRight
private fun groupTwoFloatConstsTogether(expr: BinaryExpression,
subExpr: BinaryExpression,
leftIsConst: Boolean,
rightIsConst: Boolean,
subleftIsConst: Boolean,
subrightIsConst: Boolean): IAstModification?
// TODO: this implements only a small set of possible reorderings at this time, we could perhaps add more
if(expr.operator==subExpr.operator) {
// both operators are the same.
// If associative, we can simply shuffle the const operands around to optimize.
if(expr.operator in AssociativeOperators && maySwapOperandOrder(expr)) {
return if(leftIsConst) {
ShuffleOperands(expr, null, subExpr, subExpr.right, null, null, expr.left)
ShuffleOperands(expr, null, subExpr, subExpr.left, null, expr.left, null)
} else {
ShuffleOperands(expr, null, subExpr, null, subExpr.right, null, expr.right)
ShuffleOperands(expr, null, subExpr, null, subExpr.left, expr.right, null)
// If - or /, we simetimes must reorder more, and flip operators (- -> +, / -> *)
if(expr.operator=="-" || expr.operator=="/") {
if(leftIsConst) {
return if (subleftIsConst) {
ShuffleOperands(expr, if (expr.operator == "-") "+" else "*", subExpr, subExpr.right, null, expr.left, subExpr.left)
} else {
BinaryExpression(expr.left, if (expr.operator == "-") "+" else "*", subExpr.right, subExpr.position),
expr.operator, subExpr.left, expr.position),
} else {
return if(subleftIsConst) {
return ShuffleOperands(expr, null, subExpr, null, subExpr.right, null, expr.right)
} else {
subExpr.left, expr.operator,
BinaryExpression(expr.right, if (expr.operator == "-") "+" else "*", subExpr.right, subExpr.position),
return null
if(expr.operator=="/" && subExpr.operator=="*") {
if(leftIsConst) {
val change = if(subleftIsConst) {
// C1/(C2*V) -> (C1/C2)/V
BinaryExpression(expr.left, "/", subExpr.left, subExpr.position),
subExpr.right, expr.position)
} else {
// C1/(V*C2) -> (C1/C2)/V
BinaryExpression(expr.left, "/", subExpr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (C1*V)/C2 -> (C1/C2)*V
BinaryExpression(subExpr.left, "/", expr.right, subExpr.position),
subExpr.right, expr.position)
} else {
// (V*C1)/C2 -> (C1/C2)*V
BinaryExpression(subExpr.right, "/", expr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
else if(expr.operator=="*" && subExpr.operator=="/" && subExpr.inferType(program).istype(DataType.FLOAT)) {
// division optimizations only valid for floats
if(leftIsConst) {
val change = if(subleftIsConst) {
// C1*(C2/V) -> (C1*C2)/V
BinaryExpression(expr.left, "*", subExpr.left, subExpr.position),
subExpr.right, expr.position)
} else {
// C1*(V/C2) -> (C1/C2)*V
BinaryExpression(expr.left, "/", subExpr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (C1/V)*C2 -> (C1*C2)/V
BinaryExpression(subExpr.left, "*", expr.right, subExpr.position),
subExpr.right, expr.position)
} else {
// (V/C1)*C2 -> (C1/C2)*V
BinaryExpression(expr.right, "/", subExpr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
else if(expr.operator=="+" && subExpr.operator=="-") {
val change = if(subleftIsConst){
// c1+(c2-v) -> (c1+c2)-v
BinaryExpression(expr.left, "+", subExpr.left, subExpr.position),
subExpr.right, expr.position)
} else {
// c1+(v-c2) -> v+(c1-c2)
BinaryExpression(expr.left, "-", subExpr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (c1-v)+c2 -> (c1+c2)-v
BinaryExpression(subExpr.left, "+", expr.right, subExpr.position),
subExpr.right, expr.position)
} else {
// (v-c1)+c2 -> v+(c2-c1)
BinaryExpression(expr.right, "-", subExpr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
else if(expr.operator=="-" && subExpr.operator=="+") {
if(leftIsConst) {
val change = if(subleftIsConst) {
// c1-(c2+v) -> (c1-c2)-v
BinaryExpression(expr.left, "-", subExpr.left, subExpr.position),
subExpr.right, expr.position)
} else {
// c1-(v+c2) -> (c1-c2)-v
BinaryExpression(expr.left, "-", subExpr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
} else {
val change = if(subleftIsConst) {
// (c1+v)-c2 -> v+(c1-c2)
BinaryExpression(subExpr.left, "-", expr.right, subExpr.position),
subExpr.right, expr.position)
} else {
// (v+c1)-c2 -> v+(c1-c2)
BinaryExpression(subExpr.right, "-", expr.right, subExpr.position),
subExpr.left, expr.position)
return IAstModification.ReplaceNode(expr, change, expr.parent)
return null