todo's for division optimizations

This commit is contained in:
Irmen de Jong 2024-07-21 13:35:28 +02:00
parent 2aae1f5e30
commit 0af17cdc33
10 changed files with 140 additions and 83 deletions

View File

@ -1,6 +1,11 @@
package prog8.code.core package prog8.code.core
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow
val powersOfTwoFloat = (1..16).map { (2.0).pow(it) }.toTypedArray()
val negativePowersOfTwoFloat = powersOfTwoFloat.map { -it }.toTypedArray()
val powersOfTwoInt = (0..16).map { 2.0.pow(it).toInt() }.toTypedArray()
fun Number.toHex(): String { fun Number.toHex(): String {
// 0..15 -> "0".."15" // 0..15 -> "0".."15"

View File

@ -933,6 +933,11 @@ internal class AssignmentAsmGen(private val program: PtProgram,
} }
private fun optimizedDivideExpr(expr: PtBinaryExpression, target: AsmAssignTarget): Boolean { private fun optimizedDivideExpr(expr: PtBinaryExpression, target: AsmAssignTarget): Boolean {
val constDivisor = expr.right.asConstInteger()
if(constDivisor in powersOfTwoInt) {
println("TODO optimize: divide ${expr.type} by power-of-2 ${constDivisor} at ${expr.position}") // TODO
}
when(expr.type) { when(expr.type) {
DataType.UBYTE -> { DataType.UBYTE -> {
assignExpressionToRegister(expr.left, RegisterOrPair.A, false) assignExpressionToRegister(expr.left, RegisterOrPair.A, false)

View File

@ -1477,6 +1477,9 @@ $shortcutLabel:""")
asmgen.out(" lda $name | ldy #$value | jsr math.multiply_bytes | sta $name") asmgen.out(" lda $name | ldy #$value | jsr math.multiply_bytes | sta $name")
} }
"/" -> { "/" -> {
if(value in powersOfTwoInt) {
println("TODO optimize: (u)byte division by power-of-2 $value") // TODO
}
if (dt == DataType.UBYTE) if (dt == DataType.UBYTE)
asmgen.out(" lda $name | ldy #$value | jsr math.divmod_ub_asm | sty $name") asmgen.out(" lda $name | ldy #$value | jsr math.divmod_ub_asm | sty $name")
else else
@ -1827,6 +1830,9 @@ $shortcutLabel:""")
"/" -> { "/" -> {
if(value==0) if(value==0)
throw AssemblyError("division by zero") throw AssemblyError("division by zero")
else if (value in powersOfTwoInt) {
println("TODO optimize: (u)word division by power-of-2 $value") // TODO
}
if(dt==DataType.WORD) { if(dt==DataType.WORD) {
asmgen.out(""" asmgen.out("""
lda $lsb lda $lsb

View File

@ -5,7 +5,6 @@ import prog8.code.ast.*
import prog8.code.core.* import prog8.code.core.*
import prog8.intermediate.* import prog8.intermediate.*
import kotlin.io.path.readBytes import kotlin.io.path.readBytes
import kotlin.math.pow
class IRCodeGen( class IRCodeGen(
@ -682,13 +681,11 @@ class IRCodeGen(
return code return code
} }
internal val powersOfTwo = (0..16).map { 2.0.pow(it.toDouble()).toInt() }
internal fun multiplyByConst(dt: IRDataType, reg: Int, factor: Int): IRCodeChunk { internal fun multiplyByConst(dt: IRDataType, reg: Int, factor: Int): IRCodeChunk {
val code = IRCodeChunk(null, null) val code = IRCodeChunk(null, null)
if(factor==1) if(factor==1)
return code return code
val pow2 = powersOfTwo.indexOf(factor) val pow2 = powersOfTwoInt.indexOf(factor)
if(pow2==1) { if(pow2==1) {
// just shift 1 bit // just shift 1 bit
code += IRInstruction(Opcode.LSL, dt, reg1=reg) code += IRInstruction(Opcode.LSL, dt, reg1=reg)
@ -712,7 +709,7 @@ class IRCodeGen(
val code = IRCodeChunk(null, null) val code = IRCodeChunk(null, null)
if(factor==1) if(factor==1)
return code return code
val pow2 = powersOfTwo.indexOf(factor) val pow2 = powersOfTwoInt.indexOf(factor)
if(pow2==1) { if(pow2==1) {
// just shift 1 bit // just shift 1 bit
code += if(knownAddress!=null) code += if(knownAddress!=null)
@ -785,13 +782,13 @@ class IRCodeGen(
val code = IRCodeChunk(null, null) val code = IRCodeChunk(null, null)
if(factor==1) if(factor==1)
return code return code
val pow2 = powersOfTwo.indexOf(factor) val pow2 = powersOfTwoInt.indexOf(factor)
// TODO also try to optimize for signed division by powers of 2
if(pow2==1 && !signed) { if(pow2==1 && !signed) {
code += IRInstruction(Opcode.LSR, dt, reg1=reg) // simple single bit shift code += IRInstruction(Opcode.LSR, dt, reg1=reg) // simple single bit shift
} }
else if(pow2>=1 &&!signed) { else if(pow2>=1 &&!signed) {
// just shift multiple bits // just shift multiple bits (unsigned)
// TODO also try to optimize for signed division by powers of 2
val pow2reg = registers.nextFree() val pow2reg = registers.nextFree()
code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2) code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2)
code += if(signed) code += if(signed)
@ -815,7 +812,8 @@ class IRCodeGen(
val code = IRCodeChunk(null, null) val code = IRCodeChunk(null, null)
if(factor==1) if(factor==1)
return code return code
val pow2 = powersOfTwo.indexOf(factor) val pow2 = powersOfTwoInt.indexOf(factor)
// TODO also try to optimize for signed division by powers of 2
if(pow2==1 && !signed) { if(pow2==1 && !signed) {
// just simple bit shift // just simple bit shift
code += if(knownAddress!=null) code += if(knownAddress!=null)
@ -824,7 +822,7 @@ class IRCodeGen(
IRInstruction(Opcode.LSRM, dt, labelSymbol = symbol) IRInstruction(Opcode.LSRM, dt, labelSymbol = symbol)
} }
else if(pow2>=1 && !signed) { else if(pow2>=1 && !signed) {
// just shift multiple bits // just shift multiple bits (unsigned)
val pow2reg = registers.nextFree() val pow2reg = registers.nextFree()
code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2) code += IRInstruction(Opcode.LOAD, dt, reg1=pow2reg, immediate = pow2)
code += if(signed) { code += if(signed) {

View File

@ -14,9 +14,6 @@ import kotlin.math.log2
import kotlin.math.pow import kotlin.math.pow
class ExpressionSimplifier(private val program: Program, private val options: CompilationOptions, private val errors: IErrorReporter) : AstWalker() { class ExpressionSimplifier(private val program: Program, private val options: CompilationOptions, private val errors: IErrorReporter) : AstWalker() {
private val powersOfTwo = (1..16).map { (2.0).pow(it) }.toSet()
private val negativePowersOfTwo = powersOfTwo.map { -it }.toSet()
override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> { override fun after(typecast: TypecastExpression, parent: Node): Iterable<IAstModification> {
val mods = mutableListOf<IAstModification>() val mods = mutableListOf<IAstModification>()
@ -686,7 +683,7 @@ class ExpressionSimplifier(private val program: Program, private val options: Co
if(!idt.isKnown) if(!idt.isKnown)
throw FatalAstException("unknown dt") throw FatalAstException("unknown dt")
return NumericLiteral(idt.getOr(DataType.UNDEFINED), 0.0, expr.position) return NumericLiteral(idt.getOr(DataType.UNDEFINED), 0.0, expr.position)
} else if (cv in powersOfTwo) { } else if (cv in powersOfTwoFloat) {
expr.operator = "&" expr.operator = "&"
expr.right = NumericLiteral.optimalInteger(cv!!.toInt()-1, expr.position) expr.right = NumericLiteral.optimalInteger(cv!!.toInt()-1, expr.position)
return null return null
@ -738,13 +735,15 @@ class ExpressionSimplifier(private val program: Program, private val options: Co
else -> return null else -> return null
} }
} }
in powersOfTwo -> { in powersOfTwoFloat -> {
if (leftDt==DataType.UBYTE || leftDt==DataType.UWORD) { if (leftDt==DataType.UBYTE || leftDt==DataType.UWORD) {
// Unsigned number divided by a power of two => shift right // Unsigned number divided by a power of two => shift right
// Signed number can't simply be bitshifted in this case (due to rounding issues for negative values), // Signed number can't simply be bitshifted in this case (due to rounding issues for negative values),
// so we leave that as is and let the code generator deal with it. // so we leave that as is and let the code generator deal with it.
val numshifts = log2(cv).toInt() val numshifts = log2(cv).toInt()
return BinaryExpression(expr.left, ">>", NumericLiteral.optimalInteger(numshifts, expr.position), expr.position) return BinaryExpression(expr.left, ">>", NumericLiteral.optimalInteger(numshifts, expr.position), expr.position)
} else {
println("TODO optimize: divide by power-of-2 $cv at ${expr.position}") // TODO
} }
} }
} }
@ -795,14 +794,14 @@ class ExpressionSimplifier(private val program: Program, private val options: Co
// left // left
return expr2.left return expr2.left
} }
in powersOfTwo -> { in powersOfTwoFloat -> {
if (leftValue.inferType(program).isInteger) { if (leftValue.inferType(program).isInteger) {
// times a power of two => shift left // times a power of two => shift left
val numshifts = log2(cv).toInt() val numshifts = log2(cv).toInt()
return BinaryExpression(expr2.left, "<<", NumericLiteral.optimalInteger(numshifts, expr.position), expr.position) return BinaryExpression(expr2.left, "<<", NumericLiteral.optimalInteger(numshifts, expr.position), expr.position)
} }
} }
in negativePowersOfTwo -> { in negativePowersOfTwoFloat -> {
if (leftValue.inferType(program).isInteger) { if (leftValue.inferType(program).isInteger) {
// times a negative power of two => negate, then shift // times a negative power of two => negate, then shift
val numshifts = log2(-cv).toInt() val numshifts = log2(-cv).toInt()

View File

@ -21,7 +21,8 @@ class StatementOptimizer(private val program: Program,
if(functionCallStatement.target.nameInSource.size==1) { if(functionCallStatement.target.nameInSource.size==1) {
val functionName = functionCallStatement.target.nameInSource[0] val functionName = functionCallStatement.target.nameInSource[0]
if (functionName in functions.purefunctionNames) { if (functionName in functions.purefunctionNames) {
errors.warn("statement has no effect (function return value is discarded)", functionCallStatement.position) if("ignore_unused" !in parent.definingBlock.options())
errors.warn("statement has no effect (function return value is discarded)", functionCallStatement.position)
return listOf(IAstModification.Remove(functionCallStatement, parent as IStatementContainer)) return listOf(IAstModification.Remove(functionCallStatement, parent as IStatementContainer))
} }
} }

View File

@ -149,7 +149,8 @@ class UnusedCodeRemover(private val program: Program,
val declIndex = (parent as IStatementContainer).statements.indexOf(decl) val declIndex = (parent as IStatementContainer).statements.indexOf(decl)
val singleUseIndex = (parent as IStatementContainer).statements.indexOf(singleUse.parent) val singleUseIndex = (parent as IStatementContainer).statements.indexOf(singleUse.parent)
if(declIndex==singleUseIndex-1) { if(declIndex==singleUseIndex-1) {
errors.info("replaced unused variable '${decl.name}' with void call, maybe this can be removed altogether", decl.position) if("ignore_unused" !in decl.definingBlock.options())
errors.info("replaced unused variable '${decl.name}' with void call, maybe this can be removed altogether", decl.position)
val fcall = assignment.value as IFunctionCall val fcall = assignment.value as IFunctionCall
val voidCall = FunctionCallStatement(fcall.target, fcall.args, true, fcall.position) val voidCall = FunctionCallStatement(fcall.target, fcall.args, true, fcall.position)
return listOf( return listOf(

View File

@ -182,7 +182,7 @@ sprites {
} }
sub set_mousepointer_image(uword data, bool compressed) { sub set_mousepointer_image(uword data, bool compressed) {
get_data_ptr(0) ; the mouse cursor is sprite 0 get_data_ptr_internal(0) ; the mouse cursor is sprite 0
if cx16.r1L==0 and cx16.r0==0 if cx16.r1L==0 and cx16.r0==0
return ; mouse cursor not enabled return ; mouse cursor not enabled
ubyte vbank = cx16.r1L ubyte vbank = cx16.r1L

View File

@ -5,9 +5,10 @@ See open issues on github.
Re-generate the skeletons doc files. Re-generate the skeletons doc files.
optimize signed byte/word division by powers of 2 (and shift right?), it's now using divmod routine. (also % ?) optimize byte/word division by powers of 2 (and shift right?), it's now often still using divmod routine. (also % ?)
see inplacemodificationByteVariableWithLiteralval() and inplacemodificationSomeWordWithLiteralval() see the TODOs in inplacemodificationByteVariableWithLiteralval(), inplacemodificationSomeWordWithLiteralval(), optimizedDivideExpr(),
and for IR: see divideByConst() in IRCodeGen and finally in optimizeDivision()
and for IR: see divideByConst() / divideByConstInplace() in IRCodeGen
1 shift right of AX signed word: 1 shift right of AX signed word:
stx P8ZP_SCRATCH_B1 stx P8ZP_SCRATCH_B1

View File

@ -3,68 +3,109 @@
%option no_sysinit %option no_sysinit
%zeropage basicsafe %zeropage basicsafe
main { main {
sub start() { sub start() {
ubyte[10] array signed()
array[10] = 0 unsigned()
; array[-11] = 0 }
; txt.print_ub(ptr[index]) sub signed() {
; txt.nl() txt.print("signed\n")
; ptr[index] = 123 byte @shared bvalue = -88
; txt.print_ub(ptr[index]) word @shared wvalue = -8888
; txt.nl()
txt.print_b(bvalue/2)
txt.nl()
txt.print_w(wvalue/2)
txt.nl()
bvalue /= 2
wvalue /= 2
txt.print_b(bvalue)
txt.nl()
txt.print_w(wvalue)
txt.nl()
bvalue *= 2
wvalue *= 2
txt.print_b(bvalue)
txt.nl()
txt.print_w(wvalue)
txt.nl()
txt.nl()
txt.print_b(bvalue/4)
txt.nl()
txt.print_w(wvalue/4)
txt.nl()
bvalue /= 4
wvalue /= 4
txt.print_b(bvalue)
txt.nl()
txt.print_w(wvalue)
txt.nl()
bvalue *= 4
wvalue *= 4
txt.print_b(bvalue)
txt.nl()
txt.print_w(wvalue)
txt.nl()
txt.nl()
}
sub unsigned() {
txt.print("unsigned\n")
ubyte @shared ubvalue = 88
uword @shared uwvalue = 8888
txt.print_ub(ubvalue/2)
txt.nl()
txt.print_uw(uwvalue/2)
txt.nl()
ubvalue /= 2
uwvalue /= 2
txt.print_ub(ubvalue)
txt.nl()
txt.print_uw(uwvalue)
txt.nl()
ubvalue *= 2
uwvalue *= 2
txt.print_ub(ubvalue)
txt.nl()
txt.print_uw(uwvalue)
txt.nl()
txt.nl()
txt.print_ub(ubvalue/4)
txt.nl()
txt.print_uw(uwvalue/4)
txt.nl()
ubvalue /= 4
uwvalue /= 4
txt.print_ub(ubvalue)
txt.nl()
txt.print_uw(uwvalue)
txt.nl()
ubvalue *= 4
uwvalue *= 4
txt.print_ub(ubvalue)
txt.nl()
txt.print_uw(uwvalue)
txt.nl()
} }
} }
;
;main {
; sub start() {
; signed()
; unsigned()
; }
;
; sub signed() {
; byte @shared bvalue = -100
; word @shared wvalue = -20000
;
; bvalue /= 2 ; TODO should be a simple bit shift?
; wvalue /= 2 ; TODO should be a simple bit shift?
;
; txt.print_b(bvalue)
; txt.nl()
; txt.print_w(wvalue)
; txt.nl()
;
; bvalue *= 2
; wvalue *= 2
;
; txt.print_b(bvalue)
; txt.nl()
; txt.print_w(wvalue)
; txt.nl()
; }
;
; sub unsigned() {
; ubyte @shared ubvalue = 100
; uword @shared uwvalue = 20000
;
; ubvalue /= 2
; uwvalue /= 2
;
; txt.print_ub(ubvalue)
; txt.nl()
; txt.print_uw(uwvalue)
; txt.nl()
;
; ubvalue *= 2
; uwvalue *= 2
;
; txt.print_ub(ubvalue)
; txt.nl()
; txt.print_uw(uwvalue)
; txt.nl()
; }
;}