From 34dec55eb207be36b0c8df71257bb5a84293b8fc Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sat, 29 Sep 2018 19:17:19 +0200 Subject: [PATCH] fix builtin functions over non-const arrays/strings --- compiler/examples/test.p8 | 68 +++++++++++----- compiler/src/prog8/compiler/Compiler.kt | 27 ++++++- .../src/prog8/functions/BuiltinFunctions.kt | 50 ++++++------ .../src/prog8/optimizing/ConstantFolding.kt | 4 +- compiler/src/prog8/stackvm/StackVm.kt | 80 ++++++++++++------- 5 files changed, 155 insertions(+), 74 deletions(-) diff --git a/compiler/examples/test.p8 b/compiler/examples/test.p8 index 5a8db23a7..a1cce409c 100644 --- a/compiler/examples/test.p8 +++ b/compiler/examples/test.p8 @@ -16,31 +16,59 @@ sub start() -> () { byte[2,3] matrix1 = [1,2, 3,4, 5,6] byte[2,3] matrix2 = [1,2, 3,4, 5,6] byte[2,3] matrix3 = [11,22, 33,44, 55,66] - str message = "Calculating Mandelbrot Fractal..." - _vm_gfx_text(5, 5, 7, message) - _vm_gfx_text(5, 5, 7, "Calculating Mandelbrot Fractal...") + byte num1 = 99 + word num2 = 12345 + float num3 = 98.555 - A= len("abcdef") - A= len([4,5,99/X]) - A= max([4,5,99]) - A= min([4,5,99]) - A= avg([4,5,99]) - A= sum([4,5,99]) - A= any([4,5,99]) - A= all([4,5,99]) + num1 = max(msg3) + _vm_write_str(num1) + _vm_write_char($8d) + num1 = min(msg3) + _vm_write_str(num1) + _vm_write_char($8d) - A= len(msg3) + num3 = avg(array3) + _vm_write_str(num3) + _vm_write_char($8d) + num1 = all(array3) + _vm_write_str(num1) + _vm_write_char($8d) + num1 = any(array3) + _vm_write_str(num1) + _vm_write_char($8d) - float xx + num1 = max(array3) + _vm_write_str(num1) + _vm_write_char($8d) + num1 = min(array3) + _vm_write_str(num1) + _vm_write_char($8d) + num2 = max(array5) + _vm_write_str(num2) + _vm_write_char($8d) + num2 = min(array5) + _vm_write_str(num2) + _vm_write_char($8d) + num2 = sum(array3) + _vm_write_str(num2) + _vm_write_char($8d) + num2 = sum(array5) + _vm_write_str(num2) + _vm_write_char($8d) - A= len(array3) - A= max(array3) - A= min(array3) - xx= avg(array3) - A= sum(array3) - A= any(array3) - A= all(array3) + num3 = avg(msg3) + _vm_write_str(num3) + _vm_write_char($8d) + num2 = sum(msg3) + _vm_write_str(num2) + _vm_write_char($8d) + num1 = all(msg3) + _vm_write_str(num1) + _vm_write_char($8d) + num1 = any(msg3) + _vm_write_str(num1) + _vm_write_char($8d) } } diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index d7f850c59..749b66dec 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -25,7 +25,7 @@ fun Number.toHex(): String { class HeapValues { - class HeapValue(val type: DataType, val str: String?, val array: IntArray?) { + data class HeapValue(val type: DataType, val str: String?, val array: IntArray?) { override fun equals(other: Any?): Boolean { if (this === other) return true if (javaClass != other?.javaClass) return false @@ -62,6 +62,31 @@ class HeapValues { return heap.size-1 } + fun update(heapId: Int, str: String) { + when(heap[heapId].type){ + DataType.STR, + DataType.STR_P, + DataType.STR_S, + DataType.STR_PS -> { + if(heap[heapId].str!!.length!=str.length) + throw IllegalArgumentException("heap string length mismatch") + heap[heapId] = heap[heapId].copy(str=str) + } + else-> throw IllegalArgumentException("heap data type mismatch") + } + } + + fun update(heapId: Int, array: IntArray) { + when(heap[heapId].type){ + DataType.ARRAY, DataType.ARRAY_W, DataType.MATRIX -> { + if(heap[heapId].array!!.size != array.size) + throw IllegalArgumentException("heap array length mismatch") + heap[heapId] = heap[heapId].copy(array=array) + } + else-> throw IllegalArgumentException("heap data type mismatch") + } + } + fun get(heapId: Int): HeapValue = heap[heapId] fun allStrings() = heap.asSequence().withIndex().filter { it.value.str!=null }.toList() diff --git a/compiler/src/prog8/functions/BuiltinFunctions.kt b/compiler/src/prog8/functions/BuiltinFunctions.kt index a4e055237..34c264a1a 100644 --- a/compiler/src/prog8/functions/BuiltinFunctions.kt +++ b/compiler/src/prog8/functions/BuiltinFunctions.kt @@ -43,6 +43,17 @@ fun builtinFunctionReturnType(function: String, args: List, namespa return DataType.BYTE } } + if(arglist is IdentifierReference) { + val dt = arglist.resultingDatatype(namespace, heap) + return when(dt) { + DataType.BYTE, DataType.WORD, DataType.FLOAT, + DataType.STR, DataType.STR_P, DataType.STR_S, DataType.STR_PS -> dt + DataType.ARRAY -> DataType.BYTE + DataType.ARRAY_W -> DataType.WORD + DataType.MATRIX -> DataType.BYTE + null -> throw FatalAstException("function requires one argument which is an array $function") + } + } throw FatalAstException("function requires one argument which is an array $function") } @@ -53,7 +64,16 @@ fun builtinFunctionReturnType(function: String, args: List, namespa "rndw" -> DataType.WORD "rol", "rol2", "ror", "ror2", "lsl", "lsr", "set_carry", "clear_carry", "set_irqd", "clear_irqd" -> null // no return value so no datatype "abs" -> args.single().resultingDatatype(namespace, heap) - "max", "min" -> datatypeFromListArg(args.single()) + "max", "min" -> { + val dt = datatypeFromListArg(args.single()) + when(dt) { + DataType.BYTE, DataType.WORD, DataType.FLOAT -> dt + DataType.STR, DataType.STR_P, DataType.STR_S, DataType.STR_PS -> DataType.BYTE + DataType.ARRAY -> DataType.BYTE + DataType.ARRAY_W -> DataType.WORD + DataType.MATRIX -> DataType.BYTE + } + } "round", "floor", "ceil" -> integerDatatypeFromArg(args.single()) "sum" -> { val dt=datatypeFromListArg(args.single()) @@ -61,8 +81,8 @@ fun builtinFunctionReturnType(function: String, args: List, namespa DataType.BYTE, DataType.WORD -> DataType.WORD DataType.FLOAT -> DataType.FLOAT DataType.ARRAY, DataType.ARRAY_W -> DataType.WORD - DataType.MATRIX -> DataType.BYTE - else -> throw FatalAstException("cannot sum over type $dt") + DataType.MATRIX -> DataType.WORD + DataType.STR, DataType.STR_P, DataType.STR_S, DataType.STR_PS -> DataType.WORD } } "len" -> { @@ -135,13 +155,7 @@ private fun collectionArgOutputNumber(args: List, position: Positio function: (arg: Collection)->Number): LiteralValue { if(args.size!=1) throw SyntaxError("builtin function requires one non-scalar argument", position) - var iterable = args[0].constValue(namespace, heap) - if(iterable==null) { - if(args[0] !is IdentifierReference) - throw SyntaxError("function over weird argument ${args[0]}", position) - iterable = ((args[0] as IdentifierReference).targetStatement(namespace) as? VarDecl)?.value?.constValue(namespace, heap) - ?: throw SyntaxError("function over weird argument ${args[0]}", position) - } + var iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() val result = if(iterable.arrayvalue != null) { val constants = iterable.arrayvalue!!.map { it.constValue(namespace, heap)?.asNumericValue } @@ -160,13 +174,7 @@ private fun collectionArgOutputBoolean(args: List, position: Positi function: (arg: Collection)->Boolean): LiteralValue { if(args.size!=1) throw SyntaxError("builtin function requires one non-scalar argument", position) - var iterable = args[0].constValue(namespace, heap) - if(iterable==null) { - if(args[0] !is IdentifierReference) - throw SyntaxError("function over weird argument ${args[0]}", position) - iterable = ((args[0] as IdentifierReference).targetStatement(namespace) as? VarDecl)?.value?.constValue(namespace, heap) - ?: throw SyntaxError("function over weird argument ${args[0]}", position) - } + var iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() val result = if(iterable.arrayvalue != null) { val constants = iterable.arrayvalue!!.map { it.constValue(namespace, heap)?.asNumericValue } @@ -268,13 +276,7 @@ fun builtinSum(args: List, position: Position, namespace:INameScope fun builtinAvg(args: List, position: Position, namespace:INameScope, heap: HeapValues): LiteralValue { if(args.size!=1) throw SyntaxError("avg requires array/matrix argument", position) - var iterable = args[0].constValue(namespace, heap) - if(iterable==null) { - if(args[0] !is IdentifierReference) - throw SyntaxError("avg over weird argument ${args[0]}", position) - iterable = ((args[0] as IdentifierReference).targetStatement(namespace) as? VarDecl)?.value?.constValue(namespace, heap) - ?: throw SyntaxError("avg over weird argument ${args[0]}", position) - } + var iterable = args[0].constValue(namespace, heap) ?: throw NotConstArgumentException() val result = if(iterable.arrayvalue!=null) { val constants = iterable.arrayvalue!!.map { it.constValue(namespace, heap)?.asNumericValue } diff --git a/compiler/src/prog8/optimizing/ConstantFolding.kt b/compiler/src/prog8/optimizing/ConstantFolding.kt index 4dac81d73..b7b43563f 100644 --- a/compiler/src/prog8/optimizing/ConstantFolding.kt +++ b/compiler/src/prog8/optimizing/ConstantFolding.kt @@ -317,7 +317,7 @@ class ConstantFolding(private val namespace: INameScope, private val heap: HeapV val array = newArray.map { val litval = it as? LiteralValue if(litval==null) { - addError(ExpressionError("array/matrix can contain only constant values", literalValue.position)) + addError(ExpressionError("array/matrix literal can contain only constant values", literalValue.position)) return super.process(literalValue) } if(litval.bytevalue==null && litval.wordvalue==null) { @@ -341,7 +341,7 @@ class ConstantFolding(private val namespace: INameScope, private val heap: HeapV val newValue = LiteralValue(arrayDt, heapId=heapId, position = literalValue.position) return super.process(newValue) } else { - addError(ExpressionError("array/matrix can contain only constant values", literalValue.position)) + addError(ExpressionError("array/matrix literal can contain only constant values", literalValue.position)) } val newValue = LiteralValue(arrayDt, arrayvalue = newArray, position = literalValue.position) diff --git a/compiler/src/prog8/stackvm/StackVm.kt b/compiler/src/prog8/stackvm/StackVm.kt index 06afbab32..8e3ee8230 100644 --- a/compiler/src/prog8/stackvm/StackVm.kt +++ b/compiler/src/prog8/stackvm/StackVm.kt @@ -483,9 +483,11 @@ class StackVm(private var traceOutputFile: String?) { } } Syscall.INPUT_STR -> { - val maxlen = evalstack.pop().integerValue() - val input = readLine()?.substring(0, maxlen) ?: "" - TODO("input_str opcode should put the string in a given heap location (overwriting old string)") + val variable = evalstack.pop() + val value = heap.get(variable.heapId) + val maxlen = value.str!!.length + val input = readLine() ?: "" + heap.update(variable.heapId, input.padEnd(maxlen, '\u0000').substring(0, maxlen)) } Syscall.GFX_PIXEL -> { // plot pixel at (x, y, color) from stack @@ -556,45 +558,69 @@ class StackVm(private var traceOutputFile: String?) { } Syscall.FUNC_MAX -> { val iterable = evalstack.pop() - val dt = - when(iterable.type) { - DataType.STR, DataType.STR_P, DataType.STR_S, DataType.STR_PS, - DataType.ARRAY, DataType.MATRIX -> DataType.BYTE - DataType.ARRAY_W -> DataType.WORD - else -> throw VmExecutionException("uniterable value $iterable") - } - TODO("func_max on array/matrix/string") + val value = heap.get(iterable.heapId) + val resultDt = when(iterable.type) { + DataType.STR, DataType.STR_P, DataType.STR_S, DataType.STR_PS -> DataType.BYTE + DataType.ARRAY, DataType.MATRIX -> DataType.BYTE + DataType.ARRAY_W -> DataType.WORD + else -> throw VmExecutionException("uniterable value $iterable") + } + if(value.str!=null) { + val result = Petscii.encodePetscii(value.str.max().toString(), true)[0] + evalstack.push(Value(DataType.BYTE, result)) + } else { + val result = value.array!!.max() ?: 0 + evalstack.push(Value(resultDt, result)) + } } Syscall.FUNC_MIN -> { val iterable = evalstack.pop() - val dt = - when(iterable.type) { - DataType.STR, DataType.STR_P, DataType.STR_S, DataType.STR_PS, - DataType.ARRAY, DataType.MATRIX -> DataType.BYTE - DataType.ARRAY_W -> DataType.WORD - else -> throw VmExecutionException("uniterable value $iterable") - } - TODO("func_min on array/matrix/string") + val value = heap.get(iterable.heapId) + val resultDt = when(iterable.type) { + DataType.STR, DataType.STR_P, DataType.STR_S, DataType.STR_PS -> DataType.BYTE + DataType.ARRAY, DataType.MATRIX -> DataType.BYTE + DataType.ARRAY_W -> DataType.WORD + else -> throw VmExecutionException("uniterable value $iterable") + } + if(value.str!=null) { + val result = Petscii.encodePetscii(value.str.min().toString(), true)[0] + evalstack.push(Value(DataType.BYTE, result)) + } else { + val result = value.array!!.min() ?: 0 + evalstack.push(Value(resultDt, result)) + } } Syscall.FUNC_AVG -> { val iterable = evalstack.pop() - TODO("func_avg") -// evalstack.push(Value(DataType.FLOAT, array.arrayvalue!!.average())) + val value = heap.get(iterable.heapId) + if(value.str!=null) + evalstack.push(Value(DataType.FLOAT, Petscii.encodePetscii(value.str, true).average())) + else + evalstack.push(Value(DataType.FLOAT, value.array!!.average())) } Syscall.FUNC_SUM -> { val iterable = evalstack.pop() - TODO("func_sum") -// evalstack.push(Value(DataType.WORD, array.arrayvalue!!.sum())) + val value = heap.get(iterable.heapId) + if(value.str!=null) + evalstack.push(Value(DataType.WORD, Petscii.encodePetscii(value.str, true).sum())) + else + evalstack.push(Value(DataType.WORD, value.array!!.sum())) } Syscall.FUNC_ANY -> { val iterable = evalstack.pop() - TODO("func_any") -// evalstack.push(Value(DataType.BYTE, if(array.arrayvalue!!.any{ v -> v != 0}) 1 else 0)) + val value = heap.get(iterable.heapId) + if (value.str != null) + evalstack.push(Value(DataType.BYTE, if (Petscii.encodePetscii(value.str, true).any { c -> c != 0.toShort() }) 1 else 0)) + else + evalstack.push(Value(DataType.BYTE, if (value.array!!.any{v->v!=0}) 1 else 0)) } Syscall.FUNC_ALL -> { val iterable = evalstack.pop() - TODO("func_all") -// evalstack.push(Value(DataType.BYTE, if(array.arrayvalue!!.all{ v -> v != 0}) 1 else 0)) + val value = heap.get(iterable.heapId) + if (value.str != null) + evalstack.push(Value(DataType.BYTE, if (Petscii.encodePetscii(value.str, true).all { c -> c != 0.toShort() }) 1 else 0)) + else + evalstack.push(Value(DataType.BYTE, if (value.array!!.all{v->v!=0}) 1 else 0)) } else -> throw VmExecutionException("unimplemented syscall $syscall") }