fixing str compares codegen

This commit is contained in:
Irmen de Jong 2023-03-28 02:14:16 +02:00
parent f28206d989
commit 729209574e
6 changed files with 184 additions and 127 deletions

View File

@ -79,6 +79,7 @@ val BuiltinFunctions: Map<String, FSignature> = mapOf(
"reverse" to FSignature(false, listOf(FParam("array", ArrayDatatypes)), null),
// cmp returns a status in the carry flag, but not a proper return value
"cmp" to FSignature(false, listOf(FParam("value1", IntegerDatatypesNoBool), FParam("value2", NumericDatatypesNoBool)), null),
"prog8_lib_stringcompare" to FSignature(true, listOf(FParam("str1", arrayOf(DataType.STR)), FParam("str2", arrayOf(DataType.STR))), DataType.BYTE),
"abs" to FSignature(true, listOf(FParam("value", IntegerDatatypesNoBool)), DataType.UWORD),
"len" to FSignature(true, listOf(FParam("values", IterableDatatypes)), DataType.UWORD),
// normal functions follow:

View File

@ -70,12 +70,19 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram,
"rrestorex" -> funcRrestoreX()
"cmp" -> funcCmp(fcall)
"callfar" -> funcCallFar(fcall)
"prog8_lib_stringcompare" -> funcStringCompare(fcall)
else -> throw AssemblyError("missing asmgen for builtin func ${fcall.name}")
}
return BuiltinFunctions.getValue(fcall.name).returnType
}
private fun funcStringCompare(fcall: PtBuiltinFunctionCall) {
assignAsmGen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_W2", DataType.UWORD)
assignAsmGen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.AY, false)
asmgen.out(" jsr prog8_lib.strcmp_mem")
}
private fun funcRsave() {
if (asmgen.isTargetCpu(CpuType.CPU65c02))
asmgen.out("""

View File

@ -41,10 +41,27 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe
"ror" -> funcRolRor(Opcode.ROXR, call)
"rol2" -> funcRolRor(Opcode.ROL, call)
"ror2" -> funcRolRor(Opcode.ROR, call)
"prog8_lib_stringcompare" -> funcStringCompare(call)
else -> throw AssemblyError("missing builtinfunc for ${call.name}")
}
}
private fun funcStringCompare(call: PtBuiltinFunctionCall): ExpressionCodeResult {
/*
loadm.w r65500,string.compare.st1
loadm.w r65501,string.compare.st2
syscall 29
returnreg.b r0
*/
val result = mutableListOf<IRCodeChunkBase>()
val left = exprGen.translateExpression(call.args[0])
val right = exprGen.translateExpression(call.args[1])
addToResult(result, left, 65500, -1)
addToResult(result, right, 65501, -1)
addInstr(result, IRInstruction(Opcode.SYSCALL, value=IMSyscall.COMPARE_STRINGS.number), null)
return ExpressionCodeResult(result, IRDataType.BYTE, 0, -1)
}
private fun funcCmp(call: PtBuiltinFunctionCall): ExpressionCodeResult {
val result = mutableListOf<IRCodeChunkBase>()
val leftTr = exprGen.translateExpression(call.args[0])

View File

@ -477,13 +477,17 @@ private fun transformNewExpressions(program: PtProgram) {
if(expr.type == expr.left.type) {
getExprVar(postfix, expr.type, depth, expr.position, scope)
} else {
if(expr.operator in ComparisonOperators && expr.type == DataType.UBYTE) {
if(expr.operator in ComparisonOperators && expr.type in ByteDatatypes) {
// this is very common and should be dealth with correctly; byte==0, word>42
getExprVar(postfix, expr.left.type, depth, expr.position, scope)
val varType = if(expr.left.type in PassByReferenceDatatypes) DataType.UWORD else expr.left.type
getExprVar(postfix, varType, depth, expr.position, scope)
}
else if(expr.left.type in PassByReferenceDatatypes && expr.type==DataType.UBYTE) {
// this is common and should be dealth with correctly; for instance "name"=="irmen"
getExprVar(postfix, expr.left.type, depth, expr.position, scope)
// this is common and should be dealth with correctly; for instance "name"=="john"
val varType = if (expr.left.type in PassByReferenceDatatypes) DataType.UWORD else expr.left.type
getExprVar(postfix, varType, depth, expr.position, scope)
} else if(expr.left.type equalsSize expr.type) {
getExprVar(postfix, expr.type, depth, expr.position, scope)
} else {
TODO("expression type differs from left operand type! got ${expr.left.type} expected ${expr.type} ${expr.position}")
}

View File

@ -36,6 +36,19 @@ internal class BeforeAsmAstChanger(val program: Program,
return noModifications
}
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
if(expr.operator in ComparisonOperators && expr.left.inferType(program) istype DataType.STR && expr.right.inferType(program) istype DataType.STR) {
// replace string comparison expressions with calls to string.compare()
val stringCompare = BuiltinFunctionCall(
IdentifierReference(listOf("prog8_lib_stringcompare"), expr.position),
mutableListOf(expr.left.copy(), expr.right.copy()), expr.position)
val zero = NumericLiteral.optimalInteger(0, expr.position)
val comparison = BinaryExpression(stringCompare, expr.operator, zero, expr.position)
return listOf(IAstModification.ReplaceNode(expr, comparison, parent))
}
return noModifications
}
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> {
if (decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes)
throw InternalCompilerException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl")

View File

@ -1,6 +1,7 @@
%zeropage basicsafe
%import textio
%import floats
%import string
main {
sub start() {
@ -9,133 +10,147 @@ main {
word w = -20000
uword uw = 2000
float f = -100
str name="john"
txt.print("all 1: ")
txt.print_ub(b == -100)
txt.print_ub(b != -99)
txt.print_ub(b < -99)
txt.print_ub(b <= -100)
txt.print_ub(b > -101)
txt.print_ub(b >= -100)
txt.print_ub(ub ==20)
txt.print_ub(ub !=19)
txt.print_ub(ub <21)
txt.print_ub(ub <=20)
txt.print_ub(ub>19)
txt.print_ub(ub>=20)
txt.spc()
txt.print_ub(w == -20000)
txt.print_ub(w != -19999)
txt.print_ub(w < -19999)
txt.print_ub(w <= -20000)
txt.print_ub(w > -20001)
txt.print_ub(w >= -20000)
txt.print_ub(uw == 2000)
txt.print_ub(uw != 2001)
txt.print_ub(uw < 2001)
txt.print_ub(uw <= 2000)
txt.print_ub(uw > 1999)
txt.print_ub(uw >= 2000)
txt.spc()
txt.print_ub(f == -100.0)
txt.print_ub(f != -99.0)
txt.print_ub(f < -99.0)
txt.print_ub(f <= -100.0)
txt.print_ub(f > -101.0)
txt.print_ub(f >= -100.0)
txt.nl()
txt.print("all 0: ")
txt.print_ub(b == -99)
txt.print_ub(b != -100)
txt.print_ub(b < -100)
txt.print_ub(b <= -101)
txt.print_ub(b > -100)
txt.print_ub(b >= -99)
txt.print_ub(ub ==21)
txt.print_ub(ub !=20)
txt.print_ub(ub <20)
txt.print_ub(ub <=19)
txt.print_ub(ub>20)
txt.print_ub(ub>=21)
txt.spc()
txt.print_ub(w == -20001)
txt.print_ub(w != -20000)
txt.print_ub(w < -20000)
txt.print_ub(w <= -20001)
txt.print_ub(w > -20000)
txt.print_ub(w >= -19999)
txt.print_ub(uw == 1999)
txt.print_ub(uw != 2000)
txt.print_ub(uw < 2000)
txt.print_ub(uw <= 1999)
txt.print_ub(uw > 2000)
txt.print_ub(uw >= 2001)
txt.spc()
txt.print_ub(f == -99.0)
txt.print_ub(f != -100.0)
txt.print_ub(f < -100.0)
txt.print_ub(f <= -101.0)
txt.print_ub(f > -100.0)
txt.print_ub(f >= -99.0)
txt.nl()
; TODO ALL OF THE ABOVE BUT WITH A VARIABLE INSTEAD OF A CONST VALUE
b = -100
while b <= -20
b++
txt.print_b(b)
txt.print(" -19\n")
b = -100
while b < -20
b++
txt.print_b(b)
txt.print(" -20\n")
ub = 20
while ub <= 200
ub++
txt.print_ub(ub)
txt.print(" 201\n")
ub = 20
while ub < 200
ub++
txt.print_ub(ub)
txt.print(" 200\n")
w = -20000
while w <= -8000 {
w++
if (string.compare(name, "aaa")==0) or (string.compare(name, "john")==0) or (string.compare(name, "bbb")==0) {
txt.print("name1 ok\n")
}
txt.print_w(w)
txt.print(" -7999\n")
w = -20000
while w < -8000 {
w++
if (string.compare(name, "aaa")==0) or (string.compare(name, "zzz")==0) or (string.compare(name, "bbb")==0) {
txt.print("name2 fail!\n")
}
txt.print_w(w)
txt.print(" -8000\n")
uw = 2000
while uw <= 8000 {
uw++
}
txt.print_uw(uw)
txt.print(" 8001\n")
uw = 2000
while uw < 8000 {
uw++
}
txt.print_uw(uw)
txt.print(" 8000\n")
if name=="aaa" or name=="john" or name=="bbb" ; TODO fix this result on C64 target, no newexpr!
txt.print("name1b ok\n")
if name=="aaa" or name=="zzz" or name=="bbb" ; TODO fix this result on C64 target, no newexpr!
txt.print("name2b fail!\n")
f = 0.0
while f<2.2 {
f+=0.1
}
floats.print_f(f)
txt.print(" 2.2\n")
; txt.print("all 1: ")
; txt.print_ub(b == -100)
; txt.print_ub(b != -99)
; txt.print_ub(b < -99)
; txt.print_ub(b <= -100)
; txt.print_ub(b > -101)
; txt.print_ub(b >= -100)
; txt.print_ub(ub ==20)
; txt.print_ub(ub !=19)
; txt.print_ub(ub <21)
; txt.print_ub(ub <=20)
; txt.print_ub(ub>19)
; txt.print_ub(ub>=20)
; txt.spc()
; txt.print_ub(w == -20000)
; txt.print_ub(w != -19999)
; txt.print_ub(w < -19999)
; txt.print_ub(w <= -20000)
; txt.print_ub(w > -20001)
; txt.print_ub(w >= -20000)
; txt.print_ub(uw == 2000)
; txt.print_ub(uw != 2001)
; txt.print_ub(uw < 2001)
; txt.print_ub(uw <= 2000)
; txt.print_ub(uw > 1999)
; txt.print_ub(uw >= 2000)
; txt.spc()
; txt.print_ub(f == -100.0)
; txt.print_ub(f != -99.0)
; txt.print_ub(f < -99.0)
; txt.print_ub(f <= -100.0)
; txt.print_ub(f > -101.0)
; txt.print_ub(f >= -100.0)
; txt.nl()
;
; txt.print("all 0: ")
; txt.print_ub(b == -99)
; txt.print_ub(b != -100)
; txt.print_ub(b < -100)
; txt.print_ub(b <= -101)
; txt.print_ub(b > -100)
; txt.print_ub(b >= -99)
; txt.print_ub(ub ==21)
; txt.print_ub(ub !=20)
; txt.print_ub(ub <20)
; txt.print_ub(ub <=19)
; txt.print_ub(ub>20)
; txt.print_ub(ub>=21)
; txt.spc()
; txt.print_ub(w == -20001)
; txt.print_ub(w != -20000)
; txt.print_ub(w < -20000)
; txt.print_ub(w <= -20001)
; txt.print_ub(w > -20000)
; txt.print_ub(w >= -19999)
; txt.print_ub(uw == 1999)
; txt.print_ub(uw != 2000)
; txt.print_ub(uw < 2000)
; txt.print_ub(uw <= 1999)
; txt.print_ub(uw > 2000)
; txt.print_ub(uw >= 2001)
; txt.spc()
; txt.print_ub(f == -99.0)
; txt.print_ub(f != -100.0)
; txt.print_ub(f < -100.0)
; txt.print_ub(f <= -101.0)
; txt.print_ub(f > -100.0)
; txt.print_ub(f >= -99.0)
; txt.nl()
;
; ; TODO ALL OF THE ABOVE BUT WITH A VARIABLE INSTEAD OF A CONST VALUE
;
;
; b = -100
; while b <= -20
; b++
; txt.print_b(b)
; txt.print(" -19\n")
; b = -100
; while b < -20
; b++
; txt.print_b(b)
; txt.print(" -20\n")
;
; ub = 20
; while ub <= 200
; ub++
; txt.print_ub(ub)
; txt.print(" 201\n")
; ub = 20
; while ub < 200
; ub++
; txt.print_ub(ub)
; txt.print(" 200\n")
;
; w = -20000
; while w <= -8000 {
; w++
; }
; txt.print_w(w)
; txt.print(" -7999\n")
; w = -20000
; while w < -8000 {
; w++
; }
; txt.print_w(w)
; txt.print(" -8000\n")
;
; uw = 2000
; while uw <= 8000 {
; uw++
; }
; txt.print_uw(uw)
; txt.print(" 8001\n")
; uw = 2000
; while uw < 8000 {
; uw++
; }
; txt.print_uw(uw)
; txt.print(" 8000\n")
;
; f = 0.0
; while f<2.2 {
; f+=0.1
; }
; floats.print_f(f)
; txt.print(" 2.2\n")
}
}