diff --git a/codeCore/src/prog8/code/ast/AstExpressions.kt b/codeCore/src/prog8/code/ast/AstExpressions.kt index f69a24ccf..3eef2cdeb 100644 --- a/codeCore/src/prog8/code/ast/AstExpressions.kt +++ b/codeCore/src/prog8/code/ast/AstExpressions.kt @@ -106,6 +106,7 @@ sealed class PtExpression(val type: DataType, position: Position) : PtNode(posit is PtRange -> true is PtString -> true is PtTypeCast -> value.isSimple() + is PtIfExpression -> condition.isSimple() && truevalue.isSimple() && falsevalue.isSimple() } } @@ -206,6 +207,16 @@ class PtBinaryExpression(val operator: String, type: DataType, position: Positio } +class PtIfExpression(type: DataType, position: Position): PtExpression(type, position) { + val condition: PtExpression + get() = children[0] as PtExpression + val truevalue: PtExpression + get() = children[1] as PtExpression + val falsevalue: PtExpression + get() = children[2] as PtExpression +} + + class PtContainmentCheck(position: Position): PtExpression(DataType.BOOL, position) { val needle: PtExpression get() = children[0] as PtExpression diff --git a/codeCore/src/prog8/code/ast/AstPrinter.kt b/codeCore/src/prog8/code/ast/AstPrinter.kt index 993e0a743..7daa2694a 100644 --- a/codeCore/src/prog8/code/ast/AstPrinter.kt +++ b/codeCore/src/prog8/code/ast/AstPrinter.kt @@ -155,6 +155,7 @@ fun printAst(root: PtNode, skipLibraries: Boolean, output: (text: String) -> Uni "->" } is PtDefer -> "" + is PtIfExpression -> "" } } diff --git a/codeGenCpu6502/src/prog8/codegen/cpu6502/assignment/AssignmentAsmGen.kt b/codeGenCpu6502/src/prog8/codegen/cpu6502/assignment/AssignmentAsmGen.kt index bfc1185b3..48cfd5479 100644 --- a/codeGenCpu6502/src/prog8/codegen/cpu6502/assignment/AssignmentAsmGen.kt +++ b/codeGenCpu6502/src/prog8/codegen/cpu6502/assignment/AssignmentAsmGen.kt @@ -639,10 +639,52 @@ internal class AssignmentAsmGen( throw AssemblyError("Expression is too complex to translate into assembly. Split it up into several separate statements, introduce a temporary variable, or otherwise rewrite it. Location: $pos") } } + is PtIfExpression -> assignIfExpression(assign.target, value) else -> throw AssemblyError("weird assignment value type $value") } } + private fun assignIfExpression(target: AsmAssignTarget, expr: PtIfExpression) { + // TODO dont store condition as expression result but just use the flags, like a normal PtIfElse translation does + require(target.datatype==expr.type) + val falseLabel = asmgen.makeLabel("ifexpr_false") + val endLabel = asmgen.makeLabel("ifexpr_end") + assignExpressionToRegister(expr.condition, RegisterOrPair.A, false) + asmgen.out(" beq $falseLabel") + when(expr.type) { + in ByteDatatypesWithBoolean -> { + assignExpressionToRegister(expr.truevalue, RegisterOrPair.A, false) + assignRegisterByte(target, CpuRegister.A, false, false) + asmgen.jmp(endLabel) + asmgen.out(falseLabel) + assignExpressionToRegister(expr.falsevalue, RegisterOrPair.A, false) + assignRegisterByte(target, CpuRegister.A, false, false) + asmgen.out(endLabel) + } + in WordDatatypes -> { + assignExpressionToRegister(expr.truevalue, RegisterOrPair.AY, false) + assignRegisterpairWord(target, RegisterOrPair.AY) + asmgen.jmp(endLabel) + asmgen.out(falseLabel) + assignExpressionToRegister(expr.falsevalue, RegisterOrPair.AY, false) + assignRegisterpairWord(target, RegisterOrPair.AY) + asmgen.out(endLabel) + } + DataType.FLOAT -> { + val trueSrc = AsmAssignSource.fromAstSource(expr.truevalue, program, asmgen) + val assignTrue = AsmAssignment(trueSrc, target, program.memsizer, expr.position) + translateNormalAssignment(assignTrue, expr.definingISub()) + asmgen.jmp(endLabel) + asmgen.out(falseLabel) + val falseSrc = AsmAssignSource.fromAstSource(expr.falsevalue, program, asmgen) + val assignFalse = AsmAssignment(falseSrc, target, program.memsizer, expr.position) + translateNormalAssignment(assignFalse, expr.definingISub()) + asmgen.out(endLabel) + } + else -> throw AssemblyError("weird dt") + } + } + private fun assignPrefixedExpressionToArrayElt(assign: AsmAssignment, scope: IPtSubroutine?) { require(assign.source.expression is PtPrefix) if(assign.source.datatype==DataType.FLOAT) { diff --git a/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt b/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt index ced6d781f..05c6865f8 100644 --- a/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt +++ b/codeGenIntermediate/src/prog8/codegen/intermediate/ExpressionGen.kt @@ -79,6 +79,7 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { is PtPrefix -> translate(expr) is PtArrayIndexer -> translate(expr) is PtBinaryExpression -> translate(expr) + is PtIfExpression -> translate(expr) is PtBuiltinFunctionCall -> codeGen.translateBuiltinFunc(expr) is PtFunctionCall -> translate(expr) is PtContainmentCheck -> translate(expr) @@ -89,6 +90,36 @@ internal class ExpressionGen(private val codeGen: IRCodeGen) { } } + private fun translate(ifExpr: PtIfExpression): ExpressionCodeResult { + // TODO dont store condition as expression result but just use the flags, like a normal PtIfElse translation does + val condTr = translateExpression(ifExpr.condition) + val trueTr = translateExpression(ifExpr.truevalue) + val falseTr = translateExpression(ifExpr.falsevalue) + val irDt = irType(ifExpr.type) + val result = mutableListOf() + val falseLabel = codeGen.createLabelName() + val endLabel = codeGen.createLabelName() + + addToResult(result, condTr, condTr.resultReg, -1) + addInstr(result, IRInstruction(Opcode.BSTEQ, labelSymbol = falseLabel), null) + + if (irDt != IRDataType.FLOAT) { + addToResult(result, trueTr, trueTr.resultReg, -1) + addInstr(result, IRInstruction(Opcode.JUMP, labelSymbol = endLabel), null) + result += IRCodeChunk(falseLabel, null) + addToResult(result, falseTr, trueTr.resultReg, -1) + result += IRCodeChunk(endLabel, null) + return ExpressionCodeResult(result, irDt, trueTr.resultReg, -1) + } else { + addToResult(result, trueTr, -1, trueTr.resultFpReg) + addInstr(result, IRInstruction(Opcode.JUMP, labelSymbol = endLabel), null) + result += IRCodeChunk(falseLabel, null) + addToResult(result, falseTr, -1, trueTr.resultFpReg) + result += IRCodeChunk(endLabel, null) + return ExpressionCodeResult(result, irDt, -1, trueTr.resultFpReg) + } + } + private fun translate(expr: PtAddressOf): ExpressionCodeResult { val vmDt = irType(expr.type) val symbol = expr.identifier.name diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index ddac64fa7..9fa453dfb 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -662,7 +662,7 @@ internal class AstChecker(private val program: Program, if (targetDatatype.isKnown) { val sourceDatatype = assignment.value.inferType(program) if (sourceDatatype.isUnknown) { - if (assignment.value !is BinaryExpression && assignment.value !is PrefixExpression && assignment.value !is ContainmentCheck) + if (assignment.value !is BinaryExpression && assignment.value !is PrefixExpression && assignment.value !is ContainmentCheck && assignment.value !is IfExpression) errors.err("invalid assignment value, maybe forgot '&' (address-of)", assignment.value.position) } else { checkAssignmentCompatible(assignTarget, targetDatatype.getOr(DataType.UNDEFINED), @@ -684,6 +684,20 @@ internal class AstChecker(private val program: Program, super.visit(addressOf) } + override fun visit(ifExpr: IfExpression) { + if(!ifExpr.condition.inferType(program).isBool) + errors.err("condition should be a boolean", ifExpr.condition.position) + + val trueDt = ifExpr.truevalue.inferType(program) + val falseDt = ifExpr.falsevalue.inferType(program) + if(trueDt.isUnknown || falseDt.isUnknown) { + errors.err("invalid value type(s)", ifExpr.position) + } else if(trueDt!=falseDt) { + errors.err("both values should be the same type", ifExpr.truevalue.position) + } + super.visit(ifExpr) + } + override fun visit(decl: VarDecl) { if(decl.names.size>1) throw InternalCompilerException("vardecls with multiple names should have been converted into individual vardecls") diff --git a/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt b/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt index ed9b18cba..9c6796575 100644 --- a/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt +++ b/compiler/src/prog8/compiler/astprocessing/IntermediateAstMaker.kt @@ -86,9 +86,19 @@ class IntermediateAstMaker(private val program: Program, private val errors: IEr is RangeExpression -> transform(expr) is StringLiteral -> transform(expr) is TypecastExpression -> transform(expr) + is IfExpression -> transform(expr) } } + private fun transform(ifExpr: IfExpression): PtIfExpression { + val type = ifExpr.inferType(program).getOrElse { throw FatalAstException("unknown dt") } + val ifexpr = PtIfExpression(type, ifExpr.position) + ifexpr.add(transformExpression(ifExpr.condition)) + ifexpr.add(transformExpression(ifExpr.truevalue)) + ifexpr.add(transformExpression(ifExpr.falsevalue)) + return ifexpr + } + private fun transform(srcDefer: Defer): PtDefer { val defer = PtDefer(srcDefer.position) srcDefer.scope.statements.forEach { diff --git a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt index f1213478f..01b3299b0 100644 --- a/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt +++ b/compilerAst/src/prog8/ast/antlr/Antlr2Kotlin.kt @@ -548,6 +548,12 @@ private fun ExpressionContext.toAst(insideParentheses: Boolean=false) : Expressi AddressOf(array.scoped_identifier().toAst(), array.arrayindex().toAst(), toPosition()) } + if(if_expression()!=null) { + val ifex = if_expression() + val (condition, truevalue, falsevalue) = ifex.expression() + return IfExpression(condition.toAst(), truevalue.toAst(), falsevalue.toAst(), toPosition()) + } + throw FatalAstException(text) } diff --git a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt index 411c1abe6..ae0be1c75 100644 --- a/compilerAst/src/prog8/ast/expressions/AstExpressions.kt +++ b/compilerAst/src/prog8/ast/expressions/AstExpressions.kt @@ -1382,6 +1382,49 @@ class BuiltinFunctionCall(override var target: IdentifierReference, override fun inferType(program: Program) = program.builtinFunctions.returnType(name) } +class IfExpression(var condition: Expression, var truevalue: Expression, var falsevalue: Expression, override val position: Position) : Expression() { + + override lateinit var parent: Node + + override fun linkParents(parent: Node) { + this.parent = parent + condition.linkParents(this) + truevalue.linkParents(this) + falsevalue.linkParents(this) + } + + override val isSimple: Boolean = condition.isSimple && truevalue.isSimple && falsevalue.isSimple + + override fun toString() = "IfExpr(cond=$condition, true=$truevalue, false=$falsevalue, pos=$position)" + override fun accept(visitor: IAstVisitor) = visitor.visit(this) + override fun accept(visitor: AstWalker, parent: Node) = visitor.visit(this, parent) + override fun referencesIdentifier(nameInSource: List): Boolean = condition.referencesIdentifier(nameInSource) || truevalue.referencesIdentifier(nameInSource) || falsevalue.referencesIdentifier(nameInSource) + override fun inferType(program: Program): InferredTypes.InferredType { + val t1 = truevalue.inferType(program) + val t2 = falsevalue.inferType(program) + return if(t1==t2) t1 else InferredTypes.InferredType.unknown() + } + + override fun copy(): Expression = IfExpression(condition.copy(), truevalue.copy(), falsevalue.copy(), position) + + override fun constValue(program: Program): NumericLiteral? { + val cond = condition.constValue(program) + if(cond!=null) { + return if (cond.asBooleanValue) truevalue.constValue(program) else falsevalue.constValue(program) + } + return null + } + + override fun replaceChildNode(node: Node, replacement: Node) { + if(replacement !is Expression) + throw throw FatalAstException("invalid replace") + if(node===condition) condition=replacement + else if(node===truevalue) truevalue=replacement + else if(node===falsevalue) falsevalue=replacement + else throw FatalAstException("invalid replace") + } +} + fun invertCondition(cond: Expression, program: Program): Expression { if(cond is BinaryExpression) { val invertedOperator = invertedComparisonOperator(cond.operator) diff --git a/compilerAst/src/prog8/ast/walk/AstWalker.kt b/compilerAst/src/prog8/ast/walk/AstWalker.kt index edf35de4e..5bf713224 100644 --- a/compilerAst/src/prog8/ast/walk/AstWalker.kt +++ b/compilerAst/src/prog8/ast/walk/AstWalker.kt @@ -109,6 +109,7 @@ abstract class AstWalker { open fun before(directive: Directive, parent: Node): Iterable = noModifications open fun before(expr: BinaryExpression, parent: Node): Iterable = noModifications open fun before(expr: PrefixExpression, parent: Node): Iterable = noModifications + open fun before(ifExpr: IfExpression, parent: Node): Iterable = noModifications open fun before(forLoop: ForLoop, parent: Node): Iterable = noModifications open fun before(repeatLoop: RepeatLoop, parent: Node): Iterable = noModifications open fun before(unrollLoop: UnrollLoop, parent: Node): Iterable = noModifications @@ -154,6 +155,7 @@ abstract class AstWalker { open fun after(directive: Directive, parent: Node): Iterable = noModifications open fun after(expr: BinaryExpression, parent: Node): Iterable = noModifications open fun after(expr: PrefixExpression, parent: Node): Iterable = noModifications + open fun after(ifExpr: IfExpression, parent: Node): Iterable = noModifications open fun after(forLoop: ForLoop, parent: Node): Iterable = noModifications open fun after(repeatLoop: RepeatLoop, parent: Node): Iterable = noModifications open fun after(unrollLoop: UnrollLoop, parent: Node): Iterable = noModifications @@ -475,6 +477,14 @@ abstract class AstWalker { track(after(addressOf, parent), addressOf, parent) } + fun visit(ifExpr: IfExpression, parent: Node) { + track(before(ifExpr, parent), ifExpr, parent) + ifExpr.condition.accept(this, ifExpr) + ifExpr.truevalue.accept(this, ifExpr) + ifExpr.falsevalue.accept(this, ifExpr) + track(after(ifExpr, parent), ifExpr, parent) + } + fun visit(inlineAssembly: InlineAssembly, parent: Node) { track(before(inlineAssembly, parent), inlineAssembly, parent) track(after(inlineAssembly, parent), inlineAssembly, parent) diff --git a/compilerAst/src/prog8/ast/walk/IAstVisitor.kt b/compilerAst/src/prog8/ast/walk/IAstVisitor.kt index 568b1c71e..86cbfe1ea 100644 --- a/compilerAst/src/prog8/ast/walk/IAstVisitor.kt +++ b/compilerAst/src/prog8/ast/walk/IAstVisitor.kt @@ -87,6 +87,12 @@ interface IAstVisitor { range.step.accept(this) } + fun visit(ifExpr: IfExpression) { + ifExpr.condition.accept(this) + ifExpr.truevalue.accept(this) + ifExpr.falsevalue.accept(this) + } + fun visit(label: Label) { } diff --git a/docs/source/todo.rst b/docs/source/todo.rst index 17fdab331..abcf3e01b 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -1,10 +1,14 @@ TODO ==== -- what if you use defer in a loop! (zig: defer in a loop is executed at the end of each iteration) -> now: wrong code is generated if the error msg is removed +- defers that haven't been reached yet should not be executed (how will we do this? some kind of runtime support needed? refcount or bitmask, not a boolean var per defer that would be wasteful) - unit test for defer - describe defer in the manual +- unit test for ifexpression +- describe ifexpression in the manual +- Optimize the IfExpression code generation to be more like regular if-else code. (both 6502 and IR) + Improve register load order in subroutine call args assignments: in certain situations, the "wrong" order of evaluation of function call arguments is done which results diff --git a/examples/test.p8 b/examples/test.p8 index da58d3d9c..fa80cb93d 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -7,63 +7,33 @@ main { sub start() { - for cx16.r0L in 0 to 10 { - defer txt.print("end!!\n") - } - txt.print("old way:\n") - void oldway() - txt.print("\nnew way:\n") - void newway() - } - sub oldway() -> bool { + ubyte @shared c=99 + if c>100 + cx16.r0L++ + cx16.r0L = if (c>100) 2 else (3) + txt.print_ub(if (c>100) 2 else 3) + txt.nl() + txt.print_ub(if (c<100) 6 else 7) + txt.nl() + + float @shared fl=99.99 + floats.print(if (c>100) 2.22 else 3.33) + txt.nl() + floats.print(if (c<100) 6.66 else 7.77) + txt.nl() + uword res1 = allocate(111) - if res1==0 - return false - - uword res2 = allocate(222) - if res2==0 { - deallocate(res1) - return false - } - - if not process1(res1, res2) { - deallocate(res1) - deallocate(res2) - return false - } - if not process2(res1, res2) { - deallocate(res1) - deallocate(res2) - return false - } - - deallocate(res1) - deallocate(res2) - return true - } - - sub newway() -> bool { - uword res1 = allocate(111) - if res1==0 - return false - defer { - deallocate(res1) - } - + defer deallocate(res1) uword res2 = allocate(222) if res2==0 - return false - defer { - deallocate(res2) - } + return + defer deallocate(res2) if not process1(res1, res2) - return false + return if not process2(res1, res2) - return false - - return true + return } sub allocate(uword arg) -> uword { diff --git a/parser/antlr/Prog8ANTLR.g4 b/parser/antlr/Prog8ANTLR.g4 index a08784ad7..ac0b64be7 100644 --- a/parser/antlr/Prog8ANTLR.g4 +++ b/parser/antlr/Prog8ANTLR.g4 @@ -197,6 +197,7 @@ expression : | directmemory | addressof | expression typecast + | if_expression ; arrayindexed: @@ -210,7 +211,6 @@ directmemory : '@' '(' expression ')'; addressof : ADDRESS_OF (scoped_identifier | arrayindexed) ; - functioncall : scoped_identifier '(' expression_list? ')' ; functioncall_stmt : VOID? scoped_identifier '(' expression_list? ')' ; @@ -299,6 +299,7 @@ if_stmt : 'if' expression EOL? (statement | statement_block) EOL? else_part? ; else_part : 'else' EOL? (statement | statement_block) ; // statement is constrained later +if_expression : 'if' expression EOL? expression EOL? 'else' EOL? expression ; branch_stmt : branchcondition EOL? (statement | statement_block) EOL? else_part? ;