Macro improvements:

– allow local constants in macros
– allow untyped macro parameters with void
– treat the name of a function as a pointer to it
– add this.function local alias (#118)
This commit is contained in:
Karol Stasiak 2021-11-12 02:10:07 +01:00
parent c9313e5dbe
commit f676e74e38
12 changed files with 274 additions and 42 deletions

View File

@ -14,6 +14,8 @@ It implies the following:
* cannot contain variable or array declarations
* but can contain scalar constant declarations; the constants are scoped to the particular macro invocation
* can be `asm` - in this case, they should **not** end with a return instruction
* do not have an address
@ -35,7 +37,13 @@ It implies the following:
* `call` parameters exceptionally can have their type declared as `void`;
such parameters accept expressions of any type, including `void`, however, you cannot assign from those expressions
* macros do not have their own scope (they reuse the scope from their invocations) exceptions: the parameters and the local labels defined in assembly
* macros do not have their own scope (they reuse the scope from their invocations) exceptions:
* the parameters
* the local labels defined in assembly
* the local constants
* control-flow statements (`break`, `continue`, `return`, `goto`, `label`) are run as if places in the caller function

View File

@ -25,3 +25,5 @@
* `byte segment.N.bank` the value of `segment_N_bank` from the platform definition
* `byte segment.N.fill` the value of `segment_N_fill` from the platform definition
* `this.function` the alias of the current function (in macros, it resolves to the actual non-macro function that called the macro)

View File

@ -256,6 +256,10 @@ object AbstractExpressionCompiler {
def getExpressionTypeLoosely(ctx: CompilationContext, expr: Expression): Type = {
getExpressionTypeImpl(ctx.env, ctx.log, expr, loosely = true)
}
@inline
def getExpressionTypeForMacro(ctx: CompilationContext, expr: Expression): Type = {
getExpressionTypeImpl(ctx.env, ctx.log, expr, loosely = false, failWithVoid = true)
}
@inline
def getExpressionType(env: Environment, log: Logger, expr: Expression): Type = getExpressionTypeImpl(env, log, expr, loosely = false)
@ -263,7 +267,7 @@ object AbstractExpressionCompiler {
@inline
def getExpressionTypeLoosely(env: Environment, log: Logger, expr: Expression): Type = getExpressionTypeImpl(env, log, expr, loosely = true)
def getExpressionTypeImpl(env: Environment, log: Logger, expr: Expression, loosely: Boolean): Type = {
def getExpressionTypeImpl(env: Environment, log: Logger, expr: Expression, loosely: Boolean, failWithVoid: Boolean = false): Type = {
if (expr.typeCache ne null) return expr.typeCache
val b = env.get[Type]("byte")
val bool = env.get[Type]("bool$")
@ -351,6 +355,11 @@ object AbstractExpressionCompiler {
b
}
}
} else if (failWithVoid) {
env.maybeGet[TypedThing](name) match {
case Some(t) => t.typ
case None => VoidType
}
} else {
env.get[TypedThing](name, expr.position).typ
}

View File

@ -4,9 +4,12 @@ import millfork.assembly.AbstractCode
import millfork.assembly.m6809.MOpcode
import millfork.assembly.mos._
import millfork.assembly.z80.ZOpcode
import millfork.env
import millfork.env._
import millfork.node._
import scala.collection.mutable
/**
* @author Karol Stasiak
*/
@ -104,6 +107,7 @@ abstract class MacroExpander[T <: AbstractCode] {
def inlineFunction(ctx: CompilationContext, i: MacroFunction, actualParams: List[Expression], position: Option[Position]): (List[T], List[ExecutableStatement]) = {
var paramPreparation = List[T]()
var actualCode = i.code
var actualConstants = i.constants
i.params match {
case AssemblyOrMacroParamSignature(params) =>
params.foreach{ param =>
@ -133,6 +137,7 @@ abstract class MacroExpander[T <: AbstractCode] {
ctx.log.error("Const parameters to macro functions have to be constants", expr.position)
}
actualCode = actualCode.map(stmt => replaceVariableX(stmt, paramVariable.name.stripPrefix(i.environment.prefix), expr))
actualConstants = actualConstants.map(_.replaceVariableInInitialValue(paramVariable.name.stripPrefix(i.environment.prefix), expr))
case (expr, AssemblyOrMacroParam(paramType, paramVariable, AssemblyParameterPassingBehaviour.Eval)) =>
val castParam = FunctionCallExpression(paramType.name, List(expr))
actualCode = actualCode.map(stmt => replaceVariableX(stmt, paramVariable.name.stripPrefix(i.environment.prefix), castParam))
@ -142,6 +147,21 @@ abstract class MacroExpander[T <: AbstractCode] {
}
}
var flattenedConstants = mutable.MutableList[VariableDeclarationStatement]()
while(actualConstants.nonEmpty) {
val constant = actualConstants.head
flattenedConstants += constant
actualConstants = actualConstants.tail.map(_.replaceVariableInInitialValue(constant.name.stripPrefix(i.environment.prefix), constant.initialValue.get))
}
for (constant <- flattenedConstants) {
val valueExpr = constant.initialValue.get
ctx.env.eval(valueExpr) match {
case Some(c) =>
actualCode = actualCode.map(stmt => replaceVariableX(stmt, constant.name.stripPrefix(i.environment.prefix), valueExpr))
case None =>
ctx.log.error("Not a constant", constant.position)
}
}
// fix local labels:
// TODO: do it even if the labels are in an inline assembly block inside a Millfork function
val localLabels = actualCode.flatMap {

View File

@ -191,7 +191,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
case Some(_) =>
params.flatMap(p => MosExpressionCompiler.compile(ctx, p, None, NoBranching))-> Nil
case None =>
env.lookupFunction(name, params.map(p => MosExpressionCompiler.getExpressionType(ctx, p) -> p)) match {
env.lookupFunction(name, params.map(p => AbstractExpressionCompiler.getExpressionTypeForMacro(ctx, p) -> p)) match {
case Some(i: MacroFunction) =>
val (paramPreparation, inlinedStatements) = MosMacroExpander.inlineFunction(ctx, i, params, e.position)
paramPreparation ++ compile(ctx.withInlinedEnv(i.environment, ctx.nextLabel("en")), inlinedStatements)._1 -> Nil

View File

@ -337,11 +337,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
t.asInstanceOf[T]
} else {
t match {
case Alias(_, target, deprectated) =>
case Alias(_, target, deprectated, local) =>
if (deprectated && options.flag(CompilationFlag.DeprecationWarning)) {
log.warn(s"Alias `$name` is deprecated, use `$target` instead", position)
}
root.get[T](target)
if (local) get[T](target) else root.get[T](target)
case _ => throw IdentifierHasWrongTypeOfThingException(clazz, name, position)
}
}
@ -365,11 +365,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
val t: Thing = things(name)
val clazz = implicitly[Manifest[T]].runtimeClass
t match {
case Alias(_, target, deprectated) =>
case Alias(_, target, deprectated, local) =>
if (deprectated && options.flag(CompilationFlag.DeprecationWarning)) {
log.warn(s"Alias `$name` is deprecated, use `$target` instead")
}
root.maybeGet[T](target)
if (local) maybeGet[T](target) else root.maybeGet[T](target)
case _ =>
if ((t ne null) && clazz.isInstance(t)) {
Some(t.asInstanceOf[T])
@ -720,7 +720,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
if (name.startsWith(".")) return Some(MemoryAddressConstant(Label(prefix + name)))
vv match {
case Some(m) if m.contains(name) => Some(m(name))
case _ => maybeGet[ConstantThing](name).map(_.value)
case _ => maybeGet[ConstantLikeThing](name).map {
case x: ConstantThing => x.value
case x: FunctionInMemory => x.toAddress
}
}
case IndexedExpression(arrName, index) =>
getPointy(arrName) match {
@ -1309,6 +1312,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
if (hasReturnVariable) {
registerVariable(VariableDeclarationStatement(stmt.name + ".return", stmt.resultType, None, global = true, stack = false, constant = false, volatile = false, register = false, None, None, Set.empty, None), options, isPointy = false)
}
val constants = mutable.MutableList[VariableDeclarationStatement]()
stmt.statements match {
case None =>
stmt.address match {
@ -1331,10 +1335,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
case Some(statements) =>
statements.foreach {
case v: VariableDeclarationStatement => env.registerVariable(v, options, pointies(v.name))
case a: ArrayDeclarationStatement => env.registerArray(a, options)
case _ => ()
if (stmt.isMacro) {
statements.foreach {
case v: VariableDeclarationStatement => constants += v
case _ => ()
}
} else {
statements.foreach {
case v: VariableDeclarationStatement => env.registerVariable(v, options, pointies(v.name))
case a: ArrayDeclarationStatement => env.registerArray(a, options)
case _ => ()
}
}
def scanForLabels(statement: Statement): Unit = statement match {
case c: CompoundStatement => c.getChildStatements.foreach(scanForLabels)
@ -1427,6 +1438,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
params.asInstanceOf[AssemblyOrMacroParamSignature],
stmt.assembly,
env,
constants.toList,
executableStatements
)
addThing(mangled, stmt.position)
@ -1461,6 +1473,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
}
}
if (!stmt.isMacro) {
val alias = Alias("this.function", name, local = true)
env.addThing("this.function", alias, None)
env.expandAlias(alias)
}
}
private def getFunctionPointerType(f: FunctionInMemory) = f.params.types match {
@ -2458,19 +2475,28 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
private def expandAliases(): Unit = {
val aliasesToAdd = mutable.ListBuffer[Alias]()
things.values.foreach{
case Alias(aliasName, target, deprecated) =>
val prefix = target + "."
things.foreach{
case (thingName, thing) =>
if (thingName.startsWith(prefix)) {
aliasesToAdd += Alias(aliasName + "." + thingName.stripPrefix(prefix), thingName, deprecated)
}
}
case a:Alias => aliasesToAdd ++= expandAliasImpl(a)
case _ => ()
}
aliasesToAdd.foreach(a => things += a.name -> a)
}
private def expandAliasImpl(a: Alias): Seq[Alias] = {
val aliasesToAdd = mutable.ListBuffer[Alias]()
val prefix = a.target + "."
root.things.foreach {
case (thingName, thing) =>
if (thingName.startsWith(prefix)) {
aliasesToAdd += Alias(a.name + "." + thingName.stripPrefix(prefix), thingName, a.deprecated, a.local)
}
}
aliasesToAdd
}
private def expandAlias(a: Alias): Unit = {
expandAliasImpl(a).foreach(a => things += a.name -> a)
}
def fixStructAlignments(): Unit = {
val allStructTypes: Iterable[CompoundVariableType] = things.values.flatMap {
case s@StructType(name, _, _) => Some(s)
@ -2644,7 +2670,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
if (!things.contains("memory_barrier")) {
things("memory_barrier") = MacroFunction("memory_barrier", v, AssemblyOrMacroParamSignature(Nil), isInAssembly = true, this, CpuFamily.forType(options.platform.cpu) match {
things("memory_barrier") = MacroFunction("memory_barrier", v, AssemblyOrMacroParamSignature(Nil), isInAssembly = true, this, Nil, CpuFamily.forType(options.platform.cpu) match {
case CpuFamily.M6502 => List(MosAssemblyStatement(Opcode.CHANGED_MEM, AddrMode.DoesNotExist, LiteralExpression(0, 1), Elidability.Fixed))
case CpuFamily.I80 => List(Z80AssemblyStatement(ZOpcode.CHANGED_MEM, NoRegisters, None, LiteralExpression(0, 1), Elidability.Fixed))
case CpuFamily.I86 => List(Z80AssemblyStatement(ZOpcode.CHANGED_MEM, NoRegisters, None, LiteralExpression(0, 1), Elidability.Fixed))
@ -2656,7 +2682,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
if (!things.contains("breakpoint")) {
val p = get[VariableType]("pointer")
if (options.flag(CompilationFlag.EnableBreakpoints)) {
things("breakpoint") = MacroFunction("breakpoint", v, AssemblyOrMacroParamSignature(Nil), isInAssembly = true, this, CpuFamily.forType(options.platform.cpu) match {
things("breakpoint") = MacroFunction("breakpoint", v, AssemblyOrMacroParamSignature(Nil), isInAssembly = true, this, Nil, CpuFamily.forType(options.platform.cpu) match {
case CpuFamily.M6502 => List(MosAssemblyStatement(Opcode.CHANGED_MEM, AddrMode.DoesNotExist, VariableExpression("..brk"), Elidability.Fixed))
case CpuFamily.I80 => List(Z80AssemblyStatement(ZOpcode.CHANGED_MEM, NoRegisters, None, VariableExpression("..brk"), Elidability.Fixed))
case CpuFamily.I86 => List(Z80AssemblyStatement(ZOpcode.CHANGED_MEM, NoRegisters, None, VariableExpression("..brk"), Elidability.Fixed))
@ -2664,7 +2690,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
case _ => ???
})
} else {
things("breakpoint") = MacroFunction("breakpoint", v, AssemblyOrMacroParamSignature(Nil), isInAssembly = true, this, Nil)
things("breakpoint") = MacroFunction("breakpoint", v, AssemblyOrMacroParamSignature(Nil), isInAssembly = true, this, Nil, Nil)
}
}
}
@ -2784,7 +2810,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
def getAliases: Map[String, String] = {
things.values.flatMap {
case Alias(a, b, _) => Some(a -> b)
case Alias(a, b, _, _) => Some(a -> b)
case _ => None
}.toMap ++ parent.map(_.getAliases).getOrElse(Map.empty)
}

View File

@ -1,7 +1,7 @@
package millfork.env
import millfork.assembly.BranchingOpcodeMapping
import millfork.{CompilationFlag, CompilationOptions, CpuFamily}
import millfork.{CompilationFlag, CompilationOptions, CpuFamily, env}
import millfork.node._
import millfork.output.{MemoryAlignment, NoAlignment}
@ -10,7 +10,7 @@ sealed trait Thing {
def rootName: String = name
}
case class Alias(name: String, target: String, deprecated: Boolean = false) extends Thing
case class Alias(name: String, target: String, deprecated: Boolean = false, local: Boolean = false) extends Thing
sealed trait CallableThing extends Thing
@ -69,6 +69,18 @@ case object VoidType extends Type {
override def alignment: MemoryAlignment = NoAlignment
}
case class InvalidType(nonce: Long) extends Type {
def size = 0
def alignedSize = 0
def isSigned = false
override def name = "$invalid"
override def alignment: MemoryAlignment = NoAlignment
}
sealed trait PlainType extends VariableType {
override def isCompatible(other: Type): Boolean = this == other || this.isSubtypeOf(other) || other.isSubtypeOf(this)
@ -248,6 +260,10 @@ sealed trait TypedThing extends Thing {
def typ: Type
}
sealed trait ConstantLikeThing extends TypedThing {
}
sealed trait ThingInMemory extends Thing {
def zeropage: Boolean
@ -490,6 +506,8 @@ sealed trait MangledFunction extends CallableThing {
def interrupt: Boolean
def kernalInterrupt: Boolean
def isConstPure: Boolean
def canBePointedTo: Boolean
@ -504,6 +522,8 @@ case class EmptyFunction(name: String,
override def interrupt = false
override def kernalInterrupt: Boolean = false
override def isConstPure = false
override def canBePointedTo: Boolean = false
@ -514,15 +534,18 @@ case class MacroFunction(name: String,
params: AssemblyOrMacroParamSignature,
isInAssembly: Boolean,
environment: Environment,
constants: List[VariableDeclarationStatement],
code: List[ExecutableStatement]) extends MangledFunction {
override def interrupt = false
override def kernalInterrupt: Boolean = false
override def isConstPure = false
override def canBePointedTo: Boolean = false
}
sealed trait FunctionInMemory extends MangledFunction with ThingInMemory {
sealed trait FunctionInMemory extends MangledFunction with ThingInMemory with TypedThing with VariableLikeThing with ConstantLikeThing {
def environment: Environment
override def isFar(compilationOptions: CompilationOptions): Boolean =
@ -538,6 +561,20 @@ sealed trait FunctionInMemory extends MangledFunction with ThingInMemory {
def optimizationHints: Set[String]
override def hasOptimizationHints: Boolean = optimizationHints.nonEmpty
override lazy val typ: Type = {
if (interrupt || kernalInterrupt) {
InvalidType(name.hashCode())
} else {
val paramType = params.types match {
case Nil => VoidType
case List(t) => t
case _ => InvalidType(params.types.hashCode())
}
val typeName = "function." + paramType.name + ".to." + returnType.name
FunctionPointerType(typeName, paramType.name, returnType.name, Some(paramType), Some(returnType))
}
}
}
case class ExternFunction(name: String,
@ -551,6 +588,8 @@ case class ExternFunction(name: String,
override def interrupt = false
override def kernalInterrupt: Boolean = false
override def isConstPure = false
override def zeropage: Boolean = false
@ -582,7 +621,7 @@ case class NormalFunction(name: String,
override def isVolatile: Boolean = false
}
case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing {
case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing with ConstantLikeThing {
def map(f: Constant => Constant) = ConstantThing("", f(value), typ)
}

View File

@ -173,18 +173,31 @@ case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Bool
case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression {
override def renameVariable(variable: String, newVariable: String): Expression =
FunctionCallExpression(functionName, expressions.map {
FunctionCallExpression(if (functionName == variable) newVariable else functionName, expressions.map {
_.renameVariable(variable, newVariable)
}).pos(position)
override def replaceVariable(variable: String, actualParam: Expression): Expression =
FunctionCallExpression(functionName, expressions.map {
_.replaceVariable(variable, actualParam)
}).pos(position)
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
override def replaceVariable(variable: String, actualParam: Expression): Expression = {
if (variable == functionName) {
actualParam match {
case VariableExpression(v) =>
FunctionCallExpression(v, expressions.map {
_.replaceVariable(variable, actualParam)
}).pos(position)
case _ =>
??? // TODO
}
} else {
FunctionCallExpression(functionName, expressions.map {
_.replaceVariable(variable, actualParam)
}).pos(position)
}
}
override def replaceIndexedExpression(predicate: IndexedExpression => Boolean, replacement: IndexedExpression => Expression): Expression =
FunctionCallExpression(functionName, expressions.map {
_.replaceIndexedExpression(predicate, replacement)
}).pos(position)
override def containsVariable(variable: String): Boolean = expressions.exists(_.containsVariable(variable))
override def containsVariable(variable: String): Boolean = variable == functionName || expressions.exists(_.containsVariable(variable))
override def getPointies: Seq[String] = expressions.flatMap(_.getPointies)
override def isPure: Boolean = false // TODO
override def getAllIdentifiers: Set[String] = expressions.map(_.getAllIdentifiers).fold(Set[String]())(_ ++ _) + functionName
@ -477,6 +490,11 @@ case class VariableDeclarationStatement(name: String,
override def getAllExpressions: List[Expression] = List(initialValue, address).flatten
override def withChangedBank(bank: String): BankedDeclarationStatement = copy(bank = Some(bank))
def replaceVariableInInitialValue(variable: String, expression: Expression): VariableDeclarationStatement = initialValue match {
case Some(v) => copy(initialValue = Some(v.replaceVariable(variable, expression)))
case None => this
}
}
trait ArrayContents extends Node {

View File

@ -733,14 +733,14 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
if (flags("asm")) validateAsmFunctionBody(p, flags, name, statements)
if (flags("macro")) {
statements.flatMap(_.find(_.isInstanceOf[VariableDeclarationStatement])) match {
case Some(s) =>
case Some(s: VariableDeclarationStatement) if !s.constant =>
log.error(s"Macro functions cannot declare variables", s.position)
case None =>
case _ =>
}
statements.flatMap(_.find(_.isInstanceOf[ArrayDeclarationStatement])) match {
case Some(s) =>
log.error(s"Macro functions cannot declare arrays", s.position)
case None =>
case _ =>
}
}
Seq(FunctionDeclarationStatement(name, returnType, params.toList,

View File

@ -0,0 +1,46 @@
package millfork.test
import millfork.Cpu
import millfork.test.emu.EmuUnoptimizedCrossPlatformRun
import org.scalatest.{FunSuite, Matchers}
/**
* @author Karol Stasiak
*/
class AliasSuite extends FunSuite with Matchers {
test("Test aliases to subvariables") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
|
|
|word output @$c000
|alias o = output
|
|void main() {
| o.hi = 1
| o.lo = 2
|}
|
|""".stripMargin){ m =>
m.readWord(0xc000) should equal(0x102)
}
}
test("Test aliases to pointers") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
|
|
|pointer output @$c000
|alias o = output
|
|void main() {
| o = o.addr
|}
|
|""".stripMargin){ m =>
m.readWord(0xc000) should equal(0xc000)
}
}
}

View File

@ -51,10 +51,10 @@ class FunctionPointerSuite extends FunSuite with Matchers with AppendedClues{
| byte id(byte x) = x
|
| void main() {
| tabulate(output0, zero.pointer)
| tabulate(output1, id.pointer)
| tabulate(output2, double.pointer)
| tabulate(output3, negate.pointer)
| tabulate(output0, zero)
| tabulate(output1, id)
| tabulate(output2, double)
| tabulate(output3, negate)
| }
|
""".stripMargin) { m =>

View File

@ -307,4 +307,68 @@ class MacroSuite extends FunSuite with Matchers with AppendedClues {
|}
|""".stripMargin)
}
test("Should allow passing functions to a macro") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
|
|word output1 @$c000
|word output2 @$c002
|macro void f(void callback) {
| output1 = callback.addr
| callback()
|}
|
|void g() {}
|
|void main() {
| f(g)
| output2 = g.addr
|}
|""".stripMargin) {m =>
m.readWord(0xc000) should equal(m.readWord(0xc002))
}
}
test("Constants in macros") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
|macro void f(byte const x) {
| const byte y = 2*x
| const byte z = y
| output[i] = z
| i+=1
|}
|
|array (byte) output[55]@$c000
|void main() {
| byte i
| f(1)
| f(2)
| f(3)
| f(4)
|}
|""".stripMargin) { m =>
m.readByte(0xc000) should equal(2)
m.readByte(0xc001) should equal(4)
m.readByte(0xc002) should equal(6)
m.readByte(0xc003) should equal(8)
}
}
test("this.function") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
|macro void f(byte const x) {
| output = this.function.addr
|}
|
|pointer output @$c000
|void main() {
| f(1)
|}
|""".stripMargin) { m =>
m.readWord(0xc000) should equal(0x200)
}
}
}