optimize node renames

This commit is contained in:
Irmen de Jong 2023-01-22 17:36:15 +01:00
parent 6ee270d9d8
commit 4403e4ed62
2 changed files with 69 additions and 80 deletions

View File

@ -24,10 +24,9 @@ class IRCodeGen(
internal val registers = RegisterPool() internal val registers = RegisterPool()
fun generate(): IRProgram { fun generate(): IRProgram {
flattenLabelNames() makeAllNodenamesScoped()
flattenNestedSubroutines() moveAllNestedSubroutinesToBlockScope()
verifyNameScoping(program, symbolTable)
// TODO: validateNames(program)
val irProg = IRProgram(program.name, IRSymbolTable(symbolTable), options, program.encoding) val irProg = IRProgram(program.name, IRSymbolTable(symbolTable), options, program.encoding)
@ -73,23 +72,37 @@ class IRCodeGen(
return irProg return irProg
} }
private fun validateNames(node: PtNode) { private fun verifyNameScoping(program: PtProgram, symbolTable: SymbolTable) {
when(node) { fun verifyPtNode(node: PtNode) {
is PtBuiltinFunctionCall -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} when (node) {
is PtFunctionCall -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtBuiltinFunctionCall -> require('.' !in node.name) { "builtin function call name should not be scoped: ${node.name}" }
is PtIdentifier -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtFunctionCall -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtAsmSub -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtIdentifier -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtBlock -> require('.' !in node.name) { "block name should not be scoped: ${node.name}"} is PtAsmSub -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtConstant -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtBlock -> require('.' !in node.name) { "block name should not be scoped: ${node.name}" }
is PtLabel -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtConstant -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtMemMapped -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtLabel -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtSub -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtMemMapped -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtVariable -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtSub -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtProgram -> require('.' !in node.name) { "program name should not be scoped: ${node.name}"} is PtVariable -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
is PtSubroutineParameter -> require('.' in node.name) { "node $node name is not scoped: ${node.name}"} is PtProgram -> require('.' !in node.name) { "program name should not be scoped: ${node.name}" }
else -> { /* node has no name */ } is PtSubroutineParameter -> require('.' in node.name) { "node $node name is not scoped: ${node.name}" }
else -> { /* node has no name */
}
}
node.children.forEach { verifyPtNode(it) }
} }
node.children.forEach{ validateNames(it) }
fun verifyStNode(node: StNode) {
require('.' !in node.name) { "st node name should not be scoped: ${node.name}"}
node.children.forEach {
require(it.key==it.value.name)
verifyStNode(it.value)
}
}
verifyPtNode(program)
verifyStNode(symbolTable)
} }
private fun ensureFirstChunkLabels(irProg: IRProgram) { private fun ensureFirstChunkLabels(irProg: IRProgram) {
@ -176,79 +189,47 @@ class IRCodeGen(
} }
} }
private fun flattenLabelNames() { private fun makeAllNodenamesScoped() {
val renameLabels = mutableListOf<Pair<PtNode, PtLabel>>() val renames = mutableListOf<Pair<PtNamedNode, String>>()
fun recurse(node: PtNode) {
fun flattenRecurse(node: PtNode) {
node.children.forEach { node.children.forEach {
if (it is PtLabel) if(it is PtNamedNode)
renameLabels += Pair(it.parent, it) renames.add(it to it.scopedName)
else recurse(it)
flattenRecurse(it)
} }
} }
recurse(program)
flattenRecurse(program) renames.forEach { it.first.name = it.second }
renameLabels.forEach { (_, label) -> label.name = label.scopedName }
} }
private fun flattenNestedSubroutines() { private fun moveAllNestedSubroutinesToBlockScope() {
// this moves all nested subroutines up to the block scope. val movedSubs = mutableListOf<Pair<PtBlock, PtSub>>()
// also changes the name to be the fully scoped one, so it becomes unique at the top level. val removedSubs = mutableListOf<Pair<PtSub, PtSub>>()
val flattenedSubs = mutableListOf<Pair<PtBlock, PtSub>>()
val flattenedAsmSubs = mutableListOf<Pair<PtBlock, PtAsmSub>>()
val removalsSubs = mutableListOf<Pair<PtSub, PtSub>>()
val removalsAsmSubs = mutableListOf<Pair<PtSub, PtAsmSub>>()
val renameSubs = mutableListOf<Pair<PtBlock, PtSub>>()
val renameAsmSubs = mutableListOf<Pair<PtBlock, PtAsmSub>>()
fun flattenNestedAsmSub(block: PtBlock, parentSub: PtSub, asmsub: PtAsmSub) { fun moveToBlock(block: PtBlock, parent: PtSub, asmsub: PtAsmSub) {
val flattened = PtAsmSub(asmsub.scopedName, block.add(asmsub)
asmsub.address, parent.children.remove(asmsub)
asmsub.clobbers,
asmsub.parameters,
asmsub.returnTypes,
asmsub.retvalRegisters,
asmsub.inline,
asmsub.position)
asmsub.children.forEach { flattened.add(it) }
flattenedAsmSubs += Pair(block, flattened)
removalsAsmSubs += Pair(parentSub, asmsub)
} }
fun flattenNestedSub(block: PtBlock, parentSub: PtSub, sub: PtSub) { fun moveToBlock(block: PtBlock, parent: PtSub, sub: PtSub) {
sub.children.filterIsInstance<PtSub>().forEach { subsub->flattenNestedSub(block, sub, subsub) } sub.children.filterIsInstance<PtSub>().forEach { subsub -> moveToBlock(block, sub, subsub) }
sub.children.filterIsInstance<PtAsmSub>().forEach { asmsubsub->flattenNestedAsmSub(block, sub, asmsubsub) } sub.children.filterIsInstance<PtAsmSub>().forEach { asmsubsub -> moveToBlock(block, sub, asmsubsub) }
val flattened = PtSub(sub.scopedName, movedSubs += Pair(block, sub)
sub.parameters, removedSubs += Pair(parent, sub)
sub.returntype,
sub.inline,
sub.position)
sub.children.forEach { if(it !is PtSub) flattened.add(it) }
flattenedSubs += Pair(block, flattened)
removalsSubs += Pair(parentSub, sub)
} }
program.allBlocks().forEach { block -> program.allBlocks().forEach { block ->
block.children.forEach { block.children.forEach {
if(it is PtSub) { if (it is PtSub) {
// Only regular subroutines can have nested subroutines. // Only regular subroutines can have nested subroutines.
it.children.filterIsInstance<PtSub>().forEach { subsub->flattenNestedSub(block, it, subsub)} it.children.filterIsInstance<PtSub>().forEach { subsub -> moveToBlock(block, it, subsub) }
it.children.filterIsInstance<PtAsmSub>().forEach { asmsubsub->flattenNestedAsmSub(block, it, asmsubsub)} it.children.filterIsInstance<PtAsmSub>().forEach { asmsubsub -> moveToBlock(block, it, asmsubsub) }
renameSubs += Pair(block, it)
} }
if(it is PtAsmSub)
renameAsmSubs += Pair(block, it)
} }
} }
removalsSubs.forEach { (parent, sub) -> parent.children.remove(sub) } removedSubs.forEach { (parent, sub) -> parent.children.remove(sub) }
removalsAsmSubs.forEach { (parent, asmsub) -> parent.children.remove(asmsub) } movedSubs.forEach { (block, sub) -> block.add(sub) }
flattenedSubs.forEach { (block, sub) -> block.add(sub) }
flattenedAsmSubs.forEach { (block, asmsub) -> block.add(asmsub) }
renameSubs.forEach { (_, sub) -> sub.name = sub.scopedName }
renameAsmSubs.forEach { (_, sub) -> sub.name = sub.scopedName }
} }
internal fun translateNode(node: PtNode): IRCodeChunks { internal fun translateNode(node: PtNode): IRCodeChunks {
@ -463,7 +444,8 @@ class IRCodeGen(
} }
is PtIdentifier -> { is PtIdentifier -> {
val iterableVar = symbolTable.lookup(iterable.name) as StStaticVariable val iterableVar = symbolTable.lookup(iterable.name) as StStaticVariable
val loopvarSymbol = loopvar.scopedName require(forLoop.variable.name == loopvar.scopedName)
val loopvarSymbol = forLoop.variable.name
val indexReg = registers.nextFree() val indexReg = registers.nextFree()
val tmpReg = registers.nextFree() val tmpReg = registers.nextFree()
val loopLabel = createLabelName() val loopLabel = createLabelName()
@ -526,7 +508,8 @@ class IRCodeGen(
throw AssemblyError("step 0") throw AssemblyError("step 0")
val indexReg = registers.nextFree() val indexReg = registers.nextFree()
val endvalueReg = registers.nextFree() val endvalueReg = registers.nextFree()
val loopvarSymbol = loopvar.scopedName require(forLoop.variable.name == loopvar.scopedName)
val loopvarSymbol = forLoop.variable.name
val loopvarDt = when(loopvar) { val loopvarDt = when(loopvar) {
is StMemVar -> loopvar.dt is StMemVar -> loopvar.dt
is StStaticVariable -> loopvar.dt is StStaticVariable -> loopvar.dt
@ -557,7 +540,8 @@ class IRCodeGen(
private fun translateForInConstantRange(forLoop: PtForLoop, loopvar: StNode): IRCodeChunks { private fun translateForInConstantRange(forLoop: PtForLoop, loopvar: StNode): IRCodeChunks {
val loopLabel = createLabelName() val loopLabel = createLabelName()
val loopvarSymbol = loopvar.scopedName require(forLoop.variable.name == loopvar.scopedName)
val loopvarSymbol = forLoop.variable.name
val indexReg = registers.nextFree() val indexReg = registers.nextFree()
val loopvarDt = when(loopvar) { val loopvarDt = when(loopvar) {
is StMemVar -> loopvar.dt is StMemVar -> loopvar.dt
@ -1244,7 +1228,7 @@ class IRCodeGen(
private fun translate(parameters: List<PtSubroutineParameter>) = private fun translate(parameters: List<PtSubroutineParameter>) =
parameters.map { parameters.map {
val flattenedName = it.definingSub()!!.scopedName + "." + it.name val flattenedName = it.definingSub()!!.name + "." + it.name
val orig = symbolTable.flat.getValue(flattenedName) as StStaticVariable val orig = symbolTable.flat.getValue(flattenedName) as StStaticVariable
IRSubroutine.IRParam(flattenedName, orig.dt) IRSubroutine.IRParam(flattenedName, orig.dt)
} }

View File

@ -12,6 +12,11 @@ main {
str name="irmen" str name="irmen"
sub start() { sub start() {
for cx16.r0 in 0 to 10 {
cx16.r1++
}
txt.print("= 10 ") txt.print("= 10 ")
txt.print_ub(zpvar) txt.print_ub(zpvar)
txt.nl() txt.nl()