mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-22 08:32:29 +00:00
Long comparisons; word vs unsigned byte comparison optimization
This commit is contained in:
parent
2548822b8b
commit
e0c3a566b7
@ -122,18 +122,18 @@ Note that currently in cases like `a < f() < b`, `f()` will be evaluated twice!
|
||||
|
||||
* `==`: equality
|
||||
`byte == byte`
|
||||
`word == word`
|
||||
`long == long`
|
||||
`simple word == simple word`
|
||||
`simple long == simple long`
|
||||
|
||||
* `!=`: inequality
|
||||
`byte != byte`
|
||||
`word != word`
|
||||
`long != long`
|
||||
`simple word != simple word`
|
||||
`simple long != simple long`
|
||||
|
||||
* `>`, `<`, `<=`, `>=`: inequality
|
||||
`byte > byte`
|
||||
`word > word`
|
||||
`long > long`
|
||||
`simple word > simple word`
|
||||
`simple long > simple long`
|
||||
|
||||
Currently, `>`, `<`, `<=`, `>=` operators perform unsigned comparison
|
||||
if none of the types of their arguments is signed,
|
||||
|
@ -465,8 +465,8 @@ object BuiltIns {
|
||||
val lva = env.get[VariableInMemory](v.name)
|
||||
(AssemblyLine.variable(ctx, STA, lva, 1),
|
||||
AssemblyLine.variable(ctx, STA, lva, 0),
|
||||
List(AssemblyLine.immediate(STA, rc.hiByte)),
|
||||
List(AssemblyLine.immediate(STA, rc.loByte)))
|
||||
List(AssemblyLine.immediate(STA, rc.hiByte.quickSimplify)),
|
||||
List(AssemblyLine.immediate(STA, rc.loByte.quickSimplify)))
|
||||
case (lv: VariableExpression, None, rv: VariableExpression, None) =>
|
||||
val lva = env.get[VariableInMemory](lv.name)
|
||||
val rva = env.get[VariableInMemory](rv.name)
|
||||
@ -474,9 +474,27 @@ object BuiltIns {
|
||||
AssemblyLine.variable(ctx, STA, lva, 0),
|
||||
AssemblyLine.variable(ctx, STA, rva, 1),
|
||||
AssemblyLine.variable(ctx, STA, rva, 0))
|
||||
case _ =>
|
||||
// TODO comparing expressions
|
||||
ErrorReporting.error("Too complex expressions in comparison", lhs.position)
|
||||
(Nil, Nil, Nil, Nil)
|
||||
}
|
||||
val lType = ExpressionCompiler.getExpressionType(ctx, lhs)
|
||||
val rType = ExpressionCompiler.getExpressionType(ctx, rhs)
|
||||
val compactEqualityComparison = if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) {
|
||||
None
|
||||
} else if (lType.size == 1 && !lType.isSigned) {
|
||||
Some(staTo(LDA, ll) ++ staTo(EOR, rl) ++ staTo(ORA, rh))
|
||||
} else if (rType.size == 1 && !rType.isSigned) {
|
||||
Some(staTo(LDA, rl) ++ staTo(EOR, ll) ++ staTo(ORA, lh))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
effectiveComparisonType match {
|
||||
case ComparisonType.Equal =>
|
||||
compactEqualityComparison match {
|
||||
case Some(code) => code :+ AssemblyLine.relative(BEQ, Label(x))
|
||||
case None =>
|
||||
val innerLabel = MfCompiler.nextLabel("cp")
|
||||
staTo(LDA, ll) ++
|
||||
staTo(CMP, rl) ++
|
||||
@ -486,14 +504,19 @@ object BuiltIns {
|
||||
List(
|
||||
AssemblyLine.relative(BEQ, Label(x)),
|
||||
AssemblyLine.label(innerLabel))
|
||||
}
|
||||
|
||||
case ComparisonType.NotEqual =>
|
||||
compactEqualityComparison match {
|
||||
case Some(code) => code :+ AssemblyLine.relative(BNE, Label(x))
|
||||
case None =>
|
||||
staTo(LDA, ll) ++
|
||||
staTo(CMP, rl) ++
|
||||
List(AssemblyLine.relative(BNE, Label(x))) ++
|
||||
staTo(LDA, lh) ++
|
||||
staTo(CMP, rh) ++
|
||||
List(AssemblyLine.relative(BNE, Label(x)))
|
||||
}
|
||||
|
||||
case ComparisonType.LessUnsigned =>
|
||||
val innerLabel = MfCompiler.nextLabel("cp")
|
||||
@ -546,6 +569,91 @@ object BuiltIns {
|
||||
}
|
||||
}
|
||||
|
||||
def compileLongComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, size:Int, branches: BranchSpec, alreadyFlipped: Boolean = false): List[AssemblyLine] = {
|
||||
val rType = ExpressionCompiler.getExpressionType(ctx, rhs)
|
||||
if (rType.size < size && rType.isSigned) {
|
||||
if (alreadyFlipped) ???
|
||||
else return compileLongComparison(ctx, ComparisonType.flip(compType), rhs, lhs, size, branches, alreadyFlipped = true)
|
||||
}
|
||||
|
||||
val (effectiveComparisonType, label) = branches match {
|
||||
case NoBranching => return Nil
|
||||
case BranchIfTrue(x) => compType -> x
|
||||
case BranchIfFalse(x) => ComparisonType.negate(compType) -> x
|
||||
}
|
||||
|
||||
// TODO: check for carry flag clobbering
|
||||
val l = getLoadForEachByte(ctx, lhs, size)
|
||||
val r = getLoadForEachByte(ctx, rhs, size)
|
||||
|
||||
val mask = (1L << (size * 8)) - 1
|
||||
(ctx.env.eval(lhs), ctx.env.eval(rhs)) match {
|
||||
case (Some(NumericConstant(lc, _)), Some(NumericConstant(rc, _))) =>
|
||||
return if (effectiveComparisonType match {
|
||||
// TODO: those masks are probably wrong
|
||||
case ComparisonType.Equal =>
|
||||
(lc & mask) == (rc & mask) // ??
|
||||
case ComparisonType.NotEqual =>
|
||||
(lc & mask) != (rc & mask) // ??
|
||||
|
||||
case ComparisonType.LessOrEqualUnsigned =>
|
||||
(lc & mask) <= (rc & mask)
|
||||
case ComparisonType.GreaterOrEqualUnsigned =>
|
||||
(lc & mask) >= (rc & mask)
|
||||
case ComparisonType.GreaterUnsigned =>
|
||||
(lc & mask) > (rc & mask)
|
||||
case ComparisonType.LessUnsigned =>
|
||||
(lc & mask) < (rc & mask)
|
||||
|
||||
case ComparisonType.LessOrEqualSigned =>
|
||||
signExtend(lc, mask) <= signExtend(lc, mask)
|
||||
case ComparisonType.GreaterOrEqualSigned =>
|
||||
signExtend(lc, mask) >= signExtend(lc, mask)
|
||||
case ComparisonType.GreaterSigned =>
|
||||
signExtend(lc, mask) > signExtend(lc, mask)
|
||||
case ComparisonType.LessSigned =>
|
||||
signExtend(lc, mask) < signExtend(lc, mask)
|
||||
}) List(AssemblyLine.absolute(JMP, Label(label))) else Nil
|
||||
case _ =>
|
||||
effectiveComparisonType match {
|
||||
case ComparisonType.Equal =>
|
||||
val innerLabel = MfCompiler.nextLabel("cp")
|
||||
val bytewise = l.zip(r).map{
|
||||
case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR)
|
||||
}
|
||||
bytewise.init.flatMap(b => b :+ AssemblyLine.relative(BNE, innerLabel)) ++ bytewise.last ++List(
|
||||
AssemblyLine.relative(BEQ, Label(label)),
|
||||
AssemblyLine.label(innerLabel))
|
||||
case ComparisonType.NotEqual =>
|
||||
l.zip(r).flatMap {
|
||||
case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR) :+ AssemblyLine.relative(BNE, label)
|
||||
}
|
||||
case ComparisonType.LessUnsigned =>
|
||||
val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{
|
||||
case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR)
|
||||
}
|
||||
calculateCarry ++ List(AssemblyLine.relative(BCC, Label(label)))
|
||||
case ComparisonType.GreaterOrEqualUnsigned =>
|
||||
val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{
|
||||
case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR)
|
||||
}
|
||||
calculateCarry ++ List(AssemblyLine.relative(BCS, Label(label)))
|
||||
case ComparisonType.GreaterUnsigned | ComparisonType.LessOrEqualUnsigned =>
|
||||
compileLongComparison(ctx, ComparisonType.flip(compType), rhs, lhs, size, branches, alreadyFlipped = true)
|
||||
case _ =>
|
||||
ErrorReporting.error("Long signed comparisons are not yet supported", lhs.position)
|
||||
Nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private def signExtend(value: Long, mask: Long): Long = {
|
||||
val masked = value & mask
|
||||
if (masked > mask/2) masked | ~mask
|
||||
else masked
|
||||
}
|
||||
|
||||
def compileInPlaceByteMultiplication(ctx: CompilationContext, v: LhsExpression, addend: Expression): List[AssemblyLine] = {
|
||||
val b = ctx.env.get[Type]("byte")
|
||||
ctx.env.eval(addend) match {
|
||||
@ -1068,6 +1176,47 @@ object BuiltIns {
|
||||
???
|
||||
}
|
||||
}
|
||||
private def getLoadForEachByte(ctx: CompilationContext, expr: Expression, size: Int): List[List[AssemblyLine]] = {
|
||||
val env = ctx.env
|
||||
env.eval(expr) match {
|
||||
case Some(c) =>
|
||||
List.tabulate(size) { i => List(AssemblyLine.immediate(STA, c.subbyte(i))) }
|
||||
case None =>
|
||||
expr match {
|
||||
case v: VariableExpression =>
|
||||
val variable = env.get[Variable](v.name)
|
||||
List.tabulate(size) { i =>
|
||||
if (i < variable.typ.size) {
|
||||
AssemblyLine.variable(ctx, STA, variable, i)
|
||||
} else if (variable.typ.isSigned) {
|
||||
val label = MfCompiler.nextLabel("sx")
|
||||
AssemblyLine.variable(ctx, STA, variable, i) ++ List(
|
||||
AssemblyLine.immediate(ORA, 0x7F),
|
||||
AssemblyLine.relative(BMI, label),
|
||||
AssemblyLine.immediate(STA, 0),
|
||||
AssemblyLine.label(label))
|
||||
} else List(AssemblyLine.immediate(STA, 0))
|
||||
}
|
||||
case expr@IndexedExpression(variable, index) =>
|
||||
List.tabulate(size) { i =>
|
||||
if (i == 0) ExpressionCompiler.compileByteStorage(ctx, Register.A, expr)
|
||||
else List(AssemblyLine.immediate(STA, 0))
|
||||
}
|
||||
case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) =>
|
||||
if (simplicity(ctx.env, h) < 'J' || simplicity(ctx.env, l) < 'J') {
|
||||
// a[b]:c[d] is the most complex expression that doesn't cause the following warning
|
||||
ErrorReporting.warn("Too complex expression given to the `:` operator, generated code might be wrong", ctx.options, expr.position)
|
||||
}
|
||||
List.tabulate(size) { i =>
|
||||
if (i == 0) getStorageForEachByte(ctx, l).head
|
||||
else if (i == 1) ExpressionCompiler.preserveRegisterIfNeeded(ctx, Register.A, getStorageForEachByte(ctx, h).head)
|
||||
else List(AssemblyLine.immediate(STA, 0))
|
||||
}
|
||||
case _ =>
|
||||
???
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def removeTsx(codes: List[List[AssemblyLine]]): List[List[AssemblyLine]] = codes.map {
|
||||
case List(AssemblyLine(TSX, _, _, _), AssemblyLine(op, AbsoluteX, NumericConstant(nn, _), _)) if nn >= 0x100 && nn <= 0x1ff =>
|
||||
|
@ -931,6 +931,7 @@ object ExpressionCompiler {
|
||||
size match {
|
||||
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
|
||||
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
|
||||
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, size, branches)
|
||||
}
|
||||
}
|
||||
case ">=" =>
|
||||
@ -940,6 +941,7 @@ object ExpressionCompiler {
|
||||
size match {
|
||||
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
|
||||
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
|
||||
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, size, branches)
|
||||
}
|
||||
}
|
||||
case ">" =>
|
||||
@ -949,6 +951,7 @@ object ExpressionCompiler {
|
||||
size match {
|
||||
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
|
||||
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
|
||||
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, size, branches)
|
||||
}
|
||||
}
|
||||
case "<=" =>
|
||||
@ -958,6 +961,7 @@ object ExpressionCompiler {
|
||||
size match {
|
||||
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
|
||||
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
|
||||
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, size, branches)
|
||||
}
|
||||
}
|
||||
case "==" =>
|
||||
@ -966,6 +970,7 @@ object ExpressionCompiler {
|
||||
size match {
|
||||
case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches)
|
||||
case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches)
|
||||
case _ => BuiltIns.compileLongComparison(ctx, ComparisonType.Equal, l, r, size, branches)
|
||||
}
|
||||
}
|
||||
case "!=" =>
|
||||
@ -973,6 +978,7 @@ object ExpressionCompiler {
|
||||
size match {
|
||||
case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.NotEqual, l, r, branches)
|
||||
case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.NotEqual, l, r, branches)
|
||||
case _ => BuiltIns.compileLongComparison(ctx, ComparisonType.NotEqual, l, r, size, branches)
|
||||
}
|
||||
case "+=" =>
|
||||
val (l, r, size) = assertAssignmentLike(ctx, params)
|
||||
|
@ -23,7 +23,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| output += 78
|
||||
| }
|
||||
| }
|
||||
""".stripMargin)(_.readWord(0xc000) should equal(6))
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(6))
|
||||
}
|
||||
|
||||
test("Less") {
|
||||
@ -36,7 +36,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| output += 1
|
||||
| }
|
||||
| }
|
||||
""".stripMargin)(_.readWord(0xc000) should equal(150))
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(150))
|
||||
}
|
||||
|
||||
test("Compare to zero") {
|
||||
@ -51,7 +51,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| output += 1
|
||||
| }
|
||||
| }
|
||||
""".stripMargin)(_.readWord(0xc000) should equal(150))
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(150))
|
||||
}
|
||||
|
||||
test("Carry flag optimization test") {
|
||||
@ -71,7 +71,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| byte get(byte x) {
|
||||
| if x >= 6 {return 0} else {return 128}
|
||||
| }
|
||||
""".stripMargin)(_.readWord(0xc000) should equal(4))
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(4))
|
||||
}
|
||||
|
||||
test("Does it even work") {
|
||||
@ -98,7 +98,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if 2222 == 3333 { output -= 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
|
||||
test("Word comparison == and !=") {
|
||||
@ -118,9 +118,10 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if a != c { output += 1 }
|
||||
| if a != 5 { output += 1 }
|
||||
| if a != 260 { output += 1 }
|
||||
| if a != 0 { output += 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
|
||||
test("Word comparison <=") {
|
||||
@ -141,7 +142,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if a <= c { output += 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
test("Word comparison <") {
|
||||
val src =
|
||||
@ -160,7 +161,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if a < 257 { output += 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
|
||||
|
||||
@ -181,7 +182,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if c > 0 { output += 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
|
||||
test("Word comparison >=") {
|
||||
@ -204,7 +205,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if a >= 0 { output += 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
|
||||
test("Signed comparison >=") {
|
||||
@ -227,7 +228,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if a >= 0 { output += 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
|
||||
test("Signed comparison < and <=") {
|
||||
@ -259,7 +260,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| if c <= -1 { output -= 7 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
|
||||
test("Multiple params for equality") {
|
||||
@ -275,7 +276,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| output += 78
|
||||
| }
|
||||
| }
|
||||
""".stripMargin)(_.readWord(0xc000) should equal(6))
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(6))
|
||||
}
|
||||
|
||||
test("Multiple params for inequality") {
|
||||
@ -291,7 +292,7 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| output += 78
|
||||
| }
|
||||
| }
|
||||
""".stripMargin)(_.readWord(0xc000) should equal(6))
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(6))
|
||||
}
|
||||
|
||||
test("Warnings") {
|
||||
@ -310,6 +311,33 @@ class ComparisonSuite extends FunSuite with Matchers {
|
||||
| byte three() {
|
||||
| return 3
|
||||
| }
|
||||
""".stripMargin)(_.readWord(0xc000) should equal(6))
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(6))
|
||||
}
|
||||
|
||||
test("Long comparisons") {
|
||||
val src =
|
||||
"""
|
||||
| byte output @$c000
|
||||
| void main () {
|
||||
| long a
|
||||
| long b
|
||||
| long c
|
||||
| output = 0
|
||||
| a = 1234567
|
||||
| b = 2345678
|
||||
| c = 1234599
|
||||
| if a == a { output += 1 }
|
||||
| if c >= a { output += 1 }
|
||||
| if c != a { output += 1 }
|
||||
| if a <= c { output += 1 }
|
||||
| if a < c { output += 1 }
|
||||
| if b >= a { output += 1 }
|
||||
| if b >= 0 { output += 1 }
|
||||
| if b >= 44564 { output += 1 }
|
||||
| if a >= 335444 { output += 1 }
|
||||
| if c > 335444 { output += 1 }
|
||||
| }
|
||||
""".stripMargin
|
||||
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user