mirror of
https://github.com/irmen/prog8.git
synced 2025-01-11 13:29:45 +00:00
tweak if statement handling
This commit is contained in:
parent
9906b58818
commit
97e84d0977
@ -1098,16 +1098,21 @@ class AsmGen(private val program: Program,
|
||||
val booleanCondition = stmt.condition as BinaryExpression
|
||||
|
||||
if (stmt.elsepart.isEmpty()) {
|
||||
val endLabel = makeLabel("if_end")
|
||||
translateComparisonExpressionWithJumpIfFalse(booleanCondition, endLabel)
|
||||
translate(stmt.truepart)
|
||||
out(endLabel)
|
||||
// TODO specialize this
|
||||
// if(stmt.truepart.statements.singleOrNull() is Jump) {
|
||||
// translateCompareAndJumpIfTrue(booleanCondition, stmt.truepart.statements[0] as Jump)
|
||||
// } else {
|
||||
val endLabel = makeLabel("if_end")
|
||||
translateCompareAndJumpIfFalse(booleanCondition, endLabel)
|
||||
translate(stmt.truepart)
|
||||
out(endLabel)
|
||||
// }
|
||||
}
|
||||
else {
|
||||
// both true and else parts
|
||||
val elseLabel = makeLabel("if_else")
|
||||
val endLabel = makeLabel("if_end")
|
||||
translateComparisonExpressionWithJumpIfFalse(booleanCondition, elseLabel)
|
||||
translateCompareAndJumpIfFalse(booleanCondition, elseLabel)
|
||||
translate(stmt.truepart)
|
||||
jmp(endLabel)
|
||||
out(elseLabel)
|
||||
@ -1607,44 +1612,39 @@ $label nop""")
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
private fun translateComparisonExpressionWithJumpIfFalse(expr: BinaryExpression, jumpIfFalseLabel: String) {
|
||||
// This is a helper routine called from if expressions to generate optimized conditional branching code.
|
||||
// First, if it is of the form: <constvalue> <comparison> X , then flip the expression so the constant is always the right operand.
|
||||
|
||||
var left = expr.left
|
||||
var right = expr.right
|
||||
var operator = expr.operator
|
||||
var leftConstVal = left.constValue(program)
|
||||
var rightConstVal = right.constValue(program)
|
||||
|
||||
// make sure the constant value is on the right of the comparison expression
|
||||
if(leftConstVal!=null) {
|
||||
val tmp = left
|
||||
left = right
|
||||
right = tmp
|
||||
val tmp2 = leftConstVal
|
||||
leftConstVal = rightConstVal
|
||||
rightConstVal = tmp2
|
||||
when(expr.operator) {
|
||||
"<" -> operator = ">"
|
||||
"<=" -> operator = ">="
|
||||
">" -> operator = "<"
|
||||
">=" -> operator = "<="
|
||||
}
|
||||
}
|
||||
private fun translateCompareAndJumpIfTrue(expr: BinaryExpression, jump: Jump) {
|
||||
val left = expr.left
|
||||
val right = expr.right
|
||||
val operator = expr.operator
|
||||
val leftConstVal = left.constValue(program)
|
||||
val rightConstVal = right.constValue(program)
|
||||
|
||||
if (rightConstVal!=null && rightConstVal.number == 0.0)
|
||||
jumpIfZeroOrNot(left, operator, jumpIfFalseLabel)
|
||||
testZeroAndJump(left, operator, jump, null)
|
||||
else
|
||||
jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal)
|
||||
testNonzeroComparisonAndJump(left, operator, right, jump, null, leftConstVal, rightConstVal)
|
||||
}
|
||||
|
||||
private fun jumpIfZeroOrNot(
|
||||
private fun translateCompareAndJumpIfFalse(expr: BinaryExpression, jumpIfFalseLabel: String) {
|
||||
val left = expr.left
|
||||
val right = expr.right
|
||||
val operator = expr.operator
|
||||
val leftConstVal = left.constValue(program)
|
||||
val rightConstVal = right.constValue(program)
|
||||
|
||||
if (rightConstVal!=null && rightConstVal.number == 0.0)
|
||||
testZeroAndJump(left, operator, null, jumpIfFalseLabel)
|
||||
else
|
||||
testNonzeroComparisonAndJump(left, operator, right, null, jumpIfFalseLabel, leftConstVal, rightConstVal)
|
||||
}
|
||||
|
||||
private fun testZeroAndJump(
|
||||
left: Expression,
|
||||
operator: String,
|
||||
jumpIfFalseLabel: String
|
||||
jumpIfTrue: Jump?,
|
||||
jumpIfFalseLabel: String?
|
||||
) {
|
||||
require(jumpIfTrue!=null || jumpIfFalseLabel!=null)
|
||||
when(val dt = left.inferType(program).getOr(DataType.UNDEFINED)) {
|
||||
DataType.UBYTE, DataType.UWORD -> {
|
||||
if(operator=="<") {
|
||||
@ -1729,16 +1729,20 @@ $label nop""")
|
||||
}
|
||||
}
|
||||
|
||||
private fun jumpIfComparison(
|
||||
private fun testNonzeroComparisonAndJump(
|
||||
left: Expression,
|
||||
operator: String,
|
||||
right: Expression,
|
||||
jumpIfFalseLabel: String,
|
||||
jumpIfTrue: Jump?,
|
||||
jumpIfFalseLabel: String?,
|
||||
leftConstVal: NumericLiteralValue?,
|
||||
rightConstVal: NumericLiteralValue?
|
||||
) {
|
||||
require(jumpIfTrue!=null || jumpIfFalseLabel!=null)
|
||||
val dt = left.inferType(program).getOrElse { throw AssemblyError("unknown dt") }
|
||||
|
||||
jumpIfFalseLabel!! // TODO jump if true... or rewrite everything to use just jump-if-false
|
||||
|
||||
when (operator) {
|
||||
"==" -> {
|
||||
when (dt) {
|
||||
|
@ -1,5 +1,6 @@
|
||||
package prog8.optimizer
|
||||
|
||||
import prog8.ast.IStatementContainer
|
||||
import prog8.ast.Node
|
||||
import prog8.ast.Program
|
||||
import prog8.ast.base.DataType
|
||||
@ -7,7 +8,10 @@ import prog8.ast.base.FatalAstException
|
||||
import prog8.ast.base.IntegerDatatypes
|
||||
import prog8.ast.base.NumericDatatypes
|
||||
import prog8.ast.expressions.*
|
||||
import prog8.ast.statements.AnonymousScope
|
||||
import prog8.ast.statements.Assignment
|
||||
import prog8.ast.statements.IfStatement
|
||||
import prog8.ast.statements.Jump
|
||||
import prog8.ast.walk.AstWalker
|
||||
import prog8.ast.walk.IAstModification
|
||||
import kotlin.math.abs
|
||||
@ -54,6 +58,31 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
|
||||
return mods
|
||||
}
|
||||
|
||||
override fun after(ifStatement: IfStatement, parent: Node): Iterable<IAstModification> {
|
||||
val truepart = ifStatement.truepart
|
||||
val elsepart = ifStatement.elsepart
|
||||
if(truepart.isNotEmpty() && elsepart.isNotEmpty()) {
|
||||
if(truepart.statements.singleOrNull() is Jump) {
|
||||
return listOf(
|
||||
IAstModification.InsertAfter(ifStatement, elsepart, parent as IStatementContainer),
|
||||
IAstModification.ReplaceNode(elsepart, AnonymousScope(mutableListOf(), elsepart.position), ifStatement)
|
||||
)
|
||||
}
|
||||
if(elsepart.statements.singleOrNull() is Jump) {
|
||||
val invertedCondition = invertCondition(ifStatement.condition)
|
||||
if(invertedCondition!=null) {
|
||||
return listOf(
|
||||
IAstModification.ReplaceNode(ifStatement.condition, invertedCondition, ifStatement),
|
||||
IAstModification.InsertAfter(ifStatement, truepart, parent as IStatementContainer),
|
||||
IAstModification.ReplaceNode(elsepart, AnonymousScope(mutableListOf(), elsepart.position), ifStatement),
|
||||
IAstModification.ReplaceNode(truepart, elsepart, ifStatement)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return noModifications
|
||||
}
|
||||
|
||||
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
|
||||
val leftVal = expr.left.constValue(program)
|
||||
val rightVal = expr.right.constValue(program)
|
||||
@ -696,3 +725,24 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() {
|
||||
private data class BinExprWithConstants(val expr: BinaryExpression, val leftVal: NumericLiteralValue?, val rightVal: NumericLiteralValue?)
|
||||
|
||||
}
|
||||
|
||||
|
||||
fun invertCondition(cond: Expression): BinaryExpression? {
|
||||
if(cond is BinaryExpression) {
|
||||
val invertedOperator = invertedComparisonOperator(cond.operator)
|
||||
if (invertedOperator != null)
|
||||
return BinaryExpression(cond.left, invertedOperator, cond.right, cond.position)
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
fun invertedComparisonOperator(operator: String) =
|
||||
when (operator) {
|
||||
"==" -> "!="
|
||||
"!=" -> "=="
|
||||
"<" -> ">="
|
||||
">" -> "<="
|
||||
"<=" -> ">"
|
||||
">=" -> "<"
|
||||
else -> null
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import prog8.ast.statements.*
|
||||
import prog8.ast.walk.AstWalker
|
||||
import prog8.ast.walk.IAstModification
|
||||
import prog8.compilerinterface.IErrorReporter
|
||||
import prog8.optimizer.invertedComparisonOperator
|
||||
|
||||
|
||||
internal class VariousCleanups(val program: Program, val errors: IErrorReporter): AstWalker() {
|
||||
@ -81,16 +82,7 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter)
|
||||
val comparison = expr.expression as? BinaryExpression
|
||||
if (comparison != null) {
|
||||
// NOT COMPARISON ==> inverted COMPARISON
|
||||
val invertedOperator =
|
||||
when (comparison.operator) {
|
||||
"==" -> "!="
|
||||
"!=" -> "=="
|
||||
"<" -> ">="
|
||||
">" -> "<="
|
||||
"<=" -> ">"
|
||||
">=" -> "<"
|
||||
else -> null
|
||||
}
|
||||
val invertedOperator = invertedComparisonOperator(comparison.operator)
|
||||
if (invertedOperator != null) {
|
||||
comparison.operator = invertedOperator
|
||||
return listOf(IAstModification.ReplaceNode(expr, comparison, parent))
|
||||
@ -99,4 +91,25 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter)
|
||||
}
|
||||
return noModifications
|
||||
}
|
||||
|
||||
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
|
||||
if(expr.operator in ComparisonOperators) {
|
||||
val leftConstVal = expr.left.constValue(program)
|
||||
val rightConstVal = expr.right.constValue(program)
|
||||
// make sure the constant value is on the right of the comparison expression
|
||||
if(rightConstVal==null && leftConstVal!=null) {
|
||||
val newOperator =
|
||||
when(expr.operator) {
|
||||
"<" -> ">"
|
||||
"<=" -> ">="
|
||||
">" -> "<"
|
||||
">=" -> "<="
|
||||
else -> expr.operator
|
||||
}
|
||||
val replacement = BinaryExpression(expr.right, newOperator, expr.left, expr.position)
|
||||
return listOf(IAstModification.ReplaceNode(expr, replacement, parent))
|
||||
}
|
||||
}
|
||||
return noModifications
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user