diff --git a/doc/lang/operators.md b/doc/lang/operators.md index da34df92..43703cbc 100644 --- a/doc/lang/operators.md +++ b/doc/lang/operators.md @@ -66,7 +66,8 @@ If and only if both `h` and `l` are assignable expressions, then `h:l` is also a * `+`, `-`: `byte + byte` `constant word + constant word` -`constant long + constant long` +`constant long + constant long` +`word + word` (zpreg) * `*`: multiplication; the size of the result is the same as the size of the arguments `byte * constant byte` @@ -82,7 +83,8 @@ There are no division, remainder or modulo operators. * `|`, `^`, `&`: OR, EXOR and AND `byte | byte` `constant word | constant word` -`constant long | constant long` +`constant long | constant long` +`word | word` (zpreg) * `<<`, `>>`: bit shifting; shifting pads the result with zeroes `byte << constant byte` @@ -103,7 +105,8 @@ These operators work using the decimal arithmetic and will not work on Ricoh CPU * `+'`, `-'`: decimal addition/subtraction `byte +' byte` `constant word +' constant word` -`constant long +' constant long` +`constant long +' constant long` +`word +' word` (zpreg) * `*'`: decimal multiplication `constant *' constant` diff --git a/src/main/scala/millfork/compiler/BuiltIns.scala b/src/main/scala/millfork/compiler/BuiltIns.scala index e7b3e4e6..85b86447 100644 --- a/src/main/scala/millfork/compiler/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/BuiltIns.scala @@ -182,7 +182,7 @@ object BuiltIns { val firstParamCompiled = ExpressionCompiler.compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), NoBranching) ctx.env.eval(r) match { case Some(NumericConstant(0, _)) => - Nil + ExpressionCompiler.compile(ctx, l, None, NoBranching) case Some(NumericConstant(v, _)) if v > 0 => firstParamCompiled ++ List.fill(v.toInt)(AssemblyLine.implied(opcode)) case _ => @@ -216,6 +216,7 @@ object BuiltIns { } } + @deprecated def compileNonetLeftShift(ctx: CompilationContext, lhs: Expression, rhs: Expression): List[AssemblyLine] = { val label = MfCompiler.nextLabel("sh") compileShiftOps(ASL, ctx, lhs, rhs) ++ List( @@ -506,9 +507,9 @@ object BuiltIns { val b = ctx.env.get[Type]("byte") ctx.env.eval(addend) match { case Some(NumericConstant(0, _)) => - AssemblyLine.immediate(LDA, 0) :: ExpressionCompiler.compileByteStorage(ctx, Register.A, v) + ExpressionCompiler.compile(ctx, v, None, NoBranching) ++ (AssemblyLine.immediate(LDA, 0) :: ExpressionCompiler.compileByteStorage(ctx, Register.A, v)) case Some(NumericConstant(1, _)) => - Nil + ExpressionCompiler.compile(ctx, v, None, NoBranching) case Some(NumericConstant(x, _)) => compileByteMultiplication(ctx, v, x.toInt) ++ ExpressionCompiler.compileByteStorage(ctx, Register.A, v) case _ => @@ -575,7 +576,7 @@ object BuiltIns { case _ => false } env.eval(addend) match { - case Some(NumericConstant(0, _)) => Nil + case Some(NumericConstant(0, _)) => ExpressionCompiler.compile(ctx, v, None, NoBranching) case Some(NumericConstant(1, _)) if lhsIsDirectlyIncrementable && !decimal => if (subtract) { simpleOperation(DEC, ctx, v, IndexChoice.RequireX, preserveA = false, commutative = true) } else { @@ -653,7 +654,7 @@ object BuiltIns { val (calculateRhs, addendByteRead0): (List[AssemblyLine], List[List[AssemblyLine]]) = env.eval(addend) match { case Some(NumericConstant(0, _)) => - return Nil + return ExpressionCompiler.compile(ctx, lhs, None, NoBranching) case Some(NumericConstant(1, _)) if canUseIncDec && !subtract => if (ctx.options.flags(CompilationFlag.Emit65CE02Opcodes)) { targetBytes match { @@ -968,7 +969,7 @@ object BuiltIns { case (EOR, Some(NumericConstant(0, _))) | (ORA, Some(NumericConstant(0, _))) | (AND, Some(NumericConstant(AllOnes, _))) => - Nil + ExpressionCompiler.compile(ctx, lhs, None, NoBranching) case _ => val buffer = mutable.ListBuffer[AssemblyLine]() buffer ++= calculateRhs diff --git a/src/main/scala/millfork/compiler/ExpressionCompiler.scala b/src/main/scala/millfork/compiler/ExpressionCompiler.scala index fa9a0c95..53481f28 100644 --- a/src/main/scala/millfork/compiler/ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/ExpressionCompiler.scala @@ -38,15 +38,21 @@ object ExpressionCompiler { if (getExpressionType(ctx, hi).size > 1) ErrorReporting.error("Hi byte too large", hi.position) if (getExpressionType(ctx, lo).size > 1) ErrorReporting.error("Lo byte too large", lo.position) w - case SumExpression(params, _) => b + case SumExpression(params, _) => params.map { case (_, e) => getExpressionType(ctx, e).size }.max match { + case 1 => b + case 2 => w + case _ => ErrorReporting.error("Adding values bigger than words", expr.position); w + } case FunctionCallExpression("nonet", params) => w case FunctionCallExpression("not", params) => bool - case FunctionCallExpression("hi", params) => bool - case FunctionCallExpression("lo", params) => bool + case FunctionCallExpression("hi", params) => b + case FunctionCallExpression("lo", params) => b case FunctionCallExpression("*", params) => b - case FunctionCallExpression("|", params) => b - case FunctionCallExpression("&", params) => b - case FunctionCallExpression("^", params) => b + case FunctionCallExpression("|" | "&" | "^", params) => params.map { e => getExpressionType(ctx, e).size }.max match { + case 1 => b + case 2 => w + case _ => ErrorReporting.error("Adding values bigger than words", expr.position); w + } case FunctionCallExpression("<<", List(a1, a2)) => if (getExpressionType(ctx, a2).size > 1) ErrorReporting.error("Shift amount too large", a2.position) getExpressionType(ctx, a1) @@ -633,7 +639,7 @@ object ExpressionCompiler { val (variableIndex, constantIndex) = env.evalVariableAndConstantSubParts(indexExpr) val variableIndexSize = variableIndex.map(v => getExpressionType(ctx, v).size).getOrElse(0) val totalIndexSize = getExpressionType(ctx, indexExpr).size - exprTypeAndVariable.fold(noop) { case (exprType, target) => + exprTypeAndVariable.fold(compile(ctx, indexExpr, None, BranchSpec.None)) { case (exprType, target) => val register = target match { case RegisterVariable(r, _) => r @@ -734,10 +740,16 @@ object ExpressionCompiler { } exprTypeAndVariable.map(x => compileConstant(ctx, value.quickSimplify, x._2)).getOrElse(Nil) } else { - assertAllBytesForSum("Long addition not supported", ctx, params) - val calculate = BuiltIns.compileAddition(ctx, params, decimal = decimal) - val store = expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position) - calculate ++ store + getSumSize(ctx, params) match { + case 1 => + val calculate = BuiltIns.compileAddition(ctx, params, decimal = decimal) + val store = expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position) + calculate ++ store + case 2 => + val calculate = PseudoregisterBuiltIns.compileWordAdditionToAX(ctx, params, decimal = decimal) + val store = expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position) + calculate ++ store + } } case SeparateBytesExpression(h, l) => exprTypeAndVariable.fold { @@ -830,7 +842,7 @@ object ExpressionCompiler { AssemblyLine.relative(BCC, label), AssemblyLine.implied(INX), AssemblyLine.label(label) - ) ++ expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position) + ) } case "&&" => assertBool(ctx, params, 2) @@ -860,17 +872,23 @@ object ExpressionCompiler { } case "^^" => ??? case "&" => - assertAllBytes("Long bit ops not supported", ctx, params) - BuiltIns.compileBitOps(AND, ctx, params) + getParamMaxSize(ctx, params) match { + case 1 => BuiltIns.compileBitOps(AND, ctx, params) + case 2 => PseudoregisterBuiltIns.compileWordBitOpsToAX(ctx, params, AND) + } case "*" => assertAllBytes("Long multiplication not supported", ctx, params) BuiltIns.compileByteMultiplication(ctx, params) case "|" => - assertAllBytes("Long bit ops not supported", ctx, params) - BuiltIns.compileBitOps(ORA, ctx, params) + getParamMaxSize(ctx, params) match { + case 1 => BuiltIns.compileBitOps(ORA, ctx, params) + case 2 => PseudoregisterBuiltIns.compileWordBitOpsToAX(ctx, params, ORA) + } case "^" => - assertAllBytes("Long bit ops not supported", ctx, params) - BuiltIns.compileBitOps(EOR, ctx, params) + getParamMaxSize(ctx, params) match { + case 1 => BuiltIns.compileBitOps(EOR, ctx, params) + case 2 => PseudoregisterBuiltIns.compileWordBitOpsToAX(ctx, params, EOR) + } case ">>>>" => val (l, r, 2) = assertBinary(ctx, params) l match { @@ -1336,10 +1354,12 @@ object ExpressionCompiler { } } - private def assertAllBytesForSum(msg: String, ctx: CompilationContext, params: List[(Boolean, Expression)]): Unit = { - if (params.exists { case (_, expr) => getExpressionType(ctx, expr).size != 1 }) { - ErrorReporting.fatal(msg, params.head._2.position) - } + private def getParamMaxSize(ctx: CompilationContext, params: List[Expression]): Int = { + params.map { case expr => getExpressionType(ctx, expr).size}.max + } + + private def getSumSize(ctx: CompilationContext, params: List[(Boolean, Expression)]): Int = { + params.map { case (_, expr) => getExpressionType(ctx, expr).size}.max } private def assertAllBytes(msg: String, ctx: CompilationContext, params: List[Expression]): Unit = { diff --git a/src/main/scala/millfork/compiler/PseudoregisterBuiltIns.scala b/src/main/scala/millfork/compiler/PseudoregisterBuiltIns.scala index 6bf0648f..10836eaa 100644 --- a/src/main/scala/millfork/compiler/PseudoregisterBuiltIns.scala +++ b/src/main/scala/millfork/compiler/PseudoregisterBuiltIns.scala @@ -14,6 +14,150 @@ import millfork.assembly.AddrMode._ */ object PseudoregisterBuiltIns { + def compileWordAdditionToAX(ctx: CompilationContext, params: List[(Boolean, Expression)], decimal: Boolean): List[AssemblyLine] = { + if (!ctx.options.flag(CompilationFlag.ZeropagePseudoregister)) { + ErrorReporting.error("Word addition or subtraction requires the zeropage pseudoregister", params.headOption.flatMap(_._2.position)) + return Nil + } + if (params.isEmpty) { + return List(AssemblyLine.immediate(LDA, 0), AssemblyLine.immediate(LDX, 0)) + } + val b = ctx.env.get[Type]("byte") + val w = ctx.env.get[Type]("word") + val reg = ctx.env.get[VariableInMemory]("__reg") + val head = params.head match { + case (false, e) => ExpressionCompiler.compile(ctx, e, Some(w -> reg), BranchSpec.None) + case (true, e) => ??? + } + params.tail.foldLeft[List[AssemblyLine]](head){case (code, (sub, param)) => code ++ addToReg(ctx, param, sub, decimal)} ++ List( + AssemblyLine.zeropage(LDA, reg), + AssemblyLine.zeropage(LDX, reg, 1), + ) + } + + def addToReg(ctx: CompilationContext, r: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = { + if (!ctx.options.flag(CompilationFlag.ZeropagePseudoregister)) { + ErrorReporting.error("Word addition or subtraction requires the zeropage pseudoregister", r.position) + return Nil + } + val b = ctx.env.get[Type]("byte") + val w = ctx.env.get[Type]("word") + val reg = ctx.env.get[VariableInMemory]("__reg") + // TODO: smarter on 65816 + val compileRight = ExpressionCompiler.compile(ctx, r, Some(w -> reg), BranchSpec.None) + val op = if (subtract) SBC else ADC + val prepareCarry = AssemblyLine.implied(if (subtract) SEC else CLC) + compileRight match { + + case List( + AssemblyLine(LDA, Immediate, NumericConstant(0, _), _), + AssemblyLine(STA, ZeroPage, _, _), + AssemblyLine(LDX | LDA, Immediate, NumericConstant(0, _), _), + AssemblyLine(STA | STX, ZeroPage, _, _)) => Nil + + case List( + l@AssemblyLine(LDA, _, _, _), + AssemblyLine(STA, ZeroPage, _, _), + h@AssemblyLine(LDX | LDA, addrMode, _, _), + AssemblyLine(STA | STX, ZeroPage, _, _)) if addrMode != ZeroPageY => BuiltIns.wrapInSedCldIfNeeded(decimal, + List(prepareCarry, + AssemblyLine.zeropage(LDA, reg), + l.copy(opcode = op), + AssemblyLine.zeropage(STA, reg), + AssemblyLine.zeropage(LDA, reg, 1), + h.copy(opcode = op), + AssemblyLine.zeropage(STA, reg, 1), + AssemblyLine.zeropage(LDA, reg))) + + case _ => BuiltIns.wrapInSedCldIfNeeded(decimal, + List( + AssemblyLine.zeropage(LDA, reg, 1), + AssemblyLine.implied(PHA), + AssemblyLine.zeropage(LDA, reg), + AssemblyLine.implied(PHA)) ++ compileRight ++ List( + prepareCarry, + AssemblyLine.implied(PLA), + AssemblyLine.zeropage(op, reg), + AssemblyLine.zeropage(STA, reg), + AssemblyLine.implied(PLA), + AssemblyLine.zeropage(op, reg, 1), + AssemblyLine.zeropage(STA, reg, 1))) + } + } + + + def compileWordBitOpsToAX(ctx: CompilationContext, params: List[Expression], op: Opcode.Value): List[AssemblyLine] = { + if (!ctx.options.flag(CompilationFlag.ZeropagePseudoregister)) { + ErrorReporting.error("Word bit operation requires the zeropage pseudoregister", params.headOption.flatMap(_.position)) + return Nil + } + if (params.isEmpty) { + return List(AssemblyLine.immediate(LDA, 0), AssemblyLine.immediate(LDX, 0)) + } + val b = ctx.env.get[Type]("byte") + val w = ctx.env.get[Type]("word") + val reg = ctx.env.get[VariableInMemory]("__reg") + val head = ExpressionCompiler.compile(ctx, params.head, Some(w -> reg), BranchSpec.None) + params.tail.foldLeft[List[AssemblyLine]](head){case (code, param) => code ++ bitOpReg(ctx, param, op)} ++ List( + AssemblyLine.zeropage(LDA, reg), + AssemblyLine.zeropage(LDX, reg, 1), + ) + } + + def bitOpReg(ctx: CompilationContext, r: Expression, op: Opcode.Value): List[AssemblyLine] = { + if (!ctx.options.flag(CompilationFlag.ZeropagePseudoregister)) { + ErrorReporting.error("Word bit operation requires the zeropage pseudoregister", r.position) + return Nil + } + val b = ctx.env.get[Type]("byte") + val w = ctx.env.get[Type]("word") + val reg = ctx.env.get[VariableInMemory]("__reg") + // TODO: smarter on 65816 + val compileRight = ExpressionCompiler.compile(ctx, r, Some(w -> reg), BranchSpec.None) + compileRight match { + case List( + AssemblyLine(LDA, Immediate, NumericConstant(0, _), _), + AssemblyLine(STA, ZeroPage, _, _), + AssemblyLine(LDX | LDA, Immediate, NumericConstant(0, _), _), + AssemblyLine(STA | STX, ZeroPage, _, _)) + if op != AND => Nil + + case List( + AssemblyLine(LDA, Immediate, NumericConstant(0xff, _), _), + AssemblyLine(STA, ZeroPage, _, _), + AssemblyLine(LDX | LDA, Immediate, NumericConstant(0xff, _), _), + AssemblyLine(STA | STX, ZeroPage, _, _)) + if op == AND => Nil + + case List( + l@AssemblyLine(LDA, _, _, _), + AssemblyLine(STA, ZeroPage, _, _), + h@AssemblyLine(LDX | LDA, addrMode, _, _), + AssemblyLine(STA | STX, ZeroPage, _, _)) if addrMode != ZeroPageY => + List( + AssemblyLine.zeropage(LDA, reg), + l.copy(opcode = op), + AssemblyLine.zeropage(STA, reg), + AssemblyLine.zeropage(LDA, reg, 1), + h.copy(opcode = op), + AssemblyLine.zeropage(STA, reg, 1), + AssemblyLine.zeropage(LDA, reg)) + + case _ => + List( + AssemblyLine.zeropage(LDA, reg, 1), + AssemblyLine.implied(PHA), + AssemblyLine.zeropage(LDA, reg), + AssemblyLine.implied(PHA)) ++ compileRight ++ List( + AssemblyLine.implied(PLA), + AssemblyLine.zeropage(op, reg), + AssemblyLine.zeropage(STA, reg), + AssemblyLine.implied(PLA), + AssemblyLine.zeropage(op, reg, 1), + AssemblyLine.zeropage(STA, reg, 1)) + } + } + def compileWordShiftOps(left: Boolean, ctx: CompilationContext, l: Expression, r: Expression): List[AssemblyLine] = { if (!ctx.options.flag(CompilationFlag.ZeropagePseudoregister)) { ErrorReporting.error("Word shifting requires the zeropage pseudoregister", l.position) diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala index f8e774d2..6375d9fc 100644 --- a/src/test/scala/millfork/test/WordMathSuite.scala +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -132,4 +132,36 @@ class WordMathSuite extends FunSuite with Matchers { m.readByte(0xc006) should equal(0) } } + + test("Word addition 2") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | word v + | v = w($482) + | output = v + w($482) - 3 + | } + | noinline word w(word w) { + | return w + | } + """.stripMargin){ m => + m.readWord(0xc000) should equal(0x901) + } + } + + test("Word bit ops 2") { + EmuBenchmarkRun(""" + | word output @$c000 + | void main () { + | word v + | v = w($692) + | output = (v & w($ca2)) | 3 + | } + | noinline word w(word w) { + | return w + | } + """.stripMargin){ m => + m.readWord(0xc000) should equal(0x483) + } + } }