mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-10 20:29:35 +00:00
Pointers to fields of array elements
This commit is contained in:
parent
b873030b29
commit
4b25ce2d8c
@ -246,6 +246,8 @@ Note that you cannot access a whole array element if it's bigger than 2 bytes, b
|
||||
a[2].b0 // ok
|
||||
a[2].loword // ok
|
||||
a[2].pointer // ok
|
||||
a[2].addr // ok
|
||||
a[2].b0.addr // ok, equal to the above on little-endian targets
|
||||
|
||||
## Built-in functions
|
||||
|
||||
|
@ -273,7 +273,16 @@ object AbstractExpressionCompiler {
|
||||
case DerefDebuggingExpression(_, 2) => w
|
||||
case DerefExpression(_, _, typ) => typ
|
||||
case IndirectFieldExpression(inner, firstIndices, fieldPath) =>
|
||||
var currentType = getExpressionType(env, log, inner)
|
||||
var currentType = inner match {
|
||||
case VariableExpression(arrName) =>
|
||||
env.maybeGet[Thing](arrName + ".array") match {
|
||||
case Some(a: MfArray) =>
|
||||
env.get[Type]("pointer." + a.elementType)
|
||||
case _ =>
|
||||
getExpressionType(env, log, inner)
|
||||
}
|
||||
case _ => getExpressionType(env, log, inner)
|
||||
}
|
||||
var ok = true
|
||||
for(_ <- firstIndices) {
|
||||
currentType match {
|
||||
@ -289,24 +298,36 @@ object AbstractExpressionCompiler {
|
||||
for ((dot, fieldName, indices) <- fieldPath) {
|
||||
if (dot && ok) {
|
||||
fieldName match {
|
||||
case "addr" => env.get[Type]("pointer")
|
||||
case "pointer" => env.get[Type]("pointer." + currentType.name)
|
||||
case "addr.hi" => b
|
||||
case "addr.lo" => b
|
||||
case "pointer.hi" => b
|
||||
case "pointer.lo" => b
|
||||
case "addr" => currentType = env.get[Type]("pointer")
|
||||
case "pointer" => currentType = env.get[Type]("pointer." + currentType.name)
|
||||
case "addr.hi" => currentType = b
|
||||
case "addr.lo" => currentType = b
|
||||
case "pointer.hi" => currentType = b
|
||||
case "pointer.lo" => currentType = b
|
||||
case _ =>
|
||||
log.error(s"Unexpected subfield `$fieldName`", expr.position)
|
||||
ok = false
|
||||
}
|
||||
} else if (ok) {
|
||||
val (actualFieldName, pointerWrap): (String, Int) = getActualFieldNameAndPointerWrap(fieldName)
|
||||
currentType match {
|
||||
case PointerType(_, _, Some(targetType)) =>
|
||||
val tuples = env.getSubvariables(targetType).filter(x => x._1 == "." + fieldName)
|
||||
val tuples = env.getSubvariables(targetType).filter(x => x._1 == "." + actualFieldName)
|
||||
if (tuples.isEmpty) {
|
||||
log.error(s"Type `$targetType` doesn't have field named `$fieldName`", expr.position)
|
||||
log.error(s"Type `$targetType` doesn't have field named `$actualFieldName`", expr.position)
|
||||
ok = false
|
||||
} else {
|
||||
currentType = tuples.head._3
|
||||
pointerWrap match {
|
||||
case 0 =>
|
||||
currentType = tuples.head._3
|
||||
case 1 =>
|
||||
currentType = env.get[Type]("pointer." + tuples.head._3)
|
||||
case 2 =>
|
||||
currentType = env.get[Type]("pointer")
|
||||
case 10 | 11 =>
|
||||
currentType = b
|
||||
case _ => throw new IllegalStateException
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
log.error(s"Type `$currentType` is not a pointer type", expr.position)
|
||||
@ -417,6 +438,24 @@ object AbstractExpressionCompiler {
|
||||
t
|
||||
}
|
||||
|
||||
def getActualFieldNameAndPointerWrap(fieldName: String): (String, Int) = {
|
||||
if (fieldName.endsWith(".pointer")) {
|
||||
fieldName.stripSuffix(".pointer") -> 1
|
||||
} else if (fieldName.endsWith(".addr")) {
|
||||
fieldName.stripSuffix(".addr") -> 2
|
||||
} else if (fieldName.endsWith(".addr.hi")) {
|
||||
fieldName.stripSuffix(".addr.hi") -> 11
|
||||
} else if (fieldName.endsWith(".pointer.hi")) {
|
||||
fieldName.stripSuffix(".pointer.hi") -> 11
|
||||
} else if (fieldName.endsWith(".addr.lo")) {
|
||||
fieldName.stripSuffix(".addr.lo") -> 10
|
||||
} else if (fieldName.endsWith(".pointer.lo")) {
|
||||
fieldName.stripSuffix(".pointer.lo") -> 10
|
||||
} else {
|
||||
fieldName -> 0
|
||||
}
|
||||
}
|
||||
|
||||
def checkIndexType(ctx: CompilationContext, pointy: Pointy, index: Expression): Unit = {
|
||||
val indexerType = getExpressionType(ctx, index)
|
||||
if (!indexerType.isAssignableTo(pointy.indexType)) {
|
||||
|
@ -366,10 +366,10 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
if (dot && ok) {
|
||||
val pointer = result match {
|
||||
case DerefExpression(inner, 0, _) =>
|
||||
inner
|
||||
optimizeExpr(inner, currentVarValues).pos(pos)
|
||||
case DerefExpression(inner, offset, targetType) =>
|
||||
("pointer." + targetType.name) <| SumExpression(List(
|
||||
false -> ("pointer" <| inner),
|
||||
false -> ("pointer" <| optimizeExpr(inner, currentVarValues).pos(pos)),
|
||||
false -> LiteralExpression(offset, 2)
|
||||
), decimal = false)
|
||||
case IndexedExpression(name, index) =>
|
||||
@ -391,16 +391,44 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
ok = false
|
||||
}
|
||||
} else if (ok) {
|
||||
val (actualFieldName, pointerWrap): (String, Int) = AbstractExpressionCompiler.getActualFieldNameAndPointerWrap(fieldName)
|
||||
val currentResultType = AbstractExpressionCompiler.getExpressionType(env, env.log, result)
|
||||
result = currentResultType match {
|
||||
case PointerType(_, _, Some(target)) =>
|
||||
val subvariables = env.getSubvariables(target).filter(x => x._1 == "." + fieldName)
|
||||
val subvariables = env.getSubvariables(target).filter(x => x._1 == "." + actualFieldName)
|
||||
if (subvariables.isEmpty) {
|
||||
ctx.log.error(s"Type `${target.name}` does not contain field `$fieldName`", result.position)
|
||||
ctx.log.error(s"Type `${target.name}` does not contain field `$actualFieldName`", result.position)
|
||||
ok = false
|
||||
LiteralExpression(0, 1)
|
||||
} else {
|
||||
DerefExpression(optimizeExpr(result, currentVarValues).pos(pos), subvariables.head._2, subvariables.head._3)
|
||||
val inner = optimizeExpr(result, currentVarValues).pos(pos)
|
||||
val fieldOffset = subvariables.head._2
|
||||
val fieldType = subvariables.head._3
|
||||
pointerWrap match {
|
||||
case 0 =>
|
||||
DerefExpression(inner, fieldOffset, fieldType)
|
||||
case 1 =>
|
||||
("pointer." + fieldType.name) <| SumExpression(List(
|
||||
false -> ("pointer" <| inner),
|
||||
false -> LiteralExpression(fieldOffset, 2)
|
||||
), decimal = false)
|
||||
case 2 =>
|
||||
SumExpression(List(
|
||||
false -> ("pointer" <| inner),
|
||||
false -> LiteralExpression(fieldOffset, 2)
|
||||
), decimal = false)
|
||||
case 10 =>
|
||||
"lo" <| SumExpression(List(
|
||||
false -> ("pointer" <| inner),
|
||||
false -> LiteralExpression(fieldOffset, 2)
|
||||
), decimal = false)
|
||||
case 11 =>
|
||||
"hi" <| SumExpression(List(
|
||||
false -> ("pointer" <| inner),
|
||||
false -> LiteralExpression(fieldOffset, 2)
|
||||
), decimal = false)
|
||||
case _ => throw new IllegalStateException
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
ctx.log.error("Invalid pointer type on the left-hand side of `->`", result.position)
|
||||
|
@ -335,6 +335,8 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
|
||||
e match {
|
||||
case (".", "pointer", _) => Seq(e)
|
||||
case (".", f, _) if f.startsWith("pointer.") => Seq(e)
|
||||
case (".", "addr", _) => Seq(e)
|
||||
case (".", f, _) if f.startsWith("addr.") => Seq(e)
|
||||
case (".", f, i) => Seq((".", "pointer", Nil), ("->", f, i))
|
||||
case _ => Seq(e)
|
||||
}
|
||||
|
@ -429,6 +429,33 @@ class ArraySuite extends FunSuite with Matchers {
|
||||
}
|
||||
}
|
||||
|
||||
test("Pointers to array elements") {
|
||||
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Intel8080, Cpu.Z80)(
|
||||
"""
|
||||
| struct coord { byte x, byte y }
|
||||
|
|
||||
| array(coord) c = [coord(1,2),coord(3,4)]
|
||||
| array(byte) z = "hello!"
|
||||
|
|
||||
| word output @$c000
|
||||
|
|
||||
| void main () {
|
||||
| output = 354
|
||||
| output += word(c[0].pointer)
|
||||
| output -= c[0].addr
|
||||
| output += word(c[0].x.pointer)
|
||||
| output -= word(c[0].x.addr)
|
||||
| output += word(pointer.coord(c.addr)->x.addr)
|
||||
| output -= word(pointer.coord(c.addr)->x.pointer)
|
||||
| output += word(z[0].pointer)
|
||||
| output -= z[0].addr
|
||||
| }
|
||||
|
|
||||
""".stripMargin){ m =>
|
||||
m.readWord(0xc000) should equal(354)
|
||||
}
|
||||
}
|
||||
|
||||
test("Invalid array things that will become valid in the future") {
|
||||
ShouldNotCompile(
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user