allow containment check in a range expression ("run time" range expression)

This commit is contained in:
Irmen de Jong 2024-01-03 01:17:09 +01:00
parent 517ea82b99
commit 6aed7e429a
9 changed files with 149 additions and 45 deletions

View File

@ -146,6 +146,15 @@ fun compileProgram(args: CompilerArguments): CompilationResult? {
return null return null
} }
ast = intermediateAst ast = intermediateAst
} else {
if(args.printAst1) {
println("\n*********** COMPILER AST *************")
printProgram(program)
println("*********** COMPILER AST END *************\n")
}
if(args.printAst2) {
System.err.println("There is no intermediate Ast available if assembly generation is disabled.")
}
} }
} }

View File

@ -1447,7 +1447,8 @@ internal class AstChecker(private val program: Program,
} }
} }
if(iterableDt.isIterable && containment.iterable !is RangeExpression) { if (iterableDt.isIterable) {
if (containment.iterable !is RangeExpression) {
val iterableEltDt = ArrayToElementTypes.getValue(iterableDt.getOr(DataType.UNDEFINED)) val iterableEltDt = ArrayToElementTypes.getValue(iterableDt.getOr(DataType.UNDEFINED))
val invalidDt = if (elementDt.isBytes) { val invalidDt = if (elementDt.isBytes) {
iterableEltDt !in ByteDatatypes iterableEltDt !in ByteDatatypes
@ -1458,8 +1459,9 @@ internal class AstChecker(private val program: Program,
} }
if (invalidDt) if (invalidDt)
errors.err("element datatype doesn't match iterable datatype", containment.position) errors.err("element datatype doesn't match iterable datatype", containment.position)
}
} else { } else {
errors.err("value set for containment check must be a string or array", containment.iterable.position) errors.err("iterable must be an array, a string, or a range expression", containment.iterable.position)
} }
super.visit(containment) super.visit(containment)

View File

@ -5,8 +5,6 @@ import prog8.ast.Node
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.FatalAstException import prog8.ast.base.FatalAstException
import prog8.ast.expressions.BinaryExpression import prog8.ast.expressions.BinaryExpression
import prog8.ast.expressions.ContainmentCheck
import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.NumericLiteral import prog8.ast.expressions.NumericLiteral
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.ast.walk.AstWalker import prog8.ast.walk.AstWalker
@ -28,12 +26,6 @@ internal class BeforeAsmAstChanger(val program: Program, private val options: Co
throw InternalCompilerException("do..until should have been converted to jumps") throw InternalCompilerException("do..until should have been converted to jumps")
} }
override fun after(containment: ContainmentCheck, parent: Node): Iterable<IAstModification> {
if(containment.iterable !is IdentifierReference)
throw InternalCompilerException("iterable in containmentcheck should be identifier (referencing string or array)")
return noModifications
}
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> { override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> {
if (decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes) if (decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes)
throw InternalCompilerException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl") throw InternalCompilerException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl")

View File

@ -583,15 +583,73 @@ class IntermediateAstMaker(private val program: Program, private val errors: IEr
return call return call
} }
private fun transform(srcCheck: ContainmentCheck): PtContainmentCheck { private fun transform(srcCheck: ContainmentCheck): PtExpression {
fun desugar(range: RangeExpression): PtExpression {
val expr = PtBinaryExpression("and", DataType.UBYTE, srcCheck.position)
val x1 = transformExpression(srcCheck.element)
val x2 = transformExpression(srcCheck.element)
val eltDt = srcCheck.element.inferType(program)
if(eltDt.isInteger) {
val low = PtBinaryExpression("<=", DataType.UBYTE, srcCheck.position)
low.add(transformExpression(range.from))
low.add(x1)
expr.add(low)
val high = PtBinaryExpression("<=", DataType.UBYTE, srcCheck.position)
high.add(x2)
high.add(transformExpression(range.to))
expr.add(high)
} else {
val low = PtBinaryExpression("<=", DataType.UBYTE, srcCheck.position)
val lowFloat = PtTypeCast(DataType.FLOAT, range.from.position)
lowFloat.add(transformExpression(range.from))
low.add(lowFloat)
low.add(x1)
expr.add(low)
val high = PtBinaryExpression("<=", DataType.UBYTE, srcCheck.position)
high.add(x2)
val highFLoat = PtTypeCast(DataType.FLOAT, range.to.position)
highFLoat.add(transformExpression(range.to))
high.add(highFLoat)
expr.add(high)
}
return expr
}
when(srcCheck.iterable) {
is IdentifierReference -> {
val check = PtContainmentCheck(srcCheck.position) val check = PtContainmentCheck(srcCheck.position)
check.add(transformExpression(srcCheck.element)) check.add(transformExpression(srcCheck.element))
if(srcCheck.iterable !is IdentifierReference)
throw FatalAstException("iterable in containmentcheck must always be an identifier (referencing string or array) $srcCheck")
val iterable = transformExpression(srcCheck.iterable) val iterable = transformExpression(srcCheck.iterable)
check.add(iterable) check.add(iterable)
return check return check
} }
is RangeExpression -> {
val range = srcCheck.iterable as RangeExpression
val constRange = range.toConstantIntegerRange()
val constElt = srcCheck.element.constValue(program)?.number
val step = range.step.constValue(program)?.number
if(constElt!=null && constRange!=null) {
return PtNumber(DataType.UBYTE, if(constRange.first<=constElt && constElt<=constRange.last) 1.0 else 0.0, srcCheck.position)
}
else if(step==1.0) {
// x in low to high --> low <=x and x <= high
return desugar(range)
} else if(step==-1.0) {
// x in high downto low -> low <=x and x <= high
val tmp = range.to
range.to = range.from
range.from = tmp
return desugar(range)
} else {
errors.err("cannot use step size different than 1 or -1 in a non constant range containment check", srcCheck.position)
return PtNumber(DataType.BYTE, 0.0, Position.DUMMY)
}
}
else -> throw FatalAstException("iterable in containmentcheck must always be an identifier (referencing string or array) or a range expression $srcCheck")
}
}
private fun transform(memory: DirectMemoryWrite): PtMemoryByte { private fun transform(memory: DirectMemoryWrite): PtMemoryByte {
val mem = PtMemoryByte(memory.position) val mem = PtMemoryByte(memory.position)

View File

@ -404,6 +404,14 @@ class TestCompilerOnRanges: FunSpec({
if ww in wvalues { if ww in wvalues {
xx++ xx++
} }
if xx in 10 to 20 {
xx++
}
if ww in 1000 to 2000 {
xx++
}
} }
}""", writeAssembly = true) shouldNotBe null }""", writeAssembly = true) shouldNotBe null
} }

View File

@ -924,7 +924,7 @@ class TestProg8Parser: FunSpec( {
val errors = ErrorReporterForTests() val errors = ErrorReporterForTests()
compileText(C64Target(), false, text, writeAssembly = false, errors = errors) shouldBe null compileText(C64Target(), false, text, writeAssembly = false, errors = errors) shouldBe null
errors.errors.size shouldBe 2 errors.errors.size shouldBe 2
errors.errors[0] shouldContain "must be a string or array" errors.errors[0] shouldContain "iterable must be"
errors.errors[1] shouldContain "datatype doesn't match" errors.errors[1] shouldContain "datatype doesn't match"
} }

View File

@ -611,10 +611,11 @@ range creation: ``to``, ``downto``
See :ref:`range-expression` for details. See :ref:`range-expression` for details.
containment check: ``in`` containment check: ``in``
Tests if a value is present in a list of values, which can be a string or an array. Tests if a value is present in a list of values, which can be a string, or an array, or a range expression.
The result is a simple boolean ``true`` or ``false``. The result is a simple boolean ``true`` or ``false``.
Consider using this instead of chaining multiple value tests with ``or``, because the Consider using this instead of chaining multiple value tests with ``or``, because the
containment check is more efficient. containment check is more efficient.
Checking N in a range from x to y, is identical to x<=N and N<=y; the actual range of values is never created.
Examples:: Examples::
ubyte cc ubyte cc
@ -622,6 +623,10 @@ containment check: ``in``
txt.print("cc is one of the values") txt.print("cc is one of the values")
} }
if cc in 10 to 20 {
txt.print("10 <= cc and cc <=20")
}
str email_address = "name@test.com" str email_address = "name@test.com"
if '@' in email_address { if '@' in email_address {
txt.print("email address seems ok") txt.print("email address seems ok")

View File

@ -76,10 +76,3 @@ Other language/syntax features to think about
- add (rom/ram)bank support to romsub. A call will then automatically switch banks, use callfar and something else when in banked ram. - add (rom/ram)bank support to romsub. A call will then automatically switch banks, use callfar and something else when in banked ram.
challenges: how to not make this too X16 specific? How does the compiler know what bank to switch (ram/rom)? challenges: how to not make this too X16 specific? How does the compiler know what bank to switch (ram/rom)?
How to make it performant when we want to (i.e. NOT have it use callfar/auto bank switching) ? How to make it performant when we want to (i.e. NOT have it use callfar/auto bank switching) ?
- chained comparisons `10<x<20` , `x==y==z` (desugars to `10<x and x<20`, `x==y and y==z`)
BUT this needs a new AST node type and rewritten parser rules, because otherwise it changes the semantics
of existing expressions such as if x<y==0 ...
- Better idea perhaps is "runtime range objects" the idea would be that "a to b" generates
a new kind of "range" value rather than an array (though you can still use it to initialize an array),
so you could replace if a <= n <= b with: if n in a to b
So this means we should keep the Range Expression alive for much longer.

View File

@ -1,17 +1,54 @@
%import textio %import textio
%import floats
%zeropage basicsafe %zeropage basicsafe
main { main {
sub start() { sub start() {
ubyte @shared n=20 ubyte [] array = 100 to 110
ubyte @shared x=10
if n < x { for cx16.r0L in array {
; nothing here, conditional gets inverted txt.print_ub(cx16.r0L)
} else { txt.spc()
cx16.r0++ }
txt.nl()
ubyte x = 14
if x in 10 to 20 {
txt.print("yep1\n")
}
if x in 20 to 30 {
txt.print("yep2\n")
}
if x in 10 to 20 step 2 {
txt.print("yep1b\n")
}
if x in 20 to 30 step 2 {
txt.print("yep2b\n")
}
if x in 20 to 10 step -2 {
txt.print("yep1c\n")
}
if x in 30 to 20 step -2 {
txt.print("yep2c\n")
}
txt.nl()
ubyte @shared y = 12
if y in 10 to 20 {
txt.print("yep1\n")
}
if y in 20 to 30 {
txt.print("yep2\n")
}
if y in 20 downto 10 {
txt.print("yep1c\n")
}
if y in 30 downto 20 {
txt.print("yep2c\n")
} }
cx16.r0L = n<x == 0
cx16.r1L = not n<x
} }
} }