diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index f79d524a..daac42cb 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -11,6 +11,7 @@ import millfork.output._ import org.apache.commons.lang3.StringUtils import scala.collection.mutable +import scala.collection.mutable.ListBuffer /** @@ -749,6 +750,45 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } + def registerStruct(stmt: StructDefinitionStatement): Unit = { + stmt.fields.foreach{ f => + if (Environment.invalidFieldNames.contains(f._2)) { + log.error(s"Invalid field name: `${f._2}`", stmt.position) + } + } + addThing(StructType(stmt.name, stmt.fields), stmt.position) + } + + def getTypeSize(name: String, path: Set[String]): Int = { + if (path.contains(name)) return -1 + val t = get[Type](name) + t match { + case s: StructType => + if (s.mutableSize >= 0) s.mutableSize + else { + val newPath = path + name + var sum = 0 + for( (fieldType, _) <- s.fields) { + val fieldSize = getTypeSize(fieldType, newPath) + if (fieldSize < 0) return -1 + sum += fieldSize + } + s.mutableSize = sum + if (sum > 0xff) { + log.error(s"Struct `$name` is larger than 255 bytes") + } + val b = get[Type]("byte") + var offset = 0 + for( (fieldType, fieldName) <- s.fields) { + addThing(ConstantThing(s"$name.$fieldName.offset", NumericConstant(offset, 1), b), None) + offset += getTypeSize(fieldType, newPath) + } + sum + } + case _ => t.size + } + } + def collectPointies(stmts: Seq[Statement]): Set[String] = { val pointies: mutable.Set[String] = new mutable.HashSet() pointies ++= stmts.flatMap(_.getAllPointies) @@ -941,7 +981,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa registerAddressConstant(v, stmt.position, options) val addr = v.toAddress for((suffix, offset, t) <- getSubvariables(typ)) { - addThing(RelativeVariable(v.name + suffix, addr + offset, t, zeropage = zp, None, isVolatile = v.isVolatile), stmt.position) + val subv = RelativeVariable(v.name + suffix, addr + offset, t, zeropage = zp, None, isVolatile = v.isVolatile) + addThing(subv, stmt.position) + registerAddressConstant(subv, stmt.position, options) } case ByMosRegister(_) => () case ByZRegister(_) => () @@ -1218,7 +1260,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa addThing(ConstantThing(v.name + "`", addr, b), stmt.position) } for((suffix, offset, t) <- getSubvariables(typ)) { - addThing(RelativeVariable(prefix + name + suffix, addr + offset, t, zeropage = v.zeropage, declaredBank = stmt.bank, isVolatile = v.isVolatile), stmt.position) + val subv = RelativeVariable(prefix + name + suffix, addr + offset, t, zeropage = v.zeropage, declaredBank = stmt.bank, isVolatile = v.isVolatile) + addThing(subv, stmt.position) + registerAddressConstant(subv, stmt.position, options) } } } @@ -1231,7 +1275,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa return (".lo", 0, b) :: (".hi", 1, b) :: (".loword", 0, w) :: + (".loword.lo", 0, b) :: + (".loword.hi", 1, b) :: (".b2b3", 2, w) :: + (".b2b3.lo", 2, b) :: + (".b2b3.hi", 3, b) :: List.tabulate(typ.size) { i => (".b" + i, i, b) } } typ match { @@ -1241,20 +1289,46 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa (".hi", 1, b)) case 3 => List( (".loword", 0, w), + (".loword.lo", 0, b), + (".loword.hi", 1, b), (".hiword", 1, w), + (".hiword.lo", 1, b), + (".hiword.hi", 2, b), (".b0", 0, b), (".b1", 1, b), (".b2", 2, b)) case 4 => List( (".loword", 0, w), (".hiword", 2, w), + (".loword.lo", 0, b), + (".loword.hi", 1, b), + (".hiword.lo", 2, b), + (".hiword.hi", 3, b), (".b0", 0, b), (".b1", 1, b), (".b2", 2, b), (".b3", 3, b)) - case sz if sz > 4 => (".lo", 0, b) :: (".loword", 0, w) :: List.tabulate(sz){ i => (".b" + i, i, b) } + case sz if sz > 4 => + (".lo", 0, b) :: + (".loword", 0, w) :: + (".loword.lo", 0, b) :: + (".loword.hi", 1, b) :: + List.tabulate(sz){ i => (".b" + i, i, b) } case _ => Nil } + case s: StructType => + val builder = new ListBuffer[(String, Int, VariableType)] + var offset = 0 + for((typeName, fieldName) <- s.fields) { + val typ = get[VariableType](typeName) + val suffix = "." + fieldName + builder += ((suffix, offset, typ)) + builder ++= getSubvariables(typ).map { + case (innerSuffix, innerOffset, innerType) => (suffix + innerSuffix, offset + innerOffset, innerType) + } + offset += typ.size + } + builder.toList case _ => Nil } } @@ -1309,6 +1383,23 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa aliasesToAdd.foreach(a => things += a.name -> a) } + def fixStructSizes(): Unit = { + val allStructTypes = things.values.flatMap { + case StructType(name, _) => Some(name) + case _ => None + } + var iterations = allStructTypes.size + while (iterations >= 0) { + var ok = true + for (t <- allStructTypes) { + if (getTypeSize(t, Set()) < 0) ok = false + } + if (ok) return + iterations -= 1 + } + log.error("Cycles in struct definitions found") + } + def collectDeclarations(program: Program, options: CompilationOptions): Unit = { val b = get[VariableType]("byte") val v = get[Type]("void") @@ -1323,6 +1414,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case e: EnumDefinitionStatement => registerEnum(e) case _ => } + program.declarations.foreach { + case s: StructDefinitionStatement => registerStruct(s) + case _ => + } + fixStructSizes() val pointies = collectPointies(program.declarations) pointiesUsed("") = pointies program.declarations.foreach { @@ -1457,4 +1553,5 @@ object Environment { "for", "if", "do", "while", "else", "return", "default", "to", "until", "paralleluntil", "parallelto", "downto", "inline", "noinline" ) ++ predefinedFunctions + val invalidFieldNames: Set[String] = Set("addr") } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index fef90220..eaec0d0f 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -76,6 +76,12 @@ case class EnumType(name: String, count: Option[Int]) extends VariableType { override def isSigned: Boolean = false } +case class StructType(name: String, fields: List[(String, String)]) extends VariableType { + override def size: Int = mutableSize + var mutableSize: Int = -1 + override def isSigned: Boolean = false +} + sealed trait BooleanType extends Type { def size = 0 diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 361965e9..5763a090 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -300,6 +300,10 @@ case class EnumDefinitionStatement(name: String, variants: List[(String, Option[ override def getAllExpressions: List[Expression] = variants.flatMap(_._2) } +case class StructDefinitionStatement(name: String, fields: List[(String, String)]) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = Nil +} + case class ArrayDeclarationStatement(name: String, bank: Option[String], length: Option[Expression], diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 282b9de4..f30167a7 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -472,9 +472,25 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri variants <- enumVariants ~/ Pass } yield Seq(EnumDefinitionStatement(name, variants).pos(p)) + val structField: P[(String, String)] = for { + typ <- identifier ~/ HWS + name <- identifier ~ HWS + } yield typ -> name + + val structFields: P[List[(String, String)]] = + ("{" ~/ AWS ~ structField.rep(sep = NoCut(EOLOrComma) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) + + val structDefinition: P[Seq[StructDefinitionStatement]] = for { + p <- position() + _ <- "struct" ~ !letterOrDigit ~/ SWS ~ position("struct name") + name <- identifier ~/ HWS + _ <- position("struct defintion block") + fields <- structFields ~/ Pass + } yield Seq(StructDefinitionStatement(name, fields).pos(p)) + val program: Parser[Program] = for { _ <- Start ~/ AWS ~/ Pass - definitions <- (importStatement | arrayDefinition | aliasDefinition | enumDefinition | functionDefinition | globalVariableDefinition).rep(sep = EOL) + definitions <- (importStatement | arrayDefinition | aliasDefinition | enumDefinition | structDefinition | functionDefinition | globalVariableDefinition).rep(sep = EOL) _ <- AWS ~ End } yield Program(definitions.flatten.toList) diff --git a/src/test/scala/millfork/test/StructSuite.scala b/src/test/scala/millfork/test/StructSuite.scala new file mode 100644 index 00000000..dd1c43ec --- /dev/null +++ b/src/test/scala/millfork/test/StructSuite.scala @@ -0,0 +1,71 @@ +package millfork.test + +import millfork.Cpu +import millfork.test.emu.EmuUnoptimizedCrossPlatformRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class StructSuite extends FunSuite with Matchers { + + test("Basic struct support") { + // TODO: 8080 has broken stack operations, fix and uncomment! + EmuUnoptimizedCrossPlatformRun(Cpu.StrictMos, Cpu.Z80/*, Cpu.Intel8080*/)(""" + | struct point { + | byte x + | byte y + | byte z + | byte colour + | } + | point output @$c000 + | void main () { + | stack point p + | p = f() + | output = p + | } + | point f() { + | point x + | x.x = 77 + | x.y = 88 + | x.z = 99 + | x.colour = 14 + | return x + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(77) + m.readByte(0xc001) should equal(88) + m.readByte(0xc002) should equal(99) + m.readByte(0xc003) should equal(14) + } + } + + test("Nested structs") { + EmuUnoptimizedCrossPlatformRun(Cpu.StrictMos, Cpu.Intel8080)(""" + | struct inner { word x, word y } + | struct s { + | word w + | byte b + | pointer p + | inner i + | } + | s output @$c000 + | void main () { + | output.b = 5 + | output.w.hi = output.b + | output.p = output.w.addr + | output.p[0] = 6 + | output.i.x.lo = 55 + | output.i.x.hi = s.p.offset + | output.i.y = 777 + | } + """.stripMargin) { m => + m.readWord(0xc000) should equal(0x506) + m.readByte(0xc002) should equal(5) + m.readWord(0xc003) should equal(0xc000) + m.readByte(0xc005) should equal(55) + m.readByte(0xc006) should equal(3) + m.readWord(0xc007) should equal(777) + } + } +}