1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-26 20:33:02 +00:00

Various improvements for macros (fixes #39 and pertains to #40)

This commit is contained in:
Karol Stasiak 2020-02-02 23:19:17 +01:00
parent 5cb4717de6
commit 8b6e89f9a4
5 changed files with 236 additions and 8 deletions

View File

@ -20,7 +20,7 @@ It implies the following:
* in case of `asm` macros, the parameters must be defined as either `const` (compile-time constants) or `ref` (variables)
* in case of non-`asm` macros, the parameters must be variables
* in case of non-`asm` macros, the parameters must be variables; exceptionally, their type may be declared as `void`
* macros do not have their own scope (they reuse the scope from their invocations) exceptions: the parameters and the local labels defined in assembly
@ -28,6 +28,9 @@ It implies the following:
When invoking a macro, you need to pass variables as arguments to parameters annotated with `ref` and constants as arguments annotated with `const`.
Invoking a non-`asm` macro requires the types of passed variables to match precisely. No type conversions are performed.
Exception: parameters of type `void` can accept a variable of any type.
You can invoke a macro from assembly, by preceding the invocation with `+`
Examples:

View File

@ -14,6 +14,7 @@ abstract class MacroExpander[T <: AbstractCode] {
def prepareAssemblyParams(ctx: CompilationContext, assParams: List[AssemblyParam], params: List[Expression], code: List[ExecutableStatement]): (List[T], List[ExecutableStatement])
def replaceVariable(stmt: Statement, paramName: String, target: Expression): Statement = {
val paramNamePeriod = paramName + "."
def f[S <: Expression](e: S) = e.replaceVariable(paramName, target)
def fx[S <: Expression](e: S) = e.replaceVariable(paramName, target).asInstanceOf[LhsExpression]
@ -22,7 +23,10 @@ abstract class MacroExpander[T <: AbstractCode] {
def gx[S <: ExecutableStatement](s: S) = replaceVariable(s, paramName, target).asInstanceOf[ExecutableStatement]
def h(s: String) = if (s == paramName) target.asInstanceOf[VariableExpression].name else s
def h(s: String): String =
if (s == paramName) target.asInstanceOf[VariableExpression].name
else if (s.startsWith(paramNamePeriod)) target.asInstanceOf[VariableExpression].name + s.stripPrefix(paramName)
else s
(stmt match {
case RawBytesStatement(contents, be) => RawBytesStatement(contents.replaceVariable(paramName, target), be)
@ -47,6 +51,44 @@ abstract class MacroExpander[T <: AbstractCode] {
}).pos(stmt.position)
}
def renameVariable(stmt: Statement, paramName: String, target: String): Statement = {
val paramNamePeriod = paramName + "."
def f[S <: Expression](e: S) = e.renameVariable(paramName, target)
def fx[S <: Expression](e: S) = e.renameVariable(paramName, target).asInstanceOf[LhsExpression]
def g[S <: Statement](s: S) = renameVariable(s, paramName, target)
def gx[S <: ExecutableStatement](s: S) = renameVariable(s, paramName, target).asInstanceOf[ExecutableStatement]
def h(s: String): String =
if (s == paramName) target.asInstanceOf[VariableExpression].name
else if (s.startsWith(paramNamePeriod)) target.asInstanceOf[VariableExpression].name + s.stripPrefix(paramName)
else s
(stmt match {
case RawBytesStatement(contents, be) => RawBytesStatement(contents.renameVariable(paramName, target), be)
case ExpressionStatement(e) => ExpressionStatement(e.renameVariable(paramName, target))
case ReturnStatement(e) => ReturnStatement(e.map(f))
case ReturnDispatchStatement(i, ps, bs) => ReturnDispatchStatement(i.renameVariable(paramName, target), ps.map(fx), bs.map {
case ReturnDispatchBranch(l, fu, pps) => ReturnDispatchBranch(l, f(fu), pps.map(f))
})
case WhileStatement(c, b, i, n) => WhileStatement(f(c), b.map(gx), i.map(gx), n)
case DoWhileStatement(b, i, c, n) => DoWhileStatement(b.map(gx), i.map(gx), f(c), n)
case ForStatement(v, start, end, dir, body) => ForStatement(h(v), f(start), f(end), dir, body.map(gx))
case IfStatement(c, t, e) => IfStatement(f(c), t.map(gx), e.map(gx))
case s: Z80AssemblyStatement => s.copy(expression = f(s.expression), offsetExpression = s.offsetExpression.map(f))
case s: MosAssemblyStatement => s.copy(expression = f(s.expression))
case Assignment(d, s) => Assignment(fx(d), f(s))
case BreakStatement(s) => if (s == paramName) BreakStatement(target.toString) else stmt
case ContinueStatement(s) => if (s == paramName) ContinueStatement(target.toString) else stmt
case s: EmptyStatement => s.copy(toTypecheck = s.toTypecheck.map(gx))
case _ =>
println(stmt)
???
}).pos(stmt.position)
}
def inlineFunction(ctx: CompilationContext, i: MacroFunction, params: List[Expression], position: Option[Position]): (List[T], List[ExecutableStatement]) = {
var paramPreparation = List[T]()
var actualCode = i.code
@ -62,7 +104,7 @@ abstract class MacroExpander[T <: AbstractCode] {
normalParams.foreach(param => i.environment.removeVariable(param.name))
params.zip(normalParams).foreach {
case (v@VariableExpression(_), MemoryVariable(paramName, paramType, _)) =>
actualCode = actualCode.map(stmt => replaceVariable(stmt, paramName.stripPrefix(i.environment.prefix), v).asInstanceOf[ExecutableStatement])
actualCode = actualCode.map(stmt => renameVariable(stmt, paramName.stripPrefix(i.environment.prefix), v.name).asInstanceOf[ExecutableStatement])
case (v@IndexedExpression(_, _), MemoryVariable(paramName, paramType, _)) =>
actualCode = actualCode.map(stmt => replaceVariable(stmt, paramName.stripPrefix(i.environment.prefix), v).asInstanceOf[ExecutableStatement])
case _ =>

View File

@ -1971,8 +1971,14 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
function.params match {
case NormalParamSignature(params) =>
function.params.types.zip(actualParams).zip(params).foreach { case ((required, (actual, expr)), m) =>
if (!actual.isAssignableTo(required)) {
log.error(s"Invalid value for parameter `${m.name}` of function `$name`", expr.position)
if (function.isInstanceOf[MacroFunction]) {
if (required != VoidType && actual != required) {
log.error(s"Invalid argument type for parameter `${m.name}` of macro function `$name`: required: ${required.name}, actual: ${actual.name}", expr.position)
}
} else {
if (!actual.isAssignableTo(required)) {
log.error(s"Invalid value for parameter `${m.name}` of function `$name`", expr.position)
}
}
}
case AssemblyParamSignature(params) =>

View File

@ -34,6 +34,7 @@ object Node {
}
sealed trait Expression extends Node {
def renameVariable(variable: String, newVariable: String): Expression
def replaceVariable(variable: String, actualParam: Expression): Expression
def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression
def containsVariable(variable: String): Boolean
@ -55,6 +56,7 @@ sealed trait Expression extends Node {
}
case class ConstantArrayElementExpression(constant: Constant) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression = this
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
@ -64,6 +66,7 @@ case class ConstantArrayElementExpression(constant: Constant) extends Expression
}
case class LiteralExpression(value: Long, requiredSize: Int) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression = this
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
@ -73,6 +76,7 @@ case class LiteralExpression(value: Long, requiredSize: Int) extends Expression
}
case class TextLiteralExpression(characters: List[Expression]) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression = this
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
@ -82,6 +86,7 @@ case class TextLiteralExpression(characters: List[Expression]) extends Expressio
}
case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression = this
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
@ -91,6 +96,7 @@ case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expre
}
case class BooleanLiteralExpression(value: Boolean) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression = this
override def replaceVariable(variable: String, actualParam: Expression): Expression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
@ -102,6 +108,7 @@ case class BooleanLiteralExpression(value: Boolean) extends Expression {
sealed trait LhsExpression extends Expression
case object BlackHoleExpression extends LhsExpression {
override def renameVariable(variable: String, newVariable: String): Expression = this
override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = false
@ -111,6 +118,10 @@ case object BlackHoleExpression extends LhsExpression {
}
case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression {
override def renameVariable(variable: String, newVariable: String): Expression =
SeparateBytesExpression(
hi.renameVariable(variable, newVariable),
lo.renameVariable(variable, newVariable)).pos(position)
def replaceVariable(variable: String, actualParam: Expression): Expression =
SeparateBytesExpression(
hi.replaceVariable(variable, actualParam),
@ -126,6 +137,8 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx
}
case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression =
SumExpression(expressions.map { case (n, e) => n -> e.renameVariable(variable, newVariable) }, decimal).pos(position)
override def replaceVariable(variable: String, actualParam: Expression): Expression =
SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
@ -147,6 +160,10 @@ case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Bool
}
case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression =
FunctionCallExpression(functionName, expressions.map {
_.renameVariable(variable, newVariable)
}).pos(position)
override def replaceVariable(variable: String, actualParam: Expression): Expression =
FunctionCallExpression(functionName, expressions.map {
_.replaceVariable(variable, actualParam)
@ -162,6 +179,8 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres
}
case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression =
HalfWordExpression(expression.renameVariable(variable, newVariable), hiByte).pos(position)
override def replaceVariable(variable: String, actualParam: Expression): Expression =
HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
@ -267,8 +286,25 @@ object M6809Register extends Enumeration {
//case class Indexing(child: Expression, register: Register.Value) extends Expression
case class VariableExpression(name: String) extends LhsExpression {
override def renameVariable(variable: String, newVariable: String): Expression =
if (name == variable)
VariableExpression(newVariable).pos(position)
else if (name.startsWith(variable) && name(variable.length) == '.')
VariableExpression(newVariable + name.stripPrefix(variable)).pos(position)
else this
override def replaceVariable(variable: String, actualParam: Expression): Expression =
if (name == variable) actualParam else this
if (name == variable) actualParam
else if (name.startsWith(variable) && name(variable.length) == '.') {
actualParam match {
case VariableExpression(newVariable) => this.renameVariable(variable, newVariable)
case _ =>
name.stripPrefix(variable) match {
case ".lo" => FunctionCallExpression("lo", List(this)).pos(position)
case ".hi" => FunctionCallExpression("hi", List(this)).pos(position)
case _ => ??? // TODO
}
}
} else this
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression = this
override def containsVariable(variable: String): Boolean = name == variable
override def getPointies: Seq[String] = if (name.endsWith(".addr.lo")) Seq(name.stripSuffix(".addr.lo")) else Seq.empty
@ -277,6 +313,16 @@ case class VariableExpression(name: String) extends LhsExpression {
}
case class IndexedExpression(name: String, index: Expression) extends LhsExpression {
override def renameVariable(variable: String, newVariable: String): Expression = {
val newIndex = index.renameVariable(variable, newVariable)
if (name == variable)
IndexedExpression(newVariable, newIndex).pos(position)
else if (name.startsWith(variable) && name(variable.length) == '.')
IndexedExpression(newVariable + name.stripPrefix(variable), newIndex).pos(position)
else
IndexedExpression(name, newIndex).pos(position)
}
override def replaceVariable(variable: String, actualParam: Expression): Expression =
if (name == variable) {
actualParam match {
@ -295,6 +341,11 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress
}
case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expression], fields: Seq[(Boolean, String, Seq[Expression])]) extends LhsExpression {
override def renameVariable(variable: String, newVariable: String): Expression =
IndirectFieldExpression(
root.renameVariable(variable, newVariable),
firstIndices.map(_.renameVariable(variable, newVariable)),
fields.map{case (dot, f, i) => (dot, f, i.map(_.renameVariable(variable, newVariable)))})
override def replaceVariable(variable: String, actualParam: Expression): Expression =
IndirectFieldExpression(
root.replaceVariable(variable, actualParam),
@ -323,8 +374,10 @@ case class IndirectFieldExpression(root: Expression, firstIndices: Seq[Expressio
}
case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) extends LhsExpression {
override def renameVariable(variable: String, newVariable: String): Expression = DerefDebuggingExpression(inner.renameVariable(variable, newVariable), preferredSize)
override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefDebuggingExpression(inner.replaceVariable(variable, actualParam), preferredSize)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
DerefDebuggingExpression(inner.replaceIndexedExpression(predicate, replacement), preferredSize)
@ -341,6 +394,7 @@ case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) exten
}
case class DerefExpression(inner: Expression, offset: Int, targetType: Type) extends LhsExpression {
override def renameVariable(variable: String, newVariable: String): Expression = DerefExpression(inner.renameVariable(variable, newVariable), offset, targetType)
override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefExpression(inner.replaceVariable(variable, actualParam), offset, targetType)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
@ -396,12 +450,16 @@ case class VariableDeclarationStatement(name: String,
trait ArrayContents extends Node {
def getAllExpressions(bigEndian: Boolean): List[Expression]
def renameVariable(variableToRename: String, newVariable: String): ArrayContents
def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents
}
case class LiteralContents(contents: List[Expression]) extends ArrayContents {
override def getAllExpressions(bigEndian: Boolean): List[Expression] = contents
override def renameVariable(variableToRename: String, newVariable: String): ArrayContents =
LiteralContents(contents.map(_.renameVariable(variableToRename, newVariable)))
override def replaceVariable(variable: String, expression: Expression): ArrayContents =
LiteralContents(contents.map(_.replaceVariable(variable, expression)))
}
@ -409,6 +467,14 @@ case class LiteralContents(contents: List[Expression]) extends ArrayContents {
case class ForLoopContents(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: ArrayContents) extends ArrayContents {
override def getAllExpressions(bigEndian: Boolean): List[Expression] = start :: end :: body.getAllExpressions(bigEndian).map(_.replaceVariable(variable, LiteralExpression(0, 1)))
override def renameVariable(variableToRename: String, newVariable: String): ArrayContents =
if (variableToRename == variable) this else ForLoopContents(
variable,
start.renameVariable(variableToRename, newVariable),
end.renameVariable(variableToRename, newVariable),
direction,
body.renameVariable(variableToRename, newVariable))
override def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents =
if (variableToReplace == variable) this else ForLoopContents(
variable,
@ -421,6 +487,9 @@ case class ForLoopContents(variable: String, start: Expression, end: Expression,
case class CombinedContents(contents: List[ArrayContents]) extends ArrayContents {
override def getAllExpressions(bigEndian: Boolean): List[Expression] = contents.flatMap(_.getAllExpressions(bigEndian))
override def renameVariable(variableToRename: String, newVariable: String): ArrayContents =
CombinedContents(contents.map(_.renameVariable(variableToRename, newVariable)))
override def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents =
CombinedContents(contents.map(_.replaceVariable(variableToReplace, expression)))
}
@ -459,6 +528,9 @@ case class ProcessedContents(processor: String, values: ArrayContents) extends A
case "struct" => values.getAllExpressions(bigEndian) // not used for emitting actual arrays
}
override def renameVariable(variableToRename: String, newVariable: String): ArrayContents =
ProcessedContents(processor, values.renameVariable(variableToRename, newVariable))
override def replaceVariable(variableToReplace: String, expression: Expression): ArrayContents =
ProcessedContents(processor, values.replaceVariable(variableToReplace, expression))
}

View File

@ -1,7 +1,7 @@
package millfork.test
import millfork.Cpu
import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun}
import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, ShouldNotCompile}
import org.scalatest.{FunSuite, Matchers}
/**
@ -73,4 +73,109 @@ class MacroSuite extends FunSuite with Matchers {
m.readByte(0xc000) should equal(7)
}
}
test("Macro parameter type mismatch") {
ShouldNotCompile(
"""
| byte input
| byte output @$c000
|
|void main() {
| input = $FF
| test_signed_macro(input)
|}
|
|macro void test_signed_macro(sbyte value) {
| if value > 3 {
| output = 1
| }
|}
""".stripMargin)
}
test("Macro void parameter") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
| byte input
| byte output @$c000
|
|void main() {
| input = $FF
| test_signed_macro(input)
|}
|
|macro void test_signed_macro(void value) {
| if value > 3 {
| output = 1
| }
|}
""".stripMargin) { m =>
m.readByte(0xc000) should equal(1)
}
}
test("Some important macro test") {
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
| byte input
| byte output @$c000
|
|void main() {
| input = $FF
| test_signed_macro(input)
|}
|
|macro void test_signed_macro(void value) {
| if sbyte(value) > 3 {
| output = 1
| } else {
| output = 3
| }
|}
""".stripMargin) { m =>
m.readByte(0xc000) should equal(3)
}
}
test("Accessing fields of macro parameters") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
|byte output @$c000
|
|word test = $0380
|
|void main() {
| test_macro(test)
|}
|
|macro void test_macro(word value) {
| if value.hi > 0 {
| output = 1
| }
|}
""".stripMargin) { m =>
m.readByte(0xc000) should equal(1)
}
}
test("Accessing fields of macro parameters when using void") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
|byte output @$c000
|
|word test = $0380
|
|void main() {
| test_macro(test)
|}
|
|macro void test_macro(void value) {
| if value.hi > 0 {
| output = 1
| }
|}
""".stripMargin) { m =>
m.readByte(0xc000) should equal(1)
}
}
}