diff --git a/src/main/scala/millfork/compiler/BuiltIns.scala b/src/main/scala/millfork/compiler/BuiltIns.scala index 452962d6..01830b26 100644 --- a/src/main/scala/millfork/compiler/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/BuiltIns.scala @@ -49,6 +49,7 @@ object ComparisonType extends Enumeration { /** * @author Karol Stasiak */ +//noinspection RedundantDefaultArgument object BuiltIns { object IndexChoice extends Enumeration { @@ -67,7 +68,7 @@ object BuiltIns { def ldTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == LDA || x.opcode == LDX || x.opcode == LDY) x.copy(opcode = op) else x) - def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean): List[AssemblyLine] = { + def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean, decimal: Boolean = false): List[AssemblyLine] = { val env = ctx.env val parts: (List[AssemblyLine], List[AssemblyLine]) = env.eval(source).fold { val b = env.get[Type]("byte") @@ -98,13 +99,13 @@ object BuiltIns { } calculateIndex -> List(AssemblyLine.absoluteY(opcode, baseAddress)) } - case f: FunctionCallExpression if commutative => + case _: FunctionCallExpression | _:SumExpression if commutative => // TODO: is it ok? - return List(AssemblyLine.implied(PHA)) ++ MlCompiler.compile(ctx.addStack(1), f, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ++ List( + return List(AssemblyLine.implied(PHA)) ++ MlCompiler.compile(ctx.addStack(1), source, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ++ wrapInSedCldIfNeeded(decimal, List( AssemblyLine.implied(TSX), AssemblyLine.absoluteX(opcode, 0x101), AssemblyLine.implied(INX), - AssemblyLine.implied(TXS)) + AssemblyLine.implied(TXS))) case _ => ErrorReporting.error("Right-hand-side expression is too complex", source.position) return Nil @@ -117,7 +118,7 @@ object BuiltIns { Nil -> List(AssemblyLine.immediate(opcode, const)) } val preparations = parts._1 - val finalRead = parts._2 + val finalRead = wrapInSedCldIfNeeded(decimal, parts._2) if (preserveA && AssemblyLine.treatment(preparations, State.A) != Treatment.Unchanged) { AssemblyLine.implied(PHA) :: (preparations ++ (AssemblyLine.implied(PLA) :: finalRead)) } else { @@ -163,6 +164,7 @@ object BuiltIns { val h = normalizedParams.head val firstParamCompiled = MlCompiler.compile(ctx, h._2, Some(b -> RegisterVariable(Register.A, b)), NoBranching) val firstParamSignCompiled = if (h._1) { + // TODO: check if decimal subtraction works correctly here List(AssemblyLine.immediate(EOR, 0xff), AssemblyLine.implied(SEC), AssemblyLine.immediate(ADC, 0)) } else { Nil @@ -170,13 +172,12 @@ object BuiltIns { val remainingParamsCompiled = normalizedParams.tail.flatMap { p => if (p._1) { - insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = false)) + insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = false, decimal = decimal)) } else { - insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = true)) + insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, p._2, IndexChoice.PreferY, preserveA = true, commutative = true, decimal = decimal)) } } - - wrapInSedCldIfNeeded(decimal, firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled) + firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled } def compileBitOps(opcode: Opcode.Value, ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = { @@ -521,7 +522,7 @@ object BuiltIns { def compileByteMultiplication(ctx: CompilationContext, v: Expression, c: Int): List[AssemblyLine] = { val result = ListBuffer[AssemblyLine]() // TODO: optimise - val addingCode = simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = false, commutative = false) + val addingCode = simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = false, commutative = false, decimal = false) val adc = addingCode.last val indexing = addingCode.init result ++= indexing @@ -577,12 +578,12 @@ object BuiltIns { case _ => val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching) val modifyLhs = if (subtract) { - insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false)) + insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false, decimal = decimal)) } else { - insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true)) + insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true, decimal = decimal)) } val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) - wrapInSedCldIfNeeded(decimal, loadLhs ++ modifyLhs ++ storeLhs) + loadLhs ++ modifyLhs ++ storeLhs } } @@ -718,14 +719,14 @@ object BuiltIns { buffer += AssemblyLine.implied(TAX) } } - buffer ++= staTo(ADC, targetBytes(i)) + buffer ++= wrapInSedCldIfNeeded(decimal, staTo(ADC, targetBytes(i))) buffer ++= targetBytes(i) } } for (i <- 0 until calculateRhs.count(a => a.opcode == PHA) - calculateRhs.count(a => a.opcode == PLA)) { buffer += AssemblyLine.implied(PLA) } - wrapInSedCldIfNeeded(decimal, buffer.toList) + buffer.toList } def compileInPlaceByteBitOp(ctx: CompilationContext, v: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = { diff --git a/src/test/scala/millfork/test/ByteDecimalMathSuite.scala b/src/test/scala/millfork/test/ByteDecimalMathSuite.scala index d0d8afae..968015d0 100644 --- a/src/test/scala/millfork/test/ByteDecimalMathSuite.scala +++ b/src/test/scala/millfork/test/ByteDecimalMathSuite.scala @@ -65,4 +65,37 @@ class ByteDecimalMathSuite extends FunSuite with Matchers { | byte one() { return 1 } """.stripMargin)(_.readByte(0xc001) should equal(0x40)) } + + test("Flag switching test") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = addDecimal(9, 9) + addDecimal(9, 9) + | } + | byte addDecimal(byte a, byte b) { return a +' b } + """.stripMargin)(_.readByte(0xc000) should equal(0x30)) + } + + test("Flag switching test 2") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = addDecimalTwice(9, 9) + | } + | byte addDecimalTwice(byte a, byte b) { return (a +' b) + (a +' b) } + """.stripMargin)(_.readByte(0xc000) should equal(0x30)) + } + + test("Flag switching test 3") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = addDecimalTwice($c, $c) + | } + | byte addDecimalTwice(byte a, byte b) { return (a + b) +' (a + b) } + """.stripMargin)(_.readByte(0xc000) should equal(0x36)) + } }