vm: simple optimizations for +/-/*/div with constants

This commit is contained in:
Irmen de Jong 2022-04-14 22:42:25 +02:00
parent 0f4a197e34
commit 0f36be0001
6 changed files with 274 additions and 139 deletions

View File

@ -370,10 +370,10 @@ class CodeGen(internal val program: PtProgram,
return code
}
private val powersOfTwo = (0..16).map { 2.0.pow(it.toDouble()).toInt() }
internal val powersOfTwo = (0..16).map { 2.0.pow(it.toDouble()).toInt() }
internal fun multiplyByConst(dt: VmDataType, reg: Int, factor: Int): VmCodeChunk {
require(factor>=0)
// TODO support floating-point factors
val code = VmCodeChunk()
if(factor==1)
return code
@ -386,7 +386,7 @@ class CodeGen(internal val program: PtProgram,
// just shift multiple bits
val pow2reg = vmRegisters.nextFree()
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=pow2reg, value=pow2)
code += VmCodeInstruction(Opcode.LSLM, dt, reg1=reg, reg2=reg, reg3=pow2reg)
code += VmCodeInstruction(Opcode.LSLX, dt, reg1=reg, reg2=reg, reg3=pow2reg)
} else {
if (factor == 0) {
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=reg, value=0)
@ -400,6 +400,34 @@ class CodeGen(internal val program: PtProgram,
return code
}
internal fun divideByConst(dt: VmDataType, reg: Int, factor: Int): VmCodeChunk {
// TODO support floating-point factors
val code = VmCodeChunk()
if(factor==1)
return code
val pow2 = powersOfTwo.indexOf(factor)
if(pow2==1) {
// just shift 1 bit
code += VmCodeInstruction(Opcode.LSR, dt, reg1=reg)
}
else if(pow2>=1) {
// just shift multiple bits
val pow2reg = vmRegisters.nextFree()
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=pow2reg, value=pow2)
code += VmCodeInstruction(Opcode.LSRX, dt, reg1=reg, reg2=reg, reg3=pow2reg)
} else {
if (factor == 0) {
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=reg, value=0)
}
else {
val factorReg = vmRegisters.nextFree()
code += VmCodeInstruction(Opcode.LOAD, dt, reg1=factorReg, value= factor)
code += VmCodeInstruction(Opcode.DIV, dt, reg1=reg, reg2=reg, reg3=factorReg)
}
}
return code
}
private fun translate(ifElse: PtIfElse): VmCodeChunk {
var branch = Opcode.BZ
var condition = ifElse.condition

View File

@ -264,73 +264,214 @@ internal class ExpressionGen(private val codeGen: CodeGen) {
}
private fun translate(binExpr: PtBinaryExpression, resultRegister: Int): VmCodeChunk {
val vmDt = codeGen.vmType(binExpr.left.type)
val signed = binExpr.left.type in SignedDatatypes
return when(binExpr.operator) {
"+" -> operatorPlus(binExpr, vmDt, resultRegister)
"-" -> operatorMinus(binExpr, vmDt, resultRegister)
"*" -> operatorMultiply(binExpr, vmDt, resultRegister)
"/" -> operatorDivide(binExpr, vmDt, resultRegister)
"%" -> operatorModulo(binExpr, vmDt, resultRegister)
"|", "or" -> operatorOr(binExpr, vmDt, resultRegister)
"&", "and" -> operatorAnd(binExpr, vmDt, resultRegister)
"^", "xor" -> operatorXor(binExpr, vmDt, resultRegister)
"<<" -> operatorShiftLeft(binExpr, vmDt, resultRegister)
">>" -> operatorShiftRight(binExpr, vmDt, resultRegister, signed)
"==" -> operatorEquals(binExpr, vmDt, resultRegister, true)
"!=" -> operatorEquals(binExpr, vmDt, resultRegister, false)
"<" -> operatorLessThan(binExpr, vmDt, resultRegister, signed, false)
">" -> operatorGreaterThan(binExpr, vmDt, resultRegister, signed, false)
"<=" -> operatorLessThan(binExpr, vmDt, resultRegister, signed, true)
">=" -> operatorGreaterThan(binExpr, vmDt, resultRegister, signed, true)
else -> throw AssemblyError("weird operator ${binExpr.operator}")
}
}
private fun operatorGreaterThan(
binExpr: PtBinaryExpression,
vmDt: VmDataType,
resultRegister: Int,
signed: Boolean,
greaterEquals: Boolean
): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
// TODO: optimized codegen when left or right operand is known 0 or 1 or whatever. But only if this would result in a different opcode such as ADD 1 -> INC, MUL 1 -> NOP
// actually optimizing the code should not be done here but in a tailored code optimizer step.
// multiplyByConst()
val leftCode = translateExpression(binExpr.left, leftResultReg)
val rightCode = translateExpression(binExpr.right, rightResultReg)
code += leftCode
code += rightCode
val vmDt = codeGen.vmType(binExpr.left.type)
val signed = binExpr.left.type in SignedDatatypes
when(binExpr.operator) {
"+" -> {
code += VmCodeInstruction(Opcode.ADD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"-" -> {
code += VmCodeInstruction(Opcode.SUB, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"*" -> {
code += VmCodeInstruction(Opcode.MUL, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"/" -> {
code += VmCodeInstruction(Opcode.DIV, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"%" -> {
code += VmCodeInstruction(Opcode.MOD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"|", "or" -> {
code += VmCodeInstruction(Opcode.OR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"&", "and" -> {
code += VmCodeInstruction(Opcode.AND, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"^", "xor" -> {
code += VmCodeInstruction(Opcode.XOR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"<<" -> {
code += VmCodeInstruction(Opcode.LSLM, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
">>" -> {
val opc = if(signed) Opcode.ASRM else Opcode.LSRM
code += VmCodeInstruction(opc, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"==" -> {
code += VmCodeInstruction(Opcode.SEQ, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"!=" -> {
code += VmCodeInstruction(Opcode.SNE, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"<" -> {
val ins = if(signed) Opcode.SLTS else Opcode.SLT
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
">" -> {
val ins = if(signed) Opcode.SGTS else Opcode.SGT
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
"<=" -> {
val ins = if(signed) Opcode.SLES else Opcode.SLE
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
">=" -> {
val ins = if(signed) Opcode.SGES else Opcode.SGE
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
else -> throw AssemblyError("weird operator ${binExpr.operator}")
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val ins = if(signed) {
if(greaterEquals) Opcode.SGES else Opcode.SGTS
} else {
if(greaterEquals) Opcode.SGE else Opcode.SGT
}
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorLessThan(
binExpr: PtBinaryExpression,
vmDt: VmDataType,
resultRegister: Int,
signed: Boolean,
lessEquals: Boolean
): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val ins = if(signed) {
if(lessEquals) Opcode.SLES else Opcode.SLTS
} else {
if(lessEquals) Opcode.SLE else Opcode.SLT
}
code += VmCodeInstruction(ins, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorEquals(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int, notEquals: Boolean): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val opcode = if(notEquals) Opcode.SNE else Opcode.SEQ
code += VmCodeInstruction(opcode, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorShiftRight(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int, signed: Boolean): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
val opc = if(signed) Opcode.ASRX else Opcode.LSRX
code += VmCodeInstruction(opc, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorShiftLeft(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.LSLX, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorXor(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.XOR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorAnd(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.AND, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorOr(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.OR, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorModulo(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.MOD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
return code
}
private fun operatorDivide(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val constFactorRight = binExpr.right as? PtNumber
if(constFactorRight!=null && constFactorRight.type!=DataType.FLOAT) {
code += translateExpression(binExpr.left, resultRegister)
val factor = constFactorRight.number.toInt()
code += codeGen.divideByConst(vmDt, resultRegister, factor)
} else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.DIV, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}
private fun operatorMultiply(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
val constFactorLeft = binExpr.left as? PtNumber
val constFactorRight = binExpr.right as? PtNumber
if(constFactorLeft!=null && constFactorLeft.type!=DataType.FLOAT) {
code += translateExpression(binExpr.right, resultRegister)
val factor = constFactorLeft.number.toInt()
code += codeGen.multiplyByConst(vmDt, resultRegister, factor)
} else if(constFactorRight!=null && constFactorRight.type!=DataType.FLOAT) {
code += translateExpression(binExpr.left, resultRegister)
val factor = constFactorRight.number.toInt()
code += codeGen.multiplyByConst(vmDt, resultRegister, factor)
} else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.MUL, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}
private fun operatorMinus(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
if((binExpr.right as? PtNumber)?.number==1.0) {
code += translateExpression(binExpr.left, resultRegister)
code += VmCodeInstruction(Opcode.DEC, vmDt, reg1=resultRegister)
}
else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.SUB, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}
private fun operatorPlus(binExpr: PtBinaryExpression, vmDt: VmDataType, resultRegister: Int): VmCodeChunk {
val code = VmCodeChunk()
if((binExpr.left as? PtNumber)?.number==1.0) {
code += translateExpression(binExpr.right, resultRegister)
code += VmCodeInstruction(Opcode.INC, vmDt, reg1=resultRegister)
}
else if((binExpr.right as? PtNumber)?.number==1.0) {
code += translateExpression(binExpr.left, resultRegister)
code += VmCodeInstruction(Opcode.INC, vmDt, reg1=resultRegister)
}
else {
val leftResultReg = codeGen.vmRegisters.nextFree()
val rightResultReg = codeGen.vmRegisters.nextFree()
code += translateExpression(binExpr.left, leftResultReg)
code += translateExpression(binExpr.right, rightResultReg)
code += VmCodeInstruction(Opcode.ADD, vmDt, reg1=resultRegister, reg2=leftResultReg, reg3=rightResultReg)
}
return code
}

View File

@ -6,68 +6,34 @@
main {
sub start() {
uword qq = 999 as ubyte |> abs() |> abs()
txt.print_uw(qq)
txt.nl()
; uword other = $fe4a
; uword value = $ea31
; uword[] warray = [$aa44, $bb55, $cc66]
; ubyte upperb = msb(value)
; ubyte lowerb = lsb(value)
; txt.print_ubhex(upperb, true)
; txt.print_ubhex(lowerb, false)
; uword ww = 100
; uword vv
; vv = ww+1
; txt.print_uw(vv)
; txt.nl()
; value = mkword(upperb, lowerb)
; txt.print_uwhex(value, true)
; txt.nl()
; upperb = msb(warray[1])
; lowerb = lsb(warray[1])
; txt.print_ubhex(upperb, true)
; txt.print_ubhex(lowerb, false)
; txt.nl()
; ubyte index=1
; upperb = msb(warray[index])
; lowerb = lsb(warray[index])
; txt.print_ubhex(upperb, true)
; txt.print_ubhex(lowerb, false)
; txt.nl()
; swap(other, value)
; txt.print_uwhex(value,true)
; txt.nl()
; txt.nl()
;
; pokew($1000, $ab98)
; txt.print_ubhex(@($1000),true)
; txt.print_ubhex(@($1001),false)
; txt.nl()
; txt.print_uwhex(peekw($1000),true)
; txt.nl()
; swap(@($1000), @($1001))
; txt.print_uwhex(peekw($1000),true)
; txt.nl()
; swap(warray[0], warray[1])
; txt.print_uwhex(warray[1],true)
; vv = ww * 8
; txt.print_uw(vv)
; txt.nl()
; ; a "pixelshader":
; void syscall1(8, 0) ; enable lo res creen
; ubyte shifter
;
; ; pokemon(1,0)
;
; repeat {
; uword xx
; uword yy = 0
; repeat 240 {
; xx = 0
; repeat 320 {
; syscall3(10, xx, yy, xx*yy + shifter) ; plot pixel
; xx++
; }
; yy++
; }
; shifter+=4
; }
; a "pixelshader":
void syscall1(8, 0) ; enable lo res creen
ubyte shifter
; pokemon(1,0)
repeat {
uword xx
uword yy = 0
repeat 240 {
xx = 0
repeat 320 {
syscall3(10, xx, yy, xx*yy + shifter) ; plot pixel
xx++
}
yy++
}
shifter+=4
}
}
}

View File

@ -1,6 +1,6 @@
#!/usr/bin/env sh
rm -f *.bin *.xex *.jar *.asm *.prg *.vm.txt *.vice-mon-list *.list a.out imgui.ini
rm -f *.bin *.xex *.jar *.asm *.prg *.vm.txt *.vice-mon-list *.list *.p8virt a.out imgui.ini
rm -rf build out
rm -rf compiler/build codeGenCpu6502/build codeGenExperimental/build codeOptimizers/build compilerAst/build dbusCompilerService/build httpCompilerService/build parser/build

View File

@ -120,9 +120,9 @@ All have type b or w.
and reg1, reg2, reg3 - reg1 = reg2 bitwise and reg3
or reg1, reg2, reg3 - reg1 = reg2 bitwise or reg3
xor reg1, reg2, reg3 - reg1 = reg2 bitwise xor reg3
lsrm reg1, reg2, reg3 - reg1 = multi-shift reg2 right by reg3 bits + set Carry to shifted bit
asrm reg1, reg2, reg3 - reg1 = multi-shift reg2 right by reg3 bits (signed) + set Carry to shifted bit
lslm reg1, reg2, reg3 - reg1 = multi-shift reg2 left by reg3 bits + set Carry to shifted bit
lsrx reg1, reg2, reg3 - reg1 = multi-shift reg2 right by reg3 bits + set Carry to shifted bit
asrx reg1, reg2, reg3 - reg1 = multi-shift reg2 right by reg3 bits (signed) + set Carry to shifted bit
lslx reg1, reg2, reg3 - reg1 = multi-shift reg2 left by reg3 bits + set Carry to shifted bit
lsr reg1 - shift reg1 right by 1 bits + set Carry to shifted bit
asr reg1 - shift reg1 right by 1 bits (signed) + set Carry to shifted bit
lsl reg1 - shift reg1 left by 1 bits + set Carry to shifted bit
@ -213,9 +213,9 @@ enum class Opcode {
AND,
OR,
XOR,
ASRM,
LSRM,
LSLM,
ASRX,
LSRX,
LSLX,
ASR,
LSR,
LSL,
@ -382,9 +382,9 @@ val instructionFormats = mutableMapOf(
Opcode.AND to InstructionFormat(BW, true, true, true, false),
Opcode.OR to InstructionFormat(BW, true, true, true, false),
Opcode.XOR to InstructionFormat(BW, true, true, true, false),
Opcode.ASRM to InstructionFormat(BW, true, true, true, false),
Opcode.LSRM to InstructionFormat(BW, true, true, true, false),
Opcode.LSLM to InstructionFormat(BW, true, true, true, false),
Opcode.ASRX to InstructionFormat(BW, true, true, true, false),
Opcode.LSRX to InstructionFormat(BW, true, true, true, false),
Opcode.LSLX to InstructionFormat(BW, true, true, true, false),
Opcode.ASR to InstructionFormat(BW, true, false, false, false),
Opcode.LSR to InstructionFormat(BW, true, false, false, false),
Opcode.LSL to InstructionFormat(BW, true, false, false, false),

View File

@ -144,9 +144,9 @@ class VirtualMachine(val memory: Memory, program: List<Instruction>) {
Opcode.AND -> InsAND(ins)
Opcode.OR -> InsOR(ins)
Opcode.XOR -> InsXOR(ins)
Opcode.ASRM -> InsASRM(ins)
Opcode.LSRM -> InsLSRM(ins)
Opcode.LSLM -> InsLSLM(ins)
Opcode.ASRX -> InsASRM(ins)
Opcode.LSRX -> InsLSRM(ins)
Opcode.LSLX -> InsLSLM(ins)
Opcode.ASR -> InsASR(ins)
Opcode.LSR -> InsLSR(ins)
Opcode.LSL -> InsLSL(ins)