fix aggregate functions in astvm

This commit is contained in:
Irmen de Jong 2019-07-15 03:57:51 +02:00
parent 78d7849197
commit 55a7a5d9d5
10 changed files with 138 additions and 165 deletions

View File

@ -31,7 +31,7 @@ internal fun Program.checkValid(compilerOptions: CompilationOptions) {
internal fun Program.reorderStatements() {
val initvalueCreator = VarInitValueAndAddressOfCreator(namespace)
val initvalueCreator = VarInitValueAndAddressOfCreator(namespace, heap)
initvalueCreator.visit(this)
val checker = StatementReorderer(this)
@ -52,7 +52,7 @@ internal fun Program.checkRecursion() {
internal fun Program.checkIdentifiers() {
val checker = AstIdentifiersChecker(namespace)
val checker = AstIdentifiersChecker(this)
checker.visit(this)
if(modules.map {it.name}.toSet().size != modules.size) {

View File

@ -514,29 +514,28 @@ class ReferenceLiteralValue(val type: DataType, // only reference types allo
}
fun addToHeap(heap: HeapValues) {
if(heapId==null) {
if (str != null) {
heapId = heap.addString(type, str)
}
else if (array!=null) {
if(array.any {it is AddressOf }) {
val intArrayWithAddressOfs = array.map {
when (it) {
is AddressOf -> IntegerOrAddressOf(null, it)
is NumericLiteralValue -> IntegerOrAddressOf(it.number.toInt(), null)
else -> throw FatalAstException("invalid datatype in array")
}
if (heapId != null) return
if (str != null) {
heapId = heap.addString(type, str)
}
else if (array!=null) {
if(array.any {it is AddressOf }) {
val intArrayWithAddressOfs = array.map {
when (it) {
is AddressOf -> IntegerOrAddressOf(null, it)
is NumericLiteralValue -> IntegerOrAddressOf(it.number.toInt(), null)
else -> throw FatalAstException("invalid datatype in array")
}
heapId = heap.addIntegerArray(type, intArrayWithAddressOfs.toTypedArray())
}
heapId = heap.addIntegerArray(type, intArrayWithAddressOfs.toTypedArray())
} else {
val valuesInArray = array.map { (it as NumericLiteralValue).number }
heapId = if(type== DataType.ARRAY_F) {
val doubleArray = valuesInArray.map { it.toDouble() }.toDoubleArray()
heap.addDoublesArray(doubleArray)
} else {
val valuesInArray = array.map { (it as NumericLiteralValue).number }
heapId = if(type== DataType.ARRAY_F) {
val doubleArray = valuesInArray.map { it.toDouble() }.toDoubleArray()
heap.addDoublesArray(doubleArray)
} else {
val integerArray = valuesInArray.map { it.toInt() }
heap.addIntegerArray(type, integerArray.map { IntegerOrAddressOf(it, null) }.toTypedArray())
}
val integerArray = valuesInArray.map { it.toInt() }
heap.addIntegerArray(type, integerArray.map { IntegerOrAddressOf(it, null) }.toTypedArray())
}
}
}

View File

@ -7,7 +7,7 @@ import prog8.ast.statements.*
import prog8.functions.BuiltinFunctions
internal class AstIdentifiersChecker(private val namespace: INameScope) : IAstModifyingVisitor {
internal class AstIdentifiersChecker(private val program: Program) : IAstModifyingVisitor {
private val checkResult: MutableList<AstException> = mutableListOf()
private var blocks = mutableMapOf<String, Block>()
@ -78,7 +78,7 @@ internal class AstIdentifiersChecker(private val namespace: INameScope) : IAstMo
return result
}
val existing = namespace.lookup(listOf(decl.name), decl)
val existing = program.namespace.lookup(listOf(decl.name), decl)
if (existing != null && existing !== decl)
nameError(decl.name, decl.position, existing)
@ -93,7 +93,7 @@ internal class AstIdentifiersChecker(private val namespace: INameScope) : IAstMo
if (subroutine.parameters.any { it.name in BuiltinFunctions })
checkResult.add(NameError("builtin function name cannot be used as parameter", subroutine.position))
val existing = namespace.lookup(listOf(subroutine.name), subroutine)
val existing = program.namespace.lookup(listOf(subroutine.name), subroutine)
if (existing != null && existing !== subroutine)
nameError(subroutine.name, subroutine.position, existing)
@ -137,7 +137,7 @@ internal class AstIdentifiersChecker(private val namespace: INameScope) : IAstMo
// the builtin functions can't be redefined
checkResult.add(NameError("builtin function cannot be redefined", label.position))
} else {
val existing = namespace.lookup(listOf(label.name), label)
val existing = program.namespace.lookup(listOf(label.name), label)
if (existing != null && existing !== label)
nameError(label.name, label.position, existing)
}
@ -213,21 +213,25 @@ internal class AstIdentifiersChecker(private val namespace: INameScope) : IAstMo
return super.visit(returnStmt)
}
override fun visit(refLiteral: ReferenceLiteralValue): ReferenceLiteralValue {
override fun visit(refLiteral: ReferenceLiteralValue): IExpression {
if(refLiteral.parent !is VarDecl) {
// a referencetype literal value that's not declared as a variable
// we need to introduce an auto-generated variable for this to be able to refer to the value
val declaredType = if(refLiteral.isArray) ArrayElementTypes.getValue(refLiteral.type) else refLiteral.type
val variable = VarDecl.createAuto(refLiteral)
refLiteral.addToHeap(program.heap)
val variable = VarDecl.createAuto(refLiteral, program.heap)
addVarDecl(refLiteral.definingScope(), variable)
// replace the reference literal by a identfier reference
val identifier = IdentifierReference(listOf(variable.name), variable.position)
identifier.parent = refLiteral.parent
// TODO anonymousVariablesFromHeap[variable.name] = Pair(refLiteral, variable)
return identifier
}
return super.visit(refLiteral)
}
override fun visit(addressOf: AddressOf): IExpression {
// register the scoped name of the referenced identifier
val variable= addressOf.identifier.targetVarDecl(namespace) ?: return addressOf
val variable= addressOf.identifier.targetVarDecl(program.namespace) ?: return addressOf
addressOf.scopedname = variable.scopedname
return super.visit(addressOf)
}

View File

@ -104,7 +104,7 @@ interface IAstModifyingVisitor {
return literalValue
}
fun visit(refLiteral: ReferenceLiteralValue): ReferenceLiteralValue {
fun visit(refLiteral: ReferenceLiteralValue): IExpression {
if(refLiteral.array!=null) {
for(av in refLiteral.array.withIndex()) {
val newvalue = av.value.accept(this)

View File

@ -4,9 +4,10 @@ import prog8.ast.*
import prog8.ast.base.*
import prog8.ast.expressions.*
import prog8.ast.statements.*
import prog8.compiler.HeapValues
internal class VarInitValueAndAddressOfCreator(private val namespace: INameScope): IAstModifyingVisitor {
internal class VarInitValueAndAddressOfCreator(private val namespace: INameScope, private val heap: HeapValues): IAstModifyingVisitor {
// For VarDecls that declare an initialization value:
// Replace the vardecl with an assignment (to set the initial value),
// and add a new vardecl with the default constant value of that type (usually zero) to the scope.
@ -95,7 +96,7 @@ internal class VarInitValueAndAddressOfCreator(private val namespace: INameScope
else if(strvalue!=null) {
if(strvalue.isString) {
// add a vardecl so that the autovar can be resolved in later lookups
val variable = VarDecl.createAuto(strvalue)
val variable = VarDecl.createAuto(strvalue, heap)
addVarDecl(strvalue.definingScope(), variable)
// replace the argument with &autovar
val autoHeapvarRef = IdentifierReference(listOf(variable.name), strvalue.position)

View File

@ -163,10 +163,21 @@ class VarDecl(val type: VarDeclType,
get() = value!=null && value !is NumericLiteralValue
companion object {
fun createAuto(refLv: ReferenceLiteralValue): VarDecl {
fun createAuto(refLv: ReferenceLiteralValue, heap: HeapValues): VarDecl {
if(refLv.heapId==null)
throw FatalAstException("can only create autovar for a ref lv that has a heapid $refLv")
val autoVarName = "$autoHeapValuePrefix${refLv.heapId}"
return VarDecl(VarDeclType.VAR, refLv.type, ZeropageWish.NOT_IN_ZEROPAGE, null, autoVarName, null, refLv,
isArray = false, autogeneratedDontRemove = true, position = refLv.position)
return if(refLv.isArray) {
val declaredType = ArrayElementTypes.getValue(refLv.type)
val arraysize = ArrayIndex.forArray(refLv, heap)
VarDecl(VarDeclType.VAR, declaredType, ZeropageWish.NOT_IN_ZEROPAGE, arraysize, autoVarName, null, refLv,
isArray = true, autogeneratedDontRemove = true, position = refLv.position)
} else {
VarDecl(VarDeclType.VAR, refLv.type, ZeropageWish.NOT_IN_ZEROPAGE, null, autoVarName, null, refLv,
isArray = false, autogeneratedDontRemove = true, position = refLv.position)
}
}
}

View File

@ -32,9 +32,9 @@ val BuiltinFunctions = mapOf(
"lsl" to FunctionSignature(false, listOf(BuiltinFunctionParam("item", IntegerDatatypes)), null),
"lsr" to FunctionSignature(false, listOf(BuiltinFunctionParam("item", IntegerDatatypes)), null),
// these few have a return value depending on the argument(s):
"max" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArgOutputNumber(a, p, prg) { it.max()!! }}, // type depends on args
"min" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArgOutputNumber(a, p, prg) { it.min()!! }}, // type depends on args
"sum" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, prg -> collectionArgOutputNumber(a, p, prg) { it.sum() }}, // type depends on args
"max" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
"min" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
"sum" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), null) { a, p, _ -> collectionArgNeverConst(a, p) }, // type depends on args
"abs" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", NumericDatatypes)), null, ::builtinAbs), // type depends on argument
"len" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", IterableDatatypes)), null, ::builtinLen), // type is UBYTE or UWORD depending on actual length
// normal functions follow:
@ -56,12 +56,12 @@ val BuiltinFunctions = mapOf(
"sqrt" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::sqrt) },
"rad" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toRadians) },
"deg" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toDegrees) },
"avg" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.FLOAT, ::builtinAvg),
"avg" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.FLOAT) { a, p, _ -> collectionArgNeverConst(a, p) },
"round" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::round) },
"floor" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::floor) },
"ceil" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::ceil) },
"any" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArgOutputBoolean(a, p, prg) { it.any { v -> v != 0.0} }},
"all" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArgOutputBoolean(a, p, prg) { it.all { v -> v != 0.0} }},
"any" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, _ -> collectionArgNeverConst(a, p) },
"all" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, _ -> collectionArgNeverConst(a, p) },
"lsb" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE) { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> x and 255 }},
"msb" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.UWORD, DataType.WORD))), DataType.UBYTE) { a, p, prg -> oneIntArgOutputInt(a, p, prg) { x: Int -> x ushr 8 and 255}},
"mkword" to FunctionSignature(true, listOf(
@ -216,61 +216,12 @@ private fun oneIntArgOutputInt(args: List<IExpression>, position: Position, prog
return numericLiteral(function(integer).toInt(), args[0].position)
}
private fun collectionArgOutputNumber(args: List<IExpression>, position: Position,
program: Program,
function: (arg: Collection<Double>)->Number): NumericLiteralValue {
private fun collectionArgNeverConst(args: List<IExpression>, position: Position): NumericLiteralValue {
if(args.size!=1)
throw SyntaxError("builtin function requires one non-scalar argument", position)
val iterable = args[0].constValue(program) ?: throw NotConstArgumentException()
TODO("collection functions over iterables (array, string) $iterable")
// val result = if(iterable.array != null) {
// val constants = iterable.arrayvalue.map { it.constValue(program)?.asNumericValue }
// if (null in constants)
// throw NotConstArgumentException()
// function(constants.map { it!!.toDouble() }).toDouble()
// } else {
// when(iterable.type) {
// DataType.UBYTE, DataType.UWORD, DataType.FLOAT -> throw SyntaxError("function expects an iterable type", position)
// else -> {
// val heapId = iterable.heapId ?: throw FatalAstException("iterable value should be on the heap")
// val array = program.heap.get(heapId).array ?: throw SyntaxError("function expects an iterable type", position)
// function(array.map {
// if(it.integer!=null)
// it.integer.toDouble()
// else
// throw FatalAstException("cannot perform function over array that contains other values besides constant integers")
// })
// }
// }
// }
// return numericLiteral(result, args[0].position)
}
private fun collectionArgOutputBoolean(args: List<IExpression>, position: Position,
program: Program,
function: (arg: Collection<Double>)->Boolean): NumericLiteralValue {
if(args.size!=1)
throw SyntaxError("builtin function requires one non-scalar argument", position)
val iterable = args[0].constValue(program) ?: throw NotConstArgumentException()
TODO("collection functions over iterables (array, string) $iterable")
// val result = if(iterable.arrayvalue != null) {
// val constants = iterable.arrayvalue.map { it.constValue(program)?.asNumericValue }
// if(null in constants)
// throw NotConstArgumentException()
// function(constants.map { it!!.toDouble() })
// } else {
// val array = program.heap.get(iterable.heapId!!).array ?: throw SyntaxError("function requires array argument", position)
// function(array.map {
// if(it.integer!=null)
// it.integer.toDouble()
// else
// throw FatalAstException("cannot perform function over array that contains other values besides constant integers")
// })
// }
// return LiteralValue.fromBoolean(result, position)
// max/min/sum etc only work on arrays and these are never considered to be const for these functions
throw NotConstArgumentException()
}
private fun builtinAbs(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue {
@ -286,36 +237,6 @@ private fun builtinAbs(args: List<IExpression>, position: Position, program: Pro
}
}
private fun builtinAvg(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue {
if(args.size!=1)
throw SyntaxError("avg requires array argument", position)
val iterable = args[0].constValue(program) ?: throw NotConstArgumentException()
TODO("collection functions over iterables (array, string) $iterable")
// val result = if(iterable.arrayvalue!=null) {
// val constants = iterable.arrayvalue.map { it.constValue(program)?.asNumericValue }
// if (null in constants)
// throw NotConstArgumentException()
// (constants.map { it!!.toDouble() }).average()
// }
// else {
// val heapId = iterable.heapId!!
// val integerarray = program.heap.get(heapId).array
// if(integerarray!=null) {
// if (integerarray.all { it.integer != null }) {
// integerarray.map { it.integer!! }.average()
// } else {
// throw ExpressionError("cannot avg() over array that does not only contain constant numerical values", position)
// }
// } else {
// val doublearray = program.heap.get(heapId).doubleArray
// doublearray?.average() ?: throw SyntaxError("avg requires array argument", position)
// }
// }
// return numericLiteral(result, args[0].position)
}
private fun builtinStrlen(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue {
if (args.size != 1)
throw SyntaxError("strlen requires one argument", position)
@ -323,14 +244,7 @@ private fun builtinStrlen(args: List<IExpression>, position: Position, program:
if(argument.type !in StringDatatypes)
throw SyntaxError("strlen must have string argument", position)
TODO("collection functions over iterables (array, string) $argument")
// val string = argument.strvalue!!
// val zeroIdx = string.indexOf('\u0000')
// return if(zeroIdx>=0)
// LiteralValue.optimalInteger(zeroIdx, position=position)
// else
// LiteralValue.optimalInteger(string.length, position=position)
throw NotConstArgumentException() // this function is not considering the string argument a constant
}
private fun builtinLen(args: List<IExpression>, position: Position, program: Program): NumericLiteralValue {

View File

@ -593,20 +593,22 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
return resultStmt
}
override fun visit(refLiteral: ReferenceLiteralValue): ReferenceLiteralValue {
override fun visit(refLiteral: ReferenceLiteralValue): IExpression {
val litval = super.visit(refLiteral)
if(litval.isString) {
// intern the string; move it into the heap
if(litval.str!!.length !in 1..255)
addError(ExpressionError("string literal length must be between 1 and 255", litval.position))
else {
litval.addToHeap(program.heap) // TODO: we don't know the actual string type yet, STR != STR_S etc...
if(litval is ReferenceLiteralValue) {
if (litval.isString) {
// intern the string; move it into the heap
if (litval.str!!.length !in 1..255)
addError(ExpressionError("string literal length must be between 1 and 255", litval.position))
else {
litval.addToHeap(program.heap) // TODO: we don't know the actual string type yet, STR != STR_S etc...
}
} else if (litval.isArray) {
// first, adjust the array datatype
val litval2 = adjustArrayValDatatype(litval)
litval2.addToHeap(program.heap)
return litval2
}
} else if(litval.isArray) {
// first, adjust the array datatype
val litval2 = adjustArrayValDatatype(litval)
litval2.addToHeap(program.heap)
return litval2
}
return litval
}

View File

@ -834,34 +834,34 @@ class AstVm(val program: Program) {
}
}
"max" -> {
val numbers = args.map { it.numericValue().toDouble() }
RuntimeValue(args[0].type, numbers.max())
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(ArrayElementTypes.getValue(args[0].type), numbers.max())
}
"min" -> {
val numbers = args.map { it.numericValue().toDouble() }
RuntimeValue(args[0].type, numbers.min())
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(ArrayElementTypes.getValue(args[0].type), numbers.min())
}
"avg" -> {
val numbers = args.map { it.numericValue().toDouble() }
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(DataType.FLOAT, numbers.average())
}
"sum" -> {
val sum = args.map { it.numericValue().toDouble() }.sum()
val sum = args.single().array!!.map { it.toDouble() }.sum()
when (args[0].type) {
DataType.UBYTE -> RuntimeValue(DataType.UWORD, sum)
DataType.BYTE -> RuntimeValue(DataType.WORD, sum)
DataType.UWORD -> RuntimeValue(DataType.UWORD, sum)
DataType.WORD -> RuntimeValue(DataType.WORD, sum)
DataType.FLOAT -> RuntimeValue(DataType.FLOAT, sum)
DataType.ARRAY_UB -> RuntimeValue(DataType.UWORD, sum)
DataType.ARRAY_B -> RuntimeValue(DataType.WORD, sum)
DataType.ARRAY_UW -> RuntimeValue(DataType.UWORD, sum)
DataType.ARRAY_W -> RuntimeValue(DataType.WORD, sum)
DataType.ARRAY_F -> RuntimeValue(DataType.FLOAT, sum)
else -> throw VmExecutionException("weird sum type ${args[0]}")
}
}
"any" -> {
val numbers = args.map { it.numericValue().toDouble() }
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(DataType.UBYTE, if (numbers.any { it != 0.0 }) 1 else 0)
}
"all" -> {
val numbers = args.map { it.numericValue().toDouble() }
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(DataType.UBYTE, if (numbers.all { it != 0.0 }) 1 else 0)
}
"swap" ->

View File

@ -1,19 +1,61 @@
%import c64utils
%import c64flt
%zeropage basicsafe
%option enable_floats
~ main {
sub start() {
str naam = "irmen"
byte[] array=[1,2,3,4,5]
ubyte length = len(naam)
c64scr.print(naam)
c64scr.print("irmen")
c64scr.print("irmen2")
c64scr.print("irmen2")
ubyte length2 = len("irmen") ; @todo same string as 'naam'
ubyte length3 = len("zxfdsfsf") ; @todo new string
ubyte length = len(array)
c64scr.print_ub(length)
c64.CHROUT(',')
ubyte length1 = any(array)
c64scr.print_ub(length1)
c64.CHROUT(',')
ubyte length1b = all(array)
c64scr.print_ub(length1b)
c64.CHROUT(',')
ubyte length1c = max(array)
c64scr.print_ub(length1c)
c64.CHROUT(',')
ubyte length1d = min(array)
c64scr.print_ub(length1d)
c64.CHROUT(',')
ubyte xlength = len([1,2,3])
c64scr.print_ub(xlength)
c64.CHROUT('\n')
ubyte xlength1 = any([1,0,3])
c64scr.print_ub(xlength1)
c64.CHROUT(',')
ubyte xlength1b = all([1,0,3])
c64scr.print_ub(xlength1b)
c64.CHROUT(',')
ubyte xlength1c = max([1,2,3])
c64scr.print_ub(xlength1c)
c64.CHROUT(',')
ubyte xlength1d = min([1,2,3])
c64scr.print_ub(xlength1d)
c64.CHROUT('\n')
word s1 = sum(array)
c64scr.print_w(s1)
c64.CHROUT(',')
uword s2 = sum([1,23])
c64scr.print_uw(s2)
c64.CHROUT(',')
float ff1=avg(array)
c64flt.print_f(ff1)
c64.CHROUT(',')
float ff2=avg([1,2,3])
c64flt.print_f(ff2)
c64.CHROUT('\n')
return
}