diff --git a/codeCore/src/prog8/code/core/Enumerations.kt b/codeCore/src/prog8/code/core/Enumerations.kt index ecfb278ee..1cda51515 100644 --- a/codeCore/src/prog8/code/core/Enumerations.kt +++ b/codeCore/src/prog8/code/core/Enumerations.kt @@ -23,10 +23,10 @@ enum class DataType { when(this) { BOOL -> targetType.oneOf(BOOL, BYTE, UBYTE, WORD, UWORD, FLOAT) UBYTE -> targetType.oneOf(UBYTE, WORD, UWORD, FLOAT, BOOL) - BYTE -> targetType.oneOf(BYTE, WORD, FLOAT, BOOL) - UWORD -> targetType.oneOf(UWORD, FLOAT, BOOL) - WORD -> targetType.oneOf(WORD, FLOAT, BOOL) - FLOAT -> targetType.oneOf(FLOAT, BOOL) + BYTE -> targetType.oneOf(BYTE, WORD, FLOAT) + UWORD -> targetType.oneOf(UWORD, FLOAT) + WORD -> targetType.oneOf(WORD, FLOAT) + FLOAT -> targetType.oneOf(FLOAT) STR -> targetType.oneOf(STR, UWORD) in ArrayDatatypes -> targetType == this else -> false @@ -115,7 +115,9 @@ enum class BranchCondition { val ByteDatatypes = arrayOf(DataType.UBYTE, DataType.BYTE, DataType.BOOL) val WordDatatypes = arrayOf(DataType.UWORD, DataType.WORD) val IntegerDatatypes = arrayOf(DataType.UBYTE, DataType.BYTE, DataType.UWORD, DataType.WORD, DataType.BOOL) +val IntegerDatatypesNoBool = arrayOf(DataType.UBYTE, DataType.BYTE, DataType.UWORD, DataType.WORD) val NumericDatatypes = arrayOf(DataType.UBYTE, DataType.BYTE, DataType.UWORD, DataType.WORD, DataType.FLOAT, DataType.BOOL) +val NumericDatatypesNoBool = arrayOf(DataType.UBYTE, DataType.BYTE, DataType.UWORD, DataType.WORD, DataType.FLOAT) val SignedDatatypes = arrayOf(DataType.BYTE, DataType.WORD, DataType.FLOAT) val ArrayDatatypes = arrayOf(DataType.ARRAY_UB, DataType.ARRAY_B, DataType.ARRAY_UW, DataType.ARRAY_W, DataType.ARRAY_F, DataType.ARRAY_BOOL) val StringlyDatatypes = arrayOf(DataType.STR, DataType.ARRAY_UB, DataType.ARRAY_B, DataType.UWORD) diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index e8bafe04c..8520a9947 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -99,8 +99,9 @@ internal class AstChecker(private val program: Program, } override fun visit(ifElse: IfElse) { - if(!ifElse.condition.inferType(program).isInteger) - errors.err("condition value should be an integer type", ifElse.condition.position) + val dt = ifElse.condition.inferType(program) + if(!dt.isInteger && !dt.istype(DataType.BOOL)) + errors.err("condition value should be an integer type or bool", ifElse.condition.position) super.visit(ifElse) } @@ -435,14 +436,16 @@ internal class AstChecker(private val program: Program, } override fun visit(untilLoop: UntilLoop) { - if(!untilLoop.condition.inferType(program).isInteger) - errors.err("condition value should be an integer type", untilLoop.condition.position) + val dt = untilLoop.condition.inferType(program) + if(!dt.isInteger && !dt.istype(DataType.BOOL)) + errors.err("condition value should be an integer type or bool", untilLoop.condition.position) super.visit(untilLoop) } override fun visit(whileLoop: WhileLoop) { - if(!whileLoop.condition.inferType(program).isInteger) - errors.err("condition value should be an integer type", whileLoop.condition.position) + val dt = whileLoop.condition.inferType(program) + if(!dt.isInteger && !dt.istype(DataType.BOOL)) + errors.err("condition value should be an integer type or bool", whileLoop.condition.position) super.visit(whileLoop) } @@ -891,9 +894,9 @@ internal class AstChecker(private val program: Program, "in" -> throw FatalAstException("in expression should have been replaced by containmentcheck") } - if(leftDt !in NumericDatatypes && leftDt != DataType.STR) + if(leftDt !in NumericDatatypes && leftDt != DataType.STR && leftDt != DataType.BOOL) errors.err("left operand is not numeric or str", expr.left.position) - if(rightDt!in NumericDatatypes && rightDt != DataType.STR) + if(rightDt!in NumericDatatypes && rightDt != DataType.STR && rightDt != DataType.BOOL) errors.err("right operand is not numeric or str", expr.right.position) if(leftDt!=rightDt) { if(leftDt==DataType.STR && rightDt in IntegerDatatypes && expr.operator=="*") { diff --git a/compiler/src/prog8/compiler/astprocessing/BoolRemover.kt b/compiler/src/prog8/compiler/astprocessing/BoolRemover.kt index 311c7e109..e7f132c82 100644 --- a/compiler/src/prog8/compiler/astprocessing/BoolRemover.kt +++ b/compiler/src/prog8/compiler/astprocessing/BoolRemover.kt @@ -93,6 +93,8 @@ internal class BoolRemover(val program: Program) : AstWalker() { fun isBoolean(expr: Expression): Boolean { return if(expr.inferType(program) istype DataType.BOOL) true + else if(expr is NumericLiteral && expr.type in IntegerDatatypes && (expr.number==0.0 || expr.number==1.0)) + true else if(expr is BinaryExpression && expr.operator in ComparisonOperators + LogicalOperators) true else if(expr is PrefixExpression && expr.operator == "not") diff --git a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt index d84caaa8d..03654479a 100644 --- a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt +++ b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt @@ -43,20 +43,24 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val override fun after(expr: BinaryExpression, parent: Node): Iterable { val leftDt = expr.left.inferType(program) val rightDt = expr.right.inferType(program) + val leftCv = expr.left.constValue(program) + val rightCv = expr.right.constValue(program) + if(leftDt.isKnown && rightDt.isKnown && leftDt!=rightDt) { - // convert bool type to byte - if(leftDt istype DataType.BOOL && rightDt.isBytes) { - return listOf(IAstModification.ReplaceNode(expr.left, - TypecastExpression(expr.left, rightDt.getOr(DataType.UNDEFINED),true, expr.left.position), expr)) - } else if(leftDt.isBytes && rightDt istype DataType.BOOL) { - return listOf(IAstModification.ReplaceNode(expr.right, - TypecastExpression(expr.right, leftDt.getOr(DataType.UNDEFINED),true, expr.right.position), expr)) + // convert bool type to byte if needed + if(leftDt istype DataType.BOOL && rightDt.isBytes && !rightDt.istype(DataType.BOOL)) { + if(rightCv==null || (rightCv.number!=1.0 && rightCv.number!=0.0)) + return listOf(IAstModification.ReplaceNode(expr.left, + TypecastExpression(expr.left, rightDt.getOr(DataType.UNDEFINED),true, expr.left.position), expr)) + } else if(leftDt.isBytes && !leftDt.istype(DataType.BOOL) && rightDt istype DataType.BOOL) { + if(leftCv==null || (leftCv.number!=1.0 && leftCv.number!=0.0)) + return listOf(IAstModification.ReplaceNode(expr.right, + TypecastExpression(expr.right, leftDt.getOr(DataType.UNDEFINED),true, expr.right.position), expr)) } // convert a negative operand for bitwise operator to the 2's complement positive number instead if(expr.operator in BitwiseOperators && leftDt.isInteger && rightDt.isInteger) { - val leftCv = expr.left.constValue(program) if(leftCv!=null && leftCv.number<0) { val value = if(rightDt.isBytes) 256+leftCv.number else 65536+leftCv.number return listOf(IAstModification.ReplaceNode( @@ -64,7 +68,6 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val NumericLiteral(rightDt.getOr(DataType.UNDEFINED), value, expr.left.position), expr)) } - val rightCv = expr.right.constValue(program) if(rightCv!=null && rightCv.number<0) { val value = if(leftDt.isBytes) 256+rightCv.number else 65536+rightCv.number return listOf(IAstModification.ReplaceNode( @@ -203,6 +206,9 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val arg, cast.valueOrZero(), call as Node) + } else if(requiredType==DataType.BOOL && argtype!=DataType.BOOL) { + // cast to bool + addTypecastOrCastedValueModification(modifications, arg, requiredType, call as Node) } } } diff --git a/compiler/test/TestBuiltinFunctions.kt b/compiler/test/TestBuiltinFunctions.kt index 760cbf884..f59699599 100644 --- a/compiler/test/TestBuiltinFunctions.kt +++ b/compiler/test/TestBuiltinFunctions.kt @@ -4,9 +4,8 @@ import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe import prog8.code.core.DataType -import prog8.code.core.NumericDatatypes +import prog8.code.core.NumericDatatypesNoBool import prog8.code.core.RegisterOrPair -import prog8.code.target.C64Target import prog8.code.target.Cx16Target import prog8.compiler.BuiltinFunctions import prog8tests.helpers.compileText @@ -31,7 +30,7 @@ class TestBuiltinFunctions: FunSpec({ func.name shouldBe "sgn" func.parameters.size shouldBe 1 func.parameters[0].name shouldBe "value" - func.parameters[0].possibleDatatypes shouldBe NumericDatatypes + func.parameters[0].possibleDatatypes shouldBe NumericDatatypesNoBool func.pure shouldBe true func.returnType shouldBe DataType.BYTE diff --git a/compiler/test/TestTypecasts.kt b/compiler/test/TestTypecasts.kt index dc5417d7f..40d5e549b 100644 --- a/compiler/test/TestTypecasts.kt +++ b/compiler/test/TestTypecasts.kt @@ -20,6 +20,51 @@ import prog8tests.helpers.compileText class TestTypecasts: FunSpec({ + + test("integer args for builtin funcs") { + val text=""" + %import floats + main { + sub start() { + float fl + floats.print_f(abs(fl)) + } + }""" + val errors = ErrorReporterForTests() + val result = compileText(VMTarget(), false, text, writeAssembly = false, errors=errors) + result shouldBe null + errors.errors.size shouldBe 1 + errors.errors[0] shouldContain "type mismatch, was: FLOAT expected one of: [UBYTE, BYTE, UWORD, WORD]" + } + + test("not casting bool operands to logical operators") { + val text=""" + %import textio + %zeropage basicsafe + + main { + sub start() { + bool bb2=true + bool @shared bb = bb2 and true + } + }""" + val result = compileText(C64Target(), false, text, writeAssembly = false)!! + val stmts = result.program.entrypoint.statements + stmts.size shouldBe 4 + val expr = (stmts[3] as Assignment).value as BinaryExpression + expr.operator shouldBe "and" + expr.right shouldBe NumericLiteral(DataType.UBYTE, 1.0, Position.DUMMY) + (expr.left as IdentifierReference).nameInSource shouldBe listOf("bb2") // no cast + + val result2 = compileText(C64Target(), true, text, writeAssembly = true)!! + val stmts2 = result2.program.entrypoint.statements + stmts2.size shouldBe 6 + val expr2 = (stmts2[4] as Assignment).value as BinaryExpression + expr2.operator shouldBe "&" + expr2.right shouldBe NumericLiteral(DataType.UBYTE, 1.0, Position.DUMMY) + (expr2.left as IdentifierReference).nameInSource shouldBe listOf("bb") + } + test("bool expressions with functioncalls") { val text=""" main { @@ -50,7 +95,7 @@ main { ubyte ub3 ub3 = 1 ubyte @shared bvalue - bvalue = (((ub1^ub2)^ub3)^(1!=0)) + bvalue = (((ub1^ub2)^ub3)^1) bvalue = (((ub1^ub2)^ub3)^(ftrue(99)!=0)) bvalue = ((ub1&ub2)&(ftrue(99)!=0)) return @@ -62,14 +107,12 @@ main { assignValue1.operator shouldBe "^" assignValue2.operator shouldBe "^" assignValue3.operator shouldBe "&" - val right1 = assignValue1.right as BinaryExpression + val right1 = assignValue1.right as NumericLiteral val right2 = assignValue2.right as BinaryExpression val right3 = assignValue3.right as BinaryExpression - right1.operator shouldBe "!=" + right1.number shouldBe 1.0 right2.operator shouldBe "!=" right3.operator shouldBe "!=" - right1.left shouldBe NumericLiteral(DataType.UBYTE, 1.0, Position.DUMMY) - right1.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) right2.left shouldBe instanceOf() right2.right shouldBe NumericLiteral(DataType.UBYTE, 0.0, Position.DUMMY) right3.left shouldBe instanceOf() diff --git a/compilerAst/src/prog8/compiler/BuiltinFunctions.kt b/compilerAst/src/prog8/compiler/BuiltinFunctions.kt index 2d64af559..d8139549e 100644 --- a/compilerAst/src/prog8/compiler/BuiltinFunctions.kt +++ b/compilerAst/src/prog8/compiler/BuiltinFunctions.kt @@ -94,12 +94,12 @@ private val functionSignatures: List = listOf( FSignature("sort" , false, listOf(FParam("array", ArrayDatatypes)), null), FSignature("reverse" , false, listOf(FParam("array", ArrayDatatypes)), null), // cmp returns a status in the carry flag, but not a proper return value - FSignature("cmp" , false, listOf(FParam("value1", IntegerDatatypes), FParam("value2", NumericDatatypes)), null), - FSignature("abs" , true, listOf(FParam("value", IntegerDatatypes)), DataType.UWORD, ::builtinAbs), + FSignature("cmp" , false, listOf(FParam("value1", IntegerDatatypesNoBool), FParam("value2", NumericDatatypesNoBool)), null), + FSignature("abs" , true, listOf(FParam("value", IntegerDatatypesNoBool)), DataType.UWORD, ::builtinAbs), FSignature("len" , true, listOf(FParam("values", IterableDatatypes)), DataType.UWORD, ::builtinLen), // normal functions follow: FSignature("sizeof" , true, listOf(FParam("object", DataType.values())), DataType.UBYTE, ::builtinSizeof), - FSignature("sgn" , true, listOf(FParam("value", NumericDatatypes)), DataType.BYTE, ::builtinSgn ), + FSignature("sgn" , true, listOf(FParam("value", NumericDatatypesNoBool)), DataType.BYTE, ::builtinSgn ), FSignature("sqrt16" , true, listOf(FParam("value", arrayOf(DataType.UWORD))), DataType.UBYTE) { a, p, prg -> oneIntArgOutputInt(a, p, prg) { sqrt(it.toDouble()) } }, FSignature("any" , true, listOf(FParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArg(a, p, prg, ::builtinAny) }, FSignature("all" , true, listOf(FParam("values", ArrayDatatypes)), DataType.UBYTE) { a, p, prg -> collectionArg(a, p, prg, ::builtinAll) }, @@ -178,7 +178,7 @@ private fun builtinAbs(args: List, position: Position, program: Prog val constval = args[0].constValue(program) ?: throw NotConstArgumentException() return when (constval.type) { - in IntegerDatatypes -> NumericLiteral.optimalInteger(abs(constval.number.toInt()), args[0].position) + in IntegerDatatypesNoBool -> NumericLiteral.optimalInteger(abs(constval.number.toInt()), args[0].position) else -> throw SyntaxError("abs requires one integer argument", position) } } diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 11e3168d1..a26a8c6c3 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -3,8 +3,7 @@ TODO For next release ^^^^^^^^^^^^^^^^ -- fix compiler crash (abs(fl)) ; WHY IS THIS GETTING A BOOLEAN CAST??? -- vm why is bb = bb2 and true generating so large code? +- bool @shared bb = bb2 and true should not add typecast around bb2 ... diff --git a/examples/test.p8 b/examples/test.p8 index 01f06abc5..076b48991 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,13 +1,10 @@ %import textio -%import floats %zeropage basicsafe - main { sub start() { - float fl - fl = -3.14 - floats.print_f(abs(fl)) ; WHY IS THIS GETTING A BOOLEAN CAST??? - txt.nl() + bool bb2=true + bool @shared bb = bb2 and true + txt.print_ub(bb) } }