From 89b23ee42577bbd5cb21719156542ce8ae1ea278 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Sat, 10 Mar 2018 21:52:28 +0100 Subject: [PATCH] Expanding macros from within assembly --- doc/lang/assembly.md | 16 +++++ src/main/scala/millfork/parser/MfParser.scala | 72 ++++++++++--------- src/test/scala/millfork/test/MacroSuite.scala | 21 ++++++ 3 files changed, 74 insertions(+), 35 deletions(-) diff --git a/doc/lang/assembly.md b/doc/lang/assembly.md index f499d164..6b44851d 100644 --- a/doc/lang/assembly.md +++ b/doc/lang/assembly.md @@ -52,6 +52,22 @@ but you need to be careful with using absolute vs immediate addressing: Any assembly opcode can be prefixed with `?`, which allows the optimizer change it or elide it if needed. Opcodes without that prefix will be always compiled as written. +You can insert macros into assembly, by prefixing them with `+` and using the same syntax as in Millfork: + + macro void run(byte x) { + output = x + } + + byte output @$c000 + + void main () { + byte a + a = 7 + asm { + + run(a) + } + } + Currently there is no way to insert raw bytes into inline assembly (required for certain optimizations and calling conventions). diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 9c3ebee2..e720e808 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -138,7 +138,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o def flags(allowed: String*): P[Set[String]] = StringIn(allowed: _*).!.rep(min = 0, sep = SWS).map(_.toSet).opaque("") - def variableDefinition(implicitlyGlobal: Boolean): P[DeclarationStatement] = for { + def variableDefinition(implicitlyGlobal: Boolean): P[Seq[DeclarationStatement]] = for { p <- position() flags <- flags("const", "static", "volatile", "stack", "register") ~ HWS typ <- identifier ~ SWS @@ -147,13 +147,13 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o initialValue <- ("=" ~/ HWS ~/ mlExpression(1)).? ~ HWS _ <- &(EOL) ~/ "" } yield { - VariableDeclarationStatement(name, typ, + Seq(VariableDeclarationStatement(name, typ, global = implicitlyGlobal || flags("static"), stack = flags("stack"), constant = flags("const"), volatile = flags("volatile"), register = flags("register"), - initialValue, addr).pos(p) + initialValue, addr).pos(p)) } val externFunctionBody: P[Option[List[Statement]]] = P("extern" ~/ PassWith(None)) @@ -227,13 +227,13 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o def arrayContents: P[List[Expression]] = arrayListContents | arrayFileContents | arrayStringContents - def arrayDefinition: P[ArrayDeclarationStatement] = for { + def arrayDefinition: P[Seq[ArrayDeclarationStatement]] = for { p <- position() name <- "array" ~ !letterOrDigit ~/ SWS ~ identifier ~ HWS length <- ("[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~ AWS ~ "]").? ~ HWS addr <- ("@" ~/ HWS ~/ mlExpression(1)).? ~/ HWS contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS - } yield ArrayDeclarationStatement(name, length, addr, contents).pos(p) + } yield Seq(ArrayDeclarationStatement(name, length, addr, contents).pos(p)) def tightMlExpression: P[Expression] = P(mlParenExpr | functionCall | mlIndexedExpression | atom) // TODO @@ -305,14 +305,14 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o params <- HWS ~ "(" ~/ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "" } yield FunctionCallExpression(name, params.toList).pos(p) - val expressionStatement: P[ExecutableStatement] = mlExpression(0).map(ExpressionStatement) + val expressionStatement: P[Seq[ExecutableStatement]] = mlExpression(0).map(x => Seq(ExpressionStatement(x))) - val assignmentStatement: P[ExecutableStatement] = + val assignmentStatement: P[Seq[ExecutableStatement]] = (position() ~ mlLhsExpression ~ HWS ~ "=" ~/ HWS ~ mlExpression(1)).map { - case (p, l, r) => Assignment(l, r).pos(p) + case (p, l, r) => Seq(Assignment(l, r).pos(p)) } - def keywordStatement: P[ExecutableStatement] = P( + def keywordStatement: P[Seq[ExecutableStatement]] = P( returnOrDispatchStatement | ifStatement | whileStatement | @@ -323,7 +323,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o inlineAssembly | assignmentStatement) - def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) } + def executableStatement: P[Seq[ExecutableStatement]] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.map(_.pos(p)) } // TODO: label and instruction in one line def asmLabel: P[ExecutableStatement] = (identifier ~ HWS ~ ":" ~/ HWS).map(l => AssemblyStatement(Opcode.LABEL, AddrMode.DoesNotExist, VariableExpression(l), elidable = true)) @@ -381,15 +381,17 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } } - def asmStatement: P[ExecutableStatement] = (position("assembly statement") ~ P(asmLabel | asmInstruction)).map { case (p, s) => s.pos(p) } // TODO: macros + def asmMacro: P[ExecutableStatement] = ("+" ~/ HWS ~/ functionCall).map(ExpressionStatement) - def statement: P[Statement] = (position() ~ P(keywordStatement | variableDefinition(false) | expressionStatement)).map { case (p, s) => s.pos(p) } + def asmStatement: P[ExecutableStatement] = (position("assembly statement") ~ P(asmLabel | asmMacro | asmInstruction)).map { case (p, s) => s.pos(p) } // TODO: macros + + def statement: P[Seq[Statement]] = (position() ~ P(keywordStatement | variableDefinition(false) | expressionStatement)).map { case (p, s) => s.map(_.pos(p)) } def asmStatements: P[List[ExecutableStatement]] = ("{" ~/ AWS ~/ asmStatement.rep(sep = NoCut(EOL) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) - def statements: P[List[Statement]] = ("{" ~/ AWS ~ statement.rep(sep = NoCut(EOL) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) + def statements: P[List[Statement]] = ("{" ~/ AWS ~ statement.rep(sep = NoCut(EOL) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.flatten.toList) - def executableStatements: P[Seq[ExecutableStatement]] = "{" ~/ AWS ~/ executableStatement.rep(sep = NoCut(EOL) ~ !"}" ~/ Pass) ~/ AWS ~ "}" + def executableStatements: P[Seq[ExecutableStatement]] = ("{" ~/ AWS ~/ executableStatement.rep(sep = NoCut(EOL) ~ !"}" ~/ Pass) ~/ AWS ~ "}").map(_.flatten) def dispatchLabel: P[ReturnDispatchLabel] = ("default" ~ !letterOrDigit ~/ AWS ~/ ("(" ~/ position("default branch range") ~ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").?).map{ @@ -409,31 +411,31 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o parameters <- ("(" ~/ position("dispatch actual parameters") ~ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").? } yield ReturnDispatchBranch(l, f, parameters.map(_._2.toList).getOrElse(Nil)).pos(pos) - def dispatchStatementBody: P[ExecutableStatement] = for { + def dispatchStatementBody: P[Seq[ExecutableStatement]] = for { indexer <- "[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~/ AWS ~/ "]" ~/ AWS _ <- position("dispatch statement body") parameters <- ("(" ~/ position("dispatch parameters") ~ AWS ~/ mlLhsExpression.rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").? _ <- AWS ~/ position("dispatch statement body") ~/ "{" ~/ AWS branches <- dispatchBranch.rep(sep = EOL ~ !"}" ~/ Pass) _ <- AWS ~/ "}" - } yield ReturnDispatchStatement(indexer, parameters.map(_._2.toList).getOrElse(Nil), branches.toList) + } yield Seq(ReturnDispatchStatement(indexer, parameters.map(_._2.toList).getOrElse(Nil), branches.toList)) - def returnOrDispatchStatement: P[ExecutableStatement] = "return" ~ !letterOrDigit ~/ HWS ~ (dispatchStatementBody | mlExpression(nonStatementLevel).?.map(ReturnStatement)) + def returnOrDispatchStatement: P[Seq[ExecutableStatement]] = "return" ~ !letterOrDigit ~/ HWS ~ (dispatchStatementBody | mlExpression(nonStatementLevel).?.map(ReturnStatement).map(Seq(_))) - def breakStatement: P[ExecutableStatement] = ("break" ~ !letterOrDigit ~/ HWS ~ identifier.?).map(l => BreakStatement(l.getOrElse(""))) + def breakStatement: P[Seq[ExecutableStatement]] = ("break" ~ !letterOrDigit ~/ HWS ~ identifier.?).map(l => Seq(BreakStatement(l.getOrElse("")))) - def continueStatement: P[ExecutableStatement] = ("continue" ~ !letterOrDigit ~/ HWS ~ identifier.?).map(l => ContinueStatement(l.getOrElse(""))) + def continueStatement: P[Seq[ExecutableStatement]] = ("continue" ~ !letterOrDigit ~/ HWS ~ identifier.?).map(l => Seq(ContinueStatement(l.getOrElse("")))) - def ifStatement: P[ExecutableStatement] = for { + def ifStatement: P[Seq[ExecutableStatement]] = for { condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) thenBranch <- AWS ~/ executableStatements - elseBranch <- (AWS ~ "else" ~/ AWS ~/ (ifStatement.map(_ :: Nil) | executableStatements)).? - } yield IfStatement(condition, thenBranch.toList, elseBranch.getOrElse(Nil).toList) + elseBranch <- (AWS ~ "else" ~/ AWS ~/ (ifStatement | executableStatements)).? + } yield Seq(IfStatement(condition, thenBranch.toList, elseBranch.getOrElse(Nil).toList)) - def whileStatement: P[ExecutableStatement] = for { + def whileStatement: P[Seq[ExecutableStatement]] = for { condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) body <- AWS ~ executableStatements - } yield WhileStatement(condition, body.toList, Nil) + } yield Seq(WhileStatement(condition, body.toList, Nil)) def forDirection: P[ForDirection.Value] = ("parallel" ~ HWS ~ "to").!.map(_ => ForDirection.ParallelTo) | @@ -442,26 +444,26 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o "to".!.map(_ => ForDirection.To) | ("down" ~/ HWS ~/ "to").!.map(_ => ForDirection.DownTo) - def forStatement: P[ExecutableStatement] = for { + def forStatement: P[Seq[ExecutableStatement]] = for { identifier <- "for" ~ SWS ~/ identifier ~/ HWS ~ "," ~/ HWS ~ Pass start <- mlExpression(nonStatementLevel) ~ HWS ~ "," ~/ HWS ~/ Pass direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass end <- mlExpression(nonStatementLevel) body <- AWS ~ executableStatements - } yield ForStatement(identifier, start, end, direction, body.toList) + } yield Seq(ForStatement(identifier, start, end, direction, body.toList)) - def inlineAssembly: P[ExecutableStatement] = for { + def inlineAssembly: P[Seq[ExecutableStatement]] = for { condition <- "asm" ~ !letterOrDigit ~/ Pass body <- AWS ~ asmStatements - } yield BlockStatement(body) + } yield body //noinspection MutatorLikeMethodIsParameterless - def doWhileStatement: P[ExecutableStatement] = for { + def doWhileStatement: P[Seq[ExecutableStatement]] = for { body <- "do" ~ !letterOrDigit ~/ AWS ~ executableStatements ~/ AWS condition <- "while" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) - } yield DoWhileStatement(body.toList, Nil, condition) + } yield Seq(DoWhileStatement(body.toList, Nil, condition)) - def functionDefinition: P[DeclarationStatement] = for { + def functionDefinition: P[Seq[DeclarationStatement]] = for { p <- position() flags <- flags("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt") ~ HWS returnType <- identifier ~ SWS @@ -507,7 +509,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } case None => () } - FunctionDeclarationStatement(name, returnType, params.toList, + Seq(FunctionDeclarationStatement(name, returnType, params.toList, addr, statements, flags("macro"), @@ -515,16 +517,16 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o flags("asm"), flags("interrupt"), flags("kernal_interrupt"), - flags("reentrant")).pos(p) + flags("reentrant")).pos(p)) } - def importStatement: Parser[ImportStatement] = ("import" ~ !letterOrDigit ~/ SWS ~/ identifier).map(ImportStatement) + def importStatement: Parser[Seq[ImportStatement]] = ("import" ~ !letterOrDigit ~/ SWS ~/ identifier).map(x => Seq(ImportStatement(x))) def program: Parser[Program] = for { _ <- Start ~/ AWS ~/ Pass definitions <- (importStatement | arrayDefinition | functionDefinition | variableDefinition(true)).rep(sep = EOL) _ <- AWS ~ End - } yield Program(definitions.toList) + } yield Program(definitions.flatten.toList) } diff --git a/src/test/scala/millfork/test/MacroSuite.scala b/src/test/scala/millfork/test/MacroSuite.scala index 00f2aa91..59fb0787 100644 --- a/src/test/scala/millfork/test/MacroSuite.scala +++ b/src/test/scala/millfork/test/MacroSuite.scala @@ -26,4 +26,25 @@ class MacroSuite extends FunSuite with Matchers { m.readByte(0xc000) should equal(7) } } + + test("Macros in assembly") { + EmuBenchmarkRun( + """ + | macro void run(byte x) { + | output = x + | } + | + | byte output @$c000 + | + | void main () { + | byte a + | a = 7 + | asm { + | + run(a) + | } + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(7) + } + } }