fix pipe check for number of args

This commit is contained in:
Irmen de Jong 2022-03-02 21:29:09 +01:00
parent fc1c3c6808
commit 38beebe720
4 changed files with 75 additions and 39 deletions

View File

@ -10,7 +10,6 @@ import prog8.ast.expressions.StringLiteral
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.IAstVisitor import prog8.ast.walk.IAstVisitor
import prog8.compilerinterface.BuiltinFunctions import prog8.compilerinterface.BuiltinFunctions
import prog8.compilerinterface.Encoding
import prog8.compilerinterface.ICompilationTarget import prog8.compilerinterface.ICompilationTarget
import prog8.compilerinterface.IErrorReporter import prog8.compilerinterface.IErrorReporter
@ -134,28 +133,32 @@ internal class AstIdentifiersChecker(private val errors: IErrorReporter,
override fun visit(functionCallStatement: FunctionCallStatement) = visitFunctionCall(functionCallStatement) override fun visit(functionCallStatement: FunctionCallStatement) = visitFunctionCall(functionCallStatement)
private fun visitFunctionCall(call: IFunctionCall) { private fun visitFunctionCall(call: IFunctionCall) {
val isPartOfPipeSegments = (call.parent as? IPipe)?.segments?.contains(call as Node) == true
val errormessageAboutArgs = if(isPartOfPipeSegments) "invalid number of arguments in piped call" else "invalid number of arguments"
when (val target = call.target.targetStatement(program)) { when (val target = call.target.targetStatement(program)) {
is Subroutine -> { is Subroutine -> {
// if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters // if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters
val expectedNumberOfArgs = if(call.parent is IPipe) val expectedNumberOfArgs: Int = if(isPartOfPipeSegments) {
target.parameters.size-1 target.parameters.size - 1
else } else {
target.parameters.size target.parameters.size
}
if(call.args.size != expectedNumberOfArgs) { if(call.args.size != expectedNumberOfArgs) {
val pos = (if(call.args.any()) call.args[0] else (call as Node)).position val pos = (if(call.args.any()) call.args[0] else (call as Node)).position
errors.err("invalid number of arguments", pos) errors.err(errormessageAboutArgs, pos)
} }
} }
is BuiltinFunctionPlaceholder -> { is BuiltinFunctionPlaceholder -> {
val func = BuiltinFunctions.getValue(target.name) val func = BuiltinFunctions.getValue(target.name)
// if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters // if the call is part of a Pipe, the number of arguments in the call should be 1 less than the number of parameters
val expectedNumberOfArgs = if(call.parent is IPipe) val expectedNumberOfArgs: Int = if(isPartOfPipeSegments) {
func.parameters.size-1 func.parameters.size-1
else } else {
func.parameters.size func.parameters.size
}
if(call.args.size != expectedNumberOfArgs) { if(call.args.size != expectedNumberOfArgs) {
val pos = (if(call.args.any()) call.args[0] else (call as Node)).position val pos = (if(call.args.any()) call.args[0] else (call as Node)).position
errors.err("invalid number of arguments", pos) errors.err(errormessageAboutArgs, pos)
} }
if(func.name=="memory") { if(func.name=="memory") {
val name = call.args[0] as? StringLiteral val name = call.args[0] as? StringLiteral

View File

@ -56,14 +56,15 @@ internal class VerifyFunctionArgTypes(val program: Program, val errors: IErrorRe
val argtypes = argITypes.map { it.getOr(DataType.UNDEFINED) } val argtypes = argITypes.map { it.getOr(DataType.UNDEFINED) }
val target = call.target.targetStatement(program) val target = call.target.targetStatement(program)
val isPartOfPipeSegments = (call.parent as? IPipe)?.segments?.contains(call as Node) == true val isPartOfPipeSegments = (call.parent as? IPipe)?.segments?.contains(call as Node) == true
val errormessageAboutArgs = if(isPartOfPipeSegments) "invalid number of arguments in piped call" else "invalid number of arguments"
if (target is Subroutine) { if (target is Subroutine) {
val consideredParamTypes = if(isPartOfPipeSegments) { val consideredParamTypes: List<DataType> = if(isPartOfPipeSegments) {
target.parameters.drop(1).map { it.type } // skip first one (the implicit first arg), this is checked elsewhere target.parameters.drop(1).map { it.type } // skip first one (the implicit first arg), this is checked elsewhere
} else { } else {
target.parameters.map { it.type } target.parameters.map { it.type }
} }
if(argtypes.size != consideredParamTypes.size) if(argtypes.size != consideredParamTypes.size)
return Pair("invalid number of arguments", call.position) return Pair(errormessageAboutArgs, call.position)
val mismatch = argtypes.zip(consideredParamTypes).indexOfFirst { !argTypeCompatible(it.first, it.second) } val mismatch = argtypes.zip(consideredParamTypes).indexOfFirst { !argTypeCompatible(it.first, it.second) }
if(mismatch>=0) { if(mismatch>=0) {
val actual = argtypes[mismatch].toString() val actual = argtypes[mismatch].toString()
@ -90,13 +91,13 @@ internal class VerifyFunctionArgTypes(val program: Program, val errors: IErrorRe
} }
else if (target is BuiltinFunctionPlaceholder) { else if (target is BuiltinFunctionPlaceholder) {
val func = BuiltinFunctions.getValue(target.name) val func = BuiltinFunctions.getValue(target.name)
val consideredParamTypes = if(isPartOfPipeSegments) { val consideredParamTypes: List<Array<DataType>> = if(isPartOfPipeSegments) {
func.parameters.drop(1).map { it.possibleDatatypes } // skip first one (the implicit first arg), this is checked elsewhere func.parameters.drop(1).map { it.possibleDatatypes } // skip first one (the implicit first arg), this is checked elsewhere
} else { } else {
func.parameters.map { it.possibleDatatypes } func.parameters.map { it.possibleDatatypes }
} }
if(argtypes.size != consideredParamTypes.size) if(argtypes.size != consideredParamTypes.size)
return Pair("invalid number of arguments", call.position) return Pair(errormessageAboutArgs, call.position)
argtypes.zip(consideredParamTypes).forEachIndexed { index, pair -> argtypes.zip(consideredParamTypes).forEachIndexed { index, pair ->
val anyCompatible = pair.second.any { argTypeCompatible(pair.first, it) } val anyCompatible = pair.second.any { argTypeCompatible(pair.first, it) }
if (!anyCompatible) { if (!anyCompatible) {

View File

@ -36,8 +36,7 @@ class TestPipes: FunSpec({
sub func2(uword arg) -> uword { sub func2(uword arg) -> uword {
return arg+2222 return arg+2222
} }
} }"""
"""
val src = SourceCode.Text(text) val src = SourceCode.Text(text)
val module = parseModule(src) val module = parseModule(src)
val errors = ErrorReporterForTests() val errors = ErrorReporterForTests()
@ -86,8 +85,7 @@ class TestPipes: FunSpec({
sub func3(uword arg) { sub func3(uword arg) {
; nothing ; nothing
} }
} }"""
"""
val src = SourceCode.Text(text) val src = SourceCode.Text(text)
val module = parseModule(src) val module = parseModule(src)
val errors = ErrorReporterForTests() val errors = ErrorReporterForTests()
@ -130,7 +128,7 @@ class TestPipes: FunSpec({
1.234 |> addfloat() 1.234 |> addfloat()
|> floats.print_f() |> floats.print_f()
9999 |> addword() startvalue(99) |> addword()
|> txt.print_uw() |> txt.print_uw()
9999 |> abs() |> txt.print_uw() 9999 |> abs() |> txt.print_uw()
@ -139,14 +137,16 @@ class TestPipes: FunSpec({
99 |> txt.print_ub() 99 |> txt.print_ub()
} }
sub startvalue(ubyte arg) -> uword {
return arg+9999
}
sub addfloat(float fl) -> float { sub addfloat(float fl) -> float {
return fl+2.22 return fl+2.22
} }
sub addword(uword ww) -> uword { sub addword(uword ww) -> uword {
return ww+2222 return ww+2222
} }
} }"""
"""
val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess() val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess()
val stmts = result.program.entrypoint.statements val stmts = result.program.entrypoint.statements
stmts.size shouldBe 7 stmts.size shouldBe 7
@ -183,7 +183,7 @@ class TestPipes: FunSpec({
1.234 |> addfloat() 1.234 |> addfloat()
|> floats.print_f() |> floats.print_f()
9999 |> addword() startvalue(99) |> addword()
|> txt.print_uw() |> txt.print_uw()
; these should be optimized into just the function calls: ; these should be optimized into just the function calls:
@ -193,14 +193,16 @@ class TestPipes: FunSpec({
99 |> txt.print_ub() 99 |> txt.print_ub()
} }
sub startvalue(ubyte arg) -> uword {
return arg+9999
}
sub addfloat(float fl) -> float { sub addfloat(float fl) -> float {
return fl+2.22 return fl+2.22
} }
sub addword(uword ww) -> uword { sub addword(uword ww) -> uword {
return ww+2222 return ww+2222
} }
} }"""
"""
val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess() val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess()
val stmts = result.program.entrypoint.statements val stmts = result.program.entrypoint.statements
stmts.size shouldBe 7 stmts.size shouldBe 7
@ -213,9 +215,9 @@ class TestPipes: FunSpec({
val pipew = stmts[1] as Pipe val pipew = stmts[1] as Pipe
pipef.source shouldBe instanceOf<FunctionCallExpression>() pipef.source shouldBe instanceOf<FunctionCallExpression>()
(pipew.source as IFunctionCall).target.nameInSource shouldBe listOf("addword") (pipew.source as IFunctionCall).target.nameInSource shouldBe listOf("startvalue")
pipew.segments.size shouldBe 1 pipew.segments.size shouldBe 2
val callw = pipew.segments[0] as IFunctionCall val callw = pipew.segments[1] as IFunctionCall
callw.target.nameInSource shouldBe listOf("txt", "print_uw") callw.target.nameInSource shouldBe listOf("txt", "print_uw")
var stmt = stmts[2] as FunctionCallStatement var stmt = stmts[2] as FunctionCallStatement
@ -245,8 +247,7 @@ class TestPipes: FunSpec({
sub addword(uword ww) -> uword { sub addword(uword ww) -> uword {
return ww+2222 return ww+2222
} }
} }"""
"""
val errors = ErrorReporterForTests() val errors = ErrorReporterForTests()
compileText(C64Target(), false, text, errors=errors).assertFailure() compileText(C64Target(), false, text, errors=errors).assertFailure()
errors.errors.size shouldBe 1 errors.errors.size shouldBe 1
@ -263,21 +264,23 @@ class TestPipes: FunSpec({
float @shared fl = 1.234 |> addfloat() float @shared fl = 1.234 |> addfloat()
|> addfloat() |> addfloat()
uword @shared ww = 9999 |> addword() uword @shared ww = startvalue(99) |> addword()
|> addword() |> addword()
ubyte @shared cc = 30 |> sin8u() |> cos8u() ubyte @shared cc = 30 |> sin8u() |> cos8u()
cc = cc |> sin8u() |> cos8u() cc = cc |> sin8u() |> cos8u()
} }
sub startvalue(ubyte arg) -> uword {
return arg+9999
}
sub addfloat(float fl) -> float { sub addfloat(float fl) -> float {
return fl+2.22 return fl+2.22
} }
sub addword(uword ww) -> uword { sub addword(uword ww) -> uword {
return ww+2222 return ww+2222
} }
} }"""
"""
val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess() val result = compileText(C64Target(), optimize = false, text, writeAssembly = true).assertSuccess()
val stmts = result.program.entrypoint.statements val stmts = result.program.entrypoint.statements
stmts.size shouldBe 8 stmts.size shouldBe 8
@ -293,7 +296,7 @@ class TestPipes: FunSpec({
val assignw = stmts[3] as Assignment val assignw = stmts[3] as Assignment
val pipew = assignw.value as PipeExpression val pipew = assignw.value as PipeExpression
pipew.source shouldBe instanceOf<NumericLiteral>() pipew.source shouldBe instanceOf<IFunctionCall>()
pipew.segments.size shouldBe 2 pipew.segments.size shouldBe 2
call = pipew.segments[0] as IFunctionCall call = pipew.segments[0] as IFunctionCall
call.target.nameInSource shouldBe listOf("addword") call.target.nameInSource shouldBe listOf("addword")
@ -327,21 +330,24 @@ class TestPipes: FunSpec({
float @shared fl = 1.234 |> addfloat() float @shared fl = 1.234 |> addfloat()
|> addfloat() |> addfloat()
uword @shared ww = 9999 |> addword() uword @shared ww = startvalue(99) |> addword()
|> addword() |> addword()
ubyte @shared cc = 30 |> sin8u() |> cos8u() ; will be optimized away into a const number ubyte @shared cc = 30 |> sin8u() |> cos8u() ; will be optimized away into a const number
cc = cc |> sin8u() |> cos8u() cc = cc |> sin8u() |> cos8u()
} }
sub startvalue(ubyte arg) -> uword {
return arg+9999
}
sub addfloat(float fl) -> float { sub addfloat(float fl) -> float {
return fl+2.22 return fl+2.22
} }
sub addword(uword ww) -> uword { sub addword(uword ww) -> uword {
return ww+2222 return ww+2222
} }
} }
""" """
val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess() val result = compileText(C64Target(), optimize = true, text, writeAssembly = true).assertSuccess()
val stmts = result.program.entrypoint.statements val stmts = result.program.entrypoint.statements
stmts.size shouldBe 8 stmts.size shouldBe 8
@ -354,8 +360,9 @@ class TestPipes: FunSpec({
val assignw = stmts[3] as Assignment val assignw = stmts[3] as Assignment
val pipew = assignw.value as PipeExpression val pipew = assignw.value as PipeExpression
pipew.source shouldBe instanceOf<FunctionCallExpression>() pipew.source shouldBe instanceOf<FunctionCallExpression>()
pipew.segments.size shouldBe 1 pipew.segments.size shouldBe 2
pipew.segments[0] shouldBe instanceOf<FunctionCallExpression>() pipew.segments[0] shouldBe instanceOf<FunctionCallExpression>()
pipew.segments[1] shouldBe instanceOf<FunctionCallExpression>()
var assigncc = stmts[5] as Assignment var assigncc = stmts[5] as Assignment
val value = assigncc.value as NumericLiteral val value = assigncc.value as NumericLiteral
@ -440,4 +447,29 @@ class TestPipes: FunSpec({
errors.errors[0] shouldContain "UWORD incompatible" errors.errors[0] shouldContain "UWORD incompatible"
errors.errors[1] shouldContain "UWORD incompatible" errors.errors[1] shouldContain "UWORD incompatible"
} }
test("pipe detects invalid number of args") {
val text = """
main {
sub start() {
uword ww = startvalue() |> addword()
|> addword()
ubyte cc = 30 |> sin8u(99) |> cos8u(22)
}
sub startvalue(ubyte arg) -> uword {
return arg+9999
}
sub addword(uword ww) -> uword {
return ww+2222
}
}"""
val errors = ErrorReporterForTests()
compileText(C64Target(), optimize = false, text, writeAssembly = false, errors=errors).assertFailure()
errors.errors.size shouldBe 3
errors.errors[0] shouldContain ":4:32: invalid number of arguments"
errors.errors[1] shouldContain ":7:44: invalid number of arguments"
errors.errors[2] shouldContain ":7:57: invalid number of arguments"
}
}) })

View File

@ -15,12 +15,12 @@ main {
return xx+33 return xx+33
} }
sub determine_score() -> ubyte { sub determine_score(ubyte zz) -> ubyte {
return 33 return zz+22
} }
sub add_bonus(ubyte qq) { sub add_bonus(ubyte qq) -> ubyte {
qq++ return qq+1
} }