diff --git a/codeGeneration/src/prog8/codegen/target/cpu6502/codegen/AsmGen.kt b/codeGeneration/src/prog8/codegen/target/cpu6502/codegen/AsmGen.kt index d82e3a42a..75197850b 100644 --- a/codeGeneration/src/prog8/codegen/target/cpu6502/codegen/AsmGen.kt +++ b/codeGeneration/src/prog8/codegen/target/cpu6502/codegen/AsmGen.kt @@ -616,6 +616,16 @@ class AsmGen(private val program: Program, fun asmSymbolName(name: Iterable) = fixNameSymbols(name.joinToString(".")) fun asmVariableName(name: Iterable) = fixNameSymbols(name.joinToString(".")) + fun getTempVarName(dt: DataType): List { + return when(dt) { + DataType.UBYTE -> listOf("cx16", "r9L") + DataType.BYTE -> listOf("cx16", "r9sL") + DataType.UWORD -> listOf("cx16", "r9") + DataType.WORD -> listOf("cx16", "r9s") + DataType.FLOAT -> listOf("floats", "tempvar_swap_float") // defined in floats.p8 + else -> throw FatalAstException("invalid dt $dt") + } + } internal fun loadByteFromPointerIntoA(pointervar: IdentifierReference): String { // returns the source name of the zero page pointervar if it's already in the ZP, @@ -840,6 +850,7 @@ class AsmGen(private val program: Program, is RepeatLoop -> translate(stmt) is When -> translate(stmt) is AnonymousScope -> translate(stmt) + is Pipe -> translate(stmt) is BuiltinFunctionPlaceholder -> throw AssemblyError("builtin function should not have placeholder anymore") is UntilLoop -> throw AssemblyError("do..until should have been converted to jumps") is WhileLoop -> throw AssemblyError("while should have been converted to jumps") @@ -1614,6 +1625,53 @@ $label nop""") assemblyLines.add(assembly) } + private fun translate(pipe: Pipe) { + + // TODO more efficient code generation to avoid needless assignments to the temp var + + var valueDt = pipe.valueDatatype(program) + var valueVar = getTempVarName(valueDt) + val subroutine = pipe.definingSubroutine + assignExpressionToVariable(pipe.expressions.first(), valueVar.joinToString("."), valueDt, subroutine) + pipe.expressions.drop(1).dropLast(1).forEach { + val callName = it as IdentifierReference + val args = mutableListOf(IdentifierReference(valueVar, it.position)) + val call = FunctionCallExpr(callName, args,it.position) + call.linkParents(pipe) + valueDt = call.inferType(program).getOrElse { throw AssemblyError("invalid dt") } + valueVar = getTempVarName(valueDt) + assignExpressionToVariable(call, valueVar.joinToString("."), valueDt, subroutine) + } + // the last term in the pipe: + val callName = pipe.expressions.last() as IdentifierReference + val callTarget = callName.targetStatement(program)!! + when (callTarget) { + is BuiltinFunctionPlaceholder -> { + val args = mutableListOf(IdentifierReference(valueVar, callName.position)) + val call = FunctionCallStatement(callName, args, true, callName.position) + call.linkParents(pipe) + translate(call) + } + is Subroutine -> { + if(callTarget.isAsmSubroutine) { + val args = mutableListOf(IdentifierReference(valueVar, callName.position)) + val call = FunctionCallStatement(callName, args, true, callName.position) + call.linkParents(pipe) + translate(call) + } else { + // have to use GoSub and manual parameter assignment, because no codegen for FunctionCallStmt here + val param = callTarget.parameters.single() + val paramName = callTarget.scopedName.joinToString(".") + ".${param.name}" + val tempvar = IdentifierReference(valueVar, callName.position) + tempvar.linkParents(pipe) + assignExpressionToVariable(tempvar, paramName, param.type, subroutine) + out(" jsr ${asmSymbolName(callName)}") + } + } + else -> throw AssemblyError("invalid call target") + } + } + internal fun signExtendAYlsb(valueDt: DataType) { // sign extend signed byte in A to full word in AY when(valueDt) { diff --git a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt index 915f8d8e4..401830445 100644 --- a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt +++ b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt @@ -8,12 +8,10 @@ import prog8.ast.base.FatalAstException import prog8.ast.base.IntegerDatatypes import prog8.ast.base.NumericDatatypes import prog8.ast.expressions.* -import prog8.ast.statements.AnonymousScope -import prog8.ast.statements.Assignment -import prog8.ast.statements.IfElse -import prog8.ast.statements.Jump +import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification +import prog8.compilerinterface.IErrorReporter import kotlin.math.abs import kotlin.math.log2 import kotlin.math.pow @@ -26,7 +24,7 @@ import kotlin.math.pow */ -class ExpressionSimplifier(private val program: Program) : AstWalker() { +class ExpressionSimplifier(private val program: Program, private val errors: IErrorReporter) : AstWalker() { private val powersOfTwo = (1..16).map { (2.0).pow(it) }.toSet() private val negativePowersOfTwo = powersOfTwo.map { -it }.toSet() @@ -363,6 +361,24 @@ class ExpressionSimplifier(private val program: Program) : AstWalker() { return noModifications } + override fun after(pipe: Pipe, parent: Node): Iterable { + val firstValue = pipe.expressions.first() + if(firstValue.isSimple) { + val funcname = pipe.expressions[1] as IdentifierReference + val first = FunctionCallExpr(funcname.copy(), mutableListOf(firstValue), firstValue.position) + val newExprs = mutableListOf(first) + newExprs.addAll(pipe.expressions.drop(2)) + return listOf(IAstModification.ReplaceNode(pipe, Pipe(newExprs, pipe.position), parent)) + } + val singleExpr = pipe.expressions.singleOrNull() + if(singleExpr!=null) { + val callExpr = singleExpr as FunctionCallExpr + val call = FunctionCallStatement(callExpr.target, callExpr.args, true, callExpr.position) + return listOf(IAstModification.ReplaceNode(pipe, call, parent)) + } + return noModifications + } + private fun determineY(x: Expression, subBinExpr: BinaryExpression): Expression? { return when { subBinExpr.left isSameAs x -> subBinExpr.right diff --git a/codeOptimizers/src/prog8/optimizer/Extensions.kt b/codeOptimizers/src/prog8/optimizer/Extensions.kt index 08a2b46e2..f01d53598 100644 --- a/codeOptimizers/src/prog8/optimizer/Extensions.kt +++ b/codeOptimizers/src/prog8/optimizer/Extensions.kt @@ -57,8 +57,8 @@ fun Program.optimizeStatements(errors: IErrorReporter, return optimizationCount } -fun Program.simplifyExpressions() : Int { - val opti = ExpressionSimplifier(this) +fun Program.simplifyExpressions(errors: IErrorReporter) : Int { + val opti = ExpressionSimplifier(this, errors) opti.visit(this) return opti.applyModifications() } diff --git a/compiler/res/version.txt b/compiler/res/version.txt index 38abeb202..61228f793 100644 --- a/compiler/res/version.txt +++ b/compiler/res/version.txt @@ -1 +1 @@ -7.6 +7.7-dev diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 74ac00ebb..b8ffdfe51 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -282,7 +282,7 @@ private fun processAst(program: Program, errors: IErrorReporter, compilerOptions errors.report() program.addTypecasts(errors, compilerOptions) errors.report() - program.variousCleanups(program, errors) + program.variousCleanups(program, errors, compilerOptions) errors.report() program.checkValid(errors, compilerOptions) errors.report() @@ -300,7 +300,7 @@ private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, e while (true) { // keep optimizing expressions and statements until no more steps remain - val optsDone1 = program.simplifyExpressions() + val optsDone1 = program.simplifyExpressions(errors) val optsDone2 = program.splitBinaryExpressions(compilerOptions, compTarget) val optsDone3 = program.optimizeStatements(errors, functions, compTarget) program.constantFold(errors, compTarget) // because simplified statements and expressions can result in more constants that can be folded away @@ -315,7 +315,7 @@ private fun postprocessAst(program: Program, errors: IErrorReporter, compilerOpt program.desugaring(errors) program.addTypecasts(errors, compilerOptions) errors.report() - program.variousCleanups(program, errors) + program.variousCleanups(program, errors, compilerOptions) val callGraph = CallGraph(program) callGraph.checkRecursiveCalls(errors) errors.report() diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index 910e1b2d9..12be72981 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -1,12 +1,10 @@ package prog8.compiler.astprocessing -import prog8.ast.INameScope -import prog8.ast.IStatementContainer -import prog8.ast.Module -import prog8.ast.Program +import prog8.ast.* import prog8.ast.base.* import prog8.ast.expressions.* import prog8.ast.statements.* +import prog8.ast.walk.IAstModification import prog8.ast.walk.IAstVisitor import prog8.compilerinterface.* import java.io.CharConversionException @@ -1232,6 +1230,64 @@ internal class AstChecker(private val program: Program, super.visit(containment) } + override fun visit(pipe: Pipe) { + // first expression is just any expression producing a value + // all other expressions should be the name of a unary function that returns a single value + // the last expression should be the name of a unary function whose return value we don't care about. + if (pipe.expressions.size < 2) { + errors.err("pipe is missing one or more expressions", pipe.position) + } else { + // invalid size and other issues will be handled by the ast checker later. + var valueDt = pipe.expressions[0].inferType(program).getOrElse { throw FatalAstException("invalid dt") } + + for(expr in pipe.expressions.drop(1)) { // just keep the first expression value as-is + val functionName = expr as? IdentifierReference + val function = functionName?.targetStatement(program) + if(functionName!=null && function!=null) { + when (function) { + is BuiltinFunctionPlaceholder -> { + val func = BuiltinFunctions.getValue(function.name) + if(func.parameters.size!=1) + errors.err("can only use unary function", expr.position) + else if(func.known_returntype==null && expr !== pipe.expressions.last()) + errors.err("function must return a single value", expr.position) + + val paramDts = func.parameters.firstOrNull()?.possibleDatatypes + if(paramDts!=null && !paramDts.any { valueDt isAssignableTo it }) + errors.err("pipe value datatype $valueDt incompatible withfunction argument ${paramDts.toList()}", functionName.position) + + valueDt = func.known_returntype!! + } + is Subroutine -> { + if(function.parameters.size!=1) + errors.err("can only use unary function", expr.position) + else if(function.returntypes.size!=1 && expr !== pipe.expressions.last()) + errors.err("function must return a single value", expr.position) + + val paramDt = function.parameters.firstOrNull()?.type + if(paramDt!=null && !(valueDt isAssignableTo paramDt)) + errors.err("pipe value datatype $valueDt incompatible with function argument $paramDt", functionName.position) + + if(function.returntypes.isNotEmpty()) + valueDt = function.returntypes.single() + } + else -> { + throw FatalAstException("weird function") + } + } + } else { + if(expr is IFunctionCall) + errors.err("use only the name of the function, not a call", expr.position) + else + errors.err("can only use unary function", expr.position) + } + } + } + + return super.visit(pipe) + } + + private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: Statement): Statement? { when (val targetStatement = target.targetStatement(program)) { is Label, is Subroutine, is BuiltinFunctionPlaceholder -> return targetStatement diff --git a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt index 3e3755164..8e2eab4cf 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt @@ -96,8 +96,8 @@ internal fun Program.checkIdentifiers(errors: IErrorReporter, program: Program, } } -internal fun Program.variousCleanups(program: Program, errors: IErrorReporter) { - val process = VariousCleanups(program, errors) +internal fun Program.variousCleanups(program: Program, errors: IErrorReporter, options: CompilationOptions) { + val process = VariousCleanups(program, errors, options) process.visit(this) if(errors.noErrors()) process.applyModifications() diff --git a/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt b/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt index 36a7ef0ee..b001f2b1b 100644 --- a/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt +++ b/compiler/src/prog8/compiler/astprocessing/StatementReorderer.kt @@ -382,118 +382,130 @@ internal class StatementReorderer(val program: Program, override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable { val function = functionCallStatement.target.targetStatement(program)!! checkUnusedReturnValues(functionCallStatement, function, program, errors) - if(function is Subroutine) { - if(function.inline) - return noModifications - return if(function.isAsmSubroutine) - replaceCallAsmSubStatementWithGosub(function, functionCallStatement, parent) - else - replaceCallSubStatementWithGosub(function, functionCallStatement, parent) - } - return noModifications + return replaceCallByGosub(functionCallStatement, parent, program, options) + } +} + + +internal fun replaceCallByGosub(functionCallStatement: FunctionCallStatement, + parent: Node, + program: Program, + options: CompilationOptions): Iterable { + val function = functionCallStatement.target.targetStatement(program)!! + if(function is Subroutine) { + if(function.inline) + return emptyList() + return if(function.isAsmSubroutine) + replaceCallAsmSubStatementWithGosub(function, functionCallStatement, parent, options) + else + replaceCallSubStatementWithGosub(function, functionCallStatement, parent, program) + } + return emptyList() +} + +private fun replaceCallSubStatementWithGosub(function: Subroutine, call: FunctionCallStatement, parent: Node, program: Program): Iterable { + val noModifications = emptyList() + + if(function.parameters.isEmpty()) { + // 0 params -> just GoSub + return listOf(IAstModification.ReplaceNode(call, GoSub(null, call.target, null, call.position), parent)) } - private fun replaceCallSubStatementWithGosub(function: Subroutine, call: FunctionCallStatement, parent: Node): Iterable { - if(function.parameters.isEmpty()) { - // 0 params -> just GoSub - return listOf(IAstModification.ReplaceNode(call, GoSub(null, call.target, null, call.position), parent)) + if(function.parameters.size==1) { + if(function.parameters[0].type in IntegerDatatypes) { + // optimization: 1 integer param is passed via register(s) directly, not by assignment to param variable + return noModifications } + } + else if(function.parameters.size==2) { + if(function.parameters[0].type in ByteDatatypes && function.parameters[1].type in ByteDatatypes) { + // optimization: 2 simple byte param is passed via 2 registers directly, not by assignment to param variables + return noModifications + } + } - if(function.parameters.size==1) { - if(function.parameters[0].type in IntegerDatatypes) { - // optimization: 1 integer param is passed via register(s) directly, not by assignment to param variable - return noModifications - } - } - else if(function.parameters.size==2) { - if(function.parameters[0].type in ByteDatatypes && function.parameters[1].type in ByteDatatypes) { - // optimization: 2 simple byte param is passed via 2 registers directly, not by assignment to param variables - return noModifications + val assignParams = + function.parameters.zip(call.args).map { + var argumentValue = it.second + val paramIdentifier = IdentifierReference(function.scopedName + it.first.name, argumentValue.position) + val argDt = argumentValue.inferType(program).getOrElse { throw FatalAstException("invalid dt") } + if(argDt in ArrayDatatypes) { + // pass the address of the array instead + if(argumentValue is IdentifierReference) + argumentValue = AddressOf(argumentValue, argumentValue.position) } + Assignment(AssignTarget(paramIdentifier, null, null, argumentValue.position), argumentValue, argumentValue.position) } + val scope = AnonymousScope(assignParams.toMutableList(), call.position) + scope.statements += GoSub(null, call.target, null, call.position) + return listOf(IAstModification.ReplaceNode(call, scope, parent)) +} - val assignParams = - function.parameters.zip(call.args).map { - var argumentValue = it.second - val paramIdentifier = IdentifierReference(function.scopedName + it.first.name, argumentValue.position) - val argDt = argumentValue.inferType(program).getOrElse { throw FatalAstException("invalid dt") } - if(argDt in ArrayDatatypes) { - // pass the address of the array instead - if(argumentValue is IdentifierReference) - argumentValue = AddressOf(argumentValue, argumentValue.position) - } - Assignment(AssignTarget(paramIdentifier, null, null, argumentValue.position), argumentValue, argumentValue.position) - } - val scope = AnonymousScope(assignParams.toMutableList(), call.position) +private fun replaceCallAsmSubStatementWithGosub(function: Subroutine, call: FunctionCallStatement, parent: Node, options: CompilationOptions): Iterable { + val noModifications = emptyList() + + if(function.parameters.isEmpty()) { + // 0 params -> just GoSub + val scope = AnonymousScope(mutableListOf(), call.position) + if(function.shouldSaveX()) { + scope.statements += FunctionCallStatement(IdentifierReference(listOf("rsavex"), call.position), mutableListOf(), true, call.position) + } scope.statements += GoSub(null, call.target, null, call.position) + if(function.shouldSaveX()) { + scope.statements += FunctionCallStatement(IdentifierReference(listOf("rrestorex"), call.position), mutableListOf(), true, call.position) + } + return listOf(IAstModification.ReplaceNode(call, scope, parent)) + } else if(!options.compTarget.asmsubArgsHaveRegisterClobberRisk(call.args, function.asmParameterRegisters)) { + // No register clobber risk, let the asmgen assign values to the registers directly. + // this is more efficient than first evaluating them to the stack. + // As complex expressions will be flagged as a clobber-risk, these will be simplified below. + return noModifications + } else { + // clobber risk; evaluate the arguments on the CPU stack first (in reverse order)... + val argOrder = options.compTarget.asmsubArgsEvalOrder(function) + val scope = AnonymousScope(mutableListOf(), call.position) + if(function.shouldSaveX()) { + scope.statements += FunctionCallStatement(IdentifierReference(listOf("rsavex"), call.position), mutableListOf(), true, call.position) + } + argOrder.reversed().forEach { + val arg = call.args[it] + val param = function.parameters[it] + scope.statements += pushCall(arg, param.type, arg.position) + } + // ... and pop them off again into the registers. + argOrder.forEach { + val param = function.parameters[it] + val targetName = function.scopedName + param.name + scope.statements += popCall(targetName, param.type, call.position) + } + scope.statements += GoSub(null, call.target, null, call.position) + if(function.shouldSaveX()) { + scope.statements += FunctionCallStatement(IdentifierReference(listOf("rrestorex"), call.position), mutableListOf(), true, call.position) + } return listOf(IAstModification.ReplaceNode(call, scope, parent)) } - - private fun replaceCallAsmSubStatementWithGosub(function: Subroutine, call: FunctionCallStatement, parent: Node): Iterable { - if(function.parameters.isEmpty()) { - // 0 params -> just GoSub - val scope = AnonymousScope(mutableListOf(), call.position) - if(function.shouldSaveX()) { - scope.statements += FunctionCallStatement(IdentifierReference(listOf("rsavex"), call.position), mutableListOf(), true, call.position) - } - scope.statements += GoSub(null, call.target, null, call.position) - if(function.shouldSaveX()) { - scope.statements += FunctionCallStatement(IdentifierReference(listOf("rrestorex"), call.position), mutableListOf(), true, call.position) - } - return listOf(IAstModification.ReplaceNode(call, scope, parent)) - } else if(!options.compTarget.asmsubArgsHaveRegisterClobberRisk(call.args, function.asmParameterRegisters)) { - // No register clobber risk, let the asmgen assign values to the registers directly. - // this is more efficient than first evaluating them to the stack. - // As complex expressions will be flagged as a clobber-risk, these will be simplified below. - return noModifications - } else { - // clobber risk; evaluate the arguments on the CPU stack first (in reverse order)... - val argOrder = options.compTarget.asmsubArgsEvalOrder(function) - val scope = AnonymousScope(mutableListOf(), call.position) - if(function.shouldSaveX()) { - scope.statements += FunctionCallStatement(IdentifierReference(listOf("rsavex"), call.position), mutableListOf(), true, call.position) - } - argOrder.reversed().forEach { - val arg = call.args[it] - val param = function.parameters[it] - scope.statements += pushCall(arg, param.type, arg.position) - } - // ... and pop them off again into the registers. - argOrder.forEach { - val param = function.parameters[it] - val targetName = function.scopedName + param.name - scope.statements += popCall(targetName, param.type, call.position) - } - scope.statements += GoSub(null, call.target, null, call.position) - if(function.shouldSaveX()) { - scope.statements += FunctionCallStatement(IdentifierReference(listOf("rrestorex"), call.position), mutableListOf(), true, call.position) - } - return listOf(IAstModification.ReplaceNode(call, scope, parent)) - } - } - - private fun popCall(targetName: List, dt: DataType, position: Position): FunctionCallStatement { - return FunctionCallStatement( - IdentifierReference(listOf(if(dt in ByteDatatypes) "pop" else "popw"), position), - mutableListOf(IdentifierReference(targetName, position)), - true, position - ) - } - - private fun pushCall(value: Expression, dt: DataType, position: Position): FunctionCallStatement { - val pushvalue = when(dt) { - DataType.UBYTE, DataType.UWORD -> value - in PassByReferenceDatatypes -> value - DataType.BYTE -> TypecastExpression(value, DataType.UBYTE, true, position) - DataType.WORD -> TypecastExpression(value, DataType.UWORD, true, position) - else -> throw FatalAstException("invalid dt $dt $value") - } - - return FunctionCallStatement( - IdentifierReference(listOf(if(dt in ByteDatatypes) "push" else "pushw"), position), - mutableListOf(pushvalue), - true, position - ) - } - +} + +private fun popCall(targetName: List, dt: DataType, position: Position): FunctionCallStatement { + return FunctionCallStatement( + IdentifierReference(listOf(if(dt in ByteDatatypes) "pop" else "popw"), position), + mutableListOf(IdentifierReference(targetName, position)), + true, position + ) +} + +private fun pushCall(value: Expression, dt: DataType, position: Position): FunctionCallStatement { + val pushvalue = when(dt) { + DataType.UBYTE, DataType.UWORD -> value + in PassByReferenceDatatypes -> value + DataType.BYTE -> TypecastExpression(value, DataType.UBYTE, true, position) + DataType.WORD -> TypecastExpression(value, DataType.UWORD, true, position) + else -> throw FatalAstException("invalid dt $dt $value") + } + + return FunctionCallStatement( + IdentifierReference(listOf(if(dt in ByteDatatypes) "push" else "pushw"), position), + mutableListOf(pushvalue), + true, position + ) } diff --git a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt index 1b44a3abf..7feb57ee5 100644 --- a/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt +++ b/compiler/src/prog8/compiler/astprocessing/VariousCleanups.kt @@ -10,10 +10,11 @@ import prog8.ast.expressions.* import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification +import prog8.compilerinterface.CompilationOptions import prog8.compilerinterface.IErrorReporter -internal class VariousCleanups(val program: Program, val errors: IErrorReporter): AstWalker() { +internal class VariousCleanups(val program: Program, val errors: IErrorReporter, val options: CompilationOptions): AstWalker() { override fun before(nop: Nop, parent: Node): Iterable { return listOf(IAstModification.Remove(nop, parent as IStatementContainer)) @@ -198,5 +199,9 @@ internal class VariousCleanups(val program: Program, val errors: IErrorReporter) } return noModifications } + + override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable { + return replaceCallByGosub(functionCallStatement, parent, program, options) + } } diff --git a/compiler/test/TestPipes.kt b/compiler/test/TestPipes.kt new file mode 100644 index 000000000..dfa2a8aac --- /dev/null +++ b/compiler/test/TestPipes.kt @@ -0,0 +1,82 @@ +package prog8tests + +import io.kotest.core.spec.style.FunSpec +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.kotest.matchers.types.instanceOf +import prog8.ast.expressions.FunctionCallExpr +import prog8.ast.expressions.IdentifierReference +import prog8.ast.statements.Pipe +import prog8.codegen.target.C64Target +import prog8tests.helpers.ErrorReporterForTests +import prog8tests.helpers.assertFailure +import prog8tests.helpers.assertSuccess +import prog8tests.helpers.compileText + + +class TestPipes: FunSpec({ + + test("correct pipes") { + val text = """ + %import floats + %import textio + + main { + sub start() { + + 1.234 |> addfloat + |> floats.print_f + + 9999 |> addword + |> txt.print_uw + + } + + sub addfloat(float fl) -> float { + return fl+2.22 + } + sub addword(uword ww) -> uword { + return ww+2222 + } + } + """ + val result = compileText(C64Target, true, text, writeAssembly = true).assertSuccess() + val stmts = result.program.entrypoint.statements + stmts.size shouldBe 3 + val pipef = stmts[0] as Pipe + pipef.expressions.size shouldBe 2 + pipef.expressions[0] shouldBe instanceOf() + pipef.expressions[1] shouldBe instanceOf() + + val pipew = stmts[1] as Pipe + pipew.expressions.size shouldBe 2 + pipew.expressions[0] shouldBe instanceOf() + pipew.expressions[1] shouldBe instanceOf() + } + + test("incorrect type in pipe") { + val text = """ + %option enable_floats + + main { + sub start() { + + 1.234 |> addfloat + |> addword |> addword + } + + sub addfloat(float fl) -> float { + return fl+2.22 + } + sub addword(uword ww) -> uword { + return ww+2222 + } + } + """ + val errors = ErrorReporterForTests() + compileText(C64Target, false, text, errors=errors).assertFailure() + errors.errors.size shouldBe 1 + errors.errors[0] shouldContain "incompatible" + } + +}) diff --git a/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt b/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt index f308556ff..f1274512b 100644 --- a/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt +++ b/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt @@ -457,4 +457,16 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program: whenChoice.statements.accept(this) outputln("") } + + override fun visit(pipe: Pipe) { + pipe.expressions.first().accept(this) + outputln("") + scopelevel++ + pipe.expressions.drop(1).forEach { + outputi("|> ") + it.accept(this) + outputln("") + } + scopelevel-- + } } diff --git a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt index 142684ce3..0195febd9 100644 --- a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt +++ b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt @@ -182,6 +182,9 @@ private fun Prog8ANTLRParser.StatementContext.toAst() : Statement { val whenstmt = whenstmt()?.toAst() if(whenstmt!=null) return whenstmt + val pipestmt = pipestmt()?.toAst() + if(pipestmt!=null) return pipestmt + throw FatalAstException("unprocessed source text (are we missing ast conversion rules for parser elements?): $text") } @@ -612,3 +615,8 @@ private fun Prog8ANTLRParser.VardeclContext.toAst(): VarDecl { toPosition() ) } + +private fun Prog8ANTLRParser.PipestmtContext.toAst(): Pipe { + val expressions = expression().map { it.toAst() } + return Pipe(expressions.toMutableList(), toPosition()) +} diff --git a/compilerAst/src/prog8/ast/statements/AstStatements.kt b/compilerAst/src/prog8/ast/statements/AstStatements.kt index c6b0bae3d..9f2db3c6c 100644 --- a/compilerAst/src/prog8/ast/statements/AstStatements.kt +++ b/compilerAst/src/prog8/ast/statements/AstStatements.kt @@ -1014,3 +1014,27 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) override fun copy() = DirectMemoryWrite(addressExpression.copy(), position) } + + +class Pipe(val expressions: MutableList, override val position: Position): Statement() { + override lateinit var parent: Node + + override fun linkParents(parent: Node) { + this.parent = parent + expressions.forEach { it.linkParents(this) } + } + + override fun copy() = Pipe(expressions.map { it.copy() }.toMutableList(), position) + override fun accept(visitor: IAstVisitor) = visitor.visit(this) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) + + override fun replaceChildNode(node: Node, replacement: Node) { + require(node is Expression) + require(replacement is Expression) + val idx = expressions.indexOf(node) + expressions[idx] = replacement + } + + fun valueDatatype(program: Program): DataType = + expressions.first().inferType(program).getOrElse { throw FatalAstException("invalid dt") } +} \ No newline at end of file diff --git a/compilerAst/src/prog8/ast/walk/AstWalker.kt b/compilerAst/src/prog8/ast/walk/AstWalker.kt index fc8c23cb1..987d988d8 100644 --- a/compilerAst/src/prog8/ast/walk/AstWalker.kt +++ b/compilerAst/src/prog8/ast/walk/AstWalker.kt @@ -120,6 +120,7 @@ abstract class AstWalker { open fun before(whenChoice: WhenChoice, parent: Node): Iterable = noModifications open fun before(whenStmt: When, parent: Node): Iterable = noModifications open fun before(whileLoop: WhileLoop, parent: Node): Iterable = noModifications + open fun before(pipe: Pipe, parent: Node): Iterable = noModifications open fun after(addressOf: AddressOf, parent: Node): Iterable = noModifications open fun after(array: ArrayLiteralValue, parent: Node): Iterable = noModifications @@ -163,6 +164,7 @@ abstract class AstWalker { open fun after(whenChoice: WhenChoice, parent: Node): Iterable = noModifications open fun after(whenStmt: When, parent: Node): Iterable = noModifications open fun after(whileLoop: WhileLoop, parent: Node): Iterable = noModifications + open fun after(pipe: Pipe, parent: Node): Iterable = noModifications protected val modifications = mutableListOf>() @@ -457,5 +459,11 @@ abstract class AstWalker { whenChoice.statements.accept(this, whenChoice) track(after(whenChoice, parent), whenChoice, parent) } + + fun visit(pipe: Pipe, parent: Node) { + track(before(pipe, parent), pipe, parent) + pipe.expressions.forEach { it.accept(this, pipe) } + track(after(pipe, parent), pipe, parent) + } } diff --git a/compilerAst/src/prog8/ast/walk/IAstVisitor.kt b/compilerAst/src/prog8/ast/walk/IAstVisitor.kt index 2f847865b..d4db1402a 100644 --- a/compilerAst/src/prog8/ast/walk/IAstVisitor.kt +++ b/compilerAst/src/prog8/ast/walk/IAstVisitor.kt @@ -181,4 +181,8 @@ interface IAstVisitor { whenChoice.values?.forEach { it.accept(this) } whenChoice.statements.accept(this) } + + fun visit(pipe: Pipe) { + pipe.expressions.forEach { it.accept(this) } + } } diff --git a/compilerInterfaces/src/prog8/compilerinterface/CallGraph.kt b/compilerInterfaces/src/prog8/compilerinterface/CallGraph.kt index 05adfe6b4..3ff0e824a 100644 --- a/compilerInterfaces/src/prog8/compilerinterface/CallGraph.kt +++ b/compilerInterfaces/src/prog8/compilerinterface/CallGraph.kt @@ -5,9 +5,7 @@ import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.Position import prog8.ast.base.VarDeclType -import prog8.ast.expressions.AddressOf -import prog8.ast.expressions.FunctionCallExpr -import prog8.ast.expressions.IdentifierReference +import prog8.ast.expressions.* import prog8.ast.statements.* import prog8.ast.walk.IAstVisitor @@ -122,6 +120,21 @@ class CallGraph(private val program: Program) : IAstVisitor { allAssemblyNodes.add(inlineAssembly) } + override fun visit(pipe: Pipe) { + pipe.expressions.forEach { + if(it is IdentifierReference){ + val otherSub = it.targetSubroutine(program) + if(otherSub!=null) { + pipe.definingSubroutine?.let { thisSub -> + calls[thisSub] = calls.getValue(thisSub) + otherSub + calledBy[otherSub] = calledBy.getValue(otherSub) + pipe + } + } + } + } + super.visit(pipe) + } + fun checkRecursiveCalls(errors: IErrorReporter) { val cycles = recursionCycles() if(cycles.any()) { diff --git a/docs/source/index.rst b/docs/source/index.rst index aee9df5d5..9788c6f91 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -58,6 +58,7 @@ Language features - Conditional branches to map directly on processor branch instructions - ``when`` statement to avoid if-else chains - ``in`` expression for concise and efficient multi-value/containment test +- pipe operator ``|>`` to rewrite nested function call expressions in a more readable chained form - Nested subroutines can access variables from outer scopes to avoids the overhead to pass everything via parameters - Variable data types include signed and unsigned bytes and words, arrays, strings. - Floating point math also supported if the target system provides floating point library routines (C64 and Cx16 both do). diff --git a/docs/source/programming.rst b/docs/source/programming.rst index 8add75e38..24f88173e 100644 --- a/docs/source/programming.rst +++ b/docs/source/programming.rst @@ -685,6 +685,9 @@ The arguments in parentheses after the function name, should match the parameter If you want to ignore a return value of a subroutine, you should prefix the call with the ``void`` keyword. Otherwise the compiler will issue a warning about discarding a result value. +Deeply nested function calls can be rewritten as a chain using the *pipe operator* ``|>`` as long as they +are unary functions (taking a single argument). + .. note:: **Order of evaluation:** diff --git a/docs/source/syntaxreference.rst b/docs/source/syntaxreference.rst index f1de15eb3..6d74cc919 100644 --- a/docs/source/syntaxreference.rst +++ b/docs/source/syntaxreference.rst @@ -520,6 +520,25 @@ containment check: ``in`` txt.print("email address seems ok") } +pipe: ``|>`` + Used as an alternative to nesting function calls. The pipe operator is used to 'pipe' the value + into the next function. It only works on unary functions (taking a single argument), because you just + specify the *name* of the function that the value has to be piped into. The resulting value will be piped + into the next function as long as there are chained pipes. + For example, this: ``txt.print_uw(add_bonus(determine_score(get_player(1))))`` + can be rewritten as:: + + get_player(1) + |> determine_score + |> add_bonus + |> txt.print_uw + + or even:: + + 1 |> get_player + |> determine_score + |> add_bonus + |> txt.print_uw address of: ``&`` This is a prefix operator that can be applied to a string or array variable or literal value. diff --git a/docs/source/todo.rst b/docs/source/todo.rst index f8e7d0094..cdeb9f5fb 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -3,7 +3,10 @@ TODO For next compiler release (7.7) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -... +- optimize codegen of pipe operator to avoid needless assigns to temp var +- copying floats around: do it with a subroutine rather than 5 lda/sta pairs . + is slower but floats are very slow already anyway and this should take a lot less program size. + Need help with ^^^^^^^^^^^^^^ @@ -20,7 +23,6 @@ Blocked by an official Commander-x16 r39 release Future Things and Ideas ^^^^^^^^^^^^^^^^^^^^^^^ -- pipe operator ``|>`` - can we promise a left-to-right function call argument evaluation? without sacrificing performance - make it possible to use cpu opcodes such as 'nop' as variable names by prefixing all asm vars with something such as ``v_`` then we can get rid of the instruction lists in the machinedefinitions as well? diff --git a/examples/test.p8 b/examples/test.p8 index f9021c6b2..70844469c 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,33 +1,49 @@ %import textio -%import floats %zeropage basicsafe main { sub start() { - %asm {{ - lda #float5_111 - jsr floats.MOVFM - lda #float5_122 - jsr floats.FADD - jsr floats.FOUT - sta $7e - sty $7f - ldy #0 -_loop - lda ($7e),y - beq _done - jsr c64.CHROUT - iny - bne _loop -_done - rts - -float5_111 .byte $81, $0e, $14, $7a, $e1 ; float 1.11 -float5_122 .byte $81, $1c, $28, $f5, $c2 ; float 1.22 - - }} + 9 + 3 + |> add_one |> times_two + |> txt.print_uw } + sub add_one(ubyte input) -> ubyte { + return input+1 + } + + sub times_two(ubyte input) -> uword { + return input*$0002 + } } + + +;main { +; sub start() { +; %asm {{ +; lda #float5_111 +; jsr floats.MOVFM +; lda #float5_122 +; jsr floats.FADD +; jsr floats.FOUT +; sta $7e +; sty $7f +; ldy #0 +;_loop +; lda ($7e),y +; beq _done +; jsr c64.CHROUT +; iny +; bne _loop +;_done +; rts +; +;float5_111 .byte $81, $0e, $14, $7a, $e1 ; float 1.11 +;float5_122 .byte $81, $1c, $28, $f5, $c2 ; float 1.22 +; +; }} +; } +; +;} diff --git a/parser/antlr/Prog8ANTLR.g4 b/parser/antlr/Prog8ANTLR.g4 index 3d123afdb..d450a5c0b 100644 --- a/parser/antlr/Prog8ANTLR.g4 +++ b/parser/antlr/Prog8ANTLR.g4 @@ -57,6 +57,10 @@ ARRAYSIG : '[]' ; +PIPE : + '|>' + ; + cpuregister: 'A' | 'X' | 'Y'; register: 'A' | 'X' | 'Y' | 'AX' | 'AY' | 'XY' | 'Pc' | 'Pz' | 'Pn' | 'Pv' | 'R0' | 'R1' | 'R2' | 'R3' | 'R4' | 'R5' | 'R6' | 'R7' | 'R8' | 'R9' | 'R10' | 'R11' | 'R12' | 'R13' | 'R14' | 'R15'; @@ -97,6 +101,7 @@ statement : | repeatloop | whenstmt | breakstmt + | pipestmt | labeldef ; @@ -296,3 +301,4 @@ whenstmt: 'when' expression '{' EOL (when_choice | EOL) * '}' EOL? ; when_choice: (expression_list | 'else' ) '->' (statement | statement_block ) ; +pipestmt: expression EOL? PIPE expression ( EOL? PIPE expression)* ; diff --git a/syntax-files/IDEA/Prog8.xml b/syntax-files/IDEA/Prog8.xml index c47bc3ee1..88b263886 100644 --- a/syntax-files/IDEA/Prog8.xml +++ b/syntax-files/IDEA/Prog8.xml @@ -14,7 +14,7 @@ - + diff --git a/syntax-files/NotepadPlusPlus/Prog8.xml b/syntax-files/NotepadPlusPlus/Prog8.xml index c39aa7d40..dabe7d03f 100644 --- a/syntax-files/NotepadPlusPlus/Prog8.xml +++ b/syntax-files/NotepadPlusPlus/Prog8.xml @@ -27,8 +27,8 @@ void const str byte ubyte word uword float zp shared %address %asm %asmbinary %asminclude %breakpoint %import %launcher %option %output %zeropage %zpreserved inline sub asmsub romsub clobbers asm if when else if_cc if_cs if_eq if_mi if_neg if_nz if_pl if_pos if_vc if_vs if_z for in step do while repeat break return goto - abs acos all any asin atan avg callfar callrom ceil cmp cos cos16 cos16u cos8 cos8u cosr8 cosr8u cosr16 cosr16u deg floor len ln log2 lsb lsl lsr max memory min mkword msb peek peekw poke pokew push pushw pop popw rsave rsavex rrestore rrestorex rad reverse rnd rndf rndw rol rol2 ror ror2 round sgn sin sin16 sin16u sin8 sin8u sinr8 sinr8u sinr16 sinr16u sizeof sort sqrt sqrt16 sum swap tan - true false not and or xor as to downto + abs acos all any asin atan avg callfar callrom ceil cmp cos cos16 cos16u cos8 cos8u cosr8 cosr8u cosr16 cosr16u deg floor len ln log2 lsb lsl lsr max memory min mkword msb peek peekw poke pokew push pushw pop popw rsave rsavex rrestore rrestorex rad reverse rnd rndf rndw rol rol2 ror ror2 round sgn sin sin16 sin16u sin8 sin8u sinr8 sinr8u sinr16 sinr16u sizeof sort sqrt sqrt16 sum swap tan + true false not and or xor as to downto |> diff --git a/syntax-files/Vim/prog8.vim b/syntax-files/Vim/prog8.vim index 57afc410a..0def9f109 100644 --- a/syntax-files/Vim/prog8.vim +++ b/syntax-files/Vim/prog8.vim @@ -24,6 +24,7 @@ syn match prog8Function "\(\<\(asm\)\?sub\>\s\+\)\@16<=\<\w\+\>" syn match prog8Function "\(romsub\s\+$\x\+\s\+=\s\+\)\@16<=\<\w\+\>" syn keyword prog8Statement break goto return asmsub sub inline +syn match prog8Statement "|>" syn match prog8Statement "\<\(asm\|rom\)\?sub\>" syn keyword prog8Conditional if else when syn keyword prog8Conditional if_cs if_cc if_vs if_vc if_eq if_z if_ne if_nz