From 92999ec490653df4d154aad5f8dd78a61af14361 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Thu, 4 Jan 2018 22:58:00 +0100 Subject: [PATCH] Better allocator --- src/main/scala/millfork/env/Environment.scala | 12 +- .../scala/millfork/output/Assembler.scala | 83 +++++++------ .../millfork/output/CompiledMemory.scala | 1 + .../millfork/output/VariableAllocator.scala | 110 ++++++++++-------- .../scala/millfork/test/BasicSymonTest.scala | 44 +++++++ src/test/scala/millfork/test/emu/EmuRun.scala | 2 +- 6 files changed, 164 insertions(+), 88 deletions(-) diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index fb05d295..917d3a13 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -7,7 +7,7 @@ import millfork.assembly.Opcode import millfork.compiler._ import millfork.error.ErrorReporting import millfork.node._ -import millfork.output.VariableAllocator +import millfork.output.{CompiledMemory, MemoryBank, VariableAllocator} import scala.collection.mutable @@ -68,7 +68,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { case _ => Nil }.toList - def allocateVariables(nf: Option[NormalFunction], callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = { + def allocateVariables(nf: Option[NormalFunction], mem: MemoryBank, callGraph: CallGraph, allocator: VariableAllocator, options: CompilationOptions, onEachVariable: (String, Int) => Unit): Unit = { val b = get[Type]("byte") val p = get[Type]("pointer") val params = nf.fold(List[String]()) { f => @@ -99,7 +99,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { m.sizeInBytes match { case 2 => val addr = - allocator.allocatePointer(callGraph, vertex) + allocator.allocatePointer(mem, callGraph, vertex) onEachVariable(m.name, addr) List( ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) @@ -110,13 +110,13 @@ class Environment(val parent: Option[Environment], val prefix: String) { case 0 => Nil case 2 => val addr = - allocator.allocateBytes(callGraph, vertex, options, 2) + allocator.allocateBytes(mem, callGraph, vertex, options, 2, initialized = false, writeable = true) onEachVariable(m.name, addr) List( ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) ) case count => - val addr = allocator.allocateBytes(callGraph, vertex, options, count) + val addr = allocator.allocateBytes(mem, callGraph, vertex, options, count, initialized = false, writeable = true) onEachVariable(m.name, addr) List( ConstantThing(m.name.stripPrefix(prefix) + "`", NumericConstant(addr, 2), p) @@ -124,7 +124,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { } } case f: NormalFunction => - f.environment.allocateVariables(Some(f), callGraph, allocator, options, onEachVariable) + f.environment.allocateVariables(Some(f), mem, callGraph, allocator, options, onEachVariable) Nil case _ => Nil }.toList diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala index 7f141b83..feb1995a 100644 --- a/src/main/scala/millfork/output/Assembler.scala +++ b/src/main/scala/millfork/output/Assembler.scala @@ -25,33 +25,34 @@ class Assembler(private val program: Program, private val rootEnv: Environment) val mem = new CompiledMemory val labelMap = mutable.Map[String, Int]() - val bytesToWriteLater = mutable.ListBuffer[(Int, Constant)]() - val wordsToWriteLater = mutable.ListBuffer[(Int, Constant)]() + val bytesToWriteLater = mutable.ListBuffer[(Int, Int, Constant)]() + val wordsToWriteLater = mutable.ListBuffer[(Int, Int, Constant)]() def writeByte(bank: Int, addr: Int, value: Byte): Unit = { - if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects") mem.banks(bank).occupied(addr) = true + mem.banks(bank).initialized(addr) = true mem.banks(bank).readable(addr) = true mem.banks(bank).output(addr) = value.toByte } def writeByte(bank: Int, addr: Int, value: Constant): Unit = { - if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects") mem.banks(bank).occupied(addr) = true + mem.banks(bank).initialized(addr) = true mem.banks(bank).readable(addr) = true value match { case NumericConstant(x, _) => if (x > 0xffff) ErrorReporting.error("Byte overflow") mem.banks(0).output(addr) = x.toByte case _ => - bytesToWriteLater += addr -> value + bytesToWriteLater += ((bank, addr, value)) } } def writeWord(bank: Int, addr: Int, value: Constant): Unit = { - if (mem.banks(bank).occupied(addr)) ErrorReporting.fatal("Overlapping objects") mem.banks(bank).occupied(addr) = true mem.banks(bank).occupied(addr + 1) = true + mem.banks(bank).initialized(addr) = true + mem.banks(bank).initialized(addr + 1) = true mem.banks(bank).readable(addr) = true mem.banks(bank).readable(addr + 1) = true value match { @@ -60,7 +61,7 @@ class Assembler(private val program: Program, private val rootEnv: Environment) mem.banks(bank).output(addr) = x.toByte mem.banks(bank).output(addr + 1) = (x >> 8).toByte case _ => - wordsToWriteLater += addr -> value + wordsToWriteLater += ((bank, addr, value)) } } @@ -170,6 +171,8 @@ class Assembler(private val program: Program, private val rootEnv: Environment) } } + val bank0 = mem.banks(0) + env.allPreallocatables.foreach { case InitializedArray(name, Some(NumericConstant(address, _)), items) => var index = address.toInt @@ -178,28 +181,40 @@ class Assembler(private val program: Program, private val rootEnv: Environment) for (item <- items) { writeByte(0, index, item) assembly.append(" !byte " + item) - mem.banks(0).writeable(index) = true + bank0.occupied(index) = true + bank0.initialized(index) = true + bank0.writeable(index) = true + bank0.readable(index) = true index += 1 } initializedVariablesSize += items.length case InitializedArray(name, Some(_), items) => ??? case f: NormalFunction if f.address.isDefined => - var index = f.address.get.asInstanceOf[NumericConstant].value.toInt + val index = f.address.get.asInstanceOf[NumericConstant].value.toInt val code = compiledFunctions(f.name) if (code.nonEmpty) { labelMap(f.name) = index - index = outputFunction(code, index, assembly, options) + val end = outputFunction(code, index, assembly, options) + for(i <- index until end) { + bank0.occupied(index) = true + bank0.initialized(index) = true + bank0.readable(index) = true + } } case _ => } - var index = platform.org + var justAfterCode = platform.org + val allocator = platform.allocator + allocator.notifyAboutEndOfCode(platform.org) env.allPreallocatables.foreach { case f: NormalFunction if f.address.isEmpty && f.name == "main" => val code = compiledFunctions(f.name) if (code.nonEmpty) { + val size = code.map(_.sizeInBytes).sum + val index = allocator.allocateBytes(bank0, options, size, initialized = true, writeable = false) labelMap(f.name) = index - index = outputFunction(code, index, assembly, options) + justAfterCode = outputFunction(code, index, assembly, options) } case _ => } @@ -207,27 +222,28 @@ class Assembler(private val program: Program, private val rootEnv: Environment) case f: NormalFunction if f.address.isEmpty && f.name != "main" => val code = compiledFunctions(f.name) if (code.nonEmpty) { + val size = code.map(_.sizeInBytes).sum + val index = allocator.allocateBytes(bank0, options, size, initialized = true, writeable = false) labelMap(f.name) = index - index = outputFunction(code, index, assembly, options) + justAfterCode = outputFunction(code, index, assembly, options) } case _ => } env.allPreallocatables.foreach { case InitializedArray(name, None, items) => + var index = allocator.allocateBytes(bank0, options, items.size, initialized = true, writeable = true) labelMap(name) = index assembly.append("* = $" + index.toHexString) assembly.append(name) for (item <- items) { writeByte(0, index, item) assembly.append(" !byte " + item) - mem.banks(0).writeable(index) = true index += 1 } initializedVariablesSize += items.length + justAfterCode = index case m@InitializedMemoryVariable(name, None, typ, value) => - if (options.flags(CompilationFlag.PreventJmpIndirectBug) && (index & 0xff) + typ.size > 0x100) { - index = (index & 0xffff00) + 0x100 - } + var index = allocator.allocateBytes(bank0, options, typ.size, initialized = true, writeable = true) labelMap(name) = index val altName = m.name.stripPrefix(env.prefix) + "`" env.things += altName -> ConstantThing(altName, NumericConstant(index, 2), env.get[Type]("pointer")) @@ -236,37 +252,34 @@ class Assembler(private val program: Program, private val rootEnv: Environment) for (i <- 0 until typ.size) { writeByte(0, index, value.subbyte(i)) assembly.append(" !byte " + value.subbyte(i).quickSimplify) - mem.banks(0).writeable(index) = true index += 1 } initializedVariablesSize += typ.size + justAfterCode = index case _ => } - val allocator = platform.allocator - allocator.notifyAboutEndOfCode(index) - allocator.onEachByte = { addr => - mem.banks(0).readable(addr) = true - mem.banks(0).writeable(addr) = true - } - env.allocateVariables(None, callGraph, allocator, options, labelMap.put) + allocator.notifyAboutEndOfCode(justAfterCode) + env.allocateVariables(None, bank0, callGraph, allocator, options, labelMap.put) env = rootEnv.allThings - for ((addr, b) <- bytesToWriteLater) { + for ((bank, addr, b) <- bytesToWriteLater) { val value = deepConstResolve(b) - mem.banks(0).output(addr) = value.toByte + mem.banks(bank).output(addr) = value.toByte } - for ((addr, b) <- wordsToWriteLater) { + for ((bank, addr, b) <- wordsToWriteLater) { val value = deepConstResolve(b) - mem.banks(0).output(addr) = value.toByte - mem.banks(0).output(addr + 1) = value.>>>(8).toByte + mem.banks(bank).output(addr) = value.toByte + mem.banks(bank).output(addr + 1) = value.>>>(8).toByte } - val start = mem.banks(0).occupied.indexOf(true) - val end = mem.banks(0).occupied.lastIndexOf(true) - val length = end - start + 1 - mem.banks(0).start = start - mem.banks(0).end = end + for (bank <- mem.banks.keys) { + val start = mem.banks(bank).initialized.indexOf(true) + val end = mem.banks(bank).initialized.lastIndexOf(true) + val length = end - start + 1 + mem.banks(bank).start = start + mem.banks(bank).end = end + } labelMap.toList.sorted.foreach { case (l, v) => assembly += f"$l%-30s = $$$v%04X" diff --git a/src/main/scala/millfork/output/CompiledMemory.scala b/src/main/scala/millfork/output/CompiledMemory.scala index 82f91768..5877adf4 100644 --- a/src/main/scala/millfork/output/CompiledMemory.scala +++ b/src/main/scala/millfork/output/CompiledMemory.scala @@ -22,6 +22,7 @@ class MemoryBank { val output = Array.fill[Byte](1 << 16)(0) val occupied = Array.fill(1 << 16)(false) + val initialized = Array.fill(1 << 16)(false) val readable = Array.fill(1 << 16)(false) val writeable = Array.fill(1 << 16)(false) var start: Int = 0 diff --git a/src/main/scala/millfork/output/VariableAllocator.scala b/src/main/scala/millfork/output/VariableAllocator.scala index 7263518c..9bb39d69 100644 --- a/src/main/scala/millfork/output/VariableAllocator.scala +++ b/src/main/scala/millfork/output/VariableAllocator.scala @@ -11,41 +11,39 @@ import scala.collection.mutable */ sealed trait ByteAllocator { + protected def startAt: Int + protected def endBefore: Int + def notifyAboutEndOfCode(org: Int): Unit - def allocateBytes(count: Int, options: CompilationOptions): Int + def findFreeBytes(mem: MemoryBank, count: Int, options: CompilationOptions): Int = { + var lastFree = startAt + var counter = 0 + val occupied = mem.occupied + for(i <- startAt until endBefore) { + if (occupied(i) || counter == 0 && count == 2 && i.&(0xff) == 0xff && options.flags(CompilationFlag.PreventJmpIndirectBug)) { + counter = 0 + } else { + if (counter == 0) { + lastFree = i + } + counter += 1 + if (counter == count) { + return lastFree + } + } + } + ErrorReporting.fatal("Out of high memory") + } } -class UpwardByteAllocator(startAt: Int, endBefore: Int) extends ByteAllocator { - private var nextByte = startAt - - def allocateBytes(count: Int, options: CompilationOptions): Int = { - if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1 - val t = nextByte - nextByte += count - if (nextByte > endBefore) { - ErrorReporting.fatal("Out of high memory") - } - t - } - +class UpwardByteAllocator(val startAt: Int, val endBefore: Int) extends ByteAllocator { def notifyAboutEndOfCode(org: Int): Unit = () } -class AfterCodeByteAllocator(endBefore: Int) extends ByteAllocator { - var nextByte = 0x200 - - def allocateBytes(count: Int, options: CompilationOptions): Int = { - if (count == 2 && (nextByte & 0xff) == 0xff && options.flag(CompilationFlag.PreventJmpIndirectBug)) nextByte += 1 - val t = nextByte - nextByte += count - if (nextByte > endBefore) { - ErrorReporting.fatal("Out of high memory") - } - t - } - - def notifyAboutEndOfCode(org: Int): Unit = nextByte = org +class AfterCodeByteAllocator(val endBefore: Int) extends ByteAllocator { + var startAt = 0x200 + def notifyAboutEndOfCode(org: Int): Unit = startAt = org } class VariableAllocator(private var pointers: List[Int], private val bytes: ByteAllocator) { @@ -53,30 +51,38 @@ class VariableAllocator(private var pointers: List[Int], private val bytes: Byte private var pointerMap = mutable.Map[Int, Set[VariableVertex]]() private var variableMap = mutable.Map[Int, mutable.Map[Int, Set[VariableVertex]]]() - var onEachByte: (Int => Unit) = _ - - def allocatePointer(callGraph: CallGraph, p: VariableVertex): Int = { + def allocatePointer(mem: MemoryBank, callGraph: CallGraph, p: VariableVertex): Int = { + // TODO: search for free zeropage locations pointerMap.foreach { case (addr, alreadyThere) => if (alreadyThere.forall(q => callGraph.canOverlap(p, q))) { pointerMap(addr) += p return addr } } - pointers match { - case Nil => - ErrorReporting.fatal("Out of zero-page memory") - case next :: rest => - pointers = rest - onEachByte(next) - onEachByte(next + 1) - pointerMap(next) = Set(p) - next - } + def pickFreePointer(): Int = + pointers match { + case Nil => + ErrorReporting.fatal("Out of zero-page memory") + case next :: rest => + if (mem.occupied(next) || mem.occupied(next + 1)) { + pointers = rest + pickFreePointer() + } else { + pointers = rest + mem.readable(next) = true + mem.readable(next + 1) = true + mem.occupied(next) = true + mem.occupied(next + 1) = true + mem.writeable(next) = true + mem.writeable(next + 1) = true + pointerMap(next) = Set(p) + next + } + } + pickFreePointer() } - def allocateByte(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions): Int = allocateBytes(callGraph, p, options, 1) - - def allocateBytes(callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int): Int = { + def allocateBytes(mem: MemoryBank, callGraph: CallGraph, p: VariableVertex, options: CompilationOptions, count: Int, initialized: Boolean, writeable: Boolean): Int = { if (!variableMap.contains(count)) { variableMap(count) = mutable.Map() } @@ -86,11 +92,23 @@ class VariableAllocator(private var pointers: List[Int], private val bytes: Byte return a } } - val addr = bytes.allocateBytes(count, options) - (addr to (addr + count)).foreach(onEachByte) + val addr = allocateBytes(mem, options, count, initialized, writeable) variableMap(count)(addr) = Set(p) addr } + def allocateBytes(mem: MemoryBank, options: CompilationOptions, count: Int, initialized: Boolean, writeable: Boolean): Int = { + val addr = bytes.findFreeBytes(mem, count, options) + ErrorReporting.trace(s"allocating $count bytes at $$${addr.toHexString}") + (addr until (addr + count)).foreach { i => + if (mem.occupied(i)) ErrorReporting.fatal("Overlapping objects") + mem.readable(i) = true + mem.occupied(i) = true + mem.initialized(i) = initialized + mem.writeable(i) = writeable + } + addr + } + def notifyAboutEndOfCode(org: Int): Unit = bytes.notifyAboutEndOfCode(org) } diff --git a/src/test/scala/millfork/test/BasicSymonTest.scala b/src/test/scala/millfork/test/BasicSymonTest.scala index e34b7441..931ffb62 100644 --- a/src/test/scala/millfork/test/BasicSymonTest.scala +++ b/src/test/scala/millfork/test/BasicSymonTest.scala @@ -16,6 +16,50 @@ class BasicSymonTest extends FunSuite with Matchers { """.stripMargin) } + test("Allocation test") { + val src = + """ + byte output @$c000 + void main () { + function() + } + array thing @$20F = [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1] + void function() { + output = 0 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + output += 1 + } + """ + EmuUnoptimizedRun(src).readByte(0xc000) should equal(src.count(_ == '+')) + } + test("Byte assignment") { EmuUnoptimizedRun( """ diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala index 9599931d..fdb9f42a 100644 --- a/src/test/scala/millfork/test/emu/EmuRun.scala +++ b/src/test/scala/millfork/test/emu/EmuRun.scala @@ -134,7 +134,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], println(";;; ---------------------------") assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") } - val optimizedSize = assembler.mem.banks(0).occupied.count(identity).toLong + val optimizedSize = assembler.mem.banks(0).initialized.count(identity).toLong if (unoptimizedSize == optimizedSize) { println(f"Size: $unoptimizedSize%5d B") } else {