diff --git a/docs/lang/operators.md b/docs/lang/operators.md index 242bca64..9656be62 100644 --- a/docs/lang/operators.md +++ b/docs/lang/operators.md @@ -185,6 +185,7 @@ An expression of form `a[f()] += b` may call `f` an undefined number of times. * `*=`: multiplication in place `mutable byte *= constant byte` `mutable byte *= byte` (zpreg) +`mutable word *= unsigned byte` (zpreg) * `*'=`: decimal multiplication in place `mutable byte *'= constant byte` diff --git a/include/i80_math.mfk b/include/i80_math.mfk index adcd65c2..07fe0a45 100644 --- a/include/i80_math.mfk +++ b/include/i80_math.mfk @@ -49,3 +49,28 @@ inline asm byte __mul_u8u8u8() { } #endif + +inline asm word __mul_u16u8u16() { + ? LD HL,0 + ? LD B,8 +__mul_u16u8u16_loop: + ? ADD HL,HL + ? ADC A,A +#if CPUFEATURE_Z80 || CPUFEATURE_GAMEBOY + ? JR NC,__mul_u16u8u16_skip +#else + ? JP NC,__mul_u16u8u16_skip +#endif + ? ADD HL,DE +__mul_u16u8u16_skip: +#if CPUFEATURE_Z80 + ? DJNZ __mul_u16u8u16_loop +#elseif CPUFEATURE_GAMEBOY + ? DEC B + ? JR NZ,__mul_u16u8u16_loop +#else + ? DEC B + ? JP NZ,__mul_u16u8u16_loop +#endif + ? RET +} diff --git a/include/zp_reg.mfk b/include/zp_reg.mfk index 83ff1b30..bffeeef8 100644 --- a/include/zp_reg.mfk +++ b/include/zp_reg.mfk @@ -17,3 +17,29 @@ __mul_u8u8u8_start: ? BNE __mul_u8u8u8_loop ? RTS } + +#if ZPREG_SIZE >= 3 + +asm byte __mul_u16u8u16() { + ? LDA #0 + ? TAX + ? JMP __mul_u16u8u16_start +__mul_u16u8u16_add: + ? CLC + ? ADC __reg + ? TAY + ? TXA + ? ADC __reg + 1 + ? TAX + ? TYA +__mul_u16u8u16_loop: + ? ASL __reg + ? ROL __reg + 1 +__mul_u16u8u16_start: + ? LSR __reg + 2 + ? BCS __mul_u16u8u16_add + ? BNE __mul_u16u8u16_loop + ? RTS +} + +#endif diff --git a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala index 803e8800..05b5457b 100644 --- a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala +++ b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala @@ -12,6 +12,7 @@ object ZeropageRegisterOptimizations { val functionsThatUsePseudoregisterAsInput: Map[String, Set[Int]] = Map( "__mul_u8u8u8" -> Set(0, 1), + "__mul_u16u8u16" -> Set(0, 1, 2), "__adc_decimal" -> Set(2, 3), "__sbc_decimal" -> Set(2, 3), "__sub_decimal" -> Set(2, 3)) diff --git a/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala index 7b242561..2b580e00 100644 --- a/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala +++ b/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala @@ -175,11 +175,11 @@ case class CpuImportance(a: Importance = UnknownImportance, object ReverseFlowAnalyzer { - val readsA = Set("__mul_u8u8u8") + val readsA = Set("__mul_u8u8u8", "__mul_u16u8u16") val readsB = Set("") val readsC = Set("") - val readsD = Set("__mul_u8u8u8") - val readsE = Set("") + val readsD = Set("__mul_u8u8u8","__mul_u16u8u16") + val readsE = Set("__mul_u16u8u16") val readsH = Set("") val readsL = Set("") diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 2dfccb9d..2d026114 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -45,6 +45,23 @@ class AbstractExpressionCompiler[T <: AbstractCode] { params.map { case (_, expr) => getExpressionType(ctx, expr).size}.max } + def assertSizesForMultiplication(ctx: CompilationContext, params: List[Expression]): Unit = { + assertAllArithmetic(ctx, params) + //noinspection ZeroIndexToHead + val lSize = getExpressionType(ctx, params(0)).size + val rType = getExpressionType(ctx, params(1)) + val rSize = rType.size + if (lSize != 1 && lSize != 2) { + ctx.log.fatal("Long multiplication not supported", params.head.position) + } + if (rSize != 1) { + ctx.log.fatal("Long multiplication not supported", params.head.position) + } + if (rType.isSigned) { + ctx.log.fatal("Signed multiplication not supported", params.head.position) + } + } + def assertAllArithmeticBytes(msg: String, ctx: CompilationContext, params: List[Expression]): Unit = { assertAllArithmetic(ctx, params) if (params.exists { expr => getExpressionType(ctx, expr).size != 1 }) { diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 5e216948..9acc15c6 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -831,6 +831,21 @@ object BuiltIns { } } + def compileInPlaceWordMultiplication(ctx: CompilationContext, v: LhsExpression, addend: Expression): List[AssemblyLine] = { + val b = ctx.env.get[Type]("byte") + val w = ctx.env.get[Type]("word") + ctx.env.eval(addend) match { + case Some(NumericConstant(0, _)) => + MosExpressionCompiler.compile(ctx, v, None, NoBranching) ++ MosExpressionCompiler.compileAssignment(ctx, LiteralExpression(0, 2), v) + case Some(NumericConstant(1, _)) => + MosExpressionCompiler.compile(ctx, v, None, NoBranching) + case _ => + // TODO: optimize? + PseudoregisterBuiltIns.compileWordMultiplication(ctx, Some(v), addend, storeInRegLo = true) ++ + MosExpressionCompiler.compileAssignment(ctx, VariableExpression("__reg.loword"), v) + } + } + def compileByteMultiplication(ctx: CompilationContext, v: Expression, c: Int): List[AssemblyLine] = { val result = ListBuffer[AssemblyLine]() // TODO: optimise diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index 98201994..c2592b40 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -1029,9 +1029,14 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { } } case "*=" => - assertAllArithmeticBytes("Long multiplication not supported", ctx, params) - val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params) - BuiltIns.compileInPlaceByteMultiplication(ctx, l, r) + assertSizesForMultiplication(ctx, params) + val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileInPlaceByteMultiplication(ctx, l, r) + case 2 => + BuiltIns.compileInPlaceWordMultiplication(ctx, l, r) + } case "*'=" => assertAllArithmeticBytes("Long multiplication not supported", ctx, params) val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params) diff --git a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala index ae118f24..5d6cc572 100644 --- a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala @@ -326,6 +326,49 @@ object PseudoregisterBuiltIns { load ++ calculate } + def compileWordMultiplication(ctx: CompilationContext, param1OrRegister: Option[Expression], param2: Expression, storeInRegLo: Boolean): List[AssemblyLine] = { + if (ctx.options.zpRegisterSize < 3) { + ctx.log.error("Variable word multiplication requires the zeropage pseudoregister of size at least 3", param1OrRegister.flatMap(_.position)) + return Nil + } + val b = ctx.env.get[Type]("byte") + val w = ctx.env.get[Type]("word") + val reg = ctx.env.get[VariableInMemory]("__reg") + val load: List[AssemblyLine] = param1OrRegister match { + case Some(param1) => + val code1 = MosExpressionCompiler.compile(ctx, param1, Some(w -> RegisterVariable(MosRegister.AX, w)), BranchSpec.None) + val code2 = MosExpressionCompiler.compile(ctx, param2, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None) + if (!usesRegLo(code2) && !usesRegHi(code2)) { + code1 ++ List(AssemblyLine.zeropage(STA, reg), AssemblyLine.zeropage(STX, reg, 1)) ++ code2 ++ List(AssemblyLine.zeropage(STA, reg, 2)) + } else if (!usesReg2(code1)) { + code2 ++ List(AssemblyLine.zeropage(STA, reg, 2)) ++ code1 ++ List(AssemblyLine.zeropage(STA, reg), AssemblyLine.zeropage(STX, reg, 1)) + } else { + code2 ++ List(AssemblyLine.implied(PHA)) ++ code1 ++ List( + AssemblyLine.zeropage(STA, reg), + AssemblyLine.zeropage(STX, reg, 1), + AssemblyLine.implied(PLA), + AssemblyLine.zeropage(STA, reg, 2) + ) + } + case None => + val code2 = MosExpressionCompiler.compile(ctx, param2, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None) + if (!usesRegLo(code2) && !usesRegHi(code2)) { + List(AssemblyLine.zeropage(STA, reg), AssemblyLine.zeropage(STX, reg, 1)) ++ code2 ++ List(AssemblyLine.zeropage(STA, reg, 2)) + } else { + List(AssemblyLine.implied(PHA), AssemblyLine.implied(TXA), AssemblyLine.implied(PHA)) ++ code2 ++ List( + AssemblyLine.zeropage(STA, reg, 2), + AssemblyLine.implied(PLA), + AssemblyLine.zeropage(STA, reg, 1), + AssemblyLine.implied(PLA), + AssemblyLine.zeropage(STA, reg) + ) + } + } + val calculate = AssemblyLine.absoluteOrLongAbsolute(JSR, ctx.env.get[FunctionInMemory]("__mul_u16u8u16"), ctx.options) :: + (if (storeInRegLo) List(AssemblyLine.zeropage(STA, reg), AssemblyLine.zeropage(STX, reg, 1)) else Nil) + load ++ calculate + } + private def simplicity(env: Environment, expr: Expression): Char = { val constPart = env.eval(expr) match { case Some(NumericConstant(_, _)) => 'Z' @@ -353,4 +396,10 @@ object PseudoregisterBuiltIns { case AssemblyLine0(_, _, CompoundConstant(MathOperator.Plus, MemoryAddressConstant(th), NumericConstant(1, _))) if th.name == "__reg" => true case _ => false } + + def usesReg2(code: List[AssemblyLine]): Boolean = code.forall{ + case AssemblyLine0(JSR | BSR | TCD | TDC, _, _) => true + case AssemblyLine0(_, _, CompoundConstant(MathOperator.Plus, MemoryAddressConstant(th), NumericConstant(2, _))) if th.name == "__reg" => true + case _ => false + } } diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index a72711b6..eb7bd42c 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -746,9 +746,14 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case _ => Z80DecimalBuiltIns.compileInPlaceShiftRight(ctx, l, r, size) } case "*=" => - assertAllArithmeticBytes("Long multiplication not supported", ctx, params) - val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params) - Z80Multiply.compile8BitInPlaceMultiply(ctx, l, r) + assertSizesForMultiplication(ctx, params) + val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) + size match { + case 1 => + Z80Multiply.compile8BitInPlaceMultiply(ctx, l, r) + case 2 => + Z80Multiply.compile16And8BitInPlaceMultiply(ctx, l, r) + } case "*'=" => assertAllArithmeticBytes("Long multiplication not supported", ctx, params) val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params) diff --git a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala index c5109935..f0a0fb59 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala @@ -18,6 +18,14 @@ object Z80Multiply { ctx.env.get[ThingInMemory]("__mul_u8u8u8").toAddress)) } + /** + * Compiles A = A * DE + */ + private def multiplication16And8(ctx: CompilationContext): List[ZLine] = { + List(ZLine(ZOpcode.CALL, NoRegisters, + ctx.env.get[ThingInMemory]("__mul_u16u8u16").toAddress)) + } + /** * Calculate A = l * r */ @@ -92,6 +100,21 @@ object Z80Multiply { } } + /** + * Calculate A = l * r + */ + def compile16And8BitInPlaceMultiply(ctx: CompilationContext, l: LhsExpression, r: Expression): List[ZLine] = { + ctx.env.eval(r) match { + case Some(c) => + Z80ExpressionCompiler.compileToDE(ctx, l) ++ List(ZLine.ldImm8(ZRegister.A, c)) ++ multiplication16And8(ctx) ++ Z80ExpressionCompiler.storeHL(ctx, l, signedSource = false) + case _ => + val lw = Z80ExpressionCompiler.compileToDE(ctx, l) + val rb = Z80ExpressionCompiler.compileToA(ctx, r) + val loadRegisters = lw ++ Z80ExpressionCompiler.stashDEIfChanged(ctx, rb) + loadRegisters ++ multiplication16And8(ctx) ++ Z80ExpressionCompiler.storeHL(ctx, l, signedSource = false) + } + } + /** * Calculate A = count * x */ diff --git a/src/main/scala/millfork/node/opt/UnusedFunctions.scala b/src/main/scala/millfork/node/opt/UnusedFunctions.scala index 06a61927..8de59207 100644 --- a/src/main/scala/millfork/node/opt/UnusedFunctions.scala +++ b/src/main/scala/millfork/node/opt/UnusedFunctions.scala @@ -12,7 +12,9 @@ object UnusedFunctions extends NodeOptimization { private val operatorImplementations: List[(String, Int, String)] = List( ("*", 2, "__mul_u8u8u8"), + ("*", 3, "__mul_u16u8u16"), ("*=", 2, "__mul_u8u8u8"), + ("*=", 2, "__mul_u16u8u16"), ("+'", 4, "__adc_decimal"), ("+'=", 4, "__adc_decimal"), ("-'", 4, "__sub_decimal"), diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala index 9f4d8db2..7943199b 100644 --- a/src/test/scala/millfork/test/WordMathSuite.scala +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -330,4 +330,54 @@ class WordMathSuite extends FunSuite with Matchers { m.readWord(0xc000) should equal(5) } } + + test("Word multiplication 5") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(""" + | word output @$c000 + | void main () { + | output = alot() + | output *= five() + | } + | noinline word alot() { + | return 4532 + | } + | noinline byte five() { + | return 5 + | } + | import zp_reg + """.stripMargin){ m => + m.readWord(0xc000) should equal(4532 * 5) + } + } + + test("In-place word/byte multiplication") { + multiplyCase1(0, 0) + multiplyCase1(0, 1) + multiplyCase1(0, 2) + multiplyCase1(0, 5) + multiplyCase1(1, 0) + multiplyCase1(5, 0) + multiplyCase1(7, 0) + multiplyCase1(2, 5) + multiplyCase1(7, 2) + multiplyCase1(100, 2) + multiplyCase1(1000, 2) + multiplyCase1(54, 4) + multiplyCase1(2, 100) + multiplyCase1(500, 50) + multiplyCase1(4, 54) + } + + private def multiplyCase1(x: Int, y: Int): Unit = { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + s""" + | import zp_reg + | word output @$$c000 + | void main () { + | output = $x + | output *= $y + | } + """. + stripMargin)(_.readWord(0xc000) should equal(x * y)) + } }