diff --git a/codeCore/src/prog8/code/ast/AstStatements.kt b/codeCore/src/prog8/code/ast/AstStatements.kt index 16044aff3..39b4fd6d9 100644 --- a/codeCore/src/prog8/code/ast/AstStatements.kt +++ b/codeCore/src/prog8/code/ast/AstStatements.kt @@ -130,17 +130,7 @@ class PtRepeatLoop(position: Position) : PtNode(position) { } -class PtReturn(position: Position) : PtNode(position) { - val hasValue: Boolean - get() = children.any() - val value: PtExpression? - get() { - return if(children.any()) - children.single() as PtExpression - else - null - } -} +class PtReturn(position: Position) : PtNode(position) // children are all expressions sealed interface IPtVariable { diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt index 421efb6a5..be9507af5 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt @@ -1051,11 +1051,12 @@ $repeatLabel""") } private fun translate(ret: PtReturn) { - ret.value?.let { returnvalue -> + val returnvalue = ret.children.singleOrNull() + if(returnvalue!=null) { val sub = ret.definingSub()!! val returnReg = sub.returnRegister()!! if (sub.returntype?.isNumericOrBool==true) { - assignExpressionToRegister(returnvalue, returnReg.registerOrPair!!) + assignExpressionToRegister(returnvalue as PtExpression, returnReg.registerOrPair!!) } else { // all else take its address and assign that also to AY register pair @@ -1065,6 +1066,9 @@ $repeatLabel""") assignmentAsmGen.assignExpressionToRegister(addrofValue, returnReg.registerOrPair!!, false) } } + else if(ret.children.size>1) { + TODO("multi-value return") + } out(" rts") } diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt index 577433329..142d52c5e 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt @@ -1758,10 +1758,14 @@ class IRCodeGen( private fun translate(ret: PtReturn): IRCodeChunks { val result = mutableListOf() - val value = ret.value + if(ret.children.size>1) { + TODO("multi-value return") + } + val value = ret.children.singleOrNull() if(value==null) { addInstr(result, IRInstruction(Opcode.RETURN), null) } else { + value as PtExpression if(value.type.isFloat) { if(value is PtNumber) { addInstr(result, IRInstruction(Opcode.RETURNI, IRDataType.FLOAT, immediateFp = value.number), null) diff --git a/codeOptimizers/src/prog8/optimizer/Inliner.kt b/codeOptimizers/src/prog8/optimizer/Inliner.kt index ab4041f0e..a069bc5c4 100644 --- a/codeOptimizers/src/prog8/optimizer/Inliner.kt +++ b/codeOptimizers/src/prog8/optimizer/Inliner.kt @@ -13,7 +13,7 @@ import prog8.code.core.InternalCompilerException import prog8.code.target.VMTarget -private fun isEmptyReturn(stmt: Statement): Boolean = stmt is Return && stmt.value==null +private fun isEmptyReturn(stmt: Statement): Boolean = stmt is Return && stmt.values.size==0 // inliner potentially enables *ONE LINED* subroutines, wihtout to be inlined. @@ -38,26 +38,21 @@ class Inliner(private val program: Program, private val options: CompilationOpti // subroutine is possible candidate to be inlined subroutine.inline = when (val stmt = subroutine.statements[0]) { - is Return -> { - if (stmt.value is NumericLiteral) + // TODO consider multi-value returns as well + is Return -> stmt.values.isEmpty() || stmt.values.size==1 && + if (stmt.values[0] is NumericLiteral) true - else if (stmt.value == null) + else if (stmt.values[0] is IdentifierReference) { + makeFullyScoped(stmt.values[0] as IdentifierReference) true - else if (stmt.value is IdentifierReference) { - makeFullyScoped(stmt.value as IdentifierReference) - true - } else if (stmt.value!! is IFunctionCall && (stmt.value as IFunctionCall).args.size <= 1 && (stmt.value as IFunctionCall).args.all { it is NumericLiteral || it is IdentifierReference }) { - when (stmt.value) { - is FunctionCallExpression -> { - makeFullyScoped(stmt.value as FunctionCallExpression) - true - } - - else -> false + } else if (stmt.values[0]!! is IFunctionCall && (stmt.values[0] as IFunctionCall).args.size <= 1 && (stmt.values[0] as IFunctionCall).args.all { it is NumericLiteral || it is IdentifierReference }) { + if (stmt.values[0] is FunctionCallExpression) { + makeFullyScoped(stmt.values[0] as FunctionCallExpression) + true } + else false } else false - } is Assignment -> { if (stmt.value.isSimple) { @@ -182,14 +177,19 @@ class Inliner(private val program: Program, private val options: CompilationOpti // note that we don't have to process any args, because we online inline parameterless subroutines. when (val toInline = sub.statements.first()) { is Return -> { - val fcall = toInline.value as? FunctionCallExpression - if(fcall!=null) { - // insert the function call expression as a void function call directly - sub.hasBeenInlined=true - val call = FunctionCallStatement(fcall.target.copy(), fcall.args.map { it.copy() }.toMutableList(), true, fcall.position) - listOf(IAstModification.ReplaceNode(origNode, call, parent)) - } else + // TODO consider multi-value returns as well + if(toInline.values.size!=1) noModifications + else { + val fcall = toInline.values[0] as? FunctionCallExpression + if(fcall!=null) { + // insert the function call expression as a void function call directly + sub.hasBeenInlined=true + val call = FunctionCallStatement(fcall.target.copy(), fcall.args.map { it.copy() }.toMutableList(), true, fcall.position) + listOf(IAstModification.ReplaceNode(origNode, call, parent)) + } else + noModifications + } } else -> { if(origNode !== toInline) { @@ -226,9 +226,10 @@ class Inliner(private val program: Program, private val options: CompilationOpti is Return -> { // is an expression, so we have to have a Return here in the inlined sub // note that we don't have to process any args, because we online inline parameterless subroutines. - if(toInline.value!=null && functionCallExpr!==toInline.value) { + // TODO consider multi-value returns as well + if(toInline.values.size==1 && functionCallExpr!==toInline.values[0]) { sub.hasBeenInlined=true - listOf(IAstModification.ReplaceNode(functionCallExpr, toInline.value!!.copy(), parent)) + listOf(IAstModification.ReplaceNode(functionCallExpr, toInline.values[0].copy(), parent)) } else noModifications diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index bd93945c8..ef5d25d61 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -122,25 +122,25 @@ internal class AstChecker(private val program: Program, throw FatalAstException("cannot use a return with one value in a subroutine that has multiple return values: $returnStmt") } - if(expectedReturnValues.isEmpty() && returnStmt.value!=null) { - errors.err("invalid number of return values", returnStmt.position) + if(returnStmt.values.sizeexpectedReturnValues.size) { + errors.err("too many return values for the subroutine: expected ${expectedReturnValues.size} got ${returnStmt.values.size}", returnStmt.position) } - if(expectedReturnValues.size==1 && returnStmt.value!=null) { - val valueDt = returnStmt.value!!.inferType(program) + for((expectedDt, actual) in expectedReturnValues.zip(returnStmt.values)) { + val valueDt = actual.inferType(program) if(valueDt.isKnown) { - if (expectedReturnValues[0] != valueDt.getOrUndef()) { - if(valueDt.isBool && expectedReturnValues[0].isUnsignedByte) { + if (expectedDt != valueDt.getOrUndef()) { + if(valueDt.isBool && expectedDt.isUnsignedByte) { // if the return value is a bool and the return type is ubyte, allow this. But give a warning. - errors.info("return type of the subroutine should probably be bool instead of ubyte", returnStmt.position) - } else if(valueDt.isIterable && expectedReturnValues[0].isUnsignedWord) { + errors.info("return type of the subroutine should probably be bool instead of ubyte", actual.position) + } else if(valueDt.isIterable && expectedDt.isUnsignedWord) { // you can return a string or array when an uword (pointer) is returned - } else if(valueDt issimpletype BaseDataType.UWORD && expectedReturnValues[0].isString) { + } else if(valueDt issimpletype BaseDataType.UWORD && expectedDt.isString) { // you can return an uword pointer when the return type is a string } else { - errors.err("type $valueDt of return value doesn't match subroutine's return type ${expectedReturnValues[0]}",returnStmt.value!!.position) + errors.err("type $valueDt of return value doesn't match subroutine's return type ${expectedDt}", actual.position) } } } diff --git a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt index 8369edc71..14174a656 100644 --- a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt +++ b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt @@ -51,14 +51,14 @@ internal class BeforeAsmAstChanger(val program: Program, private val options: Co // and if an assembly block doesn't contain a rts/rti. if (!subroutine.isAsmSubroutine) { if(subroutine.isEmpty()) { - val returnStmt = Return(null, subroutine.position) + val returnStmt = Return(arrayOf(), subroutine.position) mods += IAstModification.InsertLast(returnStmt, subroutine) } else { val last = subroutine.statements.last() if((last !is InlineAssembly || !last.hasReturnOrRts()) && last !is Return) { val lastStatement = subroutine.statements.reversed().firstOrNull { it !is Subroutine } if(lastStatement !is Return) { - val returnStmt = Return(null, subroutine.position) + val returnStmt = Return(arrayOf(), subroutine.position) mods += IAstModification.InsertLast(returnStmt, subroutine) } } @@ -76,7 +76,7 @@ internal class BeforeAsmAstChanger(val program: Program, private val options: Co && prevStmt !is Subroutine && prevStmt !is Return ) { - val returnStmt = Return(null, subroutine.position) + val returnStmt = Return(arrayOf(), subroutine.position) mods += IAstModification.InsertAfter(outerStatements[subroutineStmtIdx - 1], returnStmt, outerScope) } } diff --git a/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt b/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt index 1d2ba6b61..fb7d1b8b0 100644 --- a/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt +++ b/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt @@ -484,8 +484,7 @@ class IntermediateAstMaker(private val program: Program, private val errors: IEr private fun transform(srcNode: Return): PtReturn { val ret = PtReturn(srcNode.position) - if(srcNode.value!=null) - ret.add(transformExpression(srcNode.value!!)) + srcNode.values.forEach { ret.add(transformExpression(it)) } return ret } diff --git a/compiler/src/prog8/compiler/astprocessing/IntermediateAstPostprocess.kt b/compiler/src/prog8/compiler/astprocessing/IntermediateAstPostprocess.kt index cd1e47ab6..f15eb3d29 100644 --- a/compiler/src/prog8/compiler/astprocessing/IntermediateAstPostprocess.kt +++ b/compiler/src/prog8/compiler/astprocessing/IntermediateAstPostprocess.kt @@ -143,14 +143,17 @@ private fun integrateDefers(subdefers: Map>, program: PtPro // return exits for(ret in returnsToAugment) { - val value = ret.value - if(value==null || notComplex(value)) { + if(ret.children.isEmpty() || ret.children.all { notComplex(it as PtExpression) }) { invokedeferbefore(ret) continue } // complex return value, need to store it before calling the defer block - val (pushCall, popCall) = makePushPopFunctionCalls(value) + if(ret.children.size>1) { + TODO("multi-value return ; defer") + } + + val (pushCall, popCall) = makePushPopFunctionCalls(ret.children[0] as PtExpression) val newRet = PtReturn(ret.position) newRet.add(popCall) val group = PtNodeGroup() diff --git a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt index 3a3b49086..216042332 100644 --- a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt +++ b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt @@ -327,31 +327,44 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val override fun after(returnStmt: Return, parent: Node): Iterable { // add a typecast to the return type if it doesn't match the subroutine's signature // but only if no data loss occurs - val returnValue = returnStmt.value + if (returnStmt.values.isEmpty()) + return noModifications + val subroutine = returnStmt.definingSubroutine!! + if (subroutine.returntypes.size != returnStmt.values.size) + return noModifications + + for((index, pair) in returnStmt.values.zip(subroutine.returntypes).withIndex()) { + val (returnValue, subReturnType) = pair + println("$index $returnValue -> $subReturnType") + } + + // 1 or more return values to check. + val returnValue = returnStmt.values.singleOrNull() if(returnValue!=null) { - val subroutine = returnStmt.definingSubroutine!! - if(subroutine.returntypes.size==1) { - val subReturnType = subroutine.returntypes.first() - val returnDt = returnValue.inferType(program) - if(!(returnDt istype subReturnType) && returnValue is NumericLiteral) { - // see if we might change the returnvalue into the expected type - val castedValue = returnValue.convertTypeKeepValue(subReturnType.base) - if(castedValue.isValid) { - return listOf(IAstModification.ReplaceNode(returnValue, castedValue.valueOrZero(), returnStmt)) - } - } - if (returnDt istype subReturnType or returnDt.isNotAssignableTo(subReturnType)) - return noModifications - if (returnValue is NumericLiteral) { - val cast = returnValue.cast(subReturnType.base, true) - if(cast.isValid) - returnStmt.value = cast.valueOrZero() - } else { - val modifications = mutableListOf() - addTypecastOrCastedValueModification(modifications, returnValue, subReturnType.base, returnStmt) - return modifications + val subReturnType = subroutine.returntypes.single() + val returnDt = returnValue.inferType(program) + if(!(returnDt istype subReturnType) && returnValue is NumericLiteral) { + // see if we might change the returnvalue into the expected type + val castedValue = returnValue.convertTypeKeepValue(subReturnType.base) + if(castedValue.isValid) { + return listOf(IAstModification.ReplaceNode(returnValue, castedValue.valueOrZero(), returnStmt)) } } + if (returnDt istype subReturnType or returnDt.isNotAssignableTo(subReturnType)) + return noModifications + if (returnValue is NumericLiteral) { + val cast = returnValue.cast(subReturnType.base, true) + if(cast.isValid) { + returnStmt.values[0] = cast.valueOrZero() + } + } else { + val modifications = mutableListOf() + addTypecastOrCastedValueModification(modifications, returnValue, subReturnType.base, returnStmt) + return modifications + } + } + else if(returnStmt.values.size>1) { + TODO("multi-value return ; typecast") } return noModifications } diff --git a/compiler/test/ast/TestVariousCompilerAst.kt b/compiler/test/ast/TestVariousCompilerAst.kt index 0b27dbe0d..a529adc32 100644 --- a/compiler/test/ast/TestVariousCompilerAst.kt +++ b/compiler/test/ast/TestVariousCompilerAst.kt @@ -961,7 +961,7 @@ main { val ifscope_return = ifscope.children[2] as PtReturn ifscope_defer.name shouldBe "p8b_main.p8s_test.p8s_prog8_invoke_defers" ifscope_push.name shouldBe "sys.pushw" - (ifscope_return.value as PtFunctionCall).name shouldBe "sys.popw" + (ifscope_return.children.single() as PtFunctionCall).name shouldBe "sys.popw" val ending = sub.children[6] as PtFunctionCall ending.name shouldBe "p8b_main.p8s_test.p8s_prog8_invoke_defers" diff --git a/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt b/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt index e397768fc..a033d98c7 100644 --- a/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt +++ b/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt @@ -399,8 +399,10 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program: } override fun visit(returnStmt: Return) { - output("return ") - returnStmt.value?.accept(this) + if(returnStmt.values.isEmpty()) + output("return") + else + output("return ${returnStmt.values.map{it.accept(this)}.joinToString(", ")}") } override fun visit(arrayIndexedExpression: ArrayIndexedExpression) { diff --git a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt index eec024385..ce066c3bc 100644 --- a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt +++ b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt @@ -298,7 +298,8 @@ private fun InlineirContext.toAst(): InlineAssembly { } private fun ReturnstmtContext.toAst() : Return { - return Return(expression()?.toAst(), toPosition()) + val values = if(returnvalues()==null || returnvalues().expression().size==0) arrayOf() else returnvalues().expression().map { it.toAst() }.toTypedArray() + return Return(values, toPosition()) } private fun UnconditionaljumpContext.toAst(): Jump { @@ -313,11 +314,11 @@ private fun AliasContext.toAst(): Statement = private fun SubroutineContext.toAst() : Subroutine { // non-asm subroutine - val returntype = sub_return_part()?.datatype()?.toAst() + val returntypes = sub_return_part()?.datatype()?.map { it.toAst() } ?: emptyList() return Subroutine( identifier().text, sub_params()?.toAst()?.toMutableList() ?: mutableListOf(), - if (returntype == null) mutableListOf() else mutableListOf(DataType.forDt(returntype)), + returntypes.map { DataType.forDt(it) }.toMutableList(), emptyList(), emptyList(), emptySet(), diff --git a/compilerAst/src/prog8/ast/statements/AstStatements.kt b/compilerAst/src/prog8/ast/statements/AstStatements.kt index d1d4d2236..e4947c213 100644 --- a/compilerAst/src/prog8/ast/statements/AstStatements.kt +++ b/compilerAst/src/prog8/ast/statements/AstStatements.kt @@ -4,6 +4,7 @@ import prog8.ast.* import prog8.ast.expressions.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstVisitor +import prog8.code.ast.PtExpression import prog8.code.core.* import java.util.* @@ -174,25 +175,26 @@ data class Label(override val name: String, override val position: Position) : S override fun toString()= "Label(name=$name, pos=$position)" } -class Return(var value: Expression?, override val position: Position) : Statement() { +class Return(val values: Array, override val position: Position) : Statement() { override lateinit var parent: Node override fun linkParents(parent: Node) { this.parent = parent - value?.linkParents(this) + values.forEach { it.linkParents(this) } } override fun replaceChildNode(node: Node, replacement: Node) { - require(replacement is Expression) - value = replacement - replacement.parent = this + val index = values.indexOf(node) + if(replacement is Expression && index>=0) { + values[index] = replacement + } else throw FatalAstException("invalid replace") } - override fun copy() = Return(value?.copy(), position) - override fun referencesIdentifier(nameInSource: List): Boolean = value?.referencesIdentifier(nameInSource)==true + override fun copy() = Return(values.map { it.copy() }.toTypedArray(), position) + override fun referencesIdentifier(nameInSource: List): Boolean = values.any{ it.referencesIdentifier(nameInSource) } override fun accept(visitor: IAstVisitor) = visitor.visit(this) override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) - override fun toString() = "Return($value, pos=$position)" + override fun toString() = "Return($values, pos=$position)" } class Break(override val position: Position) : Statement() { diff --git a/compilerAst/src/prog8/ast/walk/AstWalker.kt b/compilerAst/src/prog8/ast/walk/AstWalker.kt index 39a2cae59..cc18603cd 100644 --- a/compilerAst/src/prog8/ast/walk/AstWalker.kt +++ b/compilerAst/src/prog8/ast/walk/AstWalker.kt @@ -412,7 +412,7 @@ abstract class AstWalker { fun visit(returnStmt: Return, parent: Node) { track(before(returnStmt, parent), returnStmt, parent) - returnStmt.value?.accept(this, returnStmt) + returnStmt.values.forEach { it -> it.accept(this, returnStmt) } track(after(returnStmt, parent), returnStmt, parent) } diff --git a/compilerAst/src/prog8/ast/walk/IAstVisitor.kt b/compilerAst/src/prog8/ast/walk/IAstVisitor.kt index c1f0609d7..49fad19df 100644 --- a/compilerAst/src/prog8/ast/walk/IAstVisitor.kt +++ b/compilerAst/src/prog8/ast/walk/IAstVisitor.kt @@ -142,7 +142,7 @@ interface IAstVisitor { } fun visit(returnStmt: Return) { - returnStmt.value?.accept(this) + returnStmt.values.forEach { it->it.accept(this) } } fun visit(arrayIndexedExpression: ArrayIndexedExpression) { diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 0525709bc..94c870816 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -1,6 +1,8 @@ TODO ==== +- implement the TODO multi-value return occurences. + - add paypal donation button as well? - announce prog8 on the 6502.org site? diff --git a/examples/test.p8 b/examples/test.p8 index 1deafb30d..4a8d31467 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,16 +1,17 @@ -%import textio %option no_sysinit %zeropage basicsafe main { sub start() { - cx16.r0L, cx16.r1, cx16.r2 = multiasm() - cx16.r0L = multiasm() + cx16.r0,cx16.r1 = single() + cx16.r0 = multi() } - asmsub multiasm() -> ubyte @A, uword @R1 { - %asm {{ - rts - }} + sub single() -> uword { + return 42+cx16.r0L + } + + sub multi() -> uword, uword { + return 42+cx16.r0L, 99 } } diff --git a/parser/src/main/antlr/Prog8ANTLR.g4 b/parser/src/main/antlr/Prog8ANTLR.g4 index 3d7442fc7..12dd8cb0a 100644 --- a/parser/src/main/antlr/Prog8ANTLR.g4 +++ b/parser/src/main/antlr/Prog8ANTLR.g4 @@ -222,7 +222,9 @@ expression_list : expression (',' EOL? expression)* // you can split the expression list over several lines ; -returnstmt : 'return' expression? ; +returnstmt : 'return' returnvalues? ; + +returnvalues: expression (',' expression)* ; breakstmt : 'break'; @@ -264,7 +266,7 @@ subroutine : 'sub' identifier '(' sub_params? ')' sub_return_part? EOL? (statement_block EOL?) ; -sub_return_part : '->' datatype ; +sub_return_part : '->' datatype (',' datatype)* ; statement_block : '{' EOL?