From e394fe15c3689d7ed98c8b80b2d10d1da79ff0c9 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Tue, 25 Jun 2019 00:45:49 +0200 Subject: [PATCH] Add struct literals --- docs/lang/types.md | 8 +++ src/main/scala/millfork/env/Constant.scala | 28 +++++++++ src/main/scala/millfork/env/Environment.scala | 37 ++++++++++- src/main/scala/millfork/env/Thing.scala | 2 + .../millfork/output/AbstractAssembler.scala | 7 +++ .../scala/millfork/test/StructSuite.scala | 62 +++++++++++++++++++ 6 files changed, 141 insertions(+), 3 deletions(-) diff --git a/docs/lang/types.md b/docs/lang/types.md index 72969a67..3bfcd6b7 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -147,6 +147,12 @@ Offsets are available as `structname.fieldname.offset`: // alternatively: ptr = p.y.addr +You can create constant expressions of struct types using so-called struct constructors, e.g.: + + point(5,6) + +All arguments to the constructor must be constant. + ## Unions union { } @@ -163,3 +169,5 @@ start at the same point in memory and therefore overlap each other. if u.w == 0 { ok() } Offset constants are also available, but they're obviously all zero. + +Unions currently do not have an equivalent of struct constructors. This may be improved on in the future. diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index e34ba6bf..89252bf8 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -124,6 +124,34 @@ case class AssertByte(c: Constant) extends Constant { override def toIntelString: String = c.toIntelString } +case class StructureConstant(typ: StructType, fields: List[Constant]) extends Constant { + override def toIntelString: String = typ.name + fields.map(_.toIntelString).mkString("(",",",")") + + override def toString: String = typ.name + fields.map(_.toString).mkString("(",",",")") + + override def requiredSize: Int = typ.size + + override def isRelatedTo(v: Thing): Boolean = fields.exists(_.isRelatedTo(v)) + + override def refersTo(name: String): Boolean = typ.name == name || fields.exists(_.refersTo(name)) + + override def loByte: Constant = subbyte(0) + + override def hiByte: Constant = subbyte(1) + + override def subbyte(index: Int): Constant = { + var offset = 0 + for ((fv, (ft, _)) <- fields.zip(typ.mutableFieldsWithTypes)) { + val fs = ft.size + if (index < offset + fs) { + return fv.subbyte(index - offset) + } + offset += fs + } + Constant.Zero + } +} + case class UnexpandedConstant(name: String, requiredSize: Int) extends Constant { override def isRelatedTo(v: Thing): Boolean = false diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 7bdd075b..3581c2de 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -552,6 +552,12 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case Some(ff: MangledFunction) => if (ff.returnType.isSigned) Some(e) -> Constant.Zero else variable -> constant + case Some(t: StructType) => + // dunno what to do + variable -> constant + case Some(t: UnionType) => + // dunno what to do + None -> Constant.Zero case Some(t: Type) => if (t.isSigned) Some(e) -> Constant.Zero else variable -> constant @@ -745,10 +751,19 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case "||" | "|" => constantOperation(MathOperator.Or, params) case _ => - if (params.size == 1) { - return maybeGet[Type](name).flatMap(_ => eval(params.head)) + maybeGet[Type](name) match { + case Some(t: StructType) => + if (params.size == t.fields.size) { + sequence(params.map(eval)).map(fields => StructureConstant(t, fields)) + } else None + case Some(_: UnionType) => + None + case Some(_) => + if (params.size == 1) { + eval(params.head) + } else None + case _ => None } - None } } }.map(_.quickSimplify) @@ -1699,6 +1714,21 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa log.error("Cycles in struct definitions found") } + def fixStructFields(): Unit = { + things.values.foreach { + case st@StructType(_, fields) => + st.mutableFieldsWithTypes = fields.map { + case (tn, name) => get[Type](tn) -> name + } + case ut@UnionType(_, fields) => + ut.mutableFieldsWithTypes = fields.map { + case (tn, name) => get[Type](tn) -> name + } + case _ => () + } + + } + def collectDeclarations(program: Program, options: CompilationOptions): Unit = { val b = get[VariableType]("byte") val v = get[Type]("void") @@ -1719,6 +1749,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case _ => } fixStructSizes() + fixStructFields() val pointies = collectPointies(program.declarations) pointiesUsed("") = pointies program.declarations.foreach { diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index 7c7b2686..39b86951 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -107,12 +107,14 @@ sealed trait CompoundVariableType extends VariableType case class StructType(name: String, fields: List[(String, String)]) extends CompoundVariableType { override def size: Int = mutableSize var mutableSize: Int = -1 + var mutableFieldsWithTypes: List[(Type, String)] = Nil override def isSigned: Boolean = false } case class UnionType(name: String, fields: List[(String, String)]) extends CompoundVariableType { override def size: Int = mutableSize var mutableSize: Int = -1 + var mutableFieldsWithTypes: List[(Type, String)] = Nil override def isSigned: Boolean = false } diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index e9ca5961..6d7f5b01 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -124,6 +124,13 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program if (labelMap.contains(name)) labelMap(name)._2 else ??? case SubbyteConstant(cc, i) => deepConstResolve(cc).>>>(i * 8).&(0xff) + case s: StructureConstant => + s.typ.size match { + case 0 => 0 + case 1 => deepConstResolve(s.subbyte(0)) + case 2 => deepConstResolve(s.subword(0)) + case _ => ??? + } case CompoundConstant(operator, lc, rc) => val l = deepConstResolve(lc) val r = deepConstResolve(rc) diff --git a/src/test/scala/millfork/test/StructSuite.scala b/src/test/scala/millfork/test/StructSuite.scala index 9b434c4f..cb3b4b88 100644 --- a/src/test/scala/millfork/test/StructSuite.scala +++ b/src/test/scala/millfork/test/StructSuite.scala @@ -107,4 +107,66 @@ class StructSuite extends FunSuite with Matchers { m.readByte(0xc400) should equal(1) } } + + test("Struct literals") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086)(""" + | struct point { byte x, byte y } + | const point origin = point(1,2) + | noinline point move_right(point p) { + | p.x += 1 + | return p + | } + | byte output @$c000 + | void main () { + | point p + | p = move_right(origin) + | p = move_right(point(1,2)) + | output = p.x + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(2) + } + } + + test("Struct literals 2") { + val code = """ + | struct point { word x, word y } + | const point origin = point(6, 8) + | noinline point move_right(point p) { + | p.x += 1 + | return p + | } + | noinline point move_up(point p) { + | p.y += 1 + | return p + | } + | word outputX @$c000 + | word outputY @$c002 + | void main () { + | point p + | p = point(0,0) + | p = move_up(point(0,0)) + | p = origin + | p = move_up(p) // ↑ + | p = move_right(p) // → + | p = move_right(p) // → + | p = move_up(p) // ↑ + | p = move_right(p) // → + | p = move_right(p) // → + | p = move_up(p) // ↑ + | p = move_up(p) // ↑ + | p = move_up(p) // ↑ + | p = move_right(p) // → + | p = move_up(p) // ↑ + | p = move_up(p) // ↑ + | p = move_up(p) // ↑ + | outputX = p.x + | outputY = p.y + | } + """.stripMargin + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086)(code){ m => + m.readWord(0xc000) should equal(code.count(_ == '→') + 6) + m.readWord(0xc002) should equal(code.count(_ == '↑') + 8) + } + } }