'Program' is not an ast Node

This commit is contained in:
Irmen de Jong 2021-10-30 00:25:34 +02:00
parent d7d2eefa4f
commit 3767b4bbe7
13 changed files with 217 additions and 206 deletions

View File

@ -24,7 +24,7 @@ import kotlin.system.measureTimeMillis
class CompilationResult(val success: Boolean, class CompilationResult(val success: Boolean,
val programAst: Program, val program: Program,
val programName: String, val programName: String,
val compTarget: ICompilationTarget, val compTarget: ICompilationTarget,
val importedFiles: List<Path>) val importedFiles: List<Path>)
@ -39,7 +39,7 @@ fun compileProgram(filepath: Path,
outputDir: Path, outputDir: Path,
errors: IErrorReporter = ErrorReporter()): CompilationResult { errors: IErrorReporter = ErrorReporter()): CompilationResult {
var programName = "" var programName = ""
lateinit var programAst: Program lateinit var program: Program
lateinit var importedFiles: List<Path> lateinit var importedFiles: List<Path>
val compTarget = val compTarget =
@ -52,31 +52,31 @@ fun compileProgram(filepath: Path,
try { try {
val totalTime = measureTimeMillis { val totalTime = measureTimeMillis {
// import main module and everything it needs // import main module and everything it needs
val (ast, compilationOptions, imported) = parseImports(filepath, errors, compTarget, sourceDirs) val (programresult, compilationOptions, imported) = parseImports(filepath, errors, compTarget, sourceDirs)
compilationOptions.slowCodegenWarnings = slowCodegenWarnings compilationOptions.slowCodegenWarnings = slowCodegenWarnings
compilationOptions.optimize = optimize compilationOptions.optimize = optimize
programAst = ast program = programresult
importedFiles = imported importedFiles = imported
processAst(programAst, errors, compilationOptions) processAst(program, errors, compilationOptions)
if (compilationOptions.optimize) if (compilationOptions.optimize)
optimizeAst( optimizeAst(
programAst, program,
errors, errors,
BuiltinFunctionsFacade(BuiltinFunctions), BuiltinFunctionsFacade(BuiltinFunctions),
compTarget compTarget
) )
postprocessAst(programAst, errors, compilationOptions) postprocessAst(program, errors, compilationOptions)
// println("*********** AST BEFORE ASSEMBLYGEN *************") // println("*********** AST BEFORE ASSEMBLYGEN *************")
// printAst(programAst) // printAst(program)
if (writeAssembly) { if (writeAssembly) {
val result = writeAssembly(programAst, errors, outputDir, compilationOptions) val result = writeAssembly(program, errors, outputDir, compilationOptions)
when (result) { when (result) {
is WriteAssemblyResult.Ok -> programName = result.filename is WriteAssemblyResult.Ok -> programName = result.filename
is WriteAssemblyResult.Fail -> { is WriteAssemblyResult.Fail -> {
System.err.println(result.error) System.err.println(result.error)
return CompilationResult(false, programAst, programName, compTarget, importedFiles) return CompilationResult(false, program, programName, compTarget, importedFiles)
} }
} }
} }
@ -84,7 +84,7 @@ fun compileProgram(filepath: Path,
System.out.flush() System.out.flush()
System.err.flush() System.err.flush()
println("\nTotal compilation+assemble time: ${totalTime / 1000.0} sec.") println("\nTotal compilation+assemble time: ${totalTime / 1000.0} sec.")
return CompilationResult(true, programAst, programName, compTarget, importedFiles) return CompilationResult(true, program, programName, compTarget, importedFiles)
} catch (px: ParseError) { } catch (px: ParseError) {
System.err.print("\u001b[91m") // bright red System.err.print("\u001b[91m") // bright red
System.err.println("${px.position.toClickableStr()} parse error: ${px.message}".trim()) System.err.println("${px.position.toClickableStr()} parse error: ${px.message}".trim())
@ -155,20 +155,20 @@ fun parseImports(filepath: Path,
sourceDirs: List<String>): Triple<Program, CompilationOptions, List<Path>> { sourceDirs: List<String>): Triple<Program, CompilationOptions, List<Path>> {
println("Compiler target: ${compTarget.name}. Parsing...") println("Compiler target: ${compTarget.name}. Parsing...")
val bf = BuiltinFunctionsFacade(BuiltinFunctions) val bf = BuiltinFunctionsFacade(BuiltinFunctions)
val programAst = Program(filepath.nameWithoutExtension, bf, compTarget, compTarget) val program = Program(filepath.nameWithoutExtension, bf, compTarget, compTarget)
bf.program = programAst bf.program = program
val importer = ModuleImporter(programAst, compTarget.name, errors, sourceDirs) val importer = ModuleImporter(program, compTarget.name, errors, sourceDirs)
val importedModuleResult = importer.importModule(filepath) val importedModuleResult = importer.importModule(filepath)
importedModuleResult.onFailure { throw it } importedModuleResult.onFailure { throw it }
errors.report() errors.report()
val importedFiles = programAst.modules.map { it.source } val importedFiles = program.modules.map { it.source }
.filter { it.isFromFilesystem } .filter { it.isFromFilesystem }
.map { Path(it.origin) } .map { Path(it.origin) }
val compilerOptions = determineCompilationOptions(programAst, compTarget) val compilerOptions = determineCompilationOptions(program, compTarget)
if (compilerOptions.launcher == LauncherType.BASIC && compilerOptions.output != OutputType.PRG) if (compilerOptions.launcher == LauncherType.BASIC && compilerOptions.output != OutputType.PRG)
throw ParsingFailedError("${programAst.modules.first().position} BASIC launcher requires output type PRG.") throw ParsingFailedError("${program.modules.first().position} BASIC launcher requires output type PRG.")
// depending on the machine and compiler options we may have to include some libraries // depending on the machine and compiler options we may have to include some libraries
for(lib in compTarget.machine.importLibs(compilerOptions, compTarget.name)) for(lib in compTarget.machine.importLibs(compilerOptions, compTarget.name))
@ -178,7 +178,7 @@ fun parseImports(filepath: Path,
importer.importLibraryModule("math") importer.importLibraryModule("math")
importer.importLibraryModule("prog8_lib") importer.importLibraryModule("prog8_lib")
errors.report() errors.report()
return Triple(programAst, compilerOptions, importedFiles) return Triple(program, compilerOptions, importedFiles)
} }
fun determineCompilationOptions(program: Program, compTarget: ICompilationTarget): CompilationOptions { fun determineCompilationOptions(program: Program, compTarget: ICompilationTarget): CompilationOptions {
@ -241,44 +241,44 @@ fun determineCompilationOptions(program: Program, compTarget: ICompilationTarget
) )
} }
private fun processAst(programAst: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) { private fun processAst(program: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) {
// perform initial syntax checks and processings // perform initial syntax checks and processings
println("Processing for target ${compilerOptions.compTarget.name}...") println("Processing for target ${compilerOptions.compTarget.name}...")
programAst.preprocessAst() program.preprocessAst()
programAst.checkIdentifiers(errors, compilerOptions) program.checkIdentifiers(errors, compilerOptions)
errors.report() errors.report()
// TODO: turning char literals into UBYTEs via an encoding should really happen in code gen - but for that we'd need DataType.CHAR // TODO: turning char literals into UBYTEs via an encoding should really happen in code gen - but for that we'd need DataType.CHAR
// NOTE: we will then lose the opportunity to do constant-folding on any expression containing a char literal, but how often will those occur? // NOTE: we will then lose the opportunity to do constant-folding on any expression containing a char literal, but how often will those occur?
// Also they might be optimized away eventually in codegen or by the assembler even // Also they might be optimized away eventually in codegen or by the assembler even
programAst.charLiteralsToUByteLiterals(compilerOptions.compTarget) program.charLiteralsToUByteLiterals(compilerOptions.compTarget)
programAst.constantFold(errors, compilerOptions.compTarget) program.constantFold(errors, compilerOptions.compTarget)
errors.report() errors.report()
programAst.reorderStatements(errors) program.reorderStatements(errors)
errors.report() errors.report()
programAst.addTypecasts(errors) program.addTypecasts(errors)
errors.report() errors.report()
programAst.variousCleanups(programAst, errors) program.variousCleanups(program, errors)
errors.report() errors.report()
programAst.checkValid(compilerOptions, errors, compilerOptions.compTarget) program.checkValid(compilerOptions, errors, compilerOptions.compTarget)
errors.report() errors.report()
programAst.checkIdentifiers(errors, compilerOptions) program.checkIdentifiers(errors, compilerOptions)
errors.report() errors.report()
} }
private fun optimizeAst(programAst: Program, errors: IErrorReporter, functions: IBuiltinFunctions, compTarget: ICompilationTarget) { private fun optimizeAst(program: Program, errors: IErrorReporter, functions: IBuiltinFunctions, compTarget: ICompilationTarget) {
// optimize the parse tree // optimize the parse tree
println("Optimizing...") println("Optimizing...")
val remover = UnusedCodeRemover(programAst, errors, compTarget) val remover = UnusedCodeRemover(program, errors, compTarget)
remover.visit(programAst) remover.visit(program)
remover.applyModifications() remover.applyModifications()
while (true) { while (true) {
// keep optimizing expressions and statements until no more steps remain // keep optimizing expressions and statements until no more steps remain
val optsDone1 = programAst.simplifyExpressions() val optsDone1 = program.simplifyExpressions()
val optsDone2 = programAst.splitBinaryExpressions(compTarget) val optsDone2 = program.splitBinaryExpressions(compTarget)
val optsDone3 = programAst.optimizeStatements(errors, functions, compTarget) val optsDone3 = program.optimizeStatements(errors, functions, compTarget)
programAst.constantFold(errors, compTarget) // because simplified statements and expressions can result in more constants that can be folded away program.constantFold(errors, compTarget) // because simplified statements and expressions can result in more constants that can be folded away
errors.report() errors.report()
if (optsDone1 + optsDone2 + optsDone3 == 0) if (optsDone1 + optsDone2 + optsDone3 == 0)
break break
@ -287,17 +287,17 @@ private fun optimizeAst(programAst: Program, errors: IErrorReporter, functions:
errors.report() errors.report()
} }
private fun postprocessAst(programAst: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) { private fun postprocessAst(program: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) {
programAst.addTypecasts(errors) program.addTypecasts(errors)
errors.report() errors.report()
programAst.variousCleanups(programAst, errors) program.variousCleanups(program, errors)
programAst.checkValid(compilerOptions, errors, compilerOptions.compTarget) // check if final tree is still valid program.checkValid(compilerOptions, errors, compilerOptions.compTarget) // check if final tree is still valid
errors.report() errors.report()
val callGraph = CallGraph(programAst) val callGraph = CallGraph(program)
callGraph.checkRecursiveCalls(errors) callGraph.checkRecursiveCalls(errors)
errors.report() errors.report()
programAst.verifyFunctionArgTypes() program.verifyFunctionArgTypes()
programAst.moveMainAndStartToFirst() program.moveMainAndStartToFirst()
} }
private sealed class WriteAssemblyResult { private sealed class WriteAssemblyResult {
@ -305,20 +305,20 @@ private sealed class WriteAssemblyResult {
class Fail(val error: String): WriteAssemblyResult() class Fail(val error: String): WriteAssemblyResult()
} }
private fun writeAssembly(programAst: Program, private fun writeAssembly(program: Program,
errors: IErrorReporter, errors: IErrorReporter,
outputDir: Path, outputDir: Path,
compilerOptions: CompilationOptions compilerOptions: CompilationOptions
): WriteAssemblyResult { ): WriteAssemblyResult {
// asm generation directly from the Ast // asm generation directly from the Ast
programAst.processAstBeforeAsmGeneration(errors, compilerOptions.compTarget) program.processAstBeforeAsmGeneration(errors, compilerOptions.compTarget)
errors.report() errors.report()
// printAst(programAst) // printAst(program)
compilerOptions.compTarget.machine.initializeZeropage(compilerOptions) compilerOptions.compTarget.machine.initializeZeropage(compilerOptions)
val assembly = asmGeneratorFor(compilerOptions.compTarget, val assembly = asmGeneratorFor(compilerOptions.compTarget,
programAst, program,
errors, errors,
compilerOptions.compTarget.machine.zeropage, compilerOptions.compTarget.machine.zeropage,
compilerOptions, compilerOptions,
@ -338,10 +338,10 @@ private fun writeAssembly(programAst: Program,
} }
} }
fun printAst(programAst: Program) { fun printAst(program: Program) {
println() println()
val printer = AstToSourceTextConverter(::print, programAst) val printer = AstToSourceTextConverter(::print, program)
printer.visit(programAst) printer.visit(program)
println() println()
} }

View File

@ -27,7 +27,7 @@ internal class AstChecker(private val program: Program,
if(mainBlocks.size>1) if(mainBlocks.size>1)
errors.err("more than one 'main' block", mainBlocks[0].position) errors.err("more than one 'main' block", mainBlocks[0].position)
if(mainBlocks.isEmpty()) if(mainBlocks.isEmpty())
errors.err("there is no 'main' block", program.modules.firstOrNull()?.position ?: program.position) errors.err("there is no 'main' block", program.modules.firstOrNull()?.position ?: Position.DUMMY)
for(mainBlock in mainBlocks) { for(mainBlock in mainBlocks) {
val startSub = mainBlock.subScope("start") as? Subroutine val startSub = mainBlock.subScope("start") as? Subroutine

View File

@ -26,11 +26,11 @@ class TestCallgraph {
} }
""" """
val result = compileText(C64Target, false, sourcecode).assertSuccess() val result = compileText(C64Target, false, sourcecode).assertSuccess()
val graph = CallGraph(result.programAst) val graph = CallGraph(result.program)
assertEquals(1, graph.imports.size) assertEquals(1, graph.imports.size)
assertEquals(1, graph.importedBy.size) assertEquals(1, graph.importedBy.size)
val toplevelModule = result.programAst.toplevelModule val toplevelModule = result.program.toplevelModule
val importedModule = graph.imports.getValue(toplevelModule).single() val importedModule = graph.imports.getValue(toplevelModule).single()
assertEquals("string", importedModule.name) assertEquals("string", importedModule.name)
val importedBy = graph.importedBy.getValue(importedModule).single() val importedBy = graph.importedBy.getValue(importedModule).single()
@ -45,7 +45,7 @@ class TestCallgraph {
assertFalse(sub in graph.calls) assertFalse(sub in graph.calls)
assertFalse(sub in graph.calledBy) assertFalse(sub in graph.calledBy)
if(sub === result.programAst.entrypoint) if(sub === result.program.entrypoint)
assertFalse(graph.unused(sub), "start() should always be marked as used to avoid having it removed") assertFalse(graph.unused(sub), "start() should always be marked as used to avoid having it removed")
else else
assertTrue(graph.unused(sub)) assertTrue(graph.unused(sub))
@ -66,11 +66,11 @@ class TestCallgraph {
} }
""" """
val result = compileText(C64Target, false, sourcecode).assertSuccess() val result = compileText(C64Target, false, sourcecode).assertSuccess()
val graph = CallGraph(result.programAst) val graph = CallGraph(result.program)
assertEquals(1, graph.imports.size) assertEquals(1, graph.imports.size)
assertEquals(1, graph.importedBy.size) assertEquals(1, graph.importedBy.size)
val toplevelModule = result.programAst.toplevelModule val toplevelModule = result.program.toplevelModule
val importedModule = graph.imports.getValue(toplevelModule).single() val importedModule = graph.imports.getValue(toplevelModule).single()
assertEquals("string", importedModule.name) assertEquals("string", importedModule.name)
val importedBy = graph.importedBy.getValue(importedModule).single() val importedBy = graph.importedBy.getValue(importedModule).single()

View File

@ -34,7 +34,7 @@ class TestCompilerOnCharLit {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val funCall = startSub.statements.filterIsInstance<IFunctionCall>()[0] val funCall = startSub.statements.filterIsInstance<IFunctionCall>()[0]
@ -58,7 +58,7 @@ class TestCompilerOnCharLit {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val funCall = startSub.statements.filterIsInstance<IFunctionCall>()[0] val funCall = startSub.statements.filterIsInstance<IFunctionCall>()[0]
@ -93,7 +93,7 @@ class TestCompilerOnCharLit {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val funCall = startSub.statements.filterIsInstance<IFunctionCall>()[0] val funCall = startSub.statements.filterIsInstance<IFunctionCall>()[0]

View File

@ -37,7 +37,7 @@ class TestCompilerOnImportsAndIncludes {
val result = compileFile(platform, optimize = false, fixturesDir, filepath.name) val result = compileFile(platform, optimize = false, fixturesDir, filepath.name)
.assertSuccess() .assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val strLits = startSub.statements val strLits = startSub.statements
.filterIsInstance<FunctionCallStatement>() .filterIsInstance<FunctionCallStatement>()
@ -62,7 +62,7 @@ class TestCompilerOnImportsAndIncludes {
val result = compileFile(platform, optimize = false, fixturesDir, filepath.name) val result = compileFile(platform, optimize = false, fixturesDir, filepath.name)
.assertSuccess() .assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val args = startSub.statements val args = startSub.statements
.filterIsInstance<FunctionCallStatement>() .filterIsInstance<FunctionCallStatement>()

View File

@ -43,7 +43,7 @@ class TestCompilerOnRanges {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val decl = startSub val decl = startSub
.statements.filterIsInstance<VarDecl>()[0] .statements.filterIsInstance<VarDecl>()[0]
@ -72,7 +72,7 @@ class TestCompilerOnRanges {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val decl = startSub val decl = startSub
.statements.filterIsInstance<VarDecl>()[0] .statements.filterIsInstance<VarDecl>()[0]
@ -154,7 +154,7 @@ class TestCompilerOnRanges {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val iterable = startSub val iterable = startSub
.statements.filterIsInstance<ForLoop>() .statements.filterIsInstance<ForLoop>()
@ -185,7 +185,7 @@ class TestCompilerOnRanges {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val rangeExpr = startSub val rangeExpr = startSub
.statements.filterIsInstance<ForLoop>() .statements.filterIsInstance<ForLoop>()
@ -212,7 +212,7 @@ class TestCompilerOnRanges {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val rangeExpr = startSub val rangeExpr = startSub
.statements.filterIsInstance<ForLoop>() .statements.filterIsInstance<ForLoop>()
@ -256,7 +256,7 @@ class TestCompilerOnRanges {
} }
""").assertSuccess() """).assertSuccess()
val program = result.programAst val program = result.program
val startSub = program.entrypoint val startSub = program.entrypoint
val iterable = startSub val iterable = startSub
.statements.filterIsInstance<ForLoop>() .statements.filterIsInstance<ForLoop>()

View File

@ -30,9 +30,9 @@ main {
} }
} }
""").assertSuccess() """).assertSuccess()
assertTrue(result.programAst.toplevelModule.name.startsWith("on_the_fly_test")) assertTrue(result.program.toplevelModule.name.startsWith("on_the_fly_test"))
val moduleNames = result.programAst.modules.map { it.name } val moduleNames = result.program.modules.map { it.name }
assertTrue(moduleNames[0].startsWith("on_the_fly_test"), "main module must be first") assertTrue(moduleNames[0].startsWith("on_the_fly_test"), "main module must be first")
assertEquals(listOf( assertEquals(listOf(
"prog8_interned_strings", "prog8_interned_strings",
@ -44,7 +44,7 @@ main {
"prog8_lib" "prog8_lib"
), moduleNames.drop(1), "module order in parse tree") ), moduleNames.drop(1), "module order in parse tree")
assertTrue(result.programAst.toplevelModule.name.startsWith("on_the_fly_test")) assertTrue(result.program.toplevelModule.name.startsWith("on_the_fly_test"))
} }
@Test @Test
@ -61,8 +61,8 @@ main {
} }
} }
""").assertSuccess() """).assertSuccess()
assertTrue(result.programAst.toplevelModule.name.startsWith("on_the_fly_test")) assertTrue(result.program.toplevelModule.name.startsWith("on_the_fly_test"))
val options = determineCompilationOptions(result.programAst, C64Target) val options = determineCompilationOptions(result.program, C64Target)
assertTrue(options.floats) assertTrue(options.floats)
assertEquals(ZeropageType.DONTUSE, options.zeropage) assertEquals(ZeropageType.DONTUSE, options.zeropage)
assertTrue(options.noSysInit) assertTrue(options.noSysInit)

View File

@ -26,10 +26,10 @@ class TestOptimization {
} }
""" """
val result = compileText(C64Target, true, sourcecode).assertSuccess() val result = compileText(C64Target, true, sourcecode).assertSuccess()
val toplevelModule = result.programAst.toplevelModule val toplevelModule = result.program.toplevelModule
val mainBlock = toplevelModule.statements.single() as Block val mainBlock = toplevelModule.statements.single() as Block
val startSub = mainBlock.statements.single() as Subroutine val startSub = mainBlock.statements.single() as Subroutine
assertSame(result.programAst.entrypoint, startSub) assertSame(result.program.entrypoint, startSub)
assertEquals("start", startSub.name, "only start sub should remain") assertEquals("start", startSub.name, "only start sub should remain")
assertTrue(startSub.statements.single() is Return, "compiler has inserted return in empty subroutines") assertTrue(startSub.statements.single() is Return, "compiler has inserted return in empty subroutines")
} }
@ -48,11 +48,11 @@ class TestOptimization {
} }
""" """
val result = compileText(C64Target, true, sourcecode).assertSuccess() val result = compileText(C64Target, true, sourcecode).assertSuccess()
val toplevelModule = result.programAst.toplevelModule val toplevelModule = result.program.toplevelModule
val mainBlock = toplevelModule.statements.single() as Block val mainBlock = toplevelModule.statements.single() as Block
val startSub = mainBlock.statements[0] as Subroutine val startSub = mainBlock.statements[0] as Subroutine
val emptySub = mainBlock.statements[1] as Subroutine val emptySub = mainBlock.statements[1] as Subroutine
assertSame(result.programAst.entrypoint, startSub) assertSame(result.program.entrypoint, startSub)
assertEquals("start", startSub.name) assertEquals("start", startSub.name)
assertEquals("empty", emptySub.name) assertEquals("empty", emptySub.name)
assertTrue(emptySub.statements.single() is Return, "compiler has inserted return in empty subroutines") assertTrue(emptySub.statements.single() is Return, "compiler has inserted return in empty subroutines")

View File

@ -2,19 +2,35 @@ package prog8tests
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance
import prog8.ast.GlobalNamespace
import prog8.ast.base.ParentSentinel
import prog8.ast.expressions.NumericLiteralValue import prog8.ast.expressions.NumericLiteralValue
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.target.C64Target import prog8.compiler.target.C64Target
import prog8tests.helpers.assertSuccess import prog8tests.helpers.assertSuccess
import prog8tests.helpers.compileText import prog8tests.helpers.compileText
import kotlin.test.assertEquals import kotlin.test.*
import kotlin.test.assertFalse
import kotlin.test.assertTrue
@TestInstance(TestInstance.Lifecycle.PER_CLASS) @TestInstance(TestInstance.Lifecycle.PER_CLASS)
class TestScoping { class TestScoping {
@Test
fun testModulesParentIsGlobalNamespace() {
val src = """
main {
sub start() {
}
}
"""
val result = compileText(C64Target, false, src, writeAssembly = false).assertSuccess()
val module = result.program.toplevelModule
assertIs<GlobalNamespace>(module.parent)
assertSame(result.program, module.program)
assertIs<ParentSentinel>(module.parent.parent)
}
@Test @Test
fun testAnonScopeVarsMovedIntoSubroutineScope() { fun testAnonScopeVarsMovedIntoSubroutineScope() {
val src = """ val src = """
@ -29,7 +45,7 @@ class TestScoping {
""" """
val result = compileText(C64Target, false, src, writeAssembly = false).assertSuccess() val result = compileText(C64Target, false, src, writeAssembly = false).assertSuccess()
val module = result.programAst.toplevelModule val module = result.program.toplevelModule
val mainBlock = module.statements.single() as Block val mainBlock = module.statements.single() as Block
val start = mainBlock.statements.single() as Subroutine val start = mainBlock.statements.single() as Subroutine
val repeatbody = start.statements.filterIsInstance<RepeatLoop>().single().body val repeatbody = start.statements.filterIsInstance<RepeatLoop>().single().body
@ -100,7 +116,7 @@ class TestScoping {
""" """
val result = compileText(C64Target, false, src, writeAssembly = true).assertSuccess() val result = compileText(C64Target, false, src, writeAssembly = true).assertSuccess()
val module = result.programAst.toplevelModule val module = result.program.toplevelModule
val mainBlock = module.statements.single() as Block val mainBlock = module.statements.single() as Block
val start = mainBlock.statements.single() as Subroutine val start = mainBlock.statements.single() as Subroutine
val labels = start.statements.filterIsInstance<Label>() val labels = start.statements.filterIsInstance<Label>()

View File

@ -46,7 +46,7 @@ class TestSubroutines {
} }
""" """
val result = compileText(C64Target, false, text, writeAssembly = false).assertSuccess() val result = compileText(C64Target, false, text, writeAssembly = false).assertSuccess()
val module = result.programAst.toplevelModule val module = result.program.toplevelModule
val mainBlock = module.statements.single() as Block val mainBlock = module.statements.single() as Block
val asmfunc = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="asmfunc"} val asmfunc = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="asmfunc"}
val func = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="func"} val func = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="func"}
@ -94,7 +94,7 @@ class TestSubroutines {
} }
""" """
val result = compileText(C64Target, false, text, writeAssembly = true).assertSuccess() val result = compileText(C64Target, false, text, writeAssembly = true).assertSuccess()
val module = result.programAst.toplevelModule val module = result.program.toplevelModule
val mainBlock = module.statements.single() as Block val mainBlock = module.statements.single() as Block
val asmfunc = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="asmfunc"} val asmfunc = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="asmfunc"}
val func = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="func"} val func = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="func"}
@ -170,7 +170,7 @@ class TestSubroutines {
""" """
val result = compileText(C64Target, false, text, writeAssembly = false).assertSuccess() val result = compileText(C64Target, false, text, writeAssembly = false).assertSuccess()
val module = result.programAst.toplevelModule val module = result.program.toplevelModule
val mainBlock = module.statements.single() as Block val mainBlock = module.statements.single() as Block
val asmfunc = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="asmfunc"} val asmfunc = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="asmfunc"}
val func = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="func"} val func = mainBlock.statements.filterIsInstance<Subroutine>().single { it.name=="func"}

View File

@ -3,12 +3,9 @@ package prog8.ast
import prog8.ast.base.* import prog8.ast.base.*
import prog8.ast.expressions.Expression import prog8.ast.expressions.Expression
import prog8.ast.expressions.IdentifierReference import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.StringLiteralValue
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.AstWalker import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstVisitor import prog8.ast.walk.IAstVisitor
import prog8.compilerinterface.IMemSizer
import prog8.compilerinterface.IStringEncoding
import prog8.parser.SourceCode import prog8.parser.SourceCode
const val internedStringsModuleName = "prog8_interned_strings" const val internedStringsModuleName = "prog8_interned_strings"
@ -284,121 +281,6 @@ interface Node {
} }
/*********** Everything starts from here, the Program; zero or more modules *************/
class Program(val name: String,
val builtinFunctions: IBuiltinFunctions,
val memsizer: IMemSizer,
val encoding: IStringEncoding
): Node {
private val _modules = mutableListOf<Module>()
val modules: List<Module> = _modules
val namespace: GlobalNamespace = GlobalNamespace(modules)
init {
// insert a container module for all interned strings later
val internedStringsModule = Module(mutableListOf(), Position.DUMMY, SourceCode.Generated(internedStringsModuleName))
val block = Block(internedStringsModuleName, null, mutableListOf(), true, Position.DUMMY)
internedStringsModule.statements.add(block)
_modules.add(0, internedStringsModule)
internedStringsModule.linkParents(namespace)
internedStringsModule.program = this
}
fun addModule(module: Module): Program {
require(null == _modules.firstOrNull { it.name == module.name })
{ "module '${module.name}' already present" }
_modules.add(module)
module.linkIntoProgram(this)
return this
}
fun removeModule(module: Module) = _modules.remove(module)
fun moveModuleToFront(module: Module): Program {
require(_modules.contains(module))
{ "Not a module of this program: '${module.name}'"}
_modules.remove(module)
_modules.add(0, module)
return this
}
val allBlocks: List<Block>
get() = modules.flatMap { it.statements.filterIsInstance<Block>() }
val entrypoint: Subroutine
get() {
val mainBlocks = allBlocks.filter { it.name == "main" }
return when (mainBlocks.size) {
0 -> throw FatalAstException("no 'main' block")
1 -> mainBlocks[0].subScope("start") as Subroutine
else -> throw FatalAstException("more than one 'main' block")
}
}
val toplevelModule: Module
get() = modules.first { it.name!=internedStringsModuleName }
val definedLoadAddress: Int
get() = toplevelModule.loadAddress
var actualLoadAddress: Int = 0
private val internedStringsUnique = mutableMapOf<Pair<String, Boolean>, List<String>>()
fun internString(string: StringLiteralValue): List<String> {
// Move a string literal into the internal, deduplicated, string pool
// replace it with a variable declaration that points to the entry in the pool.
if(string.parent is VarDecl) {
// deduplication can only be performed safely for known-const strings (=string literals OUTSIDE OF A VARDECL)!
throw FatalAstException("cannot intern a string literal that's part of a vardecl")
}
fun getScopedName(string: StringLiteralValue): List<String> {
val internedStringsBlock = modules
.first { it.name == internedStringsModuleName }.statements
.first { it is Block && it.name == internedStringsModuleName } as Block
val varName = "string_${internedStringsBlock.statements.size}"
val decl = VarDecl(
VarDeclType.VAR, DataType.STR, ZeropageWish.NOT_IN_ZEROPAGE, null, varName, string,
isArray = false, autogeneratedDontRemove = true, sharedWithAsm = false, position = string.position
)
internedStringsBlock.statements.add(decl)
decl.linkParents(internedStringsBlock)
return listOf(internedStringsModuleName, decl.name)
}
val key = Pair(string.value, string.altEncoding)
val existing = internedStringsUnique[key]
if (existing != null)
return existing
val scopedName = getScopedName(string)
internedStringsUnique[key] = scopedName
return scopedName
}
override val position: Position = Position.DUMMY
override var parent: Node
get() = throw FatalAstException("program has no parent")
set(_) = throw FatalAstException("can't set parent of program")
override fun linkParents(parent: Node) {
modules.forEach {
it.linkParents(namespace)
}
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(node is Module && replacement is Module)
val idx = _modules.indexOfFirst { it===node }
_modules[idx] = replacement
replacement.linkIntoProgram(this)
}
}
open class Module(final override var statements: MutableList<Statement>, open class Module(final override var statements: MutableList<Statement>,
final override val position: Position, final override val position: Position,
val source: SourceCode) : Node, INameScope { val source: SourceCode) : Node, INameScope {

View File

@ -0,0 +1,112 @@
package prog8.ast
import prog8.ast.base.DataType
import prog8.ast.base.FatalAstException
import prog8.ast.base.Position
import prog8.ast.base.VarDeclType
import prog8.ast.expressions.StringLiteralValue
import prog8.ast.statements.Block
import prog8.ast.statements.Subroutine
import prog8.ast.statements.VarDecl
import prog8.ast.statements.ZeropageWish
import prog8.compilerinterface.IMemSizer
import prog8.compilerinterface.IStringEncoding
import prog8.parser.SourceCode
/*********** Everything starts from here, the Program; zero or more modules *************/
class Program(val name: String,
val builtinFunctions: IBuiltinFunctions,
val memsizer: IMemSizer,
val encoding: IStringEncoding
) {
private val _modules = mutableListOf<Module>()
val modules: List<Module> = _modules
val namespace: GlobalNamespace = GlobalNamespace(modules)
init {
// insert a container module for all interned strings later
val internedStringsModule =
Module(mutableListOf(), Position.DUMMY, SourceCode.Generated(internedStringsModuleName))
val block = Block(internedStringsModuleName, null, mutableListOf(), true, Position.DUMMY)
internedStringsModule.statements.add(block)
_modules.add(0, internedStringsModule)
internedStringsModule.linkParents(namespace)
internedStringsModule.program = this
}
fun addModule(module: Module): Program {
require(null == _modules.firstOrNull { it.name == module.name })
{ "module '${module.name}' already present" }
_modules.add(module)
module.linkIntoProgram(this)
return this
}
fun removeModule(module: Module) = _modules.remove(module)
fun moveModuleToFront(module: Module): Program {
require(_modules.contains(module))
{ "Not a module of this program: '${module.name}'"}
_modules.remove(module)
_modules.add(0, module)
return this
}
val allBlocks: List<Block>
get() = modules.flatMap { it.statements.filterIsInstance<Block>() }
val entrypoint: Subroutine
get() {
val mainBlocks = allBlocks.filter { it.name == "main" }
return when (mainBlocks.size) {
0 -> throw FatalAstException("no 'main' block")
1 -> mainBlocks[0].subScope("start") as Subroutine
else -> throw FatalAstException("more than one 'main' block")
}
}
val toplevelModule: Module
get() = modules.first { it.name!= internedStringsModuleName }
val definedLoadAddress: Int
get() = toplevelModule.loadAddress
var actualLoadAddress: Int = 0
private val internedStringsUnique = mutableMapOf<Pair<String, Boolean>, List<String>>()
fun internString(string: StringLiteralValue): List<String> {
// Move a string literal into the internal, deduplicated, string pool
// replace it with a variable declaration that points to the entry in the pool.
if(string.parent is VarDecl) {
// deduplication can only be performed safely for known-const strings (=string literals OUTSIDE OF A VARDECL)!
throw FatalAstException("cannot intern a string literal that's part of a vardecl")
}
fun getScopedName(string: StringLiteralValue): List<String> {
val internedStringsBlock = modules
.first { it.name == internedStringsModuleName }.statements
.first { it is Block && it.name == internedStringsModuleName } as Block
val varName = "string_${internedStringsBlock.statements.size}"
val decl = VarDecl(
VarDeclType.VAR, DataType.STR, ZeropageWish.NOT_IN_ZEROPAGE, null, varName, string,
isArray = false, autogeneratedDontRemove = true, sharedWithAsm = false, position = string.position
)
internedStringsBlock.statements.add(decl)
decl.linkParents(internedStringsBlock)
return listOf(internedStringsModuleName, decl.name)
}
val key = Pair(string.value, string.altEncoding)
val existing = internedStringsUnique[key]
if (existing != null)
return existing
val scopedName = getScopedName(string)
internedStringsUnique[key] = scopedName
return scopedName
}
}

View File

@ -2,6 +2,7 @@ package prog8.ast.walk
import prog8.ast.* import prog8.ast.*
import prog8.ast.base.FatalAstException import prog8.ast.base.FatalAstException
import prog8.ast.base.ParentSentinel
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.* import prog8.ast.statements.*
@ -105,7 +106,7 @@ abstract class AstWalker {
open fun before(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> = noModifications open fun before(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> = noModifications
open fun before(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications open fun before(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications
open fun before(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications open fun before(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications
open fun before(program: Program, parent: Node): Iterable<IAstModification> = noModifications open fun before(program: Program): Iterable<IAstModification> = noModifications
open fun before(range: RangeExpr, parent: Node): Iterable<IAstModification> = noModifications open fun before(range: RangeExpr, parent: Node): Iterable<IAstModification> = noModifications
open fun before(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> = noModifications open fun before(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun before(returnStmt: Return, parent: Node): Iterable<IAstModification> = noModifications open fun before(returnStmt: Return, parent: Node): Iterable<IAstModification> = noModifications
@ -146,7 +147,7 @@ abstract class AstWalker {
open fun after(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> = noModifications open fun after(nopStatement: NopStatement, parent: Node): Iterable<IAstModification> = noModifications
open fun after(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications open fun after(numLiteral: NumericLiteralValue, parent: Node): Iterable<IAstModification> = noModifications
open fun after(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications open fun after(postIncrDecr: PostIncrDecr, parent: Node): Iterable<IAstModification> = noModifications
open fun after(program: Program, parent: Node): Iterable<IAstModification> = noModifications open fun after(program: Program): Iterable<IAstModification> = noModifications
open fun after(range: RangeExpr, parent: Node): Iterable<IAstModification> = noModifications open fun after(range: RangeExpr, parent: Node): Iterable<IAstModification> = noModifications
open fun after(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> = noModifications open fun after(untilLoop: UntilLoop, parent: Node): Iterable<IAstModification> = noModifications
open fun after(returnStmt: Return, parent: Node): Iterable<IAstModification> = noModifications open fun after(returnStmt: Return, parent: Node): Iterable<IAstModification> = noModifications
@ -196,9 +197,9 @@ abstract class AstWalker {
} }
fun visit(program: Program) { fun visit(program: Program) {
track(before(program, program), program, program) track(before(program), ParentSentinel, program.namespace)
program.modules.forEach { it.accept(this, program) } program.modules.forEach { it.accept(this, program.namespace) }
track(after(program, program), program, program) track(after(program), ParentSentinel, program.namespace)
} }
fun visit(module: Module, parent: Node) { fun visit(module: Module, parent: Node) {