diff --git a/docs/lang/operators.md b/docs/lang/operators.md index 9656be62..c651e402 100644 --- a/docs/lang/operators.md +++ b/docs/lang/operators.md @@ -238,5 +238,7 @@ but not `word` → `byte` some enum → `word` +* `sizeof`: size of the argument in bytes; the argument can be an expression or a type, +and the result is a constant of either `byte` or `word` type, depending on situation diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 2d026114..a41926fa 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -2,7 +2,7 @@ package millfork.compiler import millfork.env._ import millfork.node._ -import millfork.error.ConsoleLogger +import millfork.error.{ConsoleLogger, Logger} import millfork.assembly.AbstractCode /** @@ -21,7 +21,7 @@ class AbstractExpressionCompiler[T <: AbstractCode] { } } - def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = AbstractExpressionCompiler.lookupFunction(ctx, f) + def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = AbstractExpressionCompiler.lookupFunction(ctx.env, ctx.log, f) def assertCompatible(exprType: Type, variableType: Type): Unit = { // TODO @@ -168,8 +168,12 @@ class AbstractExpressionCompiler[T <: AbstractCode] { } object AbstractExpressionCompiler { + @inline def getExpressionType(ctx: CompilationContext, expr: Expression): Type = { - val env = ctx.env + getExpressionType(ctx.env, ctx.log, expr) + } + + def getExpressionType(env: Environment, log: Logger, expr: Expression): Type = { val b = env.get[Type]("byte") val bool = env.get[Type]("bool$") val v = env.get[Type]("void") @@ -187,35 +191,39 @@ object AbstractExpressionCompiler { case VariableExpression(name) => env.get[TypedThing](name, expr.position).typ case HalfWordExpression(param, _) => - getExpressionType(ctx, param) + getExpressionType(env, log, param) b case IndexedExpression(name, _) => env.getPointy(name).elementType case SeparateBytesExpression(hi, lo) => - if (getExpressionType(ctx, hi).size > 1) ctx.log.error("Hi byte too large", hi.position) - if (getExpressionType(ctx, lo).size > 1) ctx.log.error("Lo byte too large", lo.position) + if (getExpressionType(env, log, hi).size > 1) log.error("Hi byte too large", hi.position) + if (getExpressionType(env, log, lo).size > 1) log.error("Lo byte too large", lo.position) w - case SumExpression(params, _) => params.map { case (_, e) => getExpressionType(ctx, e).size }.max match { + case SumExpression(params, _) => params.map { case (_, e) => getExpressionType(env, log, e).size }.max match { case 1 => b case 2 => w - case _ => ctx.log.error("Adding values bigger than words", expr.position); w + case _ => log.error("Adding values bigger than words", expr.position); w } case FunctionCallExpression("nonet", params) => w case FunctionCallExpression("not", params) => bool case FunctionCallExpression("hi", params) => b case FunctionCallExpression("lo", params) => b - case FunctionCallExpression("*", params) => b - case FunctionCallExpression("|" | "&" | "^", params) => params.map { e => getExpressionType(ctx, e).size }.max match { + case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match { case 1 => b case 2 => w - case _ => ctx.log.error("Adding values bigger than words", expr.position); w + } + case FunctionCallExpression("*", params) => b + case FunctionCallExpression("|" | "&" | "^", params) => params.map { e => getExpressionType(env, log, e).size }.max match { + case 1 => b + case 2 => w + case _ => log.error("Adding values bigger than words", expr.position); w } case FunctionCallExpression("<<", List(a1, a2)) => - if (getExpressionType(ctx, a2).size > 1) ctx.log.error("Shift amount too large", a2.position) - getExpressionType(ctx, a1) + if (getExpressionType(env, log, a2).size > 1) log.error("Shift amount too large", a2.position) + getExpressionType(env, log, a1) case FunctionCallExpression(">>", List(a1, a2)) => - if (getExpressionType(ctx, a2).size > 1) ctx.log.error("Shift amount too large", a2.position) - getExpressionType(ctx, a1) + if (getExpressionType(env, log, a2).size > 1) log.error("Shift amount too large", a2.position) + getExpressionType(env, log, a1) case FunctionCallExpression("<<'", params) => b case FunctionCallExpression(">>'", params) => b case FunctionCallExpression(">>>>", params) => b @@ -242,11 +250,11 @@ object AbstractExpressionCompiler { case FunctionCallExpression("<<'=", params) => v case FunctionCallExpression(">>'=", params) => v case f@FunctionCallExpression(name, params) => - ctx.env.maybeGet[Type](name) match { + env.maybeGet[Type](name) match { case Some(typ) => typ case None => - lookupFunction(ctx, f).returnType + lookupFunction(env, log, f).returnType } } } @@ -274,9 +282,9 @@ object AbstractExpressionCompiler { } } - def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = { - val paramsWithTypes = f.expressions.map(x => getExpressionType(ctx, x) -> x) - ctx.env.lookupFunction(f.functionName, paramsWithTypes).getOrElse( - ctx.log.fatal(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)) + def lookupFunction(env: Environment, log: Logger, f: FunctionCallExpression): MangledFunction = { + val paramsWithTypes = f.expressions.map(x => getExpressionType(env, log, x) -> x) + env.lookupFunction(f.functionName, paramsWithTypes).getOrElse( + log.fatal(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)) } } diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index 83cbb113..6f81628f 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -222,6 +222,6 @@ object AbstractStatementPreprocessor { "<<", "<<'", ">>", ">>'", ">>>>", "&", "&&", "||", "|", "^", "==", "!=", "<", ">", ">=", "<=", - "not", "hi", "lo", "nonet" + "not", "hi", "lo", "nonet", "sizeof" ) } \ No newline at end of file diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index c2592b40..e6e122d7 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -742,6 +742,17 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { } else compilation } } + case "sizeof" => + env.eval(expr) match { + case Some(c) => + exprTypeAndVariable match { + case Some((t, v)) => + compileConstant(ctx, c, v) + case _ => + Nil + } + case None => Nil + } case "nonet" => if (params.length != 1) { ctx.log.error("Invalid number of parameters", f.position) diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index eb7bd42c..b460f4e3 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -501,6 +501,9 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case ZExpressionTarget.DEHL => List(ZLine.ldImm8(ZRegister.H, 0), ZLine.ldImm16(ZRegister.DE, 0)) }) } + case "sizeof" => + ctx.log.fatal("Unreachable branch: 8080 sizeof") + Nil case "nonet" => if (params.length != 1) { ctx.log.error("Invalid number of parameters", f.position) diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 9f8a4fd8..f0bf1e19 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -4,7 +4,7 @@ import millfork.assembly.BranchingOpcodeMapping import millfork.{env, _} import millfork.assembly.mos.Opcode import millfork.assembly.z80.{IfFlagClear, IfFlagSet, ZFlag} -import millfork.compiler.LabelGenerator +import millfork.compiler.{AbstractExpressionCompiler, LabelGenerator} import millfork.error.Logger import millfork.node._ import millfork.output._ @@ -266,17 +266,18 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa if (things.contains(name)) { val t: Thing = things(name) val clazz = implicitly[Manifest[T]].runtimeClass - if ((t ne null) && clazz.isInstance(t)) { - Some(t.asInstanceOf[T]) - } else { - t match { - case Alias(_, target, deprectated) => - if (deprectated) { - log.warn(s"Alias `$name` is deprecated, use `$target` instead") - } - root.maybeGet[T](target) - case _ => None - } + t match { + case Alias(_, target, deprectated) => + if (deprectated) { + log.warn(s"Alias `$name` is deprecated, use `$target` instead") + } + root.maybeGet[T](target) + case _ => + if ((t ne null) && clazz.isInstance(t)) { + Some(t.asInstanceOf[T]) + } else { + None + } } } else parent.flatMap { _.maybeGet[T](name) @@ -453,6 +454,29 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } + def evalSizeof(expr: Expression): Constant = { + val size: Int = expr match { + case VariableExpression(name) => + maybeGet[Thing](name) match { + case None => + log.error(s"`$name` is not defined") + 1 + case Some(thing) => thing match { + case t: Type => t.size + case v: Variable => v.typ.size + case a: InitializedArray => a.elementType.size * a.contents.length + case a: UninitializedArray => a.sizeInBytes + case x => + log.error("Invalid parameter for expr: " + name) + 1 + } + } + case _ => + AbstractExpressionCompiler.getExpressionType(this, log, expr).size + } + NumericConstant(size, Constant.minimumSize(size)) + } + def eval(e: Expression, vars: Map[String, Constant]): Option[Constant] = evalImpl(e, Some(vars)) def eval(e: Expression): Option[Constant] = evalImpl(e, None) @@ -493,6 +517,13 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } yield hc.asl(8) + lc case FunctionCallExpression(name, params) => name match { + case "sizeof" => + if (params.size == 1) { + Some(evalSizeof(params.head)) + } else { + log.error("Invalid number of parameters for `sizeof`", e.position) + Some(Constant.One) + } case "hi" => if (params.size == 1) { eval(params.head).map(_.hiByte.quickSimplify) @@ -1282,6 +1313,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa nameCheck(l) case SumExpression(params, _) => nameCheck(params.map(_._2)) + case FunctionCallExpression("sizeof", List(ve@VariableExpression(e))) => + checkName[Thing]("Type, variable or constant", e, ve.position) case FunctionCallExpression(name, params) => if (name.exists(_.isLetter) && !Environment.predefinedFunctions(name)) { checkName[CallableThing]("Function or type", name, node.position) @@ -1291,5 +1324,5 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } object Environment { - val predefinedFunctions = Set("not", "hi", "lo", "nonet") + val predefinedFunctions = Set("not", "hi", "lo", "nonet", "sizeof") } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index 3a2aa47e..f0443bd2 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -199,7 +199,7 @@ trait MfArray extends ThingInMemory with IndexableThing { def elementType: VariableType } -case class UninitializedArray(name: String, sizeInBytes: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType, override val alignment: MemoryAlignment) extends MfArray with UninitializedMemory { +case class UninitializedArray(name: String, /* TODO: what if larger elements? */ sizeInBytes: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType, override val alignment: MemoryAlignment) extends MfArray with UninitializedMemory { override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this) override def alloc: VariableAllocationMethod.Value = VariableAllocationMethod.Static diff --git a/src/test/scala/millfork/test/SizeofSuite.scala b/src/test/scala/millfork/test/SizeofSuite.scala new file mode 100644 index 00000000..30477018 --- /dev/null +++ b/src/test/scala/millfork/test/SizeofSuite.scala @@ -0,0 +1,35 @@ +package millfork.test +import millfork.Cpu +import millfork.test.emu.{EmuBenchmarkRun, EmuOptimizedCmosRun, EmuOptimizedRun, EmuUnoptimizedCrossPlatformRun} +import org.scalatest.{AppendedClues, FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class SizeofSuite extends FunSuite with Matchers with AppendedClues { + + test("Basic sizeof test") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)( + """ + | const byte sizeofbyte = sizeof(byte) + | array output [6] @$c000 + | void main () { + | byte a + | word b + | output[0] = sizeofbyte + | output[1] = sizeof(a) + | output[2] = sizeof(word) + | output[3] = sizeof(b) + | output[4] = sizeof(output[1]) + | output[5] = sizeof(long) + | } + """.stripMargin){m => + m.readByte(0xc000) should equal(1) + m.readByte(0xc001) should equal(1) + m.readByte(0xc002) should equal(2) + m.readByte(0xc003) should equal(2) + m.readByte(0xc004) should equal(1) + m.readByte(0xc005) should equal(4) + } + } +}