adding min() and max()

This commit is contained in:
Irmen de Jong 2023-04-29 14:22:04 +02:00
parent 4274296cf3
commit c07eda15b1
11 changed files with 360 additions and 54 deletions

View File

@ -93,6 +93,14 @@ val BuiltinFunctions: Map<String, FSignature> = mapOf(
"lsb" to FSignature(true, listOf(FParam("value", arrayOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE), "lsb" to FSignature(true, listOf(FParam("value", arrayOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE),
"msb" to FSignature(true, listOf(FParam("value", arrayOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE), "msb" to FSignature(true, listOf(FParam("value", arrayOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE),
"mkword" to FSignature(true, listOf(FParam("msb", arrayOf(DataType.UBYTE)), FParam("lsb", arrayOf(DataType.UBYTE))), DataType.UWORD), "mkword" to FSignature(true, listOf(FParam("msb", arrayOf(DataType.UBYTE)), FParam("lsb", arrayOf(DataType.UBYTE))), DataType.UWORD),
"min__byte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.BYTE)), FParam("val2", arrayOf(DataType.BYTE))), DataType.BYTE),
"min__ubyte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UBYTE)), FParam("val2", arrayOf(DataType.UBYTE))), DataType.UBYTE),
"min__word" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.WORD)), FParam("val2", arrayOf(DataType.WORD))), DataType.WORD),
"min__uword" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UWORD)), FParam("val2", arrayOf(DataType.UWORD))), DataType.UWORD),
"max__byte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.BYTE)), FParam("val2", arrayOf(DataType.BYTE))), DataType.BYTE),
"max__ubyte" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UBYTE)), FParam("val2", arrayOf(DataType.UBYTE))), DataType.UBYTE),
"max__word" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.WORD)), FParam("val2", arrayOf(DataType.WORD))), DataType.WORD),
"max__uword" to FSignature(true, listOf(FParam("val1", arrayOf(DataType.UWORD)), FParam("val2", arrayOf(DataType.UWORD))), DataType.UWORD),
"peek" to FSignature(true, listOf(FParam("address", arrayOf(DataType.UWORD))), DataType.UBYTE), "peek" to FSignature(true, listOf(FParam("address", arrayOf(DataType.UWORD))), DataType.UBYTE),
"peekw" to FSignature(true, listOf(FParam("address", arrayOf(DataType.UWORD))), DataType.UWORD), "peekw" to FSignature(true, listOf(FParam("address", arrayOf(DataType.UWORD))), DataType.UWORD),
"poke" to FSignature(false, listOf(FParam("address", arrayOf(DataType.UWORD)), FParam("value", arrayOf(DataType.UBYTE))), null), "poke" to FSignature(false, listOf(FParam("address", arrayOf(DataType.UWORD)), FParam("value", arrayOf(DataType.UBYTE))), null),

View File

@ -31,6 +31,8 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram,
"msb" -> funcMsb(fcall, resultToStack, resultRegister) "msb" -> funcMsb(fcall, resultToStack, resultRegister)
"lsb" -> funcLsb(fcall, resultToStack, resultRegister) "lsb" -> funcLsb(fcall, resultToStack, resultRegister)
"mkword" -> funcMkword(fcall, resultToStack, resultRegister) "mkword" -> funcMkword(fcall, resultToStack, resultRegister)
"min__byte", "min__ubyte", "min__word", "min__uword" -> funcMin(fcall, resultToStack, resultRegister)
"max__byte", "max__ubyte", "max__word", "max__uword" -> funcMax(fcall, resultToStack, resultRegister)
"abs" -> funcAbs(fcall, resultToStack, resultRegister, sscope) "abs" -> funcAbs(fcall, resultToStack, resultRegister, sscope)
"any", "all" -> funcAnyAll(fcall, resultToStack, resultRegister, sscope) "any", "all" -> funcAnyAll(fcall, resultToStack, resultRegister, sscope)
"sgn" -> funcSgn(fcall, resultToStack, resultRegister, sscope) "sgn" -> funcSgn(fcall, resultToStack, resultRegister, sscope)
@ -826,6 +828,126 @@ internal class BuiltinFunctionsAsmGen(private val program: PtProgram,
} }
} }
private fun funcMin(fcall: PtBuiltinFunctionCall, resultToStack: Boolean, resultRegister: RegisterOrPair?) {
val signed = fcall.type in SignedDatatypes
if(fcall.type in ByteDatatypes) {
asmgen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_B1", fcall.type) // right
asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A) // left
asmgen.out(" cmp P8ZP_SCRATCH_B1")
if(signed) asmgen.out(" bmi +") else asmgen.out(" bcc +")
asmgen.out("""
lda P8ZP_SCRATCH_B1
+""")
if(resultToStack) {
asmgen.out(" sta P8ESTACK_LO,x | dex")
} else {
val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen)
asmgen.assignRegister(RegisterOrPair.A, targetReg)
}
} else {
asmgen.assignExpressionToVariable(fcall.args[0], "P8ZP_SCRATCH_W1", fcall.type) // left
asmgen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_W2", fcall.type) // right
if(signed) {
asmgen.out("""
lda P8ZP_SCRATCH_W1
ldy P8ZP_SCRATCH_W1+1
cmp P8ZP_SCRATCH_W2
tya
sbc P8ZP_SCRATCH_W2+1
bvc +
eor #$80
+ bpl +
lda P8ZP_SCRATCH_W1
ldy P8ZP_SCRATCH_W1+1
jmp ++
+ lda P8ZP_SCRATCH_W2
ldy P8ZP_SCRATCH_W2+1
+""")
} else {
asmgen.out("""
lda P8ZP_SCRATCH_W1+1
cmp P8ZP_SCRATCH_W2+1
bcc ++
bne +
lda P8ZP_SCRATCH_W1
cmp P8ZP_SCRATCH_W2
bcc ++
+ lda P8ZP_SCRATCH_W2
ldy P8ZP_SCRATCH_W2+1
jmp ++
+ lda P8ZP_SCRATCH_W1
ldy P8ZP_SCRATCH_W1+1
+""")
}
if(resultToStack) {
asmgen.out(" sta P8ESTACK_LO,x | sty P8ESTACK_HI,x | dex")
} else {
val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen)
asmgen.assignRegister(RegisterOrPair.AY, targetReg)
}
}
}
private fun funcMax(fcall: PtBuiltinFunctionCall, resultToStack: Boolean, resultRegister: RegisterOrPair?) {
val signed = fcall.type in SignedDatatypes
if(fcall.type in ByteDatatypes) {
asmgen.assignExpressionToVariable(fcall.args[0], "P8ZP_SCRATCH_B1", fcall.type) // left
asmgen.assignExpressionToRegister(fcall.args[1], RegisterOrPair.A) // right
asmgen.out(" cmp P8ZP_SCRATCH_B1")
if(signed) asmgen.out(" bpl +") else asmgen.out(" bcs +")
asmgen.out("""
lda P8ZP_SCRATCH_B1
+""")
if(resultToStack) {
asmgen.out(" sta P8ESTACK_LO,x | dex")
} else {
val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen)
asmgen.assignRegister(RegisterOrPair.A, targetReg)
}
} else {
asmgen.assignExpressionToVariable(fcall.args[0], "P8ZP_SCRATCH_W1", fcall.type) // left
asmgen.assignExpressionToVariable(fcall.args[1], "P8ZP_SCRATCH_W2", fcall.type) // right
if(signed) {
asmgen.out("""
lda P8ZP_SCRATCH_W1
ldy P8ZP_SCRATCH_W1+1
cmp P8ZP_SCRATCH_W2
tya
sbc P8ZP_SCRATCH_W2+1
bvc +
eor #$80
+ bmi +
lda P8ZP_SCRATCH_W1
ldy P8ZP_SCRATCH_W1+1
jmp ++
+ lda P8ZP_SCRATCH_W2
ldy P8ZP_SCRATCH_W2+1
+""")
} else {
asmgen.out("""
lda P8ZP_SCRATCH_W1+1
cmp P8ZP_SCRATCH_W2+1
bcc ++
bne +
lda P8ZP_SCRATCH_W1
cmp P8ZP_SCRATCH_W2
bcc ++
+ lda P8ZP_SCRATCH_W1
ldy P8ZP_SCRATCH_W1+1
jmp ++
+ lda P8ZP_SCRATCH_W2
ldy P8ZP_SCRATCH_W2+1
+""")
}
if(resultToStack) {
asmgen.out(" sta P8ESTACK_LO,x | sty P8ESTACK_HI,x | dex")
} else {
val targetReg = AsmAssignTarget.fromRegisters(resultRegister!!, signed, fcall.position, fcall.definingISub(), asmgen)
asmgen.assignRegister(RegisterOrPair.AY, targetReg)
}
}
}
private fun funcMkword(fcall: PtBuiltinFunctionCall, resultToStack: Boolean, resultRegister: RegisterOrPair?) { private fun funcMkword(fcall: PtBuiltinFunctionCall, resultToStack: Boolean, resultRegister: RegisterOrPair?) {
if(resultToStack) { if(resultToStack) {
asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A) // msb asmgen.assignExpressionToRegister(fcall.args[0], RegisterOrPair.A) // msb

View File

@ -46,7 +46,7 @@ internal class AssignmentGen(private val codeGen: IRCodeGen, private val express
assignment: PtAugmentedAssign assignment: PtAugmentedAssign
): IRCodeChunks { ): IRCodeChunks {
val value = assignment.value val value = assignment.value
val vmDt = codeGen.irType(value.type) val vmDt = irType(value.type)
return when(assignment.operator) { return when(assignment.operator) {
"+" -> expressionEval.operatorPlusInplace(address, null, vmDt, value) "+" -> expressionEval.operatorPlusInplace(address, null, vmDt, value)
"-" -> expressionEval.operatorMinusInplace(address, null, vmDt, value) "-" -> expressionEval.operatorMinusInplace(address, null, vmDt, value)
@ -72,7 +72,7 @@ internal class AssignmentGen(private val codeGen: IRCodeGen, private val express
private fun assignVarAugmented(symbol: String, assignment: PtAugmentedAssign): IRCodeChunks { private fun assignVarAugmented(symbol: String, assignment: PtAugmentedAssign): IRCodeChunks {
val value = assignment.value val value = assignment.value
val targetDt = codeGen.irType(assignment.target.type) val targetDt = irType(assignment.target.type)
return when (assignment.operator) { return when (assignment.operator) {
"+=" -> expressionEval.operatorPlusInplace(null, symbol, targetDt, value) "+=" -> expressionEval.operatorPlusInplace(null, symbol, targetDt, value)
"-=" -> expressionEval.operatorMinusInplace(null, symbol, targetDt, value) "-=" -> expressionEval.operatorMinusInplace(null, symbol, targetDt, value)
@ -161,7 +161,7 @@ internal class AssignmentGen(private val codeGen: IRCodeGen, private val express
val targetIdent = assignment.target.identifier val targetIdent = assignment.target.identifier
val targetMemory = assignment.target.memory val targetMemory = assignment.target.memory
val targetArray = assignment.target.array val targetArray = assignment.target.array
val vmDt = codeGen.irType(assignment.value.type) val vmDt = irType(assignment.value.type)
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
var valueRegister = -1 var valueRegister = -1

View File

@ -4,6 +4,7 @@ import prog8.code.StStaticVariable
import prog8.code.ast.* import prog8.code.ast.*
import prog8.code.core.AssemblyError import prog8.code.core.AssemblyError
import prog8.code.core.DataType import prog8.code.core.DataType
import prog8.code.core.SignedDatatypes
import prog8.intermediate.* import prog8.intermediate.*
@ -37,6 +38,8 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe
"pokew" -> funcPokeW(call) "pokew" -> funcPokeW(call)
"pokemon" -> ExpressionCodeResult.EMPTY // easter egg function "pokemon" -> ExpressionCodeResult.EMPTY // easter egg function
"mkword" -> funcMkword(call) "mkword" -> funcMkword(call)
"min__byte", "min__ubyte", "min__word", "min__uword" -> funcMin(call)
"max__byte", "max__ubyte", "max__word", "max__uword" -> funcMax(call)
"sort" -> funcSort(call) "sort" -> funcSort(call)
"reverse" -> funcReverse(call) "reverse" -> funcReverse(call)
"rol" -> funcRolRor(Opcode.ROXL, call) "rol" -> funcRolRor(Opcode.ROXL, call)
@ -96,7 +99,7 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe
addToResult(result, leftTr, leftTr.resultReg, -1) addToResult(result, leftTr, leftTr.resultReg, -1)
val rightTr = exprGen.translateExpression(call.args[1]) val rightTr = exprGen.translateExpression(call.args[1])
addToResult(result, rightTr, rightTr.resultReg, -1) addToResult(result, rightTr, rightTr.resultReg, -1)
val dt = codeGen.irType(call.args[0].type) val dt = irType(call.args[0].type)
result += IRCodeChunk(null, null).also { result += IRCodeChunk(null, null).also {
it += IRInstruction(Opcode.CMP, dt, reg1=leftTr.resultReg, reg2=rightTr.resultReg) it += IRInstruction(Opcode.CMP, dt, reg1=leftTr.resultReg, reg2=rightTr.resultReg)
} }
@ -199,7 +202,7 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe
private fun funcSgn(call: PtBuiltinFunctionCall): ExpressionCodeResult { private fun funcSgn(call: PtBuiltinFunctionCall): ExpressionCodeResult {
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
val vmDt = codeGen.irType(call.type) val vmDt = irType(call.type)
val tr = exprGen.translateExpression(call.args.single()) val tr = exprGen.translateExpression(call.args.single())
addToResult(result, tr, tr.resultReg, -1) addToResult(result, tr, tr.resultReg, -1)
val resultReg = codeGen.registers.nextFree() val resultReg = codeGen.registers.nextFree()
@ -317,6 +320,44 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe
return ExpressionCodeResult(result, IRDataType.WORD, lsbTr.resultReg, -1) return ExpressionCodeResult(result, IRDataType.WORD, lsbTr.resultReg, -1)
} }
private fun funcMin(call: PtBuiltinFunctionCall): ExpressionCodeResult {
val type = irType(call.type)
val result = mutableListOf<IRCodeChunkBase>()
val leftTr = exprGen.translateExpression(call.args[0])
addToResult(result, leftTr, leftTr.resultReg, -1)
val rightTr = exprGen.translateExpression(call.args[1])
addToResult(result, rightTr, rightTr.resultReg, -1)
val comparisonOpcode = if(call.type in SignedDatatypes) Opcode.BGTSR else Opcode.BGTR
val after = codeGen.createLabelName()
result += IRCodeChunk(null, null).also {
it += IRInstruction(comparisonOpcode, type, reg1 = rightTr.resultReg, reg2 = leftTr.resultReg, labelSymbol = after)
// right <= left, take right
it += IRInstruction(Opcode.LOADR, type, reg1=leftTr.resultReg, reg2=rightTr.resultReg)
it += IRInstruction(Opcode.JUMP, labelSymbol = after)
}
result += IRCodeChunk(after, null)
return ExpressionCodeResult(result, type, leftTr.resultReg, -1)
}
private fun funcMax(call: PtBuiltinFunctionCall): ExpressionCodeResult {
val type = irType(call.type)
val result = mutableListOf<IRCodeChunkBase>()
val leftTr = exprGen.translateExpression(call.args[0])
addToResult(result, leftTr, leftTr.resultReg, -1)
val rightTr = exprGen.translateExpression(call.args[1])
addToResult(result, rightTr, rightTr.resultReg, -1)
val comparisonOpcode = if(call.type in SignedDatatypes) Opcode.BGTSR else Opcode.BGTR
val after = codeGen.createLabelName()
result += IRCodeChunk(null, null).also {
it += IRInstruction(comparisonOpcode, type, reg1 = leftTr.resultReg, reg2 = rightTr.resultReg, labelSymbol = after)
// right >= left, take right
it += IRInstruction(Opcode.LOADR, type, reg1=leftTr.resultReg, reg2=rightTr.resultReg)
it += IRInstruction(Opcode.JUMP, labelSymbol = after)
}
result += IRCodeChunk(after, null)
return ExpressionCodeResult(result, type, leftTr.resultReg, -1)
}
private fun funcPokeW(call: PtBuiltinFunctionCall): ExpressionCodeResult { private fun funcPokeW(call: PtBuiltinFunctionCall): ExpressionCodeResult {
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
if(codeGen.isZero(call.args[1])) { if(codeGen.isZero(call.args[1])) {
@ -455,7 +496,7 @@ internal class BuiltinFuncGen(private val codeGen: IRCodeGen, private val exprGe
} }
private fun funcRolRor(opcode: Opcode, call: PtBuiltinFunctionCall): ExpressionCodeResult { private fun funcRolRor(opcode: Opcode, call: PtBuiltinFunctionCall): ExpressionCodeResult {
val vmDt = codeGen.irType(call.args[0].type) val vmDt = irType(call.args[0].type)
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
val tr = exprGen.translateExpression(call.args[0]) val tr = exprGen.translateExpression(call.args[0])
addToResult(result, tr, tr.resultReg, -1) addToResult(result, tr, tr.resultReg, -1)

View File

@ -28,10 +28,10 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
fun translateExpression(expr: PtExpression): ExpressionCodeResult { fun translateExpression(expr: PtExpression): ExpressionCodeResult {
return when (expr) { return when (expr) {
is PtMachineRegister -> { is PtMachineRegister -> {
ExpressionCodeResult(emptyList(), codeGen.irType(expr.type), expr.register, -1) ExpressionCodeResult(emptyList(), irType(expr.type), expr.register, -1)
} }
is PtNumber -> { is PtNumber -> {
val vmDt = codeGen.irType(expr.type) val vmDt = irType(expr.type)
val code = IRCodeChunk(null, null) val code = IRCodeChunk(null, null)
if(vmDt==IRDataType.FLOAT) { if(vmDt==IRDataType.FLOAT) {
val resultFpRegister = codeGen.registers.nextFreeFloat() val resultFpRegister = codeGen.registers.nextFreeFloat()
@ -45,7 +45,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
} }
} }
is PtIdentifier -> { is PtIdentifier -> {
val vmDt = codeGen.irType(expr.type) val vmDt = irType(expr.type)
val code = IRCodeChunk(null, null) val code = IRCodeChunk(null, null)
if (expr.type in PassByValueDatatypes) { if (expr.type in PassByValueDatatypes) {
if(vmDt==IRDataType.FLOAT) { if(vmDt==IRDataType.FLOAT) {
@ -66,7 +66,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
} }
} }
is PtAddressOf -> { is PtAddressOf -> {
val vmDt = codeGen.irType(expr.type) val vmDt = irType(expr.type)
val symbol = expr.identifier.name val symbol = expr.identifier.name
// note: LOAD <symbol> gets you the address of the symbol, whereas LOADM <symbol> would get you the value stored at that location // note: LOAD <symbol> gets you the address of the symbol, whereas LOADM <symbol> would get you the value stored at that location
val code = IRCodeChunk(null, null) val code = IRCodeChunk(null, null)
@ -160,7 +160,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
private fun translate(arrayIx: PtArrayIndexer): ExpressionCodeResult { private fun translate(arrayIx: PtArrayIndexer): ExpressionCodeResult {
val eltSize = codeGen.program.memsizer.memorySize(arrayIx.type) val eltSize = codeGen.program.memsizer.memorySize(arrayIx.type)
val vmDt = codeGen.irType(arrayIx.type) val vmDt = irType(arrayIx.type)
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
val arrayVarSymbol = arrayIx.variable.name val arrayVarSymbol = arrayIx.variable.name
@ -210,7 +210,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
val tr = translateExpression(expr.value) val tr = translateExpression(expr.value)
addToResult(result, tr, tr.resultReg, tr.resultFpReg) addToResult(result, tr, tr.resultReg, tr.resultFpReg)
val vmDt = codeGen.irType(expr.type) val vmDt = irType(expr.type)
when(expr.operator) { when(expr.operator) {
"+" -> { } "+" -> { }
"-" -> { "-" -> {
@ -326,12 +326,12 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
else -> throw AssemblyError("weird cast type") else -> throw AssemblyError("weird cast type")
} }
return ExpressionCodeResult(result, codeGen.irType(cast.type), actualResultReg2, actualResultFpReg2) return ExpressionCodeResult(result, irType(cast.type), actualResultReg2, actualResultFpReg2)
} }
private fun translate(binExpr: PtBinaryExpression): ExpressionCodeResult { private fun translate(binExpr: PtBinaryExpression): ExpressionCodeResult {
require(!codeGen.options.useNewExprCode) require(!codeGen.options.useNewExprCode)
val vmDt = codeGen.irType(binExpr.left.type) val vmDt = irType(binExpr.left.type)
val signed = binExpr.left.type in SignedDatatypes val signed = binExpr.left.type in SignedDatatypes
return when(binExpr.operator) { return when(binExpr.operator) {
"+" -> operatorPlus(binExpr, vmDt) "+" -> operatorPlus(binExpr, vmDt)
@ -360,7 +360,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
for ((index, argspec) in fcall.args.zip(callTarget.parameters).withIndex()) { for ((index, argspec) in fcall.args.zip(callTarget.parameters).withIndex()) {
val (arg, param) = argspec val (arg, param) = argspec
val paramDt = codeGen.irType(param.type) val paramDt = irType(param.type)
val tr = translateExpression(arg) val tr = translateExpression(arg)
result += tr.chunks result += tr.chunks
if(paramDt==IRDataType.FLOAT) if(paramDt==IRDataType.FLOAT)
@ -369,7 +369,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
addInstr(result, IRInstruction(Opcode.SETPARAM, paramDt, reg1 = tr.resultReg, immediate = index), null) addInstr(result, IRInstruction(Opcode.SETPARAM, paramDt, reg1 = tr.resultReg, immediate = index), null)
} }
// for ((arg, parameter) in fcall.args.zip(callTarget.parameters)) { // for ((arg, parameter) in fcall.args.zip(callTarget.parameters)) {
// val paramDt = codeGen.irType(parameter.type) // val paramDt = irType(parameter.type)
// val symbol = "${fcall.name}.${parameter.name}" // val symbol = "${fcall.name}.${parameter.name}"
// if(codeGen.isZero(arg)) { // if(codeGen.isZero(arg)) {
// addInstr(result, IRInstruction(Opcode.STOREZM, paramDt, labelSymbol = symbol), null) // addInstr(result, IRInstruction(Opcode.STOREZM, paramDt, labelSymbol = symbol), null)
@ -396,15 +396,15 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
addInstr(result, IRInstruction(Opcode.CALLR, IRDataType.FLOAT, fpReg1=resultFpReg, labelSymbol=fcall.name), null) addInstr(result, IRInstruction(Opcode.CALLR, IRDataType.FLOAT, fpReg1=resultFpReg, labelSymbol=fcall.name), null)
} else { } else {
resultReg = codeGen.registers.nextFree() resultReg = codeGen.registers.nextFree()
addInstr(result, IRInstruction(Opcode.CALLR, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol=fcall.name), null) addInstr(result, IRInstruction(Opcode.CALLR, irType(fcall.type), reg1=resultReg, labelSymbol=fcall.name), null)
} }
ExpressionCodeResult(result, codeGen.irType(fcall.type), resultReg, resultFpReg) ExpressionCodeResult(result, irType(fcall.type), resultReg, resultFpReg)
} }
} }
is StRomSub -> { is StRomSub -> {
val result = mutableListOf<IRCodeChunkBase>() val result = mutableListOf<IRCodeChunkBase>()
for ((arg, parameter) in fcall.args.zip(callTarget.parameters)) { for ((arg, parameter) in fcall.args.zip(callTarget.parameters)) {
val paramDt = codeGen.irType(parameter.type) val paramDt = irType(parameter.type)
val paramRegStr = if(parameter.register.registerOrPair!=null) parameter.register.registerOrPair.toString() else parameter.register.statusflag.toString() val paramRegStr = if(parameter.register.registerOrPair!=null) parameter.register.registerOrPair.toString() else parameter.register.statusflag.toString()
if(codeGen.isZero(arg)) { if(codeGen.isZero(arg)) {
addInstr(result, IRInstruction(Opcode.STOREZCPU, paramDt, labelSymbol = paramRegStr), null) addInstr(result, IRInstruction(Opcode.STOREZCPU, paramDt, labelSymbol = paramRegStr), null)
@ -427,20 +427,20 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
throw AssemblyError("doesn't support float register result in asm romsub") throw AssemblyError("doesn't support float register result in asm romsub")
val returns = callTarget.returns.single() val returns = callTarget.returns.single()
val regStr = if(returns.register.registerOrPair!=null) returns.register.registerOrPair.toString() else returns.register.statusflag.toString() val regStr = if(returns.register.registerOrPair!=null) returns.register.registerOrPair.toString() else returns.register.statusflag.toString()
addInstr(result, IRInstruction(Opcode.LOADCPU, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) addInstr(result, IRInstruction(Opcode.LOADCPU, irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null)
} }
else -> { else -> {
val returnRegister = callTarget.returns.singleOrNull{ it.register.registerOrPair!=null } val returnRegister = callTarget.returns.singleOrNull{ it.register.registerOrPair!=null }
if(returnRegister!=null) { if(returnRegister!=null) {
// we skip the other values returned in the status flags. // we skip the other values returned in the status flags.
val regStr = returnRegister.register.registerOrPair.toString() val regStr = returnRegister.register.registerOrPair.toString()
addInstr(result, IRInstruction(Opcode.LOADCPU, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) addInstr(result, IRInstruction(Opcode.LOADCPU, irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null)
} else { } else {
val firstReturnRegister = callTarget.returns.firstOrNull{ it.register.registerOrPair!=null } val firstReturnRegister = callTarget.returns.firstOrNull{ it.register.registerOrPair!=null }
if(firstReturnRegister!=null) { if(firstReturnRegister!=null) {
// we just take the first register return value and ignore the rest. // we just take the first register return value and ignore the rest.
val regStr = firstReturnRegister.register.registerOrPair.toString() val regStr = firstReturnRegister.register.registerOrPair.toString()
addInstr(result, IRInstruction(Opcode.LOADCPU, codeGen.irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null) addInstr(result, IRInstruction(Opcode.LOADCPU, irType(fcall.type), reg1=resultReg, labelSymbol = regStr), null)
} else { } else {
throw AssemblyError("invalid number of return values from call") throw AssemblyError("invalid number of return values from call")
} }
@ -448,7 +448,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) {
} }
} }
} }
return ExpressionCodeResult(result, if(fcall.void) IRDataType.BYTE else codeGen.irType(fcall.type), resultReg, -1) return ExpressionCodeResult(result, if(fcall.void) IRDataType.BYTE else irType(fcall.type), resultReg, -1)
} }
else -> throw AssemblyError("invalid node type") else -> throw AssemblyError("invalid node type")
} }

View File

@ -1524,20 +1524,6 @@ class IRCodeGen(
} }
} }
internal fun irType(type: DataType): IRDataType {
return when(type) {
DataType.BOOL,
DataType.UBYTE,
DataType.BYTE -> IRDataType.BYTE
DataType.UWORD,
DataType.WORD -> IRDataType.WORD
DataType.FLOAT -> IRDataType.FLOAT
in PassByReferenceDatatypes -> IRDataType.WORD
else -> throw AssemblyError("no IR datatype for $type")
}
}
private var labelSequenceNumber = 0 private var labelSequenceNumber = 0
internal fun createLabelName(): String { internal fun createLabelName(): String {
labelSequenceNumber++ labelSequenceNumber++

View File

@ -68,6 +68,39 @@ class VarConstantValueTypeAdjuster(private val program: Program, private val err
return noModifications return noModifications
} }
override fun after(functionCallExpr: FunctionCallExpression, parent: Node): Iterable<IAstModification> {
// choose specific builtin function for the given types
val func = functionCallExpr.target.nameInSource
if(func==listOf("min") || func==listOf("max")) {
val t1 = functionCallExpr.args[0].inferType(program)
val t2 = functionCallExpr.args[1].inferType(program)
if(t1.isKnown && t2.isKnown) {
val funcName = func[0]
val replaceFunc: String
if(t1.isBytes && t2.isBytes) {
replaceFunc = if(t1.istype(DataType.BYTE) || t2.istype(DataType.BYTE))
"${funcName}__byte"
else
"${funcName}__ubyte"
} else if(t1.isInteger && t2.isInteger) {
replaceFunc = if(t1.istype(DataType.WORD) || t2.istype(DataType.WORD))
"${funcName}__word"
else
"${funcName}__uword"
} else if(t1.isNumeric && t2.isNumeric) {
replaceFunc = "${funcName}__float"
} else {
errors.err("expected numeric arguments", functionCallExpr.position)
return noModifications
}
return listOf(IAstModification.SetExpression({functionCallExpr.target = it as IdentifierReference},
IdentifierReference(listOf(replaceFunc), functionCallExpr.target.position),
functionCallExpr))
}
}
return noModifications
}
} }

View File

@ -7,10 +7,7 @@ import prog8.ast.base.SyntaxError
import prog8.ast.expressions.* import prog8.ast.expressions.*
import prog8.ast.statements.VarDecl import prog8.ast.statements.VarDecl
import prog8.code.core.* import prog8.code.core.*
import kotlin.math.abs import kotlin.math.*
import kotlin.math.sign
import kotlin.math.sqrt
private typealias ConstExpressionCaller = (args: List<Expression>, position: Position, program: Program) -> NumericLiteral private typealias ConstExpressionCaller = (args: List<Expression>, position: Position, program: Program) -> NumericLiteral
@ -24,7 +21,15 @@ internal val constEvaluatorsForBuiltinFuncs: Map<String, ConstExpressionCaller>
"all" to { a, p, prg -> collectionArg(a, p, prg, ::builtinAll) }, "all" to { a, p, prg -> collectionArg(a, p, prg, ::builtinAll) },
"lsb" to { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> (x and 255).toDouble() } }, "lsb" to { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> (x and 255).toDouble() } },
"msb" to { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> (x ushr 8 and 255).toDouble()} }, "msb" to { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> (x ushr 8 and 255).toDouble()} },
"mkword" to ::builtinMkword "mkword" to ::builtinMkword,
"min__ubyte" to ::builtinMinUByte,
"min__byte" to ::builtinMinByte,
"min__uword" to ::builtinMinUWord,
"min__word" to ::builtinMinWord,
"max__ubyte" to ::builtinMaxUByte,
"max__byte" to ::builtinMaxByte,
"max__uword" to ::builtinMaxUWord,
"max__word" to ::builtinMaxWord,
) )
private fun builtinAny(array: List<Double>): Double = if(array.any { it!=0.0 }) 1.0 else 0.0 private fun builtinAny(array: List<Double>): Double = if(array.any { it!=0.0 }) 1.0 else 0.0
@ -156,3 +161,75 @@ private fun builtinSgn(args: List<Expression>, position: Position, program: Prog
val constval = args[0].constValue(program) ?: throw NotConstArgumentException() val constval = args[0].constValue(program) ?: throw NotConstArgumentException()
return NumericLiteral(DataType.BYTE, constval.number.sign, position) return NumericLiteral(DataType.BYTE, constval.number.sign, position)
} }
private fun builtinMinByte(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("min requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = min(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.BYTE, result.toDouble(), position)
}
private fun builtinMinUByte(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("min requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = min(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.UBYTE, result.toDouble(), position)
}
private fun builtinMinWord(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("min requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = min(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.WORD, result.toDouble(), position)
}
private fun builtinMinUWord(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("min requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = min(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.UWORD, result.toDouble(), position)
}
private fun builtinMaxByte(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("max requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = max(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.BYTE, result.toDouble(), position)
}
private fun builtinMaxUByte(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("max requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = max(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.UBYTE, result.toDouble(), position)
}
private fun builtinMaxWord(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("max requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = max(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.WORD, result.toDouble(), position)
}
private fun builtinMaxUWord(args: List<Expression>, position: Position, program: Program): NumericLiteral {
if (args.size != 2)
throw SyntaxError("max requires 2 arguments", position)
val val1 = args[0].constValue(program) ?: throw NotConstArgumentException()
val val2 = args[1].constValue(program) ?: throw NotConstArgumentException()
val result = max(val1.number.toInt(), val2.number.toInt())
return NumericLiteral(DataType.UWORD, result.toDouble(), position)
}

View File

@ -1,6 +1,14 @@
TODO TODO
==== ====
- try to reintroduce builtin functions max/maxw/min/minw that take 2 args and return the largest/smallest of them.
This is a major change because it will likely break existing code that is now using min and max as variable names.
Add "polymorphism" that translates min -> min__ubyte etc etc.
Also add optimization that changes the word variant to byte variant if the operands are bytes.
Add to docs.
- add polymorphism to other builtin functions as well! Fix docs.
- once 9.0 is stable, upgrade other programs (assem, shell, etc) to it. - once 9.0 is stable, upgrade other programs (assem, shell, etc) to it.
... ...
@ -8,9 +16,6 @@ TODO
For 9.0 major changes For 9.0 major changes
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
- try to reintroduce builtin functions max/maxw/min/minw that take 2 args and return the largest/smallest of them.
This is a major change because it will likely break existing code that is now using min and max as variable names.
Also add optimization that changes the word variant to byte variant if the operands are bytes.
- 6502 codegen: see if we can let for loops skip the loop if startvar>endvar, without adding a lot of code size/duplicating the loop condition. - 6502 codegen: see if we can let for loops skip the loop if startvar>endvar, without adding a lot of code size/duplicating the loop condition.
It is documented behavior to now loop 'around' $00 but it's too easy to forget about! It is documented behavior to now loop 'around' $00 but it's too easy to forget about!
Lot of work because of so many special cases in ForLoopsAsmgen..... Lot of work because of so many special cases in ForLoopsAsmgen.....

View File

@ -4,16 +4,33 @@
main { main {
sub start() { sub start() {
txt.print("hello") ubyte v1 = 11
; foobar() ubyte v2 = 88
} byte v1s = 22
byte v2s = -99
asmsub foobar() { uword w1 = 1111
%asm {{ uword w2 = 8888
nop word w1s = 2222
rts word w2s = -9999
}} txt.print_uw(min(v1, v2))
txt.spc()
txt.print_w(min(v1s, v2s))
txt.spc()
txt.print_uw(max(v1, v2))
txt.spc()
txt.print_w(max(v1s, v2s))
txt.nl()
txt.print_uw(min(w1, w2))
txt.spc()
txt.print_w(min(w1s, w2s))
txt.spc()
txt.print_uw(max(w1, w2))
txt.spc()
txt.print_w(max(w1s, w2s))
txt.nl()
} }
} }

View File

@ -1,5 +1,8 @@
package prog8.intermediate package prog8.intermediate
import prog8.code.core.AssemblyError
import prog8.code.core.DataType
import prog8.code.core.PassByReferenceDatatypes
import prog8.code.core.toHex import prog8.code.core.toHex
/* /*
@ -88,7 +91,7 @@ bger reg1, reg2, address - jump to location in program given by l
bgesr reg1, reg2, address - jump to location in program given by location, if reg1 >= reg2 (signed) bgesr reg1, reg2, address - jump to location in program given by location, if reg1 >= reg2 (signed)
ble reg1, value, address - jump to location in program given by location, if reg1 <= immediate value (unsigned) ble reg1, value, address - jump to location in program given by location, if reg1 <= immediate value (unsigned)
bles reg1, value, address - jump to location in program given by location, if reg1 <= immediate value (signed) bles reg1, value, address - jump to location in program given by location, if reg1 <= immediate value (signed)
( NOTE: there are no bltr/bler instructions because these are equivalent to bgtr/bger with the register operands swapped around.) ( NOTE: there are no bltr/bler instructions because these are equivalent to bgtr/bger with the register operands swapped around.)
sz reg1, reg2 - set reg1=1 if reg2==0, otherwise set reg1=0 sz reg1, reg2 - set reg1=1 if reg2==0, otherwise set reg1=0
snz reg1, reg2 - set reg1=1 if reg2!=0, otherwise set reg1=0 snz reg1, reg2 - set reg1=1 if reg2!=0, otherwise set reg1=0
seq reg1, reg2 - set reg1=1 if reg1 == reg2, otherwise set reg1=0 seq reg1, reg2 - set reg1=1 if reg1 == reg2, otherwise set reg1=0
@ -853,3 +856,17 @@ data class IRInstruction(
return result.joinToString("").trimEnd() return result.joinToString("").trimEnd()
} }
} }
fun irType(type: DataType): IRDataType {
return when(type) {
DataType.BOOL,
DataType.UBYTE,
DataType.BYTE -> IRDataType.BYTE
DataType.UWORD,
DataType.WORD -> IRDataType.WORD
DataType.FLOAT -> IRDataType.FLOAT
in PassByReferenceDatatypes -> IRDataType.WORD
else -> throw AssemblyError("no IR datatype for $type")
}
}