1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-08-28 05:29:06 +00:00

Better allocator

This commit is contained in:
Karol Stasiak 2018-01-04 22:58:00 +01:00
parent 76122a2dd7
commit 92999ec490
6 changed files with 164 additions and 88 deletions

View File

@ -7,7 +7,7 @@ import millfork.assembly.Opcode
import millfork.compiler._ import millfork.compiler._
import millfork.error.ErrorReporting import millfork.error.ErrorReporting
import millfork.node._ import millfork.node._
import millfork.output.VariableAllocator import millfork.output.{CompiledMemory, MemoryBank, VariableAllocator}
import scala.collection.mutable import scala.collection.mutable
@ -68,7 +68,7 @@ class Environment(val parent: Option[Environment], val prefix: String) {
case _ => Nil case _ => Nil
}.toList }.toList
def allocateVariables(nf: Option[NormalFunction], callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = { def allocateVariables(nf: Option[NormalFunction], mem: MemoryBank, callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = {
val b = get[Type]("byte") val b = get[Type]("byte")
val p = get[Type]("pointer") val p = get[Type]("pointer")
val params = nf.fold(List[String]()) { f => val params = nf.fold(List[String]()) { f =>
@ -99,7 +99,7 @@ class Environment(val parent: Option[Environment], val prefix: String) {
m.sizeInBytes match { m.sizeInBytes match {
case 2 => case 2 =>
val addr = val addr =
allocator.allocatePointer(callGraph, vertex) allocator.allocatePointer(mem, callGraph, vertex)
onEachVariable(m.name, addr) onEachVariable(m.name, addr)
List( List(
ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
@ -110,13 +110,13 @@ class Environment(val parent: Option[Environment], val prefix: String) {
case 0 => Nil case 0 => Nil
case 2 => case 2 =>
val addr = val addr =
allocator.allocateBytes(callGraph, vertex, options, 2) allocator.allocateBytes(mem, callGraph, vertex, options, 2, initialized = false, writeable = true)
onEachVariable(m.name, addr) onEachVariable(m.name, addr)
List( List(
ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
) )
case count => case count =>
val addr = allocator.allocateBytes(callGraph, vertex, options, count) val addr = allocator.allocateBytes(mem, callGraph, vertex, options, count, initialized = false, writeable = true)
onEachVariable(m.name, addr) onEachVariable(m.name, addr)
List( List(
ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p)
@ -124,7 +124,7 @@ class Environment(val parent: Option[Environment], val prefix: String) {
} }
} }
case f: NormalFunction => case f: NormalFunction =>
f.environment.allocateVariables(Some(f), callGraph, allocator, options, onEachVariable) f.environment.allocateVariables(Some(f), mem, callGraph, allocator, options, onEachVariable)
Nil Nil
case _ => Nil case _ => Nil
}.toList }.toList

View File

@ -25,33 +25,34 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
val mem = new CompiledMemory val mem = new CompiledMemory
val labelMap = mutable.Map[String, Int]() val labelMap = mutable.Map[String, Int]()
val bytesToWriteLater = mutable.ListBuffer[(Int, Constant)]() val bytesToWriteLater = mutable.ListBuffer[(Int, Int, Constant)]()
val wordsToWriteLater = mutable.ListBuffer[(Int, Constant)]() val wordsToWriteLater = mutable.ListBuffer[(Int, Int, Constant)]()
def writeByte(bank: Int, addr: Int, value: Byte): Unit = { def writeByte(bank: Int, addr: Int, value: Byte): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true mem.banks(bank).occupied(addr) = true
mem.banks(bank).initialized(addr) = true
mem.banks(bank).readable(addr) = true mem.banks(bank).readable(addr) = true
mem.banks(bank).output(addr) = value.toByte mem.banks(bank).output(addr) = value.toByte
} }
def writeByte(bank: Int, addr: Int, value: Constant): Unit = { def writeByte(bank: Int, addr: Int, value: Constant): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true mem.banks(bank).occupied(addr) = true
mem.banks(bank).initialized(addr) = true
mem.banks(bank).readable(addr) = true mem.banks(bank).readable(addr) = true
value match { value match {
case NumericConstant(x, _) => case NumericConstant(x, _) =>
if (x > 0xffff) ErrorReporting.error("Byte overflow") if (x > 0xffff) ErrorReporting.error("Byte overflow")
mem.banks(0).output(addr) = x.toByte mem.banks(0).output(addr) = x.toByte
case _ => case _ =>
bytesToWriteLater += addr -> value bytesToWriteLater += ((bank, addr, value))
} }
} }
def writeWord(bank: Int, addr: Int, value: Constant): Unit = { def writeWord(bank: Int, addr: Int, value: Constant): Unit = {
if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects")
mem.banks(bank).occupied(addr) = true mem.banks(bank).occupied(addr) = true
mem.banks(bank).occupied(addr + 1) = true mem.banks(bank).occupied(addr + 1) = true
mem.banks(bank).initialized(addr) = true
mem.banks(bank).initialized(addr + 1) = true
mem.banks(bank).readable(addr) = true mem.banks(bank).readable(addr) = true
mem.banks(bank).readable(addr + 1) = true mem.banks(bank).readable(addr + 1) = true
value match { value match {
@ -60,7 +61,7 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
mem.banks(bank).output(addr) = x.toByte mem.banks(bank).output(addr) = x.toByte
mem.banks(bank).output(addr + 1) = (x >> 8).toByte mem.banks(bank).output(addr + 1) = (x >> 8).toByte
case _ => case _ =>
wordsToWriteLater += addr -> value wordsToWriteLater += ((bank, addr, value))
} }
} }
@ -170,6 +171,8 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
} }
} }
val bank0 = mem.banks(0)
env.allPreallocatables.foreach { env.allPreallocatables.foreach {
case InitializedArray(name, Some(NumericConstant(address, _)), items) => case InitializedArray(name, Some(NumericConstant(address, _)), items) =>
var index = address.toInt var index = address.toInt
@ -178,28 +181,40 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
for (item <- items) { for (item <- items) {
writeByte(0, index, item) writeByte(0, index, item)
assembly.append(" !byte " + item) assembly.append(" !byte " + item)
mem.banks(0).writeable(index) = true bank0.occupied(index) = true
bank0.initialized(index) = true
bank0.writeable(index) = true
bank0.readable(index) = true
index += 1 index += 1
} }
initializedVariablesSize += items.length initializedVariablesSize += items.length
case InitializedArray(name, Some(_), items) => ??? case InitializedArray(name, Some(_), items) => ???
case f: NormalFunction if f.address.isDefined => case f: NormalFunction if f.address.isDefined =>
var index = f.address.get.asInstanceOf[NumericConstant].value.toInt val index = f.address.get.asInstanceOf[NumericConstant].value.toInt
val code = compiledFunctions(f.name) val code = compiledFunctions(f.name)
if (code.nonEmpty) { if (code.nonEmpty) {
labelMap(f.name) = index labelMap(f.name) = index
index = outputFunction(code, index, assembly, options) val end = outputFunction(code, index, assembly, options)
for(i <- index until end) {
bank0.occupied(index) = true
bank0.initialized(index) = true
bank0.readable(index) = true
}
} }
case _ => case _ =>
} }
var index = platform.org var justAfterCode = platform.org
val allocator = platform.allocator
allocator.notifyAboutEndOfCode(platform.org)
env.allPreallocatables.foreach { env.allPreallocatables.foreach {
case f: NormalFunction if f.address.isEmpty && f.name == "main" => case f: NormalFunction if f.address.isEmpty && f.name == "main" =>
val code = compiledFunctions(f.name) val code = compiledFunctions(f.name)
if (code.nonEmpty) { if (code.nonEmpty) {
val size = code.map(_.sizeInBytes).sum
val index = allocator.allocateBytes(bank0, options, size, initialized = true, writeable = false)
labelMap(f.name) = index labelMap(f.name) = index
index = outputFunction(code, index, assembly, options) justAfterCode = outputFunction(code, index, assembly, options)
} }
case _ => case _ =>
} }
@ -207,27 +222,28 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
case f: NormalFunction if f.address.isEmpty && f.name != "main" => case f: NormalFunction if f.address.isEmpty && f.name != "main" =>
val code = compiledFunctions(f.name) val code = compiledFunctions(f.name)
if (code.nonEmpty) { if (code.nonEmpty) {
val size = code.map(_.sizeInBytes).sum
val index = allocator.allocateBytes(bank0, options, size, initialized = true, writeable = false)
labelMap(f.name) = index labelMap(f.name) = index
index = outputFunction(code, index, assembly, options) justAfterCode = outputFunction(code, index, assembly, options)
} }
case _ => case _ =>
} }
env.allPreallocatables.foreach { env.allPreallocatables.foreach {
case InitializedArray(name, None, items) => case InitializedArray(name, None, items) =>
var index = allocator.allocateBytes(bank0, options, items.size, initialized = true, writeable = true)
labelMap(name) = index labelMap(name) = index
assembly.append("* = $" + index.toHexString) assembly.append("* = $" + index.toHexString)
assembly.append(name) assembly.append(name)
for (item <- items) { for (item <- items) {
writeByte(0, index, item) writeByte(0, index, item)
assembly.append(" !byte " + item) assembly.append(" !byte " + item)
mem.banks(0).writeable(index) = true
index += 1 index += 1
} }
initializedVariablesSize += items.length initializedVariablesSize += items.length
justAfterCode = index
case m@InitializedMemoryVariable(name, None, typ, value) => case m@InitializedMemoryVariable(name, None, typ, value) =>
if (options.flags(CompilationFlag.PreventJmpIndirectBug) && (index & 0xff) + typ.size > 0x100) { var index = allocator.allocateBytes(bank0, options, typ.size, initialized = true, writeable = true)
index = (index & 0xffff00) + 0x100
}
labelMap(name) = index labelMap(name) = index
val altName = m.name.stripPrefix(env.prefix) + "`" val altName = m.name.stripPrefix(env.prefix) + "`"
env.things += altName -> ConstantThing(altName, NumericConstant(index, 2), env.get[Type]("pointer")) env.things += altName -> ConstantThing(altName, NumericConstant(index, 2), env.get[Type]("pointer"))
@ -236,37 +252,34 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
for (i <- 0 until typ.size) { for (i <- 0 until typ.size) {
writeByte(0, index, value.subbyte(i)) writeByte(0, index, value.subbyte(i))
assembly.append(" !byte " + value.subbyte(i).quickSimplify) assembly.append(" !byte " + value.subbyte(i).quickSimplify)
mem.banks(0).writeable(index) = true
index += 1 index += 1
} }
initializedVariablesSize += typ.size initializedVariablesSize += typ.size
justAfterCode = index
case _ => case _ =>
} }
val allocator = platform.allocator allocator.notifyAboutEndOfCode(justAfterCode)
allocator.notifyAboutEndOfCode(index) env.allocateVariables(None, bank0, callGraph, allocator, options, labelMap.put)
allocator.onEachByte = { addr =>
mem.banks(0).readable(addr) = true
mem.banks(0).writeable(addr) = true
}
env.allocateVariables(None, callGraph, allocator, options, labelMap.put)
env = rootEnv.allThings env = rootEnv.allThings
for ((addr, b) <- bytesToWriteLater) { for ((bank, addr, b) <- bytesToWriteLater) {
val value = deepConstResolve(b) val value = deepConstResolve(b)
mem.banks(0).output(addr) = value.toByte mem.banks(bank).output(addr) = value.toByte
} }
for ((addr, b) <- wordsToWriteLater) { for ((bank, addr, b) <- wordsToWriteLater) {
val value = deepConstResolve(b) val value = deepConstResolve(b)
mem.banks(0).output(addr) = value.toByte mem.banks(bank).output(addr) = value.toByte
mem.banks(0).output(addr + 1) = value.>>>(8).toByte mem.banks(bank).output(addr + 1) = value.>>>(8).toByte
} }
val start = mem.banks(0).occupied.indexOf(true) for (bank <- mem.banks.keys) {
val end = mem.banks(0).occupied.lastIndexOf(true) val start = mem.banks(bank).initialized.indexOf(true)
val length = end - start + 1 val end = mem.banks(bank).initialized.lastIndexOf(true)
mem.banks(0).start = start val length = end - start + 1
mem.banks(0).end = end mem.banks(bank).start = start
mem.banks(bank).end = end
}
labelMap.toList.sorted.foreach { case (l, v) => labelMap.toList.sorted.foreach { case (l, v) =>
assembly += f"$l%-30s = $$$v%04X" assembly += f"$l%-30s = $$$v%04X"

View File

@ -22,6 +22,7 @@ class MemoryBank {
val output = Array.fill[Byte](1 << 16)(0) val output = Array.fill[Byte](1 << 16)(0)
val occupied = Array.fill(1 << 16)(false) val occupied = Array.fill(1 << 16)(false)
val initialized = Array.fill(1 << 16)(false)
val readable = Array.fill(1 << 16)(false) val readable = Array.fill(1 << 16)(false)
val writeable = Array.fill(1 << 16)(false) val writeable = Array.fill(1 << 16)(false)
var start: Int = 0 var start: Int = 0

View File

@ -11,41 +11,39 @@ import scala.collection.mutable
*/ */
sealed trait ByteAllocator { sealed trait ByteAllocator {
protected def startAt: Int
protected def endBefore: Int
def notifyAboutEndOfCode(org: Int): Unit def notifyAboutEndOfCode(org: Int): Unit
def allocateBytes(count: Int, options: CompilationOptions): Int def findFreeBytes(mem: MemoryBank, count: Int, options: CompilationOptions): Int = {
var lastFree = startAt
var counter = 0
val occupied = mem.occupied
for(i <- startAt until endBefore) {
if (occupied(i) || counter == 0 && count == 2 && i.&(0xff) == 0xff && options.flags(CompilationFlag.PreventJmpIndirectBug)) {
counter = 0
} else {
if (counter == 0) {
lastFree = i
}
counter += 1
if (counter == count) {
return lastFree
}
}
}
ErrorReporting.fatal("Out of high memory")
}
} }
class UpwardByteAllocator(startAt: Int, endBefore: Int) extends ByteAllocator { class UpwardByteAllocator(val startAt: Int, val endBefore: Int) extends ByteAllocator {
private var nextByte = startAt
def allocateBytes(count: Int, options: CompilationOptions): Int = {
if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
val t = nextByte
nextByte += count
if (nextByte > endBefore) {
ErrorReporting.fatal("Out of high memory")
}
t
}
def notifyAboutEndOfCode(org: Int): Unit = () def notifyAboutEndOfCode(org: Int): Unit = ()
} }
class AfterCodeByteAllocator(endBefore: Int) extends ByteAllocator { class AfterCodeByteAllocator(val endBefore: Int) extends ByteAllocator {
var nextByte = 0x200 var startAt = 0x200
def notifyAboutEndOfCode(org: Int): Unit = startAt = org
def allocateBytes(count: Int, options: CompilationOptions): Int = {
if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1
val t = nextByte
nextByte += count
if (nextByte > endBefore) {
ErrorReporting.fatal("Out of high memory")
}
t
}
def notifyAboutEndOfCode(org: Int): Unit = nextByte = org
} }
class VariableAllocator(private var pointers: List[Int], private val bytes: ByteAllocator) { class VariableAllocator(private var pointers: List[Int], private val bytes: ByteAllocator) {
@ -53,30 +51,38 @@ class VariableAllocator(private var pointers: List[Int], private val bytes: Byte
private var pointerMap = mutable.Map[Int, Set[VariableVertex]]() private var pointerMap = mutable.Map[Int, Set[VariableVertex]]()
private var variableMap = mutable.Map[Int, mutable.Map[Int, Set[VariableVertex]]]() private var variableMap = mutable.Map[Int, mutable.Map[Int, Set[VariableVertex]]]()
var onEachByte: (Int => Unit) = _ def allocatePointer(mem: MemoryBank, callGraph: CallGraph, p: VariableVertex): Int = {
// TODO: search for free zeropage locations
def allocatePointer(callGraph: CallGraph, p: VariableVertex): Int = {
pointerMap.foreach { case (addr, alreadyThere) => pointerMap.foreach { case (addr, alreadyThere) =>
if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) { if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) {
pointerMap(addr) += p pointerMap(addr) += p
return addr return addr
} }
} }
pointers match { def pickFreePointer(): Int =
case Nil => pointers match {
ErrorReporting.fatal("Out of zero-page memory") case Nil =>
case next :: rest => ErrorReporting.fatal("Out of zero-page memory")
pointers = rest case next :: rest =>
onEachByte(next) if (mem.occupied(next) || mem.occupied(next + 1)) {
onEachByte(next + 1) pointers = rest
pointerMap(next) = Set(p) pickFreePointer()
next } else {
} pointers = rest
mem.readable(next) = true
mem.readable(next + 1) = true
mem.occupied(next) = true
mem.occupied(next + 1) = true
mem.writeable(next) = true
mem.writeable(next + 1) = true
pointerMap(next) = Set(p)
next
}
}
pickFreePointer()
} }
def allocateByte(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions): Int = allocateBytes(callGraph, p, options, 1) def allocateBytes(mem: MemoryBank, callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int, initialized: Boolean, writeable: Boolean): Int = {
def allocateBytes(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int): Int = {
if (!variableMap.contains(count)) { if (!variableMap.contains(count)) {
variableMap(count) = mutable.Map() variableMap(count) = mutable.Map()
} }
@ -86,11 +92,23 @@ class VariableAllocator(private var pointers: List[Int], private val bytes: Byte
return a return a
} }
} }
val addr = bytes.allocateBytes(count, options) val addr = allocateBytes(mem, options, count, initialized, writeable)
(addr to (addr + count)).foreach(onEachByte)
variableMap(count)(addr) = Set(p) variableMap(count)(addr) = Set(p)
addr addr
} }
def allocateBytes(mem: MemoryBank, options: CompilationOptions, count: Int, initialized: Boolean, writeable: Boolean): Int = {
val addr = bytes.findFreeBytes(mem, count, options)
ErrorReporting.trace(s"allocating $count bytes at $$${addr.toHexString}")
(addr until (addr + count)).foreach { i =>
if (mem.occupied(i)) ErrorReporting.fatal("Overlapping objects")
mem.readable(i) = true
mem.occupied(i) = true
mem.initialized(i) = initialized
mem.writeable(i) = writeable
}
addr
}
def notifyAboutEndOfCode(org: Int): Unit = bytes.notifyAboutEndOfCode(org) def notifyAboutEndOfCode(org: Int): Unit = bytes.notifyAboutEndOfCode(org)
} }

View File

@ -16,6 +16,50 @@ class BasicSymonTest extends FunSuite with Matchers {
""".stripMargin) """.stripMargin)
} }
test("Allocation test") {
val src =
"""
byte output @$c000
void main () {
function()
}
array thing @$20F = [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
void function() {
output = 0
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
output += 1
}
"""
EmuUnoptimizedRun(src).readByte(0xc000) should equal(src.count(_ == '+'))
}
test("Byte assignment") { test("Byte assignment") {
EmuUnoptimizedRun( EmuUnoptimizedRun(
""" """

View File

@ -134,7 +134,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
println(";;; ---------------------------") println(";;; ---------------------------")
assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") } assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") }
val optimizedSize = assembler.mem.banks(0).occupied.count(identity).toLong val optimizedSize = assembler.mem.banks(0).initialized.count(identity).toLong
if (unoptimizedSize == optimizedSize) { if (unoptimizedSize == optimizedSize) {
println(f"Size: $unoptimizedSize%5d B") println(f"Size: $unoptimizedSize%5d B")
} else { } else {