improve containment check for few values

This commit is contained in:
Irmen de Jong 2024-09-11 03:15:38 +02:00
parent 01c6754928
commit 255c5bfaca
8 changed files with 189 additions and 46 deletions

View File

@ -42,7 +42,18 @@ sealed class PtExpression(val type: DataType, position: Position) : PtNode(posit
else
other.left isSameAs left && other.right isSameAs right
}
is PtContainmentCheck -> other is PtContainmentCheck && other.type==type && other.element isSameAs element && other.iterable isSameAs iterable
is PtContainmentCheck -> {
if(other !is PtContainmentCheck || other.type != type || !(other.needle isSameAs needle))
false
else {
if(haystackHeapVar!=null)
other.haystackHeapVar!=null && other.haystackHeapVar!! isSameAs haystackHeapVar!!
else if(haystackValues!=null)
other.haystackValues!=null && other.haystackValues!! isSameAs haystackValues!!
else
false
}
}
is PtIdentifier -> other is PtIdentifier && other.type==type && other.name==name
is PtIrRegister -> other is PtIrRegister && other.type==type && other.register==register
is PtMemoryByte -> other is PtMemoryByte && other.address isSameAs address
@ -195,10 +206,17 @@ class PtBinaryExpression(val operator: String, type: DataType, position: Positio
class PtContainmentCheck(position: Position): PtExpression(DataType.BOOL, position) {
val element: PtExpression
val needle: PtExpression
get() = children[0] as PtExpression
val iterable: PtIdentifier
get() = children[1] as PtIdentifier
val haystackHeapVar: PtIdentifier?
get() = children[1] as? PtIdentifier
val haystackValues: PtArray?
get() = children[1] as? PtArray
companion object {
val MAX_SIZE_FOR_INLINE_CHECKS_BYTE = 5
val MAX_SIZE_FOR_INLINE_CHECKS_WORD = 4
}
}

View File

@ -1830,8 +1830,57 @@ internal class AssignmentAsmGen(
}
private fun containmentCheckIntoA(containment: PtContainmentCheck) {
val elementDt = containment.element.type
val symbol = asmgen.symbolTable.lookup(containment.iterable.name)!!
val elementDt = containment.needle.type
if(containment.haystackValues!=null) {
val haystack = containment.haystackValues!!.children.map {
if(it is PtBool) it.asInt()
else (it as PtNumber).number
}
when(elementDt) {
in ByteDatatypesWithBoolean -> {
require(haystack.size in 0..PtContainmentCheck.MAX_SIZE_FOR_INLINE_CHECKS_BYTE)
assignExpressionToRegister(containment.needle, RegisterOrPair.A, elementDt == DataType.BYTE)
for(number in haystack) {
val numstr = if(elementDt==DataType.FLOAT) number.toString() else number.toInt().toString()
asmgen.out("""
cmp #$numstr
beq +""")
}
asmgen.out("""
lda #0
beq ++
+ lda #1
+""")
}
in WordDatatypes -> {
require(haystack.size in 0..PtContainmentCheck.MAX_SIZE_FOR_INLINE_CHECKS_WORD)
assignExpressionToRegister(containment.needle, RegisterOrPair.AY, elementDt == DataType.WORD)
val gottemLabel = asmgen.makeLabel("gottem")
val endLabel = asmgen.makeLabel("end")
for(number in haystack) {
val numstr = if(elementDt==DataType.FLOAT) number.toString() else number.toInt().toString()
asmgen.out("""
cmp #<$numstr
bne +
cpy #>$numstr
beq $gottemLabel
+ """)
}
asmgen.out("""
lda #0
beq $endLabel
$gottemLabel lda #1
$endLabel""")
}
DataType.FLOAT -> throw AssemblyError("containmentchecks for floats should always be done on an array variable with subroutine")
else -> throw AssemblyError("weird dt $elementDt")
}
return
}
val symbol = asmgen.symbolTable.lookup(containment.haystackHeapVar!!.name)!!
val symbolName = asmgen.asmVariableName(symbol, containment.definingSub())
val (dt, numElements) = when(symbol) {
is StStaticVariable -> symbol.dt to symbol.length!!
@ -1840,7 +1889,7 @@ internal class AssignmentAsmGen(
}
when(dt) {
DataType.STR -> {
assignExpressionToRegister(containment.element, RegisterOrPair.A, elementDt == DataType.BYTE)
assignExpressionToRegister(containment.needle, RegisterOrPair.A, elementDt == DataType.BYTE)
asmgen.out(" pha") // need to keep the scratch var safe so we have to do it in this order
assignAddressOf(AsmAssignTarget(TargetStorageKind.VARIABLE, asmgen, DataType.UWORD, containment.definingISub(), containment.position,"P8ZP_SCRATCH_W1"), symbolName, null, null)
asmgen.out(" pla")
@ -1848,13 +1897,13 @@ internal class AssignmentAsmGen(
asmgen.out(" jsr prog8_lib.containment_bytearray")
}
DataType.ARRAY_F -> {
assignExpressionToRegister(containment.element, RegisterOrPair.FAC1, true)
assignExpressionToRegister(containment.needle, RegisterOrPair.FAC1, true)
assignAddressOf(AsmAssignTarget(TargetStorageKind.VARIABLE, asmgen, DataType.UWORD, containment.definingISub(), containment.position, "P8ZP_SCRATCH_W1"), symbolName, null, null)
asmgen.out(" ldy #$numElements")
asmgen.out(" jsr floats.containment_floatarray")
}
DataType.ARRAY_B, DataType.ARRAY_UB -> {
assignExpressionToRegister(containment.element, RegisterOrPair.A, elementDt == DataType.BYTE)
DataType.ARRAY_B, DataType.ARRAY_UB, DataType.ARRAY_BOOL -> {
assignExpressionToRegister(containment.needle, RegisterOrPair.A, elementDt == DataType.BYTE)
asmgen.out(" pha") // need to keep the scratch var safe so we have to do it in this order
assignAddressOf(AsmAssignTarget(TargetStorageKind.VARIABLE, asmgen, DataType.UWORD, containment.definingISub(), containment.position, "P8ZP_SCRATCH_W1"), symbolName, null, null)
asmgen.out(" pla")
@ -1862,7 +1911,7 @@ internal class AssignmentAsmGen(
asmgen.out(" jsr prog8_lib.containment_bytearray")
}
DataType.ARRAY_W, DataType.ARRAY_UW -> {
assignExpressionToVariable(containment.element, "P8ZP_SCRATCH_W1", elementDt)
assignExpressionToVariable(containment.needle, "P8ZP_SCRATCH_W1", elementDt)
assignAddressOf(AsmAssignTarget(TargetStorageKind.VARIABLE, asmgen, DataType.UWORD, containment.definingISub(), containment.position, "P8ZP_SCRATCH_W2"), symbolName, null, null)
asmgen.out(" ldy #$numElements")
asmgen.out(" jsr prog8_lib.containment_wordarray")

View File

@ -168,13 +168,43 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
}
private fun translate(check: PtContainmentCheck): ExpressionCodeResult {
val elementDt = check.needle.type
val result = mutableListOf<IRCodeChunkBase>()
when(check.iterable.type) {
if(check.haystackValues!=null) {
val haystack = check.haystackValues!!.children.map {
if(it is PtBool) it.asInt()
else (it as PtNumber).number
}
when(elementDt) {
in ByteDatatypesWithBoolean -> {
val elementTr = translateExpression(check.needle)
addToResult(result, elementTr, elementTr.resultReg, -1)
TODO("byte containment check ${check.needle} in ${haystack}")
// TODO should return a proper result here
val resultReg = -1
return ExpressionCodeResult(result, IRDataType.BYTE, resultReg, -1)
}
in WordDatatypes -> {
val elementTr = translateExpression(check.needle)
addToResult(result, elementTr, elementTr.resultReg, -1)
TODO("word containment check ${check.needle} in ${haystack}")
// TODO should return a proper result here
val resultReg = -1
return ExpressionCodeResult(result, IRDataType.BYTE, resultReg, -1)
}
DataType.FLOAT -> throw AssemblyError("containmentchecks for floats should always be done on an array variable with subroutine")
else -> throw AssemblyError("weird dt $elementDt")
}
}
val haystackVar = check.haystackHeapVar!!
when(haystackVar.type) {
DataType.STR -> {
addInstr(result, IRInstruction(Opcode.PREPARECALL, immediate = 2), null)
val elementTr = translateExpression(check.element)
val elementTr = translateExpression(check.needle)
addToResult(result, elementTr, elementTr.resultReg, -1)
val iterableTr = translateExpression(check.iterable)
val iterableTr = translateExpression(haystackVar)
addToResult(result, iterableTr, iterableTr.resultReg, -1)
result += codeGen.makeSyscall(IMSyscall.STRING_CONTAINS, listOf(IRDataType.BYTE to elementTr.resultReg, IRDataType.WORD to iterableTr.resultReg), IRDataType.BYTE to elementTr.resultReg)
addInstr(result, IRInstruction(Opcode.CMPI, IRDataType.BYTE, reg1=elementTr.resultReg, immediate = 0), null)
@ -182,12 +212,12 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
}
DataType.ARRAY_UB, DataType.ARRAY_B -> {
addInstr(result, IRInstruction(Opcode.PREPARECALL, immediate = 3), null)
val elementTr = translateExpression(check.element)
val elementTr = translateExpression(check.needle)
addToResult(result, elementTr, elementTr.resultReg, -1)
val iterableTr = translateExpression(check.iterable)
val iterableTr = translateExpression(haystackVar)
addToResult(result, iterableTr, iterableTr.resultReg, -1)
val lengthReg = codeGen.registers.nextFree()
val iterableLength = codeGen.symbolTable.getLength(check.iterable.name)
val iterableLength = codeGen.symbolTable.getLength(haystackVar.name)
addInstr(result, IRInstruction(Opcode.LOAD, IRDataType.BYTE, reg1=lengthReg, immediate = iterableLength!!), null)
result += codeGen.makeSyscall(IMSyscall.BYTEARRAY_CONTAINS, listOf(IRDataType.BYTE to elementTr.resultReg, IRDataType.WORD to iterableTr.resultReg, IRDataType.BYTE to lengthReg), IRDataType.BYTE to elementTr.resultReg)
addInstr(result, IRInstruction(Opcode.CMPI, IRDataType.BYTE, reg1=elementTr.resultReg, immediate = 0), null)
@ -195,12 +225,12 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
}
DataType.ARRAY_UW, DataType.ARRAY_W -> {
addInstr(result, IRInstruction(Opcode.PREPARECALL, immediate = 3), null)
val elementTr = translateExpression(check.element)
val elementTr = translateExpression(check.needle)
addToResult(result, elementTr, elementTr.resultReg, -1)
val iterableTr = translateExpression(check.iterable)
val iterableTr = translateExpression(haystackVar)
addToResult(result, iterableTr, iterableTr.resultReg, -1)
val lengthReg = codeGen.registers.nextFree()
val iterableLength = codeGen.symbolTable.getLength(check.iterable.name)
val iterableLength = codeGen.symbolTable.getLength(haystackVar.name)
addInstr(result, IRInstruction(Opcode.LOAD, IRDataType.BYTE, reg1=lengthReg, immediate = iterableLength!!), null)
result += codeGen.makeSyscall(IMSyscall.WORDARRAY_CONTAINS, listOf(IRDataType.WORD to elementTr.resultReg, IRDataType.WORD to iterableTr.resultReg, IRDataType.BYTE to lengthReg), IRDataType.BYTE to elementTr.resultReg)
addInstr(result, IRInstruction(Opcode.CMPI, IRDataType.BYTE, reg1=elementTr.resultReg, immediate = 0), null)
@ -208,19 +238,19 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
}
DataType.ARRAY_F -> {
addInstr(result, IRInstruction(Opcode.PREPARECALL, immediate = 3), null)
val elementTr = translateExpression(check.element)
val elementTr = translateExpression(check.needle)
addToResult(result, elementTr, -1, elementTr.resultFpReg)
val iterableTr = translateExpression(check.iterable)
val iterableTr = translateExpression(haystackVar)
addToResult(result, iterableTr, iterableTr.resultReg, -1)
val lengthReg = codeGen.registers.nextFree()
val resultReg = codeGen.registers.nextFree()
val iterableLength = codeGen.symbolTable.getLength(check.iterable.name)
val iterableLength = codeGen.symbolTable.getLength(haystackVar.name)
addInstr(result, IRInstruction(Opcode.LOAD, IRDataType.BYTE, reg1=lengthReg, immediate = iterableLength!!), null)
result += codeGen.makeSyscall(IMSyscall.FLOATARRAY_CONTAINS, listOf(IRDataType.FLOAT to elementTr.resultFpReg, IRDataType.WORD to iterableTr.resultReg, IRDataType.BYTE to lengthReg), IRDataType.BYTE to resultReg)
addInstr(result, IRInstruction(Opcode.CMPI, IRDataType.BYTE, reg1=resultReg, immediate = 0), null)
return ExpressionCodeResult(result, IRDataType.BYTE, resultReg, -1)
}
else -> throw AssemblyError("weird iterable dt ${check.iterable.type} for ${check.iterable.name}")
else -> throw AssemblyError("weird iterable dt ${haystackVar.type} for ${haystackVar.name}")
}
}

View File

@ -623,11 +623,10 @@ class IntermediateAstMaker(private val program: Program, private val errors: IEr
}
when(srcCheck.iterable) {
is IdentifierReference -> {
is IdentifierReference, is ArrayLiteral -> {
val check = PtContainmentCheck(srcCheck.position)
check.add(transformExpression(srcCheck.element))
val iterable = transformExpression(srcCheck.iterable)
check.add(iterable)
check.add(transformExpression(srcCheck.iterable))
return check
}
is RangeExpression -> {

View File

@ -4,19 +4,14 @@ import prog8.ast.IFunctionCall
import prog8.ast.IStatementContainer
import prog8.ast.Node
import prog8.ast.Program
import prog8.ast.expressions.ArrayLiteral
import prog8.ast.expressions.BinaryExpression
import prog8.ast.expressions.IdentifierReference
import prog8.ast.expressions.StringLiteral
import prog8.ast.expressions.*
import prog8.ast.statements.Assignment
import prog8.ast.statements.VarDecl
import prog8.ast.statements.WhenChoice
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
import prog8.code.core.DataType
import prog8.code.core.IErrorReporter
import prog8.code.core.NumericDatatypesWithBoolean
import prog8.code.core.SplitWordArrayTypes
import prog8.code.ast.PtContainmentCheck
import prog8.code.core.*
internal class LiteralsToAutoVars(private val program: Program, private val errors: IErrorReporter) : AstWalker() {
@ -54,6 +49,16 @@ internal class LiteralsToAutoVars(private val program: Program, private val erro
}
} else {
val arrayDt = array.guessDatatype(program)
val elementDt = ArrayToElementTypes.getValue(arrayDt.getOr(DataType.UNDEFINED))
val maxSize = when(elementDt) {
in ByteDatatypesWithBoolean -> PtContainmentCheck.MAX_SIZE_FOR_INLINE_CHECKS_BYTE
in WordDatatypes -> PtContainmentCheck.MAX_SIZE_FOR_INLINE_CHECKS_WORD
else -> 0
}
if(parent is ContainmentCheck && array.value.size <= maxSize) {
// keep the array in the containmentcheck inline
return noModifications
}
if(arrayDt.isKnown) {
val parentAssign = parent as? Assignment
val targetDt = parentAssign?.target?.inferType(program) ?: arrayDt

View File

@ -523,21 +523,24 @@ main {
xxValue.right shouldBe NumericLiteral(DataType.UWORD, 10.0, Position.DUMMY)
}
test("multi-comparison replaced by containment check") {
test("multi-comparison with many values replaced by containment check on heap variable") {
val src="""
main {
sub start() {
ubyte @shared source=99
ubyte @shared thingy=42
if source==3 or source==4 or source==99 or source==1
if source==3 or source==4 or source==99 or source==1 or source==2 or source==42
thingy++
}
}"""
compileText(VMTarget(), optimize=true, src, writeAssembly=true) shouldNotBe null
val result = compileText(C64Target(), optimize=true, src, writeAssembly=false)!!
/*
expected result:
ubyte[] auto_heap_var = [1,4,99,3]
ubyte[] auto_heap_var = [1,2,3,4,42,99]
ubyte source
source = 99
ubyte thingy
@ -553,16 +556,43 @@ main {
(containment.iterable as IdentifierReference).nameInSource.single() shouldStartWith "auto_heap_value"
val arrayDecl = stmts[0] as VarDecl
arrayDecl.isArray shouldBe true
arrayDecl.arraysize?.constIndex() shouldBe 4
arrayDecl.arraysize?.constIndex() shouldBe 6
val arrayValue = arrayDecl.value as ArrayLiteral
arrayValue.type shouldBe InferredTypes.InferredType.known(DataType.ARRAY_UB)
arrayValue.value shouldBe listOf(
NumericLiteral.optimalInteger(1, Position.DUMMY),
NumericLiteral.optimalInteger(2, Position.DUMMY),
NumericLiteral.optimalInteger(3, Position.DUMMY),
NumericLiteral.optimalInteger(4, Position.DUMMY),
NumericLiteral.optimalInteger(42, Position.DUMMY),
NumericLiteral.optimalInteger(99, Position.DUMMY))
}
test("multi-comparison with few values replaced by inline containment check") {
val src="""
main {
sub start() {
ubyte @shared source=99
ubyte @shared thingy=42
if source==3 or source==4 or source==99
thingy++
}
}"""
compileText(VMTarget(), optimize=true, src, writeAssembly=true) shouldNotBe null
val result = compileText(C64Target(), optimize=true, src, writeAssembly=false)!!
val stmts = result.compilerAst.entrypoint.statements
stmts.size shouldBe 5
val ifStmt = stmts[4] as IfElse
val containment = ifStmt.condition as ContainmentCheck
(containment.element as IdentifierReference).nameInSource shouldBe listOf("source")
val array = (containment.iterable as ArrayLiteral)
array.value.size shouldBe 3
array.value.map { (it as NumericLiteral).number } shouldBe listOf(3.0, 4.0, 99.0)
}
test("invalid multi-comparison (not all equals) not replaced") {
val src="""
main {

View File

@ -1,13 +1,14 @@
TODO
====
IR: add codegen for containmentcheck literal + test that it actually works (positive and negative)
Fix testgfx2 screen text being uppercase (should be upper+lowercased)
diskio.internal_f_tell gets included in the assembly even though f_tell is never called ??? (when using another routine from diskio...)
IR: Improve codegen for for loops downto 0. (BPL if <=127 etc like 6502 codegen?)
Improve register load order in subroutine call args assignments:
in certain situations, the "wrong" order of evaluation of function call arguments is done which results
in overwriting registers that already got their value, which requires a lot of stack juggling (especially on plain 6502 cpu!)

View File

@ -1,15 +1,26 @@
%import textio
%zeropage basicsafe
%option no_sysinit
main {
sub start() {
ubyte @shared x
bool @shared b = true
ubyte @shared x = 42
uword @shared w = 9999
; if x==13 or x==42 ; why is there shortcut-evaluation here for those simple terms
; cx16.r0L++
if b in [true, false, false, true]
txt.print("yep0\n")
if x==13 or x==42 or x==99 or x==100 ; is this really more efficient as a containment check , when the shortcut-evaluation gets fixed??
cx16.r0L++
if x==13 or x==42 ; why is there shortcut-evaluation here for those simple terms
txt.print("yep1\n")
if x in [13, 42,99,100]
txt.print("yep2\n")
if x==13 or x==42 or x==99 or x==100 ; optimize the containment check to not always use an array + jsr
txt.print("yep3\n")
if w==1313 or w==4242 or w==9999 or w==10101 ; optimize the containment check to not always use an array + jsr
txt.print("yep4\n")
}
}