From b44e76db57b4d676644ee47f389f88a657737e96 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sun, 11 Aug 2019 16:01:37 +0200 Subject: [PATCH] fix any/all assembly routine, added asm for min/max/sum/ etc aggregates removed avg function because of hidden internal overflow issues --- compiler/res/prog8lib/prog8lib.asm | 7 +- .../prog8/ast/expressions/AstExpressions.kt | 2 +- .../ast/processing/AstIdentifiersChecker.kt | 3 +- .../ast/processing/StatementReorderer.kt | 4 +- .../VarInitValueAndAddressOfCreator.kt | 6 +- compiler/src/prog8/compiler/Main.kt | 2 +- .../c64/codegen2/BuiltinFunctionsAsmGen.kt | 61 ++++++++-- .../src/prog8/functions/BuiltinFunctions.kt | 1 - .../src/prog8/optimizer/ConstantFolding.kt | 31 +++-- .../prog8/optimizer/SimplifyExpressions.kt | 2 +- compiler/src/prog8/vm/astvm/AstVm.kt | 4 - docs/source/programming.rst | 3 - examples/arithmetic/aggregates.p8 | 110 ++++++++++++++++++ 13 files changed, 186 insertions(+), 50 deletions(-) create mode 100644 examples/arithmetic/aggregates.p8 diff --git a/compiler/res/prog8lib/prog8lib.asm b/compiler/res/prog8lib/prog8lib.asm index dc81452a9..f506f18b8 100644 --- a/compiler/res/prog8lib/prog8lib.asm +++ b/compiler/res/prog8lib/prog8lib.asm @@ -35,7 +35,7 @@ init_system .proc rts .pend - + read_byte_from_address .proc ; -- read the byte from the memory address on the top of the stack, return in A (stack remains unchanged) lda c64.ESTACK_LO+1,x @@ -45,7 +45,7 @@ read_byte_from_address .proc + lda $ffff ; modified rts .pend - + add_a_to_zpword .proc ; -- add ubyte in A to the uword in c64.SCRATCH_ZPWORD1 @@ -851,11 +851,12 @@ func_all_w .proc bne + iny lda (c64.SCRATCH_ZPWORD1),y - bne + + bne ++ lda #0 sta c64.ESTACK_LO+1,x rts + iny ++ iny _cmp_mod cpy #255 ; modified bne - lda #1 diff --git a/compiler/src/prog8/ast/expressions/AstExpressions.kt b/compiler/src/prog8/ast/expressions/AstExpressions.kt index 7646eed5c..82d18c349 100644 --- a/compiler/src/prog8/ast/expressions/AstExpressions.kt +++ b/compiler/src/prog8/ast/expressions/AstExpressions.kt @@ -500,7 +500,7 @@ class ReferenceLiteralValue(val type: DataType, // only reference types allo throw FatalAstException("weird array element $it") it } else { - num.cast(elementType)!! + num.cast(elementType) // TODO this can throw an exception } }.toTypedArray() return ReferenceLiteralValue(targettype, null, array=castArray, position = position) diff --git a/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt b/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt index 229eec62b..984f21f4f 100644 --- a/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt +++ b/compiler/src/prog8/ast/processing/AstIdentifiersChecker.kt @@ -237,8 +237,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi val newValue: Expression val lval = returnStmt.value as? NumericLiteralValue if(lval!=null) { - val adjusted = lval.cast(subroutine.returntypes.single()) - newValue = if(adjusted!=null && adjusted !== lval) adjusted else lval + newValue = lval.cast(subroutine.returntypes.single()) } else { newValue = returnStmt.value!! } diff --git a/compiler/src/prog8/ast/processing/StatementReorderer.kt b/compiler/src/prog8/ast/processing/StatementReorderer.kt index 4ed9e1ddc..499eb8bed 100644 --- a/compiler/src/prog8/ast/processing/StatementReorderer.kt +++ b/compiler/src/prog8/ast/processing/StatementReorderer.kt @@ -337,7 +337,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi if(dt!=DataType.UWORD) { val literaladdr = memread.addressExpression as? NumericLiteralValue if(literaladdr!=null) { - memread.addressExpression = literaladdr.cast(DataType.UWORD)!! + memread.addressExpression = literaladdr.cast(DataType.UWORD) } else { memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) memread.addressExpression.parent = memread @@ -351,7 +351,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi if(dt!=DataType.UWORD) { val literaladdr = memwrite.addressExpression as? NumericLiteralValue if(literaladdr!=null) { - memwrite.addressExpression = literaladdr.cast(DataType.UWORD)!! + memwrite.addressExpression = literaladdr.cast(DataType.UWORD) } else { memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) memwrite.addressExpression.parent = memwrite diff --git a/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt b/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt index f84bd17ef..fc9f91754 100644 --- a/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt +++ b/compiler/src/prog8/ast/processing/VarInitValueAndAddressOfCreator.kt @@ -57,10 +57,8 @@ internal class VarInitValueAndAddressOfCreator(private val program: Program): IA addVarDecl(scope, decl.asDefaultValueDecl(null)) val declvalue = decl.value!! val value = - if(declvalue is NumericLiteralValue) { - val converted = declvalue.cast(decl.datatype) - converted ?: declvalue - } + if(declvalue is NumericLiteralValue) + declvalue.cast(decl.datatype) else declvalue val identifierName = listOf(decl.name) // this was: (scoped name) decl.scopedname.split(".") diff --git a/compiler/src/prog8/compiler/Main.kt b/compiler/src/prog8/compiler/Main.kt index 544033080..1100207a3 100644 --- a/compiler/src/prog8/compiler/Main.kt +++ b/compiler/src/prog8/compiler/Main.kt @@ -94,7 +94,7 @@ fun compileProgram(filepath: Path, programAst.checkValid(compilerOptions) // check if final tree is valid programAst.checkRecursion() // check if there are recursive subroutine calls - printAst(programAst) + // printAst(programAst) if(writeAssembly) { // asm generation directly from the Ast, no need for intermediate code diff --git a/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt b/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt index 0c6e7a396..d8395d235 100644 --- a/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt +++ b/compiler/src/prog8/compiler/target/c64/codegen2/BuiltinFunctionsAsmGen.kt @@ -81,17 +81,31 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, asmgen.assignFromEvalResult(secondTarget) } "strlen" -> { - val identifierName = asmgen.asmIdentifierName(fcall.arglist[0] as IdentifierReference) - asmgen.out(""" - lda #<$identifierName - sta $ESTACK_LO_HEX,x - lda #>$identifierName - sta $ESTACK_HI_HEX,x - dex - jsr prog8_lib.func_strlen - """) + outputPushAddressOfIdentifier(fcall.arglist[0]) + asmgen.out(" jsr prog8_lib.func_strlen") + } + "min", "max", "sum" -> { + outputPushAddressAndLenghtOfArray(fcall.arglist[0]) + val dt = fcall.arglist.single().inferType(program)!! + when(dt) { + DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_ub") + DataType.ARRAY_B -> asmgen.out(" jsr prog8_lib.func_${functionName}_b") + DataType.ARRAY_UW -> asmgen.out(" jsr prog8_lib.func_${functionName}_uw") + DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w") + DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f") + else -> throw AssemblyError("weird type $dt") + } + } + "any", "all" -> { + outputPushAddressAndLenghtOfArray(fcall.arglist[0]) + val dt = fcall.arglist.single().inferType(program)!! + when(dt) { + DataType.ARRAY_B, DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_b") + DataType.ARRAY_UW, DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w") + DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f") + else -> throw AssemblyError("weird type $dt") + } } - // TODO: any(f), all(f), max(f), min(f), sum(f), avg(f) "sin", "cos", "tan", "atan", "ln", "log2", "sqrt", "rad", "deg", "round", "floor", "ceil", @@ -254,6 +268,33 @@ internal class BuiltinFunctionsAsmGen(private val program: Program, } } + private fun outputPushAddressAndLenghtOfArray(arg: Expression) { + arg as IdentifierReference + val identifierName = asmgen.asmIdentifierName(arg) + val size = arg.targetVarDecl(program.namespace)!!.arraysize!!.size()!! + asmgen.out(""" + lda #<$identifierName + sta $ESTACK_LO_HEX,x + lda #>$identifierName + sta $ESTACK_HI_HEX,x + dex + lda #$size + sta $ESTACK_LO_HEX,x + dex + """) + } + + private fun outputPushAddressOfIdentifier(arg: Expression) { + val identifierName = asmgen.asmIdentifierName(arg as IdentifierReference) + asmgen.out(""" + lda #<$identifierName + sta $ESTACK_LO_HEX,x + lda #>$identifierName + sta $ESTACK_HI_HEX,x + dex + """) + } + private fun translateFunctionArguments(args: MutableList, signature: FunctionSignature) { args.forEach { asmgen.translateExpression(it) diff --git a/compiler/src/prog8/functions/BuiltinFunctions.kt b/compiler/src/prog8/functions/BuiltinFunctions.kt index 4435da5b6..d256188b0 100644 --- a/compiler/src/prog8/functions/BuiltinFunctions.kt +++ b/compiler/src/prog8/functions/BuiltinFunctions.kt @@ -52,7 +52,6 @@ val BuiltinFunctions = mapOf( "sqrt" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::sqrt) }, "rad" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toRadians) }, "deg" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toDegrees) }, - "avg" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.FLOAT) { a, p, _ -> collectionArgNeverConst(a, p) }, "round" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::round) }, "floor" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::floor) }, "ceil" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::ceil) }, diff --git a/compiler/src/prog8/optimizer/ConstantFolding.kt b/compiler/src/prog8/optimizer/ConstantFolding.kt index b675dd605..ed4016785 100644 --- a/compiler/src/prog8/optimizer/ConstantFolding.kt +++ b/compiler/src/prog8/optimizer/ConstantFolding.kt @@ -9,6 +9,7 @@ import prog8.ast.processing.fixupArrayDatatype import prog8.ast.statements.* import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_NEGATIVE import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_POSITIVE +import prog8.compiler.target.c64.codegen2.AssemblyError import prog8.functions.BuiltinFunctions import kotlin.math.floor @@ -174,7 +175,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { copy.parent = identifier.parent copy } - cval.type in PassByReferenceDatatypes -> TODO("ref type $identifier") + cval.type in PassByReferenceDatatypes -> throw AssemblyError("pass-by-reference type should not be considered a constant") else -> identifier } } catch (ax: AstException) { @@ -209,11 +210,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { val possibleDts = arg.second.possibleDatatypes val argConst = arg.first.value.constValue(program) if(argConst!=null && argConst.type !in possibleDts) { - val convertedValue = argConst.cast(possibleDts.first()) - if(convertedValue!=null) { - functionCall.arglist[arg.first.index] = convertedValue - optimizationsDone++ - } + val convertedValue = argConst.cast(possibleDts.first()) // TODO can throw exception + functionCall.arglist[arg.first.index] = convertedValue + optimizationsDone++ } } return @@ -227,11 +226,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { val expectedDt = arg.second.type val argConst = arg.first.value.constValue(program) if(argConst!=null && argConst.type!=expectedDt) { - val convertedValue = argConst.cast(expectedDt) - if(convertedValue!=null) { - functionCall.arglist[arg.first.index] = convertedValue - optimizationsDone++ - } + val convertedValue = argConst.cast(expectedDt) // TODO can throw exception + functionCall.arglist[arg.first.index] = convertedValue + optimizationsDone++ } } } @@ -315,7 +312,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { super.visit(expr) if(expr.left is ReferenceLiteralValue || expr.right is ReferenceLiteralValue) - TODO("binexpr with reference litval") + throw FatalAstException("binexpr with reference litval instead of numeric") val leftconst = expr.left.constValue(program) val rightconst = expr.right.constValue(program) @@ -547,14 +544,12 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor { override fun visit(forLoop: ForLoop): Statement { fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr { + // TODO casts can throw exception val newFrom = rangeFrom.cast(targetDt) val newTo = rangeTo.cast(targetDt) - if (newFrom != null && newTo != null) { - val newStep: Expression = - if (stepLiteral != null) (stepLiteral.cast(targetDt) ?: stepLiteral) else range.step - return RangeExpr(newFrom, newTo, newStep, range.position) - } - return range + val newStep: Expression = + stepLiteral?.cast(targetDt) ?: range.step + return RangeExpr(newFrom, newTo, newStep, range.position) } // adjust the datatype of a range expression in for loops to the loop variable. diff --git a/compiler/src/prog8/optimizer/SimplifyExpressions.kt b/compiler/src/prog8/optimizer/SimplifyExpressions.kt index 5918e49b3..0b0e9c474 100644 --- a/compiler/src/prog8/optimizer/SimplifyExpressions.kt +++ b/compiler/src/prog8/optimizer/SimplifyExpressions.kt @@ -43,7 +43,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying val literal = tc.expression as? NumericLiteralValue if(literal!=null) { val newLiteral = literal.cast(tc.type) - if(newLiteral!=null && newLiteral!==literal) { + if(newLiteral!==literal) { optimizationsDone++ return newLiteral } diff --git a/compiler/src/prog8/vm/astvm/AstVm.kt b/compiler/src/prog8/vm/astvm/AstVm.kt index 1c4ae1906..7ae98df5e 100644 --- a/compiler/src/prog8/vm/astvm/AstVm.kt +++ b/compiler/src/prog8/vm/astvm/AstVm.kt @@ -845,10 +845,6 @@ class AstVm(val program: Program) { val numbers = args.single().array!!.map { it.toDouble() } RuntimeValue(ArrayElementTypes.getValue(args[0].type), numbers.min()) } - "avg" -> { - val numbers = args.single().array!!.map { it.toDouble() } - RuntimeValue(DataType.FLOAT, numbers.average()) - } "sum" -> { val sum = args.single().array!!.map { it.toDouble() }.sum() when (args[0].type) { diff --git a/docs/source/programming.rst b/docs/source/programming.rst index 2b6be5fb1..bbeef71b9 100644 --- a/docs/source/programming.rst +++ b/docs/source/programming.rst @@ -707,9 +707,6 @@ max(x) min(x) Minimum of the values in the array value x -avg(x) - Average of the values in the array value x - sum(x) Sum of the values in the array value x diff --git a/examples/arithmetic/aggregates.p8 b/examples/arithmetic/aggregates.p8 new file mode 100644 index 000000000..4d01fffa9 --- /dev/null +++ b/examples/arithmetic/aggregates.p8 @@ -0,0 +1,110 @@ +%import c64lib +%import c64utils +%import c64flt +%zeropage dontuse + +main { + + sub start() { + ubyte[] ubarr = [100, 0, 99, 199, 22] + byte[] barr = [-100, 0, 99, -122, 22] + uword[] uwarr = [1000, 0, 222, 4444, 999] + word[] warr = [-1000, 0, 999, -4444, 222] + float[] farr = [-1000.1, 0, 999.9, -4444.4, 222.2] + str name = "irmen" + ubyte ub + byte bb + word ww + uword uw + float ff + + ; LEN/STRLEN + ubyte length = len(name) + if length!=5 c64scr.print("error len1\n") + length = len(uwarr) + if length!=5 c64scr.print("error len2\n") + length=strlen(name) + if length!=5 c64scr.print("error strlen1\n") + name[3] = 0 + length=strlen(name) + if length!=3 c64scr.print("error strlen2\n") + + ; MAX + ub = max(ubarr) + if ub!=199 c64scr.print("error max1\n") + bb = max(barr) + if bb!=99 c64scr.print("error max2\n") + uw = max(uwarr) + if uw!=4444 c64scr.print("error max3\n") + ww = max(warr) + if ww!=999 c64scr.print("error max4\n") + ff = max(farr) + if ff!=999.9 c64scr.print("error max5\n") + + ; MIN + ub = min(ubarr) + if ub!=0 c64scr.print("error min1\n") + bb = min(barr) + if bb!=-122 c64scr.print("error min2\n") + uw = min(uwarr) + if uw!=0 c64scr.print("error min3\n") + ww = min(warr) + if ww!=-4444 c64scr.print("error min4\n") + ff = min(farr) + if ff!=-4444.4 c64scr.print("error min5\n") + + ; SUM + uw = sum(ubarr) + if uw!=420 c64scr.print("error sum1\n") + ww = sum(barr) + if ww!=-101 c64scr.print("error sum2\n") + uw = sum(uwarr) + if uw!=6665 c64scr.print("error sum3\n") + ww = sum(warr) + if ww!=-4223 c64scr.print("error sum4\n") + ff = sum(farr) + if ff!=-4222.4 c64scr.print("error sum5\n") + + ; ANY + ub = any(ubarr) + if ub==0 c64scr.print("error any1\n") + ub = any(barr) + if ub==0 c64scr.print("error any2\n") + ub = any(uwarr) + if ub==0 c64scr.print("error any3\n") + ub = any(warr) + if ub==0 c64scr.print("error any4\n") + ub = any(farr) + if ub==0 c64scr.print("error any5\n") + + ; ALL + ub = all(ubarr) + if ub==1 c64scr.print("error all1\n") + ub = all(barr) + if ub==1 c64scr.print("error all2\n") + ub = all(uwarr) + if ub==1 c64scr.print("error all3\n") + ub = all(warr) + if ub==1 c64scr.print("error all4\n") + ub = all(farr) + if ub==1 c64scr.print("error all5\n") + ubarr[1]=$40 + barr[1]=$40 + uwarr[1]=$4000 + warr[1]=$4000 + farr[1]=1.1 + ub = all(ubarr) + if ub==0 c64scr.print("error all6\n") + ub = all(barr) + if ub==0 c64scr.print("error all7\n") + ub = all(uwarr) + if ub==0 c64scr.print("error all8\n") + ub = all(warr) + if ub==0 c64scr.print("error all9\n") + ub = all(farr) + if ub==0 c64scr.print("error all10\n") + + + c64scr.print("\nyou should see no errors above.") + } +}