diff --git a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala index 80b40575..b4b25c40 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala @@ -47,7 +47,11 @@ object Z80Multiply { false } } - val productOfConstants = CompoundConstant(MathOperator.Times, otherConst, NumericConstant(numericConst & 0xff, 1)).quickSimplify + val productOfConstants = numericConst match { + case 0 => Constant.Zero + case 1 => otherConst + case _ => CompoundConstant(MathOperator.Times, otherConst, NumericConstant(numericConst & 0xff, 1)).quickSimplify + } (filteredParams, otherConst) match { case (Nil, NumericConstant(n, _)) => List(ZLine.ldImm8(ZRegister.A, (numericConst * n).toInt)) case (Nil, _) => List(ZLine.ldImm8(ZRegister.A, productOfConstants)) @@ -156,7 +160,7 @@ object Z80Multiply { } val qb = Z80ExpressionCompiler.compileToA(ctx, q) val load = if (qb.exists(Z80ExpressionCompiler.changesHL)) { - pb ++ Z80ExpressionCompiler.stashHLIfChanged(ctx, qb) + pb ++ Z80ExpressionCompiler.stashHLIfChanged(ctx, qb) ++ List(ZLine.ld8(ZRegister.D, ZRegister.A)) } else if (pb.exists(Z80ExpressionCompiler.changesDE)) { qb ++ List(ZLine.ld8(ZRegister.D, ZRegister.A)) ++ Z80ExpressionCompiler.stashDEIfChanged(ctx, pb) } else { diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index a0728ca7..5bb743bf 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -434,6 +434,16 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co case MathOperator.And => Constant.Zero case _ => CompoundConstant(operator, l, r) } + case (c, NumericConstant(1, 1)) => + operator match { + case MathOperator.Times => c + case _ => CompoundConstant(operator, l, r) + } + case (NumericConstant(1, 1), c) => + operator match { + case MathOperator.Times => c + case _ => CompoundConstant(operator, l, r) + } case (NumericConstant(lv, ls), NumericConstant(rv, rs)) => var size = ls max rs val bitmask = (1L << (8*size)) - 1 diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala index 4786de0c..dd481588 100644 --- a/src/test/scala/millfork/test/ByteMathSuite.scala +++ b/src/test/scala/millfork/test/ByteMathSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUltraBenchmarkRun} +import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUltraBenchmarkRun, EmuUnoptimizedCrossPlatformRun} import org.scalatest.{AppendedClues, FunSuite, Matchers} /** @@ -422,4 +422,24 @@ class ByteMathSuite extends FunSuite with Matchers with AppendedClues { m.readByte(0xc003) should equal(x % y) withClue s"= $x %% $y" } } + + test("Division bug repro"){ + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)( + s""" + | import zp_reg + | byte output_q1 @$$c000, output_m1 @$$c001 + | array zeroes[256] = [for i,0,until,256 [0]] + | void main () { + | byte a + | byte b + | a = 186 + | memory_barrier() + | a /= zeroes[b] | 2 + | output_q1 = a + | } + """. + stripMargin) { m => + m.readByte(0xc000) should equal(93) + } + } }