mirror of
https://github.com/irmen/prog8.git
synced 2024-06-01 06:41:42 +00:00
200 lines
8.3 KiB
Kotlin
200 lines
8.3 KiB
Kotlin
package prog8.optimizing
|
|
|
|
import prog8.ast.*
|
|
import prog8.ast.base.ParentSentinel
|
|
import prog8.ast.base.VarDeclType
|
|
import prog8.ast.base.initvarsSubName
|
|
import prog8.ast.expressions.FunctionCall
|
|
import prog8.ast.expressions.IdentifierReference
|
|
import prog8.ast.processing.IAstProcessor
|
|
import prog8.ast.statements.*
|
|
import prog8.compiler.loadAsmIncludeFile
|
|
|
|
|
|
class CallGraph(private val program: Program): IAstProcessor {
|
|
|
|
val modulesImporting = mutableMapOf<Module, List<Module>>().withDefault { mutableListOf() }
|
|
val modulesImportedBy = mutableMapOf<Module, List<Module>>().withDefault { mutableListOf() }
|
|
val subroutinesCalling = mutableMapOf<INameScope, List<Subroutine>>().withDefault { mutableListOf() }
|
|
val subroutinesCalledBy = mutableMapOf<Subroutine, List<Node>>().withDefault { mutableListOf() }
|
|
val usedSymbols = mutableSetOf<IStatement>()
|
|
|
|
init {
|
|
process(program)
|
|
}
|
|
|
|
fun forAllSubroutines(scope: INameScope, sub: (s: Subroutine) -> Unit) {
|
|
fun findSubs(scope: INameScope) {
|
|
scope.statements.forEach {
|
|
if(it is Subroutine)
|
|
sub(it)
|
|
if(it is INameScope)
|
|
findSubs(it)
|
|
}
|
|
}
|
|
findSubs(scope)
|
|
}
|
|
|
|
override fun process(program: Program) {
|
|
super.process(program)
|
|
|
|
program.modules.forEach {
|
|
it.importedBy.clear()
|
|
it.imports.clear()
|
|
|
|
it.importedBy.addAll(modulesImportedBy.getValue(it))
|
|
it.imports.addAll(modulesImporting.getValue(it))
|
|
|
|
forAllSubroutines(it) { sub ->
|
|
sub.calledBy.clear()
|
|
sub.calls.clear()
|
|
|
|
sub.calledBy.addAll(subroutinesCalledBy.getValue(sub))
|
|
sub.calls.addAll(subroutinesCalling.getValue(sub))
|
|
}
|
|
|
|
}
|
|
|
|
val rootmodule = program.modules.first()
|
|
rootmodule.importedBy.add(rootmodule) // don't discard root module
|
|
}
|
|
|
|
override fun process(block: Block): IStatement {
|
|
if(block.definingModule().isLibraryModule) {
|
|
// make sure the block is not removed
|
|
addNodeAndParentScopes(block)
|
|
}
|
|
|
|
return super.process(block)
|
|
}
|
|
|
|
override fun process(directive: Directive): IStatement {
|
|
val thisModule = directive.definingModule()
|
|
if(directive.directive=="%import") {
|
|
val importedModule: Module = program.modules.single { it.name==directive.args[0].name }
|
|
modulesImporting[thisModule] = modulesImporting.getValue(thisModule).plus(importedModule)
|
|
modulesImportedBy[importedModule] = modulesImportedBy.getValue(importedModule).plus(thisModule)
|
|
} else if (directive.directive=="%asminclude") {
|
|
val asm = loadAsmIncludeFile(directive.args[0].str!!, thisModule.source)
|
|
val scope = directive.definingScope()
|
|
scanAssemblyCode(asm, directive, scope)
|
|
}
|
|
|
|
return super.process(directive)
|
|
}
|
|
|
|
override fun process(identifier: IdentifierReference): IExpression {
|
|
// track symbol usage
|
|
val target = identifier.targetStatement(this.program.namespace)
|
|
if(target!=null) {
|
|
addNodeAndParentScopes(target)
|
|
}
|
|
return super.process(identifier)
|
|
}
|
|
|
|
private fun addNodeAndParentScopes(stmt: IStatement) {
|
|
usedSymbols.add(stmt)
|
|
var node: Node=stmt
|
|
do {
|
|
if(node is INameScope && node is IStatement) {
|
|
usedSymbols.add(node)
|
|
}
|
|
node=node.parent
|
|
} while (node !is Module && node !is ParentSentinel)
|
|
}
|
|
|
|
override fun process(subroutine: Subroutine): IStatement {
|
|
if((subroutine.name=="start" && subroutine.definingScope().name=="main")
|
|
|| subroutine.name== initvarsSubName || subroutine.definingModule().isLibraryModule) {
|
|
// make sure the entrypoint is mentioned in the used symbols
|
|
addNodeAndParentScopes(subroutine)
|
|
}
|
|
return super.process(subroutine)
|
|
}
|
|
|
|
override fun process(decl: VarDecl): IStatement {
|
|
if(decl.autoGenerated || (decl.definingModule().isLibraryModule && decl.type!=VarDeclType.VAR)) {
|
|
// make sure autogenerated vardecls are in the used symbols
|
|
addNodeAndParentScopes(decl)
|
|
}
|
|
return super.process(decl)
|
|
}
|
|
|
|
override fun process(functionCall: FunctionCall): IExpression {
|
|
val otherSub = functionCall.target.targetSubroutine(program.namespace)
|
|
if(otherSub!=null) {
|
|
functionCall.definingSubroutine()?.let { thisSub ->
|
|
subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub)
|
|
subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCall)
|
|
}
|
|
}
|
|
return super.process(functionCall)
|
|
}
|
|
|
|
override fun process(functionCallStatement: FunctionCallStatement): IStatement {
|
|
val otherSub = functionCallStatement.target.targetSubroutine(program.namespace)
|
|
if(otherSub!=null) {
|
|
functionCallStatement.definingSubroutine()?.let { thisSub ->
|
|
subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub)
|
|
subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(functionCallStatement)
|
|
}
|
|
}
|
|
return super.process(functionCallStatement)
|
|
}
|
|
|
|
override fun process(jump: Jump): IStatement {
|
|
val otherSub = jump.identifier?.targetSubroutine(program.namespace)
|
|
if(otherSub!=null) {
|
|
jump.definingSubroutine()?.let { thisSub ->
|
|
subroutinesCalling[thisSub] = subroutinesCalling.getValue(thisSub).plus(otherSub)
|
|
subroutinesCalledBy[otherSub] = subroutinesCalledBy.getValue(otherSub).plus(jump)
|
|
}
|
|
}
|
|
return super.process(jump)
|
|
}
|
|
|
|
override fun process(inlineAssembly: InlineAssembly): IStatement {
|
|
// parse inline asm for subroutine calls (jmp, jsr)
|
|
val scope = inlineAssembly.definingScope()
|
|
scanAssemblyCode(inlineAssembly.assembly, inlineAssembly, scope)
|
|
return super.process(inlineAssembly)
|
|
}
|
|
|
|
private fun scanAssemblyCode(asm: String, context: IStatement, scope: INameScope) {
|
|
val asmJumpRx = Regex("""[\-+a-zA-Z0-9_ \t]+(jmp|jsr)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE)
|
|
val asmRefRx = Regex("""[\-+a-zA-Z0-9_ \t]+(...)[ \t]+(\S+).*""", RegexOption.IGNORE_CASE)
|
|
asm.lines().forEach { line ->
|
|
val matches = asmJumpRx.matchEntire(line)
|
|
if (matches != null) {
|
|
val jumptarget = matches.groups[2]?.value
|
|
if (jumptarget != null && (jumptarget[0].isLetter() || jumptarget[0] == '_')) {
|
|
val node = program.namespace.lookup(jumptarget.split('.'), context)
|
|
if (node is Subroutine) {
|
|
subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node)
|
|
subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context)
|
|
} else if(jumptarget.contains('.')) {
|
|
// maybe only the first part already refers to a subroutine
|
|
val node2 = program.namespace.lookup(listOf(jumptarget.substringBefore('.')), context)
|
|
if (node2 is Subroutine) {
|
|
subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node2)
|
|
subroutinesCalledBy[node2] = subroutinesCalledBy.getValue(node2).plus(context)
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
val matches2 = asmRefRx.matchEntire(line)
|
|
if (matches2 != null) {
|
|
val target= matches2.groups[2]?.value
|
|
if (target != null && (target[0].isLetter() || target[0] == '_')) {
|
|
val node = program.namespace.lookup(listOf(target.substringBefore('.')), context)
|
|
if (node is Subroutine) {
|
|
subroutinesCalling[scope] = subroutinesCalling.getValue(scope).plus(node)
|
|
subroutinesCalledBy[node] = subroutinesCalledBy.getValue(node).plus(context)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|