diff --git a/codeGenVirtual/src/prog8/codegen/virtual/CodeGen.kt b/codeGenVirtual/src/prog8/codegen/virtual/CodeGen.kt index 8136373e4..bc01e28a8 100644 --- a/codeGenVirtual/src/prog8/codegen/virtual/CodeGen.kt +++ b/codeGenVirtual/src/prog8/codegen/virtual/CodeGen.kt @@ -75,7 +75,7 @@ class CodeGen(internal val program: PtProgram, is PtNop -> VmCodeChunk() is PtReturn -> translate(node) is PtJump -> translate(node) - is PtWhen -> TODO("when") + is PtWhen -> translate(node) is PtPipe -> expressionEval.translate(node, 0) is PtForLoop -> translate(node) is PtIfElse -> translate(node) @@ -110,6 +110,47 @@ class CodeGen(internal val program: PtProgram, return code } + private fun translate(whenStmt: PtWhen): VmCodeChunk { + if(whenStmt.choices.children.isEmpty()) + return VmCodeChunk() + val code = VmCodeChunk() + val valueReg = vmRegisters.nextFree() + val choiceReg = vmRegisters.nextFree() + val valueDt = vmType(whenStmt.value.type) + code += expressionEval.translateExpression(whenStmt.value, valueReg) + val choices = whenStmt.choices.children.map {it as PtWhenChoice } + val endLabel = createLabelName() + for (choice in choices) { + if(choice.isElse) { + code += translateNode(choice.statements) + } else { + val skipLabel = createLabelName() + val values = choice.values.children.map {it as PtNumber} + if(values.size==1) { + code += VmCodeInstruction(Opcode.LOAD, valueDt, reg1=choiceReg, value=values[0].number.toInt()) + code += VmCodeInstruction(Opcode.BNE, valueDt, reg1=valueReg, reg2=choiceReg, symbol = skipLabel) + code += translateNode(choice.statements) + if(choice.statements.children.last() !is PtReturn) + code += VmCodeInstruction(Opcode.JUMP, symbol = endLabel) + } else { + val matchLabel = createLabelName() + for (value in values) { + code += VmCodeInstruction(Opcode.LOAD, valueDt, reg1=choiceReg, value=value.number.toInt()) + code += VmCodeInstruction(Opcode.BEQ, valueDt, reg1=valueReg, reg2=choiceReg, symbol = matchLabel) + } + code += VmCodeInstruction(Opcode.JUMP, symbol = skipLabel) + code += VmCodeLabel(matchLabel) + code += translateNode(choice.statements) + if(choice.statements.children.last() !is PtReturn) + code += VmCodeInstruction(Opcode.JUMP, symbol = endLabel) + } + code += VmCodeLabel(skipLabel) + } + } + code += VmCodeLabel(endLabel) + return code + } + private fun translate(forLoop: PtForLoop): VmCodeChunk { val loopvar = symbolTable.lookup(forLoop.variable.targetName) as StStaticVariable val iterable = forLoop.iterable diff --git a/codeGenVirtual/src/prog8/codegen/virtual/ExpressionGen.kt b/codeGenVirtual/src/prog8/codegen/virtual/ExpressionGen.kt index 64f696b5d..25a124869 100644 --- a/codeGenVirtual/src/prog8/codegen/virtual/ExpressionGen.kt +++ b/codeGenVirtual/src/prog8/codegen/virtual/ExpressionGen.kt @@ -297,7 +297,7 @@ internal class ExpressionGen(private val codeGen: CodeGen) { } code += VmCodeInstruction(Opcode.CALL, symbol=fcall.functionName) if(!fcall.void && resultRegister!=0) { - // Call convention: result value is in r0, so put it in the required register instead. TODO does this work correctly? + // Call convention: result value is in r0, so put it in the required register instead. code += VmCodeInstruction(Opcode.LOADR, codeGen.vmType(fcall.type), reg1=resultRegister, reg2=0) } return code diff --git a/docs/source/todo.rst b/docs/source/todo.rst index c45887d1b..6b69b9e26 100644 --- a/docs/source/todo.rst +++ b/docs/source/todo.rst @@ -3,7 +3,6 @@ TODO For next release ^^^^^^^^^^^^^^^^ -- vm codegen: When - vm codegen: Pipe expression - vm: support no globals re-init option - vm codegen/assembler: variable memory locations should also be referenced by the variable name instead of just the address, to make the output more human-readable diff --git a/examples/test.p8 b/examples/test.p8 index 3cb2d3971..98ea2a2d5 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -5,8 +5,32 @@ main { + sub calculate(ubyte value) -> uword { + when value { + 1 -> return "one" + 2 -> return "two" + 3 -> return "three" + 4,5,6 -> return "four to six" + else -> return "other" + } + } + sub start() { + txt.print(calculate(0)) + txt.nl() + txt.print(calculate(1)) + txt.nl() + txt.print(calculate(2)) + txt.nl() + txt.print(calculate(3)) + txt.nl() + txt.print(calculate(4)) + txt.nl() + txt.print(calculate(5)) + txt.nl() + txt.print(calculate(50)) + txt.nl() ; a "pixelshader": ; syscall1(8, 0) ; enable lo res creen