diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index 05736a482..910e1b2d9 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -1208,12 +1208,26 @@ internal class AstChecker(private val program: Program, } override fun visit(containment: ContainmentCheck) { - if(!containment.iterable.inferType(program).isIterable) - errors.err("value set for containment check must be an iterable type", containment.iterable.position) + val elementDt = containment.element.inferType(program) + val iterableDt = containment.iterable.inferType(program) + if(containment.parent is BinaryExpression) errors.err("containment check is currently not supported in complex expressions", containment.position) - // TODO check that iterable contains the same types as the element that is searched + if(iterableDt.isIterable) { + val iterableEltDt = ArrayToElementTypes.getValue(iterableDt.getOr(DataType.UNDEFINED)) + val invalidDt = if (elementDt.isBytes) { + iterableEltDt !in ByteDatatypes + } else if (elementDt.isWords) { + iterableEltDt !in WordDatatypes + } else { + false + } + if (invalidDt) + errors.err("element datatype doesn't match iterable datatype", containment.position) + } else { + errors.err("value set for containment check must be an iterable type", containment.iterable.position) + } super.visit(containment) } diff --git a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt index 24ee31655..6ec61c430 100644 --- a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt +++ b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt @@ -134,11 +134,4 @@ _after: } return noModifications } - - override fun after(expr: BinaryExpression, parent: Node): Iterable { - if(expr.operator=="in") { - println("IN-TEST: $expr\n in: $parent") - } - return noModifications - } } diff --git a/compiler/test/ast/TestProg8Parser.kt b/compiler/test/ast/TestProg8Parser.kt index 0a32c757a..8bfe26dfc 100644 --- a/compiler/test/ast/TestProg8Parser.kt +++ b/compiler/test/ast/TestProg8Parser.kt @@ -800,4 +800,50 @@ class TestProg8Parser: FunSpec( { val encodedletter = Petscii.encodePetscii("A", true).getOrElse { fail("petscii error") }.single() letter.value shouldBe NumericLiteralValue(DataType.UBYTE, encodedletter.toDouble(), Position.DUMMY) } + + test("`in` containment checks") { + val text=""" + main { + sub start() { + str string = "hello" + ubyte[] array = [1,2,3,4] + + ubyte cc + if cc in [' ', '@', 0] { + } + + if cc in "email" { + } + + cc = 99 in array + cc = '@' in string + } + } + """ + val result = compileText(C64Target, false, text, writeAssembly = false).assertSuccess() + val start = result.program.entrypoint + val containmentChecks = start.statements.takeLast(4) + (containmentChecks[0] as IfElse).condition shouldBe instanceOf() + (containmentChecks[1] as IfElse).condition shouldBe instanceOf() + (containmentChecks[2] as Assignment).value shouldBe instanceOf() + (containmentChecks[3] as Assignment).value shouldBe instanceOf() + } + + test("invalid `in` containment checks") { + val text=""" + main { + sub start() { + ubyte cc + ubyte[] array = [1,2,3] + cc = 99 in 12345 + cc = 9999 in array + } + } + """ + val errors = ErrorReporterForTests() + compileText(C64Target, false, text, writeAssembly = false, errors = errors).assertFailure() + errors.errors.size shouldBe 2 + errors.errors[0] shouldContain "must be an iterable type" + errors.errors[1] shouldContain "datatype doesn't match" + } }) diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 97feb33f9..f8e7d0094 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -3,8 +3,7 @@ TODO For next compiler release (7.7) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- add unit tests for correctly parsing the "in" operator - +... Need help with ^^^^^^^^^^^^^^