fix any/all assembly routine, added asm for min/max/sum/ etc aggregates

removed avg function because of hidden internal overflow issues
This commit is contained in:
Irmen de Jong 2019-08-11 16:01:37 +02:00
parent 2ce6bc5946
commit b44e76db57
13 changed files with 186 additions and 50 deletions

View File

@ -35,7 +35,7 @@ init_system .proc
rts rts
.pend .pend
read_byte_from_address .proc read_byte_from_address .proc
; -- read the byte from the memory address on the top of the stack, return in A (stack remains unchanged) ; -- read the byte from the memory address on the top of the stack, return in A (stack remains unchanged)
lda c64.ESTACK_LO+1,x lda c64.ESTACK_LO+1,x
@ -45,7 +45,7 @@ read_byte_from_address .proc
+ lda $ffff ; modified + lda $ffff ; modified
rts rts
.pend .pend
add_a_to_zpword .proc add_a_to_zpword .proc
; -- add ubyte in A to the uword in c64.SCRATCH_ZPWORD1 ; -- add ubyte in A to the uword in c64.SCRATCH_ZPWORD1
@ -851,11 +851,12 @@ func_all_w .proc
bne + bne +
iny iny
lda (c64.SCRATCH_ZPWORD1),y lda (c64.SCRATCH_ZPWORD1),y
bne + bne ++
lda #0 lda #0
sta c64.ESTACK_LO+1,x sta c64.ESTACK_LO+1,x
rts rts
+ iny + iny
+ iny
_cmp_mod cpy #255 ; modified _cmp_mod cpy #255 ; modified
bne - bne -
lda #1 lda #1

View File

@ -500,7 +500,7 @@ class ReferenceLiteralValue(val type: DataType, // only reference types allo
throw FatalAstException("weird array element $it") throw FatalAstException("weird array element $it")
it it
} else { } else {
num.cast(elementType)!! num.cast(elementType) // TODO this can throw an exception
} }
}.toTypedArray() }.toTypedArray()
return ReferenceLiteralValue(targettype, null, array=castArray, position = position) return ReferenceLiteralValue(targettype, null, array=castArray, position = position)

View File

@ -237,8 +237,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
val newValue: Expression val newValue: Expression
val lval = returnStmt.value as? NumericLiteralValue val lval = returnStmt.value as? NumericLiteralValue
if(lval!=null) { if(lval!=null) {
val adjusted = lval.cast(subroutine.returntypes.single()) newValue = lval.cast(subroutine.returntypes.single())
newValue = if(adjusted!=null && adjusted !== lval) adjusted else lval
} else { } else {
newValue = returnStmt.value!! newValue = returnStmt.value!!
} }

View File

@ -337,7 +337,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
if(dt!=DataType.UWORD) { if(dt!=DataType.UWORD) {
val literaladdr = memread.addressExpression as? NumericLiteralValue val literaladdr = memread.addressExpression as? NumericLiteralValue
if(literaladdr!=null) { if(literaladdr!=null) {
memread.addressExpression = literaladdr.cast(DataType.UWORD)!! memread.addressExpression = literaladdr.cast(DataType.UWORD)
} else { } else {
memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position) memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
memread.addressExpression.parent = memread memread.addressExpression.parent = memread
@ -351,7 +351,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
if(dt!=DataType.UWORD) { if(dt!=DataType.UWORD) {
val literaladdr = memwrite.addressExpression as? NumericLiteralValue val literaladdr = memwrite.addressExpression as? NumericLiteralValue
if(literaladdr!=null) { if(literaladdr!=null) {
memwrite.addressExpression = literaladdr.cast(DataType.UWORD)!! memwrite.addressExpression = literaladdr.cast(DataType.UWORD)
} else { } else {
memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position) memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
memwrite.addressExpression.parent = memwrite memwrite.addressExpression.parent = memwrite

View File

@ -57,10 +57,8 @@ internal class VarInitValueAndAddressOfCreator(private val program: Program): IA
addVarDecl(scope, decl.asDefaultValueDecl(null)) addVarDecl(scope, decl.asDefaultValueDecl(null))
val declvalue = decl.value!! val declvalue = decl.value!!
val value = val value =
if(declvalue is NumericLiteralValue) { if(declvalue is NumericLiteralValue)
val converted = declvalue.cast(decl.datatype) declvalue.cast(decl.datatype)
converted ?: declvalue
}
else else
declvalue declvalue
val identifierName = listOf(decl.name) // this was: (scoped name) decl.scopedname.split(".") val identifierName = listOf(decl.name) // this was: (scoped name) decl.scopedname.split(".")

View File

@ -94,7 +94,7 @@ fun compileProgram(filepath: Path,
programAst.checkValid(compilerOptions) // check if final tree is valid programAst.checkValid(compilerOptions) // check if final tree is valid
programAst.checkRecursion() // check if there are recursive subroutine calls programAst.checkRecursion() // check if there are recursive subroutine calls
printAst(programAst) // printAst(programAst)
if(writeAssembly) { if(writeAssembly) {
// asm generation directly from the Ast, no need for intermediate code // asm generation directly from the Ast, no need for intermediate code

View File

@ -81,17 +81,31 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
asmgen.assignFromEvalResult(secondTarget) asmgen.assignFromEvalResult(secondTarget)
} }
"strlen" -> { "strlen" -> {
val identifierName = asmgen.asmIdentifierName(fcall.arglist[0] as IdentifierReference) outputPushAddressOfIdentifier(fcall.arglist[0])
asmgen.out(""" asmgen.out(" jsr prog8_lib.func_strlen")
lda #<$identifierName }
sta $ESTACK_LO_HEX,x "min", "max", "sum" -> {
lda #>$identifierName outputPushAddressAndLenghtOfArray(fcall.arglist[0])
sta $ESTACK_HI_HEX,x val dt = fcall.arglist.single().inferType(program)!!
dex when(dt) {
jsr prog8_lib.func_strlen DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_ub")
""") DataType.ARRAY_B -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW -> asmgen.out(" jsr prog8_lib.func_${functionName}_uw")
DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w")
DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f")
else -> throw AssemblyError("weird type $dt")
}
}
"any", "all" -> {
outputPushAddressAndLenghtOfArray(fcall.arglist[0])
val dt = fcall.arglist.single().inferType(program)!!
when(dt) {
DataType.ARRAY_B, DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW, DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w")
DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f")
else -> throw AssemblyError("weird type $dt")
}
} }
// TODO: any(f), all(f), max(f), min(f), sum(f), avg(f)
"sin", "cos", "tan", "atan", "sin", "cos", "tan", "atan",
"ln", "log2", "sqrt", "rad", "ln", "log2", "sqrt", "rad",
"deg", "round", "floor", "ceil", "deg", "round", "floor", "ceil",
@ -254,6 +268,33 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
} }
} }
private fun outputPushAddressAndLenghtOfArray(arg: Expression) {
arg as IdentifierReference
val identifierName = asmgen.asmIdentifierName(arg)
val size = arg.targetVarDecl(program.namespace)!!.arraysize!!.size()!!
asmgen.out("""
lda #<$identifierName
sta $ESTACK_LO_HEX,x
lda #>$identifierName
sta $ESTACK_HI_HEX,x
dex
lda #$size
sta $ESTACK_LO_HEX,x
dex
""")
}
private fun outputPushAddressOfIdentifier(arg: Expression) {
val identifierName = asmgen.asmIdentifierName(arg as IdentifierReference)
asmgen.out("""
lda #<$identifierName
sta $ESTACK_LO_HEX,x
lda #>$identifierName
sta $ESTACK_HI_HEX,x
dex
""")
}
private fun translateFunctionArguments(args: MutableList<Expression>, signature: FunctionSignature) { private fun translateFunctionArguments(args: MutableList<Expression>, signature: FunctionSignature) {
args.forEach { args.forEach {
asmgen.translateExpression(it) asmgen.translateExpression(it)

View File

@ -52,7 +52,6 @@ val BuiltinFunctions = mapOf(
"sqrt" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::sqrt) }, "sqrt" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::sqrt) },
"rad" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toRadians) }, "rad" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toRadians) },
"deg" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toDegrees) }, "deg" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toDegrees) },
"avg" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.FLOAT) { a, p, _ -> collectionArgNeverConst(a, p) },
"round" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::round) }, "round" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::round) },
"floor" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::floor) }, "floor" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::floor) },
"ceil" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::ceil) }, "ceil" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::ceil) },

View File

@ -9,6 +9,7 @@ import prog8.ast.processing.fixupArrayDatatype
import prog8.ast.statements.* import prog8.ast.statements.*
import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_NEGATIVE import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_NEGATIVE
import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_POSITIVE import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_POSITIVE
import prog8.compiler.target.c64.codegen2.AssemblyError
import prog8.functions.BuiltinFunctions import prog8.functions.BuiltinFunctions
import kotlin.math.floor import kotlin.math.floor
@ -174,7 +175,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
copy.parent = identifier.parent copy.parent = identifier.parent
copy copy
} }
cval.type in PassByReferenceDatatypes -> TODO("ref type $identifier") cval.type in PassByReferenceDatatypes -> throw AssemblyError("pass-by-reference type should not be considered a constant")
else -> identifier else -> identifier
} }
} catch (ax: AstException) { } catch (ax: AstException) {
@ -209,11 +210,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
val possibleDts = arg.second.possibleDatatypes val possibleDts = arg.second.possibleDatatypes
val argConst = arg.first.value.constValue(program) val argConst = arg.first.value.constValue(program)
if(argConst!=null && argConst.type !in possibleDts) { if(argConst!=null && argConst.type !in possibleDts) {
val convertedValue = argConst.cast(possibleDts.first()) val convertedValue = argConst.cast(possibleDts.first()) // TODO can throw exception
if(convertedValue!=null) { functionCall.arglist[arg.first.index] = convertedValue
functionCall.arglist[arg.first.index] = convertedValue optimizationsDone++
optimizationsDone++
}
} }
} }
return return
@ -227,11 +226,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
val expectedDt = arg.second.type val expectedDt = arg.second.type
val argConst = arg.first.value.constValue(program) val argConst = arg.first.value.constValue(program)
if(argConst!=null && argConst.type!=expectedDt) { if(argConst!=null && argConst.type!=expectedDt) {
val convertedValue = argConst.cast(expectedDt) val convertedValue = argConst.cast(expectedDt) // TODO can throw exception
if(convertedValue!=null) { functionCall.arglist[arg.first.index] = convertedValue
functionCall.arglist[arg.first.index] = convertedValue optimizationsDone++
optimizationsDone++
}
} }
} }
} }
@ -315,7 +312,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
super.visit(expr) super.visit(expr)
if(expr.left is ReferenceLiteralValue || expr.right is ReferenceLiteralValue) if(expr.left is ReferenceLiteralValue || expr.right is ReferenceLiteralValue)
TODO("binexpr with reference litval") throw FatalAstException("binexpr with reference litval instead of numeric")
val leftconst = expr.left.constValue(program) val leftconst = expr.left.constValue(program)
val rightconst = expr.right.constValue(program) val rightconst = expr.right.constValue(program)
@ -547,14 +544,12 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
override fun visit(forLoop: ForLoop): Statement { override fun visit(forLoop: ForLoop): Statement {
fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr { fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr {
// TODO casts can throw exception
val newFrom = rangeFrom.cast(targetDt) val newFrom = rangeFrom.cast(targetDt)
val newTo = rangeTo.cast(targetDt) val newTo = rangeTo.cast(targetDt)
if (newFrom != null && newTo != null) { val newStep: Expression =
val newStep: Expression = stepLiteral?.cast(targetDt) ?: range.step
if (stepLiteral != null) (stepLiteral.cast(targetDt) ?: stepLiteral) else range.step return RangeExpr(newFrom, newTo, newStep, range.position)
return RangeExpr(newFrom, newTo, newStep, range.position)
}
return range
} }
// adjust the datatype of a range expression in for loops to the loop variable. // adjust the datatype of a range expression in for loops to the loop variable.

View File

@ -43,7 +43,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val literal = tc.expression as? NumericLiteralValue val literal = tc.expression as? NumericLiteralValue
if(literal!=null) { if(literal!=null) {
val newLiteral = literal.cast(tc.type) val newLiteral = literal.cast(tc.type)
if(newLiteral!=null && newLiteral!==literal) { if(newLiteral!==literal) {
optimizationsDone++ optimizationsDone++
return newLiteral return newLiteral
} }

View File

@ -845,10 +845,6 @@ class AstVm(val program: Program) {
val numbers = args.single().array!!.map { it.toDouble() } val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(ArrayElementTypes.getValue(args[0].type), numbers.min()) RuntimeValue(ArrayElementTypes.getValue(args[0].type), numbers.min())
} }
"avg" -> {
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(DataType.FLOAT, numbers.average())
}
"sum" -> { "sum" -> {
val sum = args.single().array!!.map { it.toDouble() }.sum() val sum = args.single().array!!.map { it.toDouble() }.sum()
when (args[0].type) { when (args[0].type) {

View File

@ -707,9 +707,6 @@ max(x)
min(x) min(x)
Minimum of the values in the array value x Minimum of the values in the array value x
avg(x)
Average of the values in the array value x
sum(x) sum(x)
Sum of the values in the array value x Sum of the values in the array value x

View File

@ -0,0 +1,110 @@
%import c64lib
%import c64utils
%import c64flt
%zeropage dontuse
main {
sub start() {
ubyte[] ubarr = [100, 0, 99, 199, 22]
byte[] barr = [-100, 0, 99, -122, 22]
uword[] uwarr = [1000, 0, 222, 4444, 999]
word[] warr = [-1000, 0, 999, -4444, 222]
float[] farr = [-1000.1, 0, 999.9, -4444.4, 222.2]
str name = "irmen"
ubyte ub
byte bb
word ww
uword uw
float ff
; LEN/STRLEN
ubyte length = len(name)
if length!=5 c64scr.print("error len1\n")
length = len(uwarr)
if length!=5 c64scr.print("error len2\n")
length=strlen(name)
if length!=5 c64scr.print("error strlen1\n")
name[3] = 0
length=strlen(name)
if length!=3 c64scr.print("error strlen2\n")
; MAX
ub = max(ubarr)
if ub!=199 c64scr.print("error max1\n")
bb = max(barr)
if bb!=99 c64scr.print("error max2\n")
uw = max(uwarr)
if uw!=4444 c64scr.print("error max3\n")
ww = max(warr)
if ww!=999 c64scr.print("error max4\n")
ff = max(farr)
if ff!=999.9 c64scr.print("error max5\n")
; MIN
ub = min(ubarr)
if ub!=0 c64scr.print("error min1\n")
bb = min(barr)
if bb!=-122 c64scr.print("error min2\n")
uw = min(uwarr)
if uw!=0 c64scr.print("error min3\n")
ww = min(warr)
if ww!=-4444 c64scr.print("error min4\n")
ff = min(farr)
if ff!=-4444.4 c64scr.print("error min5\n")
; SUM
uw = sum(ubarr)
if uw!=420 c64scr.print("error sum1\n")
ww = sum(barr)
if ww!=-101 c64scr.print("error sum2\n")
uw = sum(uwarr)
if uw!=6665 c64scr.print("error sum3\n")
ww = sum(warr)
if ww!=-4223 c64scr.print("error sum4\n")
ff = sum(farr)
if ff!=-4222.4 c64scr.print("error sum5\n")
; ANY
ub = any(ubarr)
if ub==0 c64scr.print("error any1\n")
ub = any(barr)
if ub==0 c64scr.print("error any2\n")
ub = any(uwarr)
if ub==0 c64scr.print("error any3\n")
ub = any(warr)
if ub==0 c64scr.print("error any4\n")
ub = any(farr)
if ub==0 c64scr.print("error any5\n")
; ALL
ub = all(ubarr)
if ub==1 c64scr.print("error all1\n")
ub = all(barr)
if ub==1 c64scr.print("error all2\n")
ub = all(uwarr)
if ub==1 c64scr.print("error all3\n")
ub = all(warr)
if ub==1 c64scr.print("error all4\n")
ub = all(farr)
if ub==1 c64scr.print("error all5\n")
ubarr[1]=$40
barr[1]=$40
uwarr[1]=$4000
warr[1]=$4000
farr[1]=1.1
ub = all(ubarr)
if ub==0 c64scr.print("error all6\n")
ub = all(barr)
if ub==0 c64scr.print("error all7\n")
ub = all(uwarr)
if ub==0 c64scr.print("error all8\n")
ub = all(warr)
if ub==0 c64scr.print("error all9\n")
ub = all(farr)
if ub==0 c64scr.print("error all10\n")
c64scr.print("\nyou should see no errors above.")
}
}