mirror of
https://github.com/irmen/prog8.git
synced 2024-07-05 22:29:04 +00:00
fixed builtin functions no longer const-folding over arrays
This commit is contained in:
parent
59f8b91e25
commit
d4a17dfad1
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
.idea/workspace.xml
|
||||
.idea/discord.xml
|
||||
/build/
|
||||
/dist/
|
||||
/output/
|
||||
|
@ -1 +1 @@
|
||||
1.53-dev
|
||||
1.52
|
||||
|
@ -24,6 +24,14 @@ object InferredTypes {
|
||||
return false
|
||||
return isVoid==other.isVoid && datatype==other.datatype
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return when {
|
||||
datatype!=null -> datatype.toString()
|
||||
isVoid -> "<void>"
|
||||
else -> "<unkonwn>"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val unknownInstance = InferredType.unknown()
|
||||
|
@ -694,7 +694,7 @@ internal class AstChecker(private val program: Program,
|
||||
|
||||
super.visit(array)
|
||||
|
||||
if(array.heapId==null)
|
||||
if(array.heapId==null && array.parent !is FunctionCall)
|
||||
throw FatalAstException("array should have been moved to heap at ${array.position}")
|
||||
}
|
||||
|
||||
@ -1207,7 +1207,7 @@ internal class AstChecker(private val program: Program,
|
||||
private fun checkArrayValues(value: ArrayLiteralValue, type: DataType): Boolean {
|
||||
if(value.heapId==null) {
|
||||
// hmm weird, array literal that hasn't been moved to the heap yet?
|
||||
val array = value.value.map { it.constValue(program)!! }
|
||||
val array = value.value.mapNotNull { it.constValue(program) }
|
||||
val correct: Boolean
|
||||
when(type) {
|
||||
DataType.ARRAY_UB -> {
|
||||
|
@ -256,7 +256,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
|
||||
val vardecl = array.parent as? VarDecl
|
||||
return if (vardecl!=null) {
|
||||
fixupArrayDatatype(array, vardecl, program.heap)
|
||||
} else {
|
||||
} else if(array.heapId!=null) {
|
||||
// fix the datatype of the array (also on the heap) to the 'biggest' datatype in the array
|
||||
// (we don't know the desired datatype here exactly so we guess)
|
||||
val datatype = determineArrayDt(array.value)
|
||||
@ -264,7 +264,8 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
|
||||
litval2.parent = array.parent
|
||||
// finally, replace the literal array by a identifier reference.
|
||||
makeIdentifierFromRefLv(litval2)
|
||||
}
|
||||
} else
|
||||
array
|
||||
}
|
||||
return array
|
||||
}
|
||||
@ -384,6 +385,29 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
|
||||
|
||||
}
|
||||
|
||||
internal fun fixupArrayDatatype(array: ArrayLiteralValue, program: Program): ArrayLiteralValue {
|
||||
val dts = array.value.map {it.inferType(program).typeOrElse(DataType.STRUCT)}.toSet()
|
||||
if(dts.any { it !in NumericDatatypes }) {
|
||||
return array
|
||||
}
|
||||
val dt = when {
|
||||
DataType.FLOAT in dts -> DataType.ARRAY_F
|
||||
DataType.WORD in dts -> DataType.ARRAY_W
|
||||
DataType.UWORD in dts -> DataType.ARRAY_UW
|
||||
DataType.BYTE in dts -> DataType.ARRAY_B
|
||||
else -> DataType.ARRAY_UB
|
||||
}
|
||||
if(dt==array.type)
|
||||
return array
|
||||
|
||||
// convert values and array type
|
||||
val elementType = ArrayElementTypes.getValue(dt)
|
||||
val values = array.value.map { (it as NumericLiteralValue).cast(elementType) as Expression}.toTypedArray()
|
||||
val array2 = ArrayLiteralValue(dt, values, array.heapId, array.position)
|
||||
array2.linkParents(array.parent)
|
||||
return array2
|
||||
}
|
||||
|
||||
internal fun fixupArrayDatatype(array: ArrayLiteralValue, vardecl: VarDecl, heap: HeapValues): ArrayLiteralValue {
|
||||
if(array.heapId!=null) {
|
||||
val arrayDt = array.type
|
||||
|
@ -28,9 +28,9 @@ val BuiltinFunctions = mapOf(
|
||||
"lsl" to FunctionSignature(false, listOf(BuiltinFunctionParam("item", IntegerDatatypes)), null),
|
||||
"lsr" to FunctionSignature(false, listOf(BuiltinFunctionParam("item", IntegerDatatypes)), null),
|
||||
// these few have a return value depending on the argument(s):
|
||||
"max" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
|
||||
"min" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
|
||||
"sum" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
|
||||
"max" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArg(a, p, prg, ::builtinMax) }, // type depends on args
|
||||
"min" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArg(a, p, prg, ::builtinMin) }, // type depends on args
|
||||
"sum" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArg(a, p, prg, ::builtinSum) }, // type depends on args
|
||||
"abs" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", NumericDatatypes)), null, ::builtinAbs), // type depends on argument
|
||||
"len" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", IterableDatatypes)), null, ::builtinLen), // type is UBYTE or UWORD depending on actual length
|
||||
// normal functions follow:
|
||||
@ -56,8 +56,8 @@ val BuiltinFunctions = mapOf(
|
||||
"round" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::round) },
|
||||
"floor" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::floor) },
|
||||
"ceil" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::ceil) },
|
||||
"any" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, _ -> collectionArgNeverConst(a, p) },
|
||||
"all" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, _ -> collectionArgNeverConst(a, p) },
|
||||
"any" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArg(a, p, prg, ::builtinAny) },
|
||||
"all" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArg(a, p, prg, ::builtinAll) },
|
||||
"lsb" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE) { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> x and 255 }},
|
||||
"msb" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE) { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> x ushr 8 and 255}},
|
||||
"mkword" to FunctionSignature(true, listOf(
|
||||
@ -112,20 +112,30 @@ val BuiltinFunctions = mapOf(
|
||||
null)
|
||||
)
|
||||
|
||||
fun builtinMax(array: List<Number>): Number = array.maxBy { it.toDouble() }!!
|
||||
|
||||
fun builtinMin(array: List<Number>): Number = array.minBy { it.toDouble() }!!
|
||||
|
||||
fun builtinSum(array: List<Number>): Number = array.sumByDouble { it.toDouble() }
|
||||
|
||||
fun builtinAny(array: List<Number>): Number = if(array.any { it.toDouble()!=0.0 }) 1 else 0
|
||||
|
||||
fun builtinAll(array: List<Number>): Number = if(array.all { it.toDouble()!=0.0 }) 1 else 0
|
||||
|
||||
|
||||
fun builtinFunctionReturnType(function: String, args: List<Expression>, program: Program): InferredTypes.InferredType {
|
||||
|
||||
fun datatypeFromIterableArg(arglist: Expression): DataType {
|
||||
if(arglist is ArrayLiteralValue) {
|
||||
if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) {
|
||||
val dt = arglist.value.map {it.inferType(program)}
|
||||
if(dt.any { !(it istype DataType.UBYTE) && !(it istype DataType.UWORD) && !(it istype DataType.FLOAT)}) {
|
||||
throw FatalAstException("fuction $function only accepts arraysize of numeric values")
|
||||
}
|
||||
if(dt.any { it istype DataType.FLOAT }) return DataType.FLOAT
|
||||
if(dt.any { it istype DataType.UWORD }) return DataType.UWORD
|
||||
return DataType.UBYTE
|
||||
val dt = arglist.value.map {it.inferType(program).typeOrElse(DataType.STRUCT)}.toSet()
|
||||
if(dt.any { it !in NumericDatatypes }) {
|
||||
throw FatalAstException("fuction $function only accepts array of numeric values")
|
||||
}
|
||||
if(DataType.FLOAT in dt) return DataType.FLOAT
|
||||
if(DataType.UWORD in dt) return DataType.UWORD
|
||||
if(DataType.WORD in dt) return DataType.WORD
|
||||
if(DataType.BYTE in dt) return DataType.BYTE
|
||||
return DataType.UBYTE
|
||||
}
|
||||
if(arglist is IdentifierReference) {
|
||||
val idt = arglist.inferType(program)
|
||||
@ -215,12 +225,16 @@ private fun oneIntArgOutputInt(args: List<Expression>, position: Position, progr
|
||||
return numericLiteral(function(integer).toInt(), args[0].position)
|
||||
}
|
||||
|
||||
private fun collectionArgNeverConst(args: List<Expression>, position: Position): NumericLiteralValue {
|
||||
private fun collectionArg(args: List<Expression>, position: Position, program: Program, function: (arg: List<Number>)->Number): NumericLiteralValue {
|
||||
if(args.size!=1)
|
||||
throw SyntaxError("builtin function requires one non-scalar argument", position)
|
||||
|
||||
// max/min/sum etc only work on arrays and these are never considered to be const for these functions
|
||||
throw NotConstArgumentException()
|
||||
val array= args[0] as? ArrayLiteralValue ?: throw NotConstArgumentException()
|
||||
val constElements = array.value.map{it.constValue(program)?.number}
|
||||
if(constElements.contains(null))
|
||||
throw NotConstArgumentException()
|
||||
|
||||
return NumericLiteralValue.optimalNumeric(function(constElements.mapNotNull { it }), args[0].position)
|
||||
}
|
||||
|
||||
private fun builtinAbs(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
|
||||
@ -258,6 +272,8 @@ private fun builtinLen(args: List<Expression>, position: Position, program: Prog
|
||||
var arraySize = directMemVar?.arraysize?.size()
|
||||
if(arraySize != null)
|
||||
return NumericLiteralValue.optimalInteger(arraySize, position)
|
||||
if(args[0] is ArrayLiteralValue)
|
||||
return NumericLiteralValue.optimalInteger((args[0] as ArrayLiteralValue).value.size, position)
|
||||
if(args[0] !is IdentifierReference)
|
||||
throw SyntaxError("len argument should be an identifier, but is ${args[0]}", position)
|
||||
val target = (args[0] as IdentifierReference).targetVarDecl(program.namespace)!!
|
||||
|
@ -39,9 +39,11 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
|
||||
if(decl.isArray){
|
||||
if(decl.arraysize==null) {
|
||||
// for arrays that have no size specifier (or a non-constant one) attempt to deduce the size
|
||||
val arrayval = (decl.value as ArrayLiteralValue).value
|
||||
decl.arraysize = ArrayIndex(NumericLiteralValue.optimalInteger(arrayval.size, decl.position), decl.position)
|
||||
optimizationsDone++
|
||||
val arrayval = decl.value as? ArrayLiteralValue
|
||||
if(arrayval!=null) {
|
||||
decl.arraysize = ArrayIndex(NumericLiteralValue.optimalInteger(arrayval.value.size, decl.position), decl.position)
|
||||
optimizationsDone++
|
||||
}
|
||||
}
|
||||
else if(decl.arraysize?.size()==null) {
|
||||
val size = decl.arraysize!!.index.accept(this)
|
||||
@ -183,9 +185,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
|
||||
}
|
||||
|
||||
override fun visit(functionCall: FunctionCall): Expression {
|
||||
super.visit(functionCall)
|
||||
typeCastConstArguments(functionCall)
|
||||
return try {
|
||||
super.visit(functionCall)
|
||||
typeCastConstArguments(functionCall)
|
||||
functionCall.constValue(program) ?: functionCall
|
||||
} catch (ax: AstException) {
|
||||
addError(ax)
|
||||
@ -600,8 +602,11 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
|
||||
val array = super.visit(arrayLiteral)
|
||||
if(array is ArrayLiteralValue) {
|
||||
val vardecl = array.parent as? VarDecl
|
||||
if (vardecl!=null) {
|
||||
return fixupArrayDatatype(array, vardecl, program.heap)
|
||||
return if (vardecl!=null) {
|
||||
fixupArrayDatatype(array, vardecl, program.heap)
|
||||
} else {
|
||||
// it's not an array associated with a vardecl, attempt to guess the data type from the array values
|
||||
fixupArrayDatatype(array, program)
|
||||
}
|
||||
}
|
||||
return array
|
||||
|
@ -79,7 +79,6 @@ main {
|
||||
word persp
|
||||
byte sx
|
||||
byte sy
|
||||
ubyte color
|
||||
|
||||
for i in 0 to len(xcoor)-1 {
|
||||
rz = rotatedz[i]
|
||||
|
106
examples/test.p8
106
examples/test.p8
@ -6,78 +6,50 @@
|
||||
main {
|
||||
|
||||
sub start() {
|
||||
byte ub = 100
|
||||
byte ub2
|
||||
word uw = 22222
|
||||
word uw2
|
||||
ubyte ubarr = max([10,0,2,8,5,4,3,9])
|
||||
uword uwarr = max([1000,0,200,8000,50,40000,3,900])
|
||||
byte barr = max([-10,0,-2,8,5,4,-3,9])
|
||||
word warr = max([-1000,0,-200,8000,50,2000,3,-900])
|
||||
float flarr = max([-2.2, 1.1, 3.3, 0.0])
|
||||
|
||||
ub = -100
|
||||
c64scr.print_b(ub >> 1)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 2)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 7)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 8)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 9)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 16)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 26)
|
||||
c64.CHROUT('\n')
|
||||
c64.CHROUT('\n')
|
||||
ubarr = min([10,0,2,8,5,4,3,9])
|
||||
uwarr = min([1000,0,200,8000,50,40000,3,900])
|
||||
barr = min([-10,0,-2,8,5,4,-3,9])
|
||||
warr = min([-1000,0,-200,8000,50,2000,3,-900])
|
||||
flarr = min([-2.2, 1.1, 3.3, 0.0])
|
||||
|
||||
ub = 100
|
||||
c64scr.print_b(ub >> 1)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 2)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 7)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 8)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 9)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 16)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_b(ub >> 26)
|
||||
c64.CHROUT('\n')
|
||||
c64.CHROUT('\n')
|
||||
uwarr = sum([10,0,2,8,5,4,3,9])
|
||||
uwarr = sum([1000,0,200,8000,50,40000,3,900])
|
||||
warr = sum([-10,0,-2,8,5,4,-3,9])
|
||||
warr = sum([-1000,0,-200,8000,50,2000,3,-900])
|
||||
flarr = sum([-2.2, 1.1, 3.3, 0.0])
|
||||
|
||||
ubarr = any([10,0,2,8,5,4,3,9])
|
||||
ubarr = any([1000,0,200,8000,50,40000,3,900])
|
||||
ubarr = any([-10,0,-2,8,5,4,-3,9])
|
||||
ubarr = any([-1000,0,-200,8000,50,2000,3,-900])
|
||||
ubarr = any([-2.2, 1.1, 3.3, 0.0])
|
||||
|
||||
ubarr = all([10,0,2,8,5,4,3,9])
|
||||
ubarr = all([1000,0,200,8000,50,40000,3,900])
|
||||
ubarr = all([-10,0,-2,8,5,4,-3,9])
|
||||
ubarr = all([-1000,0,-200,8000,50,2000,3,-900])
|
||||
ubarr = all([-2.2, 1.1, 3.3, 0.0])
|
||||
|
||||
ubarr = len([10,0,2,8,5,4,3,9])
|
||||
A = len([1000,0,200,8000,50,40000,3,900])
|
||||
ubarr = len([-10,0,-2,8,5,4,-3,9])
|
||||
A = len([-1000,0,-200,8000,50,2000,3,-900])
|
||||
ubarr = len([-2.2, 1.1, 3.3, 0.0])
|
||||
|
||||
|
||||
uw = -22222
|
||||
c64scr.print_w(uw >> 1)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 7)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 8)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 9)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 15)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 16)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 26)
|
||||
c64.CHROUT('\n')
|
||||
c64.CHROUT('\n')
|
||||
; ubyte[] uba = sort([10,0,2,8,5,4,3,9])
|
||||
; uword[] uwa = sort([1000,0,200,8000,50,40000,3,900])
|
||||
; byte[] ba = sort([-10,0,-2,8,5,4,-3,9])
|
||||
; word[] wa = sort([-1000,0,-200,8000,50,40000,3,-900])
|
||||
; float[] fla = sort([-2.2, 1.1, 3.3, 0.0])
|
||||
|
||||
|
||||
uw = 22222
|
||||
c64scr.print_w(uw >> 1)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 7)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 8)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 9)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 15)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 16)
|
||||
c64.CHROUT('\n')
|
||||
c64scr.print_w(uw >> 26)
|
||||
c64.CHROUT('\n')
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
rm *.jar *.asm *.prg *.vm.txt *.vice-mon-list
|
||||
rm -f *.jar *.asm *.prg *.vm.txt *.vice-mon-list
|
||||
rm -rf build out
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user