From 010b44f23ea1c7521169f52808c19753f06e1a1c Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Wed, 5 Jun 2019 18:36:39 +0200 Subject: [PATCH] Unsigned byte division by a constant --- CHANGELOG.md | 2 + docs/abi/generated-labels.md | 2 + docs/lang/operators.md | 9 ++- .../compiler/AbstractExpressionCompiler.scala | 19 +++++- .../millfork/compiler/mos/BuiltIns.scala | 62 ++++++++++++++++++ .../compiler/mos/MosExpressionCompiler.scala | 14 ++++ .../compiler/z80/Z80ExpressionCompiler.scala | 20 ++++++ .../millfork/compiler/z80/Z80Multiply.scala | 64 +++++++++++++++++++ src/main/scala/millfork/env/Constant.scala | 6 ++ src/main/scala/millfork/env/Environment.scala | 4 ++ .../millfork/output/AbstractAssembler.scala | 2 + src/main/scala/millfork/parser/MfParser.scala | 4 +- .../scala/millfork/test/ByteMathSuite.scala | 42 ++++++++++++ .../scala/millfork/test/ConstantSuite.scala | 3 +- 14 files changed, 247 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60fbc3b0..a11fd391 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ * Added structs and unions. +* Added unsigned byte division and modulo by a constant. + * Pointers can now be allocated anywhere. * Pointers can now be typed. diff --git a/docs/abi/generated-labels.md b/docs/abi/generated-labels.md index a13b6ee7..d5608e7b 100644 --- a/docs/abi/generated-labels.md +++ b/docs/abi/generated-labels.md @@ -28,6 +28,8 @@ where `11111` is a sequential number and `xx` is the type: * `ds` – decimal right shift operation +* `dv` – division and modulo operations + * `el` – beginning of the "else" block in an `if` statement * `ew` – end of a `while` statement diff --git a/docs/lang/operators.md b/docs/lang/operators.md index a115da2d..ec4d76c4 100644 --- a/docs/lang/operators.md +++ b/docs/lang/operators.md @@ -52,6 +52,9 @@ In the descriptions below, arguments to the operators are explained as follows: * `constant` means a compile-time constant +* `simpleconstant` means a compile-time constant evaluable at the first compilation pass +(eg. a literal or a combination of literals, not an undefined address) + * `simple` means either: a constant, a non-stack variable, a pointer indexed with a constant, a pointer indexed with a non-stack variable, an array indexed with a constant, an array indexed with a non-stack variable, @@ -91,7 +94,11 @@ TODO `word * byte` (zpreg) `byte * word` (zpreg) -There are no division, remainder or modulo operators. +* `/`, `%%`: unsigned division and unsigned modulo + +`byte / simpleconstant byte` +`constant word / constant word` +`constant long / constant long` ## Bitwise operators diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index f9f0554d..226d8410 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -62,7 +62,7 @@ class AbstractExpressionCompiler[T <: AbstractCode] { } } else { if (lSize > 2 || rSize > 2 || lSize + rSize > 3) { - ctx.log.error("Signed multiplication not supported", params.head.position) + ctx.log.error("Long multiplication not supported", params.head.position) } if (lSize == 2 && rType.isSigned) { ctx.log.error("Signed multiplication not supported", params.head.position) @@ -79,6 +79,21 @@ class AbstractExpressionCompiler[T <: AbstractCode] { } } + def assertSizesForDivision(ctx: CompilationContext, params: List[Expression], inPlace: Boolean): Unit = { + assertAllArithmetic(ctx, params) + //noinspection ZeroIndexToHead + val lType = getExpressionType(ctx, params(0)) + val lSize = lType.size + val rType = getExpressionType(ctx, params(1)) + val rSize = rType.size + if (lSize > 1 || rSize > 1) { + ctx.log.error("Long division not supported", params.head.position) + } + if (lType.isSigned || rType.isSigned) { + ctx.log.error("Signed division not supported", params.head.position) + } + } + def assertAllArithmeticBytes(msg: String, ctx: CompilationContext, params: List[Expression]): Unit = { assertAllArithmetic(ctx, params) if (params.exists { expr => getExpressionType(ctx, expr).size != 1 }) { @@ -323,7 +338,7 @@ object AbstractExpressionCompiler { case 1 => b case 2 => w } - case FunctionCallExpression("*" | "|" | "&" | "^", params) => params.map { e => getExpressionType(env, log, e).size }.max match { + case FunctionCallExpression("*" | "|" | "&" | "^" | "/" | "%%", params) => params.map { e => getExpressionType(env, log, e).size }.max match { case 1 => b case 2 => w case _ => log.error("Combining values bigger than words", expr.position); w diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 262994fd..0ef713ec 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -1009,6 +1009,68 @@ object BuiltIns { } } + def compileUnsignedByteDivision(ctx: CompilationContext, p: Expression, q: Expression, modulo: Boolean): List[AssemblyLine] = { + if (ctx.options.zpRegisterSize < 1) { + ctx.log.error("Byte division requires the zeropage pseudoregister", p.position) + return Nil + } + ctx.env.eval(q) match { + case Some(NumericConstant(qq, _)) => + if (qq < 0) { + ctx.log.error("Unsigned division by negative constant", q.position) + Nil + } else if (qq == 0) { + ctx.log.error("Unsigned division by zero", q.position) + Nil + } else if (qq > 255) { + if (modulo) MosExpressionCompiler.compileToA(ctx, p) + else List(AssemblyLine.immediate(LDA, 0)) + } else { + compileUnsignedByteDivision(ctx, p, qq.toInt, modulo) + } + case Some(_) => + ctx.log.error("Unsigned division by unknown constant", q.position) + Nil + case None => + ctx.log.error("Unsigned division by a variable expression", q.position) + Nil + } + } + def compileUnsignedByteDivision(ctx: CompilationContext, p: Expression, q: Int, modulo: Boolean): List[AssemblyLine] = { + val reg = ctx.env.get[VariableInMemory]("__reg") + val initP = MosExpressionCompiler.compileToA(ctx, p) + val result = ListBuffer[AssemblyLine]() + if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) { + result ++= initP + result += AssemblyLine.zeropage(STZ, reg) + } else if (MosExpressionCompiler.changesZpreg(initP, 0)) { + result ++= initP + result += AssemblyLine.implied(PHA) + result += AssemblyLine.immediate(LDA, 0) + result += AssemblyLine.zeropage(STA, reg) + result += AssemblyLine.implied(PLA) + } else { + result += AssemblyLine.immediate(LDA, 0) + result += AssemblyLine.zeropage(STA, reg) + result ++= initP + } + + for (i <- 7.to(0, -1)) { + if ((q << i) <= 255) { + val lbl = ctx.nextLabel("dv") + result += AssemblyLine.immediate(CMP, q << i) + result += AssemblyLine.relative(BCC, lbl) + result += AssemblyLine.immediate(SBC, q << i) + result += AssemblyLine.label(lbl) + result += AssemblyLine.zeropage(ROL, reg) + } + } + if (!modulo) { + result += AssemblyLine.zeropage(LDA, reg) + } + result.toList + } + def compileInPlaceByteAddition(ctx: CompilationContext, v: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = { if (decimal && !ctx.options.flag(CompilationFlag.DecimalMode) && ctx.options.zpRegisterSize < 4) { ctx.log.error("Unsupported decimal operation. Consider increasing the size of the zeropage register.", v.position) diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index f46ff6a9..85e86642 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -1308,6 +1308,20 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case 2 => BuiltIns.compileInPlaceWordMultiplication(ctx, l, r) } + case "/=" | "%%=" => + assertSizesForDivision(ctx, params, inPlace = true) + val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) + size match { + case 1 => + BuiltIns.compileUnsignedByteDivision(ctx, l, r, f.functionName == "%%=") ++ compileByteStorage(ctx, MosRegister.A, l) + } + case "/" | "%%" => + assertSizesForDivision(ctx, params, inPlace = false) + val (l, r, size) = assertArithmeticBinary(ctx, params) + size match { + case 1 => + BuiltIns.compileUnsignedByteDivision(ctx, l, r, f.functionName == "%%") + } case "*'=" => assertAllArithmeticBytes("Long multiplication not supported", ctx, params) val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params) diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index 3d4b665d..541162f5 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -880,6 +880,26 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case 2 => Z80Multiply.compile16And8BitInPlaceMultiply(ctx, l, r) } + case "/=" | "%%=" => + assertSizesForDivision(ctx, params, inPlace = true) + val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) + size match { + case 1 => + calculateAddressToAppropriatePointer(ctx, l, forWriting = true) match { + case Some((lvo, code)) => + code ++ (Z80Multiply.compileUnsignedByteDivision(ctx, l, r, f.functionName == "%%=") :+ ZLine.ld8(lvo, ZRegister.A)) + case None => + ctx.log.error("Invalid left-hand side", l.position) + Nil + } + } + case "/" | "%%" => + assertSizesForDivision(ctx, params, inPlace = false) + val (l, r, size) = assertArithmeticAssignmentLike(ctx, params) + size match { + case 1 => + targetifyA(ctx, target, Z80Multiply.compileUnsignedByteDivision(ctx, l, r, f.functionName == "%%"), false) + } case "*'=" => assertAllArithmeticBytes("Long multiplication not supported", ctx, params) val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params) diff --git a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala index 6d5bced4..8b559ae5 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala @@ -1,10 +1,13 @@ package millfork.compiler.z80 +import millfork.CompilationFlag import millfork.assembly.z80._ import millfork.compiler.{AbstractExpressionCompiler, CompilationContext} import millfork.env._ import millfork.node.{ConstantArrayElementExpression, Expression, LhsExpression, ZRegister} +import scala.collection.mutable.ListBuffer + /** * @author Karol Stasiak */ @@ -100,6 +103,67 @@ object Z80Multiply { } } + /** + * Calculate A = p / q or A = p %% q + */ + def compileUnsignedByteDivision(ctx: CompilationContext, p: LhsExpression, q: Expression, modulo: Boolean): List[ZLine] = { + ctx.env.eval(q) match { + case Some(NumericConstant(qq, _)) => + if (qq < 0) { + ctx.log.error("Unsigned division by negative constant", q.position) + Nil + } else if (qq == 0) { + ctx.log.error("Unsigned division by zero", q.position) + Nil + } else if (qq > 255) { + if (modulo) Z80ExpressionCompiler.compileToA(ctx, p) + else List(ZLine.ldImm8(ZRegister.A, 0)) + } else { + compileUnsignedByteDivisionImpl(ctx, p, qq.toInt, modulo) + } + case Some(_) => + ctx.log.error("Unsigned division by unknown constant", q.position) + Nil + case None => + ctx.log.error("Unsigned division by a variable expression", q.position) + Nil + } + } + /** + * Calculate A = p / q or A = p %% q + */ + def compileUnsignedByteDivisionImpl(ctx: CompilationContext, p: LhsExpression, q: Int, modulo: Boolean): List[ZLine] = { + import ZRegister._ + import ZOpcode._ + val result = ListBuffer[ZLine]() + result ++= Z80ExpressionCompiler.compileToA(ctx, p) + result += ZLine.ldImm8(E, 0) + + for (i <- 7.to(0, -1)) { + if ((q << i) <= 255) { + val lbl = ctx.nextLabel("dv") + result += ZLine.imm8(CP, q << i) + result += ZLine.jumpR(ctx, lbl, IfFlagSet(ZFlag.C)) + result += ZLine.imm8(SUB, q << i) + result += ZLine.label(lbl) + result += ZLine.implied(CCF) // TODO: optimize? + if (ctx.options.flag(CompilationFlag.EmitExtended80Opcodes)) { + result += ZLine.register(RL, E) + } else { + result += ZLine.ld8(D, A) + result += ZLine.ld8(A, E) + result += ZLine.implied(RLA) + result += ZLine.ld8(E, A) + result += ZLine.ld8(A, D) + } + } + } + if (!modulo) { + result += ZLine.ld8(A, E) + } + result.toList + } + /** * Calculate HL = l * r */ diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index f216bbb8..030a5133 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -269,6 +269,7 @@ case class SubbyteConstant(base: Constant, index: Int) extends Constant { object MathOperator extends Enumeration { val Plus, Minus, Times, Shl, Shr, Shl9, Shr9, Plus9, DecimalPlus9, DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShl9, DecimalShr, + Divide, Modulo, And, Or, Exor = Value } @@ -282,6 +283,7 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co Shl | DecimalShl | Shl9 | DecimalShl9 | Shr | DecimalShr | + Divide | Modulo | And | Or | Exor => lhs.isProvablyNonnegative && rhs.isProvablyNonnegative case _ => false } @@ -343,6 +345,8 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co case MathOperator.Exor => c case MathOperator.Or => c case MathOperator.And => Constant.Zero + case MathOperator.Divide => Constant.Zero + case MathOperator.Modulo => Constant.Zero case _ => CompoundConstant(operator, l, r) } case (c, NumericConstant(0, 1)) => @@ -378,6 +382,8 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co case MathOperator.Exor => (lv ^ rv) & bitmask case MathOperator.Or => lv | rv case MathOperator.And => lv & rv & bitmask + case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv + case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv case MathOperator.DecimalPlus if ls == 1 && rs == 1 => asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff case MathOperator.DecimalMinus if ls == 1 && rs == 1 && lv.&(0xff) >= rv.&(0xff) => diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 51497eb2..625ece8f 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -730,6 +730,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa constantOperation(MathOperator.DecimalTimes, params) case "*" => constantOperation(MathOperator.Times, params) + case "/" => + constantOperation(MathOperator.Divide, params) + case "%%" => + constantOperation(MathOperator.Modulo, params) case "&&" | "&" => constantOperation(MathOperator.And, params) case "^" => diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index fbf5b111..7b852370 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -145,6 +145,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program case MathOperator.And => l & r case MathOperator.Exor => l ^ r case MathOperator.Or => l | r + case MathOperator.Divide => l / r + case MathOperator.Modulo => l % r } } } diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index d0124289..ad8ec9a4 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -657,13 +657,13 @@ object MfParser { } val mfOperators = List( - List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'=", "="), + List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'=", "/=", "%%=", "="), List("||", "^^"), List("&&"), List("==", "<=", ">=", "!=", "<", ">"), List(":"), List("+'", "-'", "<<'", ">>'", ">>>>", "+", "-", "&", "|", "^", "<<", ">>"), - List("*'", "*")) + List("*'", "*", "/", "%%")) val mfOperatorsDropFlatten: IndexedSeq[List[String]] = (0 until mfOperators.length).map(i => mfOperators.drop(i).flatten) diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala index 6f182726..16ea160b 100644 --- a/src/test/scala/millfork/test/ByteMathSuite.scala +++ b/src/test/scala/millfork/test/ByteMathSuite.scala @@ -288,4 +288,46 @@ class ByteMathSuite extends FunSuite with Matchers with AppendedClues { """. stripMargin)(_.readByte(0xc000) should equal(x * y) withClue s"$x * $y") } + + test("Byte division 1") { + divisionCase1(0, 1) + divisionCase1(1, 1) + divisionCase1(2, 1) + divisionCase1(250, 1) + divisionCase1(0, 3) + divisionCase1(0, 5) + divisionCase1(1, 5) + divisionCase1(6, 5) + divisionCase1(73, 5) + divisionCase1(75, 5) + divisionCase1(42, 11) + } + + private def divisionCase1(x: Int, y: Int): Unit = { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp, Cpu.Intel8086)( + s""" + | import zp_reg + | byte output_q1 @$$c000 + | byte output_m1 @$$c001 + | byte output_q2 @$$c002 + | byte output_m2 @$$c003 + | void main () { + | byte a + | a = f() + | //output_q1 = a / $y + | //output_m1 = a %% $y + | output_q2 = a + | output_m2 = a + | output_q2 /= $y + | output_m2 %%= $y + | } + | byte f() {return $x} + """. + stripMargin) { m => +// m.readByte(0xc000) should equal(x / y) withClue s"$x / $y" +// m.readByte(0xc001) should equal(x % y) withClue s"$x %% $y" + m.readByte(0xc002) should equal(x / y) withClue s"$x / $y" + m.readByte(0xc003) should equal(x % y) withClue s"$x %% $y" + } + } } diff --git a/src/test/scala/millfork/test/ConstantSuite.scala b/src/test/scala/millfork/test/ConstantSuite.scala index 64e40d59..50b1b958 100644 --- a/src/test/scala/millfork/test/ConstantSuite.scala +++ b/src/test/scala/millfork/test/ConstantSuite.scala @@ -13,7 +13,8 @@ class ConstantSuite extends FunSuite with Matchers { EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086)( """ | array Sieve[4] - | array __screen[4] + | const byte two = 2 + | array __screen[4] = [4 / two, 4 %% two, 0, 0] | byte vic_mem | void main() { | vic_mem = lo( ((Sieve.addr >> 10) & 8) | ((__screen.addr >> 6) & $f0) )