1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-22 08:32:29 +00:00

Long comparisons; word vs unsigned byte comparison optimization

This commit is contained in:
Karol Stasiak 2018-03-16 13:19:54 +01:00
parent 2548822b8b
commit e0c3a566b7
4 changed files with 221 additions and 38 deletions

View File

@ -122,18 +122,18 @@ Note that currently in cases like `a < f() < b`, `f()` will be evaluated twice!
* `==`: equality
`byte == byte`
`word == word`
`long == long`
`simple word == simple word`
`simple long == simple long`
* `!=`: inequality
`byte != byte`
`word != word`
`long != long`
`simple word != simple word`
`simple long != simple long`
* `>`, `<`, `<=`, `>=`: inequality
`byte > byte`
`word > word`
`long > long`
`simple word > simple word`
`simple long > simple long`
Currently, `>`, `<`, `<=`, `>=` operators perform unsigned comparison
if none of the types of their arguments is signed,

View File

@ -465,8 +465,8 @@ object BuiltIns {
val lva = env.get[VariableInMemory](v.name)
(AssemblyLine.variable(ctx, STA, lva, 1),
AssemblyLine.variable(ctx, STA, lva, 0),
List(AssemblyLine.immediate(STA, rc.hiByte)),
List(AssemblyLine.immediate(STA, rc.loByte)))
List(AssemblyLine.immediate(STA, rc.hiByte.quickSimplify)),
List(AssemblyLine.immediate(STA, rc.loByte.quickSimplify)))
case (lv: VariableExpression, None, rv: VariableExpression, None) =>
val lva = env.get[VariableInMemory](lv.name)
val rva = env.get[VariableInMemory](rv.name)
@ -474,9 +474,27 @@ object BuiltIns {
AssemblyLine.variable(ctx, STA, lva, 0),
AssemblyLine.variable(ctx, STA, rva, 1),
AssemblyLine.variable(ctx, STA, rva, 0))
case _ =>
// TODO comparing expressions
ErrorReporting.error("Too complex expressions in comparison", lhs.position)
(Nil, Nil, Nil, Nil)
}
val lType = ExpressionCompiler.getExpressionType(ctx, lhs)
val rType = ExpressionCompiler.getExpressionType(ctx, rhs)
val compactEqualityComparison = if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) {
None
} else if (lType.size == 1 && !lType.isSigned) {
Some(staTo(LDA, ll) ++ staTo(EOR, rl) ++ staTo(ORA, rh))
} else if (rType.size == 1 && !rType.isSigned) {
Some(staTo(LDA, rl) ++ staTo(EOR, ll) ++ staTo(ORA, lh))
} else {
None
}
effectiveComparisonType match {
case ComparisonType.Equal =>
compactEqualityComparison match {
case Some(code) => code :+ AssemblyLine.relative(BEQ, Label(x))
case None =>
val innerLabel = MfCompiler.nextLabel("cp")
staTo(LDA, ll) ++
staTo(CMP, rl) ++
@ -486,14 +504,19 @@ object BuiltIns {
List(
AssemblyLine.relative(BEQ, Label(x)),
AssemblyLine.label(innerLabel))
}
case ComparisonType.NotEqual =>
compactEqualityComparison match {
case Some(code) => code :+ AssemblyLine.relative(BNE, Label(x))
case None =>
staTo(LDA, ll) ++
staTo(CMP, rl) ++
List(AssemblyLine.relative(BNE, Label(x))) ++
staTo(LDA, lh) ++
staTo(CMP, rh) ++
List(AssemblyLine.relative(BNE, Label(x)))
}
case ComparisonType.LessUnsigned =>
val innerLabel = MfCompiler.nextLabel("cp")
@ -546,6 +569,91 @@ object BuiltIns {
}
}
def compileLongComparison(ctx: CompilationContext, compType: ComparisonType.Value, lhs: Expression, rhs: Expression, size:Int, branches: BranchSpec, alreadyFlipped: Boolean = false): List[AssemblyLine] = {
val rType = ExpressionCompiler.getExpressionType(ctx, rhs)
if (rType.size < size && rType.isSigned) {
if (alreadyFlipped) ???
else return compileLongComparison(ctx, ComparisonType.flip(compType), rhs, lhs, size, branches, alreadyFlipped = true)
}
val (effectiveComparisonType, label) = branches match {
case NoBranching => return Nil
case BranchIfTrue(x) => compType -> x
case BranchIfFalse(x) => ComparisonType.negate(compType) -> x
}
// TODO: check for carry flag clobbering
val l = getLoadForEachByte(ctx, lhs, size)
val r = getLoadForEachByte(ctx, rhs, size)
val mask = (1L << (size * 8)) - 1
(ctx.env.eval(lhs), ctx.env.eval(rhs)) match {
case (Some(NumericConstant(lc, _)), Some(NumericConstant(rc, _))) =>
return if (effectiveComparisonType match {
// TODO: those masks are probably wrong
case ComparisonType.Equal =>
(lc & mask) == (rc & mask) // ??
case ComparisonType.NotEqual =>
(lc & mask) != (rc & mask) // ??
case ComparisonType.LessOrEqualUnsigned =>
(lc & mask) <= (rc & mask)
case ComparisonType.GreaterOrEqualUnsigned =>
(lc & mask) >= (rc & mask)
case ComparisonType.GreaterUnsigned =>
(lc & mask) > (rc & mask)
case ComparisonType.LessUnsigned =>
(lc & mask) < (rc & mask)
case ComparisonType.LessOrEqualSigned =>
signExtend(lc, mask) <= signExtend(lc, mask)
case ComparisonType.GreaterOrEqualSigned =>
signExtend(lc, mask) >= signExtend(lc, mask)
case ComparisonType.GreaterSigned =>
signExtend(lc, mask) > signExtend(lc, mask)
case ComparisonType.LessSigned =>
signExtend(lc, mask) < signExtend(lc, mask)
}) List(AssemblyLine.absolute(JMP, Label(label))) else Nil
case _ =>
effectiveComparisonType match {
case ComparisonType.Equal =>
val innerLabel = MfCompiler.nextLabel("cp")
val bytewise = l.zip(r).map{
case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR)
}
bytewise.init.flatMap(b => b :+ AssemblyLine.relative(BNE, innerLabel)) ++ bytewise.last ++List(
AssemblyLine.relative(BEQ, Label(label)),
AssemblyLine.label(innerLabel))
case ComparisonType.NotEqual =>
l.zip(r).flatMap {
case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR) :+ AssemblyLine.relative(BNE, label)
}
case ComparisonType.LessUnsigned =>
val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{
case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR)
}
calculateCarry ++ List(AssemblyLine.relative(BCC, Label(label)))
case ComparisonType.GreaterOrEqualUnsigned =>
val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{
case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR)
}
calculateCarry ++ List(AssemblyLine.relative(BCS, Label(label)))
case ComparisonType.GreaterUnsigned | ComparisonType.LessOrEqualUnsigned =>
compileLongComparison(ctx, ComparisonType.flip(compType), rhs, lhs, size, branches, alreadyFlipped = true)
case _ =>
ErrorReporting.error("Long signed comparisons are not yet supported", lhs.position)
Nil
}
}
}
private def signExtend(value: Long, mask: Long): Long = {
val masked = value & mask
if (masked > mask/2) masked | ~mask
else masked
}
def compileInPlaceByteMultiplication(ctx: CompilationContext, v: LhsExpression, addend: Expression): List[AssemblyLine] = {
val b = ctx.env.get[Type]("byte")
ctx.env.eval(addend) match {
@ -1068,6 +1176,47 @@ object BuiltIns {
???
}
}
private def getLoadForEachByte(ctx: CompilationContext, expr: Expression, size: Int): List[List[AssemblyLine]] = {
val env = ctx.env
env.eval(expr) match {
case Some(c) =>
List.tabulate(size) { i => List(AssemblyLine.immediate(STA, c.subbyte(i))) }
case None =>
expr match {
case v: VariableExpression =>
val variable = env.get[Variable](v.name)
List.tabulate(size) { i =>
if (i < variable.typ.size) {
AssemblyLine.variable(ctx, STA, variable, i)
} else if (variable.typ.isSigned) {
val label = MfCompiler.nextLabel("sx")
AssemblyLine.variable(ctx, STA, variable, i) ++ List(
AssemblyLine.immediate(ORA, 0x7F),
AssemblyLine.relative(BMI, label),
AssemblyLine.immediate(STA, 0),
AssemblyLine.label(label))
} else List(AssemblyLine.immediate(STA, 0))
}
case expr@IndexedExpression(variable, index) =>
List.tabulate(size) { i =>
if (i == 0) ExpressionCompiler.compileByteStorage(ctx, Register.A, expr)
else List(AssemblyLine.immediate(STA, 0))
}
case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) =>
if (simplicity(ctx.env, h) < 'J' || simplicity(ctx.env, l) < 'J') {
// a[b]:c[d] is the most complex expression that doesn't cause the following warning
ErrorReporting.warn("Too complex expression given to the `:` operator, generated code might be wrong", ctx.options, expr.position)
}
List.tabulate(size) { i =>
if (i == 0) getStorageForEachByte(ctx, l).head
else if (i == 1) ExpressionCompiler.preserveRegisterIfNeeded(ctx, Register.A, getStorageForEachByte(ctx, h).head)
else List(AssemblyLine.immediate(STA, 0))
}
case _ =>
???
}
}
}
private def removeTsx(codes: List[List[AssemblyLine]]): List[List[AssemblyLine]] = codes.map {
case List(AssemblyLine(TSX, _, _, _), AssemblyLine(op, AbsoluteX, NumericConstant(nn, _), _)) if nn >= 0x100 && nn <= 0x1ff =>

View File

@ -931,6 +931,7 @@ object ExpressionCompiler {
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, size, branches)
}
}
case ">=" =>
@ -940,6 +941,7 @@ object ExpressionCompiler {
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, size, branches)
}
}
case ">" =>
@ -949,6 +951,7 @@ object ExpressionCompiler {
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, size, branches)
}
}
case "<=" =>
@ -958,6 +961,7 @@ object ExpressionCompiler {
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
case _ => BuiltIns.compileLongComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, size, branches)
}
}
case "==" =>
@ -966,6 +970,7 @@ object ExpressionCompiler {
size match {
case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches)
case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches)
case _ => BuiltIns.compileLongComparison(ctx, ComparisonType.Equal, l, r, size, branches)
}
}
case "!=" =>
@ -973,6 +978,7 @@ object ExpressionCompiler {
size match {
case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.NotEqual, l, r, branches)
case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.NotEqual, l, r, branches)
case _ => BuiltIns.compileLongComparison(ctx, ComparisonType.NotEqual, l, r, size, branches)
}
case "+=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)

View File

@ -23,7 +23,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| output += 78
| }
| }
""".stripMargin)(_.readWord(0xc000) should equal(6))
""".stripMargin)(_.readByte(0xc000) should equal(6))
}
test("Less") {
@ -36,7 +36,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| output += 1
| }
| }
""".stripMargin)(_.readWord(0xc000) should equal(150))
""".stripMargin)(_.readByte(0xc000) should equal(150))
}
test("Compare to zero") {
@ -51,7 +51,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| output += 1
| }
| }
""".stripMargin)(_.readWord(0xc000) should equal(150))
""".stripMargin)(_.readByte(0xc000) should equal(150))
}
test("Carry flag optimization test") {
@ -71,7 +71,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| byte get(byte x) {
| if x >= 6 {return 0} else {return 128}
| }
""".stripMargin)(_.readWord(0xc000) should equal(4))
""".stripMargin)(_.readByte(0xc000) should equal(4))
}
test("Does it even work") {
@ -98,7 +98,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| if 2222 == 3333 { output -= 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Word comparison == and !=") {
@ -118,9 +118,10 @@ class ComparisonSuite extends FunSuite with Matchers {
| if a != c { output += 1 }
| if a != 5 { output += 1 }
| if a != 260 { output += 1 }
| if a != 0 { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Word comparison <=") {
@ -141,7 +142,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| if a <= c { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Word comparison <") {
val src =
@ -160,7 +161,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| if a < 257 { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
@ -181,7 +182,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| if c > 0 { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Word comparison >=") {
@ -204,7 +205,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| if a >= 0 { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Signed comparison >=") {
@ -227,7 +228,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| if a >= 0 { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Signed comparison < and <=") {
@ -259,7 +260,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| if c <= -1 { output -= 7 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+')))
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Multiple params for equality") {
@ -275,7 +276,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| output += 78
| }
| }
""".stripMargin)(_.readWord(0xc000) should equal(6))
""".stripMargin)(_.readByte(0xc000) should equal(6))
}
test("Multiple params for inequality") {
@ -291,7 +292,7 @@ class ComparisonSuite extends FunSuite with Matchers {
| output += 78
| }
| }
""".stripMargin)(_.readWord(0xc000) should equal(6))
""".stripMargin)(_.readByte(0xc000) should equal(6))
}
test("Warnings") {
@ -310,6 +311,33 @@ class ComparisonSuite extends FunSuite with Matchers {
| byte three() {
| return 3
| }
""".stripMargin)(_.readWord(0xc000) should equal(6))
""".stripMargin)(_.readByte(0xc000) should equal(6))
}
test("Long comparisons") {
val src =
"""
| byte output @$c000
| void main () {
| long a
| long b
| long c
| output = 0
| a = 1234567
| b = 2345678
| c = 1234599
| if a == a { output += 1 }
| if c >= a { output += 1 }
| if c != a { output += 1 }
| if a <= c { output += 1 }
| if a < c { output += 1 }
| if b >= a { output += 1 }
| if b >= 0 { output += 1 }
| if b >= 44564 { output += 1 }
| if a >= 335444 { output += 1 }
| if c > 335444 { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
}