diff --git a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt index f0816cfc5..71675399a 100644 --- a/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt +++ b/compiler/src/prog8/compiler/BeforeAsmGenerationAstChanger.kt @@ -74,6 +74,27 @@ internal class BeforeAsmGenerationAstChanger(val program: Program, val errors: I override fun before(subroutine: Subroutine, parent: Node): Iterable { subroutineVariables.clear() addedIfConditionVars.clear() + + if(!subroutine.isAsmSubroutine) { + // change 'str' parameters into 'uword' (just treat it as an address) + // TODO fix [TypecastsAdder] to treat str param vars as uword instead of adding casts/addressof (which is wrong for str *parameters*) + val stringParams = subroutine.parameters.filter { it.type==DataType.STR } + val parameterChanges = stringParams.map { + val uwordParam = SubroutineParameter(it.name, DataType.UWORD, it.position) + IAstModification.ReplaceNode(it, uwordParam, subroutine) + } + + val stringParamNames = stringParams.map { it.name }.toSet() + val varsChanges = subroutine.statements + .filterIsInstance() + .filter { it.autogeneratedDontRemove && it.name in stringParamNames } + .map { + val newvar = VarDecl(it.type, DataType.UWORD, it.zeropage, null, it.name, null, false, true, it.sharedWithAsm, it.position) + IAstModification.ReplaceNode(it, newvar, subroutine) + } + + return parameterChanges + varsChanges + } return noModifications } diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 95eee168a..229bf4b6e 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -346,7 +346,7 @@ private fun writeAssembly(programAst: Program, programAst.processAstBeforeAsmGeneration(errors, compilerOptions.compTarget) errors.report() - // printAst(programAst) + printAst(programAst) // TODO compilerOptions.compTarget.machine.initializeZeropage(compilerOptions) val assembly = asmGeneratorFor(compilerOptions.compTarget, diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index 898a3609e..6dd78fcf3 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -382,10 +382,12 @@ internal class AstChecker(private val program: Program, err("can only use Carry as status flag parameter") } else { - // Pass-by-reference datatypes can not occur as parameters to a subroutine directly + // Non-string Pass-by-reference datatypes can not occur as parameters to a subroutine directly // Instead, their reference (address) should be passed (as an UWORD). - if(subroutine.parameters.any{it.type in PassByReferenceDatatypes }) { - err("Pass-by-reference types (str, array) cannot occur as a parameter type directly. Instead, use an uword to receive their address, or access the variable from the outer scope directly.") + for(p in subroutine.parameters) { + if(p.type in PassByReferenceDatatypes && p.type != DataType.STR) { + err("Non-string pass-by-reference types cannot occur as a parameter type directly. Instead, use an uword to receive their address, or access the variable from the outer scope directly.") + } } } } diff --git a/compiler/test/AsmgenTests.kt b/compiler/test/AsmgenTests.kt index 73d3689d3..2d1f20a1c 100644 --- a/compiler/test/AsmgenTests.kt +++ b/compiler/test/AsmgenTests.kt @@ -63,7 +63,7 @@ locallabel: val assign8 = Assignment(tgt, AddressOf(IdentifierReference(listOf("main","label_outside"), Position.DUMMY), Position.DUMMY), Position.DUMMY) val statements = mutableListOf(varInSub, var2InSub, labelInSub, assign1, assign2, assign3, assign4, assign5, assign6, assign7, assign8) - val subroutine = Subroutine("start", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, statements, Position.DUMMY) + val subroutine = Subroutine("start", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, statements, Position.DUMMY) val labelInBlock = Label("label_outside", Position.DUMMY) val varInBlock = VarDecl(VarDeclType.VAR, DataType.UWORD, ZeropageWish.DONTCARE, null, "var_outside", null, false, false, false, Position.DUMMY) val block = Block("main", null, mutableListOf(labelInBlock, varInBlock, subroutine), false, Position.DUMMY) diff --git a/compiler/test/TestMemory.kt b/compiler/test/TestMemory.kt index 6906e63f8..21dff8654 100644 --- a/compiler/test/TestMemory.kt +++ b/compiler/test/TestMemory.kt @@ -89,7 +89,7 @@ class TestMemory { val memexpr = IdentifierReference(listOf("address"), Position.DUMMY) val target = AssignTarget(null, null, DirectMemoryWrite(memexpr, Position.DUMMY), Position.DUMMY) val assignment = Assignment(target, NumericLiteralValue.optimalInteger(0, Position.DUMMY), Position.DUMMY) - val subroutine = Subroutine("test", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) + val subroutine = Subroutine("test", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) val module = Module(mutableListOf(subroutine), Position.DUMMY, SourceCode.Generated("test")) module.linkIntoProgram(program) return target @@ -107,7 +107,7 @@ class TestMemory { val decl = VarDecl(VarDeclType.VAR, DataType.BYTE, ZeropageWish.DONTCARE, null, "address", null, false, false, false, Position.DUMMY) val target = AssignTarget(IdentifierReference(listOf("address"), Position.DUMMY), null, null, Position.DUMMY) val assignment = Assignment(target, NumericLiteralValue.optimalInteger(0, Position.DUMMY), Position.DUMMY) - val subroutine = Subroutine("test", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) + val subroutine = Subroutine("test", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) val module = Module(mutableListOf(subroutine), Position.DUMMY, SourceCode.Generated("test")) val program = Program("test", DummyFunctions, DummyMemsizer) .addModule(module) @@ -121,7 +121,7 @@ class TestMemory { val decl = VarDecl(VarDeclType.MEMORY, DataType.UBYTE, ZeropageWish.DONTCARE, null, "address", NumericLiteralValue.optimalInteger(address, Position.DUMMY), false, false, false, Position.DUMMY) val target = AssignTarget(IdentifierReference(listOf("address"), Position.DUMMY), null, null, Position.DUMMY) val assignment = Assignment(target, NumericLiteralValue.optimalInteger(0, Position.DUMMY), Position.DUMMY) - val subroutine = Subroutine("test", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) + val subroutine = Subroutine("test", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) val module = Module(mutableListOf(subroutine), Position.DUMMY, SourceCode.Generated("test")) val program = Program("test", DummyFunctions, DummyMemsizer) .addModule(module) @@ -135,7 +135,7 @@ class TestMemory { val decl = VarDecl(VarDeclType.MEMORY, DataType.UBYTE, ZeropageWish.DONTCARE, null, "address", NumericLiteralValue.optimalInteger(address, Position.DUMMY), false, false, false, Position.DUMMY) val target = AssignTarget(IdentifierReference(listOf("address"), Position.DUMMY), null, null, Position.DUMMY) val assignment = Assignment(target, NumericLiteralValue.optimalInteger(0, Position.DUMMY), Position.DUMMY) - val subroutine = Subroutine("test", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) + val subroutine = Subroutine("test", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) val module = Module(mutableListOf(subroutine), Position.DUMMY, SourceCode.Generated("test")) val program = Program("test", DummyFunctions, DummyMemsizer) .addModule(module) @@ -149,7 +149,7 @@ class TestMemory { val arrayindexed = ArrayIndexedExpression(IdentifierReference(listOf("address"), Position.DUMMY), ArrayIndex(NumericLiteralValue.optimalInteger(1, Position.DUMMY), Position.DUMMY), Position.DUMMY) val target = AssignTarget(null, arrayindexed, null, Position.DUMMY) val assignment = Assignment(target, NumericLiteralValue.optimalInteger(0, Position.DUMMY), Position.DUMMY) - val subroutine = Subroutine("test", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) + val subroutine = Subroutine("test", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) val module = Module(mutableListOf(subroutine), Position.DUMMY, SourceCode.Generated("test")) val program = Program("test", DummyFunctions, DummyMemsizer) .addModule(module) @@ -164,7 +164,7 @@ class TestMemory { val arrayindexed = ArrayIndexedExpression(IdentifierReference(listOf("address"), Position.DUMMY), ArrayIndex(NumericLiteralValue.optimalInteger(1, Position.DUMMY), Position.DUMMY), Position.DUMMY) val target = AssignTarget(null, arrayindexed, null, Position.DUMMY) val assignment = Assignment(target, NumericLiteralValue.optimalInteger(0, Position.DUMMY), Position.DUMMY) - val subroutine = Subroutine("test", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) + val subroutine = Subroutine("test", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) val module = Module(mutableListOf(subroutine), Position.DUMMY, SourceCode.Generated("test")) val program = Program("test", DummyFunctions, DummyMemsizer) .addModule(module) @@ -179,7 +179,7 @@ class TestMemory { val arrayindexed = ArrayIndexedExpression(IdentifierReference(listOf("address"), Position.DUMMY), ArrayIndex(NumericLiteralValue.optimalInteger(1, Position.DUMMY), Position.DUMMY), Position.DUMMY) val target = AssignTarget(null, arrayindexed, null, Position.DUMMY) val assignment = Assignment(target, NumericLiteralValue.optimalInteger(0, Position.DUMMY), Position.DUMMY) - val subroutine = Subroutine("test", emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) + val subroutine = Subroutine("test", mutableListOf(), emptyList(), emptyList(), emptyList(), emptySet(), null, false, false, mutableListOf(decl, assignment), Position.DUMMY) val module = Module(mutableListOf(subroutine), Position.DUMMY, SourceCode.Generated("test")) val program = Program("test", DummyFunctions, DummyMemsizer) .addModule(module) diff --git a/compiler/test/TestSubroutines.kt b/compiler/test/TestSubroutines.kt index 0030aa460..5b5430a70 100644 --- a/compiler/test/TestSubroutines.kt +++ b/compiler/test/TestSubroutines.kt @@ -4,8 +4,8 @@ import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import prog8.ast.base.DataType -import prog8.ast.statements.Block -import prog8.ast.statements.Subroutine +import prog8.ast.expressions.IdentifierReference +import prog8.ast.statements.* import prog8.compiler.target.C64Target import prog8tests.helpers.ErrorReporterForTests import prog8tests.helpers.assertFailure @@ -17,31 +17,6 @@ import kotlin.test.* @TestInstance(TestInstance.Lifecycle.PER_CLASS) class TestSubroutines { - @Test - fun stringParameterNotYetAllowed_ButShouldPerhapsBe() { - // note: the *parser* accepts this as it is valid *syntax*, - // however, it's not (yet) valid for the compiler - val text = """ - main { - sub start() { - str zzz ; should give uninitialized error - } - - asmsub asmfunc(str thing @AY) { - } - - sub func(str thing) { - } - } - """ - val errors = ErrorReporterForTests() - compileText(C64Target, false, text, errors, false).assertFailure("currently str type in signature is invalid") // TODO should not be invalid - assertEquals(0, errors.warnings.size) - assertEquals(2, errors.errors.size) - assertContains(errors.errors[0], ".p8:4:20: string var must be initialized with a string literal") - assertContains(errors.errors[1], ".p8:10:16: Pass-by-reference types (str, array) cannot occur as a parameter type directly") - } - @Test fun arrayParameterNotYetAllowed_ButShouldPerhapsBe() { // note: the *parser* accepts this as it is valid *syntax*, @@ -62,11 +37,10 @@ class TestSubroutines { val errors = ErrorReporterForTests() compileText(C64Target, false, text, errors, false).assertFailure("currently array dt in signature is invalid") // TODO should not be invalid? assertEquals(0, errors.warnings.size) - assertContains(errors.errors.single(), ".p8:9:16: Pass-by-reference types (str, array) cannot occur as a parameter type directly") + assertContains(errors.errors.single(), ".p8:9:16: Non-string pass-by-reference types cannot occur as a parameter type directly") } @Test - @Disabled("TODO: allow string parameter in signature") // TODO allow this fun stringParameter() { val text = """ main { @@ -76,17 +50,17 @@ class TestSubroutines { asmfunc("text") asmfunc(text) asmfunc($2000) - asmfunc(12.345) func("text") func(text) func($2000) - func(12.345) } asmsub asmfunc(str thing @AY) { } sub func(str thing) { + uword t2 = thing as uword + asmfunc(thing) } } """ @@ -100,7 +74,69 @@ class TestSubroutines { assertTrue(asmfunc.statements.isEmpty()) assertFalse(func.isAsmSubroutine) assertEquals(DataType.STR, func.parameters.single().type) - assertTrue(func.statements.isEmpty()) + assertEquals(3, func.statements.size) + val paramvar = func.statements[0] as VarDecl + assertEquals("thing", paramvar.name) + assertEquals(DataType.STR, paramvar.datatype) + val t2var = func.statements[1] as VarDecl + assertEquals("t2", t2var.name) + assertTrue(t2var.value is IdentifierReference, "str param in function body should be treated as plain uword") + assertEquals("thing", (t2var.value as IdentifierReference).nameInSource.single()) + val call = func.statements[2] as FunctionCallStatement + assertEquals("asmfunc", call.target.nameInSource.single()) + assertTrue(call.args.single() is IdentifierReference, "str param in function body should be treated as plain uword") + assertEquals("thing", (call.args.single() as IdentifierReference).nameInSource.single()) + } + + @Test + fun stringParameterAsmGen() { + val text = """ + main { + sub start() { + str text = "test" + + asmfunc("text") + asmfunc(text) + asmfunc($2000) + func("text") + func(text) + func($2000) + } + + asmsub asmfunc(str thing @AY) { + } + + sub func(str thing) { + uword t2 = thing as uword + asmfunc(thing) + } + } + """ + val result = compileText(C64Target, false, text, writeAssembly = true).assertSuccess() + val module = result.programAst.toplevelModule + val mainBlock = module.statements.single() as Block + val asmfunc = mainBlock.statements.filterIsInstance().single { it.name=="asmfunc"} + val func = mainBlock.statements.filterIsInstance().single { it.name=="func"} + assertTrue(asmfunc.isAsmSubroutine) + assertEquals(DataType.STR, asmfunc.parameters.single().type) + assertTrue(asmfunc.statements.single() is Return) + assertFalse(func.isAsmSubroutine) + assertEquals(DataType.UWORD, func.parameters.single().type, "asmgen should have changed str to uword type") + assertTrue(asmfunc.statements.last() is Return) + + assertEquals(4, func.statements.size) + assertTrue(func.statements[3] is Return) + val paramvar = func.statements[0] as VarDecl + assertEquals("thing", paramvar.name) + assertEquals(DataType.UWORD, paramvar.datatype, "pre-asmgen should have changed str to uword type") + val t2var = func.statements[1] as VarDecl + assertEquals("t2", t2var.name) + assertTrue(t2var.value is IdentifierReference, "str param in function body should be treated as plain uword") + assertEquals("thing", (t2var.value as IdentifierReference).nameInSource.single()) + val call = func.statements[2] as FunctionCallStatement + assertEquals("asmfunc", call.target.nameInSource.single()) + assertTrue(call.args.single() is IdentifierReference, "str param in function body should be treated as plain uword") + assertEquals("thing", (call.args.single() as IdentifierReference).nameInSource.single()) } @Test diff --git a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt index a20714f76..940c04840 100644 --- a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt +++ b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt @@ -186,7 +186,7 @@ private fun Prog8ANTLRParser.AsmsubroutineContext.toAst(): Subroutine { val inline = this.inline()!=null val subdecl = asmsub_decl().toAst() val statements = statement_block()?.toAst() ?: mutableListOf() - return Subroutine(subdecl.name, subdecl.parameters, subdecl.returntypes, + return Subroutine(subdecl.name, subdecl.parameters.toMutableList(), subdecl.returntypes, subdecl.asmParameterRegisters, subdecl.asmReturnvaluesRegisters, subdecl.asmClobbers, null, true, inline, statements, toPosition()) } @@ -194,7 +194,7 @@ private fun Prog8ANTLRParser.AsmsubroutineContext.toAst(): Subroutine { private fun Prog8ANTLRParser.RomsubroutineContext.toAst(): Subroutine { val subdecl = asmsub_decl().toAst() val address = integerliteral().toAst().number.toInt() - return Subroutine(subdecl.name, subdecl.parameters, subdecl.returntypes, + return Subroutine(subdecl.name, subdecl.parameters.toMutableList(), subdecl.returntypes, subdecl.asmParameterRegisters, subdecl.asmReturnvaluesRegisters, subdecl.asmClobbers, address, true, inline = false, statements = mutableListOf(), position = toPosition() ) @@ -306,7 +306,7 @@ private fun Prog8ANTLRParser.SubroutineContext.toAst() : Subroutine { val inline = inline()!=null val returntypes = sub_return_part()?.toAst() ?: emptyList() return Subroutine(identifier().text, - sub_params()?.toAst() ?: emptyList(), + sub_params()?.toAst()?.toMutableList() ?: mutableListOf(), returntypes, statement_block()?.toAst() ?: mutableListOf(), inline, diff --git a/compilerAst/src/prog8/ast/statements/AstStatements.kt b/compilerAst/src/prog8/ast/statements/AstStatements.kt index 9b41aebfc..95e8abcf3 100644 --- a/compilerAst/src/prog8/ast/statements/AstStatements.kt +++ b/compilerAst/src/prog8/ast/statements/AstStatements.kt @@ -597,7 +597,7 @@ class AsmGenInfo { // and also the predefined/ROM/register-based subroutines. // (multiple return types can only occur for the latter type) class Subroutine(override val name: String, - val parameters: List, + val parameters: MutableList, val returntypes: List, val asmParameterRegisters: List, val asmReturnvaluesRegisters: List, @@ -608,7 +608,7 @@ class Subroutine(override val name: String, override var statements: MutableList, override val position: Position) : Statement(), INameScope, ISymbolStatement { - constructor(name: String, parameters: List, returntypes: List, statements: MutableList, inline: Boolean, position: Position) + constructor(name: String, parameters: MutableList, returntypes: List, statements: MutableList, inline: Boolean, position: Position) : this(name, parameters, returntypes, emptyList(), determineReturnRegisters(returntypes), emptySet(), null, false, inline, statements, position) companion object { @@ -635,10 +635,19 @@ class Subroutine(override val name: String, } override fun replaceChildNode(node: Node, replacement: Node) { - require(replacement is Statement) - val idx = statements.indexOfFirst { it===node } - statements[idx] = replacement - replacement.parent = this + when(replacement) { + is SubroutineParameter -> { + val idx = parameters.indexOf(node) + parameters[idx] = replacement + replacement.parent = this + } + is Statement -> { + val idx = statements.indexOfFirst { it===node } + statements[idx] = replacement + replacement.parent = this + } + else -> throw FatalAstException("can't replace") + } } override fun accept(visitor: IAstVisitor) = visitor.visit(this) diff --git a/examples/test.p8 b/examples/test.p8 index 13a7e3e7f..f17d99e6c 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,17 +1,39 @@ +%import string +%import textio +%zeropage basicsafe + main { sub start() { - str zzz + str text = "variable" + + @($2000) = 'a' + @($2001) = 'b' + @($2002) = 'c' + @($2003) = 0 + asmfunc("text") + asmfunc(text) + asmfunc($2000) func("text") + func(text) + func($2000) } asmsub asmfunc(str thing @AY) { %asm {{ - rts + sta func.thing + sty func.thing+1 + jmp func }} } - - sub func(str thing) { + ; TODO fix asmgen when using 'str' type + sub func(uword thing) { + uword t2 = thing as uword + ubyte length = string.length(thing) + txt.print_ub(length) + txt.nl() + txt.print(thing) + txt.nl() } }