partly optimize assignments so that simple increments and decrements can be done via separate statements (postincrdecr)

This commit is contained in:
Irmen de Jong 2020-07-26 18:56:51 +02:00
parent 8a3189123a
commit d32a970101
7 changed files with 104 additions and 200 deletions

View File

@ -287,9 +287,16 @@ class AstToSourceCode(val output: (text: String) -> Unit, val program: Program):
}
override fun visit(assignment: Assignment) {
assignment.target.accept(this)
output(" = ")
assignment.value.accept(this)
val binExpr = assignment.value as? BinaryExpression
if(binExpr!=null && binExpr.left isSameAs assignment.target) {
assignment.target.accept(this)
output(" ${binExpr.operator}= ")
binExpr.right.accept(this)
} else {
assignment.target.accept(this)
output(" = ")
assignment.value.accept(this)
}
}
override fun visit(postIncrDecr: PostIncrDecr) {

View File

@ -5,7 +5,6 @@ import prog8.ast.Program
import prog8.ast.processing.*
import prog8.compiler.CompilationOptions
import prog8.compiler.BeforeAsmGenerationAstChanger
import prog8.optimizer.AssignmentTransformer
internal fun Program.checkValid(compilerOptions: CompilationOptions, errors: ErrorReporter) {
@ -36,17 +35,6 @@ internal fun Program.verifyFunctionArgTypes() {
fixer.visit(this)
}
internal fun Program.transformAssignments(errors: ErrorReporter) {
val transform = AssignmentTransformer(this, errors)
transform.visit(this)
while(transform.optimizationsDone>0 && errors.isEmpty()) {
transform.applyModifications()
transform.optimizationsDone = 0
transform.visit(this)
}
transform.applyModifications()
}
internal fun Module.checkImportedValid() {
val imr = ImportedModuleDirectiveRemover()
imr.visit(this, this.parent)

View File

@ -175,8 +175,6 @@ private fun optimizeAst(programAst: Program, errors: ErrorReporter) {
}
private fun postprocessAst(programAst: Program, errors: ErrorReporter, compilerOptions: CompilationOptions) {
programAst.transformAssignments(errors)
errors.handle()
programAst.addTypecasts(errors)
errors.handle()
programAst.variousCleanups()

View File

@ -1,152 +0,0 @@
package prog8.optimizer
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.base.ErrorReporter
import prog8.ast.expressions.BinaryExpression
import prog8.ast.processing.AstWalker
import prog8.ast.processing.IAstModification
import prog8.ast.statements.Assignment
import prog8.ast.statements.PostIncrDecr
// TODO integrate this in the StatementOptimizer
internal class AssignmentTransformer(val program: Program, val errors: ErrorReporter) : AstWalker() {
var optimizationsDone: Int = 0
private val noModifications = emptyList<IAstModification>()
override fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> {
if(assignment.target isSameAs assignment.value) {
TODO("remove assignment to self")
}
return noModifications
// TODO add these optimizations back:
/*
if(assignment.aug_op == "+=") {
val binExpr = assignment.value as? BinaryExpression
if (binExpr != null) {
val leftnum = binExpr.left.constValue(program)?.number?.toDouble()
val rightnum = binExpr.right.constValue(program)?.number?.toDouble()
if(binExpr.operator == "+") {
when {
leftnum == 1.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
leftnum == 2.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
rightnum == 1.0 -> {
// x += y + 1 -> x += y , x++
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "++", assignment.position), parent)
)
}
rightnum == 2.0 -> {
// x += y + 2 -> x += y , x++, x++
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "++", assignment.position), parent),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "++", assignment.position), parent)
)
}
}
} else if(binExpr.operator == "-") {
when {
leftnum == 1.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
leftnum == 2.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
rightnum == 1.0 -> {
// x += y - 1 -> x += y , x--
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "--", assignment.position), parent)
)
}
rightnum == 2.0 -> {
// x += y - 2 -> x += y , x--, x--
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "--", assignment.position), parent),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "--", assignment.position), parent)
)
}
}
}
}
} else if(assignment.aug_op == "-=") {
val binExpr = assignment.value as? BinaryExpression
if (binExpr != null) {
val leftnum = binExpr.left.constValue(program)?.number?.toDouble()
val rightnum = binExpr.right.constValue(program)?.number?.toDouble()
if(binExpr.operator == "+") {
when {
leftnum == 1.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
leftnum == 2.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
rightnum == 1.0 -> {
// x -= y + 1 -> x -= y , x--
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "--", assignment.position), parent)
)
}
rightnum == 2.0 -> {
// x -= y + 2 -> x -= y , x--, x--
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "--", assignment.position), parent),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "--", assignment.position), parent)
)
}
}
} else if(binExpr.operator == "-") {
when {
leftnum == 1.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
leftnum == 2.0 -> {
optimizationsDone++
return listOf(IAstModification.SwapOperands(binExpr))
}
rightnum == 1.0 -> {
// x -= y - 1 -> x -= y , x++
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "++", assignment.position), parent)
)
}
rightnum == 2.0 -> {
// x -= y - 2 -> x -= y , x++, x++
return listOf(
IAstModification.ReplaceNode(assignment.value, binExpr.left, assignment),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "++", assignment.position), parent),
IAstModification.InsertAfter(assignment, PostIncrDecr(assignment.target, "++", assignment.position), parent)
)
}
}
}
}
}
return noModifications
*/
}
}

View File

@ -84,6 +84,10 @@ internal class ExpressionSimplifier(private val program: Program) : AstWalker()
}
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
// TODO: (A +/- B) +/- C ==> A +/- ( B +/- C)
// TODO: (A * / B) * / C ==> A * / ( B * / C)
val leftVal = expr.left.constValue(program)
val rightVal = expr.right.constValue(program)

View File

@ -324,11 +324,67 @@ internal class StatementOptimizer(private val program: Program,
return noModifications
}
override fun before(assignment: Assignment, parent: Node): Iterable<IAstModification> {
val binExpr = assignment.value as? BinaryExpression
if(binExpr!=null) {
if(binExpr.left isSameAs assignment.target) {
val rExpr = binExpr.right as? BinaryExpression
if(rExpr!=null) {
val op1 = binExpr.operator
val op2 = rExpr.operator
if(rExpr.left is NumericLiteralValue && op2 in setOf("+", "*", "&", "|")) {
// associative operator, make sure the constant numeric value is second (right)
return listOf(IAstModification.SwapOperands(rExpr))
}
val rNum = (rExpr.right as? NumericLiteralValue)?.number
if(rNum!=null) {
if (op1 == "+" || op1 == "-") {
if (op2 == "+") {
// A = A +/- B + N
val expr2 = BinaryExpression(binExpr.left, binExpr.operator, rExpr.left, binExpr.position)
val addConstant = Assignment(
assignment.target,
BinaryExpression(binExpr.left, "+", rExpr.right, rExpr.position),
assignment.position
)
return listOf(
IAstModification.ReplaceNode(binExpr, expr2, binExpr.parent),
IAstModification.InsertAfter(assignment, addConstant, parent))
} else if (op2 == "-") {
// A = A +/- B - N
val expr2 = BinaryExpression(binExpr.left, binExpr.operator, rExpr.left, binExpr.position)
val subConstant = Assignment(
assignment.target,
BinaryExpression(binExpr.left, "-", rExpr.right, rExpr.position),
assignment.position
)
return listOf(
IAstModification.ReplaceNode(binExpr, expr2, binExpr.parent),
IAstModification.InsertAfter(assignment, subConstant, parent))
}
}
}
}
}
if(binExpr.right isSameAs assignment.target) {
if(binExpr.operator in setOf("+", "*", "&", "|")) {
// associative operator, swap the operands so that the assignment target is first (left)
return listOf(IAstModification.SwapOperands(binExpr))
}
}
}
return noModifications
}
override fun after(assignment: Assignment, parent: Node): Iterable<IAstModification> {
// remove assignments to self
if(assignment.target isSameAs assignment.value) {
if(assignment.target.isNotMemory(program.namespace))
return listOf(IAstModification.Remove(assignment, parent))
// remove assignment to self
return listOf(IAstModification.Remove(assignment, parent))
}
val targetIDt = assignment.target.inferType(program, assignment)

View File

@ -9,38 +9,41 @@ main {
sub start() {
repeat 10 {
c64.CHROUT('*')
}
c64.CHROUT('\n')
ubyte wv
ubyte wv2
ubyte ub = 9
repeat ub {
c64.CHROUT('*')
}
c64.CHROUT('\n')
wv *= wv2
repeat 320 {
c64.CHROUT('+')
}
c64.CHROUT('\n')
wv += 10
wv += 20
wv += 30
uword uw = 320
repeat uw {
c64.CHROUT('-')
}
c64.CHROUT('\n')
wv += 1 + wv2
wv += 2 + wv2
wv += 3 + wv2
ub = 7
repeat ub+2 {
c64.CHROUT('*')
}
c64.CHROUT('\n')
wv += wv2 + 1
wv += wv2 + 2
wv += wv2 + 3
uw = 318
repeat uw+2 {
c64.CHROUT('*')
}
c64.CHROUT('\n')
wv = wv + 1 + wv2
wv = wv + 2 + wv2
wv = wv + 3 + wv2
wv = 1 + wv2 + wv
wv = 2 + wv2 + wv
wv = 3 + wv2 + wv
wv = wv + wv2 + 1
wv = wv + wv2 + 2
wv = wv + wv2 + 3
wv = wv2 + 1 + wv
wv = wv2 + 2 + wv
wv = wv2 + 3 + wv
wv = wv2 + wv + 1
wv = wv2 + wv + 2
wv = wv2 + wv + 3
}
}