vm: simple optimizations for +/-/*/div with constants

This commit is contained in:
Irmen de Jong
2022-04-14 22:42:25 +02:00
parent 0f4a197e34
commit 0f36be0001
6 changed files with 274 additions and 139 deletions
@@ -370,10 +370,10 @@ class CodeGen(internal val program: PtProgram,
return code
}
private val powersOfTwo = (0..16).map { 2.0.pow(it.toDouble()).toInt() }
internal val powersOfTwo = (0..16).map { 2.0.pow(it.toDouble()).toInt() }
internal fun multiplyByConst(dt: VmDataType, reg: Int, factor: Int): VmCodeChunk {
require(factor>=0)
// TODO support floating-point factors
val code = VmCodeChunk()
if(factor==1)
return code
@@ -386,7 +386,7 @@ class CodeGen(internal val program: PtProgram,
// just shift multiple bits
val pow2reg = vmRegisters.nextFree()
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=pow2reg, value=pow2)
code += VmCodeInstruction(Opcode.LSLM, dt, reg1=reg, reg2=reg, reg3=pow2reg)
code += VmCodeInstruction(Opcode.LSLX, dt, reg1=reg, reg2=reg, reg3=pow2reg)
} else {
if (factor == 0) {
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=reg, value=0)
@@ -400,6 +400,34 @@ class CodeGen(internal val program: PtProgram,
return code
}
internal fun divideByConst(dt: VmDataType, reg: Int, factor: Int): VmCodeChunk {
// TODO support floating-point factors
val code = VmCodeChunk()
if(factor==1)
return code
val pow2 = powersOfTwo.indexOf(factor)
if(pow2==1) {
// just shift 1 bit
code += VmCodeInstruction(Opcode.LSR, dt, reg1=reg)
}
else if(pow2>=1) {
// just shift multiple bits
val pow2reg = vmRegisters.nextFree()
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=pow2reg, value=pow2)
code += VmCodeInstruction(Opcode.LSRX, dt, reg1=reg, reg2=reg, reg3=pow2reg)
} else {
if (factor == 0) {
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=reg, value=0)
}
else {
val factorReg = vmRegisters.nextFree()
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=factorReg, value= factor)
code += VmCodeInstruction(Opcode.DIV, dt, reg1=reg, reg2=reg, reg3=factorReg)
}
}
return code
}
private fun translate(ifElse: PtIfElse): VmCodeChunk {
var branch = Opcode.BZ
var condition = ifElse.condition
@@ -264,73 +264,214 @@ internal class ExpressionGen(private val codeGen: CodeGen) {
}
private fun translate(binExpr: PtBinaryExpression, resultRegister: Int): VmCodeChunk {
val vmDt = codeGen.vmType(binExpr.left.type)
val signed = binExpr.left.type in SignedDatatypes
return when(binExpr.operator) {
"+" -> operatorPlus(binExpr, vmDt, resultRegister)
"-" -> operatorMinus(binExpr, vmDt, resultRegister)
"*" -> operatorMultiply(binExpr, vmDt, resultRegister)
"/" -> operatorDivide(binExpr, vmDt, resultRegister)
"%" -> operatorModulo(binExpr, vmDt, resultRegister)
"|", "or" -> operatorOr(binExpr, vmDt, resultRegister)
"&", "and" -> operatorAnd(binExpr, vmDt, resultRegister)
"^", "xor" -> operatorXor(binExpr, vmDt, resultRegister)
"<<" -> operatorShiftLeft(binExpr, vmDt, resultRegister)
">>" -> operatorShiftRight(binExpr, vmDt, resultRegister, signed)
"==" -> operatorEquals(binExpr, vmDt, resultRegister, true)
"!=" -> operatorEquals(binExpr, vmDt, resultRegister, false)
"<" -> operatorLessThan(binExpr, vmDt, resultRegister, signed, false)
">" -> operatorGreaterThan(binExpr, vmDt, resultRegister, signed, false)
"<=" -> operatorLessThan(binExpr, vmDt, resultRegister, signed, true)
">=" -> operatorGreaterThan(binExpr, vmDt, resultRegister, signed, true)
else -> throw AssemblyError("weird operator ${binExpr.operator}")
}
}
private fun operatorGreaterThan(
binExpr: PtBinaryExpression,
vmDt: VmDataType,
resultRegister: Int,
signed: Boolean,
greaterEquals: Boolean
): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
// TODO: optimized codegen when left or right operand is known 0 or 1 or whatever. But only if this would result in a different opcode such as ADD 1 -> INC, MUL 1 -> NOP
// actually optimizing the code should not be done here but in a tailored code optimizer step.
// multiplyByConst()
val leftCode = translateExpression(binExpr.left, leftResultReg)
val rightCode = translateExpression(binExpr.right, rightResultReg)
code += leftCode
code += rightCode
val vmDt = codeGen.vmType(binExpr.left.type)
val signed = binExpr.left.type in SignedDatatypes
when(binExpr.operator) {
"+" -> {
code += VmCodeInstruction(Opcode.ADD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"-" -> {
code += VmCodeInstruction(Opcode.SUB, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"*" -> {
code += VmCodeInstruction(Opcode.MUL, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"/" -> {
code += VmCodeInstruction(Opcode.DIV, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"%" -> {
code += VmCodeInstruction(Opcode.MOD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"|", "or" -> {
code += VmCodeInstruction(Opcode.OR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"&", "and" -> {
code += VmCodeInstruction(Opcode.AND, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"^", "xor" -> {
code += VmCodeInstruction(Opcode.XOR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"<<" -> {
code += VmCodeInstruction(Opcode.LSLM, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
">>" -> {
val opc = if(signed) Opcode.ASRM else Opcode.LSRM
code += VmCodeInstruction(opc, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"==" -> {
code += VmCodeInstruction(Opcode.SEQ, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"!=" -> {
code += VmCodeInstruction(Opcode.SNE, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"<" -> {
val ins = if(signed) Opcode.SLTS else Opcode.SLT
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
">" -> {
val ins = if(signed) Opcode.SGTS else Opcode.SGT
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"<=" -> {
val ins = if(signed) Opcode.SLES else Opcode.SLE
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
">=" -> {
val ins = if(signed) Opcode.SGES else Opcode.SGE
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
else -> throw AssemblyError("weird operator ${binExpr.operator}")
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val ins = if(signed) {
if(greaterEquals) Opcode.SGES else Opcode.SGTS
} else {
if(greaterEquals) Opcode.SGE else Opcode.SGT
}
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorLessThan(
binExpr: PtBinaryExpression,
vmDt: VmDataType,
resultRegister: Int,
signed: Boolean,
lessEquals: Boolean
): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val ins = if(signed) {
if(lessEquals) Opcode.SLES else Opcode.SLTS
} else {
if(lessEquals) Opcode.SLE else Opcode.SLT
}
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorEquals(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int, notEquals: Boolean): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val opcode = if(notEquals) Opcode.SNE else Opcode.SEQ
code += VmCodeInstruction(opcode, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorShiftRight(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int, signed: Boolean): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val opc = if(signed) Opcode.ASRX else Opcode.LSRX
code += VmCodeInstruction(opc, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorShiftLeft(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.LSLX, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorXor(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.XOR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorAnd(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.AND, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorOr(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.OR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorModulo(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.MOD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorDivide(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val constFactorRight = binExpr.right as? PtNumber
if(constFactorRight!=null && constFactorRight.type!=DataType.FLOAT) {
code += translateExpression(binExpr.left, resultRegister)
val factor = constFactorRight.number.toInt()
code += codeGen.divideByConst(vmDt, resultRegister, factor)
} else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.DIV, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}
private fun operatorMultiply(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val constFactorLeft = binExpr.left as? PtNumber
val constFactorRight = binExpr.right as? PtNumber
if(constFactorLeft!=null && constFactorLeft.type!=DataType.FLOAT) {
code += translateExpression(binExpr.right, resultRegister)
val factor = constFactorLeft.number.toInt()
code += codeGen.multiplyByConst(vmDt, resultRegister, factor)
} else if(constFactorRight!=null && constFactorRight.type!=DataType.FLOAT) {
code += translateExpression(binExpr.left, resultRegister)
val factor = constFactorRight.number.toInt()
code += codeGen.multiplyByConst(vmDt, resultRegister, factor)
} else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.MUL, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}
private fun operatorMinus(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
if((binExpr.right as? PtNumber)?.number==1.0) {
code += translateExpression(binExpr.left, resultRegister)
code += VmCodeInstruction(Opcode.DEC, vmDt, reg1=resultRegister)
}
else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.SUB, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}
private fun operatorPlus(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
if((binExpr.left as? PtNumber)?.number==1.0) {
code += translateExpression(binExpr.right, resultRegister)
code += VmCodeInstruction(Opcode.INC, vmDt, reg1=resultRegister)
}
else if((binExpr.right as? PtNumber)?.number==1.0) {
code += translateExpression(binExpr.left, resultRegister)
code += VmCodeInstruction(Opcode.INC, vmDt, reg1=resultRegister)
}
else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.ADD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}