From 729209574e4140466eebb6dc6f3c5ccecb038b9e Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Tue, 28 Mar 2023 02:14:16 +0200 Subject: [PATCH] fixing str compares codegen --- .../src/prog8/code/core/BuiltinFunctions.kt | 1 + .../codegen/cpu6502/BuiltinFunctionsAsmGen.kt | 7 + .../codegen/intermediate/BuiltinFuncGen.kt | 17 ++ compiler/src/prog8/compiler/Compiler.kt | 12 +- .../astprocessing/BeforeAsmAstChanger.kt | 13 + examples/test.p8 | 261 +++++++++--------- 6 files changed, 184 insertions(+), 127 deletions(-) diff --git a/codeCore/src/prog8/code/core/BuiltinFunctions.kt b/codeCore/src/prog8/code/core/BuiltinFunctions.kt index a19a3b219..d189bc43a 100644 --- a/codeCore/src/prog8/code/core/BuiltinFunctions.kt +++ b/codeCore/src/prog8/code/core/BuiltinFunctions.kt @@ -79,6 +79,7 @@ val BuiltinFunctions: Map = mapOf( "reverse" to FSignature(false, listOf(FParam("array", ArrayDatatypes)), null), // cmp returns a status in the carry flag, but not a proper return value "cmp" to FSignature(false, listOf(FParam("value1", IntegerDatatypesNoBool), FParam("value2", NumericDatatypesNoBool)), null), + "prog8_lib_stringcompare" to FSignature(true, listOf(FParam("str1", arrayOf(DataType.STR)), FParam("str2", arrayOf(DataType.STR))), DataType.BYTE), "abs" to FSignature(true, listOf(FParam("value", IntegerDatatypesNoBool)), DataType.UWORD), "len" to FSignature(true, listOf(FParam("values", IterableDatatypes)), DataType.UWORD), // normal functions follow: diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt index e49456fce..aba769dd2 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt @@ -70,12 +70,19 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram, "rrestorex" -> funcRrestoreX() "cmp" -> funcCmp(fcall) "callfar" -> funcCallFar(fcall) + "prog8_lib_stringcompare" -> funcStringCompare(fcall) else -> throw AssemblyError("missing asmgen for builtin func ${fcall.name}") } return BuiltinFunctions.getValue(fcall.name).returnType } + private fun funcStringCompare(fcall: PtBuiltinFunctionCall) { + assignAsmGen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_W2", DataType.UWORD) + assignAsmGen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.AY, false) + asmgen.out(" jsr prog8_lib.strcmp_mem") + } + private fun funcRsave() { if (asmgen.isTargetCpu(CpuType.CPU65c02)) asmgen.out(""" diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt index f2b5ba232..898ecc834 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt @@ -41,10 +41,27 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe "ror" -> funcRolRor(Opcode.ROXR, call) "rol2" -> funcRolRor(Opcode.ROL, call) "ror2" -> funcRolRor(Opcode.ROR, call) + "prog8_lib_stringcompare" -> funcStringCompare(call) else -> throw AssemblyError("missing builtinfunc for ${call.name}") } } + private fun funcStringCompare(call: PtBuiltinFunctionCall): ExpressionCodeResult { +/* + loadm.w r65500,string.compare.st1 + loadm.w r65501,string.compare.st2 + syscall 29 + returnreg.b r0 + */ + val result = mutableListOf() + val left = exprGen.translateExpression(call.args[0]) + val right = exprGen.translateExpression(call.args[1]) + addToResult(result, left, 65500, -1) + addToResult(result, right, 65501, -1) + addInstr(result, IRInstruction(Opcode.SYSCALL, value=IMSyscall.COMPARE_STRINGS.number), null) + return ExpressionCodeResult(result, IRDataType.BYTE, 0, -1) + } + private fun funcCmp(call: PtBuiltinFunctionCall): ExpressionCodeResult { val result = mutableListOf() val leftTr = exprGen.translateExpression(call.args[0]) diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 82b2fd847..bb58ae311 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -477,13 +477,17 @@ private fun transformNewExpressions(program: PtProgram) { if(expr.type == expr.left.type) { getExprVar(postfix, expr.type, depth, expr.position, scope) } else { - if(expr.operator in ComparisonOperators && expr.type == DataType.UBYTE) { + if(expr.operator in ComparisonOperators && expr.type in ByteDatatypes) { // this is very common and should be dealth with correctly; byte==0, word>42 - getExprVar(postfix, expr.left.type, depth, expr.position, scope) + val varType = if(expr.left.type in PassByReferenceDatatypes) DataType.UWORD else expr.left.type + getExprVar(postfix, varType, depth, expr.position, scope) } else if(expr.left.type in PassByReferenceDatatypes && expr.type==DataType.UBYTE) { - // this is common and should be dealth with correctly; for instance "name"=="irmen" - getExprVar(postfix, expr.left.type, depth, expr.position, scope) + // this is common and should be dealth with correctly; for instance "name"=="john" + val varType = if (expr.left.type in PassByReferenceDatatypes) DataType.UWORD else expr.left.type + getExprVar(postfix, varType, depth, expr.position, scope) + } else if(expr.left.type equalsSize expr.type) { + getExprVar(postfix, expr.type, depth, expr.position, scope) } else { TODO("expression type differs from left operand type! got ${expr.left.type} expected ${expr.type} ${expr.position}") } diff --git a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt index bcbc62e23..79dca48b1 100644 --- a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt +++ b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt @@ -36,6 +36,19 @@ internal class BeforeAsmAstChanger(val program: Program, return noModifications } + override fun after(expr: BinaryExpression, parent: Node): Iterable { + if(expr.operator in ComparisonOperators && expr.left.inferType(program) istype DataType.STR && expr.right.inferType(program) istype DataType.STR) { + // replace string comparison expressions with calls to string.compare() + val stringCompare = BuiltinFunctionCall( + IdentifierReference(listOf("prog8_lib_stringcompare"), expr.position), + mutableListOf(expr.left.copy(), expr.right.copy()), expr.position) + val zero = NumericLiteral.optimalInteger(0, expr.position) + val comparison = BinaryExpression(stringCompare, expr.operator, zero, expr.position) + return listOf(IAstModification.ReplaceNode(expr, comparison, parent)) + } + return noModifications + } + override fun after(decl: VarDecl, parent: Node): Iterable { if (decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes) throw InternalCompilerException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl") diff --git a/examples/test.p8 b/examples/test.p8 index 37302c67a..ac1286195 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,6 +1,7 @@ %zeropage basicsafe %import textio %import floats +%import string main { sub start() { @@ -9,133 +10,147 @@ main { word w = -20000 uword uw = 2000 float f = -100 + str name="john" - txt.print("all 1: ") - txt.print_ub(b == -100) - txt.print_ub(b != -99) - txt.print_ub(b < -99) - txt.print_ub(b <= -100) - txt.print_ub(b > -101) - txt.print_ub(b >= -100) - txt.print_ub(ub ==20) - txt.print_ub(ub !=19) - txt.print_ub(ub <21) - txt.print_ub(ub <=20) - txt.print_ub(ub>19) - txt.print_ub(ub>=20) - txt.spc() - txt.print_ub(w == -20000) - txt.print_ub(w != -19999) - txt.print_ub(w < -19999) - txt.print_ub(w <= -20000) - txt.print_ub(w > -20001) - txt.print_ub(w >= -20000) - txt.print_ub(uw == 2000) - txt.print_ub(uw != 2001) - txt.print_ub(uw < 2001) - txt.print_ub(uw <= 2000) - txt.print_ub(uw > 1999) - txt.print_ub(uw >= 2000) - txt.spc() - txt.print_ub(f == -100.0) - txt.print_ub(f != -99.0) - txt.print_ub(f < -99.0) - txt.print_ub(f <= -100.0) - txt.print_ub(f > -101.0) - txt.print_ub(f >= -100.0) - txt.nl() - - txt.print("all 0: ") - txt.print_ub(b == -99) - txt.print_ub(b != -100) - txt.print_ub(b < -100) - txt.print_ub(b <= -101) - txt.print_ub(b > -100) - txt.print_ub(b >= -99) - txt.print_ub(ub ==21) - txt.print_ub(ub !=20) - txt.print_ub(ub <20) - txt.print_ub(ub <=19) - txt.print_ub(ub>20) - txt.print_ub(ub>=21) - txt.spc() - txt.print_ub(w == -20001) - txt.print_ub(w != -20000) - txt.print_ub(w < -20000) - txt.print_ub(w <= -20001) - txt.print_ub(w > -20000) - txt.print_ub(w >= -19999) - txt.print_ub(uw == 1999) - txt.print_ub(uw != 2000) - txt.print_ub(uw < 2000) - txt.print_ub(uw <= 1999) - txt.print_ub(uw > 2000) - txt.print_ub(uw >= 2001) - txt.spc() - txt.print_ub(f == -99.0) - txt.print_ub(f != -100.0) - txt.print_ub(f < -100.0) - txt.print_ub(f <= -101.0) - txt.print_ub(f > -100.0) - txt.print_ub(f >= -99.0) - txt.nl() - - ; TODO ALL OF THE ABOVE BUT WITH A VARIABLE INSTEAD OF A CONST VALUE - - - b = -100 - while b <= -20 - b++ - txt.print_b(b) - txt.print(" -19\n") - b = -100 - while b < -20 - b++ - txt.print_b(b) - txt.print(" -20\n") - - ub = 20 - while ub <= 200 - ub++ - txt.print_ub(ub) - txt.print(" 201\n") - ub = 20 - while ub < 200 - ub++ - txt.print_ub(ub) - txt.print(" 200\n") - - w = -20000 - while w <= -8000 { - w++ + if (string.compare(name, "aaa")==0) or (string.compare(name, "john")==0) or (string.compare(name, "bbb")==0) { + txt.print("name1 ok\n") } - txt.print_w(w) - txt.print(" -7999\n") - w = -20000 - while w < -8000 { - w++ + if (string.compare(name, "aaa")==0) or (string.compare(name, "zzz")==0) or (string.compare(name, "bbb")==0) { + txt.print("name2 fail!\n") } - txt.print_w(w) - txt.print(" -8000\n") - uw = 2000 - while uw <= 8000 { - uw++ - } - txt.print_uw(uw) - txt.print(" 8001\n") - uw = 2000 - while uw < 8000 { - uw++ - } - txt.print_uw(uw) - txt.print(" 8000\n") + if name=="aaa" or name=="john" or name=="bbb" ; TODO fix this result on C64 target, no newexpr! + txt.print("name1b ok\n") + if name=="aaa" or name=="zzz" or name=="bbb" ; TODO fix this result on C64 target, no newexpr! + txt.print("name2b fail!\n") - f = 0.0 - while f<2.2 { - f+=0.1 - } - floats.print_f(f) - txt.print(" 2.2\n") + +; txt.print("all 1: ") +; txt.print_ub(b == -100) +; txt.print_ub(b != -99) +; txt.print_ub(b < -99) +; txt.print_ub(b <= -100) +; txt.print_ub(b > -101) +; txt.print_ub(b >= -100) +; txt.print_ub(ub ==20) +; txt.print_ub(ub !=19) +; txt.print_ub(ub <21) +; txt.print_ub(ub <=20) +; txt.print_ub(ub>19) +; txt.print_ub(ub>=20) +; txt.spc() +; txt.print_ub(w == -20000) +; txt.print_ub(w != -19999) +; txt.print_ub(w < -19999) +; txt.print_ub(w <= -20000) +; txt.print_ub(w > -20001) +; txt.print_ub(w >= -20000) +; txt.print_ub(uw == 2000) +; txt.print_ub(uw != 2001) +; txt.print_ub(uw < 2001) +; txt.print_ub(uw <= 2000) +; txt.print_ub(uw > 1999) +; txt.print_ub(uw >= 2000) +; txt.spc() +; txt.print_ub(f == -100.0) +; txt.print_ub(f != -99.0) +; txt.print_ub(f < -99.0) +; txt.print_ub(f <= -100.0) +; txt.print_ub(f > -101.0) +; txt.print_ub(f >= -100.0) +; txt.nl() +; +; txt.print("all 0: ") +; txt.print_ub(b == -99) +; txt.print_ub(b != -100) +; txt.print_ub(b < -100) +; txt.print_ub(b <= -101) +; txt.print_ub(b > -100) +; txt.print_ub(b >= -99) +; txt.print_ub(ub ==21) +; txt.print_ub(ub !=20) +; txt.print_ub(ub <20) +; txt.print_ub(ub <=19) +; txt.print_ub(ub>20) +; txt.print_ub(ub>=21) +; txt.spc() +; txt.print_ub(w == -20001) +; txt.print_ub(w != -20000) +; txt.print_ub(w < -20000) +; txt.print_ub(w <= -20001) +; txt.print_ub(w > -20000) +; txt.print_ub(w >= -19999) +; txt.print_ub(uw == 1999) +; txt.print_ub(uw != 2000) +; txt.print_ub(uw < 2000) +; txt.print_ub(uw <= 1999) +; txt.print_ub(uw > 2000) +; txt.print_ub(uw >= 2001) +; txt.spc() +; txt.print_ub(f == -99.0) +; txt.print_ub(f != -100.0) +; txt.print_ub(f < -100.0) +; txt.print_ub(f <= -101.0) +; txt.print_ub(f > -100.0) +; txt.print_ub(f >= -99.0) +; txt.nl() +; +; ; TODO ALL OF THE ABOVE BUT WITH A VARIABLE INSTEAD OF A CONST VALUE +; +; +; b = -100 +; while b <= -20 +; b++ +; txt.print_b(b) +; txt.print(" -19\n") +; b = -100 +; while b < -20 +; b++ +; txt.print_b(b) +; txt.print(" -20\n") +; +; ub = 20 +; while ub <= 200 +; ub++ +; txt.print_ub(ub) +; txt.print(" 201\n") +; ub = 20 +; while ub < 200 +; ub++ +; txt.print_ub(ub) +; txt.print(" 200\n") +; +; w = -20000 +; while w <= -8000 { +; w++ +; } +; txt.print_w(w) +; txt.print(" -7999\n") +; w = -20000 +; while w < -8000 { +; w++ +; } +; txt.print_w(w) +; txt.print(" -8000\n") +; +; uw = 2000 +; while uw <= 8000 { +; uw++ +; } +; txt.print_uw(uw) +; txt.print(" 8001\n") +; uw = 2000 +; while uw < 8000 { +; uw++ +; } +; txt.print_uw(uw) +; txt.print(" 8000\n") +; +; f = 0.0 +; while f<2.2 { +; f+=0.1 +; } +; floats.print_f(f) +; txt.print(" 2.2\n") } }