1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-06-09 16:29:34 +00:00

Decimal shifts

This commit is contained in:
Karol Stasiak 2017-12-27 22:26:30 +01:00
parent 9193a4f035
commit 566631fc5e
3 changed files with 181 additions and 24 deletions

View File

@ -1,4 +1,5 @@
package millfork.compiler
import millfork.{CompilationFlag, CompilationOptions}
import millfork.assembly._
import millfork.env._
@ -24,7 +25,7 @@ object DecimalBuiltIns {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(v, _)) =>
val addition = BuiltIns.compileAddition(ctx, List.fill(1<<v)(false -> l), decimal = true)
val addition = BuiltIns.compileAddition(ctx, List.fill(1 << v)(false -> l), decimal = true)
if (rotate) addition.filterNot(_.opcode == CLC) else addition
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
@ -34,32 +35,78 @@ object DecimalBuiltIns {
def compileByteShiftRight(ctx: CompilationContext, l: Expression, r: Expression, rotate: Boolean): List[AssemblyLine] = {
val b = ctx.env.get[Type]("byte")
MlCompiler.compile(ctx, l, Some((b, RegisterVariable(Register.A, b))), BranchSpec.None) ++ (ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(v, _)) =>
List.fill(v.toInt) {
shiftOrRotateAccumulatorRight(ctx, rotate)
}.flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
Nil
})
}
private def shiftOrRotateAccumulatorRight(ctx: CompilationContext, rotate: Boolean) = {
val constantLabel = MlCompiler.nextLabel("c8")
val skipHiDigit = MlCompiler.nextLabel("ds")
val skipLoDigit = MlCompiler.nextLabel("ds")
val bit = if (ctx.options.flags(CompilationFlag.EmitCmosOpcodes)) {
AssemblyLine.immediate(BIT, 8)
} else {
AssemblyLine.absolute(BIT, Label(constantLabel))
}
List(
if (rotate) AssemblyLine.implied(ROR) else AssemblyLine.implied(LSR),
AssemblyLine.label(constantLabel),
AssemblyLine.implied(PHP),
AssemblyLine.relative(BPL, skipHiDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x30),
AssemblyLine.label(skipHiDigit),
bit,
AssemblyLine.relative(BCS, skipLoDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x3),
AssemblyLine.label(skipLoDigit),
AssemblyLine.implied(PLP))
}
def compileInPlaceLongShiftLeft(ctx: CompilationContext, l: LhsExpression, r: Expression): List[AssemblyLine] = {
ctx.env.eval(r) match {
case Some(NumericConstant(0, _)) =>
Nil
case Some(NumericConstant(1, _)) =>
val constantLabel = MlCompiler.nextLabel("c8")
val skipHiDigit = MlCompiler.nextLabel("ds")
val skipLoDigit = MlCompiler.nextLabel("ds")
val bit = if (ctx.options.flags(CompilationFlag.EmitCmosOpcodes)) {
AssemblyLine.immediate(BIT, 8)
} else {
AssemblyLine.absolute(BIT, Label(constantLabel))
}
case Some(NumericConstant(v, _)) =>
List.fill(v.toInt)(BuiltIns.compileInPlaceWordOrLongAddition(ctx, l, l, decimal = true, subtract = false)).flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
Nil
}
}
def compileInPlaceLongShiftRight(ctx: CompilationContext, l: LhsExpression, r: Expression): List[AssemblyLine] = {
val targetBytes: List[List[AssemblyLine]] = l match {
case v: VariableExpression =>
val variable = ctx.env.get[Variable](v.name)
List.tabulate(variable.typ.size) { i => AssemblyLine.variable(ctx, STA, variable, i) }
case SeparateBytesExpression(h: VariableExpression, l: VariableExpression) =>
val lv = ctx.env.get[Variable](l.name)
val hv = ctx.env.get[Variable](h.name)
List(
if (rotate) AssemblyLine.implied(ROR) else AssemblyLine.implied(LSR),
AssemblyLine.label(constantLabel),
AssemblyLine.implied(PHP),
AssemblyLine.relative(BPL, skipHiDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x30),
AssemblyLine.label(skipHiDigit),
bit,
AssemblyLine.relative(BPL, skipLoDigit),
AssemblyLine.implied(SEC),
AssemblyLine.immediate(SBC, 0x3),
AssemblyLine.label(skipLoDigit),
AssemblyLine.implied(PLP))
AssemblyLine.variable(ctx, STA, lv),
AssemblyLine.variable(ctx, STA, hv))
}
ctx.env.eval(r) match {
case Some(NumericConstant(v, _)) =>
val size = targetBytes.length
List.fill(v.toInt) {
List.tabulate(size) { i =>
BuiltIns.staTo(LDA, targetBytes(size - 1 - i)) ++
shiftOrRotateAccumulatorRight(ctx, i != 0) ++
targetBytes(size - 1 - i)
}.flatten
}.flatten
case _ =>
ErrorReporting.error("Cannot shift by a non-constant amount")
Nil

View File

@ -892,6 +892,14 @@ object MlCompiler {
assertAllBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertBinary(ctx, params)
BuiltIns.compileShiftOps(LSR, ctx, l, r)
case "<<'" =>
assertAllBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertBinary(ctx, params)
DecimalBuiltIns.compileByteShiftLeft(ctx, l, r, rotate = false)
case ">>'" =>
assertAllBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertBinary(ctx, params)
DecimalBuiltIns.compileByteShiftRight(ctx, l, r, rotate = false)
case "<" =>
// TODO: signed
val (l, r, size, signed) = assertComparison(ctx, params)
@ -1018,6 +1026,28 @@ object MlCompiler {
BuiltIns.compileInPlaceWordOrLongShiftOps(ctx, v, r, aslRatherThanLsr = false)
}
}
case "<<'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
size match {
case 1 =>
DecimalBuiltIns.compileByteShiftLeft(ctx, l, r, rotate = false) ++ compileByteStorage(ctx, Register.A, l)
case i if i >= 2 =>
l match {
case v: LhsExpression =>
DecimalBuiltIns.compileInPlaceLongShiftLeft(ctx, v, r)
}
}
case ">>'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
size match {
case 1 =>
DecimalBuiltIns.compileByteShiftRight(ctx, l, r, rotate = false) ++ compileByteStorage(ctx, Register.A, l)
case i if i >= 2 =>
l match {
case v: LhsExpression =>
DecimalBuiltIns.compileInPlaceLongShiftRight(ctx, v, r)
}
}
case "*=" =>
assertAllBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertAssignmentLike(ctx, params)

View File

@ -1,6 +1,6 @@
package millfork.test
import millfork.test.emu.EmuBenchmarkRun
import millfork.test.emu.{EmuBenchmarkRun, EmuUnoptimizedRun}
import org.scalatest.{FunSuite, Matchers}
/**
@ -98,4 +98,84 @@ class ByteDecimalMathSuite extends FunSuite with Matchers {
| byte addDecimalTwice(byte a, byte b) { return (a + b) +' (a + b) }
""".stripMargin)(_.readByte(0xc000) should equal(0x36))
}
test("Decimal left shift test") {
val m = EmuUnoptimizedRun(
"""
| byte output @$c000
| void main () {
| byte n
| n = nine()
| output = n <<' 2
| }
| byte nine() { return 9 }
""".stripMargin)
m.readByte(0xc000) should equal(0x36)
}
test("Decimal left shift test 2") {
val m = EmuUnoptimizedRun(
"""
| byte output @$c000
| void main () {
| output = nine()
| output <<'= 2
| }
| byte nine() { return 9 }
""".stripMargin)
m.readByte(0xc000) should equal(0x36)
}
test("Decimal left shift test 3") {
val m = EmuUnoptimizedRun(
"""
| word output @$c000
| void main () {
| output = nine()
| output <<'= 2
| }
| byte nine() { return $91 }
""".stripMargin)
m.readWord(0xc000) should equal(0x364)
}
test("Decimal right shift test") {
val m = EmuUnoptimizedRun(
"""
| byte output @$c000
| void main () {
| byte n
| n = thirty_six()
| output = n >>' 2
| }
| byte thirty_six() { return $36 }
""".stripMargin)
m.readByte(0xc000) should equal(9)
}
test("Decimal right shift test 2") {
val m = EmuUnoptimizedRun(
"""
| byte output @$c000
| void main () {
| output = thirty_six()
| output >>'= 2
| }
| byte thirty_six() { return $36 }
""".stripMargin)
m.readByte(0xc000) should equal(9)
}
test("Decimal right shift test 3") {
val m = EmuUnoptimizedRun(
"""
| word output @$c000
| void main () {
| output = thirty_six()
| output >>'= 2
| }
| word thirty_six() { return $364 }
""".stripMargin)
m.readWord(0xc000) should equal(0x91)
}
}