From 2c9e50873c90a3f4c466b856a9b7cc8457da8bd6 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Mon, 14 Aug 2023 00:50:40 +0200 Subject: [PATCH] use math.square for optimized X*X calculation (words only). Added IR SQUARE instruction. --- .../src/prog8/code/core/BuiltinFunctions.kt | 4 +- .../codegen/cpu6502/BuiltinFunctionsAsmGen.kt | 19 + .../codegen/cpu6502/ProgramAndVarsGen.kt | 1 - .../codegen/intermediate/BuiltinFuncGen.kt | 18 + compiler/res/prog8lib/math.asm | 6 +- .../compiler/astprocessing/CodeDesugarer.kt | 9 + compiler/test/ast/TestIntermediateAst.kt | 12 +- .../prog8/ast/expressions/AstExpressions.kt | 25 +- docs/source/todo.rst | 2 + examples/test.p8 | 376 +++--------------- .../src/prog8/intermediate/IRInstructions.kt | 3 + virtualmachine/src/prog8/vm/VirtualMachine.kt | 19 + 12 files changed, 143 insertions(+), 351 deletions(-) diff --git a/codeCore/src/prog8/code/core/BuiltinFunctions.kt b/codeCore/src/prog8/code/core/BuiltinFunctions.kt index 5a69b8024..f2c31d629 100644 --- a/codeCore/src/prog8/code/core/BuiltinFunctions.kt +++ b/codeCore/src/prog8/code/core/BuiltinFunctions.kt @@ -78,7 +78,9 @@ val BuiltinFunctions: Map = mapOf( "reverse" to FSignature(false, listOf(FParam("array", ArrayDatatypes)), null), // cmp returns a status in the carry flag, but not a proper return value "cmp" to FSignature(false, listOf(FParam("value1", IntegerDatatypesNoBool), FParam("value2", NumericDatatypesNoBool)), null), - "prog8_lib_stringcompare" to FSignature(true, listOf(FParam("str1", arrayOf(DataType.STR)), FParam("str2", arrayOf(DataType.STR))), DataType.BYTE), + "prog8_lib_stringcompare" to FSignature(true, listOf(FParam("str1", arrayOf(DataType.STR)), FParam("str2", arrayOf(DataType.STR))), DataType.BYTE), + "prog8_lib_square_byte" to FSignature(true, listOf(FParam("value", arrayOf(DataType.BYTE, DataType.UBYTE))), DataType.UBYTE), + "prog8_lib_square_word" to FSignature(true, listOf(FParam("value", arrayOf(DataType.WORD, DataType.UWORD))), DataType.UWORD), "abs" to FSignature(true, listOf(FParam("value", NumericDatatypesNoBool)), null), "abs__byte" to FSignature(true, listOf(FParam("value", arrayOf(DataType.BYTE))), DataType.BYTE), "abs__word" to FSignature(true, listOf(FParam("value", arrayOf(DataType.WORD))), DataType.WORD), diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt index 3d248e3ed..a33951db0 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt @@ -72,12 +72,31 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram, "cmp" -> funcCmp(fcall) "callfar" -> funcCallFar(fcall) "prog8_lib_stringcompare" -> funcStringCompare(fcall) + "prog8_lib_square_byte" -> funcSquare(fcall, DataType.UBYTE) + "prog8_lib_square_word" -> funcSquare(fcall, DataType.UWORD) else -> throw AssemblyError("missing asmgen for builtin func ${fcall.name}") } return BuiltinFunctions.getValue(fcall.name).returnType } + private fun funcSquare(fcall: PtBuiltinFunctionCall, resultType: DataType) { + // square of word value is faster with dedicated routine, square of byte just use the regular multiplication routine. + when (resultType) { + DataType.UBYTE -> { + asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A) + asmgen.out(" tay | jsr math.multiply_bytes") + } + DataType.UWORD -> { + asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.AY) + asmgen.out(" jsr math.square") + } + else -> { + throw AssemblyError("optimized square only for integer types") + } + } + } + private fun funcDivmod(fcall: PtBuiltinFunctionCall) { assignAsmGen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A, false) asmgen.saveRegisterStack(CpuRegister.A, false) diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/ProgramAndVarsGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/ProgramAndVarsGen.kt index 53b7fd29a..0a4e68f35 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/ProgramAndVarsGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/ProgramAndVarsGen.kt @@ -452,7 +452,6 @@ internal class ProgramAndVarsGen( } asmgen.out("""+ - ldx #127 ; init estack ptr (half page) clv clc""") } diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt index efd9cee58..4023e2483 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt @@ -45,10 +45,28 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe "rol2" -> funcRolRor(Opcode.ROL, call) "ror2" -> funcRolRor(Opcode.ROR, call) "prog8_lib_stringcompare" -> funcStringCompare(call) + "prog8_lib_square_byte" -> funcSquare(call, IRDataType.BYTE) + "prog8_lib_square_word" -> funcSquare(call, IRDataType.WORD) else -> throw AssemblyError("missing builtinfunc for ${call.name}") } } + private fun funcSquare(call: PtBuiltinFunctionCall, resultType: IRDataType): ExpressionCodeResult { + val result = mutableListOf() + val valueTr = exprGen.translateExpression(call.args[0]) + addToResult(result, valueTr, valueTr.resultReg, valueTr.resultFpReg) + return if(resultType==IRDataType.FLOAT) { + val resultFpReg = codeGen.registers.nextFreeFloat() + addInstr(result, IRInstruction(Opcode.SQUARE, resultType, fpReg1 = resultFpReg, fpReg2 = valueTr.resultFpReg), null) + ExpressionCodeResult(result, resultType, -1, resultFpReg) + } + else { + val resultReg = codeGen.registers.nextFree() + addInstr(result, IRInstruction(Opcode.SQUARE, resultType, reg1 = resultReg, reg2 = valueTr.resultReg), null) + ExpressionCodeResult(result, resultType, resultReg, -1) + } + } + private fun funcCallfar(call: PtBuiltinFunctionCall): ExpressionCodeResult { val result = mutableListOf() val bankTr = exprGen.translateExpression(call.args[0]) diff --git a/compiler/res/prog8lib/math.asm b/compiler/res/prog8lib/math.asm index abbbbc3bb..8bae268c5 100644 --- a/compiler/res/prog8lib/math.asm +++ b/compiler/res/prog8lib/math.asm @@ -814,9 +814,9 @@ asl_word_AY .proc square .proc -; -- calculate square root of signed word in AY, result in AY -; routine by Lee Davsion, source: http://6502.org/source/integers/square.htm -; using this routine is about twice as fast as doing a regular multiplication. +; -- calculate square of signed word (actually -255..255) in AY, result in AY +; routine by Lee Davison, source: http://6502.org/source/integers/square.htm +; using this routine is a lot faster as doing a regular multiplication (for words) ; ; Calculates the 16 bit unsigned integer square of the signed 16 bit integer in ; Numberl/Numberh. The result is always in the range 0 to 65025 and is held in diff --git a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt index 7b9a5cb52..859236f24 100644 --- a/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt +++ b/compiler/src/prog8/compiler/astprocessing/CodeDesugarer.kt @@ -200,6 +200,15 @@ _after: } } + if(expr.operator=="*" && expr.inferType(program).isInteger && expr.left isSameAs expr.right) { + // replace squaring with call to builtin function to do this in a more optimized way + val function = if(expr.left.inferType(program).isBytes) "prog8_lib_square_byte" else "prog8_lib_square_word" + val squareCall = BuiltinFunctionCall( + IdentifierReference(listOf(function), expr.position), + mutableListOf(expr.left.copy()), expr.position) + return listOf(IAstModification.ReplaceNode(expr, squareCall, parent)) + } + return noModifications } } diff --git a/compiler/test/ast/TestIntermediateAst.kt b/compiler/test/ast/TestIntermediateAst.kt index d259fa763..37d4da577 100644 --- a/compiler/test/ast/TestIntermediateAst.kt +++ b/compiler/test/ast/TestIntermediateAst.kt @@ -5,7 +5,7 @@ import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.shouldBe import prog8.code.ast.* -import prog8.code.core.* +import prog8.code.core.DataType import prog8.code.target.C64Target import prog8.compiler.astprocessing.IntermediateAstMaker import prog8tests.helpers.compileText @@ -26,16 +26,6 @@ class TestIntermediateAst: FunSpec({ } """ val target = C64Target() - val options = CompilationOptions( - OutputType.RAW, - CbmPrgLauncherType.NONE, - ZeropageType.DONTUSE, - emptyList(), - floats = false, - noSysInit = true, - compTarget = target, - loadAddress = target.machine.PROGRAM_LOAD_ADDRESS - ) val result = compileText(target, false, text, writeAssembly = false)!! val ast = IntermediateAstMaker(result.compilerAst).transform() ast.name shouldBe result.compilerAst.name diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index 1eeb9eb20..07abe821a 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -184,29 +184,32 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex val leftDt = left.inferType(program) val rightDt = right.inferType(program) -// fun dynamicBooleanType(): InferredTypes.InferredType { -// // as a special case, an expression yielding a boolean result, adapts the result -// // type to what is required (byte or word), to avoid useless type casting -// return when (parent) { -// is TypecastExpression -> InferredTypes.InferredType.known((parent as TypecastExpression).type) -// is Assignment -> (parent as Assignment).target.inferType(program) -// else -> InferredTypes.InferredType.known(DataType.BOOL) // or UBYTE? -// } -// } - return when (operator) { "+", "-", "*", "%", "/" -> { if (!leftDt.isKnown || !rightDt.isKnown) InferredTypes.unknown() else { try { - InferredTypes.knownFor( + val dt = InferredTypes.knownFor( commonDatatype( leftDt.getOr(DataType.BYTE), rightDt.getOr(DataType.BYTE), null, null ).first ) + if(operator=="*") { + // if both operands are the same, X*X is always positive. + if(left isSameAs right) { + if(dt.istype(DataType.BYTE)) + InferredTypes.knownFor(DataType.UBYTE) + else if(dt.istype(DataType.WORD)) + InferredTypes.knownFor(DataType.UWORD) + else + dt + } else + dt + } else + dt } catch (x: FatalAstException) { InferredTypes.unknown() } diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 024b27b62..3d0216b54 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -1,6 +1,8 @@ TODO ==== +- check mult and sqrt routines with the benchmarked ones on https://github.com/TobyLobster/sqrt_test / https://github.com/TobyLobster/multiply_test +- is math.square still the fastest after this? (now used for word*word) - [on branch:] investigate McCarthy evaluation again? this may also reduce code size perhaps for things like if a>4 or a<2 .... - IR: reduce the number of branch instructions such as BEQ, BEQR, etc (gradually), replace with CMP(I) + status branch instruction - IR: reduce amount of CMP/CMPI after instructions that set the status bits correctly (LOADs? INC? etc), but only after setting the status bits is verified! diff --git a/examples/test.p8 b/examples/test.p8 index 0ce6e11f0..4ea070018 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,339 +1,67 @@ %import textio %zeropage basicsafe +cbm2 { + sub SETTIM(ubyte a, ubyte b, ubyte c) { + } + sub RDTIM16() -> uword { + return 0 + } +} + main { sub start() { - greater() - greater_signed() - less() - less_signed() + ubyte value + uword wvalue + ubyte other = 99 + uword otherw = 99 - greatereq() - greatereq_signed() - lesseq() - lesseq_signed() - } - - sub value(ubyte arg) -> ubyte { - cx16.r0++ - return arg - } - - sub svalue(byte arg) -> byte { - cx16.r0++ - return arg - } - - sub greater () { - ubyte b1 = 10 - ubyte b2 = 20 - ubyte b3 = 10 - - txt.print(">(u): 101010: ") - ubyte xx - xx = b2>10 - txt.print_ub(xx) + value=13 + wvalue=99 + txt.print_ub(value*value) txt.spc() + txt.print_uw(wvalue*wvalue) + txt.nl() - xx = b2>20 - txt.print_ub(xx) - txt.spc() - xx = b2>b1 - txt.print_ub(xx) - txt.spc() + txt.print("byte multiply..") + cbm.SETTIM(0,0,0) + repeat 100 { + for value in 0 to 255 { + cx16.r0L = value*other + } + } + txt.print_uw(cbm.RDTIM16()) + txt.nl() - xx = b3>b1 - txt.print_ub(xx) - txt.spc() + txt.print("byte squares...") + cbm.SETTIM(0,0,0) + repeat 100 { + for value in 0 to 255 { + cx16.r0L = value*value + } + } + txt.print_uw(cbm.RDTIM16()) + txt.nl() - xx = b2>value(10) - txt.print_ub(xx) - txt.spc() - xx = b3>value(20) - txt.print_ub(xx) - txt.spc() + txt.print("word multiply..") + cbm.SETTIM(0,0,0) + repeat 50 { + for wvalue in 0 to 255 { + cx16.r0 = wvalue*otherw + } + } + txt.print_uw(cbm.RDTIM16()) + txt.nl() + txt.print("word squares...") + cbm.SETTIM(0,0,0) + repeat 50 { + for wvalue in 0 to 255 { + cx16.r0 = wvalue*wvalue + } + } + txt.print_uw(cbm.RDTIM16()) txt.nl() } - - sub greater_signed () { - byte b1 = -20 - byte b2 = -10 - byte b3 = -20 - - txt.print(">(s): 101010: ") - ubyte xx - xx = b2 > -20 - txt.print_ub(xx) - txt.spc() - - xx = b2 > -10 - txt.print_ub(xx) - txt.spc() - - xx = b2>b1 - txt.print_ub(xx) - txt.spc() - - xx = b3>b1 - txt.print_ub(xx) - txt.spc() - - xx = b2>svalue(-20) - txt.print_ub(xx) - txt.spc() - xx = b3>svalue(-10) - txt.print_ub(xx) - txt.spc() - - txt.nl() - } - - sub less () { - ubyte b1 = 20 - ubyte b2 = 10 - ubyte b3 = 20 - - txt.print("<(u): 101010: ") - ubyte xx - xx = b2<20 - txt.print_ub(xx) - txt.spc() - - xx = b2<10 - txt.print_ub(xx) - txt.spc() - - xx = b2=(u): 110110110: ") - ubyte xx - xx = b2>=19 - txt.print_ub(xx) - txt.spc() - - xx = b2>=20 - txt.print_ub(xx) - txt.spc() - - xx = b2>=21 - txt.print_ub(xx) - txt.spc() - - xx = b2>=b1 - txt.print_ub(xx) - txt.spc() - - xx = b2>=b4 - txt.print_ub(xx) - txt.spc() - - xx = b2>=b3 - txt.print_ub(xx) - txt.spc() - - xx = b2>=value(19) - txt.print_ub(xx) - txt.spc() - xx = b2>=value(20) - txt.print_ub(xx) - txt.spc() - xx = b2>=value(21) - txt.print_ub(xx) - txt.spc() - - txt.nl() - } - - sub greatereq_signed () { - byte b1 = -19 - byte b2 = -20 - byte b3 = -21 - byte b4 = -20 - - txt.print(">=(s): 011011011: ") - ubyte xx - xx = b2>= -19 - txt.print_ub(xx) - txt.spc() - - xx = b2>= -20 - txt.print_ub(xx) - txt.spc() - - xx = b2>= -21 - txt.print_ub(xx) - txt.spc() - - xx = b2>=b1 - txt.print_ub(xx) - txt.spc() - - xx = b2>=b4 - txt.print_ub(xx) - txt.spc() - - xx = b2>=b3 - txt.print_ub(xx) - txt.spc() - - xx = b2>=value(-19) - txt.print_ub(xx) - txt.spc() - xx = b2>=value(-20) - txt.print_ub(xx) - txt.spc() - xx = b2>=value(-21) - txt.print_ub(xx) - txt.spc() - - txt.nl() - } - - sub lesseq () { - ubyte b1 = 19 - ubyte b2 = 20 - ubyte b3 = 21 - ubyte b4 = 20 - - txt.print("<=(u): 011011011: ") - ubyte xx - xx = b2<=19 - txt.print_ub(xx) - txt.spc() - - xx = b2<=20 - txt.print_ub(xx) - txt.spc() - - xx = b2<=21 - txt.print_ub(xx) - txt.spc() - - xx = b2<=b1 - txt.print_ub(xx) - txt.spc() - - xx = b2<=b4 - txt.print_ub(xx) - txt.spc() - - xx = b2<=b3 - txt.print_ub(xx) - txt.spc() - - xx = b2<=value(19) - txt.print_ub(xx) - txt.spc() - xx = b2<=value(20) - txt.print_ub(xx) - txt.spc() - xx = b2<=value(21) - txt.print_ub(xx) - txt.spc() - - txt.nl() - } - - sub lesseq_signed () { - byte b1 = -19 - byte b2 = -20 - byte b3 = -21 - byte b4 = -20 - - txt.print("<=(s): 110110110: ") - ubyte xx - xx = b2<= -19 - txt.print_ub(xx) - txt.spc() - - xx = b2<= -20 - txt.print_ub(xx) - txt.spc() - - xx = b2<= -21 - txt.print_ub(xx) - txt.spc() - - xx = b2<=b1 - txt.print_ub(xx) - txt.spc() - - xx = b2<=b4 - txt.print_ub(xx) - txt.spc() - - xx = b2<=b3 - txt.print_ub(xx) - txt.spc() - - xx = b2<=value(-19) - txt.print_ub(xx) - txt.spc() - xx = b2<=value(-20) - txt.print_ub(xx) - txt.spc() - xx = b2<=value(-21) - txt.print_ub(xx) - txt.spc() - - txt.nl() - } - } diff --git a/intermediate/src/prog8/intermediate/IRInstructions.kt b/intermediate/src/prog8/intermediate/IRInstructions.kt index 6a4313dc4..06c7d6552 100644 --- a/intermediate/src/prog8/intermediate/IRInstructions.kt +++ b/intermediate/src/prog8/intermediate/IRInstructions.kt @@ -144,6 +144,7 @@ mod reg1, value - remainder (modulo) of unsigned div divmodr reg1, reg2 - unsigned division reg1/reg2, storing division and remainder on value stack (so need to be POPped off) divmod reg1, value - unsigned division reg1/value, storing division and remainder on value stack (so need to be POPped off) sqrt reg1, reg2 - reg1 is the square root of reg2 (reg2 can be .w or .b, result type in reg1 is always .b) you can also use it with floating point types, fpreg1 and fpreg2 (result is also .f) +square reg1, reg2 - reg1 is the square of reg2 (reg2 can be .w or .b, result type in reg1 is always .b) you can also use it with floating point types, fpreg1 and fpreg2 (result is also .f) sgn reg1, reg2 - reg1 is the sign of reg2 (0.b, 1.b or -1.b) cmp reg1, reg2 - set processor status bits C, N, Z according to comparison of reg1 with reg2. (semantics taken from 6502/68000 CMP instruction) cmpi reg1, value - set processor status bits C, N, Z according to comparison of reg1 with immediate value. (semantics taken from 6502/68000 CMP instruction) @@ -311,6 +312,7 @@ enum class Opcode { DIVMODR, DIVMOD, SQRT, + SQUARE, SGN, CMP, CMPI, @@ -601,6 +603,7 @@ val instructionFormats = mutableMapOf( Opcode.DIVS to InstructionFormat.from("BW,<>r1,fr1,a | F,a"), Opcode.SQRT to InstructionFormat.from("BW,>r1,fr1,r1,fr1,r1,fr1,r1,r1, InsCMP(ins) Opcode.CMPI -> InsCMPI(ins) Opcode.SQRT -> InsSQRT(ins) + Opcode.SQUARE -> InsSQUARE(ins) Opcode.EXT -> InsEXT(ins) Opcode.EXTS -> InsEXTS(ins) Opcode.ANDR -> InsANDR(ins) @@ -1247,6 +1248,24 @@ class VirtualMachine(irProgram: IRProgram) { nextPc() } + private fun InsSQUARE(i: IRInstruction) { + when(i.type!!) { + IRDataType.BYTE -> { + val value = registers.getUB(i.reg2!!).toDouble().toInt() + registers.setUB(i.reg1!!, (value*value).toUByte()) + } + IRDataType.WORD -> { + val value = registers.getUW(i.reg2!!).toDouble().toInt() + registers.setUW(i.reg1!!, (value*value).toUShort()) + } + IRDataType.FLOAT -> { + val value = registers.getFloat(i.fpReg2!!) + registers.setFloat(i.fpReg1!!, value*value) + } + } + nextPc() + } + private fun InsCMP(i: IRInstruction) { val comparison: Int when(i.type!!) {