1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-12-23 08:29:35 +00:00

Comparison fixes and improvements

This commit is contained in:
Karol Stasiak 2018-06-19 00:00:48 +02:00
parent bf1a3a6677
commit 992ea7090e
7 changed files with 176 additions and 51 deletions

View File

@ -4,7 +4,7 @@
* Preliminary Atari 2600, BBC Micro and LUnix support.
* Added array initialization syntax with `for` (not yet finalized).
* Added array initialization syntax with `for`.
* Added multiple new text codecs.
@ -20,6 +20,8 @@
* Fixed broken `downto` loops.
* Fixed broken comparisons between variables of different sizes.
* Fixed several other bugs.
* Other improvements.

View File

@ -369,10 +369,10 @@ object AssemblyLine {
private val opcodesForNopVariableOperation = Set(STA, SAX, STX, STY, STZ)
private val opcodesForZeroedVariableOperation = Set(ADC, EOR, ORA, AND, SBC, CMP, CPX, CPY)
private val opcodesForZeroedOrSignExtendedVariableOperation = Set(LDA, LDX, LDY)
private val opcodesForZeroedOrSignExtendedVariableOperation = Set(LDA, LDX, LDY, LDZ)
def variable(ctx: CompilationContext, opcode: Opcode.Value, variable: Variable, offset: Int = 0): List[AssemblyLine] =
if (offset > variable.typ.size) {
if (offset >= variable.typ.size) {
if (opcodesForNopVariableOperation(opcode)) {
Nil
} else if (opcodesForZeroedVariableOperation(opcode)) {
@ -381,11 +381,16 @@ object AssemblyLine {
} else if (opcodesForZeroedOrSignExtendedVariableOperation(opcode)) {
if (variable.typ.isSigned) {
val label = MosCompiler.nextLabel("sx")
AssemblyLine.variable(ctx, opcode, variable, variable.typ.size - 1) ++ List(
AssemblyLine.variable(ctx, LDA, variable, variable.typ.size - 1) ++ List(
AssemblyLine.immediate(ORA, 0x7f),
AssemblyLine.relative(BMI, label),
AssemblyLine.immediate(LDA, 0),
AssemblyLine.label(label))
AssemblyLine.label(label)) ++ (opcode match {
case LDA => Nil
case LDX | LAX => List(AssemblyLine.implied(TAX))
case LDY => List(AssemblyLine.implied(TAY))
case LDZ => List(AssemblyLine.implied(TAZ))
})
} else {
List(AssemblyLine.immediate(opcode, 0))
}

View File

@ -1791,6 +1791,62 @@ object AlwaysGoodOptimizations {
(Elidable & HasSourceOfNZ(State.X) & HasOpcode(CPX) & HasImmediate(0) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_.init),
(Elidable & HasSourceOfNZ(State.Y) & HasOpcode(CPY) & HasImmediate(0) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_.init),
(Elidable & HasSourceOfNZ(State.IZ) & HasOpcode(CPZ) & HasImmediate(0) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_.init),
(Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code =>
List(code.head.copy(opcode = LDA), code(1).copy(opcode = BNE))
},
(Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code =>
List(code.head.copy(opcode = LDA), code(1).copy(opcode = BEQ))
},
(Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code =>
List(code.head.copy(opcode = LDX), code(1).copy(opcode = BNE))
},
(Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code =>
List(code.head.copy(opcode = LDX), code(1).copy(opcode = BEQ))
},
(Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code =>
List(code.head.copy(opcode = LDY), code(1).copy(opcode = BNE))
},
(Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code =>
List(code.head.copy(opcode = LDY), code(1).copy(opcode = BEQ))
},
(Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code =>
List(code.head.copy(opcode = LDX), code(1).copy(opcode = BNE))
},
(Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code =>
List(code.head.copy(opcode = LDX), code(1).copy(opcode = BEQ))
},
(Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code =>
List(code.head.copy(opcode = LDY), code(1).copy(opcode = BNE))
},
(Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code =>
List(code.head.copy(opcode = LDY), code(1).copy(opcode = BEQ))
},
(Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code =>
List(code.head.copy(opcode = LDA), code(1).copy(opcode = BNE))
},
(Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code =>
List(code.head.copy(opcode = LDA), code(1).copy(opcode = BEQ))
},
(Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code =>
List(code.head.copy(opcode = LDA), code(1).copy(opcode = BNE))
},
(Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code =>
List(code.head.copy(opcode = LDA), code(1).copy(opcode = BEQ))
},
(Elidable & HasZ(0) & HasOpcode(CPZ)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.IZ)) ~~> {code =>
List(code.head.copy(opcode = LDZ), code(1).copy(opcode = BNE))
},
(Elidable & HasZ(0) & HasOpcode(CPZ)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.IZ)) ~~> {code =>
List(code.head.copy(opcode = LDZ), code(1).copy(opcode = BEQ))
},
)
private def remapZ2N(line: AssemblyLine) = line.opcode match {

View File

@ -107,7 +107,25 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] {
val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one)))
val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one)))
val names = Set("", "for", f.variable)
(f.direction, ctx.env.eval(f.start), ctx.env.eval(f.end)) match {
val startEvaluated = ctx.env.eval(f.start)
val endEvaluated = ctx.env.eval(f.end)
ctx.env.maybeGet[Variable](f.variable).foreach{ v=>
startEvaluated.foreach(value => if (!value.quickSimplify.fitsInto(v.typ)) {
ErrorReporting.error(s"Variable `${f.variable}` is too small to hold the initial value in the for loop", f.position)
})
endEvaluated.foreach { value =>
val max = f.direction match {
case ForDirection.To | ForDirection.ParallelTo | ForDirection.DownTo => value
case ForDirection.Until | ForDirection.ParallelUntil => value - 1
case _ => Constant.Zero
}
if (!max.quickSimplify.fitsInto(v.typ)) {
ErrorReporting.error(s"Variable `${f.variable}` is too small to hold the final value in the for loop", f.position)
}
}
}
(f.direction, startEvaluated, endEvaluated) match {
case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e - 1 =>
val end = nextLabel("of")

View File

@ -32,6 +32,8 @@ object BuiltIns {
def staTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == STA) x.copy(opcode = op) else x)
def cmpTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == CMP) x.copy(opcode = op) else x)
def ldTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == LDA || x.opcode == LDX || x.opcode == LDY) x.copy(opcode = op) else x)
def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean, decimal: Boolean = false): List[AssemblyLine] = {
@ -463,17 +465,17 @@ object BuiltIns {
return compileWordComparison(ctx, ComparisonType.flip(compType), rhs, lhs, branches)
case (v: VariableExpression, None, _, Some(rc)) =>
val lva = env.get[VariableInMemory](v.name)
(AssemblyLine.variable(ctx, STA, lva, 1),
AssemblyLine.variable(ctx, STA, lva, 0),
List(AssemblyLine.immediate(STA, rc.hiByte.quickSimplify)),
List(AssemblyLine.immediate(STA, rc.loByte.quickSimplify)))
(AssemblyLine.variable(ctx, CMP, lva, 1),
AssemblyLine.variable(ctx, CMP, lva, 0),
List(AssemblyLine.immediate(CMP, rc.hiByte.quickSimplify)),
List(AssemblyLine.immediate(CMP, rc.loByte.quickSimplify)))
case (lv: VariableExpression, None, rv: VariableExpression, None) =>
val lva = env.get[VariableInMemory](lv.name)
val rva = env.get[VariableInMemory](rv.name)
(AssemblyLine.variable(ctx, STA, lva, 1),
AssemblyLine.variable(ctx, STA, lva, 0),
AssemblyLine.variable(ctx, STA, rva, 1),
AssemblyLine.variable(ctx, STA, rva, 0))
(AssemblyLine.variable(ctx, CMP, lva, 1),
AssemblyLine.variable(ctx, CMP, lva, 0),
AssemblyLine.variable(ctx, CMP, rva, 1),
AssemblyLine.variable(ctx, CMP, rva, 0))
case _ =>
// TODO comparing expressions
ErrorReporting.error("Too complex expressions in comparison", lhs.position)
@ -484,9 +486,9 @@ object BuiltIns {
val compactEqualityComparison = if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) {
None
} else if (lType.size == 1 && !lType.isSigned) {
Some(staTo(LDA, ll) ++ staTo(EOR, rl) ++ staTo(ORA, rh))
Some(cmpTo(LDA, ll) ++ cmpTo(EOR, rl) ++ cmpTo(ORA, rh))
} else if (rType.size == 1 && !rType.isSigned) {
Some(staTo(LDA, rl) ++ staTo(EOR, ll) ++ staTo(ORA, lh))
Some(cmpTo(LDA, rl) ++ cmpTo(EOR, ll) ++ cmpTo(ORA, lh))
} else {
None
}
@ -496,11 +498,11 @@ object BuiltIns {
case Some(code) => code :+ AssemblyLine.relative(BEQ, Label(x))
case None =>
val innerLabel = MosCompiler.nextLabel("cp")
staTo(LDA, ll) ++
staTo(CMP, rl) ++
cmpTo(LDA, ll) ++
cmpTo(CMP, rl) ++
List(AssemblyLine.relative(BNE, innerLabel)) ++
staTo(LDA, lh) ++
staTo(CMP, rh) ++
cmpTo(LDA, lh) ++
cmpTo(CMP, rh) ++
List(
AssemblyLine.relative(BEQ, Label(x)),
AssemblyLine.label(innerLabel))
@ -510,57 +512,57 @@ object BuiltIns {
compactEqualityComparison match {
case Some(code) => code :+ AssemblyLine.relative(BNE, Label(x))
case None =>
staTo(LDA, ll) ++
staTo(CMP, rl) ++
cmpTo(LDA, ll) ++
cmpTo(CMP, rl) ++
List(AssemblyLine.relative(BNE, Label(x))) ++
staTo(LDA, lh) ++
staTo(CMP, rh) ++
cmpTo(LDA, lh) ++
cmpTo(CMP, rh) ++
List(AssemblyLine.relative(BNE, Label(x)))
}
case ComparisonType.LessUnsigned =>
val innerLabel = MosCompiler.nextLabel("cp")
staTo(LDA, lh) ++
staTo(CMP, rh) ++
cmpTo(LDA, lh) ++
cmpTo(CMP, rh) ++
List(
AssemblyLine.relative(BCC, Label(x)),
AssemblyLine.relative(BNE, innerLabel)) ++
staTo(LDA, ll) ++
staTo(CMP, rl) ++
cmpTo(LDA, ll) ++
cmpTo(CMP, rl) ++
List(
AssemblyLine.relative(BCC, Label(x)),
AssemblyLine.label(innerLabel))
case ComparisonType.LessOrEqualUnsigned =>
val innerLabel = MosCompiler.nextLabel("cp")
staTo(LDA, rh) ++
staTo(CMP, lh) ++
cmpTo(LDA, rh) ++
cmpTo(CMP, lh) ++
List(AssemblyLine.relative(BCC, innerLabel),
AssemblyLine.relative(BNE, x)) ++
staTo(LDA, rl) ++
staTo(CMP, ll) ++
cmpTo(LDA, rl) ++
cmpTo(CMP, ll) ++
List(AssemblyLine.relative(BCS, x),
AssemblyLine.label(innerLabel))
case ComparisonType.GreaterUnsigned =>
val innerLabel = MosCompiler.nextLabel("cp")
staTo(LDA, rh) ++
staTo(CMP, lh) ++
cmpTo(LDA, rh) ++
cmpTo(CMP, lh) ++
List(AssemblyLine.relative(BCC, Label(x)),
AssemblyLine.relative(BNE, innerLabel)) ++
staTo(LDA, rl) ++
staTo(CMP, ll) ++
cmpTo(LDA, rl) ++
cmpTo(CMP, ll) ++
List(AssemblyLine.relative(BCC, Label(x)),
AssemblyLine.label(innerLabel))
case ComparisonType.GreaterOrEqualUnsigned =>
val innerLabel = MosCompiler.nextLabel("cp")
staTo(LDA, lh) ++
staTo(CMP, rh) ++
cmpTo(LDA, lh) ++
cmpTo(CMP, rh) ++
List(AssemblyLine.relative(BCC, innerLabel),
AssemblyLine.relative(BNE, x)) ++
staTo(LDA, ll) ++
staTo(CMP, rl) ++
cmpTo(LDA, ll) ++
cmpTo(CMP, rl) ++
List(AssemblyLine.relative(BCS, x),
AssemblyLine.label(innerLabel))
@ -619,23 +621,23 @@ object BuiltIns {
case ComparisonType.Equal =>
val innerLabel = MosCompiler.nextLabel("cp")
val bytewise = l.zip(r).map{
case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR)
case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(CMP, cmpR)
}
bytewise.init.flatMap(b => b :+ AssemblyLine.relative(BNE, innerLabel)) ++ bytewise.last ++List(
AssemblyLine.relative(BEQ, Label(label)),
AssemblyLine.label(innerLabel))
case ComparisonType.NotEqual =>
l.zip(r).flatMap {
case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR) :+ AssemblyLine.relative(BNE, label)
case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(CMP, cmpR) :+ AssemblyLine.relative(BNE, label)
}
case ComparisonType.LessUnsigned =>
val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{
case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR)
case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(SBC, cmpR)
}
calculateCarry ++ List(AssemblyLine.relative(BCC, Label(label)))
case ComparisonType.GreaterOrEqualUnsigned =>
val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{
case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR)
case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(SBC, cmpR)
}
calculateCarry ++ List(AssemblyLine.relative(BCS, Label(label)))
case ComparisonType.GreaterUnsigned | ComparisonType.LessOrEqualUnsigned =>
@ -1180,27 +1182,27 @@ object BuiltIns {
val env = ctx.env
env.eval(expr) match {
case Some(c) =>
List.tabulate(size) { i => List(AssemblyLine.immediate(STA, c.subbyte(i))) }
List.tabulate(size) { i => List(AssemblyLine.immediate(CMP, c.subbyte(i))) }
case None =>
expr match {
case v: VariableExpression =>
val variable = env.get[Variable](v.name)
List.tabulate(size) { i =>
if (i < variable.typ.size) {
AssemblyLine.variable(ctx, STA, variable, i)
AssemblyLine.variable(ctx, CMP, variable, i)
} else if (variable.typ.isSigned) {
val label = MosCompiler.nextLabel("sx")
AssemblyLine.variable(ctx, STA, variable, variable.typ.size - 1) ++ List(
AssemblyLine.variable(ctx, LDA, variable, variable.typ.size - 1) ++ List(
AssemblyLine.immediate(ORA, 0x7F),
AssemblyLine.relative(BMI, label),
AssemblyLine.immediate(STA, 0),
AssemblyLine.immediate(CMP, 0),
AssemblyLine.label(label))
} else List(AssemblyLine.immediate(STA, 0))
} else List(AssemblyLine.immediate(CMP, 0))
}
case expr@IndexedExpression(variable, index) =>
List.tabulate(size) { i =>
if (i == 0) MosExpressionCompiler.compileByteStorage(ctx, MosRegister.A, expr)
else List(AssemblyLine.immediate(STA, 0))
else List(AssemblyLine.immediate(CMP, 0))
}
case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) =>
if (simplicity(ctx.env, h) < 'J' || simplicity(ctx.env, l) < 'J') {
@ -1210,7 +1212,7 @@ object BuiltIns {
List.tabulate(size) { i =>
if (i == 0) getStorageForEachByte(ctx, l).head
else if (i == 1) MosExpressionCompiler.preserveRegisterIfNeeded(ctx, MosRegister.A, getStorageForEachByte(ctx, h).head)
else List(AssemblyLine.immediate(STA, 0))
else List(AssemblyLine.immediate(CMP, 0))
}
case _ =>
???

View File

@ -73,6 +73,8 @@ sealed trait Constant {
def quickSimplify: Constant = this
def isRelatedTo(v: Thing): Boolean
def fitsInto(typ: Type): Boolean = true // TODO
}
case class AssertByte(c: Constant) extends Constant {
@ -83,6 +85,8 @@ case class AssertByte(c: Constant) extends Constant {
override def isRelatedTo(v: Thing): Boolean = c.isRelatedTo(v)
override def quickSimplify: Constant = AssertByte(c.quickSimplify)
override def fitsInto(typ: Type): Boolean = true
}
case class UnexpandedConstant(name: String, requiredSize: Int) extends Constant {
@ -116,6 +120,26 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant {
override def toString: String = if (value > 9) value.formatted("$%X") else value.toString
override def isRelatedTo(v: Thing): Boolean = false
override def fitsInto(typ: Type): Boolean = {
if (typ.isSigned) {
typ.size match {
case 1 => value == value.toByte
case 2 => value == value.toShort
case 3 => value == ((value.toInt << 8) >> 8)
case 4 => value == value.toInt
case _ => true
}
} else {
typ.size match {
case 1 => value == (value & 0xff)
case 2 => value == (value & 0xffff)
case 3 => value == (value & 0xffffff)
case 4 => value == (value & 0xffffffffL)
case _ => true
}
}
}
}
case class MemoryAddressConstant(var thing: ThingInMemory) extends Constant {

View File

@ -340,4 +340,22 @@ class ComparisonSuite extends FunSuite with Matchers {
""".stripMargin
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+')))
}
test("Mixed type comparison") {
val src =
"""
| byte output @$c000
| byte x @$c002
| byte y @$c003
| void main () {
| word z
| output = 0
| z = $100
| x = 4
| y = 1
| if x < z { output += 1 }
| }
""".stripMargin
EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(1))
}
}