string can be compared directly (uses strcmp() automatically in asm)

This commit is contained in:
Irmen de Jong 2020-10-17 01:30:49 +02:00
parent 2d3b7eb878
commit 7cb4100419
5 changed files with 131 additions and 42 deletions

View File

@ -1991,7 +1991,7 @@ strcpy .proc
strcmp_mem .proc strcmp_mem .proc
; -- compares strings in s1 (AY) and s2 (PZP_SCRATCH_W2). ; -- compares strings in s1 (AY) and s2 (P8ZP_SCRATCH_W2).
; Returns -1,0,1 in A, depeding on the ordering. Clobbers Y. ; Returns -1,0,1 in A, depeding on the ordering. Clobbers Y.
sta P8ZP_SCRATCH_W1 sta P8ZP_SCRATCH_W1
sty P8ZP_SCRATCH_W1+1 sty P8ZP_SCRATCH_W1+1

View File

@ -11,7 +11,7 @@ internal class LiteralsToAutoVars(private val program: Program) : AstWalker() {
private val noModifications = emptyList<IAstModification>() private val noModifications = emptyList<IAstModification>()
override fun after(string: StringLiteralValue, parent: Node): Iterable<IAstModification> { override fun after(string: StringLiteralValue, parent: Node): Iterable<IAstModification> {
if(string.parent !is VarDecl) { if(string.parent !is VarDecl && string.parent !is WhenChoice) {
// replace the literal string by a identifier reference to a new local vardecl // replace the literal string by a identifier reference to a new local vardecl
val vardecl = VarDecl.createAuto(string) val vardecl = VarDecl.createAuto(string)
val identifier = IdentifierReference(listOf(vardecl.name), vardecl.position) val identifier = IdentifierReference(listOf(vardecl.name), vardecl.position)

View File

@ -898,7 +898,10 @@ class WhenStatement(var condition: Expression,
if(choice.values==null) if(choice.values==null)
result.add(null to choice) result.add(null to choice)
else { else {
val values = choice.values!!.map { it.constValue(program)?.number?.toInt() } val values = choice.values!!.map {
val cv = it.constValue(program)
cv?.number?.toInt() ?: it.hashCode() // the hashcode is a nonsensical number but it avoids weird AST validation errors later
}
if(values.contains(null)) if(values.contains(null))
result.add(null to choice) result.add(null to choice)
else else
@ -924,9 +927,15 @@ class WhenChoice(var values: List<Expression>?, // if null, this is t
} }
override fun replaceChildNode(node: Node, replacement: Node) { override fun replaceChildNode(node: Node, replacement: Node) {
require(replacement is AnonymousScope && node===statements) val choiceValues = values
statements = replacement if(replacement is AnonymousScope && node===statements) {
replacement.parent = this statements = replacement
replacement.parent = this
} else if(choiceValues!=null && node in choiceValues) {
throw FatalAstException("cannot replace choice values")
} else {
throw FatalAstException("invalid replacement")
}
} }
override fun toString(): String { override fun toString(): String {

View File

@ -90,9 +90,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
in ByteDatatypes -> translateByteEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel) in ByteDatatypes -> translateByteEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
in WordDatatypes -> translateWordEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel) in WordDatatypes -> translateWordEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.FLOAT -> translateFloatEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel) DataType.FLOAT -> translateFloatEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.STR -> { DataType.STR -> translateStringEquals(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
TODO("strcmp ==")
}
else -> throw AssemblyError("weird operand datatype") else -> throw AssemblyError("weird operand datatype")
} }
} }
@ -134,9 +132,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
in ByteDatatypes -> translateByteNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel) in ByteDatatypes -> translateByteNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
in WordDatatypes -> translateWordNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel) in WordDatatypes -> translateWordNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.FLOAT -> translateFloatNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel) DataType.FLOAT -> translateFloatNotEquals(left, right, leftConstVal, rightConstVal, jumpIfFalseLabel)
DataType.STR -> { DataType.STR -> translateStringNotEquals(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
TODO("strcmp !=")
}
else -> throw AssemblyError("weird operand datatype") else -> throw AssemblyError("weird operand datatype")
} }
} }
@ -151,9 +147,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right) translateExpression(right)
asmgen.out(" jsr floats.less_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel") asmgen.out(" jsr floats.less_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
} }
DataType.STR -> { DataType.STR -> translateStringLess(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
TODO("strcmp <")
}
else -> throw AssemblyError("weird operand datatype") else -> throw AssemblyError("weird operand datatype")
} }
} }
@ -168,9 +162,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right) translateExpression(right)
asmgen.out(" jsr floats.lesseq_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel") asmgen.out(" jsr floats.lesseq_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
} }
DataType.STR -> { DataType.STR -> translateStringLessOrEqual(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
TODO("strcmp <=")
}
else -> throw AssemblyError("weird operand datatype") else -> throw AssemblyError("weird operand datatype")
} }
} }
@ -185,9 +177,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right) translateExpression(right)
asmgen.out(" jsr floats.greater_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel") asmgen.out(" jsr floats.greater_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
} }
DataType.STR -> { DataType.STR -> translateStringGreater(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
TODO("strcmp >")
}
else -> throw AssemblyError("weird operand datatype") else -> throw AssemblyError("weird operand datatype")
} }
} }
@ -202,9 +192,7 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
translateExpression(right) translateExpression(right)
asmgen.out(" jsr floats.greatereq_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel") asmgen.out(" jsr floats.greatereq_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
} }
DataType.STR -> { DataType.STR -> translateStringGreaterOrEqual(left as IdentifierReference, right as IdentifierReference, jumpIfFalseLabel)
TODO("strcmp >=")
}
else -> throw AssemblyError("weird operand datatype") else -> throw AssemblyError("weird operand datatype")
} }
} }
@ -960,6 +948,97 @@ internal class ExpressionsAsmGen(private val program: Program, private val asmge
asmgen.out(" jsr floats.notequal_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel") asmgen.out(" jsr floats.notequal_f | inx | lda P8ESTACK_LO,x | beq $jumpIfFalseLabel")
} }
private fun translateStringEquals(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
cmp #0
bne $jumpIfFalseLabel""")
}
private fun translateStringNotEquals(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
cmp #0
beq $jumpIfFalseLabel""")
}
private fun translateStringLess(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
bpl $jumpIfFalseLabel""")
}
private fun translateStringGreater(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
beq $jumpIfFalseLabel
bmi $jumpIfFalseLabel""")
}
private fun translateStringLessOrEqual(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
beq +
bpl $jumpIfFalseLabel
+""")
}
private fun translateStringGreaterOrEqual(left: IdentifierReference, right: IdentifierReference, jumpIfFalseLabel: String) {
val leftNam = asmgen.asmVariableName(left)
val rightNam = asmgen.asmVariableName(right)
asmgen.out("""
lda #<$rightNam
sta P8ZP_SCRATCH_W2
lda #>$rightNam
sta P8ZP_SCRATCH_W2+1
lda #<$leftNam
ldy #>$leftNam
jsr prog8_lib.strcmp_mem
beq +
bmi $jumpIfFalseLabel
+""")
}
private fun translateFunctionCallResultOntoStack(expression: FunctionCall) { private fun translateFunctionCallResultOntoStack(expression: FunctionCall) {
val functionName = expression.target.nameInSource.last() val functionName = expression.target.nameInSource.last()
val builtinFunc = BuiltinFunctions[functionName] val builtinFunc = BuiltinFunctions[functionName]

View File

@ -31,49 +31,50 @@ main {
; txt.chrout('\n') ; txt.chrout('\n')
if hex1==hex2 if hex1==hex2
goto endlab1 txt.print("1 fail ==\n")
else else
txt.print("not ==") txt.print("1 ok not ==\n")
endlab1: endlab1:
if hex1!=hex2 if hex1!=hex2
goto endlab2 txt.print("2 ok !==\n")
else else
txt.print("not !=") txt.print("2 fail not !=\n")
endlab2: endlab2:
if hex1>=hex2 if hex1>=hex2
goto endlab3 txt.print("3 ok >=\n")
else else
txt.print("not >=") txt.print("3 fail not >=\n")
endlab3: endlab3:
if hex1<=hex2 if hex1<=hex2
goto endlab4 txt.print("4 fail <=\n")
else else
txt.print("not <=") txt.print("4 ok not <=\n")
endlab4: endlab4:
if hex1>hex2 if hex1>hex2
goto endlab5 txt.print("5 ok >\n")
else else
txt.print("not >") txt.print("5 fail not >\n")
endlab5: endlab5:
if hex1<hex2 if hex1<hex2
goto endlab6 txt.print("5 fail <\n")
else else
txt.print("not <") txt.print("6 ok not <\n")
endlab6: endlab6:
txt.chrout('\n')
txt.print_ub(hex1==hex2) txt.print_ub(hex1==hex2)
txt.chrout('\n') txt.print(" 0?\n")
txt.print_ub(hex1!=hex2) txt.print_ub(hex1!=hex2)
txt.chrout('\n') txt.print(" 1?\n")
txt.print_ub(hex1>hex2) txt.print_ub(hex1>hex2)
txt.chrout('\n') txt.print(" 1?\n")
txt.print_ub(hex1<hex2) txt.print_ub(hex1<hex2)
txt.chrout('\n') txt.print(" 0?\n")
txt.print_ub(hex1>=hex2) txt.print_ub(hex1>=hex2)
txt.chrout('\n') txt.print(" 1?\n")
txt.print_ub(hex1<=hex2) txt.print_ub(hex1<=hex2)
txt.chrout('\n') txt.print(" 0?\n")
testX() testX()
} }