From 3767b4bbe758d8f7a187aca2a39c2ac53a770fb8 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sat, 30 Oct 2021 00:25:34 +0200 Subject: [PATCH] 'Program' is not an ast Node --- compiler/src/prog8/compiler/Compiler.kt | 98 +++++++-------- .../compiler/astprocessing/AstChecker.kt | 2 +- compiler/test/TestCallgraph.kt | 10 +- compiler/test/TestCompilerOnCharLit.kt | 6 +- .../test/TestCompilerOnImportsAndIncludes.kt | 4 +- compiler/test/TestCompilerOnRanges.kt | 12 +- .../TestImportedModulesOrderAndOptions.kt | 10 +- compiler/test/TestOptimization.kt | 8 +- compiler/test/TestScoping.kt | 26 +++- compiler/test/TestSubroutines.kt | 6 +- compilerAst/src/prog8/ast/AstToplevel.kt | 118 ------------------ compilerAst/src/prog8/ast/Program.kt | 112 +++++++++++++++++ compilerAst/src/prog8/ast/walk/AstWalker.kt | 11 +- 13 files changed, 217 insertions(+), 206 deletions(-) create mode 100644 compilerAst/src/prog8/ast/Program.kt diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 84ec61cbb..7bffcecad 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -24,7 +24,7 @@ import kotlin.system.measureTimeMillis class CompilationResult(val success: Boolean, - val programAst: Program, + val program: Program, val programName: String, val compTarget: ICompilationTarget, val importedFiles: List) @@ -39,7 +39,7 @@ fun compileProgram(filepath: Path, outputDir: Path, errors: IErrorReporter = ErrorReporter()): CompilationResult { var programName = "" - lateinit var programAst: Program + lateinit var program: Program lateinit var importedFiles: List val compTarget = @@ -52,31 +52,31 @@ fun compileProgram(filepath: Path, try { val totalTime = measureTimeMillis { // import main module and everything it needs - val (ast, compilationOptions, imported) = parseImports(filepath, errors, compTarget, sourceDirs) + val (programresult, compilationOptions, imported) = parseImports(filepath, errors, compTarget, sourceDirs) compilationOptions.slowCodegenWarnings = slowCodegenWarnings compilationOptions.optimize = optimize - programAst = ast + program = programresult importedFiles = imported - processAst(programAst, errors, compilationOptions) + processAst(program, errors, compilationOptions) if (compilationOptions.optimize) optimizeAst( - programAst, + program, errors, BuiltinFunctionsFacade(BuiltinFunctions), compTarget ) - postprocessAst(programAst, errors, compilationOptions) + postprocessAst(program, errors, compilationOptions) // println("*********** AST BEFORE ASSEMBLYGEN *************") -// printAst(programAst) +// printAst(program) if (writeAssembly) { - val result = writeAssembly(programAst, errors, outputDir, compilationOptions) + val result = writeAssembly(program, errors, outputDir, compilationOptions) when (result) { is WriteAssemblyResult.Ok -> programName = result.filename is WriteAssemblyResult.Fail -> { System.err.println(result.error) - return CompilationResult(false, programAst, programName, compTarget, importedFiles) + return CompilationResult(false, program, programName, compTarget, importedFiles) } } } @@ -84,7 +84,7 @@ fun compileProgram(filepath: Path, System.out.flush() System.err.flush() println("\nTotal compilation+assemble time: ${totalTime / 1000.0} sec.") - return CompilationResult(true, programAst, programName, compTarget, importedFiles) + return CompilationResult(true, program, programName, compTarget, importedFiles) } catch (px: ParseError) { System.err.print("\u001b[91m") // bright red System.err.println("${px.position.toClickableStr()} parse error: ${px.message}".trim()) @@ -155,20 +155,20 @@ fun parseImports(filepath: Path, sourceDirs: List): Triple> { println("Compiler target: ${compTarget.name}. Parsing...") val bf = BuiltinFunctionsFacade(BuiltinFunctions) - val programAst = Program(filepath.nameWithoutExtension, bf, compTarget, compTarget) - bf.program = programAst + val program = Program(filepath.nameWithoutExtension, bf, compTarget, compTarget) + bf.program = program - val importer = ModuleImporter(programAst, compTarget.name, errors, sourceDirs) + val importer = ModuleImporter(program, compTarget.name, errors, sourceDirs) val importedModuleResult = importer.importModule(filepath) importedModuleResult.onFailure { throw it } errors.report() - val importedFiles = programAst.modules.map { it.source } + val importedFiles = program.modules.map { it.source } .filter { it.isFromFilesystem } .map { Path(it.origin) } - val compilerOptions = determineCompilationOptions(programAst, compTarget) + val compilerOptions = determineCompilationOptions(program, compTarget) if (compilerOptions.launcher == LauncherType.BASIC && compilerOptions.output != OutputType.PRG) - throw ParsingFailedError("${programAst.modules.first().position} BASIC launcher requires output type PRG.") + throw ParsingFailedError("${program.modules.first().position} BASIC launcher requires output type PRG.") // depending on the machine and compiler options we may have to include some libraries for(lib in compTarget.machine.importLibs(compilerOptions, compTarget.name)) @@ -178,7 +178,7 @@ fun parseImports(filepath: Path, importer.importLibraryModule("math") importer.importLibraryModule("prog8_lib") errors.report() - return Triple(programAst, compilerOptions, importedFiles) + return Triple(program, compilerOptions, importedFiles) } fun determineCompilationOptions(program: Program, compTarget: ICompilationTarget): CompilationOptions { @@ -241,44 +241,44 @@ fun determineCompilationOptions(program: Program, compTarget: ICompilationTarget ) } -private fun processAst(programAst: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) { +private fun processAst(program: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) { // perform initial syntax checks and processings println("Processing for target ${compilerOptions.compTarget.name}...") - programAst.preprocessAst() - programAst.checkIdentifiers(errors, compilerOptions) + program.preprocessAst() + program.checkIdentifiers(errors, compilerOptions) errors.report() // TODO: turning char literals into UBYTEs via an encoding should really happen in code gen - but for that we'd need DataType.CHAR // NOTE: we will then lose the opportunity to do constant-folding on any expression containing a char literal, but how often will those occur? // Also they might be optimized away eventually in codegen or by the assembler even - programAst.charLiteralsToUByteLiterals(compilerOptions.compTarget) - programAst.constantFold(errors, compilerOptions.compTarget) + program.charLiteralsToUByteLiterals(compilerOptions.compTarget) + program.constantFold(errors, compilerOptions.compTarget) errors.report() - programAst.reorderStatements(errors) + program.reorderStatements(errors) errors.report() - programAst.addTypecasts(errors) + program.addTypecasts(errors) errors.report() - programAst.variousCleanups(programAst, errors) + program.variousCleanups(program, errors) errors.report() - programAst.checkValid(compilerOptions, errors, compilerOptions.compTarget) + program.checkValid(compilerOptions, errors, compilerOptions.compTarget) errors.report() - programAst.checkIdentifiers(errors, compilerOptions) + program.checkIdentifiers(errors, compilerOptions) errors.report() } -private fun optimizeAst(programAst: Program, errors: IErrorReporter, functions: IBuiltinFunctions, compTarget: ICompilationTarget) { +private fun optimizeAst(program: Program, errors: IErrorReporter, functions: IBuiltinFunctions, compTarget: ICompilationTarget) { // optimize the parse tree println("Optimizing...") - val remover = UnusedCodeRemover(programAst, errors, compTarget) - remover.visit(programAst) + val remover = UnusedCodeRemover(program, errors, compTarget) + remover.visit(program) remover.applyModifications() while (true) { // keep optimizing expressions and statements until no more steps remain - val optsDone1 = programAst.simplifyExpressions() - val optsDone2 = programAst.splitBinaryExpressions(compTarget) - val optsDone3 = programAst.optimizeStatements(errors, functions, compTarget) - programAst.constantFold(errors, compTarget) // because simplified statements and expressions can result in more constants that can be folded away + val optsDone1 = program.simplifyExpressions() + val optsDone2 = program.splitBinaryExpressions(compTarget) + val optsDone3 = program.optimizeStatements(errors, functions, compTarget) + program.constantFold(errors, compTarget) // because simplified statements and expressions can result in more constants that can be folded away errors.report() if (optsDone1 + optsDone2 + optsDone3 == 0) break @@ -287,17 +287,17 @@ private fun optimizeAst(programAst: Program, errors: IErrorReporter, functions: errors.report() } -private fun postprocessAst(programAst: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) { - programAst.addTypecasts(errors) +private fun postprocessAst(program: Program, errors: IErrorReporter, compilerOptions: CompilationOptions) { + program.addTypecasts(errors) errors.report() - programAst.variousCleanups(programAst, errors) - programAst.checkValid(compilerOptions, errors, compilerOptions.compTarget) // check if final tree is still valid + program.variousCleanups(program, errors) + program.checkValid(compilerOptions, errors, compilerOptions.compTarget) // check if final tree is still valid errors.report() - val callGraph = CallGraph(programAst) + val callGraph = CallGraph(program) callGraph.checkRecursiveCalls(errors) errors.report() - programAst.verifyFunctionArgTypes() - programAst.moveMainAndStartToFirst() + program.verifyFunctionArgTypes() + program.moveMainAndStartToFirst() } private sealed class WriteAssemblyResult { @@ -305,20 +305,20 @@ private sealed class WriteAssemblyResult { class Fail(val error: String): WriteAssemblyResult() } -private fun writeAssembly(programAst: Program, +private fun writeAssembly(program: Program, errors: IErrorReporter, outputDir: Path, compilerOptions: CompilationOptions ): WriteAssemblyResult { // asm generation directly from the Ast - programAst.processAstBeforeAsmGeneration(errors, compilerOptions.compTarget) + program.processAstBeforeAsmGeneration(errors, compilerOptions.compTarget) errors.report() -// printAst(programAst) +// printAst(program) compilerOptions.compTarget.machine.initializeZeropage(compilerOptions) val assembly = asmGeneratorFor(compilerOptions.compTarget, - programAst, + program, errors, compilerOptions.compTarget.machine.zeropage, compilerOptions, @@ -338,10 +338,10 @@ private fun writeAssembly(programAst: Program, } } -fun printAst(programAst: Program) { +fun printAst(program: Program) { println() - val printer = AstToSourceTextConverter(::print, programAst) - printer.visit(programAst) + val printer = AstToSourceTextConverter(::print, program) + printer.visit(program) println() } diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index 2ac7c5050..8b0875953 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -27,7 +27,7 @@ internal class AstChecker(private val program: Program, if(mainBlocks.size>1) errors.err("more than one 'main' block", mainBlocks[0].position) if(mainBlocks.isEmpty()) - errors.err("there is no 'main' block", program.modules.firstOrNull()?.position ?: program.position) + errors.err("there is no 'main' block", program.modules.firstOrNull()?.position ?: Position.DUMMY) for(mainBlock in mainBlocks) { val startSub = mainBlock.subScope("start") as? Subroutine diff --git a/compiler/test/TestCallgraph.kt b/compiler/test/TestCallgraph.kt index 43c79bbd9..bbe140d75 100644 --- a/compiler/test/TestCallgraph.kt +++ b/compiler/test/TestCallgraph.kt @@ -26,11 +26,11 @@ class TestCallgraph { } """ val result = compileText(C64Target, false, sourcecode).assertSuccess() - val graph = CallGraph(result.programAst) + val graph = CallGraph(result.program) assertEquals(1, graph.imports.size) assertEquals(1, graph.importedBy.size) - val toplevelModule = result.programAst.toplevelModule + val toplevelModule = result.program.toplevelModule val importedModule = graph.imports.getValue(toplevelModule).single() assertEquals("string", importedModule.name) val importedBy = graph.importedBy.getValue(importedModule).single() @@ -45,7 +45,7 @@ class TestCallgraph { assertFalse(sub in graph.calls) assertFalse(sub in graph.calledBy) - if(sub === result.programAst.entrypoint) + if(sub === result.program.entrypoint) assertFalse(graph.unused(sub), "start() should always be marked as used to avoid having it removed") else assertTrue(graph.unused(sub)) @@ -66,11 +66,11 @@ class TestCallgraph { } """ val result = compileText(C64Target, false, sourcecode).assertSuccess() - val graph = CallGraph(result.programAst) + val graph = CallGraph(result.program) assertEquals(1, graph.imports.size) assertEquals(1, graph.importedBy.size) - val toplevelModule = result.programAst.toplevelModule + val toplevelModule = result.program.toplevelModule val importedModule = graph.imports.getValue(toplevelModule).single() assertEquals("string", importedModule.name) val importedBy = graph.importedBy.getValue(importedModule).single() diff --git a/compiler/test/TestCompilerOnCharLit.kt b/compiler/test/TestCompilerOnCharLit.kt index 583d9983a..0b73219d2 100644 --- a/compiler/test/TestCompilerOnCharLit.kt +++ b/compiler/test/TestCompilerOnCharLit.kt @@ -34,7 +34,7 @@ class TestCompilerOnCharLit { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val funCall = startSub.statements.filterIsInstance()[0] @@ -58,7 +58,7 @@ class TestCompilerOnCharLit { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val funCall = startSub.statements.filterIsInstance()[0] @@ -93,7 +93,7 @@ class TestCompilerOnCharLit { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val funCall = startSub.statements.filterIsInstance()[0] diff --git a/compiler/test/TestCompilerOnImportsAndIncludes.kt b/compiler/test/TestCompilerOnImportsAndIncludes.kt index 241cb660e..c2db07589 100644 --- a/compiler/test/TestCompilerOnImportsAndIncludes.kt +++ b/compiler/test/TestCompilerOnImportsAndIncludes.kt @@ -37,7 +37,7 @@ class TestCompilerOnImportsAndIncludes { val result = compileFile(platform, optimize = false, fixturesDir, filepath.name) .assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val strLits = startSub.statements .filterIsInstance() @@ -62,7 +62,7 @@ class TestCompilerOnImportsAndIncludes { val result = compileFile(platform, optimize = false, fixturesDir, filepath.name) .assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val args = startSub.statements .filterIsInstance() diff --git a/compiler/test/TestCompilerOnRanges.kt b/compiler/test/TestCompilerOnRanges.kt index 46b78da46..fce4616d9 100644 --- a/compiler/test/TestCompilerOnRanges.kt +++ b/compiler/test/TestCompilerOnRanges.kt @@ -43,7 +43,7 @@ class TestCompilerOnRanges { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val decl = startSub .statements.filterIsInstance()[0] @@ -72,7 +72,7 @@ class TestCompilerOnRanges { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val decl = startSub .statements.filterIsInstance()[0] @@ -154,7 +154,7 @@ class TestCompilerOnRanges { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val iterable = startSub .statements.filterIsInstance() @@ -185,7 +185,7 @@ class TestCompilerOnRanges { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val rangeExpr = startSub .statements.filterIsInstance() @@ -212,7 +212,7 @@ class TestCompilerOnRanges { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val rangeExpr = startSub .statements.filterIsInstance() @@ -256,7 +256,7 @@ class TestCompilerOnRanges { } """).assertSuccess() - val program = result.programAst + val program = result.program val startSub = program.entrypoint val iterable = startSub .statements.filterIsInstance() diff --git a/compiler/test/TestImportedModulesOrderAndOptions.kt b/compiler/test/TestImportedModulesOrderAndOptions.kt index 44a4fc0a5..4b726157e 100644 --- a/compiler/test/TestImportedModulesOrderAndOptions.kt +++ b/compiler/test/TestImportedModulesOrderAndOptions.kt @@ -30,9 +30,9 @@ main { } } """).assertSuccess() - assertTrue(result.programAst.toplevelModule.name.startsWith("on_the_fly_test")) + assertTrue(result.program.toplevelModule.name.startsWith("on_the_fly_test")) - val moduleNames = result.programAst.modules.map { it.name } + val moduleNames = result.program.modules.map { it.name } assertTrue(moduleNames[0].startsWith("on_the_fly_test"), "main module must be first") assertEquals(listOf( "prog8_interned_strings", @@ -44,7 +44,7 @@ main { "prog8_lib" ), moduleNames.drop(1), "module order in parse tree") - assertTrue(result.programAst.toplevelModule.name.startsWith("on_the_fly_test")) + assertTrue(result.program.toplevelModule.name.startsWith("on_the_fly_test")) } @Test @@ -61,8 +61,8 @@ main { } } """).assertSuccess() - assertTrue(result.programAst.toplevelModule.name.startsWith("on_the_fly_test")) - val options = determineCompilationOptions(result.programAst, C64Target) + assertTrue(result.program.toplevelModule.name.startsWith("on_the_fly_test")) + val options = determineCompilationOptions(result.program, C64Target) assertTrue(options.floats) assertEquals(ZeropageType.DONTUSE, options.zeropage) assertTrue(options.noSysInit) diff --git a/compiler/test/TestOptimization.kt b/compiler/test/TestOptimization.kt index b073f0d75..6edc7bcda 100644 --- a/compiler/test/TestOptimization.kt +++ b/compiler/test/TestOptimization.kt @@ -26,10 +26,10 @@ class TestOptimization { } """ val result = compileText(C64Target, true, sourcecode).assertSuccess() - val toplevelModule = result.programAst.toplevelModule + val toplevelModule = result.program.toplevelModule val mainBlock = toplevelModule.statements.single() as Block val startSub = mainBlock.statements.single() as Subroutine - assertSame(result.programAst.entrypoint, startSub) + assertSame(result.program.entrypoint, startSub) assertEquals("start", startSub.name, "only start sub should remain") assertTrue(startSub.statements.single() is Return, "compiler has inserted return in empty subroutines") } @@ -48,11 +48,11 @@ class TestOptimization { } """ val result = compileText(C64Target, true, sourcecode).assertSuccess() - val toplevelModule = result.programAst.toplevelModule + val toplevelModule = result.program.toplevelModule val mainBlock = toplevelModule.statements.single() as Block val startSub = mainBlock.statements[0] as Subroutine val emptySub = mainBlock.statements[1] as Subroutine - assertSame(result.programAst.entrypoint, startSub) + assertSame(result.program.entrypoint, startSub) assertEquals("start", startSub.name) assertEquals("empty", emptySub.name) assertTrue(emptySub.statements.single() is Return, "compiler has inserted return in empty subroutines") diff --git a/compiler/test/TestScoping.kt b/compiler/test/TestScoping.kt index b8cd40cf3..05bf26906 100644 --- a/compiler/test/TestScoping.kt +++ b/compiler/test/TestScoping.kt @@ -2,19 +2,35 @@ package prog8tests import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance +import prog8.ast.GlobalNamespace +import prog8.ast.base.ParentSentinel import prog8.ast.expressions.NumericLiteralValue import prog8.ast.statements.* import prog8.compiler.target.C64Target import prog8tests.helpers.assertSuccess import prog8tests.helpers.compileText -import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertTrue +import kotlin.test.* @TestInstance(TestInstance.Lifecycle.PER_CLASS) class TestScoping { + @Test + fun testModulesParentIsGlobalNamespace() { + val src = """ + main { + sub start() { + } + } + """ + + val result = compileText(C64Target, false, src, writeAssembly = false).assertSuccess() + val module = result.program.toplevelModule + assertIs(module.parent) + assertSame(result.program, module.program) + assertIs(module.parent.parent) + } + @Test fun testAnonScopeVarsMovedIntoSubroutineScope() { val src = """ @@ -29,7 +45,7 @@ class TestScoping { """ val result = compileText(C64Target, false, src, writeAssembly = false).assertSuccess() - val module = result.programAst.toplevelModule + val module = result.program.toplevelModule val mainBlock = module.statements.single() as Block val start = mainBlock.statements.single() as Subroutine val repeatbody = start.statements.filterIsInstance().single().body @@ -100,7 +116,7 @@ class TestScoping { """ val result = compileText(C64Target, false, src, writeAssembly = true).assertSuccess() - val module = result.programAst.toplevelModule + val module = result.program.toplevelModule val mainBlock = module.statements.single() as Block val start = mainBlock.statements.single() as Subroutine val labels = start.statements.filterIsInstance