optimize various simple cases for '**' (pow) like 2**x => bitshift

This commit is contained in:
Irmen de Jong 2020-12-10 22:37:12 +01:00
parent 1d299c56e0
commit 83ceb0fde9
3 changed files with 43 additions and 6 deletions

View File

@ -7,6 +7,7 @@ import prog8.ast.expressions.*
import prog8.ast.processing.AstWalker import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModification import prog8.ast.processing.IAstModification
import prog8.ast.statements.* import prog8.ast.statements.*
import kotlin.math.pow
internal class ConstantFoldingOptimizer(private val program: Program) : AstWalker() { internal class ConstantFoldingOptimizer(private val program: Program) : AstWalker() {
@ -97,6 +98,43 @@ internal class ConstantFoldingOptimizer(private val program: Program) : AstWalke
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> { override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
val leftconst = expr.left.constValue(program) val leftconst = expr.left.constValue(program)
val rightconst = expr.right.constValue(program) val rightconst = expr.right.constValue(program)
val modifications = mutableListOf<IAstModification>()
if(expr.operator == "**" && leftconst!=null) {
// optimize various simple cases of ** :
// optimize away 1 ** x into just 1 and 0 ** x into just 0
// optimize 2 ** x into (1<<x) if both operands are integer.
val leftDt = leftconst.inferType(program).typeOrElse(DataType.STRUCT)
when (leftconst.number.toDouble()) {
0.0 -> {
val value = NumericLiteralValue(leftDt, 0, expr.position)
modifications += IAstModification.ReplaceNode(expr, value, parent)
}
1.0 -> {
val value = NumericLiteralValue(leftDt, 1, expr.position)
modifications += IAstModification.ReplaceNode(expr, value, parent)
}
2.0 -> {
if(rightconst!=null) {
val value = NumericLiteralValue(leftDt, 2.0.pow(rightconst.number.toDouble()), expr.position)
modifications += IAstModification.ReplaceNode(expr, value, parent)
} else {
val rightDt = expr.right.inferType(program).typeOrElse(DataType.STRUCT)
if(leftDt in IntegerDatatypes && rightDt in IntegerDatatypes) {
val targetDt =
when (parent) {
is Assignment -> parent.target.inferType(program).typeOrElse(DataType.STRUCT)
is VarDecl -> parent.datatype
else -> leftDt
}
val one = NumericLiteralValue(targetDt, 1, expr.position)
val shift = BinaryExpression(one, "<<", expr.right, expr.position)
modifications += IAstModification.ReplaceNode(expr, shift, parent)
}
}
}
}
}
val subExpr: BinaryExpression? = when { val subExpr: BinaryExpression? = when {
leftconst!=null -> expr.right as? BinaryExpression leftconst!=null -> expr.right as? BinaryExpression
@ -111,7 +149,8 @@ internal class ConstantFoldingOptimizer(private val program: Program) : AstWalke
val change = groupTwoConstsTogether(expr, subExpr, val change = groupTwoConstsTogether(expr, subExpr,
leftconst != null, rightconst != null, leftconst != null, rightconst != null,
subleftconst != null, subrightconst != null) subleftconst != null, subrightconst != null)
return change?.let { listOf(it) } ?: noModifications if(change!=null)
modifications += change
} }
} }
@ -119,10 +158,10 @@ internal class ConstantFoldingOptimizer(private val program: Program) : AstWalke
if(leftconst != null && rightconst != null) { if(leftconst != null && rightconst != null) {
val evaluator = ConstExprEvaluator() val evaluator = ConstExprEvaluator()
val result = evaluator.evaluate(leftconst, expr.operator, rightconst) val result = evaluator.evaluate(leftconst, expr.operator, rightconst)
return listOf(IAstModification.ReplaceNode(expr, result, parent)) modifications += IAstModification.ReplaceNode(expr, result, parent)
} }
return noModifications return modifications
} }
override fun after(array: ArrayLiteralValue, parent: Node): Iterable<IAstModification> { override fun after(array: ArrayLiteralValue, parent: Node): Iterable<IAstModification> {

View File

@ -2,8 +2,6 @@
TODO TODO
==== ====
- optimize away 1 ** x into just 1 and 0 ** x into just 0
- optimize 2 ** x into (1<<x) if x is an integer. where 1 is in the type of the assign target if possible
- add minv(a,b) and maxv(a,b) functions to determine the max or min of 2 values - add minv(a,b) and maxv(a,b) functions to determine the max or min of 2 values
- add progend() builtin function that returns the last address of the program in memory + 1 (to be able to stick dynamic data after the program easily) - add progend() builtin function that returns the last address of the program in memory + 1 (to be able to stick dynamic data after the program easily)
- see if we can group some errors together for instance the (now single) errors about unidentified symbols - see if we can group some errors together for instance the (now single) errors about unidentified symbols

View File

@ -75,7 +75,7 @@ main {
uword width = mkword(buffer[4], buffer[3]) uword width = mkword(buffer[4], buffer[3])
uword height = mkword(buffer[6], buffer[5]) uword height = mkword(buffer[6], buffer[5])
ubyte bpp = buffer[7] ubyte bpp = buffer[7]
uword num_colors = $0001 << bpp uword num_colors = 2 ** bpp
ubyte flags = buffer[8] ubyte flags = buffer[8]
ubyte compression = flags & %00000011 ubyte compression = flags & %00000011
ubyte palette_format = (flags & %00000100) >> 2 ubyte palette_format = (flags & %00000100) >> 2