diff --git a/docs/lang/operators.md b/docs/lang/operators.md index c651e402..6e39913e 100644 --- a/docs/lang/operators.md +++ b/docs/lang/operators.md @@ -79,7 +79,9 @@ If and only if both `h` and `l` are assignable expressions, then `h:l` is also a `constant byte * byte` `constant word * constant word` `constant long * constant long` -`byte * byte` (zpreg) +`byte * byte` (zpreg) +`word * byte` (zpreg) +`byte * word` (zpreg) There are no division, remainder or modulo operators. diff --git a/include/i80_math.mfk b/include/i80_math.mfk index 07fe0a45..803c9ab1 100644 --- a/include/i80_math.mfk +++ b/include/i80_math.mfk @@ -7,7 +7,7 @@ #if CPUFEATURE_Z80 || CPUFEATURE_GAMEBOY -inline asm byte __mul_u8u8u8() { +noinline asm byte __mul_u8u8u8() { ? LD E,A ? LD A, 0 ? JR __mul_u8u8u8_start @@ -24,7 +24,7 @@ inline asm byte __mul_u8u8u8() { #else -inline asm byte __mul_u8u8u8() { +noinline asm byte __mul_u8u8u8() { ? LD E,A ? LD C, 0 ? JP __mul_u8u8u8_start diff --git a/include/zp_reg.mfk b/include/zp_reg.mfk index bffeeef8..59130f0c 100644 --- a/include/zp_reg.mfk +++ b/include/zp_reg.mfk @@ -3,7 +3,7 @@ #warn zp_reg module should be used only on 6502-compatible targets #endif -inline asm byte __mul_u8u8u8() { +noinline asm byte __mul_u8u8u8() { ? LDA #0 ? JMP __mul_u8u8u8_start __mul_u8u8u8_add: @@ -20,7 +20,7 @@ __mul_u8u8u8_start: #if ZPREG_SIZE >= 3 -asm byte __mul_u16u8u16() { +noinline asm word __mul_u16u8u16() { ? LDA #0 ? TAX ? JMP __mul_u16u8u16_start diff --git a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala index 05b5457b..5285fa54 100644 --- a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala +++ b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala @@ -3,8 +3,9 @@ package millfork.assembly.mos.opt import millfork.assembly.mos.Opcode._ import millfork.assembly.mos.AddrMode._ import millfork.assembly.AssemblyOptimization -import millfork.assembly.mos.{AssemblyLine, Opcode, State} +import millfork.assembly.mos.{AssemblyLine, AssemblyLine0, Opcode, State} import millfork.DecimalUtils.asDecimal +import millfork.error.FatalErrorReporting /** * @author Karol Stasiak */ @@ -36,6 +37,16 @@ object ZeropageRegisterOptimizations { }, ) + private def compileMultiply[T](multiplicand: Int, add1:List[T], asl: List[T]): List[T] = { + if (multiplicand == 0) FatalErrorReporting.reportFlyingPig("Trying to optimize multiplication by 0 in a wrong way!") + def impl(m: Int): List[List[T]] = { + if (m == 1) Nil + else if (m % 2 == 0) asl :: impl(m / 2) + else add1 :: asl :: impl(m / 2) + } + impl(multiplicand).reverse.flatten + } + val ConstantMultiplication = new RuleBasedAssemblyOptimization("Constant multiplication", needsFlowInfo = FlowInfoRequirement.ForwardFlow, (HasOpcode(STA) & RefersTo("__reg", 0) & MatchAddrMode(0) & MatchParameter(1) & MatchA(4)) ~ @@ -56,16 +67,12 @@ object ZeropageRegisterOptimizations { (Elidable & HasOpcode(STA) & RefersTo("__reg", 0) & MatchAddrMode(0) & MatchParameter(1)) ~ (Linear & Not(RefersToOrUses("__reg", 1)) & DoesntChangeMemoryAt(0, 1)).* ~ (HasOpcode(STA) & RefersTo("__reg", 1) & MatchA(4)) ~ - Where(ctx => { - val constant = ctx.get[Int](4) - (constant & (constant - 1)) == 0 - }) ~ (Elidable & HasOpcode(JSR) & RefersTo("__mul_u8u8u8", 0)) ~~> { (code, ctx) => val constant = ctx.get[Int](4) if (constant == 0) { code.init :+ AssemblyLine.immediate(LDA, 0) } else { - code.init ++ (code.head.copy(opcode = LDA) :: List.fill(Integer.numberOfTrailingZeros(constant))(AssemblyLine.implied(ASL))) + code.init ++ (code.head.copy(opcode = LDA) :: compileMultiply(constant, List(AssemblyLine.implied(CLC), code.head.copy(opcode = ADC)), List(AssemblyLine.implied(ASL)))) } }, @@ -81,7 +88,25 @@ object ZeropageRegisterOptimizations { if (constant == 0) { code.init :+ AssemblyLine.immediate(LDA, 0) } else { - code.init ++ List.fill(Integer.numberOfTrailingZeros(constant))(AssemblyLine.implied(ASL)) + code.init ++ compileMultiply(constant, List(AssemblyLine.implied(CLC), code.init.last.copy(opcode = ADC)), List(AssemblyLine.implied(ASL))) + } + }, + + (Elidable & HasOpcode(STA) & RefersTo("__reg", 2) & MatchAddrMode(0) & MatchParameter(1) & MatchA(4)) ~ + Where(ctx => { + val constant = ctx.get[Int](4) + (constant & (constant - 1)) == 0 + }) ~ + (Linear & Not(RefersToOrUses("__reg", 2)) & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(JSR) & RefersTo("__mul_u16u8u16", 0)) ~~> { (code, ctx) => + val constant = ctx.get[Int](4) + if (constant == 0) { + code.init :+ AssemblyLine.immediate(LDA, 0) + } else { + val loAsl = code.head.copy(opcode = ASL, parameter = (code.head.parameter - 2).quickSimplify) + val hiRol = code.head.copy(opcode = ROL, parameter = (code.head.parameter - 1).quickSimplify) + val shift = List(loAsl, hiRol) + code.init ++ List.fill(Integer.numberOfTrailingZeros(constant))(shift).flatten ++ List(loAsl.copy(opcode = LDA), hiRol.copy(opcode = LDX)) } }, ) diff --git a/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala b/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala index 7a4dd225..bd933f3b 100644 --- a/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala +++ b/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala @@ -7,6 +7,7 @@ import millfork.env.{CompoundConstant, Constant, MathOperator, NumericConstant} import millfork.node.ZRegister import ZRegister._ import millfork.DecimalUtils._ +import millfork.error.FatalErrorReporting /** * Optimizations valid for Intel8080, Z80, EZ80 and Sharp @@ -1056,6 +1057,17 @@ object AlwaysGoodI80Optimizations { (Elidable & HasOpcode(DAA) & DoesntMatterWhatItDoesWithFlags & DoesntMatterWhatItDoesWith(ZRegister.A)) ~~> (_ => Nil), ) + private def compileMultiply[T](multiplicand: Int, add1:List[T], asl: List[T]): List[T] = { + if (multiplicand == 0) FatalErrorReporting.reportFlyingPig("Trying to optimize multiplication by 0 in a wrong way!") + def impl(m: Int): List[List[T]] = { + if (m == 1) Nil + else if (m % 2 == 0) asl :: impl(m / 2) + else add1 :: asl :: impl(m / 2) + } + + impl(multiplicand).reverse.flatten + } + val ConstantMultiplication = new RuleBasedAssemblyOptimization("Constant multiplication", needsFlowInfo = FlowInfoRequirement.BothFlows, (Elidable & HasOpcode(CALL) @@ -1066,98 +1078,68 @@ object AlwaysGoodI80Optimizations { & DoesntMatterWhatItDoesWithFlags & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => val product = ctx.get[Int](4) * ctx.get[Int](5) - List(ZLine.ldImm8(ZRegister.A, product)) + List(ZLine.ldImm8(ZRegister.A, product & 0xff)) + }, + + (Elidable & HasOpcode(CALL) + & IsUnconditional + & RefersTo("__mul_u16u8u16", 0) + & MatchRegister(ZRegister.A, 4) + & MatchRegister(ZRegister.DE, 5) + & DoesntMatterWhatItDoesWithFlags + & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C, ZRegister.B, ZRegister.A)) ~~> { (code, ctx) => + val product = ctx.get[Int](4) * ctx.get[Int](5) + List(ZLine.ldImm16(ZRegister.HL, product & 0xffff)) }, (Elidable & HasOpcode(CALL) & IsUnconditional & RefersTo("__mul_u8u8u8", 0) - & (HasRegister(ZRegister.D, 0) | HasRegister(ZRegister.A, 0)) + & MatchRegister(ZRegister.D, 1) & DoesntMatterWhatItDoesWithFlags & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.ldImm8(ZRegister.A, 0)) + val multiplicand = ctx.get[Int](1) + if (multiplicand == 0) List(ZLine.ldImm8(A, 0)) + else ZLine.ld8(ZRegister.D, ZRegister.A) :: compileMultiply(multiplicand, List(ZLine.register(ADD, ZRegister.D)), List(ZLine.register(ADD, ZRegister.A))) }, (Elidable & HasOpcode(CALL) & IsUnconditional & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.D, 1) + & MatchRegister(ZRegister.A, 1) & DoesntMatterWhatItDoesWithFlags & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - Nil - }, - (Elidable & HasOpcode(CALL) - & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.D, 2) - & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.register(ADD, ZRegister.A)) - }, - (Elidable & HasOpcode(CALL) - & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.D, 4) - & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A)) - }, - (Elidable & HasOpcode(CALL) - & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.D, 8) - & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A)) - }, - (Elidable & HasOpcode(CALL) - & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.D, 16) - & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A)) + val multiplicand = ctx.get[Int](1) + if (multiplicand == 0) List(ZLine.ldImm8(A, 0)) + else ZLine.ld8(ZRegister.A, ZRegister.D) :: compileMultiply(multiplicand, List(ZLine.register(ADD, ZRegister.D)), List(ZLine.register(ADD, ZRegister.A))) }, (Elidable & HasOpcode(CALL) & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.A, 1) + & RefersTo("__mul_u16u8u16", 0) + & MatchRegister(ZRegister.A, 1) & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.ld8(ZRegister.A, ZRegister.D)) + & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C, ZRegister.B, ZRegister.A)) ~~> { (code, ctx) => + val multiplicand = ctx.get[Int](1) + if (multiplicand == 0) List(ZLine.ldImm16(HL, 0)) + else ZLine.ld8(ZRegister.L, ZRegister.E) :: + ZLine.ld8(ZRegister.H, ZRegister.D) :: + compileMultiply(multiplicand, List(ZLine.registers(ADD_16, ZRegister.HL, ZRegister.DE)), List(ZLine.registers(ADD_16, ZRegister.HL, ZRegister.HL))) }, + (Elidable & HasOpcode(CALL) & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.A, 2) + & RefersTo("__mul_u16u8u16", 0) + & MatchRegister(ZRegister.DE, 1) & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.ld8(ZRegister.A, ZRegister.D), ZLine.register(ADD, ZRegister.A)) - }, - (Elidable & HasOpcode(CALL) - & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.A, 4) - & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.ld8(ZRegister.A, ZRegister.D), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A)) - }, - (Elidable & HasOpcode(CALL) - & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.A, 8) - & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.ld8(ZRegister.A, ZRegister.D), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A)) - }, - (Elidable & HasOpcode(CALL) - & IsUnconditional - & RefersTo("__mul_u8u8u8", 0) - & HasRegister(ZRegister.A, 16) - & DoesntMatterWhatItDoesWithFlags - & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C)) ~~> { (code, ctx) => - List(ZLine.ld8(ZRegister.A, ZRegister.D), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A), ZLine.register(ADD, ZRegister.A)) + & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C, ZRegister.B, ZRegister.A)) ~~> { (code, ctx) => + val multiplicand = ctx.get[Int](1) + if (multiplicand == 0) List(ZLine.ldImm16(HL, 0)) + else ZLine.ld8(ZRegister.L, ZRegister.A) :: + ZLine.ldImm8(ZRegister.H, 0) :: + ZLine.ld8(ZRegister.E, ZRegister.A) :: + ZLine.ldImm8(ZRegister.D, 0) :: + compileMultiply(multiplicand, List(ZLine.registers(ADD_16, ZRegister.HL, ZRegister.DE)), List(ZLine.registers(ADD_16, ZRegister.HL, ZRegister.HL))) }, (Elidable & Is8BitLoad(D, A)) ~ diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index a41926fa..406bf327 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -45,20 +45,39 @@ class AbstractExpressionCompiler[T <: AbstractCode] { params.map { case (_, expr) => getExpressionType(ctx, expr).size}.max } - def assertSizesForMultiplication(ctx: CompilationContext, params: List[Expression]): Unit = { + def assertSizesForMultiplication(ctx: CompilationContext, params: List[Expression], inPlace: Boolean): Unit = { assertAllArithmetic(ctx, params) //noinspection ZeroIndexToHead - val lSize = getExpressionType(ctx, params(0)).size + val lType = getExpressionType(ctx, params(0)) + val lSize = lType.size val rType = getExpressionType(ctx, params(1)) val rSize = rType.size - if (lSize != 1 && lSize != 2) { - ctx.log.fatal("Long multiplication not supported", params.head.position) - } - if (rSize != 1) { - ctx.log.fatal("Long multiplication not supported", params.head.position) - } - if (rType.isSigned) { - ctx.log.fatal("Signed multiplication not supported", params.head.position) + if (inPlace) { + if (lSize != 1 && lSize != 2) { + ctx.log.error("Long multiplication not supported", params.head.position) + } + if (rSize != 1) { + ctx.log.error("Long multiplication not supported", params.head.position) + } + if (lSize == 2 && rType.isSigned) { + ctx.log.error("Signed multiplication not supported", params.head.position) + } + } else { + if (lSize > 2 || rSize > 2 || lSize + rSize > 3) { + ctx.log.error("Signed multiplication not supported", params.head.position) + } + if (lSize == 2 && rType.isSigned) { + ctx.log.error("Signed multiplication not supported", params.head.position) + } + if (rSize == 2 && lType.isSigned) { + ctx.log.error("Signed multiplication not supported", params.head.position) + } + if (lSize + rSize > 2) { + if (params.size != 2) { + ctx.log.error("Cannot multiply more than 2 large numbers at once", params.headOption.flatMap(_.position)) + return + } + } } } @@ -212,11 +231,10 @@ object AbstractExpressionCompiler { case 1 => b case 2 => w } - case FunctionCallExpression("*", params) => b - case FunctionCallExpression("|" | "&" | "^", params) => params.map { e => getExpressionType(env, log, e).size }.max match { + case FunctionCallExpression("*" | "|" | "&" | "^", params) => params.map { e => getExpressionType(env, log, e).size }.max match { case 1 => b case 2 => w - case _ => log.error("Adding values bigger than words", expr.position); w + case _ => log.error("Combining values bigger than words", expr.position); w } case FunctionCallExpression("<<", List(a1, a2)) => if (getExpressionType(env, log, a2).size > 1) log.error("Shift amount too large", a2.position) diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index dd4b8531..cbad7ad7 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -876,7 +876,15 @@ object BuiltIns { val constant = constants.map(_._2.get.asInstanceOf[NumericConstant].value).foldLeft(1L)(_ * _).toInt variables.length match { case 0 => List(AssemblyLine.immediate(LDA, constant & 0xff)) - case 1 => compileByteMultiplication(ctx, variables.head._1, constant) + case 1 => + val sim = simplicity(ctx.env, variables.head._1) + if (sim >= 'I') { + compileByteMultiplication(ctx, variables.head._1, constant) + } else { + MosExpressionCompiler.compileToA(ctx, variables.head._1) ++ + List(AssemblyLine.zeropage(STA, ctx.env.get[ThingInMemory]("__reg.b0"))) ++ + compileByteMultiplication(ctx, VariableExpression("__reg.b0"), constant) + } case 2 => if (constant == 1) PseudoregisterBuiltIns.compileByteMultiplication(ctx, Some(variables(0)._1), variables(1)._1, storeInRegLo = false) diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index ff7877a7..b35e3b35 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -841,9 +841,15 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case 2 => PseudoregisterBuiltIns.compileWordBitOpsToAX(ctx, params, AND) } case "*" => - zeroExtend = true - assertAllArithmeticBytes("Long multiplication not supported", ctx, params) - BuiltIns.compileByteMultiplication(ctx, params) + assertSizesForMultiplication(ctx, params, inPlace = false) + getArithmeticParamMaxSize(ctx, params) match { + case 1 => + zeroExtend = true + BuiltIns.compileByteMultiplication(ctx, params) + case 2 => + //noinspection ZeroIndexToHead + PseudoregisterBuiltIns.compileWordMultiplication(ctx, Some(params(0)), params(1), storeInRegLo = false) + } case "|" => getArithmeticParamMaxSize(ctx, params) match { case 1 => @@ -1061,7 +1067,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { } } case "*=" => - assertSizesForMultiplication(ctx, params) + assertSizesForMultiplication(ctx, params, inPlace = true) val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) size match { case 1 => diff --git a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala index 043fbd50..e707308b 100644 --- a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala @@ -4,7 +4,7 @@ import millfork.CompilationFlag import millfork.assembly.mos.AddrMode._ import millfork.assembly.mos.Opcode._ import millfork.assembly.mos._ -import millfork.compiler.{BranchSpec, CompilationContext, NoBranching} +import millfork.compiler.{AbstractExpressionCompiler, BranchSpec, CompilationContext, NoBranching} import millfork.env._ import millfork.error.ConsoleLogger import millfork.node._ @@ -385,6 +385,12 @@ object PseudoregisterBuiltIns { ctx.log.error("Variable word multiplication requires the zeropage pseudoregister of size at least 3", param1OrRegister.flatMap(_.position)) return Nil } + (param1OrRegister.fold(2)(e => AbstractExpressionCompiler.getExpressionType(ctx, e).size), + AbstractExpressionCompiler.getExpressionType(ctx, param2).size) match { + case (1, 2) => return compileWordMultiplication(ctx, Some(param2), param1OrRegister.get, storeInRegLo) + case (2 | 1, 1) => // ok + case _ => ctx.log.fatal("Invalid code path", param2.position) + } val b = ctx.env.get[Type]("byte") val w = ctx.env.get[Type]("word") val reg = ctx.env.get[VariableInMemory]("__reg") diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index 2111e65b..fafdc2e2 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -586,8 +586,14 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case 2 => targetifyHL(ctx, target, ZBuiltIns.compile16BitOperation(ctx, AND, params)) } case "*" => - assertAllArithmeticBytes("Long multiplication not supported", ctx, params) - targetifyA(ctx, target, Z80Multiply.compile8BitMultiply(ctx, params), isSigned = false) + assertSizesForMultiplication(ctx, params, inPlace = false) + getArithmeticParamMaxSize(ctx, params) match { + case 1 => + targetifyA(ctx, target, Z80Multiply.compile8BitMultiply(ctx, params), isSigned = false) + case 2 => + //noinspection ZeroIndexToHead + targetifyHL(ctx, target, Z80Multiply.compile16And8BitMultiplyToHL(ctx, params(0), params(1))) + } case "|" => getArithmeticParamMaxSize(ctx, params) match { case 1 => targetifyA(ctx, target, ZBuiltIns.compile8BitOperation(ctx, OR, params), isSigned = false) @@ -750,7 +756,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case _ => Z80DecimalBuiltIns.compileInPlaceShiftRight(ctx, l, r, size) } case "*=" => - assertSizesForMultiplication(ctx, params) + assertSizesForMultiplication(ctx, params, inPlace = true) val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) size match { case 1 => diff --git a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala index f0a0fb59..6d5bced4 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala @@ -1,7 +1,7 @@ package millfork.compiler.z80 import millfork.assembly.z80._ -import millfork.compiler.CompilationContext +import millfork.compiler.{AbstractExpressionCompiler, CompilationContext} import millfork.env._ import millfork.node.{ConstantArrayElementExpression, Expression, LhsExpression, ZRegister} @@ -19,7 +19,7 @@ object Z80Multiply { } /** - * Compiles A = A * DE + * Compiles HL = A * DE */ private def multiplication16And8(ctx: CompilationContext): List[ZLine] = { List(ZLine(ZOpcode.CALL, NoRegisters, @@ -101,20 +101,33 @@ object Z80Multiply { } /** - * Calculate A = l * r + * Calculate HL = l * r */ - def compile16And8BitInPlaceMultiply(ctx: CompilationContext, l: LhsExpression, r: Expression): List[ZLine] = { + def compile16And8BitMultiplyToHL(ctx: CompilationContext, l: Expression, r: Expression): List[ZLine] = { + (AbstractExpressionCompiler.getExpressionType(ctx, l).size, + AbstractExpressionCompiler.getExpressionType(ctx, r).size) match { + case (1, 2) => return compile16And8BitMultiplyToHL(ctx, r, l) + case (2 | 1, 1) => // ok + case _ => ctx.log.fatal("Invalid code path", l.position) + } ctx.env.eval(r) match { case Some(c) => - Z80ExpressionCompiler.compileToDE(ctx, l) ++ List(ZLine.ldImm8(ZRegister.A, c)) ++ multiplication16And8(ctx) ++ Z80ExpressionCompiler.storeHL(ctx, l, signedSource = false) + Z80ExpressionCompiler.compileToDE(ctx, l) ++ List(ZLine.ldImm8(ZRegister.A, c)) ++ multiplication16And8(ctx) case _ => val lw = Z80ExpressionCompiler.compileToDE(ctx, l) val rb = Z80ExpressionCompiler.compileToA(ctx, r) val loadRegisters = lw ++ Z80ExpressionCompiler.stashDEIfChanged(ctx, rb) - loadRegisters ++ multiplication16And8(ctx) ++ Z80ExpressionCompiler.storeHL(ctx, l, signedSource = false) + loadRegisters ++ multiplication16And8(ctx) } } + /** + * Calculate l = l * r + */ + def compile16And8BitInPlaceMultiply(ctx: CompilationContext, l: LhsExpression, r: Expression): List[ZLine] = { + compile16And8BitMultiplyToHL(ctx, l, r) ++ Z80ExpressionCompiler.storeHL(ctx, l, signedSource = false) + } + /** * Calculate A = count * x */ diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala index 57a7a4ba..e4c10484 100644 --- a/src/test/scala/millfork/test/ByteMathSuite.scala +++ b/src/test/scala/millfork/test/ByteMathSuite.scala @@ -183,6 +183,7 @@ class ByteMathSuite extends FunSuite with Matchers { | import zp_reg | byte output1 @$c001 | byte output2 @$c002 + | byte output3 @$c003 | void main () { | calc1() | crash_if_bad() @@ -199,30 +200,36 @@ class ByteMathSuite extends FunSuite with Matchers { | noinline void calc1() { | output1 = five() * four() | output2 = 3 * three() * three() + | output3 = five() * three() | } | | noinline void calc2() { | output2 = 3 * three() * three() | output1 = five() * four() + | output3 = three() * five() | } | | noinline void calc3() { | output2 = 3 * three() * three() | output1 = four() * five() + | output3 = 3 * five() | } | | noinline void crash_if_bad() { | #if ARCH_6502 | if output1 != 20 { asm { lda $bfff }} | if output2 != 27 { asm { lda $bfff }} + | if output3 != 15 { asm { lda $bfff }} | #elseif ARCH_I80 | if output1 != 20 { asm { ld a,($bfff) }} | if output2 != 27 { asm { ld a,($bfff) }} + | if output3 != 15 { asm { ld a,($bfff) }} | #else | #error unsupported architecture | #endif | } """.stripMargin){m => + m.readByte(0xc003) should equal(15) m.readByte(0xc002) should equal(27) m.readByte(0xc001) should equal(20) } diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala index 5c423efe..e43237a1 100644 --- a/src/test/scala/millfork/test/WordMathSuite.scala +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -350,6 +350,29 @@ class WordMathSuite extends FunSuite with Matchers { } } + test("Word multiplication optimization") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Intel8080, Cpu.Sharp)(""" + | word output @$c000 + | void main () { + | output = alot() + | output *= two() + | output *= four() + | } + | noinline word alot() { + | return 4532 + | } + | inline byte four() { + | return 4 + | } + | inline byte two() { + | return 2 + | } + | import zp_reg + """.stripMargin){ m => + m.readWord(0xc000) should equal(4532 * 8) + } + } + test("In-place word/byte multiplication") { multiplyCase1(0, 0) multiplyCase1(0, 1) @@ -380,4 +403,40 @@ class WordMathSuite extends FunSuite with Matchers { """. stripMargin)(_.readWord(0xc000) should equal(x * y)) } + + test("Not-in-place word/byte multiplication") { + multiplyCase2(0, 0) + multiplyCase2(0, 1) + multiplyCase2(0, 2) + multiplyCase2(0, 5) + multiplyCase2(1, 0) + multiplyCase2(5, 0) + multiplyCase2(7, 0) + multiplyCase2(2, 5) + multiplyCase2(7, 2) + multiplyCase2(100, 2) + multiplyCase2(1000, 2) + multiplyCase2(54, 4) + multiplyCase2(2, 100) + multiplyCase2(500, 50) + multiplyCase2(4, 54) + } + + private def multiplyCase2(x: Int, y: Int): Unit = { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + s""" + | import zp_reg + | word output @$$c000 + | word tmp + | noinline void init() { + | tmp = $x + | } + | void main () { + | init() + | output = $y * tmp + | output = tmp * $y + | } + """. + stripMargin)(_.readWord(0xc000) should equal(x * y)) + } }