1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-07-07 21:28:59 +00:00

Unified syntax for indexing

This commit is contained in:
Karol Stasiak 2019-04-18 16:24:46 +02:00
parent 85841c6395
commit 546c4d0f44
11 changed files with 215 additions and 79 deletions

View File

@ -65,7 +65,7 @@ Examples:
p[0] // valid only if the type 't' is of size 1 or 2, accesses the pointed element
p[i] // valid only if the type 't' is of size 1, equivalent to 't(p.raw[i])'
p->x // valid only if the type 't' has a field called 'x', accesses the field 'x' of the pointed element
p->x.y->z // you can stack it
p->x.y[0]->z[0][6] // you can stack it
## `nullptr`

View File

@ -254,23 +254,51 @@ object AbstractExpressionCompiler {
case DerefDebuggingExpression(_, 1) => b
case DerefDebuggingExpression(_, 2) => w
case DerefExpression(_, _, typ) => typ
case IndirectFieldExpression(inner, fieldPath) =>
val firstPointerType = getExpressionType(env, log, inner)
fieldPath.foldLeft(firstPointerType) { (currentType, fieldName) =>
case IndirectFieldExpression(inner, firstIndices, fieldPath) =>
var currentType = getExpressionType(env, log, inner)
var ok = true
for(_ <- firstIndices) {
currentType match {
case PointerType(_, _, Some(targetType)) =>
val tuples = env.getSubvariables(targetType).filter(x => x._1 == "." + fieldName)
if (tuples.isEmpty) {
log.error(s"Type `$targetType` doesn't have field named `$fieldName`", expr.position)
b
} else {
tuples.head._3
}
currentType = targetType
case x if x.isPointy =>
currentType = b
case _ =>
log.error(s"Type `$currentType` is not a pointer type", expr.position)
b
ok = false
}
}
for ((fieldName, indices) <- fieldPath) {
if (ok) {
currentType match {
case PointerType(_, _, Some(targetType)) =>
val tuples = env.getSubvariables(targetType).filter(x => x._1 == "." + fieldName)
if (tuples.isEmpty) {
log.error(s"Type `$targetType` doesn't have field named `$fieldName`", expr.position)
ok = false
} else {
currentType = tuples.head._3
}
case _ =>
log.error(s"Type `$currentType` is not a pointer type", expr.position)
ok = false
}
}
if (ok) {
for (_ <- indices) {
currentType match {
case PointerType(_, _, Some(targetType)) =>
currentType = targetType
case x if x.isPointy =>
currentType = b
case _ =>
log.error(s"Type `$currentType` is not a pointer type", expr.position)
ok = false
}
}
}
}
if (ok) currentType else b
case SeparateBytesExpression(hi, lo) =>
if (getExpressionType(env, log, hi).size > 1) log.error("Hi byte too large", hi.position)
if (getExpressionType(env, log, lo).size > 1) log.error("Lo byte too large", lo.position)

View File

@ -277,25 +277,60 @@ abstract class AbstractStatementPreprocessor(ctx: CompilationContext, statements
case _ =>
}
expr match {
case IndirectFieldExpression(root, fieldPath) if AbstractExpressionCompiler.getExpressionType(env, env.log, root).isInstanceOf[PointerType] =>
fieldPath.foldLeft(root) { (pointer, fieldName) =>
AbstractExpressionCompiler.getExpressionType(env, env.log, pointer) match {
case PointerType(_, _, Some(target)) =>
val subvariables = env.getSubvariables(target).filter(x => x._1 == "." + fieldName)
if (subvariables.isEmpty) {
ctx.log.error(s"Type `${target.name}` does not contain field `$fieldName`", pointer.position)
LiteralExpression(0, 1)
} else {
DerefExpression(optimizeExpr(pointer, currentVarValues).pos(pos), subvariables.head._2, subvariables.head._3)
case IndirectFieldExpression(root, firstIndices, fieldPath) =>
val b = env.get[Type]("byte")
var ok = true
var result = optimizeExpr(root, currentVarValues).pos(pos)
def applyIndex(result: Expression, index: Expression): Expression = {
AbstractExpressionCompiler.getExpressionType(env, env.log, result) match {
case pt@PointerType(_, _, Some(target)) =>
env.eval(index) match {
case Some(NumericConstant(0, _)) => //ok
case _ =>
env.log.error(s"Type `$pt` can be only indexed with 0")
}
DerefExpression(result, 0, target)
case x if x.isPointy =>
env.eval(index) match {
case Some(NumericConstant(n, _)) if n >= 0 && n <= 127 =>
DerefExpression(result, n.toInt, b)
case _ =>
DerefExpression(SumExpression(List(false -> result, false -> index), decimal = false), 0, b)
}
case _ =>
ctx.log.error("Invalid pointer type on the left-hand side of `->`", pointer.position)
LiteralExpression(0, 1)
ctx.log.error("Not a pointer type on the left-hand side of `[`", pos)
ok = false
result
}
}
case IndirectFieldExpression(root, fieldPath) =>
ctx.log.error("Invalid pointer type on the left-hand side of `->`", pos)
root
for (index <- firstIndices) {
result = applyIndex(result, index)
}
for ((fieldName, indices) <- fieldPath) {
if (ok) {
result = AbstractExpressionCompiler.getExpressionType(env, env.log, result) match {
case PointerType(_, _, Some(target)) =>
val subvariables = env.getSubvariables(target).filter(x => x._1 == "." + fieldName)
if (subvariables.isEmpty) {
ctx.log.error(s"Type `${target.name}` does not contain field `$fieldName`", result.position)
ok = false
LiteralExpression(0, 1)
} else {
DerefExpression(optimizeExpr(result, currentVarValues).pos(pos), subvariables.head._2, subvariables.head._3)
}
case _ =>
ctx.log.error("Invalid pointer type on the left-hand side of `->`", result.position)
LiteralExpression(0, 1)
}
}
if (ok) {
for (index <- indices) {
result = applyIndex(result, index)
}
}
}
result
case DerefDebuggingExpression(inner, 1) =>
DerefExpression(optimizeExpr(inner, currentVarValues), 0, env.get[VariableType]("byte")).pos(pos)
case DerefDebuggingExpression(inner, 2) =>

View File

@ -1699,8 +1699,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
nameCheck(inner)
case DerefExpression(inner, _, _) =>
nameCheck(inner)
case IndirectFieldExpression(inner, _) =>
case IndirectFieldExpression(inner, firstIndices, fields) =>
nameCheck(inner)
firstIndices.foreach(nameCheck)
fields.foreach(f => f._2.foreach(nameCheck))
case SeparateBytesExpression(h, l) =>
nameCheck(h)
nameCheck(l)

View File

@ -65,6 +65,12 @@ abstract class CallGraph(program: Program, log: Logger) {
val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr")
everCalledFunctions += varName
add(currentFunction, callingFunctions, i.index)
case i: DerefDebuggingExpression =>
add(currentFunction, callingFunctions, i.inner)
case IndirectFieldExpression(root, firstIndices, fields) =>
add(currentFunction, callingFunctions, root)
firstIndices.foreach(i => add(currentFunction, callingFunctions, i))
fields.foreach(f => f._2.foreach(i => add(currentFunction, callingFunctions, i)))
case _ => ()
}
}

View File

@ -202,19 +202,26 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress
override def getAllIdentifiers: Set[String] = index.getAllIdentifiers + name
}
case class IndirectFieldExpression(root: Expression, fields: List[String]) extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): Expression = IndirectFieldExpression(root.replaceVariable(variable, actualParam), fields)
case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expression], fields: Seq[(String, Seq[Expression])]) extends LhsExpression {
override def replaceVariable(variable: String, actualParam: Expression): Expression =
IndirectFieldExpression(
root.replaceVariable(variable, actualParam),
firstIndices.map(_.replaceVariable(variable, actualParam)),
fields.map{case (f, i) => f -> i.map(_.replaceVariable(variable, actualParam))})
override def containsVariable(variable: String): Boolean = root.containsVariable(variable)
override def containsVariable(variable: String): Boolean =
root.containsVariable(variable) ||
firstIndices.exists(_.containsVariable(variable)) ||
fields.exists(_._2.exists(_.containsVariable(variable)))
override def getPointies: Seq[String] = root match {
override def getPointies: Seq[String] = (root match {
case VariableExpression(v) => List(v)
case _ => root.getPointies
}
}) ++ firstIndices.flatMap(_.getPointies) ++ fields.flatMap(_._2.flatMap(_.getPointies))
override def isPure: Boolean = root.isPure
override def isPure: Boolean = root.isPure && firstIndices.forall(_.isPure) && fields.forall(_._2.forall(_.isPure))
override def getAllIdentifiers: Set[String] = root.getAllIdentifiers
override def getAllIdentifiers: Set[String] = root.getAllIdentifiers ++ firstIndices.flatMap(_.getAllIdentifiers) ++ fields.flatMap(_._2.flatMap(_.getAllIdentifiers))
}
case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) extends LhsExpression {

View File

@ -92,7 +92,7 @@ object UnusedFunctions extends NodeOptimization {
case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList)
case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions)
case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)
case Assignment(target, expr) => getAllCalledFunctions(target :: expr :: Nil)
case s: ReturnDispatchStatement =>
getAllCalledFunctions(s.getAllExpressions) ++ getAllCalledFunctions(s.branches.map(_.function))
case s: Statement => getAllCalledFunctions(s.getAllExpressions)
@ -115,6 +115,8 @@ object UnusedFunctions extends NodeOptimization {
case FunctionCallExpression(name, xs) => name :: getAllCalledFunctions(xs)
case IndexedExpression(arr, index) => arr :: getAllCalledFunctions(List(index))
case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l))
case DerefDebuggingExpression(inner, _) => getAllCalledFunctions(List(inner))
case IndirectFieldExpression(root, firstIndices, fieldPath) => getAllCalledFunctions(root :: firstIndices ++: fieldPath.flatMap(_._2).toList)
case _ => Nil
}

View File

@ -55,7 +55,7 @@ object UnusedLocalVariables extends NodeOptimization {
case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index))
case DerefExpression(inner, _, _) => getAllReadVariables(List(inner))
case DerefDebuggingExpression(inner, _) => getAllReadVariables(List(inner))
case IndirectFieldExpression(inner, _) => getAllReadVariables(List(inner))
case IndirectFieldExpression(inner, firstIndices, fields) => getAllReadVariables(List(inner) ++ firstIndices ++ fields.flatMap(_._2))
case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l))
case _ => Nil
}

View File

@ -192,7 +192,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
start <- mfExpression(nonStatementLevel, false) ~ HWS ~ "," ~/ HWS ~/ Pass
pos <- position("loop direction")
direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass
end <- mfExpression(nonStatementLevel, false)
end <- mfExpression(nonStatementLevel, false, allowTopLevelIndexing = false)
body <- AWS ~ arrayContents
} yield {
val fixedDirection = direction match {
@ -247,29 +247,29 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS
} yield Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, contents, alignment).pos(p))
def tightMfExpression(allowIntelHex: Boolean): P[Expression] = {
def tightMfExpression(allowIntelHex: Boolean, allowTopLevelIndexing: Boolean): P[Expression] = {
val a = if (allowIntelHex) atomWithIntel else atom
for {
expression <- mfParenExpr(allowIntelHex) | derefExpression | functionCall(allowIntelHex) | mfIndexedExpression | a
fieldPath <- ("->" ~/ AWS ~/ identifier).rep
} yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList)
if (allowTopLevelIndexing)
mfExpressionWrapper[Expression](mfParenExpr(allowIntelHex) | derefExpression | functionCall(allowIntelHex) | a)
else
mfParenExpr(allowIntelHex) | derefExpression | functionCall(allowIntelHex) | a
}
def tightMfExpressionButNotCall(allowIntelHex: Boolean): P[Expression] = {
def tightMfExpressionButNotCall(allowIntelHex: Boolean, allowTopLevelIndexing: Boolean): P[Expression] = {
val a = if (allowIntelHex) atomWithIntel else atom
for {
expression <- mfParenExpr(allowIntelHex) | derefExpression | mfIndexedExpression | a
fieldPath <- ("->" ~/ AWS ~/ identifier).rep
} yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList)
if (allowTopLevelIndexing)
mfExpressionWrapper[Expression](mfParenExpr(allowIntelHex) | derefExpression | a)
else
mfParenExpr(allowIntelHex) | derefExpression | a
}
def mfExpression(level: Int, allowIntelHex: Boolean): P[Expression] = {
def mfExpression(level: Int, allowIntelHex: Boolean, allowTopLevelIndexing: Boolean = true): P[Expression] = {
val allowedOperators = mfOperatorsDropFlatten(level)
def inner: P[SeparatedList[Expression, String]] = {
for {
head <- tightMfExpression(allowIntelHex) ~/ HWS
maybeOperator <- StringIn(allowedOperators: _*).!.?
head <- tightMfExpression(allowIntelHex, allowTopLevelIndexing) ~/ HWS
maybeOperator <- (StringIn(allowedOperators: _*).! ~ !CharIn(Seq('/', '=', '-', '+', ':', '>', '<', '\''))).?
maybeTail <- maybeOperator.fold[P[Option[List[(String, Expression)]]]](Pass.map(_ => None))(o => (HWS ~/ inner ~/ HWS).map(x2 => Some((o -> x2.head) :: x2.tail)))
} yield {
maybeTail.fold[SeparatedList[Expression, String]](SeparatedList.of(head))(t => SeparatedList(head, t))
@ -296,6 +296,13 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
} else {
SeparateBytesExpression(p(xs.head, level + 1), p(xs.tail.head._2, level + 1))
}
case List(eq) if level == 0 =>
if (xs.size != 2) {
log.error(s"The `$eq` operator can have only two arguments", xs.head.head.position)
LiteralExpression(0, 1)
} else {
FunctionCallExpression(eq, xs.items.map(value => p(value, level + 1)))
}
case List(op) =>
FunctionCallExpression(op, xs.items.map(value => p(value, level + 1)))
case _ =>
@ -307,25 +314,31 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
inner.map(x => p(x, 0))
}
def mfLhsExpressionSimple: P[LhsExpression] = for {
expression <- mfIndexedExpression | derefExpression | (position() ~ identifier).map{case (p,n) => VariableExpression(n).pos(p)} ~ HWS
fieldPath <- ("->" ~/ AWS ~/ identifier).rep
} yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList)
def index: P[Expression] = HWS ~ "[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~/ "]" ~/ Pass
def mfLhsExpression: P[LhsExpression] = for {
(p, left) <- position() ~ mfLhsExpressionSimple
rightOpt <- (HWS ~ ":" ~/ HWS ~ mfLhsExpressionSimple).?
} yield rightOpt.fold(left)(right => SeparateBytesExpression(left, right).pos(p))
def mfExpressionWrapper[T <: Expression](inner: P[T]): P[T] = for {
expr <- inner
firstIndices <- index.rep
fieldPath <- (HWS ~ "->" ~/ AWS ~/ identifier ~/ index.rep).rep
} yield (expr, firstIndices, fieldPath) match {
case (_, Seq(), Seq()) => expr
case (VariableExpression(vname), Seq(i), Seq()) => IndexedExpression(vname, i).asInstanceOf[T]
case _ => IndirectFieldExpression(expr, firstIndices, fieldPath).asInstanceOf[T]
}
// def mfLhsExpression: P[LhsExpression] = for {
// (p, left) <- position() ~ mfLhsExpressionSimple
// rightOpt <- (HWS ~ ":" ~/ HWS ~ mfLhsExpressionSimple).?
// } yield rightOpt.fold(left)(right => SeparateBytesExpression(left, right).pos(p))
def mfLhsExpressionSimple: P[LhsExpression] =
mfExpressionWrapper[LhsExpression](derefExpression | (position() ~ identifier).map{case (p,n) => VariableExpression(n).pos(p)} ~ HWS)
def mfLhsExpression: P[LhsExpression] =
mfExpression(nonStatementLevel, false).filter(_.isInstanceOf[LhsExpression]).map(_.asInstanceOf[LhsExpression])
def mfParenExpr(allowIntelHex: Boolean): P[Expression] = P("(" ~/ AWS ~/ mfExpression(nonStatementLevel, allowIntelHex) ~ AWS ~/ ")")
def mfIndexedExpression: P[IndexedExpression] = for {
p <- position()
array <- identifier
index <- HWS ~ "[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~/ "]"
} yield IndexedExpression(array, index).pos(p)
def functionCall(allowIntelHex: Boolean): P[FunctionCallExpression] = for {
p <- position()
name <- identifier
@ -339,12 +352,15 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
inner <- mfParenExpr(false)
} yield DerefDebuggingExpression(inner, yens.length).pos(p)
val expressionStatement: P[Seq[ExecutableStatement]] = mfExpression(0, false).map(x => Seq(ExpressionStatement(x)))
val assignmentStatement: P[Seq[ExecutableStatement]] =
(position() ~ mfLhsExpression ~ HWS ~ "=" ~/ HWS ~ mfExpression(1, false)).map {
case (p, l, r) => Seq(Assignment(l, r).pos(p))
}
val expressionStatement: P[Seq[ExecutableStatement]] = mfExpression(0, false).map {
case FunctionCallExpression("=", List(t: LhsExpression, s)) =>
Seq(Assignment(t, s).pos(t.position))
case x@FunctionCallExpression("=", exprs) =>
log.error("Invalid left-hand-side of an assignment", x.position)
exprs.map(ExpressionStatement)
case x =>
Seq(ExpressionStatement(x).pos(x.position))
}
def keywordStatement: P[Seq[ExecutableStatement]] = P(
returnOrDispatchStatement |
@ -355,8 +371,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
doWhileStatement |
breakStatement |
continueStatement |
inlineAssembly |
assignmentStatement)
inlineAssembly)
def executableStatement: P[Seq[ExecutableStatement]] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.map(_.pos(p)) }
@ -391,7 +406,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
val dispatchBranch: P[ReturnDispatchBranch] = for {
pos <- position()
l <- dispatchLabel ~/ HWS ~/ "@" ~/ HWS
f <- tightMfExpressionButNotCall(false) ~/ HWS
f <- tightMfExpressionButNotCall(false, allowTopLevelIndexing = false) ~/ HWS
parameters <- ("(" ~/ position("dispatch actual parameters") ~ AWS ~/ mfExpression(nonStatementLevel, false).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").?
} yield ReturnDispatchBranch(l, f, parameters.map(_._2.toList).getOrElse(Nil)).pos(pos)
@ -641,7 +656,7 @@ object MfParser {
}
val mfOperators = List(
List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'="),
List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'=", "="),
List("||", "^^"),
List("&&"),
List("==", "<=", ">=", "!=", "<", ">"),
@ -683,5 +698,5 @@ object MfParser {
val functionFlags: P[Set[String]] = flags_("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt")
val InvalidReturnTypes = Set("enum", "alias", "array", "const", "stack", "register", "static", "volatile", "import")
val InvalidReturnTypes = Set("enum", "alias", "array", "const", "stack", "register", "static", "volatile", "import", "struct", "union")
}

View File

@ -188,7 +188,7 @@ class PreprocessorParser(options: CompilationOptions) {
def inner: P[SeparatedList[Q, String]] = {
for {
head <- tightMfExpression ~/ HWS
maybeOperator <- StringIn(allowedOperators: _*).!.?
maybeOperator <- (StringIn(allowedOperators: _*).! ~ !CharIn(Seq('-','+','/'))).?
maybeTail <- maybeOperator.fold[P[Option[List[(String, Q)]]]](Pass.map(_ => None))(o => (HWS ~/ inner ~/ HWS).map(x2 => Some((o -> x2.head) :: x2.tail)))
} yield {
maybeTail.fold[SeparatedList[Q, String]](SeparatedList.of(head))(t => SeparatedList(head, t))

View File

@ -2,12 +2,12 @@ package millfork.test
import millfork.Cpu
import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun}
import org.scalatest.{FunSuite, Matchers}
import org.scalatest.{AppendedClues, FunSuite, Matchers}
/**
* @author Karol Stasiak
*/
class PointerSuite extends FunSuite with Matchers {
class PointerSuite extends FunSuite with Matchers with AppendedClues {
test("Pointers outside zeropage") {
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Sixteen, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(
@ -168,4 +168,45 @@ class PointerSuite extends FunSuite with Matchers {
}
}
test("Complex pointers") {
// TODO: optimize it when inlined
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(
"""
| array output[3] @$c000
| struct s {
| pointer p
| }
| s tmp
| pointer.s tmpptr
| pointer.pointer.s get() {
| tmp.p = output.addr
| tmpptr = tmp.pointer
| return tmpptr.pointer
| }
| void main() {
| get()[0]->p[0] = 5
| }
""".stripMargin) { m =>
m.readByte(0xc000) should equal(5)
}
}
test("Indexing returned pointers") {
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(
"""
| array output[10] @$c000
| pointer get() = output.addr
| void main() {
| byte i
| for i,0,paralleluntil,10 {
| get()[i] = 42
| }
| }
""".stripMargin) { m =>
for(i <- 0xc000 until 0xc00a) {
m.readByte(i) should equal(42) withClue i
}
}
}
}