1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-08-12 11:29:20 +00:00

Decimal multiplication, decimal right shift fixes

This commit is contained in:
Karol Stasiak 2018-01-31 22:26:20 +01:00
parent fdcf3dc8c8
commit 341466b198
6 changed files with 288 additions and 33 deletions

View File

@ -27,7 +27,7 @@ Millfork has different operator precedence compared to most other languages. Fro
You cannot use two different operators at the same precedence levels without using parentheses to disambiguate.
It is to prevent confusion about whether `a + b & c << d` means `(a + b) & (c << d)` `((a + b) & c) << d` or something else.
The only exceptions are `+` and `-`, and `+'` and `-'`.
They are interpeted as expected: `5 - 3 + 2 == 4` and `5 -' 3 +' 2 == 4`.
They are interpreted as expected: `5 - 3 + 2 == 4` and `5 -' 3 +' 2 == 4`.
Note that you cannot mix `+'` and `-'` with `+` and `-`.
## Argument types
@ -154,7 +154,8 @@ and fail to compile otherwise. This will be changed in the future.
* `*=`: multiplication in place
`mutable byte *= constant byte`
There is no `*'=` operator yet.
* `*'=`: decimal multiplication in place
`mutable byte *'= constant byte`
## Indexing
@ -166,6 +167,6 @@ An expression of form `a[i]`, where `i` is an expression of type `byte`, is:
* when `a` is a pointer variable: an access to the byte in memory at address `a + i`
Those exrpressions are of type `byte`. If `a` is any other kind of expression, `a[i]` is invalid.
Those expressions are of type `byte`. If `a` is any other kind of expression, `a[i]` is invalid.

View File

@ -42,4 +42,6 @@ _lo_nibble_to_hex_lbl:
inline asm void panic() {
JSR _panic
}
}
array __constant8 = [8]

View File

@ -28,7 +28,7 @@ object DecimalBuiltIns {
val addition = BuiltIns.compileAddition(ctx, List.fill(1 << v)(false -> l), decimal = true)
if (rotate) addition.filterNot(_.opcode == CLC) else addition
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
ErrorReporting.error("Cannot shift by a non-constant amount", r.position)
Nil
}
}
@ -40,38 +40,57 @@ object DecimalBuiltIns {
Nil
case Some(NumericConstant(v, _)) =>
List.fill(v.toInt) {
shiftOrRotateAccumulatorRight(ctx, rotate)
shiftOrRotateAccumulatorRight(ctx, rotate, preserveCarry = false)
}.flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
ErrorReporting.error("Cannot shift by a non-constant amount", r.position)
Nil
})
}
private def shiftOrRotateAccumulatorRight(ctx: CompilationContext, rotate: Boolean) = {
val constantLabel = MfCompiler.nextLabel("c8")
private def shiftOrRotateAccumulatorRight(ctx: CompilationContext, rotate: Boolean, preserveCarry: Boolean) = {
val skipHiDigit = MfCompiler.nextLabel("ds")
val skipLoDigit = MfCompiler.nextLabel("ds")
val cmos = ctx.options.flags(CompilationFlag.EmitCmosOpcodes)
val bit = if (cmos) {
AssemblyLine.immediate(BIT, 8)
if (preserveCarry) {
val constantLabel = MfCompiler.nextLabel("c8")
val bit = if (cmos) {
AssemblyLine.immediate(BIT, 8)
} else {
AssemblyLine.absolute(BIT, Label(constantLabel))
}
List(
if (rotate) AssemblyLine.implied(ROR) else AssemblyLine.implied(LSR),
AssemblyLine(LABEL, DoesNotExist, Label(constantLabel).toAddress, elidable = cmos),
AssemblyLine(PHP, Implied, Constant.Zero, elidable = cmos),
AssemblyLine.relative(BPL, skipHiDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x30),
AssemblyLine.label(skipHiDigit),
bit,
AssemblyLine.relative(BEQ, skipLoDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x3),
AssemblyLine.label(skipLoDigit),
AssemblyLine.implied(PLP))
} else {
AssemblyLine.absolute(BIT, Label(constantLabel))
val bit = if (cmos) {
AssemblyLine.immediate(BIT, 8)
} else {
AssemblyLine.absolute(BIT, ctx.env.get[ThingInMemory]("__constant8"))
}
List(
if (rotate) AssemblyLine.implied(ROR) else AssemblyLine.implied(LSR),
AssemblyLine.relative(BPL, skipHiDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x30),
AssemblyLine.label(skipHiDigit),
bit,
AssemblyLine.relative(BEQ, skipLoDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x3),
AssemblyLine.label(skipLoDigit))
}
List(
if (rotate) AssemblyLine.implied(ROR) else AssemblyLine.implied(LSR),
AssemblyLine(LABEL, DoesNotExist, Label(constantLabel).toAddress, elidable = cmos),
AssemblyLine(PHP, Implied, Constant.Zero, elidable = cmos),
AssemblyLine.relative(BPL, skipHiDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x30),
AssemblyLine.label(skipHiDigit),
bit,
AssemblyLine.relative(BCS, skipLoDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x3),
AssemblyLine.label(skipLoDigit),
AssemblyLine.implied(PLP))
}
def compileInPlaceLongShiftLeft(ctx: CompilationContext, l: LhsExpression, r: Expression): List[AssemblyLine] = {
@ -81,7 +100,7 @@ object DecimalBuiltIns {
case Some(NumericConstant(v, _)) =>
List.fill(v.toInt)(BuiltIns.compileInPlaceWordOrLongAddition(ctx, l, l, decimal = true, subtract = false)).flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
ErrorReporting.error("Cannot shift by a non-constant amount", r.position)
Nil
}
}
@ -104,14 +123,194 @@ object DecimalBuiltIns {
List.fill(v.toInt) {
List.tabulate(size) { i =>
BuiltIns.staTo(LDA, targetBytes(size - 1 - i)) ++
shiftOrRotateAccumulatorRight(ctx, i != 0) ++
shiftOrRotateAccumulatorRight(ctx, rotate = i != 0, preserveCarry = i != size - 1) ++
targetBytes(size - 1 - i)
}.flatten
}.flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
ErrorReporting.error("Cannot shift by a non-constant amount", r.position)
Nil
}
}
def compileInPlaceByteMultiplication(ctx: CompilationContext, l: LhsExpression, r: Expression): List[AssemblyLine] = {
val multiplier = ctx.env.eval(r) match {
case Some(NumericConstant(v, _)) =>
if (v.&(0xf0) > 0x90 || v.&(0xf) > 9)
ErrorReporting.error("Invalid decimal constant", r.position)
(v.&(0xf0).>>(4) * 10 + v.&(0xf)).toInt
case _ =>
ErrorReporting.error("Cannot multiply by a non-constant amount", r.position)
return Nil
}
val fullStorage = MfCompiler.compileByteStorage(ctx, Register.A, l)
val sta = fullStorage.last
if (sta.opcode != STA) ???
val fullLoad = fullStorage.init :+ sta.copy(opcode = LDA)
val transferToStash = sta.addrMode match {
case AbsoluteX | AbsoluteIndexedX | ZeroPageX | IndexedX => AssemblyLine.implied(TAY)
case _ => AssemblyLine.implied(TAX)
}
val transferToAccumulator = sta.addrMode match {
case AbsoluteX | AbsoluteIndexedX | ZeroPageX | IndexedX => AssemblyLine.implied(TYA)
case _ => AssemblyLine.implied(TXA)
}
def add1 = List(transferToAccumulator, AssemblyLine.implied(CLC), sta.copy(opcode = ADC), sta)
def times7 = List(
AssemblyLine.implied(ASL), AssemblyLine.implied(ASL),
AssemblyLine.implied(ASL), AssemblyLine.implied(ASL),
AssemblyLine.implied(SEC), sta.copy(opcode = SBC),
AssemblyLine.implied(SEC), sta.copy(opcode = SBC),
AssemblyLine.implied(SEC), sta.copy(opcode = SBC),
sta)
def times8 = List(
AssemblyLine.implied(ASL), AssemblyLine.implied(ASL),
AssemblyLine.implied(ASL), AssemblyLine.implied(ASL),
AssemblyLine.implied(SEC), sta.copy(opcode = SBC),
AssemblyLine.implied(SEC), sta.copy(opcode = SBC),
sta)
def times9 = List(
AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL),
AssemblyLine.implied(SEC), sta.copy(opcode = SBC),
sta)
def times10 = List(AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), sta)
def times11 = List(AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(CLC), sta.copy(opcode = ADC), sta)
val execute = multiplier match {
case 0 => List(AssemblyLine.immediate(LDA, 0), sta)
case 1 => Nil
case x =>
val ways = sta.addrMode match {
case Absolute | AbsoluteX | AbsoluteY | AbsoluteIndexedX | Indirect =>
waysForLongAddrModes
case _ =>
waysForShortAddrModes
}
ways(x).flatMap {
case 1 => add1
case -7 => times7
case x if x < 9 => List.fill(x - 1)(List(AssemblyLine.implied(CLC), sta.copy(opcode = ADC))).flatten :+ sta
case 8 => times8
case 9 => times9
case x => List(AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL), AssemblyLine.implied(ASL)) ++
List.fill(x - 10)(List(AssemblyLine.implied(CLC), sta.copy(opcode = ADC))).flatten :+ sta
}
}
if (execute.contains(transferToAccumulator)) {
AssemblyLine.implied(SED) :: (fullLoad ++ List(transferToStash) ++ execute :+ AssemblyLine.implied(CLD))
} else {
AssemblyLine.implied(SED) :: (fullLoad ++ execute :+ AssemblyLine.implied(CLD))
}
}
private lazy val waysForShortAddrModes: Map[Int, List[Int]] = Map(
2 -> List(2), 3 -> List(3), 4 -> List(2,2), 5 -> List(5), 6 -> List(3,2), 7 -> List(7), 8 -> List(8), 9 -> List(9), 10 -> List(10),
11 -> List(11), 12 -> List(12), 13 -> List(13), 14 -> List(14), 15 -> List(3,5), 16 -> List(8,2), 17 -> List(8,2,1), 18 -> List(2,9), 19 -> List(2,9,1), 20 -> List(2,10),
21 -> List(2,10,1), 22 -> List(11,2), 23 -> List(11,2,1), 24 -> List(12,2), 25 -> List(5,5), 26 -> List(2,13), 27 -> List(3,9), 28 -> List(3,9,1), 29 -> List(3,9,1,1), 30 -> List(3,10),
31 -> List(3,10,1), 32 -> List(8,2,2), 33 -> List(11,3), 34 -> List(11,3,1), 35 -> List(7,5), 36 -> List(2,2,9), 37 -> List(2,2,9,1), 38 -> List(2,9,1,2), 39 -> List(3,13), 40 -> List(2,2,10),
41 -> List(2,2,10,1), 42 -> List(2,10,1,2), 43 -> List(2,10,1,2,1), 44 -> List(11,2,2), 45 -> List(9,5), 46 -> List(11,2,1,2), 47 -> List(11,2,1,2,1), 48 -> List(12,2,2), 49 -> List(7,7), 50 -> List(10,5),
51 -> List(10,5,1), 52 -> List(2,2,13), 53 -> List(2,2,13,1), 54 -> List(3,2,9), 55 -> List(11,5), 56 -> List(8,7), 57 -> List(2,9,1,3), 58 -> List(3,9,1,1,2), 59 -> List(3,9,1,1,2,1), 60 -> List(3,2,10),
61 -> List(3,2,10,1), 62 -> List(3,10,1,2), 63 -> List(7,9), 64 -> List(8,8), 65 -> List(13,5), 66 -> List(11,3,2), 67 -> List(11,3,2,1), 68 -> List(11,3,1,2), 69 -> List(11,2,1,3), 70 -> List(7,10),
71 -> List(7,10,1), 72 -> List(8,9), 73 -> List(8,9,1), 74 -> List(2,2,9,1,2), 75 -> List(3,5,5), 76 -> List(2,9,1,2,2), 77 -> List(11,7), 78 -> List(3,2,13), 79 -> List(3,2,13,1), 80 -> List(8,10),
81 -> List(9,9), 82 -> List(9,9,1), 83 -> List(9,9,1,1), 84 -> List(7,12), 85 -> List(7,12,1), 86 -> List(2,10,1,2,1,2), 87 -> List(3,9,1,1,3), 88 -> List(8,11), 89 -> List(8,11,1), 90 -> List(9,10),
91 -> List(9,10,1), 92 -> List(9,10,1,1), 93 -> List(3,10,1,3), 94 -> List(3,10,1,3,1), 95 -> List(2,9,1,5), 96 -> List(8,12), 97 -> List(8,12,1), 98 -> List(7,14), 99 -> List(11,9),
)
private lazy val waysForLongAddrModes: Map[Int, List[Int]] = Map(
2 -> List(2), 3 -> List(3), 4 -> List(2,2), 5 -> List(5), 6 -> List(3,2), 7 -> List(-7), 8 -> List(8), 9 -> List(9), 10 -> List(10),
11 -> List(11), 12 -> List(12), 13 -> List(13), 14 -> List(14), 15 -> List(15), 16 -> List(8,2), 17 -> List(8,2,1), 18 -> List(2,9), 19 -> List(2,9,1), 20 -> List(2,10),
21 -> List(2,10,1), 22 -> List(11,2), 23 -> List(11,2,1), 24 -> List(12,2), 25 -> List(12,2,1), 26 -> List(2,13), 27 -> List(3,9), 28 -> List(3,9,1), 29 -> List(3,9,1,1), 30 -> List(3,10),
31 -> List(3,10,1), 32 -> List(8,2,2), 33 -> List(11,3), 34 -> List(11,3,1), 35 -> List(-7,5), 36 -> List(2,2,9), 37 -> List(2,2,9,1), 38 -> List(2,9,1,2), 39 -> List(3,13), 40 -> List(2,2,10),
41 -> List(2,2,10,1), 42 -> List(2,10,1,2), 43 -> List(2,10,1,2,1), 44 -> List(11,2,2), 45 -> List(9,5), 46 -> List(11,2,1,2), 47 -> List(11,2,1,2,1), 48 -> List(12,2,2), 49 -> List(12,2,2,1), 50 -> List(10,5),
51 -> List(10,5,1), 52 -> List(2,2,13), 53 -> List(2,2,13,1), 54 -> List(3,2,9), 55 -> List(11,5), 56 -> List(8,-7), 57 -> List(2,9,1,3), 58 -> List(3,9,1,1,2), 59 -> List(3,9,1,1,2,1), 60 -> List(3,2,10),
61 -> List(3,2,10,1), 62 -> List(3,10,1,2), 63 -> List(9,-7), 64 -> List(8,8), 65 -> List(13,5), 66 -> List(11,3,2), 67 -> List(11,3,2,1), 68 -> List(11,3,1,2), 69 -> List(11,2,1,3), 70 -> List(-7,10),
71 -> List(-7,10,1), 72 -> List(8,9), 73 -> List(8,9,1), 74 -> List(2,2,9,1,2), 75 -> List(12,2,1,3), 76 -> List(2,9,1,2,2), 77 -> List(11,-7), 78 -> List(3,2,13), 79 -> List(3,2,13,1), 80 -> List(8,10),
81 -> List(9,9), 82 -> List(9,9,1), 83 -> List(9,9,1,1), 84 -> List(12,-7), 85 -> List(12,-7,1), 86 -> List(2,10,1,2,1,2), 87 -> List(3,9,1,1,3), 88 -> List(8,11), 89 -> List(8,11,1), 90 -> List(9,10),
91 -> List(9,10,1), 92 -> List(9,10,1,1), 93 -> List(3,10,1,3), 94 -> List(3,10,1,3,1), 95 -> List(2,9,1,5), 96 -> List(8,12), 97 -> List(8,12,1), 98 -> List(14,-7), 99 -> List(11,9),
)
// The following functions are used to generate the tables above:
private def multiplyCosts(addrMode: AddrMode.Value) = {
val hiBytes = addrMode match {
case Absolute | AbsoluteX | AbsoluteY | AbsoluteIndexedX | Indirect => 1
case _ => 0
}
// TODO: make those costs smarter.
// Ideally, the following features should be considered:
// * NMOS vs CMOS (for timings)
// * compiling for speed vs compiling for size
// Currently, only the size is taken account of.
Map(
1 -> (6 + 2 * hiBytes),
2 -> (5 + 2 * hiBytes),
3 -> (8 + 3 * hiBytes),
5 -> (11 + 5 * hiBytes),
7 -> (14 + 7 * hiBytes),
-7 -> (15 + 4 * hiBytes), // alternative algorithm of multiplying by 7
8 -> (12 + 3 * hiBytes),
9 -> (9 + 2 * hiBytes),
10 -> (6 + 1 * hiBytes),
11 -> (9 + 2 * hiBytes),
12 -> (12 + 3 * hiBytes),
13 -> (15 + 4 * hiBytes),
14 -> (18 + 5 * hiBytes),
15 -> (21 + 6 * hiBytes),
16 -> (24 + 7 * hiBytes),
17 -> (27 + 8 * hiBytes),
18 -> (30 + 9 * hiBytes),
19 -> (33 + 10 * hiBytes)
)
}
private def findWay(target: Int, costs: Map[Int, Int]): List[Int] = {
def recurse(acc: Int, depthLeft: Int, costAndTrace: (Int, List[Int])): Option[(Int, List[Int])] = {
if (acc == target) return Some(costAndTrace)
if (acc > target) return None
if (depthLeft == 0) return None
val (costSoFar, trace) = costAndTrace
val results = costs.flatMap {
case (key, keyCost) =>
if (key == 1) {
recurse(1 + acc, depthLeft - 1, (costSoFar + keyCost, key :: trace))
} else {
recurse(key.abs * acc, depthLeft - 1, (costSoFar + keyCost, key :: trace))
}
}
if (results.isEmpty) return None
Some(results.minBy(_._1))
}
recurse(1, 6, 0 -> Nil).get._2.reverse
}
def main(args: Array[String]): Unit = {
val shortCosts = multiplyCosts(ZeroPageY)
val longCosts = multiplyCosts(AbsoluteX)
for (i <- 2 to 99) {
if (waysForLongAddrModes(i) != waysForShortAddrModes(i)) {
println(i)
val l = waysForLongAddrModes(i)
val s = waysForShortAddrModes(i)
val longCost = l.map(longCosts).sum
val shortCost = s.map(shortCosts).sum
val longCostIfUsedShortWay = s.map(longCosts).sum
val shortCostIfUsedLongWay = l.map(shortCosts).sum
println(s"For long addr: $l (size: $longCost); the other would have size $longCostIfUsedShortWay")
println(s"For short addr: $s (size: $shortCost); the other would have size $shortCostIfUsedLongWay")
}
}
println((2 to 99).map(waysForLongAddrModes).map(_.length).max)
println((2 to 99).map(waysForShortAddrModes).map(_.length).max)
for (i <- 2 to 99) {
print(s"$i -> List(${findWay(i, shortCosts).mkString(",")}), ")
if (i % 10 == 0) println()
}
println()
println()
for (i <- 2 to 99) {
print(s"$i -> List(${findWay(i, longCosts).mkString(",")}), ")
if (i % 10 == 0) println()
}
}
}

View File

@ -50,14 +50,12 @@ object MfCompiler {
val prefix = (if (ctx.function.interrupt) {
if (ctx.options.flag(CompilationFlag.EmitCmosOpcodes)) {
List(
AssemblyLine.implied(SEI),
AssemblyLine.implied(PHA),
AssemblyLine.implied(PHX),
AssemblyLine.implied(PHY),
AssemblyLine.implied(CLD))
} else {
List(
AssemblyLine.implied(SEI),
AssemblyLine.implied(PHA),
AssemblyLine.implied(TXA),
AssemblyLine.implied(PHA),
@ -1046,6 +1044,10 @@ object MfCompiler {
assertAllBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertAssignmentLike(ctx, params)
BuiltIns.compileInPlaceByteMultiplication(ctx, l, r)
case "*'=" =>
assertAllBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertAssignmentLike(ctx, params)
DecimalBuiltIns.compileInPlaceByteMultiplication(ctx, l, r)
case "&=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
size match {
@ -1443,7 +1445,6 @@ object MfCompiler {
AssemblyLine.implied(PLY),
AssemblyLine.implied(PLX),
AssemblyLine.implied(PLA),
AssemblyLine.implied(CLI),
AssemblyLine.implied(RTI))
} else {
List(
@ -1452,7 +1453,6 @@ object MfCompiler {
AssemblyLine.implied(PLA),
AssemblyLine.implied(TAX),
AssemblyLine.implied(PLA),
AssemblyLine.implied(CLI),
AssemblyLine.implied(RTI))
}
} else {

View File

@ -702,6 +702,9 @@ class Environment(val parent: Option[Environment], val prefix: String) {
case a: ArrayDeclarationStatement => registerArray(a)
case i: ImportStatement => ()
}
if (!things.contains("__constant8")) {
things("__constant8") = InitializedArray("__constant8", None, List(NumericConstant(8, 1)))
}
}
private def checkName[T <: Thing : Manifest](objType: String, name: String, pos: Option[Position]): Unit = {

View File

@ -178,4 +178,54 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
""".stripMargin)
m.readWord(0xc000) should equal(0x91)
}
def toDecimal(v: Int) = {
if (v.&(0xf000) > 0x9000 || v.&(0xf00) > 0x900 || v.&(0xf0) > 0x90 || v.&(0xf) > 9)
fail("Invalid decimal value: " + v.toHexString)
v.&(0xf000).>>(12) * 1000 + v.&(0xf00).>>(8) * 100 + v.&(0xf0).>>(4) * 10 + v.&(0xf)
}
test("Decimal byte right-shift comprehensive byte suite") {
for (i <- 0 to 99) {
val m = EmuUnoptimizedRun(
"""
| byte output @$c000
| void main () {
| output = $#
| output >>'= 1
| }
""".stripMargin.replace("#", i.toString))
toDecimal(m.readByte(0xc000)) should equal(i/2)
}
}
test("Decimal word right-shift comprehensive suite") {
for (i <- List(0, 1, 10, 100, 1000, 2000, 500, 200, 280, 300, 5234, 7723, 7344, 9, 16, 605, 1111, 2222, 3333, 9999, 8888, 8100)) {
val m = EmuUnoptimizedRun(
"""
| word output @$c000
| void main () {
| output = $#
| output >>'= 1
| }
""".stripMargin.replace("#", i.toString))
toDecimal(m.readWord(0xc000)) should equal(i/2)
}
}
test("Decimal byte multiplication comprehensive suite") {
for (i <- List(1, 2, 3, 6, 8, 10, 11, 12, 14, 15, 16, 40, 99) ; j <- 0 to 99) {
val m = EmuUnoptimizedRun(
"""
| byte output @$c000
| void main () {
| init()
| run()
| }
| void init() { output = $#i }
| void run () { output *'= $#j }
""".stripMargin.replace("#i", i.toString).replace("#j", j.toString))
toDecimal(m.readWord(0xc000)) should equal((i * j) % 100)
}
}
}