diff --git a/compiler/src/prog8/ast/processing/AstChecker.kt b/compiler/src/prog8/ast/processing/AstChecker.kt index a8b06f902..ddf2f837f 100644 --- a/compiler/src/prog8/ast/processing/AstChecker.kt +++ b/compiler/src/prog8/ast/processing/AstChecker.kt @@ -926,14 +926,17 @@ internal class AstChecker(private val program: Program, } override fun visit(whenChoice: WhenChoice) { + val whenStmt = whenChoice.parent as WhenStatement if(whenChoice.value!=null) { + val conditionType = whenStmt.condition.inferType(program) val constvalue = whenChoice.value.constValue(program) - if (constvalue == null) - checkResult.add(SyntaxError("value of a when choice must be a constant", whenChoice.position)) - else if (constvalue.type !in IntegerDatatypes) - checkResult.add(SyntaxError("value of a when choice must be a byte or word", whenChoice.position)) + when { + constvalue == null -> checkResult.add(SyntaxError("choice value must be a constant", whenChoice.position)) + constvalue.type !in IntegerDatatypes -> checkResult.add(SyntaxError("choice value must be a byte or word", whenChoice.position)) + constvalue.type != conditionType -> checkResult.add(SyntaxError("choice value datatype differs from condition value", whenChoice.position)) + } } else { - if(whenChoice !== (whenChoice.parent as WhenStatement).choices.last()) + if(whenChoice !== whenStmt.choices.last()) checkResult.add(SyntaxError("else choice must be the last one", whenChoice.position)) } super.visit(whenChoice) diff --git a/compiler/src/prog8/compiler/Compiler.kt b/compiler/src/prog8/compiler/Compiler.kt index 99a9584b5..7e9f679c8 100644 --- a/compiler/src/prog8/compiler/Compiler.kt +++ b/compiler/src/prog8/compiler/Compiler.kt @@ -2095,22 +2095,20 @@ internal class Compiler(private val program: Program) { val endOfWhenLabel = makeLabel(whenstmt, "when_end") val choiceLabels = mutableListOf() - var previousValue = 0 for(choice in whenstmt.choiceValues(program)) { val choiceVal = choice.first if(choiceVal==null) { // the else clause translate(choice.second.statements) } else { - val subtract = choiceVal-previousValue - previousValue = choiceVal + val rval = RuntimeValue(conditionDt!!, choiceVal) if (conditionDt in ByteDatatypes) { prog.instr(Opcode.DUP_B) - prog.instr(opcodeCompare(conditionDt!!), RuntimeValue(conditionDt, subtract)) + prog.instr(opcodeCompare(conditionDt), rval) } else { prog.instr(Opcode.DUP_W) - prog.instr(opcodeCompare(conditionDt!!), RuntimeValue(conditionDt, subtract)) + prog.instr(opcodeCompare(conditionDt), rval) } val choiceLabel = makeLabel(whenstmt, "choice_$choiceVal") choiceLabels.add(choiceLabel) @@ -2120,12 +2118,12 @@ internal class Compiler(private val program: Program) { prog.instr(Opcode.JUMP, callLabel = endOfWhenLabel) for(choice in whenstmt.choices.zip(choiceLabels)) { - // TODO the various code blocks here, don't forget to jump to the end label at their eind prog.label(choice.second) - prog.instr(Opcode.NOP) + translate(choice.first.statements) prog.instr(Opcode.JUMP, callLabel = endOfWhenLabel) } + prog.removeLastInstruction() // remove the last jump, that can fall through to here prog.label(endOfWhenLabel) if (conditionDt in ByteDatatypes) prog.instr(Opcode.DISCARD_BYTE) diff --git a/compiler/src/prog8/vm/RuntimeValue.kt b/compiler/src/prog8/vm/RuntimeValue.kt index 5cc39e23f..15aeeb26c 100644 --- a/compiler/src/prog8/vm/RuntimeValue.kt +++ b/compiler/src/prog8/vm/RuntimeValue.kt @@ -53,27 +53,37 @@ open class RuntimeValue(val type: DataType, num: Number?=null, val str: String?= init { when(type) { DataType.UBYTE -> { - byteval = (num!!.toInt() and 255).toShort() + val inum = num!!.toInt() + if(inum !in 0 .. 255) + throw IllegalArgumentException("invalid value for ubyte: $inum") + byteval = inum.toShort() wordval = null floatval = null asBoolean = byteval != 0.toShort() } DataType.BYTE -> { - val v = num!!.toInt() and 255 - byteval = (if(v<128) v else v-256).toShort() + val inum = num!!.toInt() + if(inum !in -128 .. 127) + throw IllegalArgumentException("invalid value for byte: $inum") + byteval = inum.toShort() wordval = null floatval = null asBoolean = byteval != 0.toShort() } DataType.UWORD -> { - wordval = num!!.toInt() and 65535 + val inum = num!!.toInt() + if(inum !in 0 .. 65536) + throw IllegalArgumentException("invalid value for uword: $inum") + wordval = inum byteval = null floatval = null asBoolean = wordval != 0 } DataType.WORD -> { - val v = num!!.toInt() and 65535 - wordval = if(v<32768) v else v - 65536 + val inum = num!!.toInt() + if(inum !in -32768 .. 32767) + throw IllegalArgumentException("invalid value for word: $inum") + wordval = inum byteval = null floatval = null asBoolean = wordval != 0 diff --git a/examples/test.p8 b/examples/test.p8 index 3a76bfd17..c781fe555 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -4,31 +4,43 @@ ~ main { sub start() { - A=100 - Y=22 - uword uw = (A as uword)*Y + ubyte aa = 100 + ubyte yy = 22 + uword uw = (aa as uword)*yy c64scr.print("stack (255?): ") c64scr.print_ub(X) c64.CHROUT('\n') + aa=30 + yy=2 + + c64scr.print_ub(aa+yy) + c64scr.print("?: ") + check(aa, yy) + aa+=9 + c64scr.print_ub(aa+yy) + c64scr.print("?: ") + check(aa, yy) + c64scr.print_uw(uw) c64scr.print("?: ") when uw { 12345 -> c64scr.print("12345") 12346 -> c64scr.print("12346") 2200 -> c64scr.print("2200") + 2202 -> c64scr.print("2202") 12347 -> c64scr.print("12347") - else -> c64scr.print("else") + else -> c64scr.print("not in table") } c64.CHROUT('\n') - A=30 - Y=2 + c64scr.print("stack (255?): ") + c64scr.print_ub(X) + } - c64scr.print_ub(A+Y) - c64scr.print("?: ") - when A+Y { + sub check(ubyte a, ubyte y) { + when a+y { 10 -> { c64scr.print("ten") } @@ -44,16 +56,10 @@ 56 -> { ; should be optimized away } - 57243 -> { - ; should be optimized away - } else -> { - c64scr.print("!??!\n") + c64scr.print("not in table") } } c64.CHROUT('\n') - - c64scr.print("stack (255?): ") - c64scr.print_ub(X) } }