diff --git a/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/ExpressionsAsmGen.kt b/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/ExpressionsAsmGen.kt index 50355440e..642f89c62 100644 --- a/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/ExpressionsAsmGen.kt +++ b/codeGeneration/src/prog8/compiler/target/cpu6502/codegen/ExpressionsAsmGen.kt @@ -72,19 +72,17 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge leftConstVal: NumericLiteralValue?, rightConstVal: NumericLiteralValue? ) { - val dt = left.inferType(program).getOrElse { throw AssemblyError("unknown dt") } - when(dt) { + when(left.inferType(program).getOr(DataType.UNDEFINED)) { DataType.UBYTE, DataType.BYTE -> { asmgen.assignExpressionToRegister(left, RegisterOrPair.A) + if (left is FunctionCall && !left.isSimple) + asmgen.out(" cmp #0") when (operator) { - "==", "!=" -> { - // simple zero (in)equality check - if (left is FunctionCall && !left.isSimple) - asmgen.out(" cmp #0") - if (operator == "==") - asmgen.out(" bne $jumpIfFalseLabel") - else - asmgen.out(" beq $jumpIfFalseLabel") + "==" -> { + asmgen.out(" bne $jumpIfFalseLabel") + } + "!=" -> { + asmgen.out(" beq $jumpIfFalseLabel") } else -> { // TODO optimize byte other operators @@ -94,12 +92,13 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge } DataType.UWORD, DataType.WORD -> { asmgen.assignExpressionToRegister(left, RegisterOrPair.AY) + asmgen.out(" sty P8ZP_SCRATCH_B1 | ora P8ZP_SCRATCH_B1") when (operator) { "==" -> { - asmgen.out(" sty P8ZP_SCRATCH_B1 | ora P8ZP_SCRATCH_B1 | bne $jumpIfFalseLabel") + asmgen.out(" bne $jumpIfFalseLabel") } "!=" -> { - asmgen.out(" sty P8ZP_SCRATCH_B1 | ora P8ZP_SCRATCH_B1 | beq $jumpIfFalseLabel") + asmgen.out(" beq $jumpIfFalseLabel") } else -> { // TODO optimize word other operators @@ -109,18 +108,29 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge } DataType.FLOAT -> { asmgen.assignExpressionToRegister(left, RegisterOrPair.FAC1) + asmgen.out(" jsr floats.SIGN") when (operator) { "==" -> { - asmgen.out(" jsr floats.SIGN") // SIGN(fac1) to A, $ff, $0, $1 for negative, zero, positive + // SIGN(fac1) to A, $ff, $0, $1 for negative, zero, positive asmgen.out(" bne $jumpIfFalseLabel") } "!=" -> { - asmgen.out(" jsr floats.SIGN") // SIGN(fac1) to A, $ff, $0, $1 for negative, zero, positive asmgen.out(" beq $jumpIfFalseLabel") } + ">" -> { + asmgen.out(" bmi $jumpIfFalseLabel | beq $jumpIfFalseLabel") + } + "<" -> { + asmgen.out(" bpl $jumpIfFalseLabel") + } + ">=" -> { + asmgen.out(" bmi $jumpIfFalseLabel") + } + "<=" -> { + asmgen.out(" cmp #1 | beq $jumpIfFalseLabel") + } else -> { - // TODO optimize float other operators? - jumpIfComparison(left, operator, right, jumpIfFalseLabel, leftConstVal, rightConstVal) + throw AssemblyError("invalid comparison operator $operator") } } } diff --git a/examples/test.p8 b/examples/test.p8 index d6e49eff1..05b93a13c 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -10,19 +10,73 @@ main { ff=0 - if ff==0 { - txt.print("ff=0\n") + if ff>=0 { + txt.print("ff>=0\n") + } else { + txt.print("error1\n") } - if ff!=0 { - txt.print("ff!=0 (error!)\n") + if ff<=0 { + txt.print("ff<=0\n") + } else { + txt.print("error1\n") } - ff=-0.22 - if ff==0 { - txt.print("ff=0 (error!)\n") + if ff>0 { + txt.print("ff>0 error\n") + } else { + txt.print("ok1\n") } - if ff!=0 { - txt.print("ff!=0\n") + if ff<0 { + txt.print("ff<0 error\n") + } else { + txt.print("ok1\n") } + txt.nl() + + ff=0.22 + if ff>=0 { + txt.print("ff>=0\n") + } else { + txt.print("error2\n") + } + if ff<=0 { + txt.print("ff<=0 error\n") + } else { + txt.print("ok2\n") + } + if ff>0 { + txt.print("ff>0\n") + } else { + txt.print("error2\n") + } + if ff<0 { + txt.print("ff<0 error\n") + } else { + txt.print("ok2\n") + } + txt.nl() + + ff=-1.11 + if ff>=0 { + txt.print("ff>=0 error\n") + } else { + txt.print("ok3\n") + } + if ff<=0 { + txt.print("ff<=0\n") + } else { + txt.print("error3\n") + } + if ff>0 { + txt.print("ff>0 error\n") + } else { + txt.print("ok3\n") + } + if ff<0 { + txt.print("ff<0\n") + } else { + txt.print("error3\n") + } + if xx { ; doesn't use stack...