diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt index 9c60c2a0e..22f0f01fd 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/AsmGen.kt @@ -2837,44 +2837,39 @@ $repeatLabel lda $counterVar } } - internal fun translatePipeExpression(source: Expression, segments: Iterable, scope: Node, isStatement: Boolean, pushResultOnEstack: Boolean) { + internal fun translatePipeExpression(source: Expression, segments: List, scope: Node, isStatement: Boolean, pushResultOnEstack: Boolean) { // TODO more efficient code generation to avoid needless assignments to the temp var - TODO("translatePipeExpression") + // the source: an expression (could be anything) producing a value. + // one or more segment expressions, all are a IFunctionCall node, and LACKING the implicit first argument. + // when 'isStatement'=true, the last segment expression should be treated as a funcion call statement (discarding any result value if there is one) - -/* - // the first term: an expression (could be anything) producing a value. val subroutine = scope.definingSubroutine!! - val firstTerm = expressions.first() - var valueDt = firstTerm.inferType(program).getOrElse { throw FatalAstException("invalid dt") } + var valueDt = source.inferType(program).getOrElse { throw FatalAstException("invalid dt") } var valueSource: AsmAssignSource = - if(firstTerm is IFunctionCall) { - val resultReg = returnRegisterOfFunction(firstTerm.target, listOf(valueDt)) - assignExpressionToRegister(firstTerm, resultReg, valueDt in listOf(DataType.BYTE, DataType.WORD, DataType.FLOAT)) + if(source is IFunctionCall) { + val resultReg = returnRegisterOfFunction(source.target, listOf(valueDt)) + assignExpressionToRegister(source, resultReg, valueDt in listOf(DataType.BYTE, DataType.WORD, DataType.FLOAT)) AsmAssignSource(SourceStorageKind.REGISTER, program, this, valueDt, register = resultReg) } else { - AsmAssignSource.fromAstSource(firstTerm, program, this) + AsmAssignSource.fromAstSource(source, program, this) } - // the 2nd to N-1 terms: unary function calls taking a single param and producing a value. + // the segments (except the last one): unary function calls taking a single param and producing a value. // directly assign their argument from the previous call's returnvalue. - expressions.drop(1).dropLast(1).forEach { - valueDt = functioncallAsmGen.translateUnaryFunctionCallWithArgSource(it as IdentifierReference, valueSource, false, subroutine) - val resultReg = returnRegisterOfFunction(it, listOf(valueDt)) + segments.dropLast(1).forEach { + it as IFunctionCall + valueDt = translateUnaryFunctionCallWithArgSource(it.target, valueSource, false, subroutine) + val resultReg = returnRegisterOfFunction(it.target, listOf(valueDt)) valueSource = AsmAssignSource(SourceStorageKind.REGISTER, program, this, valueDt, register = resultReg) } - - // the last term: unary function call taking a single param and optionally producing a result value. + // the last segment: unary function call taking a single param and optionally producing a result value. + val lastCall = segments.last() as IFunctionCall if(isStatement) { - // the last term in the pipe, don't care about return var: - functioncallAsmGen.translateUnaryFunctionCallWithArgSource( - expressions.last() as IdentifierReference, valueSource, true, subroutine) + translateUnaryFunctionCallWithArgSource(lastCall.target, valueSource, true, subroutine) } else { - // the last term in the pipe, regular function call with returnvalue: - valueDt = functioncallAsmGen.translateUnaryFunctionCallWithArgSource( - expressions.last() as IdentifierReference, valueSource, false, subroutine) + valueDt = translateUnaryFunctionCallWithArgSource(lastCall.target, valueSource, false, subroutine) if(pushResultOnEstack) { when (valueDt) { in ByteDatatypes -> { @@ -2890,7 +2885,64 @@ $repeatLabel lda $counterVar } } } -*/ + } + + private fun translateUnaryFunctionCallWithArgSource(target: IdentifierReference, arg: AsmAssignSource, isStatement: Boolean, scope: Subroutine): DataType { + when(val targetStmt = target.targetStatement(program)!!) { + is BuiltinFunctionPlaceholder -> { + return if(isStatement) { + translateBuiltinFunctionCallStatement(targetStmt.name, listOf(arg), scope) + DataType.UNDEFINED + } else { + translateBuiltinFunctionCallExpression(targetStmt.name, listOf(arg), scope) + } + } + is Subroutine -> { + val argDt = targetStmt.parameters.single().type + if(targetStmt.isAsmSubroutine) { + // argument via registers + val argRegister = targetStmt.asmParameterRegisters.single().registerOrPair!! + val assignArgument = AsmAssignment( + arg, + AsmAssignTarget.fromRegisters(argRegister, argDt in SignedDatatypes, scope, program, this), + false, program.memsizer, target.position + ) + translateNormalAssignment(assignArgument) + } else { + val assignArgument: AsmAssignment = + if(functioncallAsmGen.optimizeIntArgsViaRegisters(targetStmt)) { + // argument goes via registers as optimization + val paramReg: RegisterOrPair = when(argDt) { + in ByteDatatypes -> RegisterOrPair.A + in WordDatatypes -> RegisterOrPair.AY + DataType.FLOAT -> RegisterOrPair.FAC1 + else -> throw AssemblyError("invalid dt") + } + AsmAssignment( + arg, + AsmAssignTarget(TargetStorageKind.REGISTER, program, this, argDt, scope, register = paramReg), + false, program.memsizer, target.position + ) + } else { + // arg goes via parameter variable + val argVarName = asmVariableName(targetStmt.scopedName + targetStmt.parameters.single().name) + AsmAssignment( + arg, + AsmAssignTarget(TargetStorageKind.VARIABLE, program, this, argDt, scope, argVarName), + false, program.memsizer, target.position + ) + } + translateNormalAssignment(assignArgument) + } + if(targetStmt.shouldSaveX()) + saveRegisterLocal(CpuRegister.X, scope) + out(" jsr ${asmSymbolName(target)}") + if(targetStmt.shouldSaveX()) + restoreRegisterLocal(CpuRegister.X) + return if(isStatement) DataType.UNDEFINED else targetStmt.returntypes.single() + } + else -> throw AssemblyError("invalid call target") + } } internal fun popCpuStack(dt: DataType, target: VarDecl, scope: Subroutine?) { diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/FunctionCallAsmGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/FunctionCallAsmGen.kt index 028b2a602..a4628c57d 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/FunctionCallAsmGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/FunctionCallAsmGen.kt @@ -129,64 +129,6 @@ internal class FunctionCallAsmGen(private val program: Program, private val asmg // remember: dealing with the X register and/or dealing with return values is the responsibility of the caller } - internal fun translateUnaryFunctionCallWithArgSource(target: IdentifierReference, arg: AsmAssignSource, isStatement: Boolean, scope: Subroutine): DataType { - when(val targetStmt = target.targetStatement(program)!!) { - is BuiltinFunctionPlaceholder -> { - return if(isStatement) { - asmgen.translateBuiltinFunctionCallStatement(targetStmt.name, listOf(arg), scope) - DataType.UNDEFINED - } else { - asmgen.translateBuiltinFunctionCallExpression(targetStmt.name, listOf(arg), scope) - } - } - is Subroutine -> { - val argDt = targetStmt.parameters.single().type - if(targetStmt.isAsmSubroutine) { - // argument via registers - val argRegister = targetStmt.asmParameterRegisters.single().registerOrPair!! - val assignArgument = AsmAssignment( - arg, - AsmAssignTarget.fromRegisters(argRegister, argDt in SignedDatatypes, scope, program, asmgen), - false, program.memsizer, target.position - ) - asmgen.translateNormalAssignment(assignArgument) - } else { - val assignArgument: AsmAssignment = - if(optimizeIntArgsViaRegisters(targetStmt)) { - // argument goes via registers as optimization - val paramReg: RegisterOrPair = when(argDt) { - in ByteDatatypes -> RegisterOrPair.A - in WordDatatypes -> RegisterOrPair.AY - DataType.FLOAT -> RegisterOrPair.FAC1 - else -> throw AssemblyError("invalid dt") - } - AsmAssignment( - arg, - AsmAssignTarget(TargetStorageKind.REGISTER, program, asmgen, argDt, scope, register = paramReg), - false, program.memsizer, target.position - ) - } else { - // arg goes via parameter variable - val argVarName = asmgen.asmVariableName(targetStmt.scopedName + targetStmt.parameters.single().name) - AsmAssignment( - arg, - AsmAssignTarget(TargetStorageKind.VARIABLE, program, asmgen, argDt, scope, argVarName), - false, program.memsizer, target.position - ) - } - asmgen.translateNormalAssignment(assignArgument) - } - if(targetStmt.shouldSaveX()) - asmgen.saveRegisterLocal(CpuRegister.X, scope) - asmgen.out(" jsr ${asmgen.asmSymbolName(target)}") - if(targetStmt.shouldSaveX()) - asmgen.restoreRegisterLocal(CpuRegister.X) - return if(isStatement) DataType.UNDEFINED else targetStmt.returntypes.single() - } - else -> throw AssemblyError("invalid call target") - } - } - private fun argumentsViaRegisters(sub: Subroutine, call: IFunctionCall) { if(sub.parameters.size==1) { argumentViaRegister(sub, IndexedValue(0, sub.parameters.single()), call.args[0]) diff --git a/codeGenTargets/src/prog8/codegen/target/Encoder.kt b/codeGenTargets/src/prog8/codegen/target/Encoder.kt index 22314b680..5f2aac98a 100644 --- a/codeGenTargets/src/prog8/codegen/target/Encoder.kt +++ b/codeGenTargets/src/prog8/codegen/target/Encoder.kt @@ -1,12 +1,12 @@ package prog8.codegen.target import com.github.michaelbull.result.fold -import prog8.ast.base.FatalAstException import prog8.codegen.target.cbm.AtasciiEncoding import prog8.codegen.target.cbm.IsoEncoding import prog8.codegen.target.cbm.PetsciiEncoding import prog8.compilerinterface.Encoding import prog8.compilerinterface.IStringEncoding +import prog8.compilerinterface.InternalCompilerException internal object Encoder: IStringEncoding { override fun encodeString(str: String, encoding: Encoding): List { @@ -15,7 +15,7 @@ internal object Encoder: IStringEncoding { Encoding.SCREENCODES -> PetsciiEncoding.encodeScreencode(str, true) Encoding.ISO -> IsoEncoding.encode(str) Encoding.ATASCII -> AtasciiEncoding.encode(str) - else -> throw FatalAstException("unsupported encoding $encoding") + else -> throw InternalCompilerException("unsupported encoding $encoding") } return coded.fold( failure = { throw it }, @@ -28,7 +28,7 @@ internal object Encoder: IStringEncoding { Encoding.SCREENCODES -> PetsciiEncoding.decodeScreencode(bytes, true) Encoding.ISO -> IsoEncoding.decode(bytes) Encoding.ATASCII -> AtasciiEncoding.decode(bytes) - else -> throw FatalAstException("unsupported encoding $encoding") + else -> throw InternalCompilerException("unsupported encoding $encoding") } return decoded.fold( failure = { throw it }, diff --git a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt index cec2cc194..6dc325c53 100644 --- a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt +++ b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt @@ -9,6 +9,7 @@ import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification import prog8.compilerinterface.ICompilationTarget import prog8.compilerinterface.IErrorReporter +import prog8.compilerinterface.InternalCompilerException // Fix up the literal value's type to match that of the vardecl // (also check range literal operands types before they get expanded into arrays for instance) @@ -97,7 +98,7 @@ internal class ConstantIdentifierReplacer(private val program: Program, private identifier.parent ) ) - in PassByReferenceDatatypes -> throw FatalAstException("pass-by-reference type should not be considered a constant") + in PassByReferenceDatatypes -> throw InternalCompilerException("pass-by-reference type should not be considered a constant") else -> noModifications } } catch (x: UndefinedSymbolError) { diff --git a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt index 2082f2374..6e3bdb7af 100644 --- a/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt +++ b/codeOptimizers/src/prog8/optimizer/ExpressionSimplifier.kt @@ -1,8 +1,6 @@ package prog8.optimizer -import prog8.ast.IStatementContainer -import prog8.ast.Node -import prog8.ast.Program +import prog8.ast.* import prog8.ast.base.DataType import prog8.ast.base.FatalAstException import prog8.ast.base.IntegerDatatypes @@ -333,44 +331,26 @@ class ExpressionSimplifier(private val program: Program, private val errors: IEr return noModifications } - override fun after(pipeExpr: PipeExpression, parent: Node): Iterable { - require(pipeExpr.segments.isNotEmpty()) - val segments = pipeExpr.segments - if(segments.size==1 && segments[0].isSimple) { - // just replace with a normal function call - val funcname = segments[1].target - val arg = segments[0] - val call = FunctionCallExpression(funcname.copy(), mutableListOf(arg), arg.position) - return listOf(IAstModification.ReplaceNode(pipeExpr, call, parent)) - } - val firstValue = pipeExpr.source - if(firstValue.isSimple) { - val funcname = pipeExpr.segments[0].target - val first = FunctionCallExpression(funcname.copy(), mutableListOf(firstValue), firstValue.position) - val newSegments = mutableListOf(first) - newSegments.addAll(pipeExpr.segments.drop(1)) - return listOf(IAstModification.ReplaceNode(pipeExpr, PipeExpression(first, newSegments, pipeExpr.position), parent)) - } - return noModifications - } + override fun after(pipeExpr: PipeExpression, parent: Node) = processPipe(pipeExpr, parent) + override fun after(pipe: Pipe, parent: Node) = processPipe(pipe, parent) - override fun after(pipe: Pipe, parent: Node): Iterable { - require(pipe.segments.isNotEmpty()) - val segments = pipe.segments - if(segments.size==1 && segments[0].isSimple) { - // just replace with a normal function call - val funcname = segments[1].target - val arg = segments[0] - val call = FunctionCallExpression(funcname.copy(), mutableListOf(arg), arg.position) - return listOf(IAstModification.ReplaceNode(pipe, call, parent)) - } - val firstValue = pipe.source - if(firstValue.isSimple) { - val funcname = pipe.segments[0].target - val first = FunctionCallExpression(funcname.copy(), mutableListOf(firstValue), firstValue.position) - val newSegments = mutableListOf(first) - newSegments.addAll(pipe.segments.drop(1)) - return listOf(IAstModification.ReplaceNode(pipe, Pipe(first, newSegments, pipe.position), parent)) + private fun processPipe(pipe: IPipe, parent: Node): Iterable { + if(pipe.source.isSimple) { + val segments = pipe.segments + if(segments.size==1) { + // replace the whole pipe with a normal function call + val funcname = (segments[0] as IFunctionCall).target + val call = if(pipe is Pipe) + FunctionCallStatement(funcname, mutableListOf(pipe.source), true, pipe.position) + else + FunctionCallExpression(funcname, mutableListOf(pipe.source), pipe.position) + return listOf(IAstModification.ReplaceNode(pipe as Node, call, parent)) + } else if(segments.size>1) { + // replace source+firstsegment by firstsegment(source) call as the new source + val firstSegment = segments.removeAt(0) as IFunctionCall + val call = FunctionCallExpression(firstSegment.target, mutableListOf(pipe.source), pipe.position) + return listOf(IAstModification.ReplaceNode(pipe.source, call, pipe as Node)) + } } return noModifications } diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index f4dbaa630..c53f22d2d 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -344,8 +344,8 @@ private fun postprocessAst(program: Program, errors: IErrorReporter, compilerOpt program.variousCleanups(errors, compilerOptions) val callGraph = CallGraph(program) callGraph.checkRecursiveCalls(errors) + program.verifyFunctionArgTypes(errors) errors.report() - program.verifyFunctionArgTypes() program.moveMainAndStartToFirst() program.checkValid(errors, compilerOptions) // check if final tree is still valid errors.report() diff --git a/compiler/src/prog8/compiler/ErrorReporter.kt b/compiler/src/prog8/compiler/ErrorReporter.kt index 2dc864133..5b78b95d5 100644 --- a/compiler/src/prog8/compiler/ErrorReporter.kt +++ b/compiler/src/prog8/compiler/ErrorReporter.kt @@ -29,10 +29,10 @@ internal class ErrorReporter: IErrorReporter { MessageSeverity.ERROR -> System.err } when(it.severity) { - MessageSeverity.ERROR -> printer.print("\u001b[91m") // bright red - MessageSeverity.WARNING -> printer.print("\u001b[93m") // bright yellow + MessageSeverity.ERROR -> printer.print("\u001b[91mERROR\u001B[0m ") // bright red + MessageSeverity.WARNING -> printer.print("\u001b[93mWARN\u001B[0m ") // bright yellow } - val msg = "${it.severity} ${it.position.toClickableStr()} ${it.message}".trim() + val msg = "${it.position.toClickableStr()} ${it.message}".trim() if(msg !in alreadyReportedMessages) { printer.println(msg) alreadyReportedMessages.add(msg) @@ -41,7 +41,6 @@ internal class ErrorReporter: IErrorReporter { MessageSeverity.ERROR -> numErrors++ } } - printer.print("\u001b[0m") // reset color } System.out.flush() System.err.flush() diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index fc7f074e7..461b02789 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -546,7 +546,7 @@ internal class AstChecker(private val program: Program, return } if(decl.value is RangeExpression) - throw FatalAstException("range expressions in vardecls should have been converted into array values during constFolding $decl") + throw InternalCompilerException("range expressions in vardecls should have been converted into array values during constFolding $decl") } when(decl.type) { @@ -555,7 +555,7 @@ internal class AstChecker(private val program: Program, null -> { // a vardecl without an initial value, don't bother with it } - is RangeExpression -> throw FatalAstException("range expression should have been converted to a true array value") + is RangeExpression -> throw InternalCompilerException("range expression should have been converted to a true array value") is StringLiteral -> { checkValueTypeAndRangeString(decl.datatype, decl.value as StringLiteral) } @@ -963,7 +963,7 @@ internal class AstChecker(private val program: Program, val error = VerifyFunctionArgTypes.checkTypes(functionCallExpr, program) if(error!=null) - errors.err(error, functionCallExpr.position) + errors.err(error.first, error.second) // check the functions that return multiple returnvalues. val stmt = functionCallExpr.target.targetStatement(program) @@ -992,7 +992,14 @@ internal class AstChecker(private val program: Program, } } else if(targetStatement is BuiltinFunctionPlaceholder) { - if(builtinFunctionReturnType(targetStatement.name, functionCallExpr.args, program).isUnknown) { + val args = if(functionCallExpr.parent is IPipe) { + // pipe segment, add implicit first argument + val firstArgDt = BuiltinFunctions.getValue(targetStatement.name).parameters.first().possibleDatatypes.first() + listOf(defaultZero(firstArgDt, functionCallExpr.position)) + functionCallExpr.args + } else { + functionCallExpr.args + } + if(builtinFunctionReturnType(targetStatement.name, args, program).isUnknown) { if(functionCallExpr.parent is Expression || functionCallExpr.parent is Assignment) errors.err("function doesn't return a value", functionCallExpr.position) } @@ -1043,9 +1050,8 @@ internal class AstChecker(private val program: Program, } val error = VerifyFunctionArgTypes.checkTypes(functionCallStatement, program) - if(error!=null) { - errors.err(error, functionCallStatement.args.firstOrNull()?.position ?: functionCallStatement.position) - } + if(error!=null) + errors.err(error.first, error.second) super.visit(functionCallStatement) } @@ -1112,6 +1118,19 @@ internal class AstChecker(private val program: Program, } } + override fun visit(pipe: PipeExpression) = process(pipe) + + override fun visit(pipe: Pipe) = process(pipe) + + private fun process(pipe: IPipe) { + if(pipe.source in pipe.segments) + throw InternalCompilerException("pipe source and segments should all be different nodes") + if (pipe.segments.isEmpty()) + throw FatalAstException("pipe is missing one or more expressions") + if(pipe.segments.any { it !is IFunctionCall }) + throw FatalAstException("pipe segments can only be function calls") + } + override fun visit(postIncrDecr: PostIncrDecr) { if(postIncrDecr.target.identifier != null) { val targetName = postIncrDecr.target.identifier!!.nameInSource @@ -1243,94 +1262,6 @@ internal class AstChecker(private val program: Program, super.visit(containment) } - override fun visit(pipe: PipeExpression) { - processPipe(pipe.source, pipe.segments, pipe) - if(errors.noErrors()) { - val last = pipe.segments.last().target - when (val target = last.targetStatement(program)!!) { - is BuiltinFunctionPlaceholder -> { - if (!BuiltinFunctions.getValue(target.name).hasReturn) - errors.err("invalid pipe expression; last term doesn't return a value", last.position) - } - is Subroutine -> { - if (target.returntypes.isEmpty()) - errors.err("invalid pipe expression; last term doesn't return a value", last.position) - else if (target.returntypes.size != 1) - errors.err("invalid pipe expression; last term doesn't return a single value", last.position) - } - else -> errors.err("invalid pipe expression; last term doesn't return a value", last.position) - } - super.visit(pipe) - } - } - - override fun visit(pipe: Pipe) { - processPipe(pipe.source, pipe.segments, pipe) - if(errors.noErrors()) { - super.visit(pipe) - } - } - - private fun processPipe(source: Expression, segments: List, scope: Node) { - // first expression is just any expression producing a value - // all other expressions should be the name of a unary function that returns a single value - // the last expression should be the name of a unary function whose return value we don't care about. - if (segments.isEmpty()) { - errors.err("pipe is missing one or more expressions", scope.position) - } else { - // invalid size and other issues will be handled by the ast checker later. - var valueDt = source.inferType(program).getOrElse { - throw FatalAstException("invalid dt") - } - - for(funccall in segments) { - val target = funccall.target.targetStatement(program) - if(target!=null) { - when (target) { - is BuiltinFunctionPlaceholder -> { - val func = BuiltinFunctions.getValue(target.name) - if(func.parameters.size!=1) - errors.err("can only use unary function", funccall.position) - else if(!func.hasReturn && funccall !== segments.last()) - errors.err("function must return a single value", funccall.position) - - val paramDts = func.parameters.firstOrNull()?.possibleDatatypes - if(paramDts!=null && !paramDts.any { valueDt isAssignableTo it }) - errors.err("pipe value datatype $valueDt incompatible with function argument ${paramDts.toList()}", funccall.position) - - if(errors.noErrors()) { - // type can depend on the argument(s) of the function. For now, we only deal with unary functions, - // so we know there must be a single argument. Take its type from the previous expression in the pipe chain. - val zero = defaultZero(valueDt, funccall.position) - valueDt = builtinFunctionReturnType(func.name, listOf(zero), program).getOrElse { DataType.UNDEFINED } - } - } - is Subroutine -> { - if(target.parameters.size!=1) - errors.err("can only use unary function", funccall.position) - else if(target.returntypes.size!=1 && funccall !== segments.last()) - errors.err("function must return a single value", funccall.position) - - val paramDt = target.parameters.firstOrNull()?.type - if(paramDt!=null && !(valueDt isAssignableTo paramDt)) - errors.err("pipe value datatype $valueDt incompatible with function argument $paramDt", funccall.position) - - if(target.returntypes.isNotEmpty()) - valueDt = target.returntypes.single() - } - is VarDecl -> { - if(!(valueDt isAssignableTo target.datatype)) - errors.err("final pipe value datatype can't be stored in pipe ending variable", funccall.position) - } - else -> { - throw FatalAstException("weird function") - } - } - } - } - } - } - private fun checkFunctionOrLabelExists(target: IdentifierReference, statement: Statement): Statement? { when (val targetStatement = target.targetStatement(program)) { is Label, is Subroutine, is BuiltinFunctionPlaceholder -> return targetStatement diff --git a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt index 27f455edc..4784923a1 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstExtensions.kt @@ -76,8 +76,8 @@ fun Program.desugaring(errors: IErrorReporter): Int { return desugar.applyModifications() } -internal fun Program.verifyFunctionArgTypes() { - val fixer = VerifyFunctionArgTypes(this) +internal fun Program.verifyFunctionArgTypes(errors: IErrorReporter) { + val fixer = VerifyFunctionArgTypes(this, errors) fixer.visit(this) } diff --git a/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt index 064ac1e5b..a05e47dc9 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt @@ -1,6 +1,7 @@ package prog8.compiler.astprocessing import prog8.ast.IFunctionCall +import prog8.ast.IPipe import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.Position @@ -135,14 +136,24 @@ internal class AstIdentifiersChecker(private val errors: IErrorReporter, private fun visitFunctionCall(call: IFunctionCall) { when (val target = call.target.targetStatement(program)) { is Subroutine -> { - if(call.args.size != target.parameters.size) { + // if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters + val expectedNumberOfArgs = if(call.parent is IPipe) + target.parameters.size-1 + else + target.parameters.size + if(call.args.size != expectedNumberOfArgs) { val pos = (if(call.args.any()) call.args[0] else (call as Node)).position errors.err("invalid number of arguments", pos) } } is BuiltinFunctionPlaceholder -> { val func = BuiltinFunctions.getValue(target.name) - if(call.args.size != func.parameters.size) { + // if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters + val expectedNumberOfArgs = if(call.parent is IPipe) + func.parameters.size-1 + else + func.parameters.size + if(call.args.size != expectedNumberOfArgs) { val pos = (if(call.args.any()) call.args[0] else (call as Node)).position errors.err("invalid number of arguments", pos) } diff --git a/compiler/src/prog8/compiler/astprocessing/AstPreprocessor.kt b/compiler/src/prog8/compiler/astprocessing/AstPreprocessor.kt index be6fb552c..2da76e0a4 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstPreprocessor.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstPreprocessor.kt @@ -1,15 +1,18 @@ package prog8.compiler.astprocessing -import prog8.ast.* +import prog8.ast.IPipe +import prog8.ast.Node +import prog8.ast.Program import prog8.ast.base.* import prog8.ast.expressions.* +import prog8.ast.getTempRegisterName import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification -import prog8.compilerinterface.BuiltinFunctions import prog8.compilerinterface.Encoding import prog8.compilerinterface.ICompilationTarget import prog8.compilerinterface.IErrorReporter +import prog8.compilerinterface.InternalCompilerException class AstPreprocessor(val program: Program, val errors: IErrorReporter, val compTarget: ICompilationTarget) : AstWalker() { @@ -112,43 +115,67 @@ class AstPreprocessor(val program: Program, val errors: IErrorReporter, val comp return noModifications } - override fun after(pipe: Pipe, parent: Node): Iterable { + override fun before(pipe: Pipe, parent: Node): Iterable { + if(pipe.source is PipeExpression) { + // correct Antlr parse tree quirk: turn nested pipe into single flat pipe + val psrc = pipe.source as PipeExpression + val newSource = psrc.source + val newSegments = psrc.segments + newSegments += pipe.segments.single() + return listOf(IAstModification.ReplaceNode(pipe as Node, Pipe(newSource, newSegments, pipe.position), parent)) + } + return process(pipe, parent) } - override fun after(pipeExpr: PipeExpression, parent: Node): Iterable { + override fun before(pipeExpr: PipeExpression, parent: Node): Iterable { + if(pipeExpr.source is PipeExpression) { + // correct Antlr parse tree quirk; turn nested pipe into single flat pipe + val psrc = pipeExpr.source as PipeExpression + val newSource = psrc.source + val newSegments = psrc.segments + newSegments += pipeExpr.segments.single() + return listOf(IAstModification.ReplaceNode(pipeExpr as Node, PipeExpression(newSource, newSegments, pipeExpr.position), parent)) + } + return process(pipeExpr, parent) } private fun process(pipe: IPipe, parent: Node): Iterable { + if(pipe.source is IPipe) + throw InternalCompilerException("pipe source should have been adjusted to be a normal expression") + + return noModifications + +// TODO don't use artifical inserted args, fix the places that check for arg numbers instead. // add the "missing" first argument to each function call in the pipe segments // so that all function call related checks just pass // might have to remove it again when entering code generation pass, or just replace it there // with the proper output value of the previous pipe segment. - return pipe.segments.map { - val firstArgDt = when (val target = it.target.targetStatement(program)) { - is Subroutine -> target.parameters.first().type - is BuiltinFunctionPlaceholder -> BuiltinFunctions.getValue(target.name).parameters.first().possibleDatatypes.first() - else -> DataType.UNDEFINED - } - val dummyFirstArg = when (firstArgDt) { - in IntegerDatatypes -> { - IdentifierReference( - getTempRegisterName(InferredTypes.InferredType.known(firstArgDt)), - pipe.position - ) - } - DataType.FLOAT -> { - val (name, _) = program.getTempVar(DataType.FLOAT) - IdentifierReference(name, pipe.position) - } - else -> throw FatalAstException("weird dt") - } - IAstModification.SetExpression( - { newexpr -> it.args.add(0, newexpr) }, - dummyFirstArg, parent - ) - } +// val mutations = mutableListOf() +// var valueDt = pipe.source.inferType(program).getOrElse { throw FatalAstException("invalid dt") } +// pipe.segments.forEach { call-> +// val dummyFirstArg = when (valueDt) { +// DataType.UBYTE -> FunctionCallExpression(IdentifierReference(listOf("rnd"), pipe.position), mutableListOf(), pipe.position) +// DataType.UWORD -> FunctionCallExpression(IdentifierReference(listOf("rndw"), pipe.position), mutableListOf(), pipe.position) +// DataType.BYTE, DataType.WORD -> IdentifierReference( +// getTempRegisterName(InferredTypes.InferredType.known(valueDt)), +// pipe.position +// ) // there's no builtin function we can abuse that returns a signed byte or word type // TODO maybe use a typecasted expression around rnd? +// DataType.FLOAT -> FunctionCallExpression(IdentifierReference(listOf("rndf"), pipe.position), mutableListOf(), pipe.position) +// else -> throw FatalAstException("invalid dt") +// } +// +// mutations += IAstModification.SetExpression( +// { newexpr -> call.args.add(0, newexpr) }, +// dummyFirstArg, parent +// ) +// +// if(call!==pipe.segments.last()) +// valueDt = call.inferType(program).getOrElse { throw FatalAstException("invalid dt") } +// } +// return mutations + } } diff --git a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt index af7a001f2..bd44b2146 100644 --- a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt +++ b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt @@ -16,15 +16,15 @@ internal class BeforeAsmAstChanger(val program: Program, ) : AstWalker() { override fun before(breakStmt: Break, parent: Node): Iterable { - throw FatalAstException("break should have been replaced by goto $breakStmt") + throw InternalCompilerException("break should have been replaced by goto $breakStmt") } override fun before(whileLoop: WhileLoop, parent: Node): Iterable { - throw FatalAstException("while should have been converted to jumps") + throw InternalCompilerException("while should have been converted to jumps") } override fun before(untilLoop: UntilLoop, parent: Node): Iterable { - throw FatalAstException("do..until should have been converted to jumps") + throw InternalCompilerException("do..until should have been converted to jumps") } override fun before(block: Block, parent: Node): Iterable { @@ -51,7 +51,7 @@ internal class BeforeAsmAstChanger(val program: Program, override fun after(decl: VarDecl, parent: Node): Iterable { if(!options.dontReinitGlobals) { if (decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes) - throw FatalAstException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl") + throw InternalCompilerException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl") } return noModifications @@ -183,7 +183,7 @@ internal class BeforeAsmAstChanger(val program: Program, if((binExpr.left as? NumericLiteral)?.number==0.0 && (binExpr.right as? NumericLiteral)?.number!=0.0) - throw FatalAstException("0==X should have been swapped to if X==0") + throw InternalCompilerException("0==X should have been swapped to if X==0") // simplify the conditional expression, introduce simple assignments if required. // NOTE: sometimes this increases code size because additional stores/loads are generated for the diff --git a/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt b/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt index c6e0d4721..f0055f92d 100644 --- a/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt +++ b/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt @@ -1,28 +1,35 @@ package prog8.compiler.astprocessing import prog8.ast.IFunctionCall +import prog8.ast.IPipe +import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.DataType +import prog8.ast.base.FatalAstException +import prog8.ast.base.Position +import prog8.ast.base.defaultZero import prog8.ast.expressions.Expression import prog8.ast.expressions.FunctionCallExpression +import prog8.ast.expressions.PipeExpression import prog8.ast.expressions.TypecastExpression import prog8.ast.statements.* import prog8.ast.walk.IAstVisitor import prog8.compilerinterface.BuiltinFunctions -import prog8.compilerinterface.InternalCompilerException +import prog8.compilerinterface.IErrorReporter +import prog8.compilerinterface.builtinFunctionReturnType -internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor { +internal class VerifyFunctionArgTypes(val program: Program, val errors: IErrorReporter) : IAstVisitor { override fun visit(functionCallExpr: FunctionCallExpression) { val error = checkTypes(functionCallExpr as IFunctionCall, program) if(error!=null) - throw InternalCompilerException(error) + errors.err(error.first, error.second) } override fun visit(functionCallStatement: FunctionCallStatement) { val error = checkTypes(functionCallStatement as IFunctionCall, program) - if (error!=null) - throw InternalCompilerException(error) + if(error!=null) + errors.err(error.first, error.second) } companion object { @@ -41,22 +48,27 @@ internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor { return false } - fun checkTypes(call: IFunctionCall, program: Program): String? { + fun checkTypes(call: IFunctionCall, program: Program): Pair? { val argITypes = call.args.map { it.inferType(program) } val firstUnknownDt = argITypes.indexOfFirst { it.isUnknown } if(firstUnknownDt>=0) - return "argument ${firstUnknownDt+1} invalid argument type" + return Pair("argument ${firstUnknownDt+1} invalid argument type", call.args[firstUnknownDt].position) val argtypes = argITypes.map { it.getOr(DataType.UNDEFINED) } val target = call.target.targetStatement(program) + val isPartOfPipeSegments = (call.parent as? IPipe)?.segments?.contains(call as Node) == true if (target is Subroutine) { - if(call.args.size != target.parameters.size) - return "invalid number of arguments (#1)" // TODO how does this relate to the same error in AstIdentifiersChecker - val paramtypes = target.parameters.map { it.type } - val mismatch = argtypes.zip(paramtypes).indexOfFirst { !argTypeCompatible(it.first, it.second) } + val consideredParamTypes = if(isPartOfPipeSegments) { + target.parameters.drop(1).map { it.type } // skip first one (the implicit first arg), this is checked elsewhere + } else { + target.parameters.map { it.type } + } + if(argtypes.size != consideredParamTypes.size) + return Pair("invalid number of arguments", call.position) + val mismatch = argtypes.zip(consideredParamTypes).indexOfFirst { !argTypeCompatible(it.first, it.second) } if(mismatch>=0) { val actual = argtypes[mismatch].toString() - val expected = paramtypes[mismatch].toString() - return "argument ${mismatch + 1} type mismatch, was: $actual expected: $expected" + val expected = consideredParamTypes[mismatch].toString() + return Pair("argument ${mismatch + 1} type mismatch, was: $actual expected: $expected", call.args[mismatch].position) } if(target.isAsmSubroutine) { if(target.asmReturnvaluesRegisters.size>1) { @@ -70,7 +82,7 @@ internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor { else parent if (checkParent !is Assignment && checkParent !is VarDecl) { - return "can't use subroutine call that returns multiple return values here" + return Pair("can't use subroutine call that returns multiple return values here", call.position) } } } @@ -78,19 +90,23 @@ internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor { } else if (target is BuiltinFunctionPlaceholder) { val func = BuiltinFunctions.getValue(target.name) - if(call.args.size != func.parameters.size) - return "invalid number of arguments (#2)" // TODO how does this relate to the same error in AstIdentifiersChecker - val paramtypes = func.parameters.map { it.possibleDatatypes } - argtypes.zip(paramtypes).forEachIndexed { index, pair -> + val consideredParamTypes = if(isPartOfPipeSegments) { + func.parameters.drop(1).map { it.possibleDatatypes } // skip first one (the implicit first arg), this is checked elsewhere + } else { + func.parameters.map { it.possibleDatatypes } + } + if(argtypes.size != consideredParamTypes.size) + return Pair("invalid number of arguments", call.position) + argtypes.zip(consideredParamTypes).forEachIndexed { index, pair -> val anyCompatible = pair.second.any { argTypeCompatible(pair.first, it) } if (!anyCompatible) { val actual = pair.first.toString() return if(pair.second.size==1) { val expected = pair.second[0].toString() - "argument ${index + 1} type mismatch, was: $actual expected: $expected" + Pair("argument ${index + 1} type mismatch, was: $actual expected: $expected", call.args[index].position) } else { val expected = pair.second.toList().toString() - "argument ${index + 1} type mismatch, was: $actual expected one of: $expected" + Pair("argument ${index + 1} type mismatch, was: $actual expected one of: $expected", call.args[index].position) } } } @@ -99,4 +115,91 @@ internal class VerifyFunctionArgTypes(val program: Program) : IAstVisitor { return null } } + + override fun visit(pipe: PipeExpression) { + processPipe(pipe.source, pipe.segments, pipe) + if(errors.noErrors()) { + val last = (pipe.segments.last() as IFunctionCall).target + when (val target = last.targetStatement(program)!!) { + is BuiltinFunctionPlaceholder -> { + if (!BuiltinFunctions.getValue(target.name).hasReturn) + errors.err("invalid pipe expression; last term doesn't return a value", last.position) + } + is Subroutine -> { + if (target.returntypes.isEmpty()) + errors.err("invalid pipe expression; last term doesn't return a value", last.position) + else if (target.returntypes.size != 1) + errors.err("invalid pipe expression; last term doesn't return a single value", last.position) + } + else -> errors.err("invalid pipe expression; last term doesn't return a value", last.position) + } + super.visit(pipe) + } + } + + override fun visit(pipe: Pipe) { + processPipe(pipe.source, pipe.segments, pipe) + if(errors.noErrors()) { + super.visit(pipe) + } + } + + private fun processPipe(source: Expression, segments: List, scope: Node) { + + val sourceArg = (source as? IFunctionCall)?.args?.firstOrNull() + if(sourceArg!=null && segments.any { (it as IFunctionCall).args.firstOrNull() === sourceArg}) + throw FatalAstException("some pipe segment first arg is replicated from the source functioncall arg") + + // invalid size and other issues will be handled by the ast checker later. + var valueDt = source.inferType(program).getOrElse { + throw FatalAstException("invalid dt") + } + + for(funccall in segments) { + val target = (funccall as IFunctionCall).target.targetStatement(program) + if(target!=null) { + when (target) { + is BuiltinFunctionPlaceholder -> { + val func = BuiltinFunctions.getValue(target.name) + if(func.parameters.size!=1) + errors.err("can only use unary function", funccall.position) + else if(!func.hasReturn && funccall !== segments.last()) + errors.err("function must return a single value", funccall.position) + + val paramDts = func.parameters.firstOrNull()?.possibleDatatypes + if(paramDts!=null && !paramDts.any { valueDt isAssignableTo it }) + errors.err("pipe value datatype $valueDt incompatible with function argument ${paramDts.toList()}", funccall.position) + + if(errors.noErrors()) { + // type can depend on the argument(s) of the function. For now, we only deal with unary functions, + // so we know there must be a single argument. Take its type from the previous expression in the pipe chain. + val zero = defaultZero(valueDt, funccall.position) + valueDt = builtinFunctionReturnType(func.name, listOf(zero), program).getOrElse { DataType.UNDEFINED } + } + } + is Subroutine -> { + if(target.parameters.size!=1) + errors.err("can only use unary function", funccall.position) + else if(target.returntypes.size!=1 && funccall !== segments.last()) + errors.err("function must return a single value", funccall.position) + + val paramDt = target.parameters.firstOrNull()?.type + if(paramDt!=null && !(valueDt isAssignableTo paramDt)) + errors.err("pipe value datatype $valueDt incompatible with function argument $paramDt", funccall.position) + + if(target.returntypes.isNotEmpty()) + valueDt = target.returntypes.single() + } + is VarDecl -> { + if(!(valueDt isAssignableTo target.datatype)) + errors.err("final pipe value datatype can't be stored in pipe ending variable", funccall.position) + } + else -> { + throw FatalAstException("weird function") + } + } + } + } + } + } diff --git a/compiler/test/ModuleImporterTests.kt b/compiler/test/ModuleImporterTests.kt index 55bfa8b31..109fbad0b 100644 --- a/compiler/test/ModuleImporterTests.kt +++ b/compiler/test/ModuleImporterTests.kt @@ -204,14 +204,14 @@ class TestModuleImporter: FunSpec({ val result = importer.importLibraryModule(filenameNoExt) withClue(count[n] + " call / NO .p8 extension") { result shouldBe null } withClue(count[n] + " call / NO .p8 extension") { errors.noErrors() shouldBe false } - errors.errors.single() shouldContain "0:0) no module found with name i_do_not_exist" + errors.errors.single() shouldContain "0:0: no module found with name i_do_not_exist" errors.report() program.modules.size shouldBe 1 val result2 = importer.importLibraryModule(filenameWithExt) withClue(count[n] + " call / with .p8 extension") { result2 shouldBe null } withClue(count[n] + " call / with .p8 extension") { importer.errors.noErrors() shouldBe false } - errors.errors.single() shouldContain "0:0) no module found with name i_do_not_exist.p8" + errors.errors.single() shouldContain "0:0: no module found with name i_do_not_exist.p8" errors.report() program.modules.size shouldBe 1 } diff --git a/compiler/test/TestAstChecks.kt b/compiler/test/TestAstChecks.kt index 46a99b19c..76a2f2211 100644 --- a/compiler/test/TestAstChecks.kt +++ b/compiler/test/TestAstChecks.kt @@ -53,8 +53,8 @@ class TestAstChecks: FunSpec({ compileText(C64Target(), true, text, writeAssembly = true, errors=errors).assertFailure() errors.errors.size shouldBe 2 errors.warnings.size shouldBe 0 - errors.errors[0] shouldContain ":7:28) assignment value is invalid" - errors.errors[1] shouldContain ":8:28) assignment value is invalid" + errors.errors[0] shouldContain ":7:28: assignment value is invalid" + errors.errors[1] shouldContain ":8:28: assignment value is invalid" } test("can't do str or array expression without using address-of") { diff --git a/compiler/test/TestCompilerOnRanges.kt b/compiler/test/TestCompilerOnRanges.kt index 120791b2b..c8b8e09af 100644 --- a/compiler/test/TestCompilerOnRanges.kt +++ b/compiler/test/TestCompilerOnRanges.kt @@ -226,8 +226,8 @@ class TestCompilerOnRanges: FunSpec({ } """, errors, false).assertFailure() errors.errors.size shouldBe 2 - errors.errors[0] shouldContain ".p8:5:30) range expression from value must be integer" - errors.errors[1] shouldContain ".p8:5:45) range expression to value must be integer" + errors.errors[0] shouldContain ".p8:5:30: range expression from value must be integer" + errors.errors[1] shouldContain ".p8:5:45: range expression to value must be integer" } test("testForLoopWithIterable_str") { diff --git a/compiler/test/TestPipes.kt b/compiler/test/TestPipes.kt index b34715652..f59a790db 100644 --- a/compiler/test/TestPipes.kt +++ b/compiler/test/TestPipes.kt @@ -4,20 +4,122 @@ import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.kotest.matchers.types.instanceOf +import prog8.ast.IFunctionCall +import prog8.ast.Program +import prog8.ast.base.DataType +import prog8.ast.base.Position import prog8.ast.expressions.* import prog8.ast.statements.Assignment import prog8.ast.statements.FunctionCallStatement import prog8.ast.statements.Pipe +import prog8.ast.statements.VarDecl import prog8.codegen.target.C64Target -import prog8tests.helpers.ErrorReporterForTests -import prog8tests.helpers.assertFailure -import prog8tests.helpers.assertSuccess -import prog8tests.helpers.compileText +import prog8.compiler.astprocessing.AstPreprocessor +import prog8.parser.Prog8Parser.parseModule +import prog8.parser.SourceCode +import prog8tests.helpers.* class TestPipes: FunSpec({ - test("correct pipe statements") { + test("pipe expression parse tree after preprocessing") { + val text = """ + main { + sub start() { + uword xx = 9999 |> func1() |> func2() + |> func1() |> func2() + |> func1() + } + sub func1(uword arg) -> uword { + return arg+1111 + } + sub func2(uword arg) -> uword { + return arg+2222 + } + } + """ + val src = SourceCode.Text(text) + val module = parseModule(src) + val errors = ErrorReporterForTests() + val program = Program("test", DummyFunctions, DummyMemsizer, DummyStringEncoder) + program.addModule(module) + val preprocess = AstPreprocessor(program, errors, C64Target()) + preprocess.visit(program) + errors.errors.size shouldBe 0 + preprocess.applyModifications() + + program.entrypoint.statements.size shouldBe 1 + val pipe = (program.entrypoint.statements.single() as VarDecl).value as PipeExpression + pipe.source shouldBe NumericLiteral(DataType.UWORD, 9999.0, Position.DUMMY) + pipe.segments.size shouldBe 5 + var call = pipe.segments[0] as IFunctionCall + call.target.nameInSource shouldBe listOf("func1") + call.args.size shouldBe 0 + call = pipe.segments[1] as IFunctionCall + call.target.nameInSource shouldBe listOf("func2") + call.args.size shouldBe 0 + call = pipe.segments[2] as IFunctionCall + call.target.nameInSource shouldBe listOf("func1") + call.args.size shouldBe 0 + call = pipe.segments[3] as IFunctionCall + call.target.nameInSource shouldBe listOf("func2") + call.args.size shouldBe 0 + call = pipe.segments[4] as IFunctionCall + call.target.nameInSource shouldBe listOf("func1") + call.args.size shouldBe 0 + } + + test("pipe statement parse tree after preprocessing") { + val text = """ + main { + sub start() { + 9999 |> func1() |> func2() + |> func1() |> func2() + |> func3() + } + sub func1(uword arg) -> uword { + return arg+1111 + } + sub func2(uword arg) -> uword { + return arg+2222 + } + sub func3(uword arg) { + ; nothing + } + } + """ + val src = SourceCode.Text(text) + val module = parseModule(src) + val errors = ErrorReporterForTests() + val program = Program("test", DummyFunctions, DummyMemsizer, DummyStringEncoder) + program.addModule(module) + val preprocess = AstPreprocessor(program, errors, C64Target()) + preprocess.visit(program) + errors.errors.size shouldBe 0 + preprocess.applyModifications() + + program.entrypoint.statements.size shouldBe 1 + val pipe = program.entrypoint.statements.single() as Pipe + pipe.source shouldBe NumericLiteral(DataType.UWORD, 9999.0, Position.DUMMY) + pipe.segments.size shouldBe 5 + var call = pipe.segments[0] as IFunctionCall + call.target.nameInSource shouldBe listOf("func1") + call.args.size shouldBe 0 + call = pipe.segments[1] as IFunctionCall + call.target.nameInSource shouldBe listOf("func2") + call.args.size shouldBe 0 + call = pipe.segments[2] as IFunctionCall + call.target.nameInSource shouldBe listOf("func1") + call.args.size shouldBe 0 + call = pipe.segments[3] as IFunctionCall + call.target.nameInSource shouldBe listOf("func2") + call.args.size shouldBe 0 + call = pipe.segments[4] as IFunctionCall + call.target.nameInSource shouldBe listOf("func3") + call.args.size shouldBe 0 + } + + test("correct pipe statements (no opt)") { val text = """ %import floats %import textio @@ -31,7 +133,6 @@ class TestPipes: FunSpec({ 9999 |> addword() |> txt.print_uw() - ; these are optimized into just the function calls: 9999 |> abs() |> txt.print_uw() 9999 |> txt.print_uw() 99 |> abs() |> txt.print_ub() @@ -46,20 +147,76 @@ class TestPipes: FunSpec({ } } """ - val result = compileText(C64Target(), true, text, writeAssembly = true).assertSuccess() + val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess() val stmts = result.program.entrypoint.statements stmts.size shouldBe 7 val pipef = stmts[0] as Pipe pipef.source shouldBe instanceOf() pipef.segments.size shouldBe 2 - pipef.segments[0] shouldBe instanceOf() - pipef.segments[1] shouldBe instanceOf() + var call = pipef.segments[0] as IFunctionCall + call.target.nameInSource shouldBe listOf("addfloat") + call = pipef.segments[1] as IFunctionCall + call.target.nameInSource shouldBe listOf("floats", "print_f") val pipew = stmts[1] as Pipe pipef.source shouldBe instanceOf() pipew.segments.size shouldBe 2 - pipew.segments[0] shouldBe instanceOf() - pipew.segments[1] shouldBe instanceOf() + call = pipew.segments[0] as IFunctionCall + call.target.nameInSource shouldBe listOf("addword") + call = pipew.segments[1] as IFunctionCall + call.target.nameInSource shouldBe listOf("txt", "print_uw") + + stmts[2] shouldBe instanceOf() + stmts[3] shouldBe instanceOf() + stmts[4] shouldBe instanceOf() + stmts[5] shouldBe instanceOf() + } + + test("correct pipe statements (with opt)") { + val text = """ + %import floats + %import textio + + main { + sub start() { + + 1.234 |> addfloat() + |> floats.print_f() + + 9999 |> addword() + |> txt.print_uw() + + ; these should be optimized into just the function calls: + 9999 |> abs() |> txt.print_uw() + 9999 |> txt.print_uw() + 99 |> abs() |> txt.print_ub() + 99 |> txt.print_ub() + } + + sub addfloat(float fl) -> float { + return fl+2.22 + } + sub addword(uword ww) -> uword { + return ww+2222 + } + } + """ + val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess() + val stmts = result.program.entrypoint.statements + stmts.size shouldBe 7 + val pipef = stmts[0] as Pipe + pipef.source shouldBe instanceOf() + (pipef.source as IFunctionCall).target.nameInSource shouldBe listOf("addfloat") + pipef.segments.size shouldBe 1 + val callf = pipef.segments[0] as IFunctionCall + callf.target.nameInSource shouldBe listOf("floats", "print_f") + + val pipew = stmts[1] as Pipe + pipef.source shouldBe instanceOf() + (pipew.source as IFunctionCall).target.nameInSource shouldBe listOf("addword") + pipew.segments.size shouldBe 1 + val callw = pipew.segments[0] as IFunctionCall + callw.target.nameInSource shouldBe listOf("txt", "print_uw") var stmt = stmts[2] as FunctionCallStatement stmt.target.nameInSource shouldBe listOf("txt", "print_uw") @@ -78,8 +235,8 @@ class TestPipes: FunSpec({ main { sub start() { - 1.234 |> addfloat - |> addword |> addword + 1.234 |> addfloat() + |> addword() |> addword() } sub addfloat(float fl) -> float { @@ -96,21 +253,21 @@ class TestPipes: FunSpec({ errors.errors[0] shouldContain "incompatible" } - test("correct pipe expressions") { + test("correct pipe expressions (no opt)") { val text = """ %import floats %import textio main { sub start() { - float @shared fl = 1.234 |> addfloat - |> addfloat + float @shared fl = 1.234 |> addfloat() + |> addfloat() - uword @shared ww = 9999 |> addword - |> addword + uword @shared ww = 9999 |> addword() + |> addword() - ubyte @shared cc = 30 |> sin8u |> cos8u ; will be optimized away into a const number - cc = cc |> sin8u |> cos8u + ubyte @shared cc = 30 |> sin8u() |> cos8u() + cc = cc |> sin8u() |> cos8u() } sub addfloat(float fl) -> float { @@ -119,27 +276,86 @@ class TestPipes: FunSpec({ sub addword(uword ww) -> uword { return ww+2222 } - sub addbyte(ubyte bb) -> ubyte { - return bb+1 - } } """ - val result = compileText(C64Target(), true, text, writeAssembly = true).assertSuccess() + val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess() val stmts = result.program.entrypoint.statements stmts.size shouldBe 8 val assignf = stmts[1] as Assignment val pipef = assignf.value as PipeExpression pipef.source shouldBe instanceOf() pipef.segments.size shouldBe 2 - pipef.segments[0] shouldBe instanceOf() - pipef.segments[1] shouldBe instanceOf() + var call = pipef.segments[0] as IFunctionCall + call.target.nameInSource shouldBe listOf("addfloat") + call = pipef.segments[1] as IFunctionCall + call.target.nameInSource shouldBe listOf("addfloat") + val assignw = stmts[3] as Assignment val pipew = assignw.value as PipeExpression pipew.source shouldBe instanceOf() pipew.segments.size shouldBe 2 + call = pipew.segments[0] as IFunctionCall + call.target.nameInSource shouldBe listOf("addword") + call = pipew.segments[1] as IFunctionCall + call.target.nameInSource shouldBe listOf("addword") + + var assigncc = stmts[5] as Assignment + val value = assigncc.value as PipeExpression + value.source shouldBe NumericLiteral(DataType.UBYTE, 30.0, Position.DUMMY) + value.segments.size shouldBe 2 + call = value.segments[0] as IFunctionCall + call.target.nameInSource shouldBe listOf("sin8u") + call = value.segments[1] as IFunctionCall + call.target.nameInSource shouldBe listOf("cos8u") + + assigncc = stmts[6] as Assignment + val pipecc = assigncc.value as PipeExpression + pipecc.source shouldBe instanceOf() + pipecc.segments.size shouldBe 2 + pipecc.segments[0] shouldBe instanceOf() + pipecc.segments[1] shouldBe instanceOf() + } + + test("correct pipe expressions (with opt)") { + val text = """ + %import floats + %import textio + + main { + sub start() { + float @shared fl = 1.234 |> addfloat() + |> addfloat() + + uword @shared ww = 9999 |> addword() + |> addword() + + ubyte @shared cc = 30 |> sin8u() |> cos8u() ; will be optimized away into a const number + cc = cc |> sin8u() |> cos8u() + } + + sub addfloat(float fl) -> float { + return fl+2.22 + } + sub addword(uword ww) -> uword { + return ww+2222 + } + } + """ + val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess() + val stmts = result.program.entrypoint.statements + stmts.size shouldBe 8 + val assignf = stmts[1] as Assignment + val pipef = assignf.value as PipeExpression + pipef.source shouldBe instanceOf() + pipef.segments.size shouldBe 1 + pipef.segments[0] shouldBe instanceOf() + + val assignw = stmts[3] as Assignment + val pipew = assignw.value as PipeExpression + pipew.source shouldBe instanceOf() + pipew.segments.size shouldBe 1 pipew.segments[0] shouldBe instanceOf() - pipew.segments[1] shouldBe instanceOf() var assigncc = stmts[5] as Assignment val value = assigncc.value as NumericLiteral @@ -147,43 +363,12 @@ class TestPipes: FunSpec({ assigncc = stmts[6] as Assignment val pipecc = assigncc.value as PipeExpression - pipecc.source shouldBe instanceOf() - pipecc.segments.size shouldBe 2 + pipecc.source shouldBe instanceOf() + (pipecc.source as BuiltinFunctionCall).target.nameInSource shouldBe listOf("sin8u") + + pipecc.segments.size shouldBe 1 pipecc.segments[0] shouldBe instanceOf() - pipecc.segments[1] shouldBe instanceOf() - } - - test("correct pipe expressions with variables at end") { - val text = """ - %import textio - - main { - sub start() { - uword @shared ww - ubyte @shared cc - - 9999 |> addword |> addword |> ww - 30 |> sin8u |> cos8u |> cc ; will be optimized away into a const number - } - sub addword(uword ww) -> uword { - return ww+2222 - } - } - """ - val result = compileText(C64Target(), true, text, writeAssembly = true).assertSuccess() - val stmts = result.program.entrypoint.statements - stmts.size shouldBe 7 - - val assignw = stmts[4] as Assignment - val pipew = assignw.value as PipeExpression - pipew.source shouldBe instanceOf() - pipew.segments.size shouldBe 2 - pipew.segments[0] shouldBe instanceOf() - pipew.segments[1] shouldBe instanceOf() - - val assigncc = stmts[5] as Assignment - val value = assigncc.value as NumericLiteral - value.number shouldBe 190.0 + (pipecc.segments[0] as BuiltinFunctionCall).target.nameInSource shouldBe listOf("cos8u") } test("incorrect type in pipe expression") { @@ -192,8 +377,8 @@ class TestPipes: FunSpec({ main { sub start() { - uword result = 1.234 |> addfloat - |> addword |> addword + uword result = 1.234 |> addfloat() + |> addword() |> addword() } sub addfloat(float fl) -> float { @@ -218,25 +403,23 @@ class TestPipes: FunSpec({ sub start() { uword ww = 9999 ubyte bb = 99 - ww |> abs |> txt.print_uw - bb |> abs |> txt.print_ub + ww |> abs() |> txt.print_uw() + bb |> abs() |> txt.print_ub() } } """ val result = compileText(C64Target(), true, text, writeAssembly = true).assertSuccess() val stmts = result.program.entrypoint.statements stmts.size shouldBe 7 - val pipef = stmts[4] as Pipe - pipef.source shouldBe instanceOf() - pipef.segments.size shouldBe 2 - pipef.segments[0] shouldBe instanceOf() - pipef.segments[1] shouldBe instanceOf() + val pipeww = stmts[4] as Pipe + pipeww.source shouldBe instanceOf() + pipeww.segments.size shouldBe 1 + pipeww.segments[0] shouldBe instanceOf() - val pipew = stmts[5] as Pipe - pipew.source shouldBe instanceOf() - pipew.segments.size shouldBe 2 - pipew.segments[0] shouldBe instanceOf() - pipew.segments[1] shouldBe instanceOf() + val pipebb = stmts[5] as Pipe + pipebb.source shouldBe instanceOf() + pipebb.segments.size shouldBe 1 + pipebb.segments[0] shouldBe instanceOf() } test("pipe statement with type errors") { @@ -246,13 +429,13 @@ class TestPipes: FunSpec({ main { sub start() { uword ww = 9999 - 9999 |> abs |> txt.print_ub - ww |> abs |> txt.print_ub + 9999 |> abs() |> txt.print_ub() + ww |> abs() |> txt.print_ub() } } """ val errors = ErrorReporterForTests() - compileText(C64Target(), true, text, writeAssembly = true, errors=errors).assertFailure() + compileText(C64Target(), optimize = false, text, writeAssembly = true, errors=errors).assertFailure() errors.errors.size shouldBe 2 errors.errors[0] shouldContain "UWORD incompatible" errors.errors[1] shouldContain "UWORD incompatible" diff --git a/compiler/test/TestSubroutines.kt b/compiler/test/TestSubroutines.kt index 6b37dbf4d..111bad03e 100644 --- a/compiler/test/TestSubroutines.kt +++ b/compiler/test/TestSubroutines.kt @@ -286,8 +286,8 @@ class TestSubroutines: FunSpec({ val errors = ErrorReporterForTests() compileText(C64Target(), false, text, writeAssembly = false, errors=errors).assertFailure() errors.errors.size shouldBe 2 - errors.errors[0] shouldContain "7:25) invalid number of arguments" - errors.errors[1] shouldContain "9:25) invalid number of arguments" + errors.errors[0] shouldContain "7:25: invalid number of arguments" + errors.errors[1] shouldContain "9:25: invalid number of arguments" } test("invalid number of args check on asm subroutine") { @@ -307,8 +307,8 @@ class TestSubroutines: FunSpec({ val errors = ErrorReporterForTests() compileText(C64Target(), false, text, writeAssembly = false, errors=errors).assertFailure() errors.errors.size shouldBe 2 - errors.errors[0] shouldContain "7:25) invalid number of arguments" - errors.errors[1] shouldContain "9:25) invalid number of arguments" + errors.errors[0] shouldContain "7:25: invalid number of arguments" + errors.errors[1] shouldContain "9:25: invalid number of arguments" } test("invalid number of args check on call to label and builtin func") { diff --git a/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt b/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt index 25b80ca5c..21e21b08f 100644 --- a/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt +++ b/compilerAst/src/prog8/ast/AstToSourceTextConverter.kt @@ -478,7 +478,7 @@ class AstToSourceTextConverter(val output: (text: String) -> Unit, val program: printPipe(pipe.source, pipe.segments) } - private fun printPipe(source: Expression, segments: Iterable) { + private fun printPipe(source: Expression, segments: Iterable) { source.accept(this) segments.first().accept(this) outputln("") diff --git a/compilerAst/src/prog8/ast/AstToplevel.kt b/compilerAst/src/prog8/ast/AstToplevel.kt index 557d8eac0..08fd7e926 100644 --- a/compilerAst/src/prog8/ast/AstToplevel.kt +++ b/compilerAst/src/prog8/ast/AstToplevel.kt @@ -22,7 +22,7 @@ interface IFunctionCall { interface IPipe { var source: Expression - val segments: MutableList + val segments: MutableList // are all function calls val position: Position var parent: Node // will be linked correctly later (late init) } diff --git a/compilerAst/src/prog8/ast/base/Base.kt b/compilerAst/src/prog8/ast/base/Base.kt index b285f1cde..a8fd66c2d 100644 --- a/compilerAst/src/prog8/ast/base/Base.kt +++ b/compilerAst/src/prog8/ast/base/Base.kt @@ -193,8 +193,8 @@ object ParentSentinel : Node { data class Position(val file: String, val line: Int, val startCol: Int, val endCol: Int) { override fun toString(): String = "[$file: line $line col ${startCol+1}-${endCol+1}]" fun toClickableStr(): String { - val path = (Path("") / file).absolute().normalize() - return "($path:$line:$startCol)" + val path = Path(file).absolute().normalize() + return "file://$path:$line:$startCol:" } companion object { diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index 954f235a5..1f75449e8 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -939,6 +939,14 @@ class FunctionCallExpression(override var target: IdentifierReference, // lenghts of arrays and strings are constants that are determined at compile time! if(target.nameInSource.size>1) return null + + // If the function call is part of a Pipe segments, the number of args will be 1 less than the number of parameters required + // because of the implicit first argument. We don't know this first argument here. Assume it is not a constant, + // which means that this function call cannot be a constant either. + val pipeParentSegments = (parent as? IPipe)?.segments ?: emptyList() + if(this in pipeParentSegments) + return null + val resultValue: NumericLiteral? = program.builtinFunctions.constValue(target.nameInSource[0], args, position) if(withDatatypeCheck) { val resultDt = this.inferType(program) @@ -1075,7 +1083,7 @@ class ContainmentCheck(var element: Expression, } class PipeExpression(override var source: Expression, - override val segments: MutableList, + override val segments: MutableList, // are all function calls override val position: Position): Expression(), IPipe { override lateinit var parent: Node @@ -1099,8 +1107,9 @@ class PipeExpression(override var source: Expression, if(node===source) { source = replacement } else { + require(replacement is IFunctionCall) val idx = segments.indexOf(node) - segments[idx] = replacement as FunctionCallExpression + segments[idx] = replacement } } } diff --git a/compilerAst/src/prog8/ast/statements/AstStatements.kt b/compilerAst/src/prog8/ast/statements/AstStatements.kt index cb72d4467..fe5780877 100644 --- a/compilerAst/src/prog8/ast/statements/AstStatements.kt +++ b/compilerAst/src/prog8/ast/statements/AstStatements.kt @@ -1023,7 +1023,7 @@ class DirectMemoryWrite(var addressExpression: Expression, override val position class Pipe(override var source: Expression, - override val segments: MutableList, + override val segments: MutableList, // are all function calls override val position: Position): Statement(), IPipe { override lateinit var parent: Node @@ -1043,8 +1043,9 @@ class Pipe(override var source: Expression, if(node===source) { source = replacement } else { + require(replacement is IFunctionCall) val idx = segments.indexOf(node) - segments[idx] = replacement as FunctionCallExpression + segments[idx] = replacement } } } diff --git a/docs/source/syntaxreference.rst b/docs/source/syntaxreference.rst index e5e69da96..842b2b31c 100644 --- a/docs/source/syntaxreference.rst +++ b/docs/source/syntaxreference.rst @@ -528,31 +528,26 @@ containment check: ``in`` pipe: ``|>`` Used as an alternative to nesting function calls. The pipe operator is used to 'pipe' the value - into the next function. It only works on unary functions (taking a single argument), because you just - specify the *name* of the function that the value has to be piped into. The resulting value will be piped - into the next function as long as there are chained pipes. + into the next function. You write a pipe as a sequence of function calls. You don't write + the arguments to the functions though: the value of one segment in the pipe, will be used as the argument + for the next function call in the sequence. + + *note:* It only works on unary functions (taking a single argument) for now. + For example, this: ``txt.print_uw(add_bonus(determine_score(get_player(1))))`` can be rewritten as:: get_player(1) - |> determine_score - |> add_bonus - |> txt.print_uw + |> determine_score() + |> add_bonus() + |> txt.print_uw() - It also works for expressions that return a value, for example ``uword score = add_bonus(determine_score(get_player(1)))`` :: + A pipe can also be written as an expression that returns a value, for example ``uword score = add_bonus(determine_score(get_player(1)))`` :: uword score = get_player(1) - |> determine_score - |> add_bonus + |> determine_score() + |> add_bonus() - Finally, if you like the left-to-right flow, it's possible to use the name of a variable as the last term. This just means that the pipe's resulting value is - stored in that variable (it's just another way of writing an assignment). So the above can also be written as:: - - uword score - get_player(1) - |> determine_score - |> add_bonus - |> score address of: ``&`` This is a prefix operator that can be applied to a string or array variable or literal value. diff --git a/docs/source/todo.rst b/docs/source/todo.rst index bd1368383..444552718 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -23,11 +23,10 @@ Future Things and Ideas ^^^^^^^^^^^^^^^^^^^^^^^ Compiler: +- pipe operator: allow non-unary function calls in the pipe that specify the other argument(s) in the calls. - writeAssembly(): make it possible to actually get rid of the VarDecl nodes by fixing the rest of the code mentioned there. - make everything an expression? (get rid of Statements. Statements are expressions with void return types?). - allow "xxx" * constexpr (where constexpr is not a number literal), now gives expression error not same type -- for the pipe operator: recognise a placeholder (``?`` or ``%`` or ``_``) in a non-unary function call to allow non-unary functions in the chain; ``4 |> mkword(?, $44) |> print_uw`` - OR: change pipe syntax and require function call, but always have implicit first argument added. - for the pipe operator: make it 100% syntactic sugar so there's no need for asm codegen like translatePipeExpression - make it possible to inline non-asmsub routines that just contain a single statement (return, functioncall, assignment) but this requires all identifiers in the inlined expression to be changed to fully scoped names. diff --git a/examples/test.p8 b/examples/test.p8 index 636f20447..c074a7e1b 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,20 +1,27 @@ %import textio +%import floats +%import test_stack +%zeropage basicsafe main { sub start() { - ubyte xx = 30 - ubyte cc - - cc=0 - cc = 30 |> sin8u |> cos8u |> cc - txt.print_ub(cc) - txt.nl() - cc=0 - cc = xx |> sin8u |> cos8u |> cc - txt.print_ub(cc) - txt.nl() - - repeat { - } + get_player(1) + |> determine_score() + |> add_bonus() + |> txt.print_uw() } + + sub get_player(ubyte xx) -> ubyte { + return xx+33 + } + + sub determine_score() -> ubyte { + return 33 + } + + sub add_bonus(ubyte qq) { + qq++ + } + + }