1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-07-05 09:28:54 +00:00

Z80: Operators *'=, <<' and <<'=

This commit is contained in:
Karol Stasiak 2018-07-23 15:47:03 +02:00
parent 85243c96a7
commit a34acbf6ce
5 changed files with 298 additions and 22 deletions

View File

@ -172,9 +172,9 @@ An expression of form `a[f()] += b` may call `f` an undefined number of times.
`mutable long <<= byte`
* `<<'=`, `>>'=`: decimal shift in place
`mutable byte <<= constant byte`
`mutable word <<= constant byte`
`mutable long <<= constant byte`
`mutable byte <<'= constant byte`
`mutable word <<'= constant byte`
`mutable long <<'= constant byte`
* `-=`, `-'=`: subtraction in place
`mutable byte -= byte`

View File

@ -249,6 +249,13 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta
case OneRegister(r) => s" (${asAssemblyString(r)})"
}
s" $opcode$ps"
case op@(ADD | SBC | ADC) =>
val os = op.toString
val ps = (registers match {
case OneRegister(r) => s" ${asAssemblyString(r)}"
case OneRegisterOffset(r, o) => s" ${asAssemblyString(r, o)}"
}).stripPrefix(" ")
s" $op A,$ps"
case op =>
val os = op.toString//.stripSuffix("_16")
val ps = registers match {

View File

@ -0,0 +1,228 @@
package millfork.compiler.z80
import millfork.CompilationFlag
import millfork.assembly.z80.ZLine
import millfork.compiler.CompilationContext
import millfork.env.NumericConstant
import millfork.error.ErrorReporting
import millfork.node.{Expression, LhsExpression, ZRegister}
/**
* @author Karol Stasiak
*/
object Z80DecimalBuiltIns {
def compileByteShiftLeft(ctx: CompilationContext, r: Expression): List[ZLine] = {
import millfork.assembly.z80.ZOpcode._
ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(v, _)) if v < 0 =>
ErrorReporting.error("Cannot shift by a negative amount", r.position)
Nil
case Some(NumericConstant(v, _)) =>
List.fill(v.toInt)(List(ZLine.register(ADD, ZRegister.A), ZLine.implied(DAA))).flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount", r.position)
Nil
}
}
def compileWordShiftLeft(ctx: CompilationContext, r: Expression): List[ZLine] = {
import millfork.assembly.z80.ZOpcode._
import ZRegister._
ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(v, _)) if v < 0 =>
ErrorReporting.error("Cannot shift by a negative amount", r.position)
Nil
case Some(NumericConstant(v, _)) =>
List.fill(v.toInt)(List(
ZLine.ld8(A, L),
ZLine.register(ADD, ZRegister.A),
ZLine.implied(DAA),
ZLine.ld8(L, A),
ZLine.ld8(A, H),
ZLine.register(ADC, ZRegister.A),
ZLine.implied(DAA),
ZLine.ld8(H, A)
)).flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount", r.position)
Nil
}
}
def compileInPlaceShiftLeft(ctx: CompilationContext, l: LhsExpression, r: Expression, size: Int): List[ZLine] = {
import millfork.assembly.z80.ZOpcode._
ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(v, _)) if v < 0 =>
ErrorReporting.error("Cannot shift by a negative amount", r.position)
Nil
case Some(NumericConstant(v, _)) =>
List.fill(v.toInt)(ZBuiltIns.performLongInPlace(ctx, l, l, ADD, ADC, size, decimal = true)).flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount", r.position)
Nil
}
}
def compileInPlaceByteMultiplication(ctx: CompilationContext, r: Expression): List[ZLine] = {
val multiplier = ctx.env.eval(r) match {
case Some(NumericConstant(v, _)) =>
if (v.&(0xf0) > 0x90 || v.&(0xf) > 9)
ErrorReporting.error("Invalid decimal constant", r.position)
(v.&(0xf0).>>(4) * 10 + v.&(0xf)).toInt
case _ =>
ErrorReporting.error("Cannot multiply by a non-constant amount", r.position)
return Nil
}
import millfork.assembly.z80.ZOpcode._
import ZRegister._
multiplier match {
case 0 => List(ZLine.ldImm8(A, 0))
case 1 => Nil
case x =>
val add1 = List(ZLine.register(ADD, D), ZLine.implied(DAA), ZLine.ld8(E, A))
val times10 = List(ZLine.register(RL, A), ZLine.register(RL, A), ZLine.register(RL, A), ZLine.register(RL, A), ZLine.imm8(AND, 0xf0))
// TODO: rethink this:
val ways = if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) waysOptimizedForCycles else waysOptimizedForBytes
ZLine.ld8(D, A) :: ZLine.ld8(E, A) :: ways(x).flatMap {
case 1 => add1
case q if q < 10 => List.fill(q - 1)(List(ZLine.register(ADD, E), ZLine.implied(DAA))).flatten :+ ZLine.ld8(E, A)
case q if q >= 10 => times10 ++ List.fill(q - 10)(List(ZLine.register(ADD, E), ZLine.implied(DAA))).flatten :+ ZLine.ld8(E, A)
}
}
}
private lazy val waysOptimizedForCycles: Map[Int, List[Int]] = Map(
2 -> List(2), 3 -> List(3), 4 -> List(2,2), 5 -> List(2,2,1), 6 -> List(3,2), 7 -> List(3,2,1), 8 -> List(2,2,2), 9 -> List(3,3), 10 -> List(10),
11 -> List(11), 12 -> List(12), 13 -> List(13), 14 -> List(14), 15 -> List(3,5), 16 -> List(2,2,2,2), 17 -> List(2,2,2,2,1), 18 -> List(3,3,2), 19 -> List(3,3,2,1), 20 -> List(2,10),
21 -> List(2,10,1), 22 -> List(11,2), 23 -> List(11,2,1), 24 -> List(12,2), 25 -> List(12,2,1), 26 -> List(2,13), 27 -> List(3,3,3), 28 -> List(2,14), 29 -> List(2,14,1), 30 -> List(3,10),
31 -> List(3,10,1), 32 -> List(2,2,2,2,2), 33 -> List(11,3), 34 -> List(11,3,1), 35 -> List(11,3,1,1), 36 -> List(3,12), 37 -> List(3,12,1), 38 -> List(3,3,2,1,2), 39 -> List(3,13), 40 -> List(2,2,10),
41 -> List(2,2,10,1), 42 -> List(2,10,1,2), 43 -> List(2,10,1,2,1), 44 -> List(11,2,2), 45 -> List(11,2,2,1), 46 -> List(11,2,1,2), 47 -> List(11,2,1,2,1), 48 -> List(12,2,2), 49 -> List(12,2,2,1), 50 -> List(2,2,1,10),
51 -> List(2,2,1,10,1), 52 -> List(2,2,13), 53 -> List(2,2,13,1), 54 -> List(3,3,3,2), 55 -> List(11,5), 56 -> List(11,5,1), 57 -> List(3,3,2,1,3), 58 -> List(2,14,1,2), 59 -> List(2,14,1,2,1), 60 -> List(3,2,10),
61 -> List(3,2,10,1), 62 -> List(3,10,1,2), 63 -> List(2,10,1,3), 64 -> List(2,2,2,2,2,2), 65 -> List(2,2,1,13), 66 -> List(11,3,2), 67 -> List(11,3,2,1), 68 -> List(11,3,1,2), 69 -> List(11,2,1,3), 70 -> List(3,2,1,10),
71 -> List(3,2,1,10,1), 72 -> List(3,12,2), 73 -> List(3,12,2,1), 74 -> List(3,12,1,2), 75 -> List(12,2,1,3), 76 -> List(3,3,2,1,2,2), 77 -> List(3,2,1,11), 78 -> List(3,2,13), 79 -> List(3,2,13,1), 80 -> List(2,2,2,10),
81 -> List(2,2,2,10,1), 82 -> List(2,2,10,1,2), 83 -> List(2,2,10,1,2,1), 84 -> List(2,10,1,2,2), 85 -> List(2,10,1,2,2,1), 86 -> List(2,10,1,2,1,2), 87 -> List(2,10,1,2,1,2,1), 88 -> List(11,2,2,2), 89 -> List(11,2,2,2,1), 90 -> List(3,3,10),
91 -> List(3,3,10,1), 92 -> List(11,2,1,2,2), 93 -> List(3,10,1,3), 94 -> List(3,10,1,3,1), 95 -> List(3,10,1,3,1,1), 96 -> List(12,2,2,2), 97 -> List(12,2,2,2,1), 98 -> List(12,2,2,1,2), 99 -> List(11,3,3),
)
private lazy val waysOptimizedForBytes: Map[Int, List[Int]] = Map(
2 -> List(2), 3 -> List(3), 4 -> List(2,2), 5 -> List(2,2,1), 6 -> List(3,2), 7 -> List(3,2,1), 8 -> List(2,2,2), 9 -> List(3,3), 10 -> List(10),
11 -> List(11), 12 -> List(12), 13 -> List(13), 14 -> List(3,2,1,2), 15 -> List(3,5), 16 -> List(2,2,2,2), 17 -> List(2,2,2,2,1), 18 -> List(3,3,2), 19 -> List(3,3,2,1), 20 -> List(2,10),
21 -> List(2,10,1), 22 -> List(11,2), 23 -> List(11,2,1), 24 -> List(12,2), 25 -> List(12,2,1), 26 -> List(2,13), 27 -> List(3,3,3), 28 -> List(3,2,1,2,2), 29 -> List(3,2,1,2,2,1), 30 -> List(3,10),
31 -> List(3,10,1), 32 -> List(2,2,2,2,2), 33 -> List(11,3), 34 -> List(11,3,1), 35 -> List(11,3,1,1), 36 -> List(3,12), 37 -> List(3,12,1), 38 -> List(3,3,2,1,2), 39 -> List(3,13), 40 -> List(2,2,10),
41 -> List(2,2,10,1), 42 -> List(2,10,1,2), 43 -> List(2,10,1,2,1), 44 -> List(11,2,2), 45 -> List(11,2,2,1), 46 -> List(11,2,1,2), 47 -> List(11,2,1,2,1), 48 -> List(12,2,2), 49 -> List(12,2,2,1), 50 -> List(2,2,1,10),
51 -> List(2,2,1,10,1), 52 -> List(2,2,13), 53 -> List(2,2,13,1), 54 -> List(3,3,3,2), 55 -> List(11,5), 56 -> List(3,2,1,2,2,2), 57 -> List(3,3,2,1,3), 58 -> List(3,2,1,2,2,1,2), 59 -> List(3,2,1,2,2,1,2,1), 60 -> List(3,2,10),
61 -> List(3,2,10,1), 62 -> List(3,10,1,2), 63 -> List(2,10,1,3), 64 -> List(2,2,2,2,2,2), 65 -> List(2,2,2,2,2,2,1), 66 -> List(11,3,2), 67 -> List(11,3,2,1), 68 -> List(11,3,1,2), 69 -> List(11,2,1,3), 70 -> List(3,2,1,10),
71 -> List(3,2,1,10,1), 72 -> List(3,12,2), 73 -> List(3,12,2,1), 74 -> List(3,12,1,2), 75 -> List(12,2,1,3), 76 -> List(3,3,2,1,2,2), 77 -> List(3,2,1,11), 78 -> List(3,2,13), 79 -> List(3,2,13,1), 80 -> List(2,2,2,10),
81 -> List(2,2,2,10,1), 82 -> List(2,2,10,1,2), 83 -> List(2,2,10,1,2,1), 84 -> List(2,10,1,2,2), 85 -> List(2,10,1,2,2,1), 86 -> List(2,10,1,2,1,2), 87 -> List(2,10,1,2,1,2,1), 88 -> List(11,2,2,2), 89 -> List(11,2,2,2,1), 90 -> List(3,3,10),
91 -> List(3,3,10,1), 92 -> List(11,2,1,2,2), 93 -> List(3,10,1,3), 94 -> List(3,10,1,3,1), 95 -> List(3,3,2,1,5), 96 -> List(12,2,2,2), 97 -> List(12,2,2,2,1), 98 -> List(12,2,2,1,2), 99 -> List(11,3,3),
)
private val multiplyCostsCycles = {
Map(
1 -> 12,
2 -> 12,
3 -> 20,
5 -> 36,
10 -> 17,
11 -> 25,
12 -> 33,
13 -> 41,
14 -> 49,
15 -> 57,
16 -> 65,
17 -> 73,
18 -> 81,
19 -> 89
)
}
private val multiplyCostsBytes = {
Map(
1 -> 3,
2 -> 3,
3 -> 5,
5 -> 9,
10 -> 7,
11 -> 9,
12 -> 11,
13 -> 13,
14 -> 15,
15 -> 17,
16 -> 19,
17 -> 21,
18 -> 23,
19 -> 25
)
}
private def findWay(target: Int, costs: Map[Int, Double]): List[Int] = {
def recurse(acc: Int, depthLeft: Int, costAndTrace: (Double, List[Int])): Option[(Double, List[Int])] = {
if (acc == target) return Some(costAndTrace)
if (acc > target) return None
if (depthLeft == 0) return None
val (costSoFar, trace) = costAndTrace
val results = costs.flatMap {
case (key, keyCost) =>
if (key == 1) {
recurse(1 + acc, depthLeft - 1, (costSoFar + keyCost, key :: trace))
} else {
recurse(key.abs * acc, depthLeft - 1, (costSoFar + keyCost, key :: trace))
}
}
if (results.isEmpty) return None
Some(results.minBy(_._1))
}
recurse(1, 9, 0.0 -> Nil).get._2.reverse
}
def main(args: Array[String]): Unit = {
// println((2 to 99).map(ways).map(_.length).max)
println()
println()
val multiplyCostsCycles2 = multiplyCostsCycles.map{case(k,v) => k -> (v + multiplyCostsBytes(k) / 2048.0)}.toMap
val multiplyCostsBytes2 = multiplyCostsBytes.map{case(k,v) => k -> (v + multiplyCostsCycles(k) / 2048.0)}.toMap
val mc = (2 to 99).map{i => i -> findWay(i, multiplyCostsCycles2)}.toMap
val mb = (2 to 99).map{i => i -> findWay(i, multiplyCostsBytes2)}.toMap
for (i <- 2 to 99) {
print(s"$i -> List(${mc(i).mkString(",")}), ")
if (i % 10 == 0) println()
}
println()
println()
for (i <- 2 to 99) {
print(s"$i -> List(${mb(i).mkString(",")}), ")
if (i % 10 == 0) println()
}
println()
println()
for (i <- 2 to 99) {
if (mc(i) != mb(i)) {
println(i)
val c = mc(i)
val b = mb(i)
val cycleCostForC = c.map(multiplyCostsCycles).sum
val cycleCostForB = b.map(multiplyCostsCycles).sum
val byteCostForC = c.map(multiplyCostsBytes).sum
val byteCostForB = b.map(multiplyCostsBytes).sum
println(s"For cycle-optimized addr: $c (cycles: $cycleCostForC); the other would have cycles $cycleCostForB")
println(s"For size-optimized addr: $b (size: $byteCostForB); the other would have size $byteCostForC")
}
}
def weight(i: Int): Double = 300.0 / (i * i)
println("expected byte waste: ", (2 to 99).map(i => (mc(i).map(multiplyCostsBytes).sum - mb(i).map(multiplyCostsBytes).sum) * weight(i)).sum)
println("expected cycles waste: ", (2 to 99).map(i => (mb(i).map(multiplyCostsCycles).sum - mc(i).map(multiplyCostsCycles).sum) * weight(i)).sum)
println((2 to 99).map(mc).map(_.length).max)
println((2 to 99).map(mb).map(_.length).max)
}
}

View File

@ -340,7 +340,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
case "<<'" =>
assertAllArithmeticBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertArithmeticBinary(ctx, params)
???
targetifyA(target, compileToA(ctx, l) ++ Z80DecimalBuiltIns.compileByteShiftLeft(ctx, r), isSigned = false)
case ">>'" =>
assertAllArithmeticBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertArithmeticBinary(ctx, params)
@ -440,7 +440,21 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
}
case "<<'=" =>
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
???
size match {
case 1 =>
calculateAddressToAppropriatePointer(ctx, l) match {
case Some((lvo, code)) =>
code ++
(ZLine.ld8(ZRegister.A, lvo) ::
(Z80DecimalBuiltIns.compileByteShiftLeft(ctx, r) :+ ZLine.ld8(lvo, ZRegister.A)))
case None =>
ErrorReporting.error("Invalid left-hand side", l.position)
Nil
}
case 2 =>
compileToHL(ctx, l) ++ Z80DecimalBuiltIns.compileWordShiftLeft(ctx, r) ++ storeHL(ctx, l, signedSource = false)
case _ => Z80DecimalBuiltIns.compileInPlaceShiftLeft(ctx, l, r, size)
}
case ">>'=" =>
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
???
@ -451,7 +465,15 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
case "*'=" =>
assertAllArithmeticBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params)
???
calculateAddressToAppropriatePointer(ctx, l) match {
case Some((lvo, code)) =>
code ++
(ZLine.ld8(ZRegister.A, lvo) ::
(Z80DecimalBuiltIns.compileInPlaceByteMultiplication(ctx, r) :+ ZLine.ld8(lvo, ZRegister.A)))
case None =>
ErrorReporting.error("Invalid left-hand side", l.position)
Nil
}
case "&=" =>
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {

View File

@ -1,7 +1,7 @@
package millfork.test
import millfork.Cpu
import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUnoptimizedRun}
import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, EmuUnoptimizedRun}
import org.scalatest.{FunSuite, Matchers}
/**
@ -101,7 +101,7 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
}
test("Decimal left shift test") {
val m = EmuUnoptimizedRun(
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| byte output @$c000
| void main () {
@ -110,12 +110,13 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
| output = n <<' 2
| }
| byte nine() { return 9 }
""".stripMargin)
m.readByte(0xc000) should equal(0x36)
""".stripMargin) { m =>
m.readByte(0xc000) should equal(0x36)
}
}
test("Decimal left shift test 2") {
val m = EmuUnoptimizedRun(
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| byte output @$c000
| void main () {
@ -123,12 +124,13 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
| output <<'= 2
| }
| byte nine() { return 9 }
""".stripMargin)
m.readByte(0xc000) should equal(0x36)
""".stripMargin) { m =>
m.readByte(0xc000) should equal(0x36)
}
}
test("Decimal left shift test 3") {
val m = EmuUnoptimizedRun(
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| word output @$c000
| void main () {
@ -136,8 +138,23 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
| output <<'= 2
| }
| byte nine() { return $91 }
""".stripMargin)
m.readWord(0xc000) should equal(0x364)
""".stripMargin) { m =>
m.readWord(0xc000) should equal(0x364)
}
}
test("Decimal left shift test 4") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| long output @$c000
| void main () {
| output = nine()
| output <<'= 2
| }
| byte nine() { return $91 }
""".stripMargin) { m =>
m.readLong(0xc000) should equal(0x364)
}
}
test("Decimal right shift test") {
@ -180,7 +197,7 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
m.readWord(0xc000) should equal(0x91)
}
def toDecimal(v: Int) = {
private def toDecimal(v: Int): Int = {
if (v.&(0xf000) > 0x9000 || v.&(0xf00) > 0x900 || v.&(0xf0) > 0x90 || v.&(0xf) > 9)
fail("Invalid decimal value: " + v.toHexString)
v.&(0xf000).>>(12) * 1000 + v.&(0xf00).>>(8) * 100 + v.&(0xf0).>>(4) * 10 + v.&(0xf)
@ -215,8 +232,9 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
}
test("Decimal byte multiplication comprehensive suite") {
for (i <- List(1, 2, 3, 6, 8, 10, 11, 12, 14, 15, 16, 40, 99) ; j <- 0 to 99) {
val m = EmuUnoptimizedRun(
val numbers = List(0, 1, 2, 3, 6, 8, 10, 11, 12, 14, 15, 16, 20, 40, 73, 81, 82, 98, 99)
for (i <- numbers; j <- numbers) {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| byte output @$c000
| void main () {
@ -225,15 +243,16 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
| }
| void init() { output = $#i }
| void run () { output *'= $#j }
""".stripMargin.replace("#i", i.toString).replace("#j", j.toString))
toDecimal(m.readWord(0xc000)) should equal((i * j) % 100)
""".stripMargin.replace("#i", i.toString).replace("#j", j.toString)) { m =>
toDecimal(m.readByte(0xc000)) should equal((i * j) % 100)
}
}
}
test("Decimal comparison") {
// CMP#0 shouldn't be elided after a decimal operation.
// Currently no emulator used for testing can catch that.
EmuBenchmarkRun(
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(
"""
| byte output @$c000
| void main () {