diff --git a/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt index a05e47dc9..ebfb02bbf 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstIdentifiersChecker.kt @@ -10,7 +10,6 @@ import prog8.ast.expressions.StringLiteral import prog8.ast.statements.* import prog8.ast.walk.IAstVisitor import prog8.compilerinterface.BuiltinFunctions -import prog8.compilerinterface.Encoding import prog8.compilerinterface.ICompilationTarget import prog8.compilerinterface.IErrorReporter @@ -134,28 +133,32 @@ internal class AstIdentifiersChecker(private val errors: IErrorReporter, override fun visit(functionCallStatement: FunctionCallStatement) = visitFunctionCall(functionCallStatement) private fun visitFunctionCall(call: IFunctionCall) { + val isPartOfPipeSegments = (call.parent as? IPipe)?.segments?.contains(call as Node) == true + val errormessageAboutArgs = if(isPartOfPipeSegments) "invalid number of arguments in piped call" else "invalid number of arguments" when (val target = call.target.targetStatement(program)) { is Subroutine -> { // if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters - val expectedNumberOfArgs = if(call.parent is IPipe) - target.parameters.size-1 - else + val expectedNumberOfArgs: Int = if(isPartOfPipeSegments) { + target.parameters.size - 1 + } else { target.parameters.size + } if(call.args.size != expectedNumberOfArgs) { val pos = (if(call.args.any()) call.args[0] else (call as Node)).position - errors.err("invalid number of arguments", pos) + errors.err(errormessageAboutArgs, pos) } } is BuiltinFunctionPlaceholder -> { val func = BuiltinFunctions.getValue(target.name) // if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters - val expectedNumberOfArgs = if(call.parent is IPipe) + val expectedNumberOfArgs: Int = if(isPartOfPipeSegments) { func.parameters.size-1 - else + } else { func.parameters.size + } if(call.args.size != expectedNumberOfArgs) { val pos = (if(call.args.any()) call.args[0] else (call as Node)).position - errors.err("invalid number of arguments", pos) + errors.err(errormessageAboutArgs, pos) } if(func.name=="memory") { val name = call.args[0] as? StringLiteral diff --git a/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt b/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt index f0055f92d..3e70b7895 100644 --- a/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt +++ b/compiler/src/prog8/compiler/astprocessing/VerifyFunctionArgTypes.kt @@ -56,14 +56,15 @@ internal class VerifyFunctionArgTypes(val program: Program, val errors: IErrorRe val argtypes = argITypes.map { it.getOr(DataType.UNDEFINED) } val target = call.target.targetStatement(program) val isPartOfPipeSegments = (call.parent as? IPipe)?.segments?.contains(call as Node) == true + val errormessageAboutArgs = if(isPartOfPipeSegments) "invalid number of arguments in piped call" else "invalid number of arguments" if (target is Subroutine) { - val consideredParamTypes = if(isPartOfPipeSegments) { + val consideredParamTypes: List = if(isPartOfPipeSegments) { target.parameters.drop(1).map { it.type } // skip first one (the implicit first arg), this is checked elsewhere } else { target.parameters.map { it.type } } if(argtypes.size != consideredParamTypes.size) - return Pair("invalid number of arguments", call.position) + return Pair(errormessageAboutArgs, call.position) val mismatch = argtypes.zip(consideredParamTypes).indexOfFirst { !argTypeCompatible(it.first, it.second) } if(mismatch>=0) { val actual = argtypes[mismatch].toString() @@ -90,13 +91,13 @@ internal class VerifyFunctionArgTypes(val program: Program, val errors: IErrorRe } else if (target is BuiltinFunctionPlaceholder) { val func = BuiltinFunctions.getValue(target.name) - val consideredParamTypes = if(isPartOfPipeSegments) { + val consideredParamTypes: List> = if(isPartOfPipeSegments) { func.parameters.drop(1).map { it.possibleDatatypes } // skip first one (the implicit first arg), this is checked elsewhere } else { func.parameters.map { it.possibleDatatypes } } if(argtypes.size != consideredParamTypes.size) - return Pair("invalid number of arguments", call.position) + return Pair(errormessageAboutArgs, call.position) argtypes.zip(consideredParamTypes).forEachIndexed { index, pair -> val anyCompatible = pair.second.any { argTypeCompatible(pair.first, it) } if (!anyCompatible) { diff --git a/compiler/test/TestPipes.kt b/compiler/test/TestPipes.kt index f59a790db..b15e70805 100644 --- a/compiler/test/TestPipes.kt +++ b/compiler/test/TestPipes.kt @@ -36,8 +36,7 @@ class TestPipes: FunSpec({ sub func2(uword arg) -> uword { return arg+2222 } - } - """ + }""" val src = SourceCode.Text(text) val module = parseModule(src) val errors = ErrorReporterForTests() @@ -86,8 +85,7 @@ class TestPipes: FunSpec({ sub func3(uword arg) { ; nothing } - } - """ + }""" val src = SourceCode.Text(text) val module = parseModule(src) val errors = ErrorReporterForTests() @@ -130,7 +128,7 @@ class TestPipes: FunSpec({ 1.234 |> addfloat() |> floats.print_f() - 9999 |> addword() + startvalue(99) |> addword() |> txt.print_uw() 9999 |> abs() |> txt.print_uw() @@ -139,14 +137,16 @@ class TestPipes: FunSpec({ 99 |> txt.print_ub() } + sub startvalue(ubyte arg) -> uword { + return arg+9999 + } sub addfloat(float fl) -> float { return fl+2.22 } sub addword(uword ww) -> uword { return ww+2222 } - } - """ + }""" val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess() val stmts = result.program.entrypoint.statements stmts.size shouldBe 7 @@ -183,7 +183,7 @@ class TestPipes: FunSpec({ 1.234 |> addfloat() |> floats.print_f() - 9999 |> addword() + startvalue(99) |> addword() |> txt.print_uw() ; these should be optimized into just the function calls: @@ -193,14 +193,16 @@ class TestPipes: FunSpec({ 99 |> txt.print_ub() } + sub startvalue(ubyte arg) -> uword { + return arg+9999 + } sub addfloat(float fl) -> float { return fl+2.22 } sub addword(uword ww) -> uword { return ww+2222 } - } - """ + }""" val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess() val stmts = result.program.entrypoint.statements stmts.size shouldBe 7 @@ -213,9 +215,9 @@ class TestPipes: FunSpec({ val pipew = stmts[1] as Pipe pipef.source shouldBe instanceOf() - (pipew.source as IFunctionCall).target.nameInSource shouldBe listOf("addword") - pipew.segments.size shouldBe 1 - val callw = pipew.segments[0] as IFunctionCall + (pipew.source as IFunctionCall).target.nameInSource shouldBe listOf("startvalue") + pipew.segments.size shouldBe 2 + val callw = pipew.segments[1] as IFunctionCall callw.target.nameInSource shouldBe listOf("txt", "print_uw") var stmt = stmts[2] as FunctionCallStatement @@ -245,8 +247,7 @@ class TestPipes: FunSpec({ sub addword(uword ww) -> uword { return ww+2222 } - } - """ + }""" val errors = ErrorReporterForTests() compileText(C64Target(), false, text, errors=errors).assertFailure() errors.errors.size shouldBe 1 @@ -263,21 +264,23 @@ class TestPipes: FunSpec({ float @shared fl = 1.234 |> addfloat() |> addfloat() - uword @shared ww = 9999 |> addword() + uword @shared ww = startvalue(99) |> addword() |> addword() ubyte @shared cc = 30 |> sin8u() |> cos8u() cc = cc |> sin8u() |> cos8u() } + sub startvalue(ubyte arg) -> uword { + return arg+9999 + } sub addfloat(float fl) -> float { return fl+2.22 } sub addword(uword ww) -> uword { return ww+2222 } - } - """ + }""" val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess() val stmts = result.program.entrypoint.statements stmts.size shouldBe 8 @@ -293,7 +296,7 @@ class TestPipes: FunSpec({ val assignw = stmts[3] as Assignment val pipew = assignw.value as PipeExpression - pipew.source shouldBe instanceOf() + pipew.source shouldBe instanceOf() pipew.segments.size shouldBe 2 call = pipew.segments[0] as IFunctionCall call.target.nameInSource shouldBe listOf("addword") @@ -327,21 +330,24 @@ class TestPipes: FunSpec({ float @shared fl = 1.234 |> addfloat() |> addfloat() - uword @shared ww = 9999 |> addword() + uword @shared ww = startvalue(99) |> addword() |> addword() ubyte @shared cc = 30 |> sin8u() |> cos8u() ; will be optimized away into a const number cc = cc |> sin8u() |> cos8u() } + sub startvalue(ubyte arg) -> uword { + return arg+9999 + } sub addfloat(float fl) -> float { return fl+2.22 } sub addword(uword ww) -> uword { return ww+2222 } - } - """ + } + """ val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess() val stmts = result.program.entrypoint.statements stmts.size shouldBe 8 @@ -354,8 +360,9 @@ class TestPipes: FunSpec({ val assignw = stmts[3] as Assignment val pipew = assignw.value as PipeExpression pipew.source shouldBe instanceOf() - pipew.segments.size shouldBe 1 + pipew.segments.size shouldBe 2 pipew.segments[0] shouldBe instanceOf() + pipew.segments[1] shouldBe instanceOf() var assigncc = stmts[5] as Assignment val value = assigncc.value as NumericLiteral @@ -440,4 +447,29 @@ class TestPipes: FunSpec({ errors.errors[0] shouldContain "UWORD incompatible" errors.errors[1] shouldContain "UWORD incompatible" } + + test("pipe detects invalid number of args") { + val text = """ + main { + sub start() { + uword ww = startvalue() |> addword() + |> addword() + + ubyte cc = 30 |> sin8u(99) |> cos8u(22) + } + + sub startvalue(ubyte arg) -> uword { + return arg+9999 + } + sub addword(uword ww) -> uword { + return ww+2222 + } + }""" + val errors = ErrorReporterForTests() + compileText(C64Target(), optimize = false, text, writeAssembly = false, errors=errors).assertFailure() + errors.errors.size shouldBe 3 + errors.errors[0] shouldContain ":4:32: invalid number of arguments" + errors.errors[1] shouldContain ":7:44: invalid number of arguments" + errors.errors[2] shouldContain ":7:57: invalid number of arguments" + } }) diff --git a/examples/test.p8 b/examples/test.p8 index c074a7e1b..964a882b2 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -15,12 +15,12 @@ main { return xx+33 } - sub determine_score() -> ubyte { - return 33 + sub determine_score(ubyte zz) -> ubyte { + return zz+22 } - sub add_bonus(ubyte qq) { - qq++ + sub add_bonus(ubyte qq) -> ubyte { + return qq+1 }