From 8b6e89f9a4801907d03a35194f6c7a8f5043b66a Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Sun, 2 Feb 2020 23:19:17 +0100 Subject: [PATCH] Various improvements for macros (fixes #39 and pertains to #40) --- docs/abi/inlining.md | 5 +- .../millfork/compiler/MacroExpander.scala | 46 +++++++- src/main/scala/millfork/env/Environment.scala | 10 +- src/main/scala/millfork/node/Node.scala | 76 ++++++++++++- src/test/scala/millfork/test/MacroSuite.scala | 107 +++++++++++++++++- 5 files changed, 236 insertions(+), 8 deletions(-) diff --git a/docs/abi/inlining.md b/docs/abi/inlining.md index 637d7801..13db60a5 100644 --- a/docs/abi/inlining.md +++ b/docs/abi/inlining.md @@ -20,7 +20,7 @@ It implies the following: * in case of `asm` macros, the parameters must be defined as either `const` (compile-time constants) or `ref` (variables) -* in case of non-`asm` macros, the parameters must be variables +* in case of non-`asm` macros, the parameters must be variables; exceptionally, their type may be declared as `void` * macros do not have their own scope (they reuse the scope from their invocations) – exceptions: the parameters and the local labels defined in assembly @@ -28,6 +28,9 @@ It implies the following: When invoking a macro, you need to pass variables as arguments to parameters annotated with `ref` and constants as arguments annotated with `const`. +Invoking a non-`asm` macro requires the types of passed variables to match precisely. No type conversions are performed. +Exception: parameters of type `void` can accept a variable of any type. + You can invoke a macro from assembly, by preceding the invocation with `+` Examples: diff --git a/src/main/scala/millfork/compiler/MacroExpander.scala b/src/main/scala/millfork/compiler/MacroExpander.scala index f15dfcf5..108599a3 100644 --- a/src/main/scala/millfork/compiler/MacroExpander.scala +++ b/src/main/scala/millfork/compiler/MacroExpander.scala @@ -14,6 +14,7 @@ abstract class MacroExpander[T <: AbstractCode] { def prepareAssemblyParams(ctx: CompilationContext, assParams: List[AssemblyParam], params: List[Expression], code: List[ExecutableStatement]): (List[T], List[ExecutableStatement]) def replaceVariable(stmt: Statement, paramName: String, target: Expression): Statement = { + val paramNamePeriod = paramName + "." def f[S <: Expression](e: S) = e.replaceVariable(paramName, target) def fx[S <: Expression](e: S) = e.replaceVariable(paramName, target).asInstanceOf[LhsExpression] @@ -22,7 +23,10 @@ abstract class MacroExpander[T <: AbstractCode] { def gx[S <: ExecutableStatement](s: S) = replaceVariable(s, paramName, target).asInstanceOf[ExecutableStatement] - def h(s: String) = if (s == paramName) target.asInstanceOf[VariableExpression].name else s + def h(s: String): String = + if (s == paramName) target.asInstanceOf[VariableExpression].name + else if (s.startsWith(paramNamePeriod)) target.asInstanceOf[VariableExpression].name + s.stripPrefix(paramName) + else s (stmt match { case RawBytesStatement(contents, be) => RawBytesStatement(contents.replaceVariable(paramName, target), be) @@ -47,6 +51,44 @@ abstract class MacroExpander[T <: AbstractCode] { }).pos(stmt.position) } + def renameVariable(stmt: Statement, paramName: String, target: String): Statement = { + val paramNamePeriod = paramName + "." + def f[S <: Expression](e: S) = e.renameVariable(paramName, target) + + def fx[S <: Expression](e: S) = e.renameVariable(paramName, target).asInstanceOf[LhsExpression] + + def g[S <: Statement](s: S) = renameVariable(s, paramName, target) + + def gx[S <: ExecutableStatement](s: S) = renameVariable(s, paramName, target).asInstanceOf[ExecutableStatement] + + def h(s: String): String = + if (s == paramName) target.asInstanceOf[VariableExpression].name + else if (s.startsWith(paramNamePeriod)) target.asInstanceOf[VariableExpression].name + s.stripPrefix(paramName) + else s + + (stmt match { + case RawBytesStatement(contents, be) => RawBytesStatement(contents.renameVariable(paramName, target), be) + case ExpressionStatement(e) => ExpressionStatement(e.renameVariable(paramName, target)) + case ReturnStatement(e) => ReturnStatement(e.map(f)) + case ReturnDispatchStatement(i, ps, bs) => ReturnDispatchStatement(i.renameVariable(paramName, target), ps.map(fx), bs.map { + case ReturnDispatchBranch(l, fu, pps) => ReturnDispatchBranch(l, f(fu), pps.map(f)) + }) + case WhileStatement(c, b, i, n) => WhileStatement(f(c), b.map(gx), i.map(gx), n) + case DoWhileStatement(b, i, c, n) => DoWhileStatement(b.map(gx), i.map(gx), f(c), n) + case ForStatement(v, start, end, dir, body) => ForStatement(h(v), f(start), f(end), dir, body.map(gx)) + case IfStatement(c, t, e) => IfStatement(f(c), t.map(gx), e.map(gx)) + case s: Z80AssemblyStatement => s.copy(expression = f(s.expression), offsetExpression = s.offsetExpression.map(f)) + case s: MosAssemblyStatement => s.copy(expression = f(s.expression)) + case Assignment(d, s) => Assignment(fx(d), f(s)) + case BreakStatement(s) => if (s == paramName) BreakStatement(target.toString) else stmt + case ContinueStatement(s) => if (s == paramName) ContinueStatement(target.toString) else stmt + case s: EmptyStatement => s.copy(toTypecheck = s.toTypecheck.map(gx)) + case _ => + println(stmt) + ??? + }).pos(stmt.position) + } + def inlineFunction(ctx: CompilationContext, i: MacroFunction, params: List[Expression], position: Option[Position]): (List[T], List[ExecutableStatement]) = { var paramPreparation = List[T]() var actualCode = i.code @@ -62,7 +104,7 @@ abstract class MacroExpander[T <: AbstractCode] { normalParams.foreach(param => i.environment.removeVariable(param.name)) params.zip(normalParams).foreach { case (v@VariableExpression(_), MemoryVariable(paramName, paramType, _)) => - actualCode = actualCode.map(stmt => replaceVariable(stmt, paramName.stripPrefix(i.environment.prefix), v).asInstanceOf[ExecutableStatement]) + actualCode = actualCode.map(stmt => renameVariable(stmt, paramName.stripPrefix(i.environment.prefix), v.name).asInstanceOf[ExecutableStatement]) case (v@IndexedExpression(_, _), MemoryVariable(paramName, paramType, _)) => actualCode = actualCode.map(stmt => replaceVariable(stmt, paramName.stripPrefix(i.environment.prefix), v).asInstanceOf[ExecutableStatement]) case _ => diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index e76bf416..dbe44ef1 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -1971,8 +1971,14 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa function.params match { case NormalParamSignature(params) => function.params.types.zip(actualParams).zip(params).foreach { case ((required, (actual, expr)), m) => - if (!actual.isAssignableTo(required)) { - log.error(s"Invalid value for parameter `${m.name}` of function `$name`", expr.position) + if (function.isInstanceOf[MacroFunction]) { + if (required != VoidType && actual != required) { + log.error(s"Invalid argument type for parameter `${m.name}` of macro function `$name`: required: ${required.name}, actual: ${actual.name}", expr.position) + } + } else { + if (!actual.isAssignableTo(required)) { + log.error(s"Invalid value for parameter `${m.name}` of function `$name`", expr.position) + } } } case AssemblyParamSignature(params) => diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 9c683968..70af40a7 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -34,6 +34,7 @@ object Node { } sealed trait Expression extends Node { + def renameVariable(variable: String, newVariable: String): Expression def replaceVariable(variable: String, actualParam: Expression): Expression def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression def containsVariable(variable: String): Boolean @@ -55,6 +56,7 @@ sealed trait Expression extends Node { } case class ConstantArrayElementExpression(constant: Constant) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = this override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this override def containsVariable(variable: String): Boolean = false @@ -64,6 +66,7 @@ case class ConstantArrayElementExpression(constant: Constant) extends Expression } case class LiteralExpression(value: Long, requiredSize: Int) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = this override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this override def containsVariable(variable: String): Boolean = false @@ -73,6 +76,7 @@ case class LiteralExpression(value: Long, requiredSize: Int) extends Expression } case class TextLiteralExpression(characters: List[Expression]) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = this override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this override def containsVariable(variable: String): Boolean = false @@ -82,6 +86,7 @@ case class TextLiteralExpression(characters: List[Expression]) extends Expressio } case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = this override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this override def containsVariable(variable: String): Boolean = false @@ -91,6 +96,7 @@ case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expre } case class BooleanLiteralExpression(value: Boolean) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = this override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this override def containsVariable(variable: String): Boolean = false @@ -102,6 +108,7 @@ case class BooleanLiteralExpression(value: Boolean) extends Expression { sealed trait LhsExpression extends Expression case object BlackHoleExpression extends LhsExpression { + override def renameVariable(variable: String, newVariable: String): Expression = this override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this override def containsVariable(variable: String): Boolean = false @@ -111,6 +118,10 @@ case object BlackHoleExpression extends LhsExpression { } case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression { + override def renameVariable(variable: String, newVariable: String): Expression = + SeparateBytesExpression( + hi.renameVariable(variable, newVariable), + lo.renameVariable(variable, newVariable)).pos(position) def replaceVariable(variable: String, actualParam: Expression): Expression = SeparateBytesExpression( hi.replaceVariable(variable, actualParam), @@ -126,6 +137,8 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx } case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = + SumExpression(expressions.map { case (n, e) => n -> e.renameVariable(variable, newVariable) }, decimal).pos(position) override def replaceVariable(variable: String, actualParam: Expression): Expression = SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal).pos(position) override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = @@ -147,6 +160,10 @@ case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Bool } case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = + FunctionCallExpression(functionName, expressions.map { + _.renameVariable(variable, newVariable) + }).pos(position) override def replaceVariable(variable: String, actualParam: Expression): Expression = FunctionCallExpression(functionName, expressions.map { _.replaceVariable(variable, actualParam) @@ -162,6 +179,8 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres } case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression { + override def renameVariable(variable: String, newVariable: String): Expression = + HalfWordExpression(expression.renameVariable(variable, newVariable), hiByte).pos(position) override def replaceVariable(variable: String, actualParam: Expression): Expression = HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte).pos(position) override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = @@ -267,8 +286,25 @@ object M6809Register extends Enumeration { //case class Indexing(child: Expression, register: Register.Value) extends Expression case class VariableExpression(name: String) extends LhsExpression { + override def renameVariable(variable: String, newVariable: String): Expression = + if (name == variable) + VariableExpression(newVariable).pos(position) + else if (name.startsWith(variable) && name(variable.length) == '.') + VariableExpression(newVariable + name.stripPrefix(variable)).pos(position) + else this override def replaceVariable(variable: String, actualParam: Expression): Expression = - if (name == variable) actualParam else this + if (name == variable) actualParam + else if (name.startsWith(variable) && name(variable.length) == '.') { + actualParam match { + case VariableExpression(newVariable) => this.renameVariable(variable, newVariable) + case _ => + name.stripPrefix(variable) match { + case ".lo" => FunctionCallExpression("lo", List(this)).pos(position) + case ".hi" => FunctionCallExpression("hi", List(this)).pos(position) + case _ => ??? // TODO + } + } + } else this override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this override def containsVariable(variable: String): Boolean = name == variable override def getPointies: Seq[String] = if (name.endsWith(".addr.lo")) Seq(name.stripSuffix(".addr.lo")) else Seq.empty @@ -277,6 +313,16 @@ case class VariableExpression(name: String) extends LhsExpression { } case class IndexedExpression(name: String, index: Expression) extends LhsExpression { + override def renameVariable(variable: String, newVariable: String): Expression = { + val newIndex = index.renameVariable(variable, newVariable) + if (name == variable) + IndexedExpression(newVariable, newIndex).pos(position) + else if (name.startsWith(variable) && name(variable.length) == '.') + IndexedExpression(newVariable + name.stripPrefix(variable), newIndex).pos(position) + else + IndexedExpression(name, newIndex).pos(position) + } + override def replaceVariable(variable: String, actualParam: Expression): Expression = if (name == variable) { actualParam match { @@ -295,6 +341,11 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress } case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expression], fields: Seq[(Boolean, String, Seq[Expression])]) extends LhsExpression { + override def renameVariable(variable: String, newVariable: String): Expression = + IndirectFieldExpression( + root.renameVariable(variable, newVariable), + firstIndices.map(_.renameVariable(variable, newVariable)), + fields.map{case (dot, f, i) => (dot, f, i.map(_.renameVariable(variable, newVariable)))}) override def replaceVariable(variable: String, actualParam: Expression): Expression = IndirectFieldExpression( root.replaceVariable(variable, actualParam), @@ -323,8 +374,10 @@ case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expressio } case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) extends LhsExpression { + override def renameVariable(variable: String, newVariable: String): Expression = DerefDebuggingExpression(inner.renameVariable(variable, newVariable), preferredSize) + override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefDebuggingExpression(inner.replaceVariable(variable, actualParam), preferredSize) - + override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = DerefDebuggingExpression(inner.replaceIndexedExpression(predicate, replacement), preferredSize) @@ -341,6 +394,7 @@ case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) exten } case class DerefExpression(inner: Expression, offset: Int, targetType: Type) extends LhsExpression { + override def renameVariable(variable: String, newVariable: String): Expression = DerefExpression(inner.renameVariable(variable, newVariable), offset, targetType) override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefExpression(inner.replaceVariable(variable, actualParam), offset, targetType) override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = @@ -396,12 +450,16 @@ case class VariableDeclarationStatement(name: String, trait ArrayContents extends Node { def getAllExpressions(bigEndian: Boolean): List[Expression] + def renameVariable(variableToRename: String, newVariable: String): ArrayContents def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents } case class LiteralContents(contents: List[Expression]) extends ArrayContents { override def getAllExpressions(bigEndian: Boolean): List[Expression] = contents + override def renameVariable(variableToRename: String, newVariable: String): ArrayContents = + LiteralContents(contents.map(_.renameVariable(variableToRename, newVariable))) + override def replaceVariable(variable: String, expression: Expression): ArrayContents = LiteralContents(contents.map(_.replaceVariable(variable, expression))) } @@ -409,6 +467,14 @@ case class LiteralContents(contents: List[Expression]) extends ArrayContents { case class ForLoopContents(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: ArrayContents) extends ArrayContents { override def getAllExpressions(bigEndian: Boolean): List[Expression] = start :: end :: body.getAllExpressions(bigEndian).map(_.replaceVariable(variable, LiteralExpression(0, 1))) + override def renameVariable(variableToRename: String, newVariable: String): ArrayContents = + if (variableToRename == variable) this else ForLoopContents( + variable, + start.renameVariable(variableToRename, newVariable), + end.renameVariable(variableToRename, newVariable), + direction, + body.renameVariable(variableToRename, newVariable)) + override def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents = if (variableToReplace == variable) this else ForLoopContents( variable, @@ -421,6 +487,9 @@ case class ForLoopContents(variable: String, start: Expression, end: Expression, case class CombinedContents(contents: List[ArrayContents]) extends ArrayContents { override def getAllExpressions(bigEndian: Boolean): List[Expression] = contents.flatMap(_.getAllExpressions(bigEndian)) + override def renameVariable(variableToRename: String, newVariable: String): ArrayContents = + CombinedContents(contents.map(_.renameVariable(variableToRename, newVariable))) + override def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents = CombinedContents(contents.map(_.replaceVariable(variableToReplace, expression))) } @@ -459,6 +528,9 @@ case class ProcessedContents(processor: String, values: ArrayContents) extends A case "struct" => values.getAllExpressions(bigEndian) // not used for emitting actual arrays } + override def renameVariable(variableToRename: String, newVariable: String): ArrayContents = + ProcessedContents(processor, values.renameVariable(variableToRename, newVariable)) + override def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents = ProcessedContents(processor, values.replaceVariable(variableToReplace, expression)) } diff --git a/src/test/scala/millfork/test/MacroSuite.scala b/src/test/scala/millfork/test/MacroSuite.scala index 8a4ea1e0..d1781d22 100644 --- a/src/test/scala/millfork/test/MacroSuite.scala +++ b/src/test/scala/millfork/test/MacroSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun} +import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, ShouldNotCompile} import org.scalatest.{FunSuite, Matchers} /** @@ -73,4 +73,109 @@ class MacroSuite extends FunSuite with Matchers { m.readByte(0xc000) should equal(7) } } + + test("Macro parameter type mismatch") { + ShouldNotCompile( + """ + | byte input + | byte output @$c000 + | + |void main() { + | input = $FF + | test_signed_macro(input) + |} + | + |macro void test_signed_macro(sbyte value) { + | if value > 3 { + | output = 1 + | } + |} + """.stripMargin) + } + + test("Macro void parameter") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + | byte input + | byte output @$c000 + | + |void main() { + | input = $FF + | test_signed_macro(input) + |} + | + |macro void test_signed_macro(void value) { + | if value > 3 { + | output = 1 + | } + |} + """.stripMargin) { m => + m.readByte(0xc000) should equal(1) + } + } + + test("Some important macro test") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + | byte input + | byte output @$c000 + | + |void main() { + | input = $FF + | test_signed_macro(input) + |} + | + |macro void test_signed_macro(void value) { + | if sbyte(value) > 3 { + | output = 1 + | } else { + | output = 3 + | } + |} + """.stripMargin) { m => + m.readByte(0xc000) should equal(3) + } + } + + test("Accessing fields of macro parameters") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + |byte output @$c000 + | + |word test = $0380 + | + |void main() { + | test_macro(test) + |} + | + |macro void test_macro(word value) { + | if value.hi > 0 { + | output = 1 + | } + |} + """.stripMargin) { m => + m.readByte(0xc000) should equal(1) + } + } + + test("Accessing fields of macro parameters when using void") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + |byte output @$c000 + | + |word test = $0380 + | + |void main() { + | test_macro(test) + |} + | + |macro void test_macro(void value) { + | if value.hi > 0 { + | output = 1 + | } + |} + """.stripMargin) { m => + m.readByte(0xc000) should equal(1) + } + } }