1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-10 20:29:35 +00:00

Decimal and binary addition in the same expression should work correctly

This commit is contained in:
Karol Stasiak 2017-12-16 17:55:08 +01:00
parent 810ac4f00e
commit 47e6b41384
2 changed files with 49 additions and 15 deletions

View File

@ -49,6 +49,7 @@ object ComparisonType extends Enumeration {
/**
* @author Karol Stasiak
*/
//noinspection RedundantDefaultArgument
object BuiltIns {
object IndexChoice extends Enumeration {
@ -67,7 +68,7 @@ object BuiltIns {
def ldTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == LDA || x.opcode == LDX || x.opcode == LDY) x.copy(opcode = op) else x)
def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean): List[AssemblyLine] = {
def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean, decimal: Boolean = false): List[AssemblyLine] = {
val env = ctx.env
val parts: (List[AssemblyLine], List[AssemblyLine]) = env.eval(source).fold {
val b = env.get[Type]("byte")
@ -98,13 +99,13 @@ object BuiltIns {
}
calculateIndex -> List(AssemblyLine.absoluteY(opcode, baseAddress))
}
case f: FunctionCallExpression if commutative =>
case _: FunctionCallExpression | _:SumExpression if commutative =>
// TODO: is it ok?
return List(AssemblyLine.implied(PHA)) ++ MlCompiler.compile(ctx.addStack(1), f, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ++ List(
return List(AssemblyLine.implied(PHA)) ++ MlCompiler.compile(ctx.addStack(1), source, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ++ wrapInSedCldIfNeeded(decimal, List(
AssemblyLine.implied(TSX),
AssemblyLine.absoluteX(opcode, 0x101),
AssemblyLine.implied(INX),
AssemblyLine.implied(TXS))
AssemblyLine.implied(TXS)))
case _ =>
ErrorReporting.error("Right-hand-side expression is too complex", source.position)
return Nil
@ -117,7 +118,7 @@ object BuiltIns {
Nil -> List(AssemblyLine.immediate(opcode, const))
}
val preparations = parts._1
val finalRead = parts._2
val finalRead = wrapInSedCldIfNeeded(decimal, parts._2)
if (preserveA && AssemblyLine.treatment(preparations, State.A) != Treatment.Unchanged) {
AssemblyLine.implied(PHA) :: (preparations ++ (AssemblyLine.implied(PLA) :: finalRead))
} else {
@ -163,6 +164,7 @@ object BuiltIns {
val h = normalizedParams.head
val firstParamCompiled = MlCompiler.compile(ctx, h._2, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
val firstParamSignCompiled = if (h._1) {
// TODO: check if decimal subtraction works correctly here
List(AssemblyLine.immediate(EOR, 0xff), AssemblyLine.implied(SEC), AssemblyLine.immediate(ADC, 0))
} else {
Nil
@ -170,13 +172,12 @@ object BuiltIns {
val remainingParamsCompiled = normalizedParams.tail.flatMap { p =>
if (p._1) {
insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = false))
insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = false, decimal = decimal))
} else {
insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = true))
insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = true, decimal = decimal))
}
}
wrapInSedCldIfNeeded(decimal, firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled)
firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled
}
def compileBitOps(opcode: Opcode.Value, ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = {
@ -521,7 +522,7 @@ object BuiltIns {
def compileByteMultiplication(ctx: CompilationContext, v: Expression, c: Int): List[AssemblyLine] = {
val result = ListBuffer[AssemblyLine]()
// TODO: optimise
val addingCode = simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = false, commutative = false)
val addingCode = simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = false, commutative = false, decimal = false)
val adc = addingCode.last
val indexing = addingCode.init
result ++= indexing
@ -577,12 +578,12 @@ object BuiltIns {
case _ =>
val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching)
val modifyLhs = if (subtract) {
insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false))
insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false, decimal = decimal))
} else {
insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true))
insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true, decimal = decimal))
}
val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v)
wrapInSedCldIfNeeded(decimal, loadLhs ++ modifyLhs ++ storeLhs)
loadLhs ++ modifyLhs ++ storeLhs
}
}
@ -718,14 +719,14 @@ object BuiltIns {
buffer += AssemblyLine.implied(TAX)
}
}
buffer ++= staTo(ADC, targetBytes(i))
buffer ++= wrapInSedCldIfNeeded(decimal, staTo(ADC, targetBytes(i)))
buffer ++= targetBytes(i)
}
}
for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) {
buffer += AssemblyLine.implied(PLA)
}
wrapInSedCldIfNeeded(decimal, buffer.toList)
buffer.toList
}
def compileInPlaceByteBitOp(ctx: CompilationContext, v: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = {

View File

@ -65,4 +65,37 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
| byte one() { return 1 }
""".stripMargin)(_.readByte(0xc001) should equal(0x40))
}
test("Flag switching test") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = addDecimal(9, 9) + addDecimal(9, 9)
| }
| byte addDecimal(byte a, byte b) { return a +' b }
""".stripMargin)(_.readByte(0xc000) should equal(0x30))
}
test("Flag switching test 2") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = addDecimalTwice(9, 9)
| }
| byte addDecimalTwice(byte a, byte b) { return (a +' b) + (a +' b) }
""".stripMargin)(_.readByte(0xc000) should equal(0x30))
}
test("Flag switching test 3") {
EmuBenchmarkRun(
"""
| byte output @$c000
| void main () {
| output = addDecimalTwice($c, $c)
| }
| byte addDecimalTwice(byte a, byte b) { return (a + b) +' (a + b) }
""".stripMargin)(_.readByte(0xc000) should equal(0x36))
}
}