diff --git a/docs/lang/types.md b/docs/lang/types.md index 154ae2f9..e9b18237 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -10,6 +10,9 @@ Millfork puts extra limitations on which types can be used in which contexts. * `word` – 2-byte value of undefined signedness, defaulting to unsigned +* `farword` – 4-byte value of undefined signedness, defaulting to unsigned +(the name is an analogy to a future 24-bit type called `farpointer`) + * `long` – 4-byte value of undefined signedness, defaulting to unsigned * `sbyte` – signed 1-byte value @@ -20,6 +23,8 @@ Millfork puts extra limitations on which types can be used in which contexts. and you can index `pointer` variables (not arbitrary `pointer`-typed expressions though, `f()[0]` won't compile) Functions cannot return types longer than 2 bytes. +There's also no reason to make a function return `pointer`, since to dereference it, +you need to put it in a variable first anyway. Numeric types can be converted automatically: diff --git a/src/main/scala/millfork/compiler/ExpressionCompiler.scala b/src/main/scala/millfork/compiler/ExpressionCompiler.scala index 40391462..6ee6440f 100644 --- a/src/main/scala/millfork/compiler/ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/ExpressionCompiler.scala @@ -20,13 +20,13 @@ object ExpressionCompiler { val bool = env.get[Type]("bool$") val v = env.get[Type]("void") val w = env.get[Type]("word") - val l = env.get[Type]("long") expr match { case LiteralExpression(value, size) => size match { case 1 => b case 2 => w - case 3 | 4 => l + case 3 => env.get[Type]("farword") + case 4 => env.get[Type]("long") } case VariableExpression(name) => env.get[TypedThing](name, expr.position).typ @@ -524,7 +524,8 @@ object ExpressionCompiler { ErrorReporting.error(s"Variable `$target.name` is too small", expr.position) Nil } else { - val copy = List.tabulate(exprType.size)(i => AssemblyLine.variable(ctx, LDA, source, i) ++ AssemblyLine.variable(ctx, STA, target, i)) + val copyFromLo = List.tabulate(exprType.size)(i => AssemblyLine.variable(ctx, LDA, source, i) ++ AssemblyLine.variable(ctx, STA, target, i)) + val copy = if (shouldCopyFromHiToLo(source.toAddress, target.toAddress)) copyFromLo.reverse else copyFromLo val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) { signExtendA() ++ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.variable(ctx, STA, target, i + exprType.size)).flatten } else { @@ -627,7 +628,8 @@ object ExpressionCompiler { ErrorReporting.error(s"Variable `$target.name` is too small", expr.position) Nil } else { - val copy = List.tabulate(exprType.size)(i => List(AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + i), AssemblyLine.absoluteX(STA, target.baseOffset + i))) + val copyFromLo = List.tabulate(exprType.size)(i => List(AssemblyLine.absoluteX(LDA, offset + ctx.extraStackOffset + i), AssemblyLine.absoluteX(STA, target.baseOffset + i))) + val copy = if (shouldCopyFromHiToLo(NumericConstant(source.baseOffset, 2), NumericConstant(target.baseOffset, 2))) copyFromLo.reverse else copyFromLo val extend = if (exprType.size == target.typ.size) Nil else if (exprType.isSigned) { signExtendA() ++ List.tabulate(target.typ.size - exprType.size)(i => AssemblyLine.absoluteX(STA, target.baseOffset + ctx.extraStackOffset + i + exprType.size)) } else { @@ -1466,4 +1468,17 @@ object ExpressionCompiler { AssemblyLine.immediate(LDA, 0), AssemblyLine.label(label)) } + + private def shouldCopyFromHiToLo(srcAddress: Constant, destAddress: Constant): Boolean = (srcAddress, destAddress) match { + case ( + CompoundConstant(MathOperator.Plus, a: MemoryAddressConstant, NumericConstant(s, _)), + CompoundConstant(MathOperator.Plus, b: MemoryAddressConstant, NumericConstant(d, _)) + ) if a == b => s < d + case ( + a: MemoryAddressConstant, + CompoundConstant(MathOperator.Plus, b: MemoryAddressConstant, NumericConstant(d, _)) + ) if a == b => 0 < d + case (NumericConstant(s, _), NumericConstant(d, _)) => s < d + case _ => false + } } diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index c8fdfac3..6ec4aa70 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -242,11 +242,12 @@ class Environment(val parent: Option[Environment], val prefix: String) { addThing(BuiltInBooleanType, None) addThing(BasicPlainType("byte", 1), None) addThing(BasicPlainType("word", 2), None) + addThing(BasicPlainType("farword", 3), None) addThing(BasicPlainType("long", 4), None) addThing(DerivedPlainType("pointer", get[PlainType]("word"), isSigned = false), None) +// addThing(DerivedPlainType("farpointer", get[PlainType]("farword"), isSigned = false), None) addThing(DerivedPlainType("ubyte", get[PlainType]("byte"), isSigned = false), None) addThing(DerivedPlainType("sbyte", get[PlainType]("byte"), isSigned = true), None) - addThing(DerivedPlainType("cent", get[PlainType]("byte"), isSigned = false), None) val trueType = ConstantBooleanType("true$", value = true) val falseType = ConstantBooleanType("false$", value = false) addThing(trueType, None) @@ -599,6 +600,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { def registerParameter(stmt: ParameterDeclaration): Unit = { val typ = get[Type](stmt.typ) val b = get[Type]("byte") + val w = get[Type]("word") val p = get[Type]("pointer") stmt.assemblyParamPassingConvention match { case ByVariable(name) => @@ -606,10 +608,25 @@ class Environment(val parent: Option[Environment], val prefix: String) { val v = UninitializedMemoryVariable(prefix + name, typ, if (zp) VariableAllocationMethod.Zeropage else VariableAllocationMethod.Auto, None) addThing(v, stmt.position) registerAddressConstant(v, stmt.position) - if (typ.size == 2) { - val addr = v.toAddress - addThing(RelativeVariable(v.name + ".hi", addr + 1, b, zeropage = zp, None), stmt.position) - addThing(RelativeVariable(v.name + ".lo", addr, b, zeropage = zp, None), stmt.position) + val addr = v.toAddress + typ.size match { + case 2 => + addThing(RelativeVariable(v.name + ".hi", addr + 1, b, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".lo", addr, b, zeropage = zp, None), stmt.position) + case 3 => + addThing(RelativeVariable(v.name + ".hiword", addr + 1, w, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".loword", addr, w, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".b2", addr + 2, b, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".b1", addr + 1, b, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".b0", addr, b, zeropage = zp, None), stmt.position) + case 4 => + addThing(RelativeVariable(v.name + ".hiword", addr + 2, w, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".loword", addr, w, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".b3", addr + 3, b, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".b2", addr + 2, b, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".b1", addr + 1, b, zeropage = zp, None), stmt.position) + addThing(RelativeVariable(v.name + ".b0", addr, b, zeropage = zp, None), stmt.position) + case _ => } case ByRegister(_) => () case ByConstant(name) => @@ -741,8 +758,9 @@ class Environment(val parent: Option[Environment], val prefix: String) { if (stmt.stack && stmt.global) ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position) } val b = get[Type]("byte") + val w = get[Type]("word") val typ = get[PlainType](stmt.typ) - if (stmt.typ == "pointer") { + if (stmt.typ == "pointer" || stmt.typ == "farpointer") { // if (stmt.constant) { // ErrorReporting.error(s"Pointer `${stmt.name}` cannot be constant") // } @@ -761,9 +779,24 @@ class Environment(val parent: Option[Environment], val prefix: String) { val constantValue: Constant = stmt.initialValue.flatMap(eval).getOrElse(Constant.error(s"`$name` has a non-constant value", position)) if (constantValue.requiredSize > typ.size) ErrorReporting.error(s"`$name` is has an invalid value: not in the range of `$typ`", position) addThing(ConstantThing(prefix + name, constantValue, typ), stmt.position) - if (typ.size >= 2) { - addThing(ConstantThing(prefix + name + ".hi", constantValue.hiByte, b), stmt.position) - addThing(ConstantThing(prefix + name + ".lo", constantValue.loByte, b), stmt.position) + typ.size match { + case 2 => + addThing(ConstantThing(prefix + name + ".hi", constantValue.hiByte, b), stmt.position) + addThing(ConstantThing(prefix + name + ".lo", constantValue.loByte, b), stmt.position) + case 3 => + addThing(ConstantThing(prefix + name + ".hiword", constantValue.subword(1), b), stmt.position) + addThing(ConstantThing(prefix + name + ".loword", constantValue.subword(0), b), stmt.position) + addThing(ConstantThing(prefix + name + ".b2", constantValue.subbyte(2), b), stmt.position) + addThing(ConstantThing(prefix + name + ".b1", constantValue.hiByte, b), stmt.position) + addThing(ConstantThing(prefix + name + ".b0", constantValue.loByte, b), stmt.position) + case 4 => + addThing(ConstantThing(prefix + name + ".hiword", constantValue.subword(2), b), stmt.position) + addThing(ConstantThing(prefix + name + ".loword", constantValue.subword(0), b), stmt.position) + addThing(ConstantThing(prefix + name + ".b3", constantValue.subbyte(3), b), stmt.position) + addThing(ConstantThing(prefix + name + ".b2", constantValue.subbyte(2), b), stmt.position) + addThing(ConstantThing(prefix + name + ".b1", constantValue.hiByte, b), stmt.position) + addThing(ConstantThing(prefix + name + ".b0", constantValue.loByte, b), stmt.position) + case _ => } } else { if (stmt.stack && stmt.global) ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position) @@ -777,9 +810,24 @@ class Environment(val parent: Option[Environment], val prefix: String) { val v = StackVariable(prefix + name, typ, this.baseStackOffset) baseStackOffset += typ.size addThing(v, stmt.position) - if (typ.size == 2) { - addThing(StackVariable(prefix + name + ".lo", b, baseStackOffset), stmt.position) - addThing(StackVariable(prefix + name + ".hi", b, baseStackOffset + 1), stmt.position) + typ.size match { + case 2 => + addThing(StackVariable(prefix + name + ".hi", b, baseStackOffset + 1), stmt.position) + addThing(StackVariable(prefix + name + ".lo", b, baseStackOffset), stmt.position) + case 3 => + addThing(StackVariable(prefix + name + ".hiword", w, baseStackOffset + 1), stmt.position) + addThing(StackVariable(prefix + name + ".loword", w, baseStackOffset), stmt.position) + addThing(StackVariable(prefix + name + ".b2", b, baseStackOffset + 2), stmt.position) + addThing(StackVariable(prefix + name + ".b1", b, baseStackOffset + 1), stmt.position) + addThing(StackVariable(prefix + name + ".b0", b, baseStackOffset), stmt.position) + case 4 => + addThing(StackVariable(prefix + name + ".hiword", w, baseStackOffset + 2), stmt.position) + addThing(StackVariable(prefix + name + ".loword", w, baseStackOffset), stmt.position) + addThing(StackVariable(prefix + name + ".b3", b, baseStackOffset + 3), stmt.position) + addThing(StackVariable(prefix + name + ".b2", b, baseStackOffset + 2), stmt.position) + addThing(StackVariable(prefix + name + ".b1", b, baseStackOffset + 1), stmt.position) + addThing(StackVariable(prefix + name + ".b0", b, baseStackOffset), stmt.position) + case _ => } } else { val (v, addr) = stmt.address.fold[(VariableInMemory, Constant)]({ @@ -815,11 +863,24 @@ class Environment(val parent: Option[Environment], val prefix: String) { if (!v.isInstanceOf[MemoryVariable]) { addThing(ConstantThing(v.name + "`", addr, b), stmt.position) } - if (typ.size == 2) { - addThing(RelativeVariable(prefix + name + ".hi", addr + 1, b, zeropage = v.zeropage, - declaredBank = stmt.bank), stmt.position) - addThing(RelativeVariable(prefix + name + ".lo", addr, b, zeropage = v.zeropage, - declaredBank = stmt.bank), stmt.position) + typ.size match { + case 2 => + addThing(RelativeVariable(prefix + name + ".hi", addr + 1, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".lo", addr, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + case 3 => + addThing(RelativeVariable(prefix + name + ".hiword", addr + 1, w, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".loword", addr, w, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".b2", addr + 2, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".b1", addr + 1, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".b0", addr, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + case 4 => + addThing(RelativeVariable(prefix + name + ".hiword", addr + 2, w, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".loword", addr, w, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".b3", addr + 3, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".b2", addr + 2, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".b1", addr + 1, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + addThing(RelativeVariable(prefix + name + ".b0", addr, b, zeropage = v.zeropage, declaredBank = stmt.bank), stmt.position) + case _ => } } } diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 38bef43d..eaeb22f9 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -75,11 +75,12 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } } - // TODO: 3-byte types - def size(value: Int, wordLiteral: Boolean, longLiteral: Boolean): Int = - if (value > 255 || value < -128 || wordLiteral) - if (value > 0xffff || longLiteral) 4 else 2 - else 1 + def size(value: Int, wordLiteral: Boolean, farwordLiteral: Boolean, longLiteral: Boolean): Int = { + val w = value > 255 || value < -0x80 || wordLiteral + val f = value > 0xffff || value < -0x8000 || farwordLiteral + val l = value > 0xffffff || value < -0x800000 || longLiteral + if (l) 4 else if (f) 3 else if (w) 2 else 1 + } def sign(abs: Int, minus: Boolean): Int = if (minus) -abs else abs @@ -91,7 +92,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } yield { val abs = Integer.parseInt(s, 10) val value = sign(abs, minus.isDefined) - LiteralExpression(value, size(value, s.length > 3, s.length > 5)).pos(p) + LiteralExpression(value, size(value, s.length > 3, s.length > 5, s.length > 7)).pos(p) } val binaryAtom: P[LiteralExpression] = @@ -103,7 +104,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } yield { val abs = Integer.parseInt(s, 2) val value = sign(abs, minus.isDefined) - LiteralExpression(value, size(value, s.length > 8, s.length > 16)).pos(p) + LiteralExpression(value, size(value, s.length > 8, s.length > 16, s.length > 24)).pos(p) } val hexAtom: P[LiteralExpression] = @@ -115,7 +116,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } yield { val abs = Integer.parseInt(s, 16) val value = sign(abs, minus.isDefined) - LiteralExpression(value, size(value, s.length > 2, s.length > 4)).pos(p) + LiteralExpression(value, size(value, s.length > 2, s.length > 4, s.length > 6)).pos(p) } val octalAtom: P[LiteralExpression] = @@ -127,7 +128,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } yield { val abs = Integer.parseInt(s, 8) val value = sign(abs, minus.isDefined) - LiteralExpression(value, size(value, s.length > 3, s.length > 6)).pos(p) + LiteralExpression(value, size(value, s.length > 3, s.length > 6, s.length > 9)).pos(p) } val quaternaryAtom: P[LiteralExpression] = @@ -139,7 +140,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } yield { val abs = Integer.parseInt(s, 4) val value = sign(abs, minus.isDefined) - LiteralExpression(value, size(value, s.length > 4, s.length > 8)).pos(p) + LiteralExpression(value, size(value, s.length > 4, s.length > 8, s.length > 12)).pos(p) } val literalAtom: P[LiteralExpression] = charAtom | binaryAtom | hexAtom | octalAtom | quaternaryAtom | decimalAtom diff --git a/src/test/scala/millfork/test/FarwordTest.scala b/src/test/scala/millfork/test/FarwordTest.scala new file mode 100644 index 00000000..e338bace --- /dev/null +++ b/src/test/scala/millfork/test/FarwordTest.scala @@ -0,0 +1,218 @@ +package millfork.test + +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class FarwordTest extends FunSuite with Matchers { + + test("Farword assignment") { + EmuBenchmarkRun( + """ + | farword output3 @$c000 + | farword output2 @$c004 + | farword output1 @$c008 + | void main () { + | output3 = $223344 + | output2 = $223344 + | output1 = $223344 + | output2 = $7788 + | output1 = $55 + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0x223344) + m.readMedium(0xc004) should equal(0x7788) + m.readMedium(0xc008) should equal(0x55) + } + } + test("Farword assignment 2") { + EmuBenchmarkRun( + """ + | farword output3 @$c000 + | farword output2 @$c004 + | word output1 @$c008 + | void main () { + | word w + | byte b + | w = $7788 + | b = $55 + | output3 = $23344 + | output2 = $11223344 + | output1 = $11223344 + | output2 = w + | output1 = b + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0x23344) + m.readMedium(0xc004) should equal(0x7788) + m.readMedium(0xc008) should equal(0x55) + } + } + + test("Farword assignment 3") { + EmuBenchmarkRun( + """ + | farword output0 @$c000 + | farword output1 @$c003 + | void main () { + | output0 = $112233 + | output1 = $112233 + | output0.hiword = output0.loword + | output1.loword = output1.hiword + | } + """.stripMargin) { m => + // TODO: this fails right now: + m.readMedium(0xc000) should equal(0x223333) + m.readMedium(0xc003) should equal(0x111122) + } + } + test("Farword addition") { + EmuBenchmarkRun( + """ + | farword output @$c000 + | void main () { + | word w + | farword l + | byte b + | w = $8000 + | b = $8 + | l = $50000 + | output = 0 + | output += l + | output += w + | output += b + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0x58008) + } + } + test("Farword addition 2") { + EmuBenchmarkRun( + """ + | farword output @$c000 + | void main () { + | output = 0 + | output += $50000 + | output += $8000 + | output += $8 + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0x58008) + } + } + test("Farword subtraction") { + EmuBenchmarkRun( + """ + | farword output @$c000 + | void main () { + | word w + | farword l + | byte b + | w = $8000 + | b = $8 + | l = $50000 + | output = $58008 + | output -= l + | output -= w + | output -= b + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0) + } + } + test("Farword subtraction 2") { + EmuBenchmarkRun( + """ + | farword output @$c000 + | void main () { + | output = $58008 + | output -= $50000 + | output -= $8000 + | output -= $8 + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0) + } + } + test("Farword subtraction 3") { + EmuBenchmarkRun( + """ + | farword output @$c000 + | void main () { + | output = $58008 + | output -= w() + | output -= b() + | } + | byte b() { + | return $8 + | } + | word w() { + | return $8000 + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0x50000) + } + } + + test("Farword AND") { + EmuBenchmarkRun( + """ + | farword output @$c000 + | void main () { + | output = $FFFFFF + | output &= w() + | output &= b() + | } + | byte b() { + | return $77 + | } + | word w() { + | return $CCCC + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(0x44) + } + } + + test("Farword INC/DEC") { + EmuBenchmarkRun( + """ + | farword output0 @$c000 + | farword output1 @$c004 + | farword output2 @$c008 + | farword output3 @$c00c + | farword output4 @$c010 + | farword output5 @$c014 + | farword output6 @$c018 + | void main () { + | output0 = 0 + | output1 = $FF + | output2 = $FFFF + | output3 = $FF00 + | output4 = $FF00 + | output5 = $10000 + | output6 = 0 + | barrier() + | output0 += 1 + | output1 += 1 + | output2 += 1 + | output3 += 1 + | output4 -= 1 + | output5 -= 1 + | output6 -= 1 + | } + | void barrier() { + | } + """.stripMargin) { m => + m.readMedium(0xc000) should equal(1) + m.readMedium(0xc004) should equal(0x100) + m.readMedium(0xc008) should equal(0x10000) + m.readMedium(0xc00c) should equal(0xff01) + m.readMedium(0xc010) should equal(0xfeff) + m.readMedium(0xc014) should equal(0xffff) + m.readMedium(0xc018) should equal(0xffffff) + } + } +}