removed a problematic bitshift replacement in the const evaluator

This commit is contained in:
Irmen de Jong 2024-01-21 23:05:51 +01:00
parent 87c46ba730
commit a8c09d6144
3 changed files with 79 additions and 45 deletions

View File

@ -251,48 +251,6 @@ class ConstantFoldingOptimizer(private val program: Program, private val errors:
}
}
if(rightconst!=null && (expr.operator=="<<" || expr.operator==">>")) {
val dt = expr.left.inferType(program)
if(dt.isBytes && rightconst.number>=8) {
if(dt.istype(DataType.UBYTE)) {
val zeroUB = NumericLiteral(DataType.UBYTE, 0.0, expr.position)
modifications.add(IAstModification.ReplaceNode(expr, zeroUB, parent))
} else {
if(leftconst!=null) {
val zeroB = NumericLiteral(DataType.BYTE, 0.0, expr.position)
val minusoneB = NumericLiteral(DataType.BYTE, -1.0, expr.position)
if(leftconst.number<0.0) {
if(expr.operator=="<<")
modifications.add(IAstModification.ReplaceNode(expr, zeroB, parent))
else
modifications.add(IAstModification.ReplaceNode(expr, minusoneB, parent))
} else {
modifications.add(IAstModification.ReplaceNode(expr, zeroB, parent))
}
}
}
}
else if(dt.isWords && rightconst.number>=16) {
if(dt.istype(DataType.UWORD)) {
val zeroUW = NumericLiteral(DataType.UWORD, 0.0, expr.position)
modifications.add(IAstModification.ReplaceNode(expr, zeroUW, parent))
} else {
if(leftconst!=null) {
val zeroW = NumericLiteral(DataType.WORD, 0.0, expr.position)
val minusoneW = NumericLiteral(DataType.WORD, -1.0, expr.position)
if(leftconst.number<0.0) {
if(expr.operator=="<<")
modifications.add(IAstModification.ReplaceNode(expr, zeroW, parent))
else
modifications.add(IAstModification.ReplaceNode(expr, minusoneW, parent))
} else {
modifications.add(IAstModification.ReplaceNode(expr, zeroW, parent))
}
}
}
}
}
return modifications
}

View File

@ -817,4 +817,59 @@ main {
if5.condition shouldBe instanceOf<BinaryExpression>()
if6.condition shouldBe instanceOf<BinaryExpression>()
}
test("funky bitshifts") {
val src="""
main {
sub start() {
const uword one = 1
const uword two = 2
uword @shared answer = one * two >> 8
funcw(one * two >> 8)
const uword uw1 = 99
const uword uw2 = 22
uword @shared answer2 = uw1 * uw2 >> 8 ; optimized into msb(uw1*uw2) as uword
funcw(uw1 * uw2 >> 8)
uword @shared uw3 = 99
uword @shared uw4 = 22
uword @shared answer3 = uw3 * uw4 >> 8 ; optimized into msb(uw1*uw2) as uword
funcw(uw3 * uw4 >> 8)
}
sub funcw(uword ww) {
cx16.r0++
}
}"""
val result = compileText(Cx16Target(), true, src, writeAssembly = false)!!
val st = result.compilerAst.entrypoint.statements
st.size shouldBe 17
val answerValue = (st[3] as Assignment).value
answerValue shouldBe NumericLiteral(DataType.UWORD, 0.0, Position.DUMMY)
val funcarg1 = (st[4] as FunctionCallStatement).args.single()
funcarg1 shouldBe NumericLiteral(DataType.UWORD, 0.0, Position.DUMMY)
val answer2Value = (st[8] as Assignment).value
answer2Value shouldBe NumericLiteral(DataType.UWORD, 8.0, Position.DUMMY)
val funcarg2 = (st[9] as FunctionCallStatement).args.single()
funcarg2 shouldBe NumericLiteral(DataType.UWORD, 8.0, Position.DUMMY)
val answer3ValueTc = (st[15] as Assignment).value as TypecastExpression
answer3ValueTc.type shouldBe DataType.UWORD
val answer3Value = answer3ValueTc.expression as FunctionCallExpression
answer3Value.target.nameInSource shouldBe listOf("msb")
answer3Value.args.single() shouldBe instanceOf<BinaryExpression>()
val funcarg3tc = (st[16] as FunctionCallStatement).args.single() as TypecastExpression
funcarg3tc.type shouldBe DataType.UWORD
val funcarg3 = funcarg3tc.expression as FunctionCallExpression
funcarg3.target.nameInSource shouldBe listOf("msb")
funcarg3.args.single() shouldBe instanceOf<BinaryExpression>()
}
})

View File

@ -1,9 +1,30 @@
%import textio
%zeropage basicsafe
%option no_sysinit
main {
sub start() {
ubyte[10] uba = [1,2,3]
bool[10] bba = [true, false, true]
const uword one = 1
const uword two = 2
uword @shared answer = one * two >> 8
txt.print_uw(answer)
txt.spc()
txt.print_uw(one * two >> 8)
txt.nl()
const uword uw1 = 99
const uword uw2 = 22
uword @shared answer2 = uw1 * uw2 >> 8
txt.print_uw(answer2)
txt.spc()
txt.print_uw(uw1 * uw2 >> 8)
txt.nl()
uword @shared uw3 = 99
uword @shared uw4 = 22
uword @shared answer3 = uw3 * uw4 >> 8
txt.print_uw(answer3)
txt.spc()
txt.print_uw(uw3 * uw4 >> 8)
txt.nl()
}
}