simplified containment check, only possible on string and arrays (as per the docs)

This commit is contained in:
Irmen de Jong 2022-03-27 15:23:32 +02:00
parent e41d6787bb
commit 3b6e7eccdd
17 changed files with 176 additions and 211 deletions

View File

@ -26,7 +26,7 @@ class PtArrayIndexer(type: DataType, position: Position): PtExpression(type, pos
}
class PtArrayLiteral(type: DataType, position: Position): PtExpression(type, position)
class PtArray(type: DataType, position: Position): PtExpression(type, position)
class PtBuiltinFunctionCall(val name: String, val void: Boolean, type: DataType, position: Position) : PtExpression(type, position) {
@ -56,10 +56,10 @@ class PtBinaryExpression(val operator: String, type: DataType, position: Positio
class PtContainmentCheck(position: Position): PtExpression(DataType.UBYTE, position) {
val element: PtExpression
val element: PtExpression // either a PtIdentifier or PtNumber
get() = children[0] as PtExpression
val iterable: PtIdentifier
get() = children[0] as PtIdentifier
get() = children[1] as PtIdentifier
}

View File

@ -333,15 +333,9 @@ internal class AssignmentAsmGen(private val program: Program,
private fun containmentCheckIntoA(containment: ContainmentCheck) {
val elementDt = containment.element.inferType(program)
val range = containment.iterable as? RangeExpression
if(range!=null) {
val constRange = range.toConstantIntegerRange()
if(constRange!=null)
return containmentCheckIntoA(containment.element, elementDt.getOr(DataType.UNDEFINED), constRange.toList())
throw AssemblyError("non const range containment check not supported")
}
val variable = (containment.iterable as? IdentifierReference)?.targetVarDecl(program)
if(variable!=null) {
?: throw AssemblyError("invalid containment iterable type")
if(elementDt istype DataType.FLOAT)
throw AssemblyError("containment check of floats not supported")
if(variable.origin!=VarDeclOrigin.USERCODE) {
@ -349,19 +343,12 @@ internal class AssignmentAsmGen(private val program: Program,
DataType.STR -> {
require(elementDt.isBytes)
val stringVal = variable.value as StringLiteral
if(stringVal.value.length > ContainmentCheck.max_inlined_string_length) {
// use subroutine
val varname = asmgen.asmVariableName(containment.iterable as IdentifierReference)
assignAddressOf(AsmAssignTarget(TargetStorageKind.VARIABLE, program, asmgen, DataType.UWORD, containment.definingSubroutine, "P8ZP_SCRATCH_W1"), varname)
assignExpressionToRegister(containment.element, RegisterOrPair.A, elementDt istype DataType.BYTE)
asmgen.out(" ldy #${stringVal.value.length}")
asmgen.out(" jsr prog8_lib.containment_bytearray")
return
} else {
// inline cmp table
val encoded = program.encoding.encodeString(stringVal.value, stringVal.encoding)
return containmentCheckIntoA(containment.element, elementDt.getOr(DataType.UNDEFINED), encoded.map { it.toInt() })
}
}
DataType.ARRAY_F -> {
// require(elementDt istype DataType.FLOAT)
@ -371,8 +358,6 @@ internal class AssignmentAsmGen(private val program: Program,
require(elementDt.isInteger)
val arrayVal = variable.value as ArrayLiteral
val dt = elementDt.getOr(DataType.UNDEFINED)
if(arrayVal.value.size > ContainmentCheck.max_inlined_string_length) {
// use subroutine
val varname = asmgen.asmVariableName(containment.iterable as IdentifierReference)
when(dt) {
in ByteDatatypes -> {
@ -390,11 +375,6 @@ internal class AssignmentAsmGen(private val program: Program,
else -> throw AssemblyError("invalid dt")
}
return
} else {
// inline cmp table
val values = arrayVal.value.map { it.constValue(program)!!.number.toInt() }
return containmentCheckIntoA(containment.element, dt, values)
}
}
else -> throw AssemblyError("invalid dt")
}
@ -430,61 +410,6 @@ internal class AssignmentAsmGen(private val program: Program,
else -> throw AssemblyError("invalid dt")
}
}
val stringVal = containment.iterable as? StringLiteral
if(stringVal!=null) {
require(elementDt.isBytes)
if(stringVal.value.length > ContainmentCheck.max_inlined_string_length)
throw AssemblyError("string should have been inlined in if it was this long")
val encoded = program.encoding.encodeString(stringVal.value, stringVal.encoding)
return containmentCheckIntoA(containment.element, elementDt.getOr(DataType.UNDEFINED), encoded.map { it.toInt() })
}
val arrayVal = containment.iterable as? ArrayLiteral
if(arrayVal!=null) {
require(elementDt.isInteger)
if(arrayVal.value.size > ContainmentCheck.max_inlined_string_length)
throw AssemblyError("array should have been inlined in if it was this long")
val values = arrayVal.value.map { it.constValue(program)!!.number.toInt() }
return containmentCheckIntoA(containment.element, elementDt.getOr(DataType.UNDEFINED), values)
}
throw AssemblyError("invalid containment iterable type")
}
private fun containmentCheckIntoA(element: Expression, dt: DataType, values: List<Number>) {
if(values.size<2)
throw AssemblyError("containment check against 0 or 1 values should have been optimized away")
val containsLabel = asmgen.makeLabel("contains")
when(dt) {
in ByteDatatypes -> {
asmgen.assignExpressionToRegister(element, RegisterOrPair.A, dt==DataType.BYTE)
for (value in values) {
asmgen.out(" cmp #$value | beq +")
}
asmgen.out("""
lda #0
beq ++
+ lda #1
+""")
}
in WordDatatypes -> {
asmgen.assignExpressionToRegister(element, RegisterOrPair.AY, dt==DataType.WORD)
for (value in values) {
asmgen.out("""
cmp #<$value
bne +
cpy #>$value
beq $containsLabel
+""")
}
asmgen.out("""
lda #0
beq +
$containsLabel lda #1
+""")
}
else -> throw AssemblyError("invalid dt")
}
}
private fun assignStatusFlagByte(target: AsmAssignTarget, statusflag: Statusflag) {
when(statusflag) {

View File

@ -157,7 +157,7 @@ class AstToXmlConverter(internal val program: PtProgram,
is PtAsmSub -> write(it)
is PtAddressOf -> write(it)
is PtArrayIndexer -> write(it)
is PtArrayLiteral -> write(it)
is PtArray -> write(it)
is PtBinaryExpression -> write(it)
is PtBuiltinFunctionCall -> write(it)
is PtConditionalBranch -> write(it)
@ -212,7 +212,7 @@ class AstToXmlConverter(internal val program: PtProgram,
xml.endElt()
}
private fun write(array: PtArrayLiteral) {
private fun write(array: PtArray) {
xml.elt("array")
xml.attr("type", array.type.name)
xml.startChildren()
@ -271,11 +271,11 @@ class AstToXmlConverter(internal val program: PtProgram,
xml.startChildren()
xml.elt("element")
xml.startChildren()
writeNode(check.element)
writeNode(check.children[0])
xml.endElt()
xml.elt("iterable")
xml.startChildren()
writeNode(check.iterable)
writeNode(check.children[1])
xml.endElt()
xml.endElt()
}

View File

@ -57,7 +57,7 @@ class CodeGen(internal val program: PtProgram,
is PtReturn -> translate(node)
is PtJump -> translate(node)
is PtWhen -> TODO()
is PtPipe -> TODO()
is PtPipe -> expressionEval.translate(node, regUsage.nextFree(), regUsage)
is PtForLoop -> TODO()
is PtIfElse -> translate(node, regUsage)
is PtPostIncrDecr -> translate(node, regUsage)
@ -78,7 +78,7 @@ class CodeGen(internal val program: PtProgram,
is PtTypeCast,
is PtSubroutineParameter,
is PtNumber,
is PtArrayLiteral,
is PtArray,
is PtString -> throw AssemblyError("strings should not occur as separate statement node ${node.position}")
is PtAsmSub -> throw AssemblyError("asmsub not supported on virtual machine target ${node.position}")
is PtInlineAssembly -> throw AssemblyError("inline assembly not supported on virtual machine target ${node.position}")

View File

@ -3,10 +3,7 @@ package prog8.codegen.virtual
import prog8.code.StStaticVariable
import prog8.code.StSub
import prog8.code.ast.*
import prog8.code.core.AssemblyError
import prog8.code.core.DataType
import prog8.code.core.PassByValueDatatypes
import prog8.code.core.SignedDatatypes
import prog8.code.core.*
import prog8.vm.Instruction
import prog8.vm.Opcode
import prog8.vm.VmDataType
@ -57,28 +54,48 @@ internal class ExpressionGen(val codeGen: CodeGen) {
is PtBuiltinFunctionCall -> code += translate(expr, resultRegister, regUsage)
is PtFunctionCall -> code += translate(expr, resultRegister, regUsage)
is PtContainmentCheck -> code += translate(expr, resultRegister, regUsage)
is PtPipe -> TODO()
is PtPipe -> code += translate(expr, resultRegister, regUsage)
is PtRange,
is PtArrayLiteral,
is PtArray,
is PtString -> throw AssemblyError("range/arrayliteral/string should no longer occur as expression")
else -> throw AssemblyError("weird expression")
}
return code
}
private fun translate(check: PtContainmentCheck, resultRegister: Int, regUsage: RegisterUsage): VmCodeChunk {
val iterableIdent = check.iterable
val iterable = codeGen.symbolTable.flat.getValue(iterableIdent.targetName) as StStaticVariable
when(iterable.dt) {
DataType.STR -> println("CONTAINMENT CHECK ${check.element} in string $iterable ${iterable.initialStringValue}")
DataType.ARRAY_UB -> println("CONTAINMENT CHECK ${check.element} in UB-array $iterable ${iterable.initialArrayValue}")
DataType.ARRAY_B -> println("CONTAINMENT CHECK ${check.element} in B-array $iterable ${iterable.initialArrayValue}")
DataType.ARRAY_UW -> println("CONTAINMENT CHECK ${check.element} in UW-array $iterable ${iterable.initialArrayValue}")
DataType.ARRAY_W -> println("CONTAINMENT CHECK ${check.element} in W-array $iterable ${iterable.initialArrayValue}")
DataType.ARRAY_F -> TODO("containment check in float-array")
else -> throw AssemblyError("weird iterable dt ${iterable.dt} for ${iterableIdent.targetName}")
internal fun translate(pipe: PtPipe, resultRegister: Int, regUsage: RegisterUsage): VmCodeChunk {
TODO("Not yet implemented: pipe expression")
}
return VmCodeChunk()
private fun translate(check: PtContainmentCheck, resultRegister: Int, regUsage: RegisterUsage): VmCodeChunk {
val code = VmCodeChunk()
code += translateExpression(check.element, resultRegister, regUsage) // load the element to check in resultRegister
val iterable = codeGen.symbolTable.flat.getValue(check.iterable.targetName) as StStaticVariable
when(iterable.dt) {
DataType.STR -> {
val call = PtFunctionCall(listOf("prog8_lib", "string_contains"), false, DataType.UBYTE, check.position)
call.children.add(check.element)
call.children.add(check.iterable)
code += translate(call, resultRegister, regUsage)
}
DataType.ARRAY_UB, DataType.ARRAY_B -> {
val call = PtFunctionCall(listOf("prog8_lib", "bytearray_contains"), false, DataType.UBYTE, check.position)
call.children.add(check.element)
call.children.add(check.iterable)
call.children.add(PtNumber(DataType.UBYTE, iterable.arraysize!!.toDouble(), iterable.position))
code += translate(call, resultRegister, regUsage)
}
DataType.ARRAY_UW, DataType.ARRAY_W -> {
val call = PtFunctionCall(listOf("prog8_lib", "wordarray_contains"), false, DataType.UBYTE, check.position)
call.children.add(check.element)
call.children.add(check.iterable)
call.children.add(PtNumber(DataType.UBYTE, iterable.arraysize!!.toDouble(), iterable.position))
code += translate(call, resultRegister, regUsage)
}
DataType.ARRAY_F -> TODO("containment check in float-array")
else -> throw AssemblyError("weird iterable dt ${iterable.dt} for ${check.iterable.targetName}")
}
return code
}
private fun translate(arrayIx: PtArrayIndexer, resultRegister: Int, regUsage: RegisterUsage): VmCodeChunk {
@ -287,7 +304,7 @@ internal class ExpressionGen(val codeGen: CodeGen) {
}
code += VmCodeInstruction(Instruction(Opcode.CALL), labelArg=fcall.functionName)
if(!fcall.void && resultRegister!=0) {
// Call convention: result value is in r0, so put it in the required register instead.
// Call convention: result value is in r0, so put it in the required register instead. TODO does this work correctly?
code += VmCodeInstruction(Instruction(Opcode.LOADR, codeGen.vmType(fcall.type), reg1=resultRegister, reg2=0))
}
return code
@ -336,6 +353,13 @@ internal class ExpressionGen(val codeGen: CodeGen) {
code += VmCodeInstruction(Instruction(Opcode.POP, VmDataType.WORD, reg1 = 1))
code += VmCodeInstruction(Instruction(Opcode.POP, VmDataType.BYTE, reg1 = 0))
}
"msb" -> {
code += translateExpression(call.args.single(), resultRegister, regUsage)
code += VmCodeInstruction(Instruction(Opcode.SWAP, VmDataType.BYTE, reg1 = resultRegister, reg2=resultRegister))
}
"lsb" -> {
code += translateExpression(call.args.single(), resultRegister, regUsage)
}
else -> {
// TODO builtin functions...
TODO("builtinfunc ${call.name}")

View File

@ -2,8 +2,39 @@
;
; Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
%import textio
prog8_lib {
%option force_output
; nothing here for now
sub string_contains(ubyte needle, str haystack) -> ubyte {
txt.print(">>>string elt check: ")
txt.print_ub(needle)
txt.spc()
txt.print_uwhex(haystack, true)
txt.nl()
return 0
}
sub bytearray_contains(ubyte needle, uword haystack_ptr, ubyte num_elements) -> ubyte {
txt.print(">>>bytearray elt check: ")
txt.print_ub(needle)
txt.spc()
txt.print_uwhex(haystack_ptr, true)
txt.spc()
txt.print_ub(num_elements)
txt.nl()
return 0
}
sub wordarray_contains(ubyte needle, uword haystack_ptr, ubyte num_elements) -> ubyte {
txt.print(">>>wordarray elt check: ")
txt.print_ub(needle)
txt.spc()
txt.print_uwhex(haystack_ptr, true)
txt.spc()
txt.print_ub(num_elements)
txt.nl()
return 0
}
}

View File

@ -64,9 +64,15 @@ sub print_b (byte value) {
; TODO use conv module?
}
str hex_digits = "0123456789abcdef"
sub print_ubhex (ubyte value, ubyte prefix) {
; ---- print the ubyte in hex form
; TODO use conv module?
if prefix
chrout('$')
chrout(hex_digits[value>>4])
chrout(hex_digits[value&15])
}
sub print_ubbin (ubyte value, ubyte prefix) {
@ -81,7 +87,8 @@ sub print_uwbin (uword value, ubyte prefix) {
sub print_uwhex (uword value, ubyte prefix) {
; ---- print the uword in hexadecimal form (4 digits)
; TODO use conv module?
print_ubhex(msb(value), true)
print_ubhex(lsb(value), false)
}
sub print_uw0 (uword value) {

View File

@ -83,8 +83,8 @@ fun compileProgram(args: CompilerArguments): CompilationResult? {
importedFiles = imported
processAst(program, args.errors, compilationOptions)
if (compilationOptions.optimize) {
println("*********** AST RIGHT BEFORE OPTIMIZING *************")
printProgram(program)
// println("*********** AST RIGHT BEFORE OPTIMIZING *************")
// printProgram(program)
optimizeAst(
program,

View File

@ -414,8 +414,8 @@ class IntermediateAstMaker(val program: Program) {
return array
}
private fun transform(srcArr: ArrayLiteral): PtArrayLiteral {
val arr = PtArrayLiteral(srcArr.inferType(program).getOrElse { throw FatalAstException("array must know its type") }, srcArr.position)
private fun transform(srcArr: ArrayLiteral): PtArray {
val arr = PtArray(srcArr.inferType(program).getOrElse { throw FatalAstException("array must know its type") }, srcArr.position)
for (elt in srcArr.value)
arr.add(transformExpression(elt))
return arr
@ -440,7 +440,10 @@ class IntermediateAstMaker(val program: Program) {
private fun transform(srcCheck: ContainmentCheck): PtContainmentCheck {
val check = PtContainmentCheck(srcCheck.position)
check.add(transformExpression(srcCheck.element))
check.add(transformExpression(srcCheck.iterable))
if(srcCheck.iterable !is IdentifierReference)
throw FatalAstException("iterable in containmentcheck must always be an identifier (referencing string or array) $srcCheck")
val iterable = transformExpression(srcCheck.iterable)
check.add(iterable)
return check
}

View File

@ -1232,13 +1232,9 @@ internal class AstChecker(private val program: Program,
val iterableDt = containment.iterable.inferType(program)
if(containment.parent is BinaryExpression)
errors.err("containment check is currently not supported in complex expressions", containment.position)
errors.err("containment check is currently not supported inside complex expressions", containment.position)
val range = containment.iterable as? RangeExpression
if(range!=null && range.toConstantIntegerRange()==null)
errors.err("containment check requires a constant integer range", range.position)
if(iterableDt.isIterable) {
if(iterableDt.isIterable && containment.iterable !is RangeExpression) {
val iterableEltDt = ArrayToElementTypes.getValue(iterableDt.getOr(DataType.UNDEFINED))
val invalidDt = if (elementDt.isBytes) {
iterableEltDt !in ByteDatatypes
@ -1250,7 +1246,7 @@ internal class AstChecker(private val program: Program,
if (invalidDt)
errors.err("element datatype doesn't match iterable datatype", containment.position)
} else {
errors.err("value set for containment check must be an iterable type", containment.iterable.position)
errors.err("value set for containment check must be a string or array", containment.iterable.position)
}
super.visit(containment)

View File

@ -7,6 +7,7 @@ import prog8.ast.statements.*
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
import prog8.ast.walk.IAstVisitor
import prog8.code.ast.PtIdentifier
import prog8.code.core.*
import prog8.code.target.VMTarget
@ -27,6 +28,14 @@ internal class BeforeAsmAstChanger(val program: Program,
throw InternalCompilerException("do..until should have been converted to jumps")
}
override fun after(containment: ContainmentCheck, parent: Node): Iterable<IAstModification> {
if(containment.element !is IdentifierReference && containment.element !is NumericLiteral)
throw InternalCompilerException("element in containmentcheck should be identifier or constant number")
if(containment.iterable !is IdentifierReference)
throw InternalCompilerException("iterable in containmentcheck should be identifier (referencing string or array)")
return noModifications
}
override fun before(block: Block, parent: Node): Iterable<IAstModification> {
// move all subroutines to the bottom of the block
val subs = block.statements.filterIsInstance<Subroutine>()

View File

@ -4,7 +4,6 @@ import prog8.ast.IFunctionCall
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.expressions.ArrayLiteral
import prog8.ast.expressions.ContainmentCheck
import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.StringLiteral
import prog8.ast.statements.VarDecl
@ -27,11 +26,8 @@ internal class LiteralsToAutoVars(private val program: Program,
errors.err("compilation target doesn't support this text encoding", string.position)
return noModifications
}
if(string.parent !is VarDecl
&& string.parent !is WhenChoice
&& (string.parent !is ContainmentCheck || string.value.length>ContainmentCheck.max_inlined_string_length)) {
if(string.parent !is VarDecl && string.parent !is WhenChoice) {
// replace the literal string by an identifier reference to the interned string
val parentFunc = (string.parent as? IFunctionCall)?.target
if(parentFunc!=null) {
if(parentFunc.nameInSource.size==1 && parentFunc.nameInSource[0]=="memory") {
@ -57,9 +53,6 @@ internal class LiteralsToAutoVars(private val program: Program,
return listOf(IAstModification.ReplaceNode(vardecl.value!!, cast, vardecl))
}
} else {
if(array.parent is ContainmentCheck && array.value.size<ContainmentCheck.max_inlined_string_length)
return noModifications
val arrayDt = array.guessDatatype(program)
if(arrayDt.isKnown) {
// turn the array literal it into an identifier reference

View File

@ -369,18 +369,10 @@ class TestCompilerOnRanges: FunSpec({
xx++
}
if 16 in 10 to 20 step 3 {
xx++
}
if 'b' in "abcdef" {
xx++
}
if 8 in [2,4,6,8] {
xx++
}
if xx in name {
xx++
}
@ -389,10 +381,6 @@ class TestCompilerOnRanges: FunSpec({
xx++
}
if xx in 10 to 20 step 3 {
xx++
}
if xx in "abcdef" {
xx++
}
@ -425,12 +413,10 @@ class TestCompilerOnRanges: FunSpec({
xx = 'm' in name
xx = 5 in values
xx = 16 in 10 to 20 step 3
xx = 'b' in "abcdef"
xx = 8 in [2,4,6,8]
xx = xx in name
xx = xx in values
xx = xx in 10 to 20 step 3
xx = xx in "abcdef"
xx = xx in [2,4,6,8]
xx = ww in [9000,8000,7000]

View File

@ -924,7 +924,7 @@ class TestProg8Parser: FunSpec( {
val errors = ErrorReporterForTests()
compileText(C64Target(), false, text, writeAssembly = false, errors = errors) shouldBe null
errors.errors.size shouldBe 2
errors.errors[0] shouldContain "must be an iterable type"
errors.errors[0] shouldContain "must be a string or array"
errors.errors[1] shouldContain "datatype doesn't match"
}

View File

@ -1009,10 +1009,6 @@ class ContainmentCheck(var element: Expression,
iterable.linkParents(this)
}
companion object {
const val max_inlined_string_length = 16
}
override val isSimple: Boolean = false
override fun copy() = ContainmentCheck(element.copy(), iterable.copy(), position)
override fun constValue(program: Program): NumericLiteral? {
@ -1023,13 +1019,6 @@ class ContainmentCheck(var element: Expression,
val exists = (iterable as ArrayLiteral).value.any { it.constValue(program)==elementConst }
return NumericLiteral.fromBoolean(exists, position)
}
is RangeExpression -> {
val intRange = (iterable as RangeExpression).toConstantIntegerRange()
if(intRange!=null && elementConst.type in IntegerDatatypes) {
val exists = elementConst.number.toInt() in intRange
return NumericLiteral.fromBoolean(exists, position)
}
}
is StringLiteral -> {
if(elementConst.type in ByteDatatypes) {
val stringval = iterable as StringLiteral
@ -1047,11 +1036,6 @@ class ContainmentCheck(var element: Expression,
if(array.value.isEmpty())
return NumericLiteral.fromBoolean(false, position)
}
is RangeExpression -> {
val size = (iterable as RangeExpression).size()
if(size!=null && size==0)
return NumericLiteral.fromBoolean(false, position)
}
is StringLiteral -> {
if((iterable as StringLiteral).value.isEmpty())
return NumericLiteral.fromBoolean(false, position)

View File

@ -3,7 +3,6 @@ TODO
For next release
^^^^^^^^^^^^^^^^
...

View File

@ -8,6 +8,7 @@ main {
txt.print("Welcome to a prog8 pixel shader :-)\n")
ubyte bb = 4
ubyte[] array = [1,2,3,4,5,6]
uword[] warray = [1111,2222,3333]
str tekst = "test"
uword ww = 19
bb = bb in "teststring"
@ -17,6 +18,13 @@ main {
bb = bb in array
bb++
bb = bb in tekst
bb++
bb = ww in warray
bb++
bb = 666 in warray
bb ++
bb = '?' in tekst
bb++
txt.print("bb=")
txt.print_ub(bb)
txt.nl()