1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-06-09 16:29:34 +00:00

Better allocator

This commit is contained in:
Karol Stasiak 2018-01-04 22:58:00 +01:00
parent 76122a2dd7
commit 92999ec490
6 changed files with 164 additions and 88 deletions

View File

@ -7,7 +7,7 @@ import millfork.assembly.Opcode
import millfork.compiler._
import millfork.error.ErrorReporting
import millfork.node._
import millfork.output.VariableAllocator
import millfork.output.{CompiledMemory, MemoryBank, VariableAllocator}
import scala.collection.mutable
@ -68,7 +68,7 @@ class Environment(val parent: Option[Environment], val prefix: String) {
case _ => Nil
}.toList
def allocateVariables(nf: Option[NormalFunction], callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = {
def allocateVariables(nf: Option[NormalFunction], mem: MemoryBank, callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = {
val b = get[Type]("byte")
val p = get[Type]("pointer")
val params = nf.fold(List[String]()) { f =>
@ -99,7 +99,7 @@ class Environment(val parent: Option[Environment], val prefix: String) {
m.sizeInBytes match {
case 2 =>
val addr =
allocator.allocatePointer(callGraph, vertex)
allocator.allocatePointer(mem, callGraph, vertex)
onEachVariable(m.name, addr)
List(
ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
@ -110,13 +110,13 @@ class Environment(val parent: Option[Environment], val prefix: String) {
case 0 => Nil
case 2 =>
val addr =
allocator.allocateBytes(callGraph, vertex, options, 2)
allocator.allocateBytes(mem, callGraph, vertex, options, 2, initialized = false, writeable = true)
onEachVariable(m.name, addr)
List(
ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
)
case count =>
val addr = allocator.allocateBytes(callGraph, vertex, options, count)
val addr = allocator.allocateBytes(mem, callGraph, vertex, options, count, initialized = false, writeable = true)
onEachVariable(m.name, addr)
List(
ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
@ -124,7 +124,7 @@ class Environment(val parent: Option[Environment], val prefix: String) {
}
}
case f: NormalFunction =>
f.environment.allocateVariables(Some(f), callGraph, allocator, options, onEachVariable)
f.environment.allocateVariables(Some(f), mem, callGraph, allocator, options, onEachVariable)
Nil
case _ => Nil
}.toList

View File

@ -25,33 +25,34 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
val mem = new CompiledMemory
val labelMap = mutable.Map[String, Int]()
val bytesToWriteLater = mutable.ListBuffer[(Int, Constant)]()
val wordsToWriteLater = mutable.ListBuffer[(Int, Constant)]()
val bytesToWriteLater = mutable.ListBuffer[(Int, Int, Constant)]()
val wordsToWriteLater = mutable.ListBuffer[(Int, Int, Constant)]()
def writeByte(bank: Int, addr: Int, value: Byte): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true
mem.banks(bank).initialized(addr) = true
mem.banks(bank).readable(addr) = true
mem.banks(bank).output(addr) = value.toByte
}
def writeByte(bank: Int, addr: Int, value: Constant): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true
mem.banks(bank).initialized(addr) = true
mem.banks(bank).readable(addr) = true
value match {
case NumericConstant(x, _) =>
if (x > 0xffff) ErrorReporting.error("Byte overflow")
mem.banks(0).output(addr) = x.toByte
case _ =>
bytesToWriteLater += addr -> value
bytesToWriteLater += ((bank, addr, value))
}
}
def writeWord(bank: Int, addr: Int, value: Constant): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true
mem.banks(bank).occupied(addr + 1) = true
mem.banks(bank).initialized(addr) = true
mem.banks(bank).initialized(addr + 1) = true
mem.banks(bank).readable(addr) = true
mem.banks(bank).readable(addr + 1) = true
value match {
@ -60,7 +61,7 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
mem.banks(bank).output(addr) = x.toByte
mem.banks(bank).output(addr + 1) = (x >> 8).toByte
case _ =>
wordsToWriteLater += addr -> value
wordsToWriteLater += ((bank, addr, value))
}
}
@ -170,6 +171,8 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
}
}
val bank0 = mem.banks(0)
env.allPreallocatables.foreach {
case InitializedArray(name, Some(NumericConstant(address, _)), items) =>
var index = address.toInt
@ -178,28 +181,40 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
for (item <- items) {
writeByte(0, index, item)
assembly.append(" !byte " + item)
mem.banks(0).writeable(index) = true
bank0.occupied(index) = true
bank0.initialized(index) = true
bank0.writeable(index) = true
bank0.readable(index) = true
index += 1
}
initializedVariablesSize += items.length
case InitializedArray(name, Some(_), items) => ???
case f: NormalFunction if f.address.isDefined =>
var index = f.address.get.asInstanceOf[NumericConstant].value.toInt
val index = f.address.get.asInstanceOf[NumericConstant].value.toInt
val code = compiledFunctions(f.name)
if (code.nonEmpty) {
labelMap(f.name) = index
index = outputFunction(code, index, assembly, options)
val end = outputFunction(code, index, assembly, options)
for(i <- index until end) {
bank0.occupied(index) = true
bank0.initialized(index) = true
bank0.readable(index) = true
}
}
case _ =>
}
var index = platform.org
var justAfterCode = platform.org
val allocator = platform.allocator
allocator.notifyAboutEndOfCode(platform.org)
env.allPreallocatables.foreach {
case f: NormalFunction if f.address.isEmpty && f.name == "main" =>
val code = compiledFunctions(f.name)
if (code.nonEmpty) {
val size = code.map(_.sizeInBytes).sum
val index = allocator.allocateBytes(bank0, options, size, initialized = true, writeable = false)
labelMap(f.name) = index
index = outputFunction(code, index, assembly, options)
justAfterCode = outputFunction(code, index, assembly, options)
}
case _ =>
}
@ -207,27 +222,28 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
case f: NormalFunction if f.address.isEmpty && f.name != "main" =>
val code = compiledFunctions(f.name)
if (code.nonEmpty) {
val size = code.map(_.sizeInBytes).sum
val index = allocator.allocateBytes(bank0, options, size, initialized = true, writeable = false)
labelMap(f.name) = index
index = outputFunction(code, index, assembly, options)
justAfterCode = outputFunction(code, index, assembly, options)
}
case _ =>
}
env.allPreallocatables.foreach {
case InitializedArray(name, None, items) =>
var index = allocator.allocateBytes(bank0, options, items.size, initialized = true, writeable = true)
labelMap(name) = index
assembly.append("* = $" + index.toHexString)
assembly.append(name)
for (item <- items) {
writeByte(0, index, item)
assembly.append(" !byte " + item)
mem.banks(0).writeable(index) = true
index += 1
}
initializedVariablesSize += items.length
justAfterCode = index
case m@InitializedMemoryVariable(name, None, typ, value) =>
if (options.flags(CompilationFlag.PreventJmpIndirectBug) && (index & 0xff) + typ.size > 0x100) {
index = (index & 0xffff00) + 0x100
}
var index = allocator.allocateBytes(bank0, options, typ.size, initialized = true, writeable = true)
labelMap(name) = index
val altName = m.name.stripPrefix(env.prefix) + "`"
env.things += altName -> ConstantThing(altName, NumericConstant(index, 2), env.get[Type]("pointer"))
@ -236,37 +252,34 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
for (i <- 0 until typ.size) {
writeByte(0, index, value.subbyte(i))
assembly.append(" !byte " + value.subbyte(i).quickSimplify)
mem.banks(0).writeable(index) = true
index += 1
}
initializedVariablesSize += typ.size
justAfterCode = index
case _ =>
}
val allocator = platform.allocator
allocator.notifyAboutEndOfCode(index)
allocator.onEachByte = { addr =>
mem.banks(0).readable(addr) = true
mem.banks(0).writeable(addr) = true
}
env.allocateVariables(None, callGraph, allocator, options, labelMap.put)
allocator.notifyAboutEndOfCode(justAfterCode)
env.allocateVariables(None, bank0, callGraph, allocator, options, labelMap.put)
env = rootEnv.allThings
for ((addr, b) <- bytesToWriteLater) {
for ((bank, addr, b) <- bytesToWriteLater) {
val value = deepConstResolve(b)
mem.banks(0).output(addr) = value.toByte
mem.banks(bank).output(addr) = value.toByte
}
for ((addr, b) <- wordsToWriteLater) {
for ((bank, addr, b) <- wordsToWriteLater) {
val value = deepConstResolve(b)
mem.banks(0).output(addr) = value.toByte
mem.banks(0).output(addr + 1) = value.>>>(8).toByte
mem.banks(bank).output(addr) = value.toByte
mem.banks(bank).output(addr + 1) = value.>>>(8).toByte
}
val start = mem.banks(0).occupied.indexOf(true)
val end = mem.banks(0).occupied.lastIndexOf(true)
val length = end - start + 1
mem.banks(0).start = start
mem.banks(0).end = end
for (bank <- mem.banks.keys) {
val start = mem.banks(bank).initialized.indexOf(true)
val end = mem.banks(bank).initialized.lastIndexOf(true)
val length = end - start + 1
mem.banks(bank).start = start
mem.banks(bank).end = end
}
labelMap.toList.sorted.foreach { case (l, v) =>
assembly += f"$l%-30s = $$$v%04X"

View File

@ -22,6 +22,7 @@ class MemoryBank {
val output = Array.fill[Byte](1 << 16)(0)
val occupied = Array.fill(1 << 16)(false)
val initialized = Array.fill(1 << 16)(false)
val readable = Array.fill(1 << 16)(false)
val writeable = Array.fill(1 << 16)(false)
var start: Int = 0

View File

@ -11,41 +11,39 @@ import scala.collection.mutable
*/
sealed trait ByteAllocator {
protected def startAt: Int
protected def endBefore: Int
def notifyAboutEndOfCode(org: Int): Unit
def allocateBytes(count: Int, options: CompilationOptions): Int
def findFreeBytes(mem: MemoryBank, count: Int, options: CompilationOptions): Int = {
var lastFree = startAt
var counter = 0
val occupied = mem.occupied
for(i <- startAt until endBefore) {
if (occupied(i) || counter == 0 && count == 2 && i.&(0xff) == 0xff && options.flags(CompilationFlag.PreventJmpIndirectBug)) {
counter = 0
} else {
if (counter == 0) {
lastFree = i
}
counter += 1
if (counter == count) {
return lastFree
}
}
}
ErrorReporting.fatal("Out of high memory")
}
}
class UpwardByteAllocator(startAt: Int, endBefore: Int) extends ByteAllocator {
private var nextByte = startAt
def allocateBytes(count: Int, options: CompilationOptions): Int = {
if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
val t = nextByte
nextByte += count
if (nextByte > endBefore) {
ErrorReporting.fatal("Out of high memory")
}
t
}
class UpwardByteAllocator(val startAt: Int, val endBefore: Int) extends ByteAllocator {
def notifyAboutEndOfCode(org: Int): Unit = ()
}
class AfterCodeByteAllocator(endBefore: Int) extends ByteAllocator {
var nextByte = 0x200
def allocateBytes(count: Int, options: CompilationOptions): Int = {
if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
val t = nextByte
nextByte += count
if (nextByte > endBefore) {
ErrorReporting.fatal("Out of high memory")
}
t
}
def notifyAboutEndOfCode(org: Int): Unit = nextByte = org
class AfterCodeByteAllocator(val endBefore: Int) extends ByteAllocator {
var startAt = 0x200
def notifyAboutEndOfCode(org: Int): Unit = startAt = org
}
class VariableAllocator(private var pointers: List[Int], private val bytes: ByteAllocator) {
@ -53,30 +51,38 @@ class VariableAllocator(private var pointers: List[Int], private val bytes: Byte
private var pointerMap = mutable.Map[Int, Set[VariableVertex]]()
private var variableMap = mutable.Map[Int, mutable.Map[Int, Set[VariableVertex]]]()
var onEachByte: (Int => Unit) = _
def allocatePointer(callGraph: CallGraph, p: VariableVertex): Int = {
def allocatePointer(mem: MemoryBank, callGraph: CallGraph, p: VariableVertex): Int = {
// TODO: search for free zeropage locations
pointerMap.foreach { case (addr, alreadyThere) =>
if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) {
pointerMap(addr) += p
return addr
}
}
pointers match {
case Nil =>
ErrorReporting.fatal("Out of zero-page memory")
case next :: rest =>
pointers = rest
onEachByte(next)
onEachByte(next + 1)
pointerMap(next) = Set(p)
next
}
def pickFreePointer(): Int =
pointers match {
case Nil =>
ErrorReporting.fatal("Out of zero-page memory")
case next :: rest =>
if (mem.occupied(next) || mem.occupied(next + 1)) {
pointers = rest
pickFreePointer()
} else {
pointers = rest
mem.readable(next) = true
mem.readable(next + 1) = true
mem.occupied(next) = true
mem.occupied(next + 1) = true
mem.writeable(next) = true
mem.writeable(next + 1) = true
pointerMap(next) = Set(p)
next
}
}
pickFreePointer()
}
def allocateByte(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions): Int = allocateBytes(callGraph, p, options, 1)
def allocateBytes(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int): Int = {
def allocateBytes(mem: MemoryBank, callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int, initialized: Boolean, writeable: Boolean): Int = {
if (!variableMap.contains(count)) {
variableMap(count) = mutable.Map()
}
@ -86,11 +92,23 @@ class VariableAllocator(private var pointers: List[Int], private val bytes: Byte
return a
}
}
val addr = bytes.allocateBytes(count, options)
(addr to (addr + count)).foreach(onEachByte)
val addr = allocateBytes(mem, options, count, initialized, writeable)
variableMap(count)(addr) = Set(p)
addr
}
def allocateBytes(mem: MemoryBank, options: CompilationOptions, count: Int, initialized: Boolean, writeable: Boolean): Int = {
val addr = bytes.findFreeBytes(mem, count, options)
ErrorReporting.trace(s"allocating $count bytes at $$${addr.toHexString}")
(addr until (addr + count)).foreach { i =>
if (mem.occupied(i)) ErrorReporting.fatal("Overlapping objects")
mem.readable(i) = true
mem.occupied(i) = true
mem.initialized(i) = initialized
mem.writeable(i) = writeable
}
addr
}
def notifyAboutEndOfCode(org: Int): Unit = bytes.notifyAboutEndOfCode(org)
}

View File

@ -16,6 +16,50 @@ class BasicSymonTest extends FunSuite with Matchers {
""".stripMargin)
}
test("Allocation test") {
val src =
"""
byte output @$c000
void main () {
function()
}
array thing @$20F = [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
void function() {
output = 0
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
}
"""
EmuUnoptimizedRun(src).readByte(0xc000) should equal(src.count(_ == '+'))
}
test("Byte assignment") {
EmuUnoptimizedRun(
"""

View File

@ -134,7 +134,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
println(";;; ---------------------------")
assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") }
val optimizedSize = assembler.mem.banks(0).occupied.count(identity).toLong
val optimizedSize = assembler.mem.banks(0).initialized.count(identity).toLong
if (unoptimizedSize == optimizedSize) {
println(f"Size: $unoptimizedSize%5d B")
} else {