diff --git a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt index 83a06efa7..74288107f 100644 --- a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt +++ b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt @@ -375,6 +375,21 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val return noModifications } + override fun after(whenChoice: WhenChoice, parent: Node): Iterable { + if((parent as When).condition.inferType(program).isWords) { + val values = whenChoice.values + values?.toTypedArray()?.withIndex()?.forEach { (index, value) -> + val num = value.constValue(program) + if(num!=null && num.type in ByteDatatypes) { + val wordNum = NumericLiteral(if(num.type==DataType.UBYTE) DataType.UWORD else DataType.WORD, num.number, num.position) + wordNum.parent = num.parent + values[index] = wordNum + } + } + } + return noModifications + } + private fun addTypecastOrCastedValueModification( modifications: MutableList, expressionToCast: Expression, diff --git a/compiler/test/TestTypecasts.kt b/compiler/test/TestTypecasts.kt index 5455fdf0b..4db763268 100644 --- a/compiler/test/TestTypecasts.kt +++ b/compiler/test/TestTypecasts.kt @@ -1001,4 +1001,19 @@ main { compileText(VMTarget(), false, text, writeAssembly = true) shouldNotBe null compileText(VMTarget(), true, text, writeAssembly = true) shouldNotBe null } + + test("byte when choices silently converted to word for convenience") { + var text=""" +main { + sub start() { + uword z = 3 + when z { + 1-> z++ + 2-> z++ + else -> z++ + } + } +}""" + compileText(C64Target(), false, text, writeAssembly = false) shouldNotBe null + } }) diff --git a/compiler/test/ast/TestProg8Parser.kt b/compiler/test/ast/TestProg8Parser.kt index 984d19b70..ec6d4de6f 100644 --- a/compiler/test/ast/TestProg8Parser.kt +++ b/compiler/test/ast/TestProg8Parser.kt @@ -311,7 +311,7 @@ class TestProg8Parser: FunSpec( { sub start() { ubyte foo = 42 ubyte bar - when (foo) { + when foo { 23 -> bar = 'x' 42 -> bar = 'y' else -> bar = 'z'