diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 0f9b9b28..97ee8bfd 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -1083,8 +1083,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa def collectPointies(stmts: Seq[Statement]): Set[String] = { val pointies: mutable.Set[String] = new mutable.HashSet() - pointies ++= stmts.flatMap(_.getAllPointies) - pointies ++ getAliases.filterKeys(pointies).values + pointies ++= stmts.flatMap(_.getAllPointies) + pointies ++= getAliases.filterKeys(pointies).values log.trace("Collected pointies: " + pointies) pointies.toSet } @@ -1133,7 +1133,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } else { new Environment(Some(this), name + "$", cpuFamily, options) } - stmt.params.foreach(p => env.registerParameter(p, options)) + stmt.params.foreach(p => env.registerParameter(p, options, pointies)) def params: ParamSignature = if (stmt.assembly) { AssemblyParamSignature(stmt.params.map { pd => @@ -1423,14 +1423,14 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } - def registerParameter(stmt: ParameterDeclaration, options: CompilationOptions): Unit = { + def registerParameter(stmt: ParameterDeclaration, options: CompilationOptions, pointies: Set[String]): Unit = { val typ = get[Type](stmt.typ) val b = get[Type]("byte") val w = get[Type]("word") val p = get[Type]("pointer") stmt.assemblyParamPassingConvention match { case ByVariable(name) => - val zp = typ.isPointy // TODO + val zp = typ.isPointy || pointies(name) // TODO val v = UninitializedMemoryVariable(prefix + name, typ, if (zp) VariableAllocationMethod.Zeropage else VariableAllocationMethod.Auto, None, defaultVariableAlignment(options, 2), isVolatile = false) addThing(v, stmt.position) registerAddressConstant(v, stmt.position, options, Some(typ)) diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 920bfc2b..55e58a28 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -612,6 +612,18 @@ case class FunctionDeclarationStatement(name: String, override def getAllExpressions: List[Expression] = address.toList ++ statements.getOrElse(Nil).flatMap(_.getAllExpressions) override def withChangedBank(bank: String): BankedDeclarationStatement = copy(bank = Some(bank)) + + override def getAllPointies: Seq[String] = statements match { + case None => Seq.empty + case Some(stmts) => + val locals = stmts.flatMap{ + case s:VariableDeclarationStatement => Some(s.name) + case s:ArrayDeclarationStatement => Some(s.name) + case _ => None + }.toSet + val pointies = stmts.flatMap(_.getAllPointies).toSet + (pointies -- locals).toSeq + } } sealed trait ExecutableStatement extends Statement diff --git a/src/test/scala/millfork/test/PointerSuite.scala b/src/test/scala/millfork/test/PointerSuite.scala index 28c1a23a..b7163bf7 100644 --- a/src/test/scala/millfork/test/PointerSuite.scala +++ b/src/test/scala/millfork/test/PointerSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, ShouldNotCompile} +import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, EmuUnoptimizedRun, ShouldNotCompile} import org.scalatest.{AppendedClues, FunSuite, Matchers} /** @@ -427,4 +427,30 @@ class PointerSuite extends FunSuite with Matchers with AppendedClues { m.readWord(0xc100) should equal(0x400) } } + + test("Pointers should remain at zero page if used as pointers in assembly") { + val m = EmuUnoptimizedRun( + """ + | word output @$c000 + | pointer p + | volatile pointer q1 + | volatile pointer q2 + | volatile pointer q3 + | volatile pointer q4 + | volatile pointer q5 + | volatile pointer q6 + | volatile pointer q7 + | volatile pointer q8 + | array arr [250] @$0 + | void main () { + | asm { + | lda (p),y + | } + | output = p.addr + | q1 = q2 + q3 + q4 + q5 + q6 + q7 + q8 + | } + |""".stripMargin) + m.readWord(0xc000) should be <(256) + + } }