From 546c4d0f44d9c4046ba3e932790ca87ad3b3733d Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Thu, 18 Apr 2019 16:24:46 +0200 Subject: [PATCH] Unified syntax for indexing --- docs/lang/types.md | 2 +- .../compiler/AbstractExpressionCompiler.scala | 50 +++++++--- .../AbstractStatementPreprocessor.scala | 65 ++++++++++--- src/main/scala/millfork/env/Environment.scala | 4 +- src/main/scala/millfork/node/CallGraph.scala | 6 ++ src/main/scala/millfork/node/Node.scala | 21 +++-- .../millfork/node/opt/UnusedFunctions.scala | 4 +- .../node/opt/UnusedLocalVariables.scala | 2 +- src/main/scala/millfork/parser/MfParser.scala | 93 +++++++++++-------- .../scala/millfork/parser/Preprocessor.scala | 2 +- .../scala/millfork/test/PointerSuite.scala | 45 ++++++++- 11 files changed, 215 insertions(+), 79 deletions(-) diff --git a/docs/lang/types.md b/docs/lang/types.md index db20f2b9..a1923cee 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -65,7 +65,7 @@ Examples: p[0] // valid only if the type 't' is of size 1 or 2, accesses the pointed element p[i] // valid only if the type 't' is of size 1, equivalent to 't(p.raw[i])' p->x // valid only if the type 't' has a field called 'x', accesses the field 'x' of the pointed element - p->x.y->z // you can stack it + p->x.y[0]->z[0][6] // you can stack it ## `nullptr` diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index d5d5b321..5f22e976 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -254,23 +254,51 @@ object AbstractExpressionCompiler { case DerefDebuggingExpression(_, 1) => b case DerefDebuggingExpression(_, 2) => w case DerefExpression(_, _, typ) => typ - case IndirectFieldExpression(inner, fieldPath) => - val firstPointerType = getExpressionType(env, log, inner) - fieldPath.foldLeft(firstPointerType) { (currentType, fieldName) => + case IndirectFieldExpression(inner, firstIndices, fieldPath) => + var currentType = getExpressionType(env, log, inner) + var ok = true + for(_ <- firstIndices) { currentType match { case PointerType(_, _, Some(targetType)) => - val tuples = env.getSubvariables(targetType).filter(x => x._1 == "." + fieldName) - if (tuples.isEmpty) { - log.error(s"Type `$targetType` doesn't have field named `$fieldName`", expr.position) - b - } else { - tuples.head._3 - } + currentType = targetType + case x if x.isPointy => + currentType = b case _ => log.error(s"Type `$currentType` is not a pointer type", expr.position) - b + ok = false } } + for ((fieldName, indices) <- fieldPath) { + if (ok) { + currentType match { + case PointerType(_, _, Some(targetType)) => + val tuples = env.getSubvariables(targetType).filter(x => x._1 == "." + fieldName) + if (tuples.isEmpty) { + log.error(s"Type `$targetType` doesn't have field named `$fieldName`", expr.position) + ok = false + } else { + currentType = tuples.head._3 + } + case _ => + log.error(s"Type `$currentType` is not a pointer type", expr.position) + ok = false + } + } + if (ok) { + for (_ <- indices) { + currentType match { + case PointerType(_, _, Some(targetType)) => + currentType = targetType + case x if x.isPointy => + currentType = b + case _ => + log.error(s"Type `$currentType` is not a pointer type", expr.position) + ok = false + } + } + } + } + if (ok) currentType else b case SeparateBytesExpression(hi, lo) => if (getExpressionType(env, log, hi).size > 1) log.error("Hi byte too large", hi.position) if (getExpressionType(env, log, lo).size > 1) log.error("Lo byte too large", lo.position) diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index b45a1d70..e91ed756 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -277,25 +277,60 @@ abstract class AbstractStatementPreprocessor(ctx: CompilationContext, statements case _ => } expr match { - case IndirectFieldExpression(root, fieldPath) if AbstractExpressionCompiler.getExpressionType(env, env.log, root).isInstanceOf[PointerType] => - fieldPath.foldLeft(root) { (pointer, fieldName) => - AbstractExpressionCompiler.getExpressionType(env, env.log, pointer) match { - case PointerType(_, _, Some(target)) => - val subvariables = env.getSubvariables(target).filter(x => x._1 == "." + fieldName) - if (subvariables.isEmpty) { - ctx.log.error(s"Type `${target.name}` does not contain field `$fieldName`", pointer.position) - LiteralExpression(0, 1) - } else { - DerefExpression(optimizeExpr(pointer, currentVarValues).pos(pos), subvariables.head._2, subvariables.head._3) + case IndirectFieldExpression(root, firstIndices, fieldPath) => + val b = env.get[Type]("byte") + var ok = true + var result = optimizeExpr(root, currentVarValues).pos(pos) + def applyIndex(result: Expression, index: Expression): Expression = { + AbstractExpressionCompiler.getExpressionType(env, env.log, result) match { + case pt@PointerType(_, _, Some(target)) => + env.eval(index) match { + case Some(NumericConstant(0, _)) => //ok + case _ => + env.log.error(s"Type `$pt` can be only indexed with 0") + } + DerefExpression(result, 0, target) + case x if x.isPointy => + env.eval(index) match { + case Some(NumericConstant(n, _)) if n >= 0 && n <= 127 => + DerefExpression(result, n.toInt, b) + case _ => + DerefExpression(SumExpression(List(false -> result, false -> index), decimal = false), 0, b) } case _ => - ctx.log.error("Invalid pointer type on the left-hand side of `->`", pointer.position) - LiteralExpression(0, 1) + ctx.log.error("Not a pointer type on the left-hand side of `[`", pos) + ok = false + result } } - case IndirectFieldExpression(root, fieldPath) => - ctx.log.error("Invalid pointer type on the left-hand side of `->`", pos) - root + + for (index <- firstIndices) { + result = applyIndex(result, index) + } + for ((fieldName, indices) <- fieldPath) { + if (ok) { + result = AbstractExpressionCompiler.getExpressionType(env, env.log, result) match { + case PointerType(_, _, Some(target)) => + val subvariables = env.getSubvariables(target).filter(x => x._1 == "." + fieldName) + if (subvariables.isEmpty) { + ctx.log.error(s"Type `${target.name}` does not contain field `$fieldName`", result.position) + ok = false + LiteralExpression(0, 1) + } else { + DerefExpression(optimizeExpr(result, currentVarValues).pos(pos), subvariables.head._2, subvariables.head._3) + } + case _ => + ctx.log.error("Invalid pointer type on the left-hand side of `->`", result.position) + LiteralExpression(0, 1) + } + } + if (ok) { + for (index <- indices) { + result = applyIndex(result, index) + } + } + } + result case DerefDebuggingExpression(inner, 1) => DerefExpression(optimizeExpr(inner, currentVarValues), 0, env.get[VariableType]("byte")).pos(pos) case DerefDebuggingExpression(inner, 2) => diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 4bf4738b..5a3b4d84 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -1699,8 +1699,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa nameCheck(inner) case DerefExpression(inner, _, _) => nameCheck(inner) - case IndirectFieldExpression(inner, _) => + case IndirectFieldExpression(inner, firstIndices, fields) => nameCheck(inner) + firstIndices.foreach(nameCheck) + fields.foreach(f => f._2.foreach(nameCheck)) case SeparateBytesExpression(h, l) => nameCheck(h) nameCheck(l) diff --git a/src/main/scala/millfork/node/CallGraph.scala b/src/main/scala/millfork/node/CallGraph.scala index 008ec07d..9470f8d6 100644 --- a/src/main/scala/millfork/node/CallGraph.scala +++ b/src/main/scala/millfork/node/CallGraph.scala @@ -65,6 +65,12 @@ abstract class CallGraph(program: Program, log: Logger) { val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") everCalledFunctions += varName add(currentFunction, callingFunctions, i.index) + case i: DerefDebuggingExpression => + add(currentFunction, callingFunctions, i.inner) + case IndirectFieldExpression(root, firstIndices, fields) => + add(currentFunction, callingFunctions, root) + firstIndices.foreach(i => add(currentFunction, callingFunctions, i)) + fields.foreach(f => f._2.foreach(i => add(currentFunction, callingFunctions, i))) case _ => () } } diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index b8f28614..a35ff9a4 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -202,19 +202,26 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress override def getAllIdentifiers: Set[String] = index.getAllIdentifiers + name } -case class IndirectFieldExpression(root: Expression, fields: List[String]) extends LhsExpression { - override def replaceVariable(variable: String, actualParam: Expression): Expression = IndirectFieldExpression(root.replaceVariable(variable, actualParam), fields) +case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expression], fields: Seq[(String, Seq[Expression])]) extends LhsExpression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = + IndirectFieldExpression( + root.replaceVariable(variable, actualParam), + firstIndices.map(_.replaceVariable(variable, actualParam)), + fields.map{case (f, i) => f -> i.map(_.replaceVariable(variable, actualParam))}) - override def containsVariable(variable: String): Boolean = root.containsVariable(variable) + override def containsVariable(variable: String): Boolean = + root.containsVariable(variable) || + firstIndices.exists(_.containsVariable(variable)) || + fields.exists(_._2.exists(_.containsVariable(variable))) - override def getPointies: Seq[String] = root match { + override def getPointies: Seq[String] = (root match { case VariableExpression(v) => List(v) case _ => root.getPointies - } + }) ++ firstIndices.flatMap(_.getPointies) ++ fields.flatMap(_._2.flatMap(_.getPointies)) - override def isPure: Boolean = root.isPure + override def isPure: Boolean = root.isPure && firstIndices.forall(_.isPure) && fields.forall(_._2.forall(_.isPure)) - override def getAllIdentifiers: Set[String] = root.getAllIdentifiers + override def getAllIdentifiers: Set[String] = root.getAllIdentifiers ++ firstIndices.flatMap(_.getAllIdentifiers) ++ fields.flatMap(_._2.flatMap(_.getAllIdentifiers)) } case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) extends LhsExpression { diff --git a/src/main/scala/millfork/node/opt/UnusedFunctions.scala b/src/main/scala/millfork/node/opt/UnusedFunctions.scala index 87d46d4b..aa1c5345 100644 --- a/src/main/scala/millfork/node/opt/UnusedFunctions.scala +++ b/src/main/scala/millfork/node/opt/UnusedFunctions.scala @@ -92,7 +92,7 @@ object UnusedFunctions extends NodeOptimization { case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList) case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions) case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) - case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) + case Assignment(target, expr) => getAllCalledFunctions(target :: expr :: Nil) case s: ReturnDispatchStatement => getAllCalledFunctions(s.getAllExpressions) ++ getAllCalledFunctions(s.branches.map(_.function)) case s: Statement => getAllCalledFunctions(s.getAllExpressions) @@ -115,6 +115,8 @@ object UnusedFunctions extends NodeOptimization { case FunctionCallExpression(name, xs) => name :: getAllCalledFunctions(xs) case IndexedExpression(arr, index) => arr :: getAllCalledFunctions(List(index)) case SeparateBytesExpression(h, l) => getAllCalledFunctions(List(h, l)) + case DerefDebuggingExpression(inner, _) => getAllCalledFunctions(List(inner)) + case IndirectFieldExpression(root, firstIndices, fieldPath) => getAllCalledFunctions(root :: firstIndices ++: fieldPath.flatMap(_._2).toList) case _ => Nil } diff --git a/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala index 047e77a4..9b649f15 100644 --- a/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala +++ b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala @@ -55,7 +55,7 @@ object UnusedLocalVariables extends NodeOptimization { case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index)) case DerefExpression(inner, _, _) => getAllReadVariables(List(inner)) case DerefDebuggingExpression(inner, _) => getAllReadVariables(List(inner)) - case IndirectFieldExpression(inner, _) => getAllReadVariables(List(inner)) + case IndirectFieldExpression(inner, firstIndices, fields) => getAllReadVariables(List(inner) ++ firstIndices ++ fields.flatMap(_._2)) case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l)) case _ => Nil } diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 26245b46..2f7e267d 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -192,7 +192,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri start <- mfExpression(nonStatementLevel, false) ~ HWS ~ "," ~/ HWS ~/ Pass pos <- position("loop direction") direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass - end <- mfExpression(nonStatementLevel, false) + end <- mfExpression(nonStatementLevel, false, allowTopLevelIndexing = false) body <- AWS ~ arrayContents } yield { val fixedDirection = direction match { @@ -247,29 +247,29 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS } yield Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, contents, alignment).pos(p)) - def tightMfExpression(allowIntelHex: Boolean): P[Expression] = { + def tightMfExpression(allowIntelHex: Boolean, allowTopLevelIndexing: Boolean): P[Expression] = { val a = if (allowIntelHex) atomWithIntel else atom - for { - expression <- mfParenExpr(allowIntelHex) | derefExpression | functionCall(allowIntelHex) | mfIndexedExpression | a - fieldPath <- ("->" ~/ AWS ~/ identifier).rep - } yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList) + if (allowTopLevelIndexing) + mfExpressionWrapper[Expression](mfParenExpr(allowIntelHex) | derefExpression | functionCall(allowIntelHex) | a) + else + mfParenExpr(allowIntelHex) | derefExpression | functionCall(allowIntelHex) | a } - def tightMfExpressionButNotCall(allowIntelHex: Boolean): P[Expression] = { + def tightMfExpressionButNotCall(allowIntelHex: Boolean, allowTopLevelIndexing: Boolean): P[Expression] = { val a = if (allowIntelHex) atomWithIntel else atom - for { - expression <- mfParenExpr(allowIntelHex) | derefExpression | mfIndexedExpression | a - fieldPath <- ("->" ~/ AWS ~/ identifier).rep - } yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList) + if (allowTopLevelIndexing) + mfExpressionWrapper[Expression](mfParenExpr(allowIntelHex) | derefExpression | a) + else + mfParenExpr(allowIntelHex) | derefExpression | a } - def mfExpression(level: Int, allowIntelHex: Boolean): P[Expression] = { + def mfExpression(level: Int, allowIntelHex: Boolean, allowTopLevelIndexing: Boolean = true): P[Expression] = { val allowedOperators = mfOperatorsDropFlatten(level) def inner: P[SeparatedList[Expression, String]] = { for { - head <- tightMfExpression(allowIntelHex) ~/ HWS - maybeOperator <- StringIn(allowedOperators: _*).!.? + head <- tightMfExpression(allowIntelHex, allowTopLevelIndexing) ~/ HWS + maybeOperator <- (StringIn(allowedOperators: _*).! ~ !CharIn(Seq('/', '=', '-', '+', ':', '>', '<', '\''))).? maybeTail <- maybeOperator.fold[P[Option[List[(String, Expression)]]]](Pass.map(_ => None))(o => (HWS ~/ inner ~/ HWS).map(x2 => Some((o -> x2.head) :: x2.tail))) } yield { maybeTail.fold[SeparatedList[Expression, String]](SeparatedList.of(head))(t => SeparatedList(head, t)) @@ -296,6 +296,13 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } else { SeparateBytesExpression(p(xs.head, level + 1), p(xs.tail.head._2, level + 1)) } + case List(eq) if level == 0 => + if (xs.size != 2) { + log.error(s"The `$eq` operator can have only two arguments", xs.head.head.position) + LiteralExpression(0, 1) + } else { + FunctionCallExpression(eq, xs.items.map(value => p(value, level + 1))) + } case List(op) => FunctionCallExpression(op, xs.items.map(value => p(value, level + 1))) case _ => @@ -307,25 +314,31 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri inner.map(x => p(x, 0)) } - def mfLhsExpressionSimple: P[LhsExpression] = for { - expression <- mfIndexedExpression | derefExpression | (position() ~ identifier).map{case (p,n) => VariableExpression(n).pos(p)} ~ HWS - fieldPath <- ("->" ~/ AWS ~/ identifier).rep - } yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList) + def index: P[Expression] = HWS ~ "[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~/ "]" ~/ Pass - def mfLhsExpression: P[LhsExpression] = for { - (p, left) <- position() ~ mfLhsExpressionSimple - rightOpt <- (HWS ~ ":" ~/ HWS ~ mfLhsExpressionSimple).? - } yield rightOpt.fold(left)(right => SeparateBytesExpression(left, right).pos(p)) + def mfExpressionWrapper[T <: Expression](inner: P[T]): P[T] = for { + expr <- inner + firstIndices <- index.rep + fieldPath <- (HWS ~ "->" ~/ AWS ~/ identifier ~/ index.rep).rep + } yield (expr, firstIndices, fieldPath) match { + case (_, Seq(), Seq()) => expr + case (VariableExpression(vname), Seq(i), Seq()) => IndexedExpression(vname, i).asInstanceOf[T] + case _ => IndirectFieldExpression(expr, firstIndices, fieldPath).asInstanceOf[T] + } +// def mfLhsExpression: P[LhsExpression] = for { +// (p, left) <- position() ~ mfLhsExpressionSimple +// rightOpt <- (HWS ~ ":" ~/ HWS ~ mfLhsExpressionSimple).? +// } yield rightOpt.fold(left)(right => SeparateBytesExpression(left, right).pos(p)) + + def mfLhsExpressionSimple: P[LhsExpression] = + mfExpressionWrapper[LhsExpression](derefExpression | (position() ~ identifier).map{case (p,n) => VariableExpression(n).pos(p)} ~ HWS) + + def mfLhsExpression: P[LhsExpression] = + mfExpression(nonStatementLevel, false).filter(_.isInstanceOf[LhsExpression]).map(_.asInstanceOf[LhsExpression]) def mfParenExpr(allowIntelHex: Boolean): P[Expression] = P("(" ~/ AWS ~/ mfExpression(nonStatementLevel, allowIntelHex) ~ AWS ~/ ")") - def mfIndexedExpression: P[IndexedExpression] = for { - p <- position() - array <- identifier - index <- HWS ~ "[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~/ "]" - } yield IndexedExpression(array, index).pos(p) - def functionCall(allowIntelHex: Boolean): P[FunctionCallExpression] = for { p <- position() name <- identifier @@ -339,12 +352,15 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri inner <- mfParenExpr(false) } yield DerefDebuggingExpression(inner, yens.length).pos(p) - val expressionStatement: P[Seq[ExecutableStatement]] = mfExpression(0, false).map(x => Seq(ExpressionStatement(x))) - - val assignmentStatement: P[Seq[ExecutableStatement]] = - (position() ~ mfLhsExpression ~ HWS ~ "=" ~/ HWS ~ mfExpression(1, false)).map { - case (p, l, r) => Seq(Assignment(l, r).pos(p)) - } + val expressionStatement: P[Seq[ExecutableStatement]] = mfExpression(0, false).map { + case FunctionCallExpression("=", List(t: LhsExpression, s)) => + Seq(Assignment(t, s).pos(t.position)) + case x@FunctionCallExpression("=", exprs) => + log.error("Invalid left-hand-side of an assignment", x.position) + exprs.map(ExpressionStatement) + case x => + Seq(ExpressionStatement(x).pos(x.position)) + } def keywordStatement: P[Seq[ExecutableStatement]] = P( returnOrDispatchStatement | @@ -355,8 +371,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri doWhileStatement | breakStatement | continueStatement | - inlineAssembly | - assignmentStatement) + inlineAssembly) def executableStatement: P[Seq[ExecutableStatement]] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.map(_.pos(p)) } @@ -391,7 +406,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri val dispatchBranch: P[ReturnDispatchBranch] = for { pos <- position() l <- dispatchLabel ~/ HWS ~/ "@" ~/ HWS - f <- tightMfExpressionButNotCall(false) ~/ HWS + f <- tightMfExpressionButNotCall(false, allowTopLevelIndexing = false) ~/ HWS parameters <- ("(" ~/ position("dispatch actual parameters") ~ AWS ~/ mfExpression(nonStatementLevel, false).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").? } yield ReturnDispatchBranch(l, f, parameters.map(_._2.toList).getOrElse(Nil)).pos(pos) @@ -641,7 +656,7 @@ object MfParser { } val mfOperators = List( - List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'="), + List("+=", "-=", "+'=", "-'=", "^=", "&=", "|=", "*=", "*'=", "<<=", ">>=", "<<'=", ">>'=", "="), List("||", "^^"), List("&&"), List("==", "<=", ">=", "!=", "<", ">"), @@ -683,5 +698,5 @@ object MfParser { val functionFlags: P[Set[String]] = flags_("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt") - val InvalidReturnTypes = Set("enum", "alias", "array", "const", "stack", "register", "static", "volatile", "import") + val InvalidReturnTypes = Set("enum", "alias", "array", "const", "stack", "register", "static", "volatile", "import", "struct", "union") } diff --git a/src/main/scala/millfork/parser/Preprocessor.scala b/src/main/scala/millfork/parser/Preprocessor.scala index 0fba7c43..a6d94cb4 100644 --- a/src/main/scala/millfork/parser/Preprocessor.scala +++ b/src/main/scala/millfork/parser/Preprocessor.scala @@ -188,7 +188,7 @@ class PreprocessorParser(options: CompilationOptions) { def inner: P[SeparatedList[Q, String]] = { for { head <- tightMfExpression ~/ HWS - maybeOperator <- StringIn(allowedOperators: _*).!.? + maybeOperator <- (StringIn(allowedOperators: _*).! ~ !CharIn(Seq('-','+','/'))).? maybeTail <- maybeOperator.fold[P[Option[List[(String, Q)]]]](Pass.map(_ => None))(o => (HWS ~/ inner ~/ HWS).map(x2 => Some((o -> x2.head) :: x2.tail))) } yield { maybeTail.fold[SeparatedList[Q, String]](SeparatedList.of(head))(t => SeparatedList(head, t)) diff --git a/src/test/scala/millfork/test/PointerSuite.scala b/src/test/scala/millfork/test/PointerSuite.scala index 39bb5504..16f97a88 100644 --- a/src/test/scala/millfork/test/PointerSuite.scala +++ b/src/test/scala/millfork/test/PointerSuite.scala @@ -2,12 +2,12 @@ package millfork.test import millfork.Cpu import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun} -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.{AppendedClues, FunSuite, Matchers} /** * @author Karol Stasiak */ -class PointerSuite extends FunSuite with Matchers { +class PointerSuite extends FunSuite with Matchers with AppendedClues { test("Pointers outside zeropage") { EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Sixteen, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( @@ -168,4 +168,45 @@ class PointerSuite extends FunSuite with Matchers { } } + + test("Complex pointers") { + // TODO: optimize it when inlined + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + """ + | array output[3] @$c000 + | struct s { + | pointer p + | } + | s tmp + | pointer.s tmpptr + | pointer.pointer.s get() { + | tmp.p = output.addr + | tmpptr = tmp.pointer + | return tmpptr.pointer + | } + | void main() { + | get()[0]->p[0] = 5 + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(5) + } + } + + test("Indexing returned pointers") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + """ + | array output[10] @$c000 + | pointer get() = output.addr + | void main() { + | byte i + | for i,0,paralleluntil,10 { + | get()[i] = 42 + | } + | } + """.stripMargin) { m => + for(i <- 0xc000 until 0xc00a) { + m.readByte(i) should equal(42) withClue i + } + } + } }