diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt index 08824e846..298b8318b 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt @@ -473,7 +473,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { addInstr(result, IRInstruction(ins, IRDataType.BYTE, reg1 = resultRegister, reg2 = zeroRegister), null) return ExpressionCodeResult(result, IRDataType.BYTE, resultRegister, -1) } else { - if(binExpr.left.type==DataType.STR && binExpr.right.type==DataType.STR) { + if(binExpr.left.type==DataType.STR || binExpr.right.type==DataType.STR) { throw AssemblyError("str compares should have been replaced with builtin function call to do the compare") } else { val leftTr = translateExpression(binExpr.left) @@ -515,7 +515,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { addInstr(result, IRInstruction(ins, IRDataType.BYTE, reg1 = resultRegister, reg2 = zeroRegister), null) return ExpressionCodeResult(result, IRDataType.BYTE, resultRegister, -1) } else { - if(binExpr.left.type==DataType.STR && binExpr.right.type==DataType.STR) { + if(binExpr.left.type==DataType.STR || binExpr.right.type==DataType.STR) { throw AssemblyError("str compares should have been replaced with builtin function call to do the compare") } else { val leftTr = translateExpression(binExpr.left) @@ -554,7 +554,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { result += IRCodeChunk(label, null) return ExpressionCodeResult(result, IRDataType.BYTE, resultRegister, -1) } else { - if(binExpr.left.type==DataType.STR && binExpr.right.type==DataType.STR) { + if(binExpr.left.type==DataType.STR || binExpr.right.type==DataType.STR) { throw AssemblyError("str compares should have been replaced with builtin function call to do the compare") } else { return if(constValue(binExpr.right)==0.0) { diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index c78703ca4..7a3a6bd54 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -966,6 +966,8 @@ internal class AstChecker(private val program: Program, // expression with one side BOOL other side (U)BYTE is allowed; bool==byte } else if((expr.operator == "<<" || expr.operator == ">>") && (leftDt in WordDatatypes && rightDt in ByteDatatypes)) { // exception allowed: shifting a word by a byte + } else if((leftDt==DataType.UWORD && rightDt==DataType.STR) || (leftDt==DataType.STR && rightDt==DataType.UWORD)) { + // exception allowed: comparing uword (pointer) with string } else { errors.err("left and right operands aren't the same type", expr.left.position) } diff --git a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt index d4369edce..e2a06d0d3 100644 --- a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt +++ b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt @@ -36,19 +36,6 @@ internal class BeforeAsmAstChanger(val program: Program, return noModifications } - override fun after(expr: BinaryExpression, parent: Node): Iterable { - if(expr.operator in ComparisonOperators && expr.left.inferType(program) istype DataType.STR && expr.right.inferType(program) istype DataType.STR) { - // replace string comparison expressions with calls to string.compare() - val stringCompare = BuiltinFunctionCall( - IdentifierReference(listOf("prog8_lib_stringcompare"), expr.position), - mutableListOf(expr.left.copy(), expr.right.copy()), expr.position) - val zero = NumericLiteral.optimalInteger(0, expr.position) - val comparison = BinaryExpression(stringCompare, expr.operator, zero, expr.position) - return listOf(IAstModification.ReplaceNode(expr, comparison, parent)) - } - return noModifications - } - override fun after(decl: VarDecl, parent: Node): Iterable { if (decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes) throw InternalCompilerException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl") diff --git a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt index b8d081bc3..7b9a5cb52 100644 --- a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt +++ b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt @@ -5,6 +5,7 @@ import prog8.ast.expressions.* import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification +import prog8.code.core.ComparisonOperators import prog8.code.core.DataType import prog8.code.core.IErrorReporter import prog8.code.core.Position @@ -173,10 +174,32 @@ _after: } override fun after(expr: BinaryExpression, parent: Node): Iterable { + fun isStringComparison(leftDt: InferredTypes.InferredType, rightDt: InferredTypes.InferredType): Boolean = + if(leftDt istype DataType.STR && rightDt istype DataType.STR) + true + else + leftDt istype DataType.UWORD && rightDt istype DataType.STR || leftDt istype DataType.STR && rightDt istype DataType.UWORD + if(expr.operator=="in") { val containment = ContainmentCheck(expr.left, expr.right, expr.position) return listOf(IAstModification.ReplaceNode(expr, containment, parent)) } + + if(expr.operator in ComparisonOperators) { + val leftDt = expr.left.inferType(program) + val rightDt = expr.right.inferType(program) + + if(isStringComparison(leftDt, rightDt)) { + // replace string comparison expressions with calls to string.compare() + val stringCompare = BuiltinFunctionCall( + IdentifierReference(listOf("prog8_lib_stringcompare"), expr.position), + mutableListOf(expr.left.copy(), expr.right.copy()), expr.position) + val zero = NumericLiteral.optimalInteger(0, expr.position) + val comparison = BinaryExpression(stringCompare, expr.operator, zero, expr.position) + return listOf(IAstModification.ReplaceNode(expr, comparison, parent)) + } + } + return noModifications } } diff --git a/compiler/test/ast/TestVariousCompilerAst.kt b/compiler/test/ast/TestVariousCompilerAst.kt index ebdea7e7b..a3d0cdaa0 100644 --- a/compiler/test/ast/TestVariousCompilerAst.kt +++ b/compiler/test/ast/TestVariousCompilerAst.kt @@ -13,6 +13,7 @@ import prog8.ast.statements.VarDecl import prog8.code.core.DataType import prog8.code.core.Position import prog8.code.target.C64Target +import prog8.code.target.VMTarget import prog8tests.helpers.ErrorReporterForTests import prog8tests.helpers.compileText @@ -70,25 +71,42 @@ main { compileText(C64Target(), false, text, writeAssembly = false) shouldBe null } - test("simple string comparison still works") { + test("string comparisons") { val src=""" - main { - sub start() { - ubyte @shared value - str thing = "????" - - if thing=="name" { - value++ - } - - if thing!="name" { - value++ - } - } - }""" +main { + + sub start() { + str name = "name" + uword nameptr = &name + + cx16.r0L= name=="foo" + cx16.r1L= name!="foo" + cx16.r2L= name<"foo" + cx16.r3L= name>"foo" + + cx16.r0L= nameptr=="foo" + cx16.r1L= nameptr!="foo" + cx16.r2L= nameptr<"foo" + cx16.r3L= nameptr>"foo" + + void compare(name, "foo") + void compare(name, "name") + void compare(nameptr, "foo") + void compare(nameptr, "name") + } + + sub compare(str s1, str s2) -> ubyte { + if s1==s2 + return 42 + return 0 + } +}""" val result = compileText(C64Target(), optimize=false, src, writeAssembly=true)!! val stmts = result.compilerAst.entrypoint.statements - stmts.size shouldBe 6 + stmts.size shouldBe 16 + val result2 = compileText(VMTarget(), optimize=false, src, writeAssembly=true)!! + val stmts2 = result2.compilerAst.entrypoint.statements + stmts2.size shouldBe 16 } test("string concatenation and repeats") { diff --git a/docs/source/libraries.rst b/docs/source/libraries.rst index 31da82991..4507019d6 100644 --- a/docs/source/libraries.rst +++ b/docs/source/libraries.rst @@ -218,6 +218,7 @@ Provides string manipulation routines. Returns -1, 0 or 1 depending on whether string1 sorts before, equal or after string2. Note that you can also directly compare strings and string values with each other using ``==``, ``<`` etcetera (it will use string.compare for you under water automatically). + This even works when dealing with uword (pointer) variables when comparing them to a string type. ``copy (from, to) -> ubyte length`` Copy a string to another, overwriting that one. Returns the length of the string that was copied. diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 7e04dc0cf..a1fa6e3d2 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -1,6 +1,8 @@ TODO ==== +- replace all the string.compare calls in rockrunner with equalites + ... diff --git a/examples/test.p8 b/examples/test.p8 index ba7dbbbd1..411179e0e 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -4,38 +4,29 @@ main { sub start() { - txt.print_ub(danglingelse(32)) - txt.spc() - txt.print_ub(danglingelse(99)) - txt.spc() - txt.print_ub(danglingelse(1)) - txt.spc() - txt.print_ub(danglingelse(100)) - txt.nl() - txt.print_ub(danglingelse2(32)) - txt.spc() - txt.print_ub(danglingelse2(99)) - txt.spc() - txt.print_ub(danglingelse2(1)) - txt.spc() - txt.print_ub(danglingelse2(100)) - txt.nl() + str name = "name" + uword nameptr = &name + + cx16.r0L= name=="foo" + cx16.r1L= name!="foo" + cx16.r2L= name<"foo" + cx16.r3L= name>"foo" + + cx16.r0L= nameptr=="foo" + cx16.r1L= nameptr!="foo" + cx16.r2L= nameptr<"foo" + cx16.r3L= nameptr>"foo" + + void compare(name, "foo") + void compare(name, "name") + void compare(nameptr, "foo") + void compare(nameptr, "name") } - sub danglingelse(ubyte bb) -> ubyte { - if bb==32 - return 32 - else if bb==99 - return 99 - else - return 0 - } - - sub danglingelse2(ubyte bb) -> ubyte { - if bb==32 - return 32 - if bb==99 - return 99 + sub compare(str s1, str s2) -> ubyte { + if s1==s2 + return 42 return 0 } } +