trying to fix arithmetic and funcion calls

This commit is contained in:
Irmen de Jong
2019-08-03 01:51:12 +02:00
parent e9c357a885
commit 4718f09cb7
11 changed files with 163 additions and 185 deletions

View File

@@ -97,14 +97,13 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
val leftDt = left.inferType(program) val leftDt = left.inferType(program)
val rightDt = right.inferType(program) val rightDt = right.inferType(program)
return when (operator) { return when (operator) {
"+", "-", "*", "**", "%" -> if (leftDt == null || rightDt == null) null else { "+", "-", "*", "**", "%", "/" -> if (leftDt == null || rightDt == null) null else {
try { try {
arithmeticOpDt(leftDt, rightDt) commonDatatype(leftDt, rightDt, null, null).first
} catch (x: FatalAstException) { } catch (x: FatalAstException) {
null null
} }
} }
"/" -> if (leftDt == null || rightDt == null) null else divisionOpDt(leftDt, rightDt)
"&" -> leftDt "&" -> leftDt
"|" -> leftDt "|" -> leftDt
"^" -> leftDt "^" -> leftDt
@@ -118,132 +117,61 @@ class BinaryExpression(var left: Expression, var operator: String, var right: Ex
} }
companion object { companion object {
fun divisionOpDt(leftDt: DataType, rightDt: DataType): DataType { fun commonDatatype(leftDt: DataType, rightDt: DataType,
left: Expression?, right: Expression?): Pair<DataType, Expression?> {
// byte + byte -> byte
// byte + word -> word
// word + byte -> word
// word + word -> word
// a combination with a float will be float (but give a warning about this!)
return when (leftDt) { return when (leftDt) {
DataType.UBYTE -> when (rightDt) { DataType.UBYTE -> {
DataType.UBYTE, DataType.UWORD -> DataType.UBYTE when (rightDt) {
DataType.BYTE, DataType.WORD -> DataType.WORD DataType.UBYTE -> Pair(DataType.UBYTE, null)
DataType.FLOAT -> DataType.BYTE DataType.BYTE -> Pair(DataType.BYTE, left)
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt") DataType.UWORD -> Pair(DataType.UWORD, left)
DataType.WORD -> Pair(DataType.WORD, left)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
} }
DataType.BYTE -> when (rightDt) { DataType.BYTE -> {
in NumericDatatypes -> DataType.BYTE when (rightDt) {
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt") DataType.UBYTE -> Pair(DataType.BYTE, right)
DataType.BYTE -> Pair(DataType.BYTE, null)
DataType.UWORD -> Pair(DataType.WORD, left)
DataType.WORD -> Pair(DataType.WORD, left)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
} }
DataType.UWORD -> when (rightDt) { DataType.UWORD -> {
DataType.UBYTE, DataType.UWORD -> DataType.UWORD when (rightDt) {
DataType.BYTE, DataType.WORD -> DataType.WORD DataType.UBYTE -> Pair(DataType.UWORD, right)
DataType.FLOAT -> DataType.FLOAT DataType.BYTE -> Pair(DataType.WORD, right)
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt") DataType.UWORD -> Pair(DataType.UWORD, null)
DataType.WORD -> Pair(DataType.WORD, left)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
} }
DataType.WORD -> when (rightDt) { DataType.WORD -> {
in NumericDatatypes -> DataType.WORD when (rightDt) {
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt") DataType.UBYTE -> Pair(DataType.WORD, right)
DataType.BYTE -> Pair(DataType.WORD, right)
DataType.UWORD -> Pair(DataType.WORD, right)
DataType.WORD -> Pair(DataType.WORD, null)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
} }
DataType.FLOAT -> when (rightDt) { DataType.FLOAT -> {
in NumericDatatypes -> DataType.FLOAT Pair(DataType.FLOAT, right)
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt")
} }
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt") else -> Pair(leftDt, null) // non-numeric datatype
} }
} }
fun arithmeticOpDt(leftDt: DataType, rightDt: DataType): DataType {
return when (leftDt) {
DataType.UBYTE -> when (rightDt) {
DataType.UBYTE -> DataType.UBYTE
DataType.BYTE -> DataType.BYTE
DataType.UWORD -> DataType.UWORD
DataType.WORD -> DataType.WORD
DataType.FLOAT -> DataType.FLOAT
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt")
}
DataType.BYTE -> when (rightDt) {
in ByteDatatypes -> DataType.BYTE
in WordDatatypes -> DataType.WORD
DataType.FLOAT -> DataType.FLOAT
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt")
}
DataType.UWORD -> when (rightDt) {
DataType.UBYTE, DataType.UWORD -> DataType.UWORD
DataType.BYTE, DataType.WORD -> DataType.WORD
DataType.FLOAT -> DataType.FLOAT
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt")
}
DataType.WORD -> when (rightDt) {
in IntegerDatatypes -> DataType.WORD
DataType.FLOAT -> DataType.FLOAT
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt")
}
DataType.FLOAT -> when (rightDt) {
in NumericDatatypes -> DataType.FLOAT
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt")
}
else -> throw FatalAstException("arithmetic operation on incompatible datatypes: $leftDt and $rightDt")
}
}
}
fun commonDatatype(leftDt: DataType, rightDt: DataType,
left: Expression, right: Expression): Pair<DataType, Expression?> {
// byte + byte -> byte
// byte + word -> word
// word + byte -> word
// word + word -> word
// a combination with a float will be float (but give a warning about this!)
if(this.operator=="/") {
// division is a bit weird, don't cast the operands
val commondt = divisionOpDt(leftDt, rightDt)
return Pair(commondt, null)
}
return when (leftDt) {
DataType.UBYTE -> {
when (rightDt) {
DataType.UBYTE -> Pair(DataType.UBYTE, null)
DataType.BYTE -> Pair(DataType.BYTE, left)
DataType.UWORD -> Pair(DataType.UWORD, left)
DataType.WORD -> Pair(DataType.WORD, left)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
}
DataType.BYTE -> {
when (rightDt) {
DataType.UBYTE -> Pair(DataType.BYTE, right)
DataType.BYTE -> Pair(DataType.BYTE, null)
DataType.UWORD -> Pair(DataType.WORD, left)
DataType.WORD -> Pair(DataType.WORD, left)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
}
DataType.UWORD -> {
when (rightDt) {
DataType.UBYTE -> Pair(DataType.UWORD, right)
DataType.BYTE -> Pair(DataType.UWORD, right)
DataType.UWORD -> Pair(DataType.UWORD, null)
DataType.WORD -> Pair(DataType.WORD, left)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
}
DataType.WORD -> {
when (rightDt) {
DataType.UBYTE -> Pair(DataType.WORD, right)
DataType.BYTE -> Pair(DataType.WORD, right)
DataType.UWORD -> Pair(DataType.WORD, right)
DataType.WORD -> Pair(DataType.WORD, null)
DataType.FLOAT -> Pair(DataType.FLOAT, left)
else -> Pair(leftDt, null) // non-numeric datatype
}
}
DataType.FLOAT -> {
Pair(DataType.FLOAT, right)
}
else -> Pair(leftDt, null) // non-numeric datatype
}
} }
} }

View File

@@ -762,6 +762,8 @@ internal class AstChecker(private val program: Program,
checkResult.add(ExpressionError("left operand is not numeric", expr.left.position)) checkResult.add(ExpressionError("left operand is not numeric", expr.left.position))
if(rightDt!in NumericDatatypes) if(rightDt!in NumericDatatypes)
checkResult.add(ExpressionError("right operand is not numeric", expr.right.position)) checkResult.add(ExpressionError("right operand is not numeric", expr.right.position))
if(leftDt!=rightDt)
checkResult.add(ExpressionError("left and right operands aren't the same type", expr.left.position))
super.visit(expr) super.visit(expr)
} }

View File

@@ -187,26 +187,29 @@ internal class StatementReorderer(private val program: Program): IAstModifyingVi
} }
override fun visit(expr: BinaryExpression): Expression { override fun visit(expr: BinaryExpression): Expression {
val leftDt = expr.left.inferType(program) val expr2 = super.visit(expr)
val rightDt = expr.right.inferType(program) if(expr2 !is BinaryExpression)
return expr2
val leftDt = expr2.left.inferType(program)
val rightDt = expr2.right.inferType(program)
if(leftDt!=null && rightDt!=null && leftDt!=rightDt) { if(leftDt!=null && rightDt!=null && leftDt!=rightDt) {
// determine common datatype and add typecast as required to make left and right equal types // determine common datatype and add typecast as required to make left and right equal types
val (commonDt, toFix) = expr.commonDatatype(leftDt, rightDt, expr.left, expr.right) val (commonDt, toFix) = BinaryExpression.commonDatatype(leftDt, rightDt, expr2.left, expr2.right)
if(toFix!=null) { if(toFix!=null) {
when { when {
toFix===expr.left -> { toFix===expr2.left -> {
expr.left = TypecastExpression(expr.left, commonDt, true, expr.left.position) expr2.left = TypecastExpression(expr2.left, commonDt, true, expr2.left.position)
expr.left.linkParents(expr) expr2.left.linkParents(expr2)
} }
toFix===expr.right -> { toFix===expr2.right -> {
expr.right = TypecastExpression(expr.right, commonDt, true, expr.right.position) expr2.right = TypecastExpression(expr2.right, commonDt, true, expr2.right.position)
expr.right.linkParents(expr) expr2.right.linkParents(expr2)
} }
else -> throw FatalAstException("confused binary expression side") else -> throw FatalAstException("confused binary expression side")
} }
} }
} }
return super.visit(expr) return expr2
} }
override fun visit(assignment: Assignment): Statement { override fun visit(assignment: Assignment): Statement {

View File

@@ -8,6 +8,7 @@ import prog8.ast.processing.IAstModifyingVisitor
import prog8.ast.statements.AnonymousScope import prog8.ast.statements.AnonymousScope
import prog8.ast.statements.Statement import prog8.ast.statements.Statement
import prog8.ast.statements.VarDecl import prog8.ast.statements.VarDecl
import java.util.*
class AnonymousScopeVarsCleanup(val program: Program): IAstModifyingVisitor { class AnonymousScopeVarsCleanup(val program: Program): IAstModifyingVisitor {
companion object { companion object {
@@ -32,8 +33,12 @@ class AnonymousScopeVarsCleanup(val program: Program): IAstModifyingVisitor {
private val varsToMove: MutableMap<AnonymousScope, List<VarDecl>> = mutableMapOf() private val varsToMove: MutableMap<AnonymousScope, List<VarDecl>> = mutableMapOf()
private val currentAnonScope: Stack<AnonymousScope> = Stack()
override fun visit(scope: AnonymousScope): Statement { override fun visit(scope: AnonymousScope): Statement {
currentAnonScope.push(scope)
val scope2 = super.visit(scope) as AnonymousScope val scope2 = super.visit(scope) as AnonymousScope
currentAnonScope.pop()
val vardecls = scope2.statements.filterIsInstance<VarDecl>() val vardecls = scope2.statements.filterIsInstance<VarDecl>()
varsToMove[scope2] = vardecls varsToMove[scope2] = vardecls
return scope2 return scope2
@@ -43,26 +48,35 @@ class AnonymousScopeVarsCleanup(val program: Program): IAstModifyingVisitor {
override fun visit(decl: VarDecl): Statement { override fun visit(decl: VarDecl): Statement {
val decl2 = super.visit(decl) as VarDecl val decl2 = super.visit(decl) as VarDecl
val scope = decl2.definingScope() if(currentAnonScope.isEmpty())
if(scope is AnonymousScope) { return decl2
return decl2.withPrefixedName(nameprefix(scope)) return decl2.withPrefixedName(nameprefix(currentAnonScope.peek()))
}
return decl2
} }
override fun visit(identifier: IdentifierReference): Expression { override fun visit(identifier: IdentifierReference): Expression {
val ident = super.visit(identifier) val ident = super.visit(identifier)
if(ident !is IdentifierReference) if(ident !is IdentifierReference)
return ident return ident
if(currentAnonScope.isEmpty())
val scope = ident.definingScope() as? AnonymousScope ?: return ident return ident
val vardecl = ident.targetVarDecl(program.namespace) val vardecl = ident.targetVarDecl(program.namespace)
return if(vardecl!=null && vardecl.definingScope() == ident.definingScope()) { return if(vardecl!=null && vardecl.definingScope() === ident.definingScope()) {
// prefix the variable name reference that is defined inside the anon scope // prefix the variable name reference that is defined inside the anon scope
ident.withPrefixedName(nameprefix(scope)) ident.withPrefixedName(nameprefix(currentAnonScope.peek()))
} else { } else {
ident ident
} }
} }
/*
; @todo FIX Symbol lookup over anon scopes
; sub start() {
; for ubyte i in 0 to 10 {
; word rz = 4
; if rz >= 1 {
; word persp = rz+1
; }
; }
; }
*/
} }

View File

@@ -1823,6 +1823,9 @@ $endLabel""")
} }
} }
// TODO: use optimized routines such as mul_10
private fun translateBinaryOperatorBytes(operator: String, types: DataType) { private fun translateBinaryOperatorBytes(operator: String, types: DataType) {
when(operator) { when(operator) {
"**" -> throw AssemblyError("** operator requires floats") "**" -> throw AssemblyError("** operator requires floats")

View File

@@ -2,6 +2,8 @@ package prog8.compiler.target.c64.codegen2
import prog8.ast.IFunctionCall import prog8.ast.IFunctionCall
import prog8.ast.Program import prog8.ast.Program
import prog8.ast.base.ByteDatatypes
import prog8.ast.base.DataType
import prog8.ast.base.WordDatatypes import prog8.ast.base.WordDatatypes
import prog8.ast.expressions.Expression import prog8.ast.expressions.Expression
import prog8.ast.expressions.FunctionCall import prog8.ast.expressions.FunctionCall
@@ -56,6 +58,24 @@ internal class BuiltinFunctionsAsmGen(private val program: Program,
translateFunctionArguments(fcall.arglist) translateFunctionArguments(fcall.arglist)
asmgen.out(" inx | lda $ESTACK_LO_HEX,x | sta $ESTACK_HI_PLUS1_HEX,x") asmgen.out(" inx | lda $ESTACK_LO_HEX,x | sta $ESTACK_HI_PLUS1_HEX,x")
} }
"abs" -> {
translateFunctionArguments(fcall.arglist)
val dt = fcall.arglist.single().inferType(program)!!
when (dt) {
in ByteDatatypes -> asmgen.out(" jsr prog8_lib.abs_b")
in WordDatatypes -> asmgen.out(" jsr prog8_lib.abs_w")
DataType.FLOAT -> asmgen.out(" jsr c64flt.abs_f")
else -> throw AssemblyError("weird type")
}
}
// TODO: any(f), all(f), max(f), min(f), sum(f)
"sin", "cos", "tan", "atan",
"ln", "log2", "sqrt", "rad",
"deg", "round", "floor", "ceil",
"rdnf" -> {
translateFunctionArguments(fcall.arglist)
asmgen.out(" jsr c64flt.func_$functionName")
}
else -> { else -> {
translateFunctionArguments(fcall.arglist) translateFunctionArguments(fcall.arglist)
asmgen.out(" jsr prog8_lib.func_$functionName") asmgen.out(" jsr prog8_lib.func_$functionName")

View File

@@ -145,12 +145,11 @@ fun builtinFunctionReturnType(function: String, args: List<Expression>, program:
return when (function) { return when (function) {
"abs" -> { "abs" -> {
when(val dt = args.single().inferType(program)) { val dt = args.single().inferType(program)
in ByteDatatypes -> DataType.UBYTE if(dt in NumericDatatypes)
in WordDatatypes -> DataType.UWORD return dt
DataType.FLOAT -> DataType.FLOAT else
else -> throw FatalAstException("weird datatype passed to abs $dt") throw FatalAstException("weird datatype passed to abs $dt")
}
} }
"max", "min" -> { "max", "min" -> {
when(val dt = datatypeFromIterableArg(args.single())) { when(val dt = datatypeFromIterableArg(args.single())) {

View File

@@ -107,8 +107,6 @@ open class RuntimeValue(val type: DataType, num: Number?=null, val str: String?=
asBoolean = floatval != 0.0 asBoolean = floatval != 0.0
} }
else -> { else -> {
if(heapId==null)
throw IllegalArgumentException("for non-numeric types, a heapId should be given")
byteval = null byteval = null
wordval = null wordval = null
floatval = null floatval = null
@@ -628,7 +626,7 @@ open class RuntimeValue(val type: DataType, num: Number?=null, val str: String?=
} }
class RuntimeValueRange(type: DataType, val range: IntProgression): RuntimeValue(type, 0) { class RuntimeValueRange(type: DataType, val range: IntProgression): RuntimeValue(type, array=range.toList().toTypedArray()) {
override fun iterator(): Iterator<Number> { override fun iterator(): Iterator<Number> {
return range.iterator() return range.iterator()
} }

View File

@@ -675,23 +675,23 @@ class AstVm(val program: Program) {
dialog.canvas.printText(args[0].wordval!!.toString(), true) dialog.canvas.printText(args[0].wordval!!.toString(), true)
} }
"c64scr.print_ubhex" -> { "c64scr.print_ubhex" -> {
val prefix = if (args[0].asBoolean) "$" else "" val number = args[0].byteval!!
val number = args[1].byteval!! val prefix = if (args[1].asBoolean) "$" else ""
dialog.canvas.printText("$prefix${number.toString(16).padStart(2, '0')}", true) dialog.canvas.printText("$prefix${number.toString(16).padStart(2, '0')}", true)
} }
"c64scr.print_uwhex" -> { "c64scr.print_uwhex" -> {
val prefix = if (args[0].asBoolean) "$" else "" val number = args[0].wordval!!
val number = args[1].wordval!! val prefix = if (args[1].asBoolean) "$" else ""
dialog.canvas.printText("$prefix${number.toString(16).padStart(4, '0')}", true) dialog.canvas.printText("$prefix${number.toString(16).padStart(4, '0')}", true)
} }
"c64scr.print_uwbin" -> { "c64scr.print_uwbin" -> {
val prefix = if (args[0].asBoolean) "%" else "" val number = args[0].wordval!!
val number = args[1].wordval!! val prefix = if (args[1].asBoolean) "%" else ""
dialog.canvas.printText("$prefix${number.toString(2).padStart(16, '0')}", true) dialog.canvas.printText("$prefix${number.toString(2).padStart(16, '0')}", true)
} }
"c64scr.print_ubbin" -> { "c64scr.print_ubbin" -> {
val prefix = if (args[0].asBoolean) "%" else "" val number = args[0].byteval!!
val number = args[1].byteval!! val prefix = if (args[1].asBoolean) "%" else ""
dialog.canvas.printText("$prefix${number.toString(2).padStart(8, '0')}", true) dialog.canvas.printText("$prefix${number.toString(2).padStart(8, '0')}", true)
} }
"c64scr.clear_screenchars" -> { "c64scr.clear_screenchars" -> {

View File

@@ -387,7 +387,7 @@ arithmetic: ``+`` ``-`` ``*`` ``/`` ``**`` ``%``
``+``, ``-``, ``*``, ``/`` are the familiar arithmetic operations. ``+``, ``-``, ``*``, ``/`` are the familiar arithmetic operations.
``/`` is division (will result in integer division when using on integer operands, and a floating point division when at least one of the operands is a float) ``/`` is division (will result in integer division when using on integer operands, and a floating point division when at least one of the operands is a float)
``**`` is the power operator: ``3 ** 5`` is equal to 3*3*3*3*3 and is 243. (it only works on floating point variables) ``**`` is the power operator: ``3 ** 5`` is equal to 3*3*3*3*3 and is 243. (it only works on floating point variables)
``%`` is the remainder operator: ``25 % 7`` is 4. Be careful: without a space, %10 will be parsed as the binary number 2 ``%`` is the remainder operator: ``25 % 7`` is 4. Be careful: without a space, %10 will be parsed as the binary number 2.
Remainder is only supported on integer operands (not floats). Remainder is only supported on integer operands (not floats).
bitwise arithmetic: ``&`` ``|`` ``^`` ``~`` ``<<`` ``>>`` bitwise arithmetic: ``&`` ``|`` ``^`` ``~`` ``<<`` ``>>``

View File

@@ -1,36 +1,47 @@
%import c64lib
%import c64utils
%import c64flt
%zeropage basicsafe %zeropage basicsafe
main { main {
sub start() { sub start() {
c64.CHROUT('\n') ; float fl = 123.4567
%asm {{ ; c64flt.print_f(round(fl))
stx $0410 ; c64.CHROUT('\n')
}} ; c64flt.print_f(round(fl))
c64.CHRIN() ; c64.CHROUT('\n')
%asm {{ ; c64flt.print_f(round(fl))
stx $0411 ; c64.CHROUT('\n')
}} ; c64flt.print_f(ceil(fl))
print_notes(80,35) ; c64.CHROUT('\n')
%asm {{ ; c64flt.print_f(ceil(fl))
stx $0412 ; c64.CHROUT('\n')
}} ; c64flt.print_f(ceil(fl))
return ; c64.CHROUT('\n')
} ; c64flt.print_f(floor(fl))
; c64.CHROUT('\n')
; c64flt.print_f(floor(fl))
; c64.CHROUT('\n')
; c64flt.print_f(floor(fl))
; c64.CHROUT('\n')
; @($040a)=X
; return
sub print_notes(ubyte n1, ubyte n2) { while true {
c64scr.print_ub(n1/2) float clock_seconds = ((mkword(c64.TIME_LO, c64.TIME_MID) as float) + (c64.TIME_HI as float)*65536.0) / 60
c64.CHROUT('\n') float hours = floor(clock_seconds / 3600)
c64scr.print_ub(n1/3) clock_seconds -= hours*3600
c64.CHROUT('\n') float minutes = floor(clock_seconds / 60)
c64scr.print_ub(n1/4) clock_seconds = floor(clock_seconds - minutes * 60.0)
c64.CHROUT('\n')
c64.CHROUT('\n') c64scr.print("system time in ti$ is ")
c64scr.print_ub(n2/2) c64flt.print_f(hours)
c64.CHROUT('\n') c64.CHROUT(':')
c64scr.print_ub(n2/3) c64flt.print_f(minutes)
c64.CHROUT('\n') c64.CHROUT(':')
c64scr.print_ub(n2/4) c64flt.print_f(clock_seconds)
c64.CHROUT('\n') c64.CHROUT('\n')
}
} }
} }