use math.square for optimized X*X calculation (words only).

Added IR SQUARE instruction.
This commit is contained in:
Irmen de Jong 2023-08-14 00:50:40 +02:00
parent 923367296d
commit 2c9e50873c
12 changed files with 143 additions and 351 deletions

View File

@ -79,6 +79,8 @@ val BuiltinFunctions: Map<String, FSignature> = mapOf(
// cmp returns a status in the carry flag, but not a proper return value
"cmp" to FSignature(false, listOf(FParam("value1", IntegerDatatypesNoBool), FParam("value2", NumericDatatypesNoBool)), null),
"prog8_lib_stringcompare" to FSignature(true, listOf(FParam("str1", arrayOf(DataType.STR)), FParam("str2", arrayOf(DataType.STR))), DataType.BYTE),
"prog8_lib_square_byte" to FSignature(true, listOf(FParam("value", arrayOf(DataType.BYTE, DataType.UBYTE))), DataType.UBYTE),
"prog8_lib_square_word" to FSignature(true, listOf(FParam("value", arrayOf(DataType.WORD, DataType.UWORD))), DataType.UWORD),
"abs" to FSignature(true, listOf(FParam("value", NumericDatatypesNoBool)), null),
"abs__byte" to FSignature(true, listOf(FParam("value", arrayOf(DataType.BYTE))), DataType.BYTE),
"abs__word" to FSignature(true, listOf(FParam("value", arrayOf(DataType.WORD))), DataType.WORD),

View File

@ -72,12 +72,31 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram,
"cmp" -> funcCmp(fcall)
"callfar" -> funcCallFar(fcall)
"prog8_lib_stringcompare" -> funcStringCompare(fcall)
"prog8_lib_square_byte" -> funcSquare(fcall, DataType.UBYTE)
"prog8_lib_square_word" -> funcSquare(fcall, DataType.UWORD)
else -> throw AssemblyError("missing asmgen for builtin func ${fcall.name}")
}
return BuiltinFunctions.getValue(fcall.name).returnType
}
private fun funcSquare(fcall: PtBuiltinFunctionCall, resultType: DataType) {
// square of word value is faster with dedicated routine, square of byte just use the regular multiplication routine.
when (resultType) {
DataType.UBYTE -> {
asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A)
asmgen.out(" tay | jsr math.multiply_bytes")
}
DataType.UWORD -> {
asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.AY)
asmgen.out(" jsr math.square")
}
else -> {
throw AssemblyError("optimized square only for integer types")
}
}
}
private fun funcDivmod(fcall: PtBuiltinFunctionCall) {
assignAsmGen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A, false)
asmgen.saveRegisterStack(CpuRegister.A, false)

View File

@ -452,7 +452,6 @@ internal class ProgramAndVarsGen(
}
asmgen.out("""+
ldx #127 ; init estack ptr (half page)
clv
clc""")
}

View File

@ -45,10 +45,28 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe
"rol2" -> funcRolRor(Opcode.ROL, call)
"ror2" -> funcRolRor(Opcode.ROR, call)
"prog8_lib_stringcompare" -> funcStringCompare(call)
"prog8_lib_square_byte" -> funcSquare(call, IRDataType.BYTE)
"prog8_lib_square_word" -> funcSquare(call, IRDataType.WORD)
else -> throw AssemblyError("missing builtinfunc for ${call.name}")
}
}
private fun funcSquare(call: PtBuiltinFunctionCall, resultType: IRDataType): ExpressionCodeResult {
val result = mutableListOf<IRCodeChunkBase>()
val valueTr = exprGen.translateExpression(call.args[0])
addToResult(result, valueTr, valueTr.resultReg, valueTr.resultFpReg)
return if(resultType==IRDataType.FLOAT) {
val resultFpReg = codeGen.registers.nextFreeFloat()
addInstr(result, IRInstruction(Opcode.SQUARE, resultType, fpReg1 = resultFpReg, fpReg2 = valueTr.resultFpReg), null)
ExpressionCodeResult(result, resultType, -1, resultFpReg)
}
else {
val resultReg = codeGen.registers.nextFree()
addInstr(result, IRInstruction(Opcode.SQUARE, resultType, reg1 = resultReg, reg2 = valueTr.resultReg), null)
ExpressionCodeResult(result, resultType, resultReg, -1)
}
}
private fun funcCallfar(call: PtBuiltinFunctionCall): ExpressionCodeResult {
val result = mutableListOf<IRCodeChunkBase>()
val bankTr = exprGen.translateExpression(call.args[0])

View File

@ -814,9 +814,9 @@ asl_word_AY .proc
square .proc
; -- calculate square root of signed word in AY, result in AY
; routine by Lee Davsion, source: http://6502.org/source/integers/square.htm
; using this routine is about twice as fast as doing a regular multiplication.
; -- calculate square of signed word (actually -255..255) in AY, result in AY
; routine by Lee Davison, source: http://6502.org/source/integers/square.htm
; using this routine is a lot faster as doing a regular multiplication (for words)
;
; Calculates the 16 bit unsigned integer square of the signed 16 bit integer in
; Numberl/Numberh. The result is always in the range 0 to 65025 and is held in

View File

@ -200,6 +200,15 @@ _after:
}
}
if(expr.operator=="*" && expr.inferType(program).isInteger && expr.left isSameAs expr.right) {
// replace squaring with call to builtin function to do this in a more optimized way
val function = if(expr.left.inferType(program).isBytes) "prog8_lib_square_byte" else "prog8_lib_square_word"
val squareCall = BuiltinFunctionCall(
IdentifierReference(listOf(function), expr.position),
mutableListOf(expr.left.copy()), expr.position)
return listOf(IAstModification.ReplaceNode(expr, squareCall, parent))
}
return noModifications
}
}

View File

@ -5,7 +5,7 @@ import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.ints.shouldBeGreaterThan
import io.kotest.matchers.shouldBe
import prog8.code.ast.*
import prog8.code.core.*
import prog8.code.core.DataType
import prog8.code.target.C64Target
import prog8.compiler.astprocessing.IntermediateAstMaker
import prog8tests.helpers.compileText
@ -26,16 +26,6 @@ class TestIntermediateAst: FunSpec({
}
"""
val target = C64Target()
val options = CompilationOptions(
OutputType.RAW,
CbmPrgLauncherType.NONE,
ZeropageType.DONTUSE,
emptyList(),
floats = false,
noSysInit = true,
compTarget = target,
loadAddress = target.machine.PROGRAM_LOAD_ADDRESS
)
val result = compileText(target, false, text, writeAssembly = false)!!
val ast = IntermediateAstMaker(result.compilerAst).transform()
ast.name shouldBe result.compilerAst.name

View File

@ -184,29 +184,32 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
val leftDt = left.inferType(program)
val rightDt = right.inferType(program)
// fun dynamicBooleanType(): InferredTypes.InferredType {
// // as a special case, an expression yielding a boolean result, adapts the result
// // type to what is required (byte or word), to avoid useless type casting
// return when (parent) {
// is TypecastExpression -> InferredTypes.InferredType.known((parent as TypecastExpression).type)
// is Assignment -> (parent as Assignment).target.inferType(program)
// else -> InferredTypes.InferredType.known(DataType.BOOL) // or UBYTE?
// }
// }
return when (operator) {
"+", "-", "*", "%", "/" -> {
if (!leftDt.isKnown || !rightDt.isKnown)
InferredTypes.unknown()
else {
try {
InferredTypes.knownFor(
val dt = InferredTypes.knownFor(
commonDatatype(
leftDt.getOr(DataType.BYTE),
rightDt.getOr(DataType.BYTE),
null, null
).first
)
if(operator=="*") {
// if both operands are the same, X*X is always positive.
if(left isSameAs right) {
if(dt.istype(DataType.BYTE))
InferredTypes.knownFor(DataType.UBYTE)
else if(dt.istype(DataType.WORD))
InferredTypes.knownFor(DataType.UWORD)
else
dt
} else
dt
} else
dt
} catch (x: FatalAstException) {
InferredTypes.unknown()
}

View File

@ -1,6 +1,8 @@
TODO
====
- check mult and sqrt routines with the benchmarked ones on https://github.com/TobyLobster/sqrt_test / https://github.com/TobyLobster/multiply_test
- is math.square still the fastest after this? (now used for word*word)
- [on branch:] investigate McCarthy evaluation again? this may also reduce code size perhaps for things like if a>4 or a<2 ....
- IR: reduce the number of branch instructions such as BEQ, BEQR, etc (gradually), replace with CMP(I) + status branch instruction
- IR: reduce amount of CMP/CMPI after instructions that set the status bits correctly (LOADs? INC? etc), but only after setting the status bits is verified!

View File

@ -1,339 +1,67 @@
%import textio
%zeropage basicsafe
cbm2 {
sub SETTIM(ubyte a, ubyte b, ubyte c) {
}
sub RDTIM16() -> uword {
return 0
}
}
main {
sub start() {
greater()
greater_signed()
less()
less_signed()
ubyte value
uword wvalue
ubyte other = 99
uword otherw = 99
greatereq()
greatereq_signed()
lesseq()
lesseq_signed()
value=13
wvalue=99
txt.print_ub(value*value)
txt.spc()
txt.print_uw(wvalue*wvalue)
txt.nl()
txt.print("byte multiply..")
cbm.SETTIM(0,0,0)
repeat 100 {
for value in 0 to 255 {
cx16.r0L = value*other
}
sub value(ubyte arg) -> ubyte {
cx16.r0++
return arg
}
txt.print_uw(cbm.RDTIM16())
txt.nl()
sub svalue(byte arg) -> byte {
cx16.r0++
return arg
txt.print("byte squares...")
cbm.SETTIM(0,0,0)
repeat 100 {
for value in 0 to 255 {
cx16.r0L = value*value
}
}
txt.print_uw(cbm.RDTIM16())
txt.nl()
sub greater () {
ubyte b1 = 10
ubyte b2 = 20
ubyte b3 = 10
txt.print(">(u): 101010: ")
ubyte xx
xx = b2>10
txt.print_ub(xx)
txt.spc()
xx = b2>20
txt.print_ub(xx)
txt.spc()
xx = b2>b1
txt.print_ub(xx)
txt.spc()
xx = b3>b1
txt.print_ub(xx)
txt.spc()
xx = b2>value(10)
txt.print_ub(xx)
txt.spc()
xx = b3>value(20)
txt.print_ub(xx)
txt.spc()
txt.print("word multiply..")
cbm.SETTIM(0,0,0)
repeat 50 {
for wvalue in 0 to 255 {
cx16.r0 = wvalue*otherw
}
}
txt.print_uw(cbm.RDTIM16())
txt.nl()
txt.print("word squares...")
cbm.SETTIM(0,0,0)
repeat 50 {
for wvalue in 0 to 255 {
cx16.r0 = wvalue*wvalue
}
}
txt.print_uw(cbm.RDTIM16())
txt.nl()
}
sub greater_signed () {
byte b1 = -20
byte b2 = -10
byte b3 = -20
txt.print(">(s): 101010: ")
ubyte xx
xx = b2 > -20
txt.print_ub(xx)
txt.spc()
xx = b2 > -10
txt.print_ub(xx)
txt.spc()
xx = b2>b1
txt.print_ub(xx)
txt.spc()
xx = b3>b1
txt.print_ub(xx)
txt.spc()
xx = b2>svalue(-20)
txt.print_ub(xx)
txt.spc()
xx = b3>svalue(-10)
txt.print_ub(xx)
txt.spc()
txt.nl()
}
sub less () {
ubyte b1 = 20
ubyte b2 = 10
ubyte b3 = 20
txt.print("<(u): 101010: ")
ubyte xx
xx = b2<20
txt.print_ub(xx)
txt.spc()
xx = b2<10
txt.print_ub(xx)
txt.spc()
xx = b2<b1
txt.print_ub(xx)
txt.spc()
xx = b3<b1
txt.print_ub(xx)
txt.spc()
xx = b2<value(20)
txt.print_ub(xx)
txt.spc()
xx = b2<value(10)
txt.print_ub(xx)
txt.spc()
txt.nl()
}
sub less_signed () {
byte b1 = -10
byte b2 = -20
byte b3 = -10
txt.print("<(s): 101010: ")
ubyte xx
xx = b2 < -10
txt.print_ub(xx)
txt.spc()
xx = b2 < -20
txt.print_ub(xx)
txt.spc()
xx = b2<b1
txt.print_ub(xx)
txt.spc()
xx = b3<b1
txt.print_ub(xx)
txt.spc()
xx = b2<svalue(-10)
txt.print_ub(xx)
txt.spc()
xx = b3<svalue(-20)
txt.print_ub(xx)
txt.spc()
txt.nl()
}
sub greatereq () {
ubyte b1 = 19
ubyte b2 = 20
ubyte b3 = 21
ubyte b4 = 20
txt.print(">=(u): 110110110: ")
ubyte xx
xx = b2>=19
txt.print_ub(xx)
txt.spc()
xx = b2>=20
txt.print_ub(xx)
txt.spc()
xx = b2>=21
txt.print_ub(xx)
txt.spc()
xx = b2>=b1
txt.print_ub(xx)
txt.spc()
xx = b2>=b4
txt.print_ub(xx)
txt.spc()
xx = b2>=b3
txt.print_ub(xx)
txt.spc()
xx = b2>=value(19)
txt.print_ub(xx)
txt.spc()
xx = b2>=value(20)
txt.print_ub(xx)
txt.spc()
xx = b2>=value(21)
txt.print_ub(xx)
txt.spc()
txt.nl()
}
sub greatereq_signed () {
byte b1 = -19
byte b2 = -20
byte b3 = -21
byte b4 = -20
txt.print(">=(s): 011011011: ")
ubyte xx
xx = b2>= -19
txt.print_ub(xx)
txt.spc()
xx = b2>= -20
txt.print_ub(xx)
txt.spc()
xx = b2>= -21
txt.print_ub(xx)
txt.spc()
xx = b2>=b1
txt.print_ub(xx)
txt.spc()
xx = b2>=b4
txt.print_ub(xx)
txt.spc()
xx = b2>=b3
txt.print_ub(xx)
txt.spc()
xx = b2>=value(-19)
txt.print_ub(xx)
txt.spc()
xx = b2>=value(-20)
txt.print_ub(xx)
txt.spc()
xx = b2>=value(-21)
txt.print_ub(xx)
txt.spc()
txt.nl()
}
sub lesseq () {
ubyte b1 = 19
ubyte b2 = 20
ubyte b3 = 21
ubyte b4 = 20
txt.print("<=(u): 011011011: ")
ubyte xx
xx = b2<=19
txt.print_ub(xx)
txt.spc()
xx = b2<=20
txt.print_ub(xx)
txt.spc()
xx = b2<=21
txt.print_ub(xx)
txt.spc()
xx = b2<=b1
txt.print_ub(xx)
txt.spc()
xx = b2<=b4
txt.print_ub(xx)
txt.spc()
xx = b2<=b3
txt.print_ub(xx)
txt.spc()
xx = b2<=value(19)
txt.print_ub(xx)
txt.spc()
xx = b2<=value(20)
txt.print_ub(xx)
txt.spc()
xx = b2<=value(21)
txt.print_ub(xx)
txt.spc()
txt.nl()
}
sub lesseq_signed () {
byte b1 = -19
byte b2 = -20
byte b3 = -21
byte b4 = -20
txt.print("<=(s): 110110110: ")
ubyte xx
xx = b2<= -19
txt.print_ub(xx)
txt.spc()
xx = b2<= -20
txt.print_ub(xx)
txt.spc()
xx = b2<= -21
txt.print_ub(xx)
txt.spc()
xx = b2<=b1
txt.print_ub(xx)
txt.spc()
xx = b2<=b4
txt.print_ub(xx)
txt.spc()
xx = b2<=b3
txt.print_ub(xx)
txt.spc()
xx = b2<=value(-19)
txt.print_ub(xx)
txt.spc()
xx = b2<=value(-20)
txt.print_ub(xx)
txt.spc()
xx = b2<=value(-21)
txt.print_ub(xx)
txt.spc()
txt.nl()
}
}

View File

@ -144,6 +144,7 @@ mod reg1, value - remainder (modulo) of unsigned div
divmodr reg1, reg2 - unsigned division reg1/reg2, storing division and remainder on value stack (so need to be POPped off)
divmod reg1, value - unsigned division reg1/value, storing division and remainder on value stack (so need to be POPped off)
sqrt reg1, reg2 - reg1 is the square root of reg2 (reg2 can be .w or .b, result type in reg1 is always .b) you can also use it with floating point types, fpreg1 and fpreg2 (result is also .f)
square reg1, reg2 - reg1 is the square of reg2 (reg2 can be .w or .b, result type in reg1 is always .b) you can also use it with floating point types, fpreg1 and fpreg2 (result is also .f)
sgn reg1, reg2 - reg1 is the sign of reg2 (0.b, 1.b or -1.b)
cmp reg1, reg2 - set processor status bits C, N, Z according to comparison of reg1 with reg2. (semantics taken from 6502/68000 CMP instruction)
cmpi reg1, value - set processor status bits C, N, Z according to comparison of reg1 with immediate value. (semantics taken from 6502/68000 CMP instruction)
@ -311,6 +312,7 @@ enum class Opcode {
DIVMODR,
DIVMOD,
SQRT,
SQUARE,
SGN,
CMP,
CMPI,
@ -601,6 +603,7 @@ val instructionFormats = mutableMapOf(
Opcode.DIVS to InstructionFormat.from("BW,<>r1,<i | F,<>fr1,<i"),
Opcode.DIVSM to InstructionFormat.from("BW,<r1,<>a | F,<fr1,<>a"),
Opcode.SQRT to InstructionFormat.from("BW,>r1,<r2 | F,>fr1,<fr2"),
Opcode.SQUARE to InstructionFormat.from("BW,>r1,<r2 | F,>fr1,<fr2"),
Opcode.SGN to InstructionFormat.from("BW,>r1,<r2 | F,>fr1,<fr2"),
Opcode.MODR to InstructionFormat.from("BW,<>r1,<r2"),
Opcode.MOD to InstructionFormat.from("BW,<>r1,<i"),

View File

@ -247,6 +247,7 @@ class VirtualMachine(irProgram: IRProgram) {
Opcode.CMP -> InsCMP(ins)
Opcode.CMPI -> InsCMPI(ins)
Opcode.SQRT -> InsSQRT(ins)
Opcode.SQUARE -> InsSQUARE(ins)
Opcode.EXT -> InsEXT(ins)
Opcode.EXTS -> InsEXTS(ins)
Opcode.ANDR -> InsANDR(ins)
@ -1247,6 +1248,24 @@ class VirtualMachine(irProgram: IRProgram) {
nextPc()
}
private fun InsSQUARE(i: IRInstruction) {
when(i.type!!) {
IRDataType.BYTE -> {
val value = registers.getUB(i.reg2!!).toDouble().toInt()
registers.setUB(i.reg1!!, (value*value).toUByte())
}
IRDataType.WORD -> {
val value = registers.getUW(i.reg2!!).toDouble().toInt()
registers.setUW(i.reg1!!, (value*value).toUShort())
}
IRDataType.FLOAT -> {
val value = registers.getFloat(i.fpReg2!!)
registers.setFloat(i.fpReg1!!, value*value)
}
}
nextPc()
}
private fun InsCMP(i: IRInstruction) {
val comparison: Int
when(i.type!!) {