fixed builtin functions no longer const-folding over arrays

This commit is contained in:
Irmen de Jong 2019-08-17 20:16:39 +02:00
parent 59f8b91e25
commit d4a17dfad1
10 changed files with 122 additions and 97 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
.idea/workspace.xml
.idea/discord.xml
/build/
/dist/
/output/

View File

@ -1 +1 @@
1.53-dev
1.52

View File

@ -24,6 +24,14 @@ object InferredTypes {
return false
return isVoid==other.isVoid && datatype==other.datatype
}
override fun toString(): String {
return when {
datatype!=null -> datatype.toString()
isVoid -> "<void>"
else -> "<unkonwn>"
}
}
}
private val unknownInstance = InferredType.unknown()

View File

@ -694,7 +694,7 @@ internal class AstChecker(private val program: Program,
super.visit(array)
if(array.heapId==null)
if(array.heapId==null && array.parent !is FunctionCall)
throw FatalAstException("array should have been moved to heap at ${array.position}")
}
@ -1207,7 +1207,7 @@ internal class AstChecker(private val program: Program,
private fun checkArrayValues(value: ArrayLiteralValue, type: DataType): Boolean {
if(value.heapId==null) {
// hmm weird, array literal that hasn't been moved to the heap yet?
val array = value.value.map { it.constValue(program)!! }
val array = value.value.mapNotNull { it.constValue(program) }
val correct: Boolean
when(type) {
DataType.ARRAY_UB -> {

View File

@ -256,7 +256,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
val vardecl = array.parent as? VarDecl
return if (vardecl!=null) {
fixupArrayDatatype(array, vardecl, program.heap)
} else {
} else if(array.heapId!=null) {
// fix the datatype of the array (also on the heap) to the 'biggest' datatype in the array
// (we don't know the desired datatype here exactly so we guess)
val datatype = determineArrayDt(array.value)
@ -264,7 +264,8 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
litval2.parent = array.parent
// finally, replace the literal array by a identifier reference.
makeIdentifierFromRefLv(litval2)
}
} else
array
}
return array
}
@ -384,6 +385,29 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
}
internal fun fixupArrayDatatype(array: ArrayLiteralValue, program: Program): ArrayLiteralValue {
val dts = array.value.map {it.inferType(program).typeOrElse(DataType.STRUCT)}.toSet()
if(dts.any { it !in NumericDatatypes }) {
return array
}
val dt = when {
DataType.FLOAT in dts -> DataType.ARRAY_F
DataType.WORD in dts -> DataType.ARRAY_W
DataType.UWORD in dts -> DataType.ARRAY_UW
DataType.BYTE in dts -> DataType.ARRAY_B
else -> DataType.ARRAY_UB
}
if(dt==array.type)
return array
// convert values and array type
val elementType = ArrayElementTypes.getValue(dt)
val values = array.value.map { (it as NumericLiteralValue).cast(elementType) as Expression}.toTypedArray()
val array2 = ArrayLiteralValue(dt, values, array.heapId, array.position)
array2.linkParents(array.parent)
return array2
}
internal fun fixupArrayDatatype(array: ArrayLiteralValue, vardecl: VarDecl, heap: HeapValues): ArrayLiteralValue {
if(array.heapId!=null) {
val arrayDt = array.type

View File

@ -28,9 +28,9 @@ val BuiltinFunctions = mapOf(
"lsl" to FunctionSignature(false, listOf(BuiltinFunctionParam("item", IntegerDatatypes)), null),
"lsr" to FunctionSignature(false, listOf(BuiltinFunctionParam("item", IntegerDatatypes)), null),
// these few have a return value depending on the argument(s):
"max" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
"min" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
"sum" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
"max" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArg(a, p, prg, ::builtinMax) }, // type depends on args
"min" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArg(a, p, prg, ::builtinMin) }, // type depends on args
"sum" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArg(a, p, prg, ::builtinSum) }, // type depends on args
"abs" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", NumericDatatypes)), null, ::builtinAbs), // type depends on argument
"len" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", IterableDatatypes)), null, ::builtinLen), // type is UBYTE or UWORD depending on actual length
// normal functions follow:
@ -56,8 +56,8 @@ val BuiltinFunctions = mapOf(
"round" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::round) },
"floor" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::floor) },
"ceil" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::ceil) },
"any" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, _ -> collectionArgNeverConst(a, p) },
"all" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, _ -> collectionArgNeverConst(a, p) },
"any" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArg(a, p, prg, ::builtinAny) },
"all" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArg(a, p, prg, ::builtinAll) },
"lsb" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE) { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> x and 255 }},
"msb" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE) { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> x ushr 8 and 255}},
"mkword" to FunctionSignature(true, listOf(
@ -112,20 +112,30 @@ val BuiltinFunctions = mapOf(
null)
)
fun builtinMax(array: List<Number>): Number = array.maxBy { it.toDouble() }!!
fun builtinMin(array: List<Number>): Number = array.minBy { it.toDouble() }!!
fun builtinSum(array: List<Number>): Number = array.sumByDouble { it.toDouble() }
fun builtinAny(array: List<Number>): Number = if(array.any { it.toDouble()!=0.0 }) 1 else 0
fun builtinAll(array: List<Number>): Number = if(array.all { it.toDouble()!=0.0 }) 1 else 0
fun builtinFunctionReturnType(function: String, args: List<Expression>, program: Program): InferredTypes.InferredType {
fun datatypeFromIterableArg(arglist: Expression): DataType {
if(arglist is ArrayLiteralValue) {
if(arglist.type== DataType.ARRAY_UB || arglist.type== DataType.ARRAY_UW || arglist.type== DataType.ARRAY_F) {
val dt = arglist.value.map {it.inferType(program)}
if(dt.any { !(it istype DataType.UBYTE) && !(it istype DataType.UWORD) && !(it istype DataType.FLOAT)}) {
throw FatalAstException("fuction $function only accepts arraysize of numeric values")
}
if(dt.any { it istype DataType.FLOAT }) return DataType.FLOAT
if(dt.any { it istype DataType.UWORD }) return DataType.UWORD
return DataType.UBYTE
val dt = arglist.value.map {it.inferType(program).typeOrElse(DataType.STRUCT)}.toSet()
if(dt.any { it !in NumericDatatypes }) {
throw FatalAstException("fuction $function only accepts array of numeric values")
}
if(DataType.FLOAT in dt) return DataType.FLOAT
if(DataType.UWORD in dt) return DataType.UWORD
if(DataType.WORD in dt) return DataType.WORD
if(DataType.BYTE in dt) return DataType.BYTE
return DataType.UBYTE
}
if(arglist is IdentifierReference) {
val idt = arglist.inferType(program)
@ -215,12 +225,16 @@ private fun oneIntArgOutputInt(args: List<Expression>, position: Position, progr
return numericLiteral(function(integer).toInt(), args[0].position)
}
private fun collectionArgNeverConst(args: List<Expression>, position: Position): NumericLiteralValue {
private fun collectionArg(args: List<Expression>, position: Position, program: Program, function: (arg: List<Number>)->Number): NumericLiteralValue {
if(args.size!=1)
throw SyntaxError("builtin function requires one non-scalar argument", position)
// max/min/sum etc only work on arrays and these are never considered to be const for these functions
throw NotConstArgumentException()
val array= args[0] as? ArrayLiteralValue ?: throw NotConstArgumentException()
val constElements = array.value.map{it.constValue(program)?.number}
if(constElements.contains(null))
throw NotConstArgumentException()
return NumericLiteralValue.optimalNumeric(function(constElements.mapNotNull { it }), args[0].position)
}
private fun builtinAbs(args: List<Expression>, position: Position, program: Program): NumericLiteralValue {
@ -258,6 +272,8 @@ private fun builtinLen(args: List<Expression>, position: Position, program: Prog
var arraySize = directMemVar?.arraysize?.size()
if(arraySize != null)
return NumericLiteralValue.optimalInteger(arraySize, position)
if(args[0] is ArrayLiteralValue)
return NumericLiteralValue.optimalInteger((args[0] as ArrayLiteralValue).value.size, position)
if(args[0] !is IdentifierReference)
throw SyntaxError("len argument should be an identifier, but is ${args[0]}", position)
val target = (args[0] as IdentifierReference).targetVarDecl(program.namespace)!!

View File

@ -39,9 +39,11 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
if(decl.isArray){
if(decl.arraysize==null) {
// for arrays that have no size specifier (or a non-constant one) attempt to deduce the size
val arrayval = (decl.value as ArrayLiteralValue).value
decl.arraysize = ArrayIndex(NumericLiteralValue.optimalInteger(arrayval.size, decl.position), decl.position)
optimizationsDone++
val arrayval = decl.value as? ArrayLiteralValue
if(arrayval!=null) {
decl.arraysize = ArrayIndex(NumericLiteralValue.optimalInteger(arrayval.value.size, decl.position), decl.position)
optimizationsDone++
}
}
else if(decl.arraysize?.size()==null) {
val size = decl.arraysize!!.index.accept(this)
@ -183,9 +185,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
}
override fun visit(functionCall: FunctionCall): Expression {
super.visit(functionCall)
typeCastConstArguments(functionCall)
return try {
super.visit(functionCall)
typeCastConstArguments(functionCall)
functionCall.constValue(program) ?: functionCall
} catch (ax: AstException) {
addError(ax)
@ -600,8 +602,11 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
val array = super.visit(arrayLiteral)
if(array is ArrayLiteralValue) {
val vardecl = array.parent as? VarDecl
if (vardecl!=null) {
return fixupArrayDatatype(array, vardecl, program.heap)
return if (vardecl!=null) {
fixupArrayDatatype(array, vardecl, program.heap)
} else {
// it's not an array associated with a vardecl, attempt to guess the data type from the array values
fixupArrayDatatype(array, program)
}
}
return array

View File

@ -79,7 +79,6 @@ main {
word persp
byte sx
byte sy
ubyte color
for i in 0 to len(xcoor)-1 {
rz = rotatedz[i]

View File

@ -6,78 +6,50 @@
main {
sub start() {
byte ub = 100
byte ub2
word uw = 22222
word uw2
ubyte ubarr = max([10,0,2,8,5,4,3,9])
uword uwarr = max([1000,0,200,8000,50,40000,3,900])
byte barr = max([-10,0,-2,8,5,4,-3,9])
word warr = max([-1000,0,-200,8000,50,2000,3,-900])
float flarr = max([-2.2, 1.1, 3.3, 0.0])
ub = -100
c64scr.print_b(ub >> 1)
c64.CHROUT('\n')
c64scr.print_b(ub >> 2)
c64.CHROUT('\n')
c64scr.print_b(ub >> 7)
c64.CHROUT('\n')
c64scr.print_b(ub >> 8)
c64.CHROUT('\n')
c64scr.print_b(ub >> 9)
c64.CHROUT('\n')
c64scr.print_b(ub >> 16)
c64.CHROUT('\n')
c64scr.print_b(ub >> 26)
c64.CHROUT('\n')
c64.CHROUT('\n')
ubarr = min([10,0,2,8,5,4,3,9])
uwarr = min([1000,0,200,8000,50,40000,3,900])
barr = min([-10,0,-2,8,5,4,-3,9])
warr = min([-1000,0,-200,8000,50,2000,3,-900])
flarr = min([-2.2, 1.1, 3.3, 0.0])
ub = 100
c64scr.print_b(ub >> 1)
c64.CHROUT('\n')
c64scr.print_b(ub >> 2)
c64.CHROUT('\n')
c64scr.print_b(ub >> 7)
c64.CHROUT('\n')
c64scr.print_b(ub >> 8)
c64.CHROUT('\n')
c64scr.print_b(ub >> 9)
c64.CHROUT('\n')
c64scr.print_b(ub >> 16)
c64.CHROUT('\n')
c64scr.print_b(ub >> 26)
c64.CHROUT('\n')
c64.CHROUT('\n')
uwarr = sum([10,0,2,8,5,4,3,9])
uwarr = sum([1000,0,200,8000,50,40000,3,900])
warr = sum([-10,0,-2,8,5,4,-3,9])
warr = sum([-1000,0,-200,8000,50,2000,3,-900])
flarr = sum([-2.2, 1.1, 3.3, 0.0])
ubarr = any([10,0,2,8,5,4,3,9])
ubarr = any([1000,0,200,8000,50,40000,3,900])
ubarr = any([-10,0,-2,8,5,4,-3,9])
ubarr = any([-1000,0,-200,8000,50,2000,3,-900])
ubarr = any([-2.2, 1.1, 3.3, 0.0])
ubarr = all([10,0,2,8,5,4,3,9])
ubarr = all([1000,0,200,8000,50,40000,3,900])
ubarr = all([-10,0,-2,8,5,4,-3,9])
ubarr = all([-1000,0,-200,8000,50,2000,3,-900])
ubarr = all([-2.2, 1.1, 3.3, 0.0])
ubarr = len([10,0,2,8,5,4,3,9])
A = len([1000,0,200,8000,50,40000,3,900])
ubarr = len([-10,0,-2,8,5,4,-3,9])
A = len([-1000,0,-200,8000,50,2000,3,-900])
ubarr = len([-2.2, 1.1, 3.3, 0.0])
uw = -22222
c64scr.print_w(uw >> 1)
c64.CHROUT('\n')
c64scr.print_w(uw >> 7)
c64.CHROUT('\n')
c64scr.print_w(uw >> 8)
c64.CHROUT('\n')
c64scr.print_w(uw >> 9)
c64.CHROUT('\n')
c64scr.print_w(uw >> 15)
c64.CHROUT('\n')
c64scr.print_w(uw >> 16)
c64.CHROUT('\n')
c64scr.print_w(uw >> 26)
c64.CHROUT('\n')
c64.CHROUT('\n')
; ubyte[] uba = sort([10,0,2,8,5,4,3,9])
; uword[] uwa = sort([1000,0,200,8000,50,40000,3,900])
; byte[] ba = sort([-10,0,-2,8,5,4,-3,9])
; word[] wa = sort([-1000,0,-200,8000,50,40000,3,-900])
; float[] fla = sort([-2.2, 1.1, 3.3, 0.0])
uw = 22222
c64scr.print_w(uw >> 1)
c64.CHROUT('\n')
c64scr.print_w(uw >> 7)
c64.CHROUT('\n')
c64scr.print_w(uw >> 8)
c64.CHROUT('\n')
c64scr.print_w(uw >> 9)
c64.CHROUT('\n')
c64scr.print_w(uw >> 15)
c64.CHROUT('\n')
c64scr.print_w(uw >> 16)
c64.CHROUT('\n')
c64scr.print_w(uw >> 26)
c64.CHROUT('\n')
}
}

View File

@ -1,5 +1,5 @@
#!/usr/bin/env sh
rm *.jar *.asm *.prg *.vm.txt *.vice-mon-list
rm -f *.jar *.asm *.prg *.vm.txt *.vice-mon-list
rm -rf build out