diff --git a/codeOptimizers/src/prog8/optimizer/UnusedCodeRemover.kt b/codeOptimizers/src/prog8/optimizer/UnusedCodeRemover.kt index d42dadf84..10c301681 100644 --- a/codeOptimizers/src/prog8/optimizer/UnusedCodeRemover.kt +++ b/codeOptimizers/src/prog8/optimizer/UnusedCodeRemover.kt @@ -157,6 +157,12 @@ class UnusedCodeRemover(private val program: Program, return noModifications } + override fun after(assignment: Assignment, parent: Node): Iterable { + if(assignment.target isSameAs assignment.value) + return listOf(IAstModification.Remove(assignment, parent as IStatementContainer)) + return noModifications + } + private fun deduplicateAssignments(statements: List, scope: IStatementContainer): List { // removes 'duplicate' assignments that assign the same target directly after another, unless it is a function call val linesToRemove = mutableListOf() diff --git a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt index 511b0ceab..4fdc2fc84 100644 --- a/compiler/src/prog8/compiler/astprocessing/AstChecker.kt +++ b/compiler/src/prog8/compiler/astprocessing/AstChecker.kt @@ -479,7 +479,9 @@ internal class AstChecker(private val program: Program, errors.err("target datatype is unknown", assignment.target.position) // otherwise, another error about missing symbol is already reported. } else { - errors.err("type of value $valueDt doesn't match target $targetDt", assignment.value.position) + // allow bitwise operations on different types as long as the size is the same + if (!((assignment.value as? BinaryExpression)?.operator in BitwiseOperators && targetDt.isBytes && valueDt.isBytes || targetDt.isWords && valueDt.isWords)) + errors.err("type of value $valueDt doesn't match target $targetDt", assignment.value.position) } } } @@ -1598,8 +1600,11 @@ internal class AstChecker(private val program: Program, else if(sourceDatatype== DataType.FLOAT && targetDatatype in IntegerDatatypes) errors.err("cannot assign float to ${targetDatatype.name.lowercase()}; possible loss of precision. Suggestion: round the value or revert to integer arithmetic", position) else { - if(targetDatatype!=DataType.UWORD && sourceDatatype !in PassByReferenceDatatypes) - errors.err("type of value $sourceDatatype doesn't match target $targetDatatype", position) + if(targetDatatype!=DataType.UWORD && sourceDatatype !in PassByReferenceDatatypes) { + // allow bitwise operations on different types as long as the size is the same + if (!((sourceValue as? BinaryExpression)?.operator in BitwiseOperators && targetDatatype.equalsSize(sourceDatatype))) + errors.err("type of value $sourceDatatype doesn't match target $targetDatatype", position) + } } return false diff --git a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt index e1a986b40..d4369edce 100644 --- a/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt +++ b/compiler/src/prog8/compiler/astprocessing/BeforeAsmAstChanger.kt @@ -93,6 +93,11 @@ internal class BeforeAsmAstChanger(val program: Program, ) } } else { + if(binExpr.left isSameAs assignment.target) + return noModifications + val typeCast = binExpr.left as? TypecastExpression + if(typeCast!=null && typeCast.expression isSameAs assignment.target) + return noModifications val sourceDt = binExpr.left.inferType(program).getOrElse { throw AssemblyError("unknown dt") } val (_, left) = binExpr.left.typecastTo(assignment.target.inferType(program).getOrElse { throw AssemblyError( "unknown dt" diff --git a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt index af6f6c8e7..3e5f2dd2d 100644 --- a/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt +++ b/compiler/src/prog8/compiler/astprocessing/TypecastsAdder.kt @@ -133,9 +133,16 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val return listOf(IAstModification.ReplaceNode(expr.left, cast, expr)) } if(leftDt istype DataType.WORD && rightDt.oneOf(DataType.UBYTE, DataType.UWORD)) { - // cast left to unsigned - val cast = TypecastExpression(expr.left, rightDt.getOr(DataType.UNDEFINED), true, expr.left.position) - return listOf(IAstModification.ReplaceNode(expr.left, cast, expr)) + // cast left to unsigned word. Cast right to unsigned word if it is ubyte + val mods = mutableListOf() + val cast = TypecastExpression(expr.left, DataType.UWORD, true, expr.left.position) + mods += IAstModification.ReplaceNode(expr.left, cast, expr) + if(rightDt istype DataType.UBYTE) { + mods += IAstModification.ReplaceNode(expr.right, + TypecastExpression(expr.right, DataType.UWORD, true, expr.right.position), + expr) + } + return mods } if(rightDt istype DataType.BYTE && leftDt.oneOf(DataType.UBYTE, DataType.UWORD)) { // cast right to unsigned @@ -143,9 +150,16 @@ class TypecastsAdder(val program: Program, val options: CompilationOptions, val return listOf(IAstModification.ReplaceNode(expr.right, cast, expr)) } if(rightDt istype DataType.WORD && leftDt.oneOf(DataType.UBYTE, DataType.UWORD)) { - // cast right to unsigned - val cast = TypecastExpression(expr.right, leftDt.getOr(DataType.UNDEFINED), true, expr.right.position) - return listOf(IAstModification.ReplaceNode(expr.right, cast, expr)) + // cast right to unsigned word. Cast left to unsigned word if it is ubyte + val mods = mutableListOf() + val cast = TypecastExpression(expr.right, DataType.UWORD, true, expr.right.position) + mods += IAstModification.ReplaceNode(expr.right, cast, expr) + if(leftDt istype DataType.UBYTE) { + mods += IAstModification.ReplaceNode(expr.left, + TypecastExpression(expr.left, DataType.UWORD, true, expr.left.position), + expr) + } + return mods } } diff --git a/examples/test.p8 b/examples/test.p8 index c1ae9a19f..d1bea1db7 100644 --- a/examples/test.p8 +++ b/examples/test.p8 @@ -1,16 +1,35 @@ -%import floats %import textio %zeropage basicsafe main { sub start() { - ubyte[] array = [ $00, $11, $22, $33, $44, $55, $66, $77, $88, $99, $aa, $bb] + ubyte ub = 123 + byte bb = -100 + uword uw = 12345 + word ww = -12345 - ubyte x = 2 - ubyte y = 3 - txt.print_uwhex(mkword(array[9], array[8]), true) - txt.print_uwhex(mkword(array[x*y+y], array[y*x+x]), true) + ub |= 63 ; vm/c64 ok (127) + bb |= 63 ; vm/c64 ok (-65) + uw |= 63 ; vm/c64 ok (12351) + ww |= 63 ; vm/c64 ok (-12289) + + txt.print_ub(ub) + txt.spc() + txt.print_b(bb) + txt.spc() + txt.print_uw(uw) + txt.spc() + txt.print_w(ww) + txt.nl() + + uw |= 16384 ; vm/c64 ok (28735) + ww |= 8192 ; vm/c64 ok (-4097) + + txt.print_uw(uw) + txt.spc() + txt.print_w(ww) + txt.nl() } }