millfork/src/main/scala/millfork/assembly/opt/QuantumFlowAnalyzer.scala

426 lines
18 KiB
Scala

package millfork.assembly.opt
import millfork.assembly.{AssemblyLine, OpcodeClasses}
import millfork.env.{Label, MemoryAddressConstant, NormalFunction, NumericConstant}
import scala.collection.immutable.BitSet
/**
* @author Karol Stasiak
*/
object QCpuStatus {
val InitialStatus = QCpuStatus((for {
c <- Seq(true, false)
v <- Seq(true, false)
n <- Seq(true, false)
z <- Seq(true, false)
} yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.NoEquality)).toMap)
val UnknownStatus = QCpuStatus((for {
c <- Seq(true, false)
v <- Seq(true, false)
n <- Seq(true, false)
z <- Seq(true, false)
} yield QFlagStatus(c = c, d = false, v = v, n = n, z = z) -> QRegStatus(a = QRegStatus.AllValues, x = QRegStatus.AllValues, y = QRegStatus.AllValues, equal = RegEquality.UnknownEquality)).toMap)
def gather(l: List[(QFlagStatus, QRegStatus)]) =
QCpuStatus(l.groupBy(_._1).
map { case (k, vs) => k -> vs.map(_._2).reduce(_ ++ _) }.
filterNot(_._2.isEmpty))
}
case class QCpuStatus(data: Map[QFlagStatus, QRegStatus]) {
def collapse: CpuStatus = {
val registers = data.values.reduce(_ ++ _)
def bitset(b: BitSet): Status[Int] = if (b.size == 1) SingleStatus(b.head) else AnyStatus()
def flag(f: QFlagStatus => Boolean): Status[Boolean] =
if (data.keys.forall(k => f(k))) SingleStatus(true)
else if (data.keys.forall(k => !f(k))) SingleStatus(false)
else AnyStatus()
CpuStatus(
a = bitset(registers.a),
x = bitset(registers.x),
y = bitset(registers.y),
c = flag(_.c),
d = flag(_.d),
v = flag(_.v),
z = flag(_.z),
n = flag(_.n),
)
}
def changeFlagUnconditionally(f: QFlagStatus => QFlagStatus): QCpuStatus = {
QCpuStatus.gather(data.toList.map { case (k, v) => f(k) -> v })
}
def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus): QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v) })
}
def changeFlagsInAnUnknownWay(f: QFlagStatus => QFlagStatus, g: QFlagStatus => QFlagStatus, h: QFlagStatus => QFlagStatus): QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) => List(f(k) -> v, g(k) -> v, h(k) -> v) })
}
def mapRegisters(f: QRegStatus => QRegStatus): QCpuStatus = {
QCpuStatus(data.map { case (k, v) => k -> f(v) })
}
def mapRegisters(f: (QFlagStatus, QRegStatus) => QRegStatus): QCpuStatus = {
QCpuStatus(data.map { case (k, v) => k -> f(k, v) })
}
def flatMap(f: (QFlagStatus, QRegStatus) => List[(QFlagStatus, QRegStatus)]): QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) => f(k, v) })
}
def changeNZFromA: QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
List(
k.copy(n = false, z = false) -> v.whereA(i => i.toByte > 0),
k.copy(n = true, z = false) -> v.whereA(i => i.toByte < 0),
k.copy(n = false, z = true) -> v.whereA(i => i.toByte == 0))
})
}
def changeNZFromX: QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
List(
k.copy(n = false, z = false) -> v.whereX(i => i.toByte > 0),
k.copy(n = true, z = false) -> v.whereX(i => i.toByte < 0),
k.copy(n = false, z = true) -> v.whereX(i => i.toByte == 0))
})
}
def changeNZFromY: QCpuStatus = {
QCpuStatus.gather(data.toList.flatMap { case (k, v) =>
List(
k.copy(n = false, z = false) -> v.whereY(i => i.toByte > 0),
k.copy(n = true, z = false) -> v.whereY(i => i.toByte < 0),
k.copy(n = false, z = true) -> v.whereY(i => i.toByte == 0))
})
}
def ~(that: QCpuStatus): QCpuStatus = QCpuStatus.gather(this.data.toList ++ that.data.toList)
}
object QRegStatus {
val NoValues: BitSet = BitSet.empty
val AllValues: BitSet = BitSet.fromBitMask(Array(-1L, -1L, -1L, -1L))
}
object RegEquality extends Enumeration {
val NoEquality, AX, AY, XY, AXY, UnknownEquality = Value
def or(a: Value, b: Value) = {
(a, b) match {
case (UnknownEquality, _) => b
case (_, UnknownEquality) => a
case (NoEquality, _) => NoEquality
case (_, NoEquality) => NoEquality
case (_, _) if a == b => a
case (AXY, _) => b
case (_, AXY) => a
case _ => NoEquality
}
}
def afterTransfer(a: Value, b: Value) = {
(a, b) match {
case (UnknownEquality, _) => b
case (_, UnknownEquality) => a
case (NoEquality, _) => b
case (_, NoEquality) => a
case (_, _) if a == b => a
case _ => AXY
}
}
}
case class QRegStatus(a: BitSet, x: BitSet, y: BitSet, equal: RegEquality.Value) {
def isEmpty: Boolean = a.isEmpty || x.isEmpty || y.isEmpty
def ++(that: QRegStatus) = QRegStatus(
a = a ++ that.a,
x = x ++ that.x,
y = y ++ that.y,
equal = RegEquality.or(equal, that.equal))
def afterTransfer(transfer: RegEquality.Value): QRegStatus =
copy(equal = RegEquality.afterTransfer(equal, transfer))
def changeA(f: Int => Long): QRegStatus = {
val newA = a.map(i => f(i).toInt & 0xff)
val newEqual = equal match {
case RegEquality.XY => RegEquality.XY
case RegEquality.AXY => RegEquality.XY
case _ => RegEquality.NoEquality
}
QRegStatus(newA, x, y, newEqual)
}
def changeX(f: Int => Long): QRegStatus = {
val newA = a.map(i => f(i).toInt & 0xff)
val newEqual = equal match {
case RegEquality.XY => RegEquality.XY
case RegEquality.AXY => RegEquality.XY
case _ => RegEquality.NoEquality
}
QRegStatus(newA, x, y, newEqual)
}
def changeY(f: Int => Long): QRegStatus = {
val newA = a.map(i => f(i).toInt & 0xff)
val newEqual = equal match {
case RegEquality.XY => RegEquality.XY
case RegEquality.AXY => RegEquality.XY
case _ => RegEquality.NoEquality
}
QRegStatus(newA, x, y, newEqual)
}
def whereA(f: Int => Boolean): QRegStatus =
equal match {
case RegEquality.AXY =>
copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
case RegEquality.AY =>
copy(a = a.filter(f), y = y.filter(f))
case RegEquality.AX =>
copy(a = a.filter(f), x = x.filter(f))
case _ =>
copy(a = a.filter(f))
}
def whereX(f: Int => Boolean): QRegStatus =
equal match {
case RegEquality.AXY =>
copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
case RegEquality.XY =>
copy(x = x.filter(f), y = y.filter(f))
case RegEquality.AX =>
copy(a = a.filter(f), x = x.filter(f))
case _ =>
copy(x = x.filter(f))
}
def whereY(f: Int => Boolean): QRegStatus =
equal match {
case RegEquality.AXY =>
copy(a = a.filter(f), x = x.filter(f), y = y.filter(f))
case RegEquality.AY =>
copy(a = a.filter(f), y = y.filter(f))
case RegEquality.XY =>
copy(x = x.filter(f), y = y.filter(f))
case _ =>
copy(y = y.filter(f))
}
}
case class QFlagStatus(c: Boolean, d: Boolean, v: Boolean, z: Boolean, n: Boolean)
object QuantumFlowAnalyzer {
private def loBit(b: Boolean) = if (b) 1 else 0
private def hiBit(b: Boolean) = if (b) 0x80 else 0
//noinspection RedundantNewCaseClass
def analyze(f: NormalFunction, code: List[AssemblyLine]): List[QCpuStatus] = {
val flagArray = Array.fill[QCpuStatus](code.length)(QCpuStatus.UnknownStatus)
val codeArray = code.toArray
var changed = true
while (changed) {
changed = false
var currentStatus: QCpuStatus = if (f.interrupt) QCpuStatus.UnknownStatus else QCpuStatus.UnknownStatus
for (i <- codeArray.indices) {
import millfork.assembly.Opcode._
import millfork.assembly.AddrMode._
if (flagArray(i) != currentStatus) {
changed = true
flagArray(i) = currentStatus
}
codeArray(i) match {
case AssemblyLine(LABEL, _, MemoryAddressConstant(Label(l)), _) =>
val L = l
currentStatus = codeArray.indices.flatMap(j => codeArray(j) match {
case AssemblyLine(_, _, MemoryAddressConstant(Label(L)), _) => Some(flagArray(j))
case _ => None
}).fold(QCpuStatus.UnknownStatus)(_ ~ _)
case AssemblyLine(BCC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true))
case AssemblyLine(BCS, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false))
case AssemblyLine(BVS, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false))
case AssemblyLine(BVC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = true))
case AssemblyLine(BMI, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = false))
case AssemblyLine(BPL, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(n = true))
case AssemblyLine(BEQ, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = false))
case AssemblyLine(BNE, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(z = true))
case AssemblyLine(SED, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = true))
case AssemblyLine(SEC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = true))
case AssemblyLine(CLD, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(d = false))
case AssemblyLine(CLC, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(c = false))
case AssemblyLine(CLV, _, _, _) =>
currentStatus = currentStatus.changeFlagUnconditionally(f => f.copy(v = false))
case AssemblyLine(JSR, _, _, _) =>
currentStatus = QCpuStatus.InitialStatus
case AssemblyLine(LDX, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeX(_ => n)).changeNZFromX
case AssemblyLine(LDY, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeY(_ => n)).changeNZFromY
case AssemblyLine(LDA, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n)).changeNZFromA
case AssemblyLine(LAX, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ => n).changeX(_ => n).afterTransfer(RegEquality.AX)).changeNZFromA
case AssemblyLine(EOR, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ ^ n)).changeNZFromA
case AssemblyLine(AND, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA
case AssemblyLine(ANC, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ & n)).changeNZFromA.changeFlagUnconditionally(f => f.copy(c = f.z))
case AssemblyLine(ORA, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeA(_ | n)).changeNZFromA
case AssemblyLine(INX, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeX(_ + 1)).changeNZFromX
case AssemblyLine(DEX, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeX(_ - 1)).changeNZFromX
case AssemblyLine(INY, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY
case AssemblyLine(DEY, Implied, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.changeY(_ - 1)).changeNZFromY
case AssemblyLine(TAX, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(x = r.a).afterTransfer(RegEquality.AX)).changeNZFromX
case AssemblyLine(TXA, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.x).afterTransfer(RegEquality.AX)).changeNZFromA
case AssemblyLine(TAY, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(y = r.a).afterTransfer(RegEquality.AY)).changeNZFromY
case AssemblyLine(TYA, _, _, _) =>
currentStatus = currentStatus.mapRegisters(r => r.copy(a = r.y).afterTransfer(RegEquality.AY)).changeNZFromA
case AssemblyLine(ROL, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2 + loBit(f.c)),
f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2 + loBit(f.c)),
)).changeNZFromA
case AssemblyLine(ROR, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)),
f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f | hiBit(f.c)),
)).changeNZFromA
case AssemblyLine(ASL, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 0x80) != 0).changeA(a => a * 2),
f.copy(c = false) -> r.whereA(a => (a & 0x80) == 0).changeA(a => a * 2),
)).changeNZFromA
case AssemblyLine(LSR, Implied, _, _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & 1) != 0).changeA(a => (a >>> 2) & 0x7f),
f.copy(c = false) -> r.whereA(a => (a & 1) == 0).changeA(a => (a >>> 2) & 0x7f),
)).changeNZFromA
case AssemblyLine(ALR, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.flatMap((f, r) => List(
f.copy(c = true) -> r.whereA(a => (a & n & 1) != 0).changeA(a => ((a & n) >>> 2) & 0x7f),
f.copy(c = false) -> r.whereA(a => (a & n & 1) == 0).changeA(a => ((a & n) >>> 2) & 0x7f),
)).changeNZFromA
case AssemblyLine(ADC, Immediate, NumericConstant(nn, _), _) =>
val n = nn & 0xff
currentStatus = currentStatus.flatMap((f, r) =>
if (f.d) {
val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong)
List(
f.copy(c = false, v = false) -> regs,
f.copy(c = true, v = false) -> regs,
f.copy(c = false, v = true) -> regs,
f.copy(c = true, v = true) -> regs,
)
} else {
if (f.c) {
val regs = r.changeA(_ + n + 1)
List(
f.copy(c = false, v = false) -> regs.whereA(_ >= n),
f.copy(c = true, v = false) -> regs.whereA(_ < n),
f.copy(c = false, v = true) -> regs.whereA(_ >= n),
f.copy(c = true, v = true) -> regs.whereA(_ < n),
)
} else {
val regs = r.changeA(_ + n)
List(
f.copy(c = false, v = false) -> regs.whereA(_ > n),
f.copy(c = true, v = false) -> regs.whereA(_ <= n),
f.copy(c = false, v = true) -> regs.whereA(_ > n),
f.copy(c = true, v = true) -> regs.whereA(_ <= n),
)
}
}
).changeNZFromA
case AssemblyLine(SBC, Immediate, NumericConstant(n, _), _) =>
currentStatus = currentStatus.flatMap((f, r) =>
if (f.d) {
val regs = r.copy(a = QRegStatus.AllValues).changeA(_.toLong)
// TODO: guess the carry flag correctly
List(
f.copy(c = false, v = false) -> regs,
f.copy(c = true, v = false) -> regs,
f.copy(c = false, v = true) -> regs,
f.copy(c = true, v = true) -> regs,
)
} else {
val regs = if (f.c) r.changeA(_ - n) else r.changeA(_ - n - 1)
List(
f.copy(c = false, v = false) -> regs,
f.copy(c = true, v = false) -> regs,
f.copy(c = false, v = true) -> regs,
f.copy(c = true, v = true) -> regs,
)
}
).changeNZFromA
case AssemblyLine(opcode, addrMode, parameter, _) =>
if (OpcodeClasses.ChangesX(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(x = QRegStatus.AllValues))
if (OpcodeClasses.ChangesY(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(y = QRegStatus.AllValues))
if (OpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues))
if (addrMode == Implied && OpcodeClasses.ChangesAIfImplied(opcode)) currentStatus = currentStatus.mapRegisters(r => r.copy(a = QRegStatus.AllValues))
if (OpcodeClasses.ChangesNAndZ(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(
_.copy(n = false, z = false),
_.copy(n = true, z = false),
_.copy(n = false, z = true))
if (OpcodeClasses.ChangesC(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(c = false), _.copy(c = true))
if (OpcodeClasses.ChangesV(opcode)) currentStatus = currentStatus.changeFlagsInAnUnknownWay(_.copy(v = false), _.copy(v = true))
if (opcode == CMP || opcode == CPX || opcode == CPY) {
if (addrMode == Immediate) parameter match {
case NumericConstant(0, _) => currentStatus = currentStatus.changeFlagUnconditionally(_.copy(c = true))
case _ => ()
}
}
}
}
// flagArray.zip(codeArray).foreach{
// case (fl, y) => if (y.isPrintable) println(f"$fl%-32s $y%-32s")
// }
// println("---------------------")
}
flagArray.toList
}
}