optimize division by powers of 2 better (into bit shifts)

This commit is contained in:
Irmen de Jong
2024-07-21 20:42:48 +02:00
parent 0af17cdc33
commit 3681d6ee1c
7 changed files with 104 additions and 89 deletions

View File

@@ -3,7 +3,7 @@ package prog8.code.core
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow import kotlin.math.pow
val powersOfTwoFloat = (1..16).map { (2.0).pow(it) }.toTypedArray() val powersOfTwoFloat = (0..16).map { (2.0).pow(it) }.toTypedArray()
val negativePowersOfTwoFloat = powersOfTwoFloat.map { -it }.toTypedArray() val negativePowersOfTwoFloat = powersOfTwoFloat.map { -it }.toTypedArray()
val powersOfTwoInt = (0..16).map { 2.0.pow(it).toInt() }.toTypedArray() val powersOfTwoInt = (0..16).map { 2.0.pow(it).toInt() }.toTypedArray()

View File

@@ -933,11 +933,7 @@ internal class AssignmentAsmGen(private val program: PtProgram,
} }
private fun optimizedDivideExpr(expr: PtBinaryExpression, target: AsmAssignTarget): Boolean { private fun optimizedDivideExpr(expr: PtBinaryExpression, target: AsmAssignTarget): Boolean {
val constDivisor = expr.right.asConstInteger() // replacing division by shifting is done in an optimizer step.
if(constDivisor in powersOfTwoInt) {
println("TODO optimize: divide ${expr.type} by power-of-2 ${constDivisor} at ${expr.position}") // TODO
}
when(expr.type) { when(expr.type) {
DataType.UBYTE -> { DataType.UBYTE -> {
assignExpressionToRegister(expr.left, RegisterOrPair.A, false) assignExpressionToRegister(expr.left, RegisterOrPair.A, false)

View File

@@ -1477,9 +1477,7 @@ $shortcutLabel:""")
asmgen.out(" lda $name | ldy #$value | jsr math.multiply_bytes | sta $name") asmgen.out(" lda $name | ldy #$value | jsr math.multiply_bytes | sta $name")
} }
"/" -> { "/" -> {
if(value in powersOfTwoInt) { // replacing division by shifting is done in an optimizer step.
println("TODO optimize: (u)byte division by power-of-2 $value") // TODO
}
if (dt == DataType.UBYTE) if (dt == DataType.UBYTE)
asmgen.out(" lda $name | ldy #$value | jsr math.divmod_ub_asm | sty $name") asmgen.out(" lda $name | ldy #$value | jsr math.divmod_ub_asm | sty $name")
else else
@@ -1828,36 +1826,36 @@ $shortcutLabel:""")
} }
} }
"/" -> { "/" -> {
if(value==0) // replacing division by shifting is done in an optimizer step.
if(value==0) {
throw AssemblyError("division by zero") throw AssemblyError("division by zero")
else if (value in powersOfTwoInt) { } else {
println("TODO optimize: (u)word division by power-of-2 $value") // TODO if(dt==DataType.WORD) {
} asmgen.out("""
if(dt==DataType.WORD) { lda $lsb
asmgen.out(""" ldy $msb
lda $lsb sta P8ZP_SCRATCH_W1
ldy $msb sty P8ZP_SCRATCH_W1+1
sta P8ZP_SCRATCH_W1 lda #<$value
sty P8ZP_SCRATCH_W1+1 ldy #>$value
lda #<$value jsr math.divmod_w_asm
ldy #>$value sta $lsb
jsr math.divmod_w_asm sty $msb
sta $lsb """)
sty $msb }
""") else {
} asmgen.out("""
else { lda $lsb
asmgen.out(""" ldy $msb
lda $lsb sta P8ZP_SCRATCH_W1
ldy $msb sty P8ZP_SCRATCH_W1+1
sta P8ZP_SCRATCH_W1 lda #<$value
sty P8ZP_SCRATCH_W1+1 ldy #>$value
lda #<$value jsr math.divmod_uw_asm
ldy #>$value sta $lsb
jsr math.divmod_uw_asm sty $msb
sta $lsb """)
sty $msb }
""")
} }
} }
"%" -> { "%" -> {

View File

@@ -783,19 +783,31 @@ class IRCodeGen(
if(factor==1) if(factor==1)
return code return code
val pow2 = powersOfTwoInt.indexOf(factor) val pow2 = powersOfTwoInt.indexOf(factor)
// TODO also try to optimize for signed division by powers of 2 if(pow2>=0) {
if(pow2==1 && !signed) { if(signed) {
code += IRInstruction(Opcode.LSR, dt, reg1=reg) // simple single bit shift if(pow2==1) {
} // simple single bit shift (signed)
else if(pow2>=1 &&!signed) { code += IRInstruction(Opcode.ASR, dt, reg1=reg)
// just shift multiple bits (unsigned) } else {
val pow2reg = registers.nextFree() // just shift multiple bits (signed)
code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2) val pow2reg = registers.nextFree()
code += if(signed) code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2)
IRInstruction(Opcode.ASRN, dt, reg1=reg, reg2=pow2reg) code += IRInstruction(Opcode.ASRN, dt, reg1=reg, reg2=pow2reg)
else }
IRInstruction(Opcode.LSRN, dt, reg1=reg, reg2=pow2reg) } else {
if(pow2==1) {
// simple single bit shift (unsigned)
code += IRInstruction(Opcode.LSR, dt, reg1=reg)
} else {
// just shift multiple bits (unsigned)
val pow2reg = registers.nextFree()
code += IRInstruction(Opcode.LOAD, dt, reg1 = pow2reg, immediate = pow2)
code += IRInstruction(Opcode.LSRN, dt, reg1 = reg, reg2 = pow2reg)
}
}
return code
} else { } else {
// regular div
code += if (factor == 0) { code += if (factor == 0) {
IRInstruction(Opcode.LOAD, dt, reg1=reg, immediate = 0xffff) IRInstruction(Opcode.LOAD, dt, reg1=reg, immediate = 0xffff)
} else { } else {
@@ -804,8 +816,8 @@ class IRCodeGen(
else else
IRInstruction(Opcode.DIV, dt, reg1=reg, immediate = factor) IRInstruction(Opcode.DIV, dt, reg1=reg, immediate = factor)
} }
return code
} }
return code
} }
internal fun divideByConstInplace(dt: IRDataType, knownAddress: Int?, symbol: String?, factor: Int, signed: Boolean): IRCodeChunk { internal fun divideByConstInplace(dt: IRDataType, knownAddress: Int?, symbol: String?, factor: Int, signed: Boolean): IRCodeChunk {
@@ -813,31 +825,47 @@ class IRCodeGen(
if(factor==1) if(factor==1)
return code return code
val pow2 = powersOfTwoInt.indexOf(factor) val pow2 = powersOfTwoInt.indexOf(factor)
// TODO also try to optimize for signed division by powers of 2 if(pow2>=0) {
if(pow2==1 && !signed) { // can do bit shift instead of division
// just simple bit shift if(signed) {
code += if(knownAddress!=null) if(pow2==1) {
IRInstruction(Opcode.LSRM, dt, address = knownAddress) // just simple bit shift (signed)
else code += if (knownAddress != null)
IRInstruction(Opcode.LSRM, dt, labelSymbol = symbol) IRInstruction(Opcode.ASRM, dt, address = knownAddress)
else
IRInstruction(Opcode.ASRM, dt, labelSymbol = symbol)
} else {
// just shift multiple bits (signed)
val pow2reg = registers.nextFree()
code += IRInstruction(Opcode.LOAD, dt, reg1 = pow2reg, immediate = pow2)
code += if (knownAddress != null)
IRInstruction(Opcode.ASRNM, dt, reg1 = pow2reg, address = knownAddress)
else
IRInstruction(Opcode.ASRNM, dt, reg1 = pow2reg, labelSymbol = symbol)
}
} else {
if(pow2==1) {
// just simple bit shift (unsigned)
code += if(knownAddress!=null)
IRInstruction(Opcode.LSRM, dt, address = knownAddress)
else
IRInstruction(Opcode.LSRM, dt, labelSymbol = symbol)
}
else {
// just shift multiple bits (unsigned)
val pow2reg = registers.nextFree()
code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2)
code += if(knownAddress!=null)
IRInstruction(Opcode.LSRNM, dt, reg1 = pow2reg, address = knownAddress)
else
IRInstruction(Opcode.LSRNM, dt, reg1 = pow2reg, labelSymbol = symbol)
}
}
return code
} }
else if(pow2>=1 && !signed) { else
// just shift multiple bits (unsigned) {
val pow2reg = registers.nextFree() // regular div
code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2)
code += if(signed) {
if(knownAddress!=null)
IRInstruction(Opcode.ASRNM, dt, reg1 = pow2reg, address = knownAddress)
else
IRInstruction(Opcode.ASRNM, dt, reg1 = pow2reg, labelSymbol = symbol)
}
else {
if(knownAddress!=null)
IRInstruction(Opcode.LSRNM, dt, reg1 = pow2reg, address = knownAddress)
else
IRInstruction(Opcode.LSRNM, dt, reg1 = pow2reg, labelSymbol = symbol)
}
} else {
if (factor == 0) { if (factor == 0) {
val reg = registers.nextFree() val reg = registers.nextFree()
code += IRInstruction(Opcode.LOAD, dt, reg1=reg, immediate = 0xffff) code += IRInstruction(Opcode.LOAD, dt, reg1=reg, immediate = 0xffff)
@@ -862,8 +890,8 @@ class IRCodeGen(
IRInstruction(Opcode.DIVM, dt, reg1 = factorReg, labelSymbol = symbol) IRInstruction(Opcode.DIVM, dt, reg1 = factorReg, labelSymbol = symbol)
} }
} }
return code
} }
return code
} }
private fun translate(ifElse: PtIfElse): IRCodeChunks { private fun translate(ifElse: PtIfElse): IRCodeChunks {

View File

@@ -708,6 +708,7 @@ class ExpressionSimplifier(private val program: Program, private val options: Co
return null return null
val leftDt = leftIDt.getOr(DataType.UNDEFINED) val leftDt = leftIDt.getOr(DataType.UNDEFINED)
when (cv) { when (cv) {
0.0 -> return null // fall through to regular float division to properly deal with division by zero
-1.0 -> { -1.0 -> {
// '/' -> -left // '/' -> -left
if (expr.operator == "/") { if (expr.operator == "/") {
@@ -736,14 +737,10 @@ class ExpressionSimplifier(private val program: Program, private val options: Co
} }
} }
in powersOfTwoFloat -> { in powersOfTwoFloat -> {
if (leftDt==DataType.UBYTE || leftDt==DataType.UWORD) { val numshifts = powersOfTwoFloat.indexOf(cv)
// Unsigned number divided by a power of two => shift right if (leftDt in IntegerDatatypes) {
// Signed number can't simply be bitshifted in this case (due to rounding issues for negative values), // division by a power of two => shift right (signed and unsigned)
// so we leave that as is and let the code generator deal with it.
val numshifts = log2(cv).toInt()
return BinaryExpression(expr.left, ">>", NumericLiteral.optimalInteger(numshifts, expr.position), expr.position) return BinaryExpression(expr.left, ">>", NumericLiteral.optimalInteger(numshifts, expr.position), expr.position)
} else {
println("TODO optimize: divide by power-of-2 $cv at ${expr.position}") // TODO
} }
} }
} }

View File

@@ -5,11 +5,7 @@ See open issues on github.
Re-generate the skeletons doc files. Re-generate the skeletons doc files.
optimize byte/word division by powers of 2 (and shift right?), it's now often still using divmod routine. (also % ?) optimize signed word bit shifting?:
see the TODOs in inplacemodificationByteVariableWithLiteralval(), inplacemodificationSomeWordWithLiteralval(), optimizedDivideExpr(),
and finally in optimizeDivision()
and for IR: see divideByConst() / divideByConstInplace() in IRCodeGen
1 shift right of AX signed word: 1 shift right of AX signed word:
stx P8ZP_SCRATCH_B1 stx P8ZP_SCRATCH_B1
cpx #$80 cpx #$80

View File

@@ -61,7 +61,7 @@ main {
} }
sub unsigned() { sub unsigned() {
txt.print("unsigned\n") txt.print("\nunsigned\n")
ubyte @shared ubvalue = 88 ubyte @shared ubvalue = 88
uword @shared uwvalue = 8888 uword @shared uwvalue = 8888