diff --git a/compiler/src/prog8/ast/antlr/Antr2Kotlin.kt b/compiler/src/prog8/ast/antlr/Antr2Kotlin.kt index 501d4444e..831f17114 100644 --- a/compiler/src/prog8/ast/antlr/Antr2Kotlin.kt +++ b/compiler/src/prog8/ast/antlr/Antr2Kotlin.kt @@ -429,7 +429,7 @@ private fun prog8Parser.ExpressionContext.toAst() : Expression { else -> throw FatalAstException("invalid datatype for numeric literal") } litval.floatliteral()!=null -> NumericLiteralValue(DataType.FLOAT, litval.floatliteral().toAst(), litval.toPosition()) - litval.stringliteral()!=null -> StringLiteralValue(DataType.STR, unescape(litval.stringliteral().text, litval.toPosition()), position = litval.toPosition()) + litval.stringliteral()!=null -> StringLiteralValue(DataType.STR, unescape(litval.stringliteral().text, litval.toPosition()), null, litval.toPosition()) litval.charliteral()!=null -> { try { NumericLiteralValue(DataType.UBYTE, Petscii.encodePetscii(unescape(litval.charliteral().text, litval.toPosition()), true)[0], litval.toPosition()) diff --git a/compiler/src/prog8/ast/expressions/AstExpressions.kt b/compiler/src/prog8/ast/expressions/AstExpressions.kt index cda512169..ae6842aaa 100644 --- a/compiler/src/prog8/ast/expressions/AstExpressions.kt +++ b/compiler/src/prog8/ast/expressions/AstExpressions.kt @@ -414,13 +414,10 @@ class StructLiteralValue(var values: List, class StringLiteralValue(val type: DataType, // only string types val value: String, - initHeapId: Int? =null, + var heapId: Int?, override val position: Position) : Expression() { override lateinit var parent: Node - var heapId = initHeapId - private set - override fun linkParents(parent: Node) { this.parent = parent } @@ -441,8 +438,10 @@ class StringLiteralValue(val type: DataType, // only string types fun addToHeap(heap: HeapValues) { if (heapId != null) return - else - heapId = heap.addString(type, value) + else { + val encodedStr = Petscii.encodePetscii(value, true) + heapId = heap.addIntegerArray(DataType.ARRAY_UB, encodedStr.map { IntegerOrAddressOf(it.toInt(), null)}.toTypedArray()) + } } } diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index 60188e3c5..7c101f7c1 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -965,8 +965,7 @@ internal class AstChecker(private val program: Program, } else if(target.datatype in StringDatatypes) { if(target.value is StringLiteralValue) { // check string lengths for non-memory mapped strings - val heapId = (target.value as StringLiteralValue).heapId!! - val stringLen = program.heap.get(heapId).str!!.length + val stringLen = (target.value as StringLiteralValue).value.length val index = (arrayIndexedExpression.arrayspec.index as? NumericLiteralValue)?.number?.toInt() if (index != null && (index < 0 || index >= stringLen)) checkResult.add(ExpressionError("index out of bounds", arrayIndexedExpression.arrayspec.position)) @@ -1231,7 +1230,7 @@ internal class AstChecker(private val program: Program, return correct } - val array = program.heap.get(value.heapId!!) + val array = program.heap.get(value.heapId!!) // TODO use value.array directly? if(array.type !in ArrayDatatypes || (array.array==null && array.doubleArray==null)) throw FatalAstException("should have an array in the heapvar $array") diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 3107d8d53..6aaf7cf45 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -69,4 +69,3 @@ fun loadAsmIncludeFile(filename: String, source: Path): String { internal fun tryGetEmbeddedResource(name: String): InputStream? { return object{}.javaClass.getResourceAsStream("/prog8lib/$name") } - diff --git a/compiler/src/prog8/compiler/HeapValues.kt b/compiler/src/prog8/compiler/HeapValues.kt index 3be033079..e64538470 100644 --- a/compiler/src/prog8/compiler/HeapValues.kt +++ b/compiler/src/prog8/compiler/HeapValues.kt @@ -2,19 +2,23 @@ package prog8.compiler import prog8.ast.base.ArrayDatatypes import prog8.ast.base.DataType -import prog8.ast.base.StringDatatypes import java.util.* +/** + * The 'heapvalues' is the collection of variables that are allocated globally. + * Arrays and strings belong here. + * They get assigned a heapId to be able to retrieve them later. + */ class HeapValues { - data class HeapValue(val type: DataType, val str: String?, val array: Array?, val doubleArray: DoubleArray?) { + data class HeapValue(val type: DataType, val array: Array?, val doubleArray: DoubleArray?) { override fun equals(other: Any?): Boolean { if (this === other) return true if (javaClass != other?.javaClass) return false other as HeapValue - return type==other.type && str==other.str && Arrays.equals(array, other.array) && Arrays.equals(doubleArray, other.doubleArray) + return type==other.type && Arrays.equals(array, other.array) && Arrays.equals(doubleArray, other.doubleArray) } - override fun hashCode(): Int = Objects.hash(str, array, doubleArray) + override fun hashCode(): Int = Objects.hash(type, array, doubleArray) } private val heap = mutableMapOf() @@ -22,49 +26,28 @@ class HeapValues { fun size(): Int = heap.size - fun addString(type: DataType, str: String): Int { - if (str.length > 255) - throw IllegalArgumentException("string length must be 0-255") - - // strings are 'interned' and shared if they're the isSameAs - val value = HeapValue(type, str, null, null) - - val existing = heap.filter { it.value==value }.map { it.key }.firstOrNull() - if(existing!=null) - return existing - val newId = heapId++ - heap[newId] = value - return newId - } - fun addIntegerArray(type: DataType, array: Array): Int { // arrays are never shared, don't check for existing if(type !in ArrayDatatypes) - throw CompilerException("wrong array type") + throw CompilerException("wrong array type $type") val newId = heapId++ - heap[newId] = HeapValue(type, null, array, null) + heap[newId] = HeapValue(type, array, null) return newId } fun addDoublesArray(darray: DoubleArray): Int { // arrays are never shared, don't check for existing val newId = heapId++ - heap[newId] = HeapValue(DataType.ARRAY_F, null, null, darray) + heap[newId] = HeapValue(DataType.ARRAY_F, null, darray) return newId } - fun updateString(heapId: Int, str: String) { - val oldVal = heap[heapId] ?: throw IllegalArgumentException("heapId not found in heap") - if(oldVal.type in StringDatatypes) { - if (oldVal.str!!.length != str.length) - throw IllegalArgumentException("heap string length mismatch") - heap[heapId] = oldVal.copy(str = str) - } - else throw IllegalArgumentException("heap data type mismatch") - } - fun get(heapId: Int): HeapValue { return heap[heapId] ?: throw IllegalArgumentException("heapId $heapId not found in heap") } + + fun remove(heapId: Int) { + heap.remove(heapId) + } } diff --git a/compiler/src/prog8/compiler/Main.kt b/compiler/src/prog8/compiler/Main.kt index 304e34c7c..e4a2b0ffc 100644 --- a/compiler/src/prog8/compiler/Main.kt +++ b/compiler/src/prog8/compiler/Main.kt @@ -96,7 +96,7 @@ fun compileProgram(filepath: Path, programAst.checkValid(compilerOptions) // check if final tree is valid programAst.checkRecursion() // check if there are recursive subroutine calls - // printAst(programAst) + printAst(programAst) if(writeAssembly) { // asm generation directly from the Ast, no need for intermediate code diff --git a/compiler/src/prog8/optimizer/StatementOptimizer.kt b/compiler/src/prog8/optimizer/StatementOptimizer.kt index ff5c275af..f6aac9c1a 100644 --- a/compiler/src/prog8/optimizer/StatementOptimizer.kt +++ b/compiler/src/prog8/optimizer/StatementOptimizer.kt @@ -27,10 +27,15 @@ internal class StatementOptimizer(private val program: Program) : IAstModifyingV private val pureBuiltinFunctions = BuiltinFunctions.filter { it.value.pure } private val callgraph = CallGraph(program) + private val vardeclsToRemove = mutableListOf() override fun visit(program: Program) { removeUnusedCode(callgraph) super.visit(program) + + for(decl in vardeclsToRemove) { + decl.definingScope().remove(decl) + } } private fun removeUnusedCode(callgraph: CallGraph) { @@ -165,24 +170,33 @@ internal class StatementOptimizer(private val program: Program) : IAstModifyingV if(functionCallStatement.target.nameInSource==listOf("c64scr", "print") || functionCallStatement.target.nameInSource==listOf("c64scr", "print_p")) { // printing a literal string of just 2 or 1 characters is replaced by directly outputting those characters - val stringVar = functionCallStatement.arglist.single() as? IdentifierReference + val arg = functionCallStatement.arglist.single() + val stringVar: IdentifierReference? + if(arg is AddressOf) { + stringVar = arg.identifier + } else { + stringVar = arg as? IdentifierReference + } if(stringVar!=null) { - val heapId = stringVar.heapId(program.namespace) - val string = program.heap.get(heapId).str!! - if(string.length==1) { - val petscii = Petscii.encodePetscii(string, true)[0] + val vardecl = stringVar.targetVarDecl(program.namespace)!! + val string = vardecl.value!! as StringLiteralValue + val encodedString = Petscii.encodePetscii(string.value, true) + if(string.value.length==1) { functionCallStatement.arglist.clear() - functionCallStatement.arglist.add(NumericLiteralValue.optimalInteger(petscii.toInt(), functionCallStatement.position)) + functionCallStatement.arglist.add(NumericLiteralValue.optimalInteger(encodedString[0].toInt(), functionCallStatement.position)) functionCallStatement.target = IdentifierReference(listOf("c64", "CHROUT"), functionCallStatement.target.position) + vardeclsToRemove.add(vardecl) + program.heap.remove(string.heapId!!) optimizationsDone++ return functionCallStatement - } else if(string.length==2) { - val petscii = Petscii.encodePetscii(string, true) + } else if(string.value.length==2) { val scope = AnonymousScope(mutableListOf(), functionCallStatement.position) scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("c64", "CHROUT"), functionCallStatement.target.position), - mutableListOf(NumericLiteralValue.optimalInteger(petscii[0].toInt(), functionCallStatement.position)), functionCallStatement.position)) + mutableListOf(NumericLiteralValue.optimalInteger(encodedString[0].toInt(), functionCallStatement.position)), functionCallStatement.position)) scope.statements.add(FunctionCallStatement(IdentifierReference(listOf("c64", "CHROUT"), functionCallStatement.target.position), - mutableListOf(NumericLiteralValue.optimalInteger(petscii[1].toInt(), functionCallStatement.position)), functionCallStatement.position)) + mutableListOf(NumericLiteralValue.optimalInteger(encodedString[1].toInt(), functionCallStatement.position)), functionCallStatement.position)) + vardeclsToRemove.add(vardecl) + program.heap.remove(string.heapId!!) optimizationsDone++ return scope } diff --git a/compiler/src/prog8/vm/astvm/AstVm.kt b/compiler/src/prog8/vm/astvm/AstVm.kt index 37ce114d1..3947389e2 100644 --- a/compiler/src/prog8/vm/astvm/AstVm.kt +++ b/compiler/src/prog8/vm/astvm/AstVm.kt @@ -662,8 +662,8 @@ class AstVm(val program: Program) { "c64scr.print" -> { // if the argument is an UWORD, consider it to be the "address" of the string (=heapId) if (args[0].wordval != null) { - val str = program.heap.get(args[0].wordval!!).str!! - dialog.canvas.printText(str, true) + val encodedStr = program.heap.get(args[0].wordval!!).array!!.map { it.integer!!.toShort() } + dialog.canvas.printText(encodedStr) } else throw VmExecutionException("print non-heap string") } @@ -738,10 +738,11 @@ class AstVm(val program: Program) { } val inputStr = input.joinToString("") val heapId = args[0].wordval!! - val origStr = program.heap.get(heapId).str!! - val paddedStr=inputStr.padEnd(origStr.length+1, '\u0000').substring(0, origStr.length) - program.heap.updateString(heapId, paddedStr) - result = RuntimeValueNumeric(DataType.UBYTE, paddedStr.indexOf('\u0000')) + val origStrLength = program.heap.get(heapId).array!!.size + val encodedStr = Petscii.encodePetscii(inputStr, true).take(origStrLength).toMutableList() + while(encodedStr.size { dialog.canvas.printText(args[0].floatval.toString(), false) @@ -761,8 +762,8 @@ class AstVm(val program: Program) { } "c64utils.str2uword" -> { val heapId = args[0].wordval!! - val argString = program.heap.get(heapId).str!! - val numericpart = argString.takeWhile { it.isDigit() } + val argString = program.heap.get(heapId).array!!.map { it.integer!!.toChar() } + val numericpart = argString.takeWhile { it.isDigit() }.toString() result = RuntimeValueNumeric(DataType.UWORD, numericpart.toInt() and 65535) } else -> TODO("syscall ${sub.scopedname} $sub") diff --git a/compiler/src/prog8/vm/astvm/ScreenDialog.kt b/compiler/src/prog8/vm/astvm/ScreenDialog.kt index f07b70ef0..79546b10d 100644 --- a/compiler/src/prog8/vm/astvm/ScreenDialog.kt +++ b/compiler/src/prog8/vm/astvm/ScreenDialog.kt @@ -75,6 +75,10 @@ class BitmapScreenPanel : KeyListener, JPanel() { } } + fun printText(text: Iterable) { + text.forEach { printPetscii(it, false) } + } + fun printPetscii(char: Short, inverseVideo: Boolean=false) { if(char==13.toShort() || char==141.toShort()) { cursorX=0 diff --git a/compiler/test/LiteralValueTests.kt b/compiler/test/LiteralValueTests.kt index 89cfab9e3..acd5b6eb1 100644 --- a/compiler/test/LiteralValueTests.kt +++ b/compiler/test/LiteralValueTests.kt @@ -83,8 +83,8 @@ class TestParserNumericLiteralValue { @Test fun testEqualsRef() { - assertTrue(StringLiteralValue(DataType.STR, "hello", position = dummyPos) == StringLiteralValue(DataType.STR, "hello", position = dummyPos)) - assertFalse(StringLiteralValue(DataType.STR, "hello", position = dummyPos) == StringLiteralValue(DataType.STR, "bye", position = dummyPos)) + assertTrue(StringLiteralValue(DataType.STR, "hello", null, dummyPos) == StringLiteralValue(DataType.STR, "hello", null, dummyPos)) + assertFalse(StringLiteralValue(DataType.STR, "hello", null, dummyPos) == StringLiteralValue(DataType.STR, "bye", null, dummyPos)) val lvOne = NumericLiteralValue(DataType.UBYTE, 1, dummyPos) val lvTwo = NumericLiteralValue(DataType.UBYTE, 2, dummyPos) @@ -93,9 +93,9 @@ class TestParserNumericLiteralValue { val lvTwoR = NumericLiteralValue(DataType.UBYTE, 2, dummyPos) val lvThreeR = NumericLiteralValue(DataType.UBYTE, 3, dummyPos) val lvFour= NumericLiteralValue(DataType.UBYTE, 4, dummyPos) - val lv1 = ArrayLiteralValue(DataType.ARRAY_UB, arrayOf(lvOne, lvTwo, lvThree), position = dummyPos) - val lv2 = ArrayLiteralValue(DataType.ARRAY_UB, arrayOf(lvOneR, lvTwoR, lvThreeR), position = dummyPos) - val lv3 = ArrayLiteralValue(DataType.ARRAY_UB, arrayOf(lvOneR, lvTwoR, lvFour), position = dummyPos) + val lv1 = ArrayLiteralValue(DataType.ARRAY_UB, arrayOf(lvOne, lvTwo, lvThree), null, dummyPos) + val lv2 = ArrayLiteralValue(DataType.ARRAY_UB, arrayOf(lvOneR, lvTwoR, lvThreeR), null, dummyPos) + val lv3 = ArrayLiteralValue(DataType.ARRAY_UB, arrayOf(lvOneR, lvTwoR, lvFour), null, dummyPos) assertEquals(lv1, lv2) assertNotEquals(lv1, lv3) } diff --git a/compiler/test/UnitTests.kt b/compiler/test/UnitTests.kt index 54c3f8b99..b9f4d8009 100644 --- a/compiler/test/UnitTests.kt +++ b/compiler/test/UnitTests.kt @@ -371,8 +371,8 @@ class TestPetscii { assertTrue(ten <= ten) assertFalse(ten < ten) - val abc = StringLiteralValue(DataType.STR, "abc", position = Position("", 0, 0, 0)) - val abd = StringLiteralValue(DataType.STR, "abd", position = Position("", 0, 0, 0)) + val abc = StringLiteralValue(DataType.STR, "abc", null, Position("", 0, 0, 0)) + val abd = StringLiteralValue(DataType.STR, "abd", null, Position("", 0, 0, 0)) assertEquals(abc, abc) assertTrue(abc!=abd) assertFalse(abc!=abc) diff --git a/examples/test.p8 b/examples/test.p8 index 8b2ebbd84..e7e83e124 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -12,80 +12,94 @@ main { str_s strs2 = "test" sub start() { - str str1x = "irmen" - str str2x = "test" - str_s strs1x = "irmen" - str_s strs2x = "test" - c64scr.print(str1) - c64.CHROUT('\n') - c64scr.print(str2) - c64.CHROUT('\n') - c64scr.print(str1x) - c64.CHROUT('\n') - c64scr.print(str2x) - c64.CHROUT('\n') + str foo1 = "1\n" + str foo2 = "12\n" - str1[0]='a' - str2[0]='a' - str1x[0]='a' - str2x[0]='a' - strs1x[0]='a' - strs2x[0]='a' - strs1[0]='a' - strs2[0]='a' + c64scr.print(foo1) + c64scr.print(foo2) + c64scr.print("\n") + c64scr.print("1\n") + c64scr.print("12\n") + c64scr.print("\n") + c64scr.print("1\n") + c64scr.print("12\n") - ; @TODO fix AstVm handling of strings (they're not modified right now) NOTE: array's seem to work fine - c64scr.print(str1) - c64.CHROUT('\n') - c64scr.print(str2) - c64.CHROUT('\n') - c64scr.print(str1x) - c64.CHROUT('\n') - c64scr.print(str2x) - c64.CHROUT('\n') - - - byte[] barr = [1,2,3] - word[] warr = [1000,2000,3000] - float[] farr = [1.1, 2.2, 3.3] - - byte bb - word ww - float ff - for bb in barr { - c64scr.print_b(bb) - c64.CHROUT(',') - } - c64.CHROUT('\n') - for ww in warr { - c64scr.print_w(ww) - c64.CHROUT(',') - } - c64.CHROUT('\n') - for bb in 0 to len(farr)-1 { - c64flt.print_f(farr[bb]) - c64.CHROUT(',') - } - c64.CHROUT('\n') - - barr[0] = 99 - warr[0] = 99 - farr[0] = 99.9 - for bb in barr { - c64scr.print_b(bb) - c64.CHROUT(',') - } - c64.CHROUT('\n') - for ww in warr { - c64scr.print_w(ww) - c64.CHROUT(',') - } - c64.CHROUT('\n') - for bb in 0 to len(farr)-1 { - c64flt.print_f(farr[bb]) - c64.CHROUT(',') - } +; str str1x = "irmen" +; str str2x = "test" +; str_s strs1x = "irmen" +; str_s strs2x = "test" +; +; c64scr.print("yoooooo") +; c64scr.print(str1) +; c64.CHROUT('\n') +; c64scr.print(str2) +; c64.CHROUT('\n') +; c64scr.print(str1x) +; c64.CHROUT('\n') +; c64scr.print(str2x) +; c64.CHROUT('\n') +; +; str1[0]='a' +; str2[0]='a' +; str1x[0]='a' +; str2x[0]='a' +; strs1x[0]='a' +; strs2x[0]='a' +; strs1[0]='a' +; strs2[0]='a' +; +; ; @TODO fix AstVm handling of strings (they're not modified right now) NOTE: array's seem to work fine +; c64scr.print(str1) +; c64.CHROUT('\n') +; c64scr.print(str2) +; c64.CHROUT('\n') +; c64scr.print(str1x) +; c64.CHROUT('\n') +; c64scr.print(str2x) +; c64.CHROUT('\n') +; +; +; byte[] barr = [1,2,3] +; word[] warr = [1000,2000,3000] +; float[] farr = [1.1, 2.2, 3.3] +; +; byte bb +; word ww +; float ff +; for bb in barr { +; c64scr.print_b(bb) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; for ww in warr { +; c64scr.print_w(ww) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; for bb in 0 to len(farr)-1 { +; c64flt.print_f(farr[bb]) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; +; barr[0] = 99 +; warr[0] = 99 +; farr[0] = 99.9 +; for bb in barr { +; c64scr.print_b(bb) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; for ww in warr { +; c64scr.print_w(ww) +; c64.CHROUT(',') +; } +; c64.CHROUT('\n') +; for bb in 0 to len(farr)-1 { +; c64flt.print_f(farr[bb]) +; c64.CHROUT(',') +; } } }