string can be compared directly (uses strcmp() automatically in asm)

This commit is contained in:
Irmen de Jong 2020-10-17 01:30:49 +02:00
parent 2d3b7eb878
commit 7cb4100419
5 changed files with 131 additions and 42 deletions

View File

@ -1991,7 +1991,7 @@ strcpy .proc
strcmp_mem .proc
; -- compares strings in s1 (AY) and s2 (PZP_SCRATCH_W2).
; -- compares strings in s1 (AY) and s2 (P8ZP_SCRATCH_W2).
; Returns -1,0,1 in A, depeding on the ordering. Clobbers Y.
sta P8ZP_SCRATCH_W1
sty P8ZP_SCRATCH_W1+1

View File

@ -11,7 +11,7 @@ internal class LiteralsToAutoVars(private val program: Program) : AstWalker() {
private val noModifications = emptyList<IAstModification>()
override fun after(string: StringLiteralValue, parent: Node): Iterable<IAstModification> {
if(string.parent !is VarDecl) {
if(string.parent !is VarDecl && string.parent !is WhenChoice) {
// replace the literal string by a identifier reference to a new local vardecl
val vardecl = VarDecl.createAuto(string)
val identifier = IdentifierReference(listOf(vardecl.name), vardecl.position)

View File

@ -898,7 +898,10 @@ class WhenStatement(var condition: Expression,
if(choice.values==null)
result.add(null to choice)
else {
val values = choice.values!!.map { it.constValue(program)?.number?.toInt() }
val values = choice.values!!.map {
val cv = it.constValue(program)
cv?.number?.toInt() ?: it.hashCode() // the hashcode is a nonsensical number but it avoids weird AST validation errors later
}
if(values.contains(null))
result.add(null to choice)
else
@ -924,9 +927,15 @@ class WhenChoice(var values: List<Expression>?, // if null, this is t
}
override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is AnonymousScope && node===statements)
statements = replacement
replacement.parent = this
val choiceValues = values
if(replacement is AnonymousScope && node===statements) {
statements = replacement
replacement.parent = this
} else if(choiceValues!=null && node in choiceValues) {
throw FatalAstException("cannot replace choice values")
} else {
throw FatalAstException("invalid replacement")
}
}
override fun toString(): String {

View File

@ -90,9 +90,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
in ByteDatatypes -> translateByteEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
in WordDatatypes -> translateWordEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.FLOAT -> translateFloatEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.STR -> {
TODO("strcmp ==")
}
DataType.STR -> translateStringEquals(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
else -> throw AssemblyError("weird operand datatype")
}
}
@ -134,9 +132,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
in ByteDatatypes -> translateByteNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
in WordDatatypes -> translateWordNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.FLOAT -> translateFloatNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.STR -> {
TODO("strcmp !=")
}
DataType.STR -> translateStringNotEquals(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
else -> throw AssemblyError("weird operand datatype")
}
}
@ -151,9 +147,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right)
asmgen.out(" jsr floats.less_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
}
DataType.STR -> {
TODO("strcmp <")
}
DataType.STR -> translateStringLess(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
else -> throw AssemblyError("weird operand datatype")
}
}
@ -168,9 +162,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right)
asmgen.out(" jsr floats.lesseq_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
}
DataType.STR -> {
TODO("strcmp <=")
}
DataType.STR -> translateStringLessOrEqual(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
else -> throw AssemblyError("weird operand datatype")
}
}
@ -185,9 +177,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right)
asmgen.out(" jsr floats.greater_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
}
DataType.STR -> {
TODO("strcmp >")
}
DataType.STR -> translateStringGreater(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
else -> throw AssemblyError("weird operand datatype")
}
}
@ -202,9 +192,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right)
asmgen.out(" jsr floats.greatereq_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
}
DataType.STR -> {
TODO("strcmp >=")
}
DataType.STR -> translateStringGreaterOrEqual(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
else -> throw AssemblyError("weird operand datatype")
}
}
@ -960,6 +948,97 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
asmgen.out(" jsr floats.notequal_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
}
private fun translateStringEquals(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
cmp #0
bne $jumpIfFalseLabel""")
}
private fun translateStringNotEquals(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
cmp #0
beq $jumpIfFalseLabel""")
}
private fun translateStringLess(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
bpl $jumpIfFalseLabel""")
}
private fun translateStringGreater(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
beq $jumpIfFalseLabel
bmi $jumpIfFalseLabel""")
}
private fun translateStringLessOrEqual(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
beq +
bpl $jumpIfFalseLabel
+""")
}
private fun translateStringGreaterOrEqual(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
beq +
bmi $jumpIfFalseLabel
+""")
}
private fun translateFunctionCallResultOntoStack(expression: FunctionCall) {
val functionName = expression.target.nameInSource.last()
val builtinFunc = BuiltinFunctions[functionName]

View File

@ -31,49 +31,50 @@ main {
; txt.chrout('\n')
if hex1==hex2
goto endlab1
txt.print("1 fail ==\n")
else
txt.print("not ==")
txt.print("1 ok not ==\n")
endlab1:
if hex1!=hex2
goto endlab2
txt.print("2 ok !==\n")
else
txt.print("not !=")
txt.print("2 fail not !=\n")
endlab2:
if hex1>=hex2
goto endlab3
txt.print("3 ok >=\n")
else
txt.print("not >=")
txt.print("3 fail not >=\n")
endlab3:
if hex1<=hex2
goto endlab4
txt.print("4 fail <=\n")
else
txt.print("not <=")
txt.print("4 ok not <=\n")
endlab4:
if hex1>hex2
goto endlab5
txt.print("5 ok >\n")
else
txt.print("not >")
txt.print("5 fail not >\n")
endlab5:
if hex1<hex2
goto endlab6
txt.print("5 fail <\n")
else
txt.print("not <")
txt.print("6 ok not <\n")
endlab6:
txt.chrout('\n')
txt.print_ub(hex1==hex2)
txt.chrout('\n')
txt.print(" 0?\n")
txt.print_ub(hex1!=hex2)
txt.chrout('\n')
txt.print(" 1?\n")
txt.print_ub(hex1>hex2)
txt.chrout('\n')
txt.print(" 1?\n")
txt.print_ub(hex1<hex2)
txt.chrout('\n')
txt.print(" 0?\n")
txt.print_ub(hex1>=hex2)
txt.chrout('\n')
txt.print(" 1?\n")
txt.print_ub(hex1<=hex2)
txt.chrout('\n')
txt.print(" 0?\n")
testX()
}