comparisonjump tweak

This commit is contained in:
Irmen de Jong 2021-10-30 23:15:18 +02:00
parent 03ac9b6956
commit 61fa3bc77c
3 changed files with 132 additions and 77 deletions

View File

@ -982,7 +982,7 @@ class AsmGen(private val program: Program,
}
private fun translate(stmt: IfStatement) {
checkBooleanExpression(stmt.condition) // we require the condition to be of the form 'x <comparison> <value>'
requireComparisonExpression(stmt.condition) // IfStatement: condition must be of form 'x <comparison> <value>'
val booleanCondition = stmt.condition as BinaryExpression
// DISABLED FOR NOW:
@ -1008,9 +1008,9 @@ class AsmGen(private val program: Program,
}
}
private fun checkBooleanExpression(condition: Expression) {
private fun requireComparisonExpression(condition: Expression) {
if(condition !is BinaryExpression || condition.operator !in comparisonOperators)
throw AssemblyError("expected boolean expression $condition")
throw AssemblyError("expected boolean comparison expression $condition")
}
private fun translate(stmt: RepeatLoop) {
@ -1161,7 +1161,7 @@ $repeatLabel lda $counterVar
}
private fun translate(stmt: WhileLoop) {
checkBooleanExpression(stmt.condition) // we require the condition to be of the form 'x <comparison> <value>'
requireComparisonExpression(stmt.condition) // WhileLoop: condition must be of form 'x <comparison> <value>'
val booleanCondition = stmt.condition as BinaryExpression
val whileLabel = makeLabel("while")
val endLabel = makeLabel("whileend")
@ -1175,7 +1175,7 @@ $repeatLabel lda $counterVar
}
private fun translate(stmt: UntilLoop) {
checkBooleanExpression(stmt.condition) // we require the condition to be of the form 'x <comparison> <value>'
requireComparisonExpression(stmt.condition) // UntilLoop: condition must be of form 'x <comparison> <value>'
val booleanCondition = stmt.condition as BinaryExpression
val repeatLabel = makeLabel("repeat")
val endLabel = makeLabel("repeatend")

View File

@ -33,13 +33,16 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
}
internal fun translateComparisonExpressionWithJumpIfFalse(expr: BinaryExpression, jumpIfFalseLabel: String) {
// This is a helper routine called from while, do-util, and if expressions to generate optimized conditional branching code.
// This is a helper routine called from while, do-util, and if expressions to generate optimized conditional branching code.
// First, if it is of the form: <constvalue> <comparison> X , then flip the expression so the constant is always the right operand.
var left = expr.left
var right = expr.right
var operator = expr.operator
var leftConstVal = left.constValue(program)
var rightConstVal = right.constValue(program)
// make sure the constant value is on the right of the comparison expression
if(leftConstVal!=null) {
val tmp = left
left = right
@ -55,32 +58,90 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
}
}
val idt = left.inferType(program)
if(!idt.isKnown)
throw AssemblyError("unknown dt")
val dt = idt.getOr(DataType.UNDEFINED)
when (operator) {
"==" -> {
// if the left operand is an expression, and the right is 0, we can just evaluate that expression,
// and use the result value directly to determine the boolean result. Shortcut only for integers.
if(rightConstVal?.number?.toDouble() == 0.0) {
if(dt in ByteDatatypes) {
asmgen.assignExpressionToRegister(left, RegisterOrPair.A)
if(left is FunctionCall && !left.isSimple)
if (rightConstVal!=null && rightConstVal.number.toDouble() == 0.0)
jumpIfZeroOrNot(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal)
else
jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal)
}
private fun jumpIfZeroOrNot(
left: Expression,
operator: String,
right: Expression,
jumpIfFalseLabel: String,
leftConstVal: NumericLiteralValue?,
rightConstVal: NumericLiteralValue?
) {
val dt = left.inferType(program).getOrElse { throw AssemblyError("unknown dt") }
when(dt) {
DataType.UBYTE, DataType.BYTE -> {
asmgen.assignExpressionToRegister(left, RegisterOrPair.A)
when (operator) {
"==", "!=" -> {
// simple zero (in)equality check
if (left is FunctionCall && !left.isSimple)
asmgen.out(" cmp #0")
asmgen.out(" bne $jumpIfFalseLabel")
return
if (operator == "==")
asmgen.out(" bne $jumpIfFalseLabel")
else
asmgen.out(" beq $jumpIfFalseLabel")
}
else if(dt in WordDatatypes) {
asmgen.assignExpressionToRegister(left, RegisterOrPair.AY)
asmgen.out("""
sty P8ZP_SCRATCH_B1
ora P8ZP_SCRATCH_B1
bne $jumpIfFalseLabel""")
return
else -> {
// TODO optimize byte other operators
jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal)
}
}
}
DataType.UWORD, DataType.WORD -> {
asmgen.assignExpressionToRegister(left, RegisterOrPair.AY)
when (operator) {
"==" -> {
asmgen.out(" sty P8ZP_SCRATCH_B1 | ora P8ZP_SCRATCH_B1 | bne $jumpIfFalseLabel")
}
"!=" -> {
asmgen.out(" sty P8ZP_SCRATCH_B1 | ora P8ZP_SCRATCH_B1 | beq $jumpIfFalseLabel")
}
else -> {
// TODO optimize word other operators
jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal)
}
}
}
DataType.FLOAT -> {
asmgen.assignExpressionToRegister(left, RegisterOrPair.FAC1)
when (operator) {
"==" -> {
asmgen.out(" jsr floats.SIGN") // SIGN(fac1) to A, $ff, $0, $1 for negative, zero, positive
asmgen.out(" bne $jumpIfFalseLabel")
}
"!=" -> {
asmgen.out(" jsr floats.SIGN") // SIGN(fac1) to A, $ff, $0, $1 for negative, zero, positive
asmgen.out(" beq $jumpIfFalseLabel")
}
else -> {
// TODO optimize float other operators?
jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal)
}
}
}
else -> {
throw AssemblyError("invalid dt")
}
}
}
private fun jumpIfComparison(
left: Expression,
operator: String,
right: Expression,
jumpIfFalseLabel: String,
leftConstVal: NumericLiteralValue?,
rightConstVal: NumericLiteralValue?
) {
val dt = left.inferType(program).getOrElse { throw AssemblyError("unknown dt") }
when (operator) {
"==" -> {
when (dt) {
in ByteDatatypes -> translateByteEqualsJump(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
in WordDatatypes -> translateWordEqualsJump(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
@ -90,26 +151,6 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
}
}
"!=" -> {
// if the left operand is an expression, and the right is 0, we can just evaluate that expression,
// and use the result value directly to determine the boolean result. Shortcut only for integers.
if(rightConstVal?.number?.toDouble() == 0.0) {
if(dt in ByteDatatypes) {
asmgen.assignExpressionToRegister(left, RegisterOrPair.A)
if(left is FunctionCall && !left.isSimple)
asmgen.out(" cmp #0")
asmgen.out(" beq $jumpIfFalseLabel")
return
}
else if(dt in WordDatatypes) {
asmgen.assignExpressionToRegister(left, RegisterOrPair.AY)
asmgen.out("""
sty P8ZP_SCRATCH_B1
ora P8ZP_SCRATCH_B1
beq $jumpIfFalseLabel""")
return
}
}
when (dt) {
in ByteDatatypes -> translateByteNotEqualsJump(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
in WordDatatypes -> translateWordNotEqualsJump(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)

View File

@ -1,41 +1,55 @@
%import string
%import floats
%import textio
%zeropage basicsafe
main {
sub start() {
uword[] values = [1111,2222,3333,4444]
ubyte xx
float ff
@($2000) = 'a'
@($2001) = 'b'
@($2002) = 'c'
@($2003) = 0
ff=0
asmfunc([999,888,777])
asmfunc(values)
asmfunc($2000)
txt.nl()
func([999,888,777])
func(values)
func($2000)
}
if ff==0 {
txt.print("ff=0\n")
}
if ff!=0 {
txt.print("ff!=0 (error!)\n")
}
ff=-0.22
if ff==0 {
txt.print("ff=0 (error!)\n")
}
if ff!=0 {
txt.print("ff!=0\n")
}
asmsub asmfunc(uword[] thing @AY) {
%asm {{
sta func.thing
sty func.thing+1
jmp func
}}
}
sub func(uword[] thing) {
uword t2 = thing as uword
ubyte length = string.length(thing)
txt.print_uwhex(thing, true)
txt.nl()
txt.print_ub(length)
txt.nl()
txt.print(thing)
txt.nl()
if xx { ; doesn't use stack...
xx++
}
xx = xx+1 ; doesn't use stack...
if 8<xx {
}
if xx+1 { ; TODO why does this use stack?
xx++
}
xx = xx & %0001 ; doesn't use stack...
if xx & %0001 { ; TODO why does this use stack?
xx--
}
do {
xx++
} until xx+1
while xx+1 {
xx++
}
}
}