fix any/all assembly routine, added asm for min/max/sum/ etc aggregates

removed avg function because of hidden internal overflow issues
This commit is contained in:
Irmen de Jong 2019-08-11 16:01:37 +02:00
parent 2ce6bc5946
commit b44e76db57
13 changed files with 186 additions and 50 deletions

View File

@ -35,7 +35,7 @@ init_system .proc
rts
.pend
read_byte_from_address .proc
; -- read the byte from the memory address on the top of the stack, return in A (stack remains unchanged)
lda c64.ESTACK_LO+1,x
@ -45,7 +45,7 @@ read_byte_from_address .proc
+ lda $ffff ; modified
rts
.pend
add_a_to_zpword .proc
; -- add ubyte in A to the uword in c64.SCRATCH_ZPWORD1
@ -851,11 +851,12 @@ func_all_w .proc
bne +
iny
lda (c64.SCRATCH_ZPWORD1),y
bne +
bne ++
lda #0
sta c64.ESTACK_LO+1,x
rts
+ iny
+ iny
_cmp_mod cpy #255 ; modified
bne -
lda #1

View File

@ -500,7 +500,7 @@ class ReferenceLiteralValue(val type: DataType, // only reference types allo
throw FatalAstException("weird array element $it")
it
} else {
num.cast(elementType)!!
num.cast(elementType) // TODO this can throw an exception
}
}.toTypedArray()
return ReferenceLiteralValue(targettype, null, array=castArray, position = position)

View File

@ -237,8 +237,7 @@ internal class AstIdentifiersChecker(private val program: Program) : IAstModifyi
val newValue: Expression
val lval = returnStmt.value as? NumericLiteralValue
if(lval!=null) {
val adjusted = lval.cast(subroutine.returntypes.single())
newValue = if(adjusted!=null && adjusted !== lval) adjusted else lval
newValue = lval.cast(subroutine.returntypes.single())
} else {
newValue = returnStmt.value!!
}

View File

@ -337,7 +337,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
if(dt!=DataType.UWORD) {
val literaladdr = memread.addressExpression as? NumericLiteralValue
if(literaladdr!=null) {
memread.addressExpression = literaladdr.cast(DataType.UWORD)!!
memread.addressExpression = literaladdr.cast(DataType.UWORD)
} else {
memread.addressExpression = TypecastExpression(memread.addressExpression, DataType.UWORD, true, memread.addressExpression.position)
memread.addressExpression.parent = memread
@ -351,7 +351,7 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
if(dt!=DataType.UWORD) {
val literaladdr = memwrite.addressExpression as? NumericLiteralValue
if(literaladdr!=null) {
memwrite.addressExpression = literaladdr.cast(DataType.UWORD)!!
memwrite.addressExpression = literaladdr.cast(DataType.UWORD)
} else {
memwrite.addressExpression = TypecastExpression(memwrite.addressExpression, DataType.UWORD, true, memwrite.addressExpression.position)
memwrite.addressExpression.parent = memwrite

View File

@ -57,10 +57,8 @@ internal class VarInitValueAndAddressOfCreator(private val program: Program): IA
addVarDecl(scope, decl.asDefaultValueDecl(null))
val declvalue = decl.value!!
val value =
if(declvalue is NumericLiteralValue) {
val converted = declvalue.cast(decl.datatype)
converted ?: declvalue
}
if(declvalue is NumericLiteralValue)
declvalue.cast(decl.datatype)
else
declvalue
val identifierName = listOf(decl.name) // this was: (scoped name) decl.scopedname.split(".")

View File

@ -94,7 +94,7 @@ fun compileProgram(filepath: Path,
programAst.checkValid(compilerOptions) // check if final tree is valid
programAst.checkRecursion() // check if there are recursive subroutine calls
printAst(programAst)
// printAst(programAst)
if(writeAssembly) {
// asm generation directly from the Ast, no need for intermediate code

View File

@ -81,17 +81,31 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
asmgen.assignFromEvalResult(secondTarget)
}
"strlen" -> {
val identifierName = asmgen.asmIdentifierName(fcall.arglist[0] as IdentifierReference)
asmgen.out("""
lda #<$identifierName
sta $ESTACK_LO_HEX,x
lda #>$identifierName
sta $ESTACK_HI_HEX,x
dex
jsr prog8_lib.func_strlen
""")
outputPushAddressOfIdentifier(fcall.arglist[0])
asmgen.out(" jsr prog8_lib.func_strlen")
}
"min", "max", "sum" -> {
outputPushAddressAndLenghtOfArray(fcall.arglist[0])
val dt = fcall.arglist.single().inferType(program)!!
when(dt) {
DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_ub")
DataType.ARRAY_B -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW -> asmgen.out(" jsr prog8_lib.func_${functionName}_uw")
DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w")
DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f")
else -> throw AssemblyError("weird type $dt")
}
}
"any", "all" -> {
outputPushAddressAndLenghtOfArray(fcall.arglist[0])
val dt = fcall.arglist.single().inferType(program)!!
when(dt) {
DataType.ARRAY_B, DataType.ARRAY_UB, DataType.STR_S, DataType.STR -> asmgen.out(" jsr prog8_lib.func_${functionName}_b")
DataType.ARRAY_UW, DataType.ARRAY_W -> asmgen.out(" jsr prog8_lib.func_${functionName}_w")
DataType.ARRAY_F -> asmgen.out(" jsr c64flt.func_${functionName}_f")
else -> throw AssemblyError("weird type $dt")
}
}
// TODO: any(f), all(f), max(f), min(f), sum(f), avg(f)
"sin", "cos", "tan", "atan",
"ln", "log2", "sqrt", "rad",
"deg", "round", "floor", "ceil",
@ -254,6 +268,33 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
}
}
private fun outputPushAddressAndLenghtOfArray(arg: Expression) {
arg as IdentifierReference
val identifierName = asmgen.asmIdentifierName(arg)
val size = arg.targetVarDecl(program.namespace)!!.arraysize!!.size()!!
asmgen.out("""
lda #<$identifierName
sta $ESTACK_LO_HEX,x
lda #>$identifierName
sta $ESTACK_HI_HEX,x
dex
lda #$size
sta $ESTACK_LO_HEX,x
dex
""")
}
private fun outputPushAddressOfIdentifier(arg: Expression) {
val identifierName = asmgen.asmIdentifierName(arg as IdentifierReference)
asmgen.out("""
lda #<$identifierName
sta $ESTACK_LO_HEX,x
lda #>$identifierName
sta $ESTACK_HI_HEX,x
dex
""")
}
private fun translateFunctionArguments(args: MutableList<Expression>, signature: FunctionSignature) {
args.forEach {
asmgen.translateExpression(it)

View File

@ -52,7 +52,6 @@ val BuiltinFunctions = mapOf(
"sqrt" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::sqrt) },
"rad" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toRadians) },
"deg" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArg(a, p, prg, Math::toDegrees) },
"avg" to FunctionSignature(true, listOf(BuiltinFunctionParam("values", ArrayDatatypes)), DataType.FLOAT) { a, p, _ -> collectionArgNeverConst(a, p) },
"round" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::round) },
"floor" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::floor) },
"ceil" to FunctionSignature(true, listOf(BuiltinFunctionParam("value", setOf(DataType.FLOAT))), DataType.FLOAT) { a, p, prg -> oneDoubleArgOutputWord(a, p, prg, Math::ceil) },

View File

@ -9,6 +9,7 @@ import prog8.ast.processing.fixupArrayDatatype
import prog8.ast.statements.*
import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_NEGATIVE
import prog8.compiler.target.c64.MachineDefinition.FLOAT_MAX_POSITIVE
import prog8.compiler.target.c64.codegen2.AssemblyError
import prog8.functions.BuiltinFunctions
import kotlin.math.floor
@ -174,7 +175,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
copy.parent = identifier.parent
copy
}
cval.type in PassByReferenceDatatypes -> TODO("ref type $identifier")
cval.type in PassByReferenceDatatypes -> throw AssemblyError("pass-by-reference type should not be considered a constant")
else -> identifier
}
} catch (ax: AstException) {
@ -209,11 +210,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
val possibleDts = arg.second.possibleDatatypes
val argConst = arg.first.value.constValue(program)
if(argConst!=null && argConst.type !in possibleDts) {
val convertedValue = argConst.cast(possibleDts.first())
if(convertedValue!=null) {
functionCall.arglist[arg.first.index] = convertedValue
optimizationsDone++
}
val convertedValue = argConst.cast(possibleDts.first()) // TODO can throw exception
functionCall.arglist[arg.first.index] = convertedValue
optimizationsDone++
}
}
return
@ -227,11 +226,9 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
val expectedDt = arg.second.type
val argConst = arg.first.value.constValue(program)
if(argConst!=null && argConst.type!=expectedDt) {
val convertedValue = argConst.cast(expectedDt)
if(convertedValue!=null) {
functionCall.arglist[arg.first.index] = convertedValue
optimizationsDone++
}
val convertedValue = argConst.cast(expectedDt) // TODO can throw exception
functionCall.arglist[arg.first.index] = convertedValue
optimizationsDone++
}
}
}
@ -315,7 +312,7 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
super.visit(expr)
if(expr.left is ReferenceLiteralValue || expr.right is ReferenceLiteralValue)
TODO("binexpr with reference litval")
throw FatalAstException("binexpr with reference litval instead of numeric")
val leftconst = expr.left.constValue(program)
val rightconst = expr.right.constValue(program)
@ -547,14 +544,12 @@ class ConstantFolding(private val program: Program) : IAstModifyingVisitor {
override fun visit(forLoop: ForLoop): Statement {
fun adjustRangeDt(rangeFrom: NumericLiteralValue, targetDt: DataType, rangeTo: NumericLiteralValue, stepLiteral: NumericLiteralValue?, range: RangeExpr): RangeExpr {
// TODO casts can throw exception
val newFrom = rangeFrom.cast(targetDt)
val newTo = rangeTo.cast(targetDt)
if (newFrom != null && newTo != null) {
val newStep: Expression =
if (stepLiteral != null) (stepLiteral.cast(targetDt) ?: stepLiteral) else range.step
return RangeExpr(newFrom, newTo, newStep, range.position)
}
return range
val newStep: Expression =
stepLiteral?.cast(targetDt) ?: range.step
return RangeExpr(newFrom, newTo, newStep, range.position)
}
// adjust the datatype of a range expression in for loops to the loop variable.

View File

@ -43,7 +43,7 @@ internal class SimplifyExpressions(private val program: Program) : IAstModifying
val literal = tc.expression as? NumericLiteralValue
if(literal!=null) {
val newLiteral = literal.cast(tc.type)
if(newLiteral!=null && newLiteral!==literal) {
if(newLiteral!==literal) {
optimizationsDone++
return newLiteral
}

View File

@ -845,10 +845,6 @@ class AstVm(val program: Program) {
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(ArrayElementTypes.getValue(args[0].type), numbers.min())
}
"avg" -> {
val numbers = args.single().array!!.map { it.toDouble() }
RuntimeValue(DataType.FLOAT, numbers.average())
}
"sum" -> {
val sum = args.single().array!!.map { it.toDouble() }.sum()
when (args[0].type) {

View File

@ -707,9 +707,6 @@ max(x)
min(x)
Minimum of the values in the array value x
avg(x)
Average of the values in the array value x
sum(x)
Sum of the values in the array value x

View File

@ -0,0 +1,110 @@
%import c64lib
%import c64utils
%import c64flt
%zeropage dontuse
main {
sub start() {
ubyte[] ubarr = [100, 0, 99, 199, 22]
byte[] barr = [-100, 0, 99, -122, 22]
uword[] uwarr = [1000, 0, 222, 4444, 999]
word[] warr = [-1000, 0, 999, -4444, 222]
float[] farr = [-1000.1, 0, 999.9, -4444.4, 222.2]
str name = "irmen"
ubyte ub
byte bb
word ww
uword uw
float ff
; LEN/STRLEN
ubyte length = len(name)
if length!=5 c64scr.print("error len1\n")
length = len(uwarr)
if length!=5 c64scr.print("error len2\n")
length=strlen(name)
if length!=5 c64scr.print("error strlen1\n")
name[3] = 0
length=strlen(name)
if length!=3 c64scr.print("error strlen2\n")
; MAX
ub = max(ubarr)
if ub!=199 c64scr.print("error max1\n")
bb = max(barr)
if bb!=99 c64scr.print("error max2\n")
uw = max(uwarr)
if uw!=4444 c64scr.print("error max3\n")
ww = max(warr)
if ww!=999 c64scr.print("error max4\n")
ff = max(farr)
if ff!=999.9 c64scr.print("error max5\n")
; MIN
ub = min(ubarr)
if ub!=0 c64scr.print("error min1\n")
bb = min(barr)
if bb!=-122 c64scr.print("error min2\n")
uw = min(uwarr)
if uw!=0 c64scr.print("error min3\n")
ww = min(warr)
if ww!=-4444 c64scr.print("error min4\n")
ff = min(farr)
if ff!=-4444.4 c64scr.print("error min5\n")
; SUM
uw = sum(ubarr)
if uw!=420 c64scr.print("error sum1\n")
ww = sum(barr)
if ww!=-101 c64scr.print("error sum2\n")
uw = sum(uwarr)
if uw!=6665 c64scr.print("error sum3\n")
ww = sum(warr)
if ww!=-4223 c64scr.print("error sum4\n")
ff = sum(farr)
if ff!=-4222.4 c64scr.print("error sum5\n")
; ANY
ub = any(ubarr)
if ub==0 c64scr.print("error any1\n")
ub = any(barr)
if ub==0 c64scr.print("error any2\n")
ub = any(uwarr)
if ub==0 c64scr.print("error any3\n")
ub = any(warr)
if ub==0 c64scr.print("error any4\n")
ub = any(farr)
if ub==0 c64scr.print("error any5\n")
; ALL
ub = all(ubarr)
if ub==1 c64scr.print("error all1\n")
ub = all(barr)
if ub==1 c64scr.print("error all2\n")
ub = all(uwarr)
if ub==1 c64scr.print("error all3\n")
ub = all(warr)
if ub==1 c64scr.print("error all4\n")
ub = all(farr)
if ub==1 c64scr.print("error all5\n")
ubarr[1]=$40
barr[1]=$40
uwarr[1]=$4000
warr[1]=$4000
farr[1]=1.1
ub = all(ubarr)
if ub==0 c64scr.print("error all6\n")
ub = all(barr)
if ub==0 c64scr.print("error all7\n")
ub = all(uwarr)
if ub==0 c64scr.print("error all8\n")
ub = all(warr)
if ub==0 c64scr.print("error all9\n")
ub = all(farr)
if ub==0 c64scr.print("error all10\n")
c64scr.print("\nyou should see no errors above.")
}
}