diff --git a/src/main/scala/millfork/assembly/mos/opt/EmptyParameterStoreRemoval.scala b/src/main/scala/millfork/assembly/mos/opt/EmptyParameterStoreRemoval.scala index 72db811b..5bb3ed40 100644 --- a/src/main/scala/millfork/assembly/mos/opt/EmptyParameterStoreRemoval.scala +++ b/src/main/scala/millfork/assembly/mos/opt/EmptyParameterStoreRemoval.scala @@ -15,7 +15,7 @@ object EmptyParameterStoreRemoval extends AssemblyOptimization[AssemblyLine] { override def name = "Removing pointless stores to foreign variables" private val storeInstructions = Set(STA, STX, STY, SAX, STZ, STA_W, STX_W, STY_W, STZ_W) - private val storeAddrModes = Set(Absolute, ZeroPage, AbsoluteX, AbsoluteY, ZeroPageX, ZeroPageY) + private val storeAddrModes = Set(Absolute, ZeroPage, AbsoluteX, AbsoluteY, ZeroPageX, ZeroPageY, LongAbsolute, LongAbsoluteX) override def optimize(f: NormalFunction, code: List[AssemblyLine], optimizationContext: OptimizationContext): List[AssemblyLine] = { val usedFunctions = code.flatMap { diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index c2b6261f..4fd43698 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -27,13 +27,10 @@ class AbstractExpressionCompiler[T <: AbstractCode] { // TODO } - def callingContext(ctx: CompilationContext, callee: String, v: MemoryVariable): CompilationContext = { + def callingContext(ctx: CompilationContext, callee: String, v: VariableInMemory): CompilationContext = { val result = new Environment(Some(ctx.env), "", ctx.options.platform.cpuFamily, ctx.jobContext) - val isPointy = ctx.env.isKnownPointy(callee, v.name.stripPrefix(callee + '$')) - result.registerVariable(VariableDeclarationStatement( - v.name, v.typ.name, - stack = false, global = false, constant = false, volatile = false, register = false, - initialValue = None, address = None, bank = v.declaredBank, alignment = None), ctx.options, isPointy = isPointy) + val localName = v.name.stripPrefix(callee + '$') + result.addVariable(ctx.options, localName, v, None) ctx.copy(env = result) } diff --git a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala index 02826758..b1ec16e0 100644 --- a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala @@ -245,6 +245,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { case 2 => MosExpressionCompiler.compile(ctx, e, someRegisterAX, NoBranching) ++ stackPointerFixBeforeReturn(ctx, preserveA = true, preserveX = true) ++ returnInstructions case _ => + // TODO: is this case ever used? MosExpressionCompiler.compileAssignment(ctx, e, VariableExpression(ctx.function.name + "`return")) ++ stackPointerFixBeforeReturn(ctx) ++ returnInstructions } @@ -268,8 +269,12 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { returnInstructions } case _ => - MosExpressionCompiler.compileAssignment(ctx, e, VariableExpression(ctx.function.name + ".return")) ++ - stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions + if (ctx.function.hasElidedReturnVariable) { + stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions + } else { + MosExpressionCompiler.compileAssignment(ctx, e, VariableExpression(ctx.function.name + ".return")) ++ + stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions + } } }) -> Nil case s: IfStatement => diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala index 141c6fb4..fec78bc3 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala @@ -72,8 +72,13 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { Z80ExpressionCompiler.compileToDEHL(ctx, e) ++ fixStackOnReturn(ctx) ++ List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_BC), ZLine.implied(RET)) case _ => - Z80ExpressionCompiler.storeLarge(ctx, VariableExpression(ctx.function.name + ".return"), e) ++ fixStackOnReturn(ctx) ++ - List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BC), ZLine.implied(DISCARD_DE), ZLine.implied(RET)) + if (ctx.function.hasElidedReturnVariable) { + fixStackOnReturn(ctx) ++ + List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BC), ZLine.implied(DISCARD_DE), ZLine.implied(RET)) + } else { + Z80ExpressionCompiler.storeLarge(ctx, VariableExpression(ctx.function.name + ".return"), e) ++ fixStackOnReturn(ctx) ++ + List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BC), ZLine.implied(DISCARD_DE), ZLine.implied(RET)) + } } }) -> Nil diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 755c5621..e80dbb43 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -206,16 +206,63 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val pointiesUsed: mutable.Map[String, Set[String]] = mutable.Map() val removedThings: mutable.Set[String] = mutable.Set() - def isKnownPointy(callee: String, variable: String): Boolean = { - root.pointiesUsed.get(callee).exists(_.contains(variable)) - } - private def addThing(t: Thing, position: Option[Position]): Unit = { if (assertNotDefined(t.name, position)) { things(t.name.stripPrefix(prefix)) = t } } + def getReturnedVariables(statements: Seq[Statement]): Set[String] = { + statements.flatMap { + case ReturnStatement(Some(VariableExpression(v))) => Set(v) + case ReturnStatement(_) => Set("/none/", "|none|") + case x: CompoundStatement => getReturnedVariables(x.getChildStatements) + case _ => Set.empty[String] + }.toSet + } + + def coerceLocalVariableIntoGlobalVariable(localVarToRelativize: String, concreteGlobalTarget: String): Unit = { + log.trace(s"Coercing $localVarToRelativize to $concreteGlobalTarget") + def removeVariableImpl2(e: Environment, str: String): Unit = { + e.things -= str + e.things -= str + ".addr" + e.things -= str + ".addr.lo" + e.things -= str + ".addr.hi" + e.things -= str + ".pointer" + e.things -= str + ".pointer.lo" + e.things -= str + ".pointer.hi" + e.things -= str + ".rawaddr" + e.things -= str + ".rawaddr.lo" + e.things -= str + ".rawaddr.hi" + e.things -= str.stripPrefix(prefix) + e.things -= str.stripPrefix(prefix) + ".addr" + e.things -= str.stripPrefix(prefix) + ".addr.lo" + e.things -= str.stripPrefix(prefix) + ".addr.hi" + e.things -= str.stripPrefix(prefix) + ".pointer" + e.things -= str.stripPrefix(prefix) + ".pointer.lo" + e.things -= str.stripPrefix(prefix) + ".pointer.hi" + e.things -= str.stripPrefix(prefix) + ".rawaddr" + e.things -= str.stripPrefix(prefix) + ".rawaddr.lo" + e.things -= str.stripPrefix(prefix) + ".rawaddr.hi" + parent.foreach(x => removeVariableImpl2(x,str)) + } + removeVariableImpl2(this, prefix + localVarToRelativize) + val namePrefix = concreteGlobalTarget + '.' + root.things.filter { entry => + entry._1 == concreteGlobalTarget || entry._1.startsWith(namePrefix) + }.foreach { entry => + val name = entry._1 + val thing = entry._2 + val newName = if (name == concreteGlobalTarget) localVarToRelativize else localVarToRelativize + '.' + name.stripPrefix(namePrefix) + val newThing = thing match { + case t: VariableInMemory => RelativeVariable(prefix + newName, t.toAddress, t.typ, t.zeropage, t.declaredBank, t.isVolatile) + case t: ConstantThing => t.copy(name = prefix + newName) + case t => println(t); ??? + } + addThing(newThing, None) + } + } + def removeVariable(str: String): Unit = { log.trace("Removing variable: " + str) removeVariableImpl(str) @@ -883,7 +930,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val env = new Environment(Some(this), name + "$", cpuFamily, jobContext) stmt.params.foreach(p => env.registerParameter(p, options)) - val params = if (stmt.assembly) { + def params: ParamSignature = if (stmt.assembly) { AssemblyParamSignature(stmt.params.map { pd => val typ = env.get[Type](pd.typ) @@ -902,10 +949,12 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa }) } else { NormalParamSignature(stmt.params.map { pd => - env.get[MemoryVariable](pd.assemblyParamPassingConvention.asInstanceOf[ByVariable].name) + env.get[VariableInMemory](pd.assemblyParamPassingConvention.asInstanceOf[ByVariable].name) }) } - if (resultType.size > Cpu.getMaxSizeReturnableViaRegisters(options.platform.cpu, options)) { + var hasElidedReturnVariable = false + val hasReturnVariable = resultType.size > Cpu.getMaxSizeReturnableViaRegisters(options.platform.cpu, options) + if (hasReturnVariable) { registerVariable(VariableDeclarationStatement(stmt.name + ".return", stmt.resultType, None, global = true, stack = false, constant = false, volatile = false, register = false, None, None, None), options, isPointy = false) } stmt.statements match { @@ -937,6 +986,19 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case e: ExecutableStatement => Some(e) case _ => None } + if (hasReturnVariable) { + val set = getReturnedVariables(executableStatements) + if (set.size == 1) { + env.maybeGet[Variable](set.head) match { + case Some(v: MemoryVariable) => + if (!v.isVolatile && v.typ == resultType && v.alloc == VariableAllocationMethod.Auto) { + env.coerceLocalVariableIntoGlobalVariable(set.head, stmt.name + ".return") + hasElidedReturnVariable = true + } + case _ => + } + } + } val paramForAutomaticReturn: List[Option[Expression]] = if (stmt.isMacro || stmt.assembly) { Nil } else if (statements.isEmpty) { @@ -982,6 +1044,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa stackVariablesSize, stmt.address.map(a => this.eval(a).getOrElse(errorConstant(s"Address of `${stmt.name}` is not a constant"))), executableStatements ++ paramForAutomaticReturn.map(param => ReturnStatement(param).pos(executableStatements.lastOption.fold(stmt.position)(_.position))), + hasElidedReturnVariable = hasElidedReturnVariable, interrupt = stmt.interrupt, kernalInterrupt = stmt.kernalInterrupt, reentrant = stmt.reentrant, @@ -1290,10 +1353,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa if (stmt.register && stmt.address.isDefined) log.error(s"`$name` cannot by simultaneously at an address and in a register", position) if (stmt.stack) { val v = StackVariable(prefix + name, typ, this.baseStackOffset) - addThing(v, stmt.position) - for((suffix, offset, t) <- getSubvariables(typ)) { - addThing(StackVariable(prefix + name + suffix, t, baseStackOffset + offset), stmt.position) - } + addVariable(options, name, v, stmt.position) baseStackOffset += typ.size } else { val (v, addr) = stmt.address.fold[(VariableInMemory, Constant)]({ @@ -1326,19 +1386,37 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa registerAddressConstant(v, stmt.position, options, Some(typ)) (v, addr) }) - addThing(v, stmt.position) - if (!v.isInstanceOf[MemoryVariable]) { - addThing(ConstantThing(v.name + "`", addr, b), stmt.position) - } - for((suffix, offset, t) <- getSubvariables(typ)) { - val subv = RelativeVariable(prefix + name + suffix, addr + offset, t, zeropage = v.zeropage, declaredBank = stmt.bank, isVolatile = v.isVolatile) - addThing(subv, stmt.position) - registerAddressConstant(subv, stmt.position, options, Some(t)) - } + addVariable(options, name, v, stmt.position) } } } + def addVariable(options: CompilationOptions, localName: String, variable: Variable, position: Option[Position]): Unit = { + variable match { + case v: StackVariable => + addThing(v, position) + for ((suffix, offset, t) <- getSubvariables(v.typ)) { + addThing(StackVariable(prefix + localName + suffix, t, baseStackOffset + offset), position) + } + case v: MemoryVariable => + addThing(v, position) + for ((suffix, offset, t) <- getSubvariables(v.typ)) { + val subv = RelativeVariable(prefix + localName + suffix, v.toAddress + offset, t, zeropage = v.zeropage, declaredBank = v.declaredBank, isVolatile = v.isVolatile) + addThing(subv, position) + registerAddressConstant(subv, position, options, Some(t)) + } + case v: VariableInMemory => + addThing(v, position) + addThing(ConstantThing(v.name + "`", v.toAddress, get[Type]("word")), position) + for ((suffix, offset, t) <- getSubvariables(v.typ)) { + val subv = RelativeVariable(prefix + localName + suffix, v.toAddress + offset, t, zeropage = v.zeropage, declaredBank = v.declaredBank, isVolatile = v.isVolatile) + addThing(subv, position) + registerAddressConstant(subv, position, options, Some(t)) + } + case _ => ??? + } + } + def getSubvariables(typ: Type): List[(String, Int, VariableType)] = { val b = get[VariableType]("byte") val w = get[VariableType]("word") @@ -1650,5 +1728,5 @@ object Environment { "for", "if", "do", "while", "else", "return", "default", "to", "until", "paralleluntil", "parallelto", "downto", "inline", "noinline" ) ++ predefinedFunctions - val invalidFieldNames: Set[String] = Set("addr", "rawaddr", "pointer") + val invalidFieldNames: Set[String] = Set("addr", "rawaddr", "pointer", "return") } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index 662c68f3..dff3b221 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -350,6 +350,7 @@ case class NormalFunction(name: String, stackVariablesSize: Int, address: Option[Constant], code: List[ExecutableStatement], + hasElidedReturnVariable: Boolean, interrupt: Boolean, kernalInterrupt: Boolean, reentrant: Boolean, @@ -373,7 +374,7 @@ trait ParamSignature { def length: Int } -case class NormalParamSignature(params: List[MemoryVariable]) extends ParamSignature { +case class NormalParamSignature(params: List[VariableInMemory]) extends ParamSignature { override def length: Int = params.length override def types: List[Type] = params.map(_.typ) diff --git a/src/test/scala/millfork/test/NodeOptimizationSuite.scala b/src/test/scala/millfork/test/NodeOptimizationSuite.scala index 13f196a7..e20f3991 100644 --- a/src/test/scala/millfork/test/NodeOptimizationSuite.scala +++ b/src/test/scala/millfork/test/NodeOptimizationSuite.scala @@ -1,6 +1,7 @@ package millfork.test -import millfork.test.emu.EmuNodeOptimizedRun +import millfork.Cpu +import millfork.test.emu.{EmuNodeOptimizedRun, EmuUnoptimizedCrossPlatformRun} import org.scalatest.{FunSuite, Matchers} /** @@ -29,4 +30,27 @@ class NodeOptimizationSuite extends FunSuite with Matchers { | } """.stripMargin) } + + test("Returning one variable") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)( + """ + | int64 output @$c000 + | void main () { + | int64 tmp + | tmp = f() + | output += tmp + | tmp = g($2000000) + | output += tmp + | } + | noinline int64 f() { + | int64 a + | a = 0 + | a.b3 = 1 + | return a + | } + | noinline int64 g(int64 p) = p + """.stripMargin) { m => + m.readLong(0xc000) should equal (0x3000000) + } + } }