diff --git a/src/main/scala/millfork/compiler/BuiltIns.scala b/src/main/scala/millfork/compiler/BuiltIns.scala index f6ab2ec1..bee63678 100644 --- a/src/main/scala/millfork/compiler/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/BuiltIns.scala @@ -143,19 +143,7 @@ object BuiltIns { val env = ctx.env val b = env.get[Type]("byte") val sortedParams = params.sortBy { case (subtract, expr) => - val constPart = env.eval(expr) match { - case Some(NumericConstant(_, _)) => "Z" - case Some(_) => "Y" - case None => expr match { - case VariableExpression(_) => "V" - case IndexedExpression(_, LiteralExpression(_, _)) => "K" - case IndexedExpression(_, VariableExpression(_)) => "J" - case IndexedExpression(_, _) => "I" - case _ => "A" - } - } - val subtractPart = if (subtract) "X" else "P" - constPart + subtractPart + simplicity(env, expr) + (if (subtract) "X" else "P") } // TODO: merge constants val normalizedParams = sortedParams @@ -179,6 +167,21 @@ object BuiltIns { firstParamCompiled ++ firstParamSignCompiled ++ remainingParamsCompiled } + private def simplicity(env: Environment, expr: Expression): Char = { + val constPart = env.eval(expr) match { + case Some(NumericConstant(_, _)) => 'Z' + case Some(_) => 'Y' + case None => expr match { + case VariableExpression(_) => 'V' + case IndexedExpression(_, LiteralExpression(_, _)) => 'K' + case IndexedExpression(_, VariableExpression(_)) => 'J' + case IndexedExpression(_, _) => 'I' + case _ => 'A' + } + } + constPart + } + def compileBitOps(opcode: Opcode.Value, ctx: CompilationContext, params: List[Expression]): List[AssemblyLine] = { val b = ctx.env.get[Type]("byte") @@ -264,15 +267,7 @@ object BuiltIns { def compileInPlaceWordOrLongShiftOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression, aslRatherThanLsr: Boolean): List[AssemblyLine] = { val env = ctx.env val b = env.get[Type]("byte") - val targetBytes = lhs match { - case v: VariableExpression => - val variable = env.get[Variable](v.name) - List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) } - case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) => - List( - AssemblyLine.variable(ctx, STA, env.get[Variable](l.name)), - AssemblyLine.variable(ctx, STA, env.get[Variable](h.name))) - } + val targetBytes = getStorageForEachByte(ctx, lhs) val lo = targetBytes.head val hi = targetBytes.last env.eval(rhs) match { @@ -578,14 +573,21 @@ object BuiltIns { simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true) } case _ => - val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching) - val modifyLhs = if (subtract) { - insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false, decimal = decimal)) + if (!subtract && simplicity(env, v) > simplicity(env, addend)) { + val loadRhs = MlCompiler.compile(ctx, addend, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + val modifyAcc = insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, v, IndexChoice.PreferY, preserveA = true, commutative = true, decimal = decimal)) + val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) + loadRhs ++ modifyAcc ++ storeLhs } else { - insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true, decimal = decimal)) + val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + val modifyLhs = if (subtract) { + insertBeforeLast(AssemblyLine.implied(SEC), simpleOperation(SBC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = false, decimal = decimal)) + } else { + insertBeforeLast(AssemblyLine.implied(CLC), simpleOperation(ADC, ctx, addend, IndexChoice.PreferY, preserveA = true, commutative = true, decimal = decimal)) + } + val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) + loadLhs ++ modifyLhs ++ storeLhs } - val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) - loadLhs ++ modifyLhs ++ storeLhs } } @@ -596,17 +598,7 @@ object BuiltIns { val env = ctx.env val b = env.get[Type]("byte") val w = env.get[Type]("word") - val targetBytes: List[List[AssemblyLine]] = lhs match { - case v: VariableExpression => - val variable = env.get[Variable](v.name) - List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) } - case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) => - val lv = env.get[Variable](l.name) - val hv = env.get[Variable](h.name) - List( - AssemblyLine.variable(ctx, STA, lv), - AssemblyLine.variable(ctx, STA, hv)) - } + val targetBytes: List[List[AssemblyLine]] = getStorageForEachByte(ctx, lhs) val lhsIsStack = targetBytes.head.head.opcode == TSX val targetSize = targetBytes.size val addendType = MlCompiler.getExpressionType(ctx, addend) @@ -775,10 +767,17 @@ object BuiltIns { | (AND, Some(NumericConstant(-1, _))) => Nil case _ => - val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching) - val modifyLhs = simpleOperation(operation, ctx, param, IndexChoice.PreferY, preserveA = true, commutative = true) - val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) - loadLhs ++ modifyLhs ++ storeLhs + if (simplicity(env, v) > simplicity(env, param)) { + val loadRhs = MlCompiler.compile(ctx, param, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + val modifyAcc = simpleOperation(operation, ctx, v, IndexChoice.PreferY, preserveA = true, commutative = true) + val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) + loadRhs ++ modifyAcc ++ storeLhs + } else { + val loadLhs = MlCompiler.compile(ctx, v, Some(b -> RegisterVariable(Register.A, b)), NoBranching) + val modifyLhs = simpleOperation(operation, ctx, param, IndexChoice.PreferY, preserveA = true, commutative = true) + val storeLhs = MlCompiler.compileByteStorage(ctx, Register.A, v) + loadLhs ++ modifyLhs ++ storeLhs + } } } @@ -787,19 +786,7 @@ object BuiltIns { val env = ctx.env val b = env.get[Type]("byte") val w = env.get[Type]("word") - val targetBytes: List[List[AssemblyLine]] = lhs match { - case v: VariableExpression => - val variable = env.get[Variable](v.name) - List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) } - case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) => - val lv = env.get[Variable](l.name) - val hv = env.get[Variable](h.name) - List( - AssemblyLine.variable(ctx, STA, lv), - AssemblyLine.variable(ctx, STA, hv)) - case _ => - ??? - } + val targetBytes: List[List[AssemblyLine]] = getStorageForEachByte(ctx, lhs) val lo = targetBytes.head val targetSize = targetBytes.size val paramType = MlCompiler.getExpressionType(ctx, param) @@ -866,4 +853,20 @@ object BuiltIns { } + private def getStorageForEachByte(ctx: CompilationContext, lhs: LhsExpression): List[List[AssemblyLine]] = { + val env = ctx.env + lhs match { + case v: VariableExpression => + val variable = env.get[Variable](v.name) + List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) } + case IndexedExpression(variable, index) => + List(MlCompiler.compileByteStorage(ctx, Register.A, lhs)) + case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) => + List( + getStorageForEachByte(ctx, l).head, + MlCompiler.preserveRegisterIfNeeded(ctx, Register.A, getStorageForEachByte(ctx, h).head)) + case _ => + ??? + } + } } diff --git a/src/main/scala/millfork/compiler/MfCompiler.scala b/src/main/scala/millfork/compiler/MfCompiler.scala index bdf4923c..d38468fe 100644 --- a/src/main/scala/millfork/compiler/MfCompiler.scala +++ b/src/main/scala/millfork/compiler/MfCompiler.scala @@ -797,13 +797,14 @@ object MlCompiler { } { case (exprType, target) => assertCompatible(exprType, target.typ) target match { + // TODO: some more complex ones may not work correctly case RegisterVariable(Register.A | Register.X | Register.Y, _) => compile(ctx, l, exprTypeAndVariable, branches) case RegisterVariable(Register.AX, _) => compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), branches) ++ - compile(ctx, h, Some(b -> RegisterVariable(Register.X, b)), branches) + preserveRegisterIfNeeded(ctx, Register.A, compile(ctx, h, Some(b -> RegisterVariable(Register.X, b)), branches)) case RegisterVariable(Register.AY, _) => compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), branches) ++ - compile(ctx, h, Some(b -> RegisterVariable(Register.Y, b)), branches) + preserveRegisterIfNeeded(ctx, Register.A, compile(ctx, h, Some(b -> RegisterVariable(Register.Y, b)), branches)) case RegisterVariable(Register.XA, _) => compile(ctx, l, Some(b -> RegisterVariable(Register.X, b)), branches) ++ compile(ctx, h, Some(b -> RegisterVariable(Register.A, b)), branches) diff --git a/src/test/scala/millfork/test/SeparateBytesSuite.scala b/src/test/scala/millfork/test/SeparateBytesSuite.scala index b6e2dd1a..80436a87 100644 --- a/src/test/scala/millfork/test/SeparateBytesSuite.scala +++ b/src/test/scala/millfork/test/SeparateBytesSuite.scala @@ -134,4 +134,32 @@ class SeparateBytesSuite extends FunSuite with Matchers { | } """.stripMargin)(_.readWord(0xc000) should equal(0x707)) } + + ignore("Complex separate addition") { + EmuBenchmarkRun(""" + | array hi [25] @$c000 + | array lo [25] @$c080 + | void main () { + | byte i + | hi[0] = 0 + | lo[0] = 0 + | hi[1] = 0 + | lo[1] = 1 + | for i,0,until,lo.length-2 { + | barrier() + | hi[addTwo(i)]:lo[i + (one() << 1)] = hi[i + one()]:lo[1+i] + | barrier() + | hi[addTwo(i)]:lo[i + (one() << 1)] += hi[i]:lo[i] + | barrier() + | } + | } + | byte one() { return 1 } + | byte addTwo(byte x) { return x + 2 } + | void barrier() {} + """.stripMargin){m => + val h = m.readWord(0xc000 + 24) + val l = m.readWord(0xc080 + 24) + (h * 0x100 + l) should equal(46368) + } + } }