uword == str is now possible (sugar for string.compare)

This commit is contained in:
Irmen de Jong 2023-06-22 00:20:30 +02:00
parent a587482edf
commit 04e4e71f2e
8 changed files with 86 additions and 62 deletions

View File

@ -473,7 +473,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
addInstr(result, IRInstruction(ins, IRDataType.BYTE, reg1 = resultRegister, reg2 = zeroRegister), null)
return ExpressionCodeResult(result, IRDataType.BYTE, resultRegister, -1)
} else {
if(binExpr.left.type==DataType.STR && binExpr.right.type==DataType.STR) {
if(binExpr.left.type==DataType.STR || binExpr.right.type==DataType.STR) {
throw AssemblyError("str compares should have been replaced with builtin function call to do the compare")
} else {
val leftTr = translateExpression(binExpr.left)
@ -515,7 +515,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
addInstr(result, IRInstruction(ins, IRDataType.BYTE, reg1 = resultRegister, reg2 = zeroRegister), null)
return ExpressionCodeResult(result, IRDataType.BYTE, resultRegister, -1)
} else {
if(binExpr.left.type==DataType.STR && binExpr.right.type==DataType.STR) {
if(binExpr.left.type==DataType.STR || binExpr.right.type==DataType.STR) {
throw AssemblyError("str compares should have been replaced with builtin function call to do the compare")
} else {
val leftTr = translateExpression(binExpr.left)
@ -554,7 +554,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
result += IRCodeChunk(label, null)
return ExpressionCodeResult(result, IRDataType.BYTE, resultRegister, -1)
} else {
if(binExpr.left.type==DataType.STR && binExpr.right.type==DataType.STR) {
if(binExpr.left.type==DataType.STR || binExpr.right.type==DataType.STR) {
throw AssemblyError("str compares should have been replaced with builtin function call to do the compare")
} else {
return if(constValue(binExpr.right)==0.0) {

View File

@ -966,6 +966,8 @@ internal class AstChecker(private val program: Program,
// expression with one side BOOL other side (U)BYTE is allowed; bool==byte
} else if((expr.operator == "<<" || expr.operator == ">>") && (leftDt in WordDatatypes && rightDt in ByteDatatypes)) {
// exception allowed: shifting a word by a byte
} else if((leftDt==DataType.UWORD && rightDt==DataType.STR) || (leftDt==DataType.STR && rightDt==DataType.UWORD)) {
// exception allowed: comparing uword (pointer) with string
} else {
errors.err("left and right operands aren't the same type", expr.left.position)
}

View File

@ -36,19 +36,6 @@ internal class BeforeAsmAstChanger(val program: Program,
return noModifications
}
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
if(expr.operator in ComparisonOperators && expr.left.inferType(program) istype DataType.STR && expr.right.inferType(program) istype DataType.STR) {
// replace string comparison expressions with calls to string.compare()
val stringCompare = BuiltinFunctionCall(
IdentifierReference(listOf("prog8_lib_stringcompare"), expr.position),
mutableListOf(expr.left.copy(), expr.right.copy()), expr.position)
val zero = NumericLiteral.optimalInteger(0, expr.position)
val comparison = BinaryExpression(stringCompare, expr.operator, zero, expr.position)
return listOf(IAstModification.ReplaceNode(expr, comparison, parent))
}
return noModifications
}
override fun after(decl: VarDecl, parent: Node): Iterable<IAstModification> {
if (decl.type == VarDeclType.VAR && decl.value != null && decl.datatype in NumericDatatypes)
throw InternalCompilerException("vardecls for variables, with initial numerical value, should have been rewritten as plain vardecl + assignment $decl")

View File

@ -5,6 +5,7 @@ import prog8.ast.expressions.*
import prog8.ast.statements.*
import prog8.ast.walk.AstWalker
import prog8.ast.walk.IAstModification
import prog8.code.core.ComparisonOperators
import prog8.code.core.DataType
import prog8.code.core.IErrorReporter
import prog8.code.core.Position
@ -173,10 +174,32 @@ _after:
}
override fun after(expr: BinaryExpression, parent: Node): Iterable<IAstModification> {
fun isStringComparison(leftDt: InferredTypes.InferredType, rightDt: InferredTypes.InferredType): Boolean =
if(leftDt istype DataType.STR && rightDt istype DataType.STR)
true
else
leftDt istype DataType.UWORD && rightDt istype DataType.STR || leftDt istype DataType.STR && rightDt istype DataType.UWORD
if(expr.operator=="in") {
val containment = ContainmentCheck(expr.left, expr.right, expr.position)
return listOf(IAstModification.ReplaceNode(expr, containment, parent))
}
if(expr.operator in ComparisonOperators) {
val leftDt = expr.left.inferType(program)
val rightDt = expr.right.inferType(program)
if(isStringComparison(leftDt, rightDt)) {
// replace string comparison expressions with calls to string.compare()
val stringCompare = BuiltinFunctionCall(
IdentifierReference(listOf("prog8_lib_stringcompare"), expr.position),
mutableListOf(expr.left.copy(), expr.right.copy()), expr.position)
val zero = NumericLiteral.optimalInteger(0, expr.position)
val comparison = BinaryExpression(stringCompare, expr.operator, zero, expr.position)
return listOf(IAstModification.ReplaceNode(expr, comparison, parent))
}
}
return noModifications
}
}

View File

@ -13,6 +13,7 @@ import prog8.ast.statements.VarDecl
import prog8.code.core.DataType
import prog8.code.core.Position
import prog8.code.target.C64Target
import prog8.code.target.VMTarget
import prog8tests.helpers.ErrorReporterForTests
import prog8tests.helpers.compileText
@ -70,25 +71,42 @@ main {
compileText(C64Target(), false, text, writeAssembly = false) shouldBe null
}
test("simple string comparison still works") {
test("string comparisons") {
val src="""
main {
sub start() {
ubyte @shared value
str thing = "????"
if thing=="name" {
value++
}
if thing!="name" {
value++
}
}
}"""
main {
sub start() {
str name = "name"
uword nameptr = &name
cx16.r0L= name=="foo"
cx16.r1L= name!="foo"
cx16.r2L= name<"foo"
cx16.r3L= name>"foo"
cx16.r0L= nameptr=="foo"
cx16.r1L= nameptr!="foo"
cx16.r2L= nameptr<"foo"
cx16.r3L= nameptr>"foo"
void compare(name, "foo")
void compare(name, "name")
void compare(nameptr, "foo")
void compare(nameptr, "name")
}
sub compare(str s1, str s2) -> ubyte {
if s1==s2
return 42
return 0
}
}"""
val result = compileText(C64Target(), optimize=false, src, writeAssembly=true)!!
val stmts = result.compilerAst.entrypoint.statements
stmts.size shouldBe 6
stmts.size shouldBe 16
val result2 = compileText(VMTarget(), optimize=false, src, writeAssembly=true)!!
val stmts2 = result2.compilerAst.entrypoint.statements
stmts2.size shouldBe 16
}
test("string concatenation and repeats") {

View File

@ -218,6 +218,7 @@ Provides string manipulation routines.
Returns -1, 0 or 1 depending on whether string1 sorts before, equal or after string2.
Note that you can also directly compare strings and string values with each other
using ``==``, ``<`` etcetera (it will use string.compare for you under water automatically).
This even works when dealing with uword (pointer) variables when comparing them to a string type.
``copy (from, to) -> ubyte length``
Copy a string to another, overwriting that one. Returns the length of the string that was copied.

View File

@ -1,6 +1,8 @@
TODO
====
- replace all the string.compare calls in rockrunner with equalites
...

View File

@ -4,38 +4,29 @@
main {
sub start() {
txt.print_ub(danglingelse(32))
txt.spc()
txt.print_ub(danglingelse(99))
txt.spc()
txt.print_ub(danglingelse(1))
txt.spc()
txt.print_ub(danglingelse(100))
txt.nl()
txt.print_ub(danglingelse2(32))
txt.spc()
txt.print_ub(danglingelse2(99))
txt.spc()
txt.print_ub(danglingelse2(1))
txt.spc()
txt.print_ub(danglingelse2(100))
txt.nl()
str name = "name"
uword nameptr = &name
cx16.r0L= name=="foo"
cx16.r1L= name!="foo"
cx16.r2L= name<"foo"
cx16.r3L= name>"foo"
cx16.r0L= nameptr=="foo"
cx16.r1L= nameptr!="foo"
cx16.r2L= nameptr<"foo"
cx16.r3L= nameptr>"foo"
void compare(name, "foo")
void compare(name, "name")
void compare(nameptr, "foo")
void compare(nameptr, "name")
}
sub danglingelse(ubyte bb) -> ubyte {
if bb==32
return 32
else if bb==99
return 99
else
return 0
}
sub danglingelse2(ubyte bb) -> ubyte {
if bb==32
return 32
if bb==99
return 99
sub compare(str s1, str s2) -> ubyte {
if s1==s2
return 42
return 0
}
}