From c07eda15b102fc594a7a447d717d1332e7903017 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sat, 29 Apr 2023 14:22:04 +0200 Subject: [PATCH] adding min() and max() --- .../src/prog8/code/core/BuiltinFunctions.kt | 8 ++ .../codegen/cpu6502/BuiltinFunctionsAsmGen.kt | 122 ++++++++++++++++++ .../codegen/intermediate/AssignmentGen.kt | 6 +- .../codegen/intermediate/BuiltinFuncGen.kt | 47 ++++++- .../codegen/intermediate/ExpressionGen.kt | 34 ++--- .../prog8/codegen/intermediate/IRCodeGen.kt | 14 -- .../optimizer/ConstantIdentifierReplacer.kt | 33 +++++ .../src/prog8/compiler/BuiltinFunctions.kt | 87 ++++++++++++- docs/source/todo.rst | 11 +- examples/test.p8 | 33 +++-- .../src/prog8/intermediate/IRInstructions.kt | 19 ++- 11 files changed, 360 insertions(+), 54 deletions(-) diff --git a/codeCore/src/prog8/code/core/BuiltinFunctions.kt b/codeCore/src/prog8/code/core/BuiltinFunctions.kt index f3c617d7c..cf9f9a57c 100644 --- a/codeCore/src/prog8/code/core/BuiltinFunctions.kt +++ b/codeCore/src/prog8/code/core/BuiltinFunctions.kt @@ -93,6 +93,14 @@ val BuiltinFunctions: Map = mapOf( "lsb" to FSignature(true, listOf(FParam("value", arrayOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE), "msb" to FSignature(true, listOf(FParam("value", arrayOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE), "mkword" to FSignature(true, listOf(FParam("msb", arrayOf(DataType.UBYTE)), FParam("lsb", arrayOf(DataType.UBYTE))), DataType.UWORD), + "min__byte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.BYTE)), FParam("val2", arrayOf(DataType.BYTE))), DataType.BYTE), + "min__ubyte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UBYTE)), FParam("val2", arrayOf(DataType.UBYTE))), DataType.UBYTE), + "min__word" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.WORD)), FParam("val2", arrayOf(DataType.WORD))), DataType.WORD), + "min__uword" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UWORD)), FParam("val2", arrayOf(DataType.UWORD))), DataType.UWORD), + "max__byte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.BYTE)), FParam("val2", arrayOf(DataType.BYTE))), DataType.BYTE), + "max__ubyte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UBYTE)), FParam("val2", arrayOf(DataType.UBYTE))), DataType.UBYTE), + "max__word" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.WORD)), FParam("val2", arrayOf(DataType.WORD))), DataType.WORD), + "max__uword" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UWORD)), FParam("val2", arrayOf(DataType.UWORD))), DataType.UWORD), "peek" to FSignature(true, listOf(FParam("address", arrayOf(DataType.UWORD))), DataType.UBYTE), "peekw" to FSignature(true, listOf(FParam("address", arrayOf(DataType.UWORD))), DataType.UWORD), "poke" to FSignature(false, listOf(FParam("address", arrayOf(DataType.UWORD)), FParam("value", arrayOf(DataType.UBYTE))), null), diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt index e7009ed50..1107a14fe 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/BuiltinFunctionsAsmGen.kt @@ -31,6 +31,8 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram, "msb" -> funcMsb(fcall, resultToStack, resultRegister) "lsb" -> funcLsb(fcall, resultToStack, resultRegister) "mkword" -> funcMkword(fcall, resultToStack, resultRegister) + "min__byte", "min__ubyte", "min__word", "min__uword" -> funcMin(fcall, resultToStack, resultRegister) + "max__byte", "max__ubyte", "max__word", "max__uword" -> funcMax(fcall, resultToStack, resultRegister) "abs" -> funcAbs(fcall, resultToStack, resultRegister, sscope) "any", "all" -> funcAnyAll(fcall, resultToStack, resultRegister, sscope) "sgn" -> funcSgn(fcall, resultToStack, resultRegister, sscope) @@ -826,6 +828,126 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram, } } + private fun funcMin(fcall: PtBuiltinFunctionCall, resultToStack: Boolean, resultRegister: RegisterOrPair?) { + val signed = fcall.type in SignedDatatypes + if(fcall.type in ByteDatatypes) { + asmgen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_B1", fcall.type) // right + asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A) // left + asmgen.out(" cmp P8ZP_SCRATCH_B1") + if(signed) asmgen.out(" bmi +") else asmgen.out(" bcc +") + asmgen.out(""" + lda P8ZP_SCRATCH_B1 ++""") + if(resultToStack) { + asmgen.out(" sta P8ESTACK_LO,x | dex") + } else { + val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen) + asmgen.assignRegister(RegisterOrPair.A, targetReg) + } + } else { + asmgen.assignExpressionToVariable(fcall.args[0], "P8ZP_SCRATCH_W1", fcall.type) // left + asmgen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_W2", fcall.type) // right + if(signed) { + asmgen.out(""" + lda P8ZP_SCRATCH_W1 + ldy P8ZP_SCRATCH_W1+1 + cmp P8ZP_SCRATCH_W2 + tya + sbc P8ZP_SCRATCH_W2+1 + bvc + + eor #$80 ++ bpl + + lda P8ZP_SCRATCH_W1 + ldy P8ZP_SCRATCH_W1+1 + jmp ++ ++ lda P8ZP_SCRATCH_W2 + ldy P8ZP_SCRATCH_W2+1 ++""") + } else { + asmgen.out(""" + lda P8ZP_SCRATCH_W1+1 + cmp P8ZP_SCRATCH_W2+1 + bcc ++ + bne + + lda P8ZP_SCRATCH_W1 + cmp P8ZP_SCRATCH_W2 + bcc ++ ++ lda P8ZP_SCRATCH_W2 + ldy P8ZP_SCRATCH_W2+1 + jmp ++ ++ lda P8ZP_SCRATCH_W1 + ldy P8ZP_SCRATCH_W1+1 ++""") + } + if(resultToStack) { + asmgen.out(" sta P8ESTACK_LO,x | sty P8ESTACK_HI,x | dex") + } else { + val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen) + asmgen.assignRegister(RegisterOrPair.AY, targetReg) + } + } + } + + private fun funcMax(fcall: PtBuiltinFunctionCall, resultToStack: Boolean, resultRegister: RegisterOrPair?) { + val signed = fcall.type in SignedDatatypes + if(fcall.type in ByteDatatypes) { + asmgen.assignExpressionToVariable(fcall.args[0], "P8ZP_SCRATCH_B1", fcall.type) // left + asmgen.assignExpressionToRegister(fcall.args[1], RegisterOrPair.A) // right + asmgen.out(" cmp P8ZP_SCRATCH_B1") + if(signed) asmgen.out(" bpl +") else asmgen.out(" bcs +") + asmgen.out(""" + lda P8ZP_SCRATCH_B1 ++""") + if(resultToStack) { + asmgen.out(" sta P8ESTACK_LO,x | dex") + } else { + val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen) + asmgen.assignRegister(RegisterOrPair.A, targetReg) + } + } else { + asmgen.assignExpressionToVariable(fcall.args[0], "P8ZP_SCRATCH_W1", fcall.type) // left + asmgen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_W2", fcall.type) // right + if(signed) { + asmgen.out(""" + lda P8ZP_SCRATCH_W1 + ldy P8ZP_SCRATCH_W1+1 + cmp P8ZP_SCRATCH_W2 + tya + sbc P8ZP_SCRATCH_W2+1 + bvc + + eor #$80 ++ bmi + + lda P8ZP_SCRATCH_W1 + ldy P8ZP_SCRATCH_W1+1 + jmp ++ ++ lda P8ZP_SCRATCH_W2 + ldy P8ZP_SCRATCH_W2+1 ++""") + } else { + asmgen.out(""" + lda P8ZP_SCRATCH_W1+1 + cmp P8ZP_SCRATCH_W2+1 + bcc ++ + bne + + lda P8ZP_SCRATCH_W1 + cmp P8ZP_SCRATCH_W2 + bcc ++ ++ lda P8ZP_SCRATCH_W1 + ldy P8ZP_SCRATCH_W1+1 + jmp ++ ++ lda P8ZP_SCRATCH_W2 + ldy P8ZP_SCRATCH_W2+1 ++""") + } + if(resultToStack) { + asmgen.out(" sta P8ESTACK_LO,x | sty P8ESTACK_HI,x | dex") + } else { + val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen) + asmgen.assignRegister(RegisterOrPair.AY, targetReg) + } + } + } + private fun funcMkword(fcall: PtBuiltinFunctionCall, resultToStack: Boolean, resultRegister: RegisterOrPair?) { if(resultToStack) { asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A) // msb diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/AssignmentGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/AssignmentGen.kt index 53b30902b..8ce38d3b9 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/AssignmentGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/AssignmentGen.kt @@ -46,7 +46,7 @@ internal class AssignmentGen(private val codeGen: IRCodeGen, private val express assignment: PtAugmentedAssign ): IRCodeChunks { val value = assignment.value - val vmDt = codeGen.irType(value.type) + val vmDt = irType(value.type) return when(assignment.operator) { "+" -> expressionEval.operatorPlusInplace(address, null, vmDt, value) "-" -> expressionEval.operatorMinusInplace(address, null, vmDt, value) @@ -72,7 +72,7 @@ internal class AssignmentGen(private val codeGen: IRCodeGen, private val express private fun assignVarAugmented(symbol: String, assignment: PtAugmentedAssign): IRCodeChunks { val value = assignment.value - val targetDt = codeGen.irType(assignment.target.type) + val targetDt = irType(assignment.target.type) return when (assignment.operator) { "+=" -> expressionEval.operatorPlusInplace(null, symbol, targetDt, value) "-=" -> expressionEval.operatorMinusInplace(null, symbol, targetDt, value) @@ -161,7 +161,7 @@ internal class AssignmentGen(private val codeGen: IRCodeGen, private val express val targetIdent = assignment.target.identifier val targetMemory = assignment.target.memory val targetArray = assignment.target.array - val vmDt = codeGen.irType(assignment.value.type) + val vmDt = irType(assignment.value.type) val result = mutableListOf() var valueRegister = -1 diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt index 43008e04c..823ebd086 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/BuiltinFuncGen.kt @@ -4,6 +4,7 @@ import prog8.code.StStaticVariable import prog8.code.ast.* import prog8.code.core.AssemblyError import prog8.code.core.DataType +import prog8.code.core.SignedDatatypes import prog8.intermediate.* @@ -37,6 +38,8 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe "pokew" -> funcPokeW(call) "pokemon" -> ExpressionCodeResult.EMPTY // easter egg function "mkword" -> funcMkword(call) + "min__byte", "min__ubyte", "min__word", "min__uword" -> funcMin(call) + "max__byte", "max__ubyte", "max__word", "max__uword" -> funcMax(call) "sort" -> funcSort(call) "reverse" -> funcReverse(call) "rol" -> funcRolRor(Opcode.ROXL, call) @@ -96,7 +99,7 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe addToResult(result, leftTr, leftTr.resultReg, -1) val rightTr = exprGen.translateExpression(call.args[1]) addToResult(result, rightTr, rightTr.resultReg, -1) - val dt = codeGen.irType(call.args[0].type) + val dt = irType(call.args[0].type) result += IRCodeChunk(null, null).also { it += IRInstruction(Opcode.CMP, dt, reg1=leftTr.resultReg, reg2=rightTr.resultReg) } @@ -199,7 +202,7 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe private fun funcSgn(call: PtBuiltinFunctionCall): ExpressionCodeResult { val result = mutableListOf() - val vmDt = codeGen.irType(call.type) + val vmDt = irType(call.type) val tr = exprGen.translateExpression(call.args.single()) addToResult(result, tr, tr.resultReg, -1) val resultReg = codeGen.registers.nextFree() @@ -317,6 +320,44 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe return ExpressionCodeResult(result, IRDataType.WORD, lsbTr.resultReg, -1) } + private fun funcMin(call: PtBuiltinFunctionCall): ExpressionCodeResult { + val type = irType(call.type) + val result = mutableListOf() + val leftTr = exprGen.translateExpression(call.args[0]) + addToResult(result, leftTr, leftTr.resultReg, -1) + val rightTr = exprGen.translateExpression(call.args[1]) + addToResult(result, rightTr, rightTr.resultReg, -1) + val comparisonOpcode = if(call.type in SignedDatatypes) Opcode.BGTSR else Opcode.BGTR + val after = codeGen.createLabelName() + result += IRCodeChunk(null, null).also { + it += IRInstruction(comparisonOpcode, type, reg1 = rightTr.resultReg, reg2 = leftTr.resultReg, labelSymbol = after) + // right <= left, take right + it += IRInstruction(Opcode.LOADR, type, reg1=leftTr.resultReg, reg2=rightTr.resultReg) + it += IRInstruction(Opcode.JUMP, labelSymbol = after) + } + result += IRCodeChunk(after, null) + return ExpressionCodeResult(result, type, leftTr.resultReg, -1) + } + + private fun funcMax(call: PtBuiltinFunctionCall): ExpressionCodeResult { + val type = irType(call.type) + val result = mutableListOf() + val leftTr = exprGen.translateExpression(call.args[0]) + addToResult(result, leftTr, leftTr.resultReg, -1) + val rightTr = exprGen.translateExpression(call.args[1]) + addToResult(result, rightTr, rightTr.resultReg, -1) + val comparisonOpcode = if(call.type in SignedDatatypes) Opcode.BGTSR else Opcode.BGTR + val after = codeGen.createLabelName() + result += IRCodeChunk(null, null).also { + it += IRInstruction(comparisonOpcode, type, reg1 = leftTr.resultReg, reg2 = rightTr.resultReg, labelSymbol = after) + // right >= left, take right + it += IRInstruction(Opcode.LOADR, type, reg1=leftTr.resultReg, reg2=rightTr.resultReg) + it += IRInstruction(Opcode.JUMP, labelSymbol = after) + } + result += IRCodeChunk(after, null) + return ExpressionCodeResult(result, type, leftTr.resultReg, -1) + } + private fun funcPokeW(call: PtBuiltinFunctionCall): ExpressionCodeResult { val result = mutableListOf() if(codeGen.isZero(call.args[1])) { @@ -455,7 +496,7 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe } private fun funcRolRor(opcode: Opcode, call: PtBuiltinFunctionCall): ExpressionCodeResult { - val vmDt = codeGen.irType(call.args[0].type) + val vmDt = irType(call.args[0].type) val result = mutableListOf() val tr = exprGen.translateExpression(call.args[0]) addToResult(result, tr, tr.resultReg, -1) diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt index b302f61c4..195683f4e 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt @@ -28,10 +28,10 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { fun translateExpression(expr: PtExpression): ExpressionCodeResult { return when (expr) { is PtMachineRegister -> { - ExpressionCodeResult(emptyList(), codeGen.irType(expr.type), expr.register, -1) + ExpressionCodeResult(emptyList(), irType(expr.type), expr.register, -1) } is PtNumber -> { - val vmDt = codeGen.irType(expr.type) + val vmDt = irType(expr.type) val code = IRCodeChunk(null, null) if(vmDt==IRDataType.FLOAT) { val resultFpRegister = codeGen.registers.nextFreeFloat() @@ -45,7 +45,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { } } is PtIdentifier -> { - val vmDt = codeGen.irType(expr.type) + val vmDt = irType(expr.type) val code = IRCodeChunk(null, null) if (expr.type in PassByValueDatatypes) { if(vmDt==IRDataType.FLOAT) { @@ -66,7 +66,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { } } is PtAddressOf -> { - val vmDt = codeGen.irType(expr.type) + val vmDt = irType(expr.type) val symbol = expr.identifier.name // note: LOAD gets you the address of the symbol, whereas LOADM would get you the value stored at that location val code = IRCodeChunk(null, null) @@ -160,7 +160,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { private fun translate(arrayIx: PtArrayIndexer): ExpressionCodeResult { val eltSize = codeGen.program.memsizer.memorySize(arrayIx.type) - val vmDt = codeGen.irType(arrayIx.type) + val vmDt = irType(arrayIx.type) val result = mutableListOf() val arrayVarSymbol = arrayIx.variable.name @@ -210,7 +210,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { val result = mutableListOf() val tr = translateExpression(expr.value) addToResult(result, tr, tr.resultReg, tr.resultFpReg) - val vmDt = codeGen.irType(expr.type) + val vmDt = irType(expr.type) when(expr.operator) { "+" -> { } "-" -> { @@ -326,12 +326,12 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { else -> throw AssemblyError("weird cast type") } - return ExpressionCodeResult(result, codeGen.irType(cast.type), actualResultReg2, actualResultFpReg2) + return ExpressionCodeResult(result, irType(cast.type), actualResultReg2, actualResultFpReg2) } private fun translate(binExpr: PtBinaryExpression): ExpressionCodeResult { require(!codeGen.options.useNewExprCode) - val vmDt = codeGen.irType(binExpr.left.type) + val vmDt = irType(binExpr.left.type) val signed = binExpr.left.type in SignedDatatypes return when(binExpr.operator) { "+" -> operatorPlus(binExpr, vmDt) @@ -360,7 +360,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { val result = mutableListOf() for ((index, argspec) in fcall.args.zip(callTarget.parameters).withIndex()) { val (arg, param) = argspec - val paramDt = codeGen.irType(param.type) + val paramDt = irType(param.type) val tr = translateExpression(arg) result += tr.chunks if(paramDt==IRDataType.FLOAT) @@ -369,7 +369,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { addInstr(result, IRInstruction(Opcode.SETPARAM, paramDt, reg1 = tr.resultReg, immediate = index), null) } // for ((arg, parameter) in fcall.args.zip(callTarget.parameters)) { -// val paramDt = codeGen.irType(parameter.type) +// val paramDt = irType(parameter.type) // val symbol = "${fcall.name}.${parameter.name}" // if(codeGen.isZero(arg)) { // addInstr(result, IRInstruction(Opcode.STOREZM, paramDt, labelSymbol = symbol), null) @@ -396,15 +396,15 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { addInstr(result, IRInstruction(Opcode.CALLR, IRDataType.FLOAT, fpReg1=resultFpReg, labelSymbol=fcall.name), null) } else { resultReg = codeGen.registers.nextFree() - addInstr(result, IRInstruction(Opcode.CALLR, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol=fcall.name), null) + addInstr(result, IRInstruction(Opcode.CALLR, irType(fcall.type), reg1=resultReg, labelSymbol=fcall.name), null) } - ExpressionCodeResult(result, codeGen.irType(fcall.type), resultReg, resultFpReg) + ExpressionCodeResult(result, irType(fcall.type), resultReg, resultFpReg) } } is StRomSub -> { val result = mutableListOf() for ((arg, parameter) in fcall.args.zip(callTarget.parameters)) { - val paramDt = codeGen.irType(parameter.type) + val paramDt = irType(parameter.type) val paramRegStr = if(parameter.register.registerOrPair!=null) parameter.register.registerOrPair.toString() else parameter.register.statusflag.toString() if(codeGen.isZero(arg)) { addInstr(result, IRInstruction(Opcode.STOREZCPU, paramDt, labelSymbol = paramRegStr), null) @@ -427,20 +427,20 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { throw AssemblyError("doesn't support float register result in asm romsub") val returns = callTarget.returns.single() val regStr = if(returns.register.registerOrPair!=null) returns.register.registerOrPair.toString() else returns.register.statusflag.toString() - addInstr(result, IRInstruction(Opcode.LOADCPU, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) + addInstr(result, IRInstruction(Opcode.LOADCPU, irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) } else -> { val returnRegister = callTarget.returns.singleOrNull{ it.register.registerOrPair!=null } if(returnRegister!=null) { // we skip the other values returned in the status flags. val regStr = returnRegister.register.registerOrPair.toString() - addInstr(result, IRInstruction(Opcode.LOADCPU, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) + addInstr(result, IRInstruction(Opcode.LOADCPU, irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) } else { val firstReturnRegister = callTarget.returns.firstOrNull{ it.register.registerOrPair!=null } if(firstReturnRegister!=null) { // we just take the first register return value and ignore the rest. val regStr = firstReturnRegister.register.registerOrPair.toString() - addInstr(result, IRInstruction(Opcode.LOADCPU, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) + addInstr(result, IRInstruction(Opcode.LOADCPU, irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) } else { throw AssemblyError("invalid number of return values from call") } @@ -448,7 +448,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { } } } - return ExpressionCodeResult(result, if(fcall.void) IRDataType.BYTE else codeGen.irType(fcall.type), resultReg, -1) + return ExpressionCodeResult(result, if(fcall.void) IRDataType.BYTE else irType(fcall.type), resultReg, -1) } else -> throw AssemblyError("invalid node type") } diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt index 529e8121d..fa09e63f0 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/IRCodeGen.kt @@ -1524,20 +1524,6 @@ class IRCodeGen( } } - - internal fun irType(type: DataType): IRDataType { - return when(type) { - DataType.BOOL, - DataType.UBYTE, - DataType.BYTE -> IRDataType.BYTE - DataType.UWORD, - DataType.WORD -> IRDataType.WORD - DataType.FLOAT -> IRDataType.FLOAT - in PassByReferenceDatatypes -> IRDataType.WORD - else -> throw AssemblyError("no IR datatype for $type") - } - } - private var labelSequenceNumber = 0 internal fun createLabelName(): String { labelSequenceNumber++ diff --git a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt index 45cd1caaf..7765daac8 100644 --- a/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt +++ b/codeOptimizers/src/prog8/optimizer/ConstantIdentifierReplacer.kt @@ -68,6 +68,39 @@ class VarConstantValueTypeAdjuster(private val program: Program, private val err return noModifications } + + override fun after(functionCallExpr: FunctionCallExpression, parent: Node): Iterable { + // choose specific builtin function for the given types + val func = functionCallExpr.target.nameInSource + if(func==listOf("min") || func==listOf("max")) { + val t1 = functionCallExpr.args[0].inferType(program) + val t2 = functionCallExpr.args[1].inferType(program) + if(t1.isKnown && t2.isKnown) { + val funcName = func[0] + val replaceFunc: String + if(t1.isBytes && t2.isBytes) { + replaceFunc = if(t1.istype(DataType.BYTE) || t2.istype(DataType.BYTE)) + "${funcName}__byte" + else + "${funcName}__ubyte" + } else if(t1.isInteger && t2.isInteger) { + replaceFunc = if(t1.istype(DataType.WORD) || t2.istype(DataType.WORD)) + "${funcName}__word" + else + "${funcName}__uword" + } else if(t1.isNumeric && t2.isNumeric) { + replaceFunc = "${funcName}__float" + } else { + errors.err("expected numeric arguments", functionCallExpr.position) + return noModifications + } + return listOf(IAstModification.SetExpression({functionCallExpr.target = it as IdentifierReference}, + IdentifierReference(listOf(replaceFunc), functionCallExpr.target.position), + functionCallExpr)) + } + } + return noModifications + } } diff --git a/compiler/src/prog8/compiler/BuiltinFunctions.kt b/compiler/src/prog8/compiler/BuiltinFunctions.kt index e22ee39f0..7bb987d8c 100644 --- a/compiler/src/prog8/compiler/BuiltinFunctions.kt +++ b/compiler/src/prog8/compiler/BuiltinFunctions.kt @@ -7,10 +7,7 @@ import prog8.ast.base.SyntaxError import prog8.ast.expressions.* import prog8.ast.statements.VarDecl import prog8.code.core.* -import kotlin.math.abs -import kotlin.math.sign -import kotlin.math.sqrt - +import kotlin.math.* private typealias ConstExpressionCaller = (args: List, position: Position, program: Program) -> NumericLiteral @@ -24,7 +21,15 @@ internal val constEvaluatorsForBuiltinFuncs: Map "all" to { a, p, prg -> collectionArg(a, p, prg, ::builtinAll) }, "lsb" to { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> (x and 255).toDouble() } }, "msb" to { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> (x ushr 8 and 255).toDouble()} }, - "mkword" to ::builtinMkword + "mkword" to ::builtinMkword, + "min__ubyte" to ::builtinMinUByte, + "min__byte" to ::builtinMinByte, + "min__uword" to ::builtinMinUWord, + "min__word" to ::builtinMinWord, + "max__ubyte" to ::builtinMaxUByte, + "max__byte" to ::builtinMaxByte, + "max__uword" to ::builtinMaxUWord, + "max__word" to ::builtinMaxWord, ) private fun builtinAny(array: List): Double = if(array.any { it!=0.0 }) 1.0 else 0.0 @@ -156,3 +161,75 @@ private fun builtinSgn(args: List, position: Position, program: Prog val constval = args[0].constValue(program) ?: throw NotConstArgumentException() return NumericLiteral(DataType.BYTE, constval.number.sign, position) } + +private fun builtinMinByte(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("min requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = min(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.BYTE, result.toDouble(), position) +} + +private fun builtinMinUByte(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("min requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = min(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.UBYTE, result.toDouble(), position) +} + +private fun builtinMinWord(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("min requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = min(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.WORD, result.toDouble(), position) +} + +private fun builtinMinUWord(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("min requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = min(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.UWORD, result.toDouble(), position) +} + +private fun builtinMaxByte(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("max requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = max(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.BYTE, result.toDouble(), position) +} + +private fun builtinMaxUByte(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("max requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = max(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.UBYTE, result.toDouble(), position) +} + +private fun builtinMaxWord(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("max requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = max(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.WORD, result.toDouble(), position) +} + +private fun builtinMaxUWord(args: List, position: Position, program: Program): NumericLiteral { + if (args.size != 2) + throw SyntaxError("max requires 2 arguments", position) + val val1 = args[0].constValue(program) ?: throw NotConstArgumentException() + val val2 = args[1].constValue(program) ?: throw NotConstArgumentException() + val result = max(val1.number.toInt(), val2.number.toInt()) + return NumericLiteral(DataType.UWORD, result.toDouble(), position) +} diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 0342f2af5..f85097abb 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -1,6 +1,14 @@ TODO ==== +- try to reintroduce builtin functions max/maxw/min/minw that take 2 args and return the largest/smallest of them. + This is a major change because it will likely break existing code that is now using min and max as variable names. + Add "polymorphism" that translates min -> min__ubyte etc etc. + Also add optimization that changes the word variant to byte variant if the operands are bytes. + Add to docs. + +- add polymorphism to other builtin functions as well! Fix docs. + - once 9.0 is stable, upgrade other programs (assem, shell, etc) to it. ... @@ -8,9 +16,6 @@ TODO For 9.0 major changes ^^^^^^^^^^^^^^^^^^^^^ -- try to reintroduce builtin functions max/maxw/min/minw that take 2 args and return the largest/smallest of them. - This is a major change because it will likely break existing code that is now using min and max as variable names. - Also add optimization that changes the word variant to byte variant if the operands are bytes. - 6502 codegen: see if we can let for loops skip the loop if startvar>endvar, without adding a lot of code size/duplicating the loop condition. It is documented behavior to now loop 'around' $00 but it's too easy to forget about! Lot of work because of so many special cases in ForLoopsAsmgen..... diff --git a/examples/test.p8 b/examples/test.p8 index 6f54b5e35..6bceb6414 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -4,16 +4,33 @@ main { sub start() { - txt.print("hello") - ; foobar() - } + ubyte v1 = 11 + ubyte v2 = 88 + byte v1s = 22 + byte v2s = -99 - asmsub foobar() { - %asm {{ - nop - rts + uword w1 = 1111 + uword w2 = 8888 + word w1s = 2222 + word w2s = -9999 - }} + txt.print_uw(min(v1, v2)) + txt.spc() + txt.print_w(min(v1s, v2s)) + txt.spc() + txt.print_uw(max(v1, v2)) + txt.spc() + txt.print_w(max(v1s, v2s)) + txt.nl() + + txt.print_uw(min(w1, w2)) + txt.spc() + txt.print_w(min(w1s, w2s)) + txt.spc() + txt.print_uw(max(w1, w2)) + txt.spc() + txt.print_w(max(w1s, w2s)) + txt.nl() } } diff --git a/intermediate/src/prog8/intermediate/IRInstructions.kt b/intermediate/src/prog8/intermediate/IRInstructions.kt index a2921842e..dcfe1aff5 100644 --- a/intermediate/src/prog8/intermediate/IRInstructions.kt +++ b/intermediate/src/prog8/intermediate/IRInstructions.kt @@ -1,5 +1,8 @@ package prog8.intermediate +import prog8.code.core.AssemblyError +import prog8.code.core.DataType +import prog8.code.core.PassByReferenceDatatypes import prog8.code.core.toHex /* @@ -88,7 +91,7 @@ bger reg1, reg2, address - jump to location in program given by l bgesr reg1, reg2, address - jump to location in program given by location, if reg1 >= reg2 (signed) ble reg1, value, address - jump to location in program given by location, if reg1 <= immediate value (unsigned) bles reg1, value, address - jump to location in program given by location, if reg1 <= immediate value (signed) -( NOTE: there are no bltr/bler instructions because these are equivalent to bgtr/bger with the register operands swapped around.) + ( NOTE: there are no bltr/bler instructions because these are equivalent to bgtr/bger with the register operands swapped around.) sz reg1, reg2 - set reg1=1 if reg2==0, otherwise set reg1=0 snz reg1, reg2 - set reg1=1 if reg2!=0, otherwise set reg1=0 seq reg1, reg2 - set reg1=1 if reg1 == reg2, otherwise set reg1=0 @@ -853,3 +856,17 @@ data class IRInstruction( return result.joinToString("").trimEnd() } } + + +fun irType(type: DataType): IRDataType { + return when(type) { + DataType.BOOL, + DataType.UBYTE, + DataType.BYTE -> IRDataType.BYTE + DataType.UWORD, + DataType.WORD -> IRDataType.WORD + DataType.FLOAT -> IRDataType.FLOAT + in PassByReferenceDatatypes -> IRDataType.WORD + else -> throw AssemblyError("no IR datatype for $type") + } +}