From 76dd4929a61063e85d9b9ba84f21362e209ef6e9 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Fri, 21 Dec 2018 22:36:05 +0100 Subject: [PATCH] 6502: Track which pointers need to be on zeropage --- .../compiler/AbstractExpressionCompiler.scala | 5 ++-- .../compiler/mos/MosExpressionCompiler.scala | 2 +- .../compiler/z80/Z80ExpressionCompiler.scala | 2 +- src/main/scala/millfork/env/Environment.scala | 28 +++++++++++++------ src/main/scala/millfork/node/Node.scala | 20 +++++++++++++ 5 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 406bf327..63ec966e 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -27,12 +27,13 @@ class AbstractExpressionCompiler[T <: AbstractCode] { // TODO } - def callingContext(ctx: CompilationContext, v: MemoryVariable): CompilationContext = { + def callingContext(ctx: CompilationContext, callee: String, v: MemoryVariable): CompilationContext = { val result = new Environment(Some(ctx.env), "", ctx.options.platform.cpuFamily, ctx.jobContext) + val isPointy = ctx.env.isKnownPointy(callee, v.name.stripPrefix(callee + '$')) result.registerVariable(VariableDeclarationStatement( v.name, v.typ.name, stack = false, global = false, constant = false, volatile = false, register = false, - initialValue = None, address = None, bank = v.declaredBank, alignment = None), ctx.options) + initialValue = None, address = None, bank = v.declaredBank, alignment = None), ctx.options, isPointy = isPointy) ctx.copy(env = result) } diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index b35e3b35..bba6b991 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -1166,7 +1166,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case NormalParamSignature(paramVars) => params.zip(paramVars).flatMap { case (paramExpr, paramVar) => - val callCtx = callingContext(ctx, paramVar) + val callCtx = callingContext(ctx, function.name, paramVar) compileAssignment(callCtx, paramExpr, VariableExpression(paramVar.name)) } ++ List(AssemblyLine.absoluteOrLongAbsolute(JSR, function, ctx.options)) } diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index b0a9a0f0..e13ce416 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -895,7 +895,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case NormalParamSignature(paramVars) => params.zip(paramVars).flatMap { case (paramExpr, paramVar) => - val callCtx = callingContext(ctx, paramVar) + val callCtx = callingContext(ctx, function.name, paramVar) paramVar.typ.size match { case 1 => compileToA(ctx, paramExpr) ++ storeA(callCtx, VariableExpression(paramVar.name), paramVar.typ.isSigned) diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 903c1645..a4e1acac 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -210,8 +210,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa private def addThing(t: Thing, position: Option[Position]): Unit = { if (assertNotDefined(t.name, position)) { - things(t.name.stripPrefix(prefix)) = t - } + things(t.name.stripPrefix(prefix)) = t + } } def removeVariable(str: String): Unit = { @@ -726,7 +726,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } + def collectPointies(stmts: Seq[Statement]): Set[String] = { + val pointies: mutable.Set[String] = new mutable.HashSet() + pointies ++= stmts.flatMap(_.getAllPointies) + pointies ++ getAliases.filterKeys(pointies).values + log.trace("Collected pointies: " + pointies) + pointies.toSet + } + def registerFunction(stmt: FunctionDeclarationStatement, options: CompilationOptions): Unit = { + val pointies = collectPointies(stmt.statements.getOrElse(Seq.empty)) + pointiesUsed(stmt.name) = pointies val w = get[Type]("word") val name = stmt.name val resultType = get[Type](stmt.resultType) @@ -782,7 +792,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa }) } if (resultType.size > Cpu.getMaxSizeReturnableViaRegisters(options.platform.cpu, options)) { - registerVariable(VariableDeclarationStatement(stmt.name + ".return", stmt.resultType, None, global = true, stack = false, constant = false, volatile = false, register = false, None, None, None), options) + registerVariable(VariableDeclarationStatement(stmt.name + ".return", stmt.resultType, None, global = true, stack = false, constant = false, volatile = false, register = false, None, None, None), options, isPointy = false) } stmt.statements match { case None => @@ -806,7 +816,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case Some(statements) => statements.foreach { - case v: VariableDeclarationStatement => env.registerVariable(v, options) + case v: VariableDeclarationStatement => env.registerVariable(v, options, pointies(v.name)) case _ => () } val executableStatements = statements.flatMap { @@ -1108,7 +1118,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } - def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions): Unit = { + def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions, isPointy: Boolean): Unit = { val name = stmt.name val position = stmt.position if (stmt.stack && parent.isEmpty) { @@ -1160,7 +1170,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } else { val (v, addr) = stmt.address.fold[(VariableInMemory, Constant)]({ val alloc = - if (typ.name == "pointer" || typ.name == "__reg$type") VariableAllocationMethod.Zeropage + if (isPointy || typ.name == "__reg$type") VariableAllocationMethod.Zeropage else if (stmt.global) VariableAllocationMethod.Static else if (stmt.register) VariableAllocationMethod.Register else VariableAllocationMethod.Auto @@ -1280,9 +1290,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case e: EnumDefinitionStatement => registerEnum(e) case _ => } + val pointies = collectPointies(program.declarations) + pointiesUsed("") = pointies program.declarations.foreach { case f: FunctionDeclarationStatement => registerFunction(f, options) - case v: VariableDeclarationStatement => registerVariable(v, options) + case v: VariableDeclarationStatement => registerVariable(v, options, pointies(v.name)) case a: ArrayDeclarationStatement => registerArray(a, options) case _ => } @@ -1299,7 +1311,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa register = false, initialValue = None, address = None, - alignment = None), options) + alignment = None), options, isPointy = true) } if (CpuFamily.forType(options.platform.cpu) == CpuFamily.M6502) { if (!things.contains("__constant8")) { diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 295f26fa..c3d8c3ae 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -30,6 +30,7 @@ object Node { sealed trait Expression extends Node { def replaceVariable(variable: String, actualParam: Expression): Expression def containsVariable(variable: String): Boolean + def getPointies: Seq[String] def isPure: Boolean def getAllIdentifiers: Set[String] } @@ -37,6 +38,7 @@ sealed trait Expression extends Node { case class ConstantArrayElementExpression(constant: Constant) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false + override def getPointies: Seq[String] = Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set.empty } @@ -44,6 +46,7 @@ case class ConstantArrayElementExpression(constant: Constant) extends Expression case class LiteralExpression(value: Long, requiredSize: Int) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false + override def getPointies: Seq[String] = Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set.empty } @@ -51,6 +54,7 @@ case class LiteralExpression(value: Long, requiredSize: Int) extends Expression case class TextLiteralExpression(characters: List[Expression]) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false + override def getPointies: Seq[String] = Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set.empty } @@ -58,6 +62,7 @@ case class TextLiteralExpression(characters: List[Expression]) extends Expressio case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false + override def getPointies: Seq[String] = Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set.empty } @@ -65,6 +70,7 @@ case class GeneratedConstantExpression(value: Constant, typ: Type) extends Expre case class BooleanLiteralExpression(value: Boolean) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this override def containsVariable(variable: String): Boolean = false + override def getPointies: Seq[String] = Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set.empty } @@ -74,6 +80,7 @@ sealed trait LhsExpression extends Expression case object BlackHoleExpression extends LhsExpression { override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this override def containsVariable(variable: String): Boolean = false + override def getPointies: Seq[String] = Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set.empty } @@ -84,6 +91,7 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx hi.replaceVariable(variable, actualParam), lo.replaceVariable(variable, actualParam)).pos(position) override def containsVariable(variable: String): Boolean = hi.containsVariable(variable) || lo.containsVariable(variable) + override def getPointies: Seq[String] = hi.getPointies ++ lo.getPointies override def isPure: Boolean = hi.isPure && lo.isPure override def getAllIdentifiers: Set[String] = hi.getAllIdentifiers ++ lo.getAllIdentifiers } @@ -92,6 +100,7 @@ case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Bool override def replaceVariable(variable: String, actualParam: Expression): Expression = SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal).pos(position) override def containsVariable(variable: String): Boolean = expressions.exists(_._2.containsVariable(variable)) + override def getPointies: Seq[String] = expressions.flatMap(_._2.getPointies) override def isPure: Boolean = expressions.forall(_._2.isPure) override def getAllIdentifiers: Set[String] = expressions.map(_._2.getAllIdentifiers).fold(Set[String]())(_ ++ _) } @@ -102,6 +111,7 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres _.replaceVariable(variable, actualParam) }).pos(position) override def containsVariable(variable: String): Boolean = expressions.exists(_.containsVariable(variable)) + override def getPointies: Seq[String] = expressions.flatMap(_.getPointies) override def isPure: Boolean = false // TODO override def getAllIdentifiers: Set[String] = expressions.map(_.getAllIdentifiers).fold(Set[String]())(_ ++ _) + functionName } @@ -110,6 +120,7 @@ case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends E override def replaceVariable(variable: String, actualParam: Expression): Expression = HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte).pos(position) override def containsVariable(variable: String): Boolean = expression.containsVariable(variable) + override def getPointies: Seq[String] = expression.getPointies override def isPure: Boolean = expression.isPure override def getAllIdentifiers: Set[String] = expression.getAllIdentifiers } @@ -170,6 +181,7 @@ case class VariableExpression(name: String) extends LhsExpression { override def replaceVariable(variable: String, actualParam: Expression): Expression = if (name == variable) actualParam else this override def containsVariable(variable: String): Boolean = name == variable + override def getPointies: Seq[String] = Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set(name) } @@ -184,12 +196,14 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress } } else IndexedExpression(name, index.replaceVariable(variable, actualParam)).pos(position) override def containsVariable(variable: String): Boolean = name == variable || index.containsVariable(variable) + override def getPointies: Seq[String] = Seq(name) override def isPure: Boolean = index.isPure override def getAllIdentifiers: Set[String] = index.getAllIdentifiers + name } sealed trait Statement extends Node { def getAllExpressions: List[Expression] + def getAllPointies: Seq[String] = getAllExpressions.flatMap(_.getPointies) } sealed trait DeclarationStatement extends Statement { @@ -357,6 +371,12 @@ case class Assignment(destination: LhsExpression, source: Expression) extends Ex case class MosAssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value, expression: Expression, elidability: Elidability.Value) extends ExecutableStatement { override def getAllExpressions: List[Expression] = List(expression) + override def getAllPointies: Seq[String] = addrMode match { + case AddrMode.IndexedY | AddrMode.IndexedX | AddrMode.LongIndexedY | AddrMode.IndexedZ | AddrMode.LongIndexedZ | + AddrMode.ZeroPage | AddrMode.ZeroPageX | AddrMode.ZeroPageY => + expression.getAllIdentifiers.toSeq.map(_.takeWhile(_ != '.')) + case _ => Seq.empty + } } case class Z80AssemblyStatement(opcode: ZOpcode.Value, registers: ZRegisters, offsetExpression: Option[Expression], expression: Expression, elidability: Elidability.Value) extends ExecutableStatement {