1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-08-12 11:29:20 +00:00

Add sizeof operator

This commit is contained in:
Karol Stasiak 2018-12-16 15:43:17 +01:00
parent badd7ef1d8
commit cd8697552c
8 changed files with 128 additions and 36 deletions

View File

@ -238,5 +238,7 @@ but not
`word``byte`
some enum → `word`
* `sizeof`: size of the argument in bytes; the argument can be an expression or a type,
and the result is a constant of either `byte` or `word` type, depending on situation

View File

@ -2,7 +2,7 @@ package millfork.compiler
import millfork.env._
import millfork.node._
import millfork.error.ConsoleLogger
import millfork.error.{ConsoleLogger, Logger}
import millfork.assembly.AbstractCode
/**
@ -21,7 +21,7 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
}
}
def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = AbstractExpressionCompiler.lookupFunction(ctx, f)
def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = AbstractExpressionCompiler.lookupFunction(ctx.env, ctx.log, f)
def assertCompatible(exprType: Type, variableType: Type): Unit = {
// TODO
@ -168,8 +168,12 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
}
object AbstractExpressionCompiler {
@inline
def getExpressionType(ctx: CompilationContext, expr: Expression): Type = {
val env = ctx.env
getExpressionType(ctx.env, ctx.log, expr)
}
def getExpressionType(env: Environment, log: Logger, expr: Expression): Type = {
val b = env.get[Type]("byte")
val bool = env.get[Type]("bool$")
val v = env.get[Type]("void")
@ -187,35 +191,39 @@ object AbstractExpressionCompiler {
case VariableExpression(name) =>
env.get[TypedThing](name, expr.position).typ
case HalfWordExpression(param, _) =>
getExpressionType(ctx, param)
getExpressionType(env, log, param)
b
case IndexedExpression(name, _) =>
env.getPointy(name).elementType
case SeparateBytesExpression(hi, lo) =>
if (getExpressionType(ctx, hi).size > 1) ctx.log.error("Hi byte too large", hi.position)
if (getExpressionType(ctx, lo).size > 1) ctx.log.error("Lo byte too large", lo.position)
if (getExpressionType(env, log, hi).size > 1) log.error("Hi byte too large", hi.position)
if (getExpressionType(env, log, lo).size > 1) log.error("Lo byte too large", lo.position)
w
case SumExpression(params, _) => params.map { case (_, e) => getExpressionType(ctx, e).size }.max match {
case SumExpression(params, _) => params.map { case (_, e) => getExpressionType(env, log, e).size }.max match {
case 1 => b
case 2 => w
case _ => ctx.log.error("Adding values bigger than words", expr.position); w
case _ => log.error("Adding values bigger than words", expr.position); w
}
case FunctionCallExpression("nonet", params) => w
case FunctionCallExpression("not", params) => bool
case FunctionCallExpression("hi", params) => b
case FunctionCallExpression("lo", params) => b
case FunctionCallExpression("*", params) => b
case FunctionCallExpression("|" | "&" | "^", params) => params.map { e => getExpressionType(ctx, e).size }.max match {
case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match {
case 1 => b
case 2 => w
case _ => ctx.log.error("Adding values bigger than words", expr.position); w
}
case FunctionCallExpression("*", params) => b
case FunctionCallExpression("|" | "&" | "^", params) => params.map { e => getExpressionType(env, log, e).size }.max match {
case 1 => b
case 2 => w
case _ => log.error("Adding values bigger than words", expr.position); w
}
case FunctionCallExpression("<<", List(a1, a2)) =>
if (getExpressionType(ctx, a2).size > 1) ctx.log.error("Shift amount too large", a2.position)
getExpressionType(ctx, a1)
if (getExpressionType(env, log, a2).size > 1) log.error("Shift amount too large", a2.position)
getExpressionType(env, log, a1)
case FunctionCallExpression(">>", List(a1, a2)) =>
if (getExpressionType(ctx, a2).size > 1) ctx.log.error("Shift amount too large", a2.position)
getExpressionType(ctx, a1)
if (getExpressionType(env, log, a2).size > 1) log.error("Shift amount too large", a2.position)
getExpressionType(env, log, a1)
case FunctionCallExpression("<<'", params) => b
case FunctionCallExpression(">>'", params) => b
case FunctionCallExpression(">>>>", params) => b
@ -242,11 +250,11 @@ object AbstractExpressionCompiler {
case FunctionCallExpression("<<'=", params) => v
case FunctionCallExpression(">>'=", params) => v
case f@FunctionCallExpression(name, params) =>
ctx.env.maybeGet[Type](name) match {
env.maybeGet[Type](name) match {
case Some(typ) =>
typ
case None =>
lookupFunction(ctx, f).returnType
lookupFunction(env, log, f).returnType
}
}
}
@ -274,9 +282,9 @@ object AbstractExpressionCompiler {
}
}
def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = {
val paramsWithTypes = f.expressions.map(x => getExpressionType(ctx, x) -> x)
ctx.env.lookupFunction(f.functionName, paramsWithTypes).getOrElse(
ctx.log.fatal(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position))
def lookupFunction(env: Environment, log: Logger, f: FunctionCallExpression): MangledFunction = {
val paramsWithTypes = f.expressions.map(x => getExpressionType(env, log, x) -> x)
env.lookupFunction(f.functionName, paramsWithTypes).getOrElse(
log.fatal(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position))
}
}

View File

@ -222,6 +222,6 @@ object AbstractStatementPreprocessor {
"<<", "<<'", ">>", ">>'", ">>>>",
"&", "&&", "||", "|", "^",
"==", "!=", "<", ">", ">=", "<=",
"not", "hi", "lo", "nonet"
"not", "hi", "lo", "nonet", "sizeof"
)
}

View File

@ -742,6 +742,17 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
} else compilation
}
}
case "sizeof" =>
env.eval(expr) match {
case Some(c) =>
exprTypeAndVariable match {
case Some((t, v)) =>
compileConstant(ctx, c, v)
case _ =>
Nil
}
case None => Nil
}
case "nonet" =>
if (params.length != 1) {
ctx.log.error("Invalid number of parameters", f.position)

View File

@ -501,6 +501,9 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
case ZExpressionTarget.DEHL => List(ZLine.ldImm8(ZRegister.H, 0), ZLine.ldImm16(ZRegister.DE, 0))
})
}
case "sizeof" =>
ctx.log.fatal("Unreachable branch: 8080 sizeof")
Nil
case "nonet" =>
if (params.length != 1) {
ctx.log.error("Invalid number of parameters", f.position)

View File

@ -4,7 +4,7 @@ import millfork.assembly.BranchingOpcodeMapping
import millfork.{env, _}
import millfork.assembly.mos.Opcode
import millfork.assembly.z80.{IfFlagClear, IfFlagSet, ZFlag}
import millfork.compiler.LabelGenerator
import millfork.compiler.{AbstractExpressionCompiler, LabelGenerator}
import millfork.error.Logger
import millfork.node._
import millfork.output._
@ -266,17 +266,18 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
if (things.contains(name)) {
val t: Thing = things(name)
val clazz = implicitly[Manifest[T]].runtimeClass
if ((t ne null) && clazz.isInstance(t)) {
Some(t.asInstanceOf[T])
} else {
t match {
case Alias(_, target, deprectated) =>
if (deprectated) {
log.warn(s"Alias `$name` is deprecated, use `$target` instead")
}
root.maybeGet[T](target)
case _ => None
}
t match {
case Alias(_, target, deprectated) =>
if (deprectated) {
log.warn(s"Alias `$name` is deprecated, use `$target` instead")
}
root.maybeGet[T](target)
case _ =>
if ((t ne null) && clazz.isInstance(t)) {
Some(t.asInstanceOf[T])
} else {
None
}
}
} else parent.flatMap {
_.maybeGet[T](name)
@ -453,6 +454,29 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
}
def evalSizeof(expr: Expression): Constant = {
val size: Int = expr match {
case VariableExpression(name) =>
maybeGet[Thing](name) match {
case None =>
log.error(s"`$name` is not defined")
1
case Some(thing) => thing match {
case t: Type => t.size
case v: Variable => v.typ.size
case a: InitializedArray => a.elementType.size * a.contents.length
case a: UninitializedArray => a.sizeInBytes
case x =>
log.error("Invalid parameter for expr: " + name)
1
}
}
case _ =>
AbstractExpressionCompiler.getExpressionType(this, log, expr).size
}
NumericConstant(size, Constant.minimumSize(size))
}
def eval(e: Expression, vars: Map[String, Constant]): Option[Constant] = evalImpl(e, Some(vars))
def eval(e: Expression): Option[Constant] = evalImpl(e, None)
@ -493,6 +517,13 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
} yield hc.asl(8) + lc
case FunctionCallExpression(name, params) =>
name match {
case "sizeof" =>
if (params.size == 1) {
Some(evalSizeof(params.head))
} else {
log.error("Invalid number of parameters for `sizeof`", e.position)
Some(Constant.One)
}
case "hi" =>
if (params.size == 1) {
eval(params.head).map(_.hiByte.quickSimplify)
@ -1282,6 +1313,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
nameCheck(l)
case SumExpression(params, _) =>
nameCheck(params.map(_._2))
case FunctionCallExpression("sizeof", List(ve@VariableExpression(e))) =>
checkName[Thing]("Type, variable or constant", e, ve.position)
case FunctionCallExpression(name, params) =>
if (name.exists(_.isLetter) && !Environment.predefinedFunctions(name)) {
checkName[CallableThing]("Function or type", name, node.position)
@ -1291,5 +1324,5 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
object Environment {
val predefinedFunctions = Set("not", "hi", "lo", "nonet")
val predefinedFunctions = Set("not", "hi", "lo", "nonet", "sizeof")
}

View File

@ -199,7 +199,7 @@ trait MfArray extends ThingInMemory with IndexableThing {
def elementType: VariableType
}
case class UninitializedArray(name: String, sizeInBytes: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType, override val alignment: MemoryAlignment) extends MfArray with UninitializedMemory {
case class UninitializedArray(name: String, /* TODO: what if larger elements? */ sizeInBytes: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType, override val alignment: MemoryAlignment) extends MfArray with UninitializedMemory {
override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
override def alloc: VariableAllocationMethod.Value = VariableAllocationMethod.Static

View File

@ -0,0 +1,35 @@
package millfork.test
import millfork.Cpu
import millfork.test.emu.{EmuBenchmarkRun, EmuOptimizedCmosRun, EmuOptimizedRun, EmuUnoptimizedCrossPlatformRun}
import org.scalatest.{AppendedClues, FunSuite, Matchers}
/**
* @author Karol Stasiak
*/
class SizeofSuite extends FunSuite with Matchers with AppendedClues {
test("Basic sizeof test") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| const byte sizeofbyte = sizeof(byte)
| array output [6] @$c000
| void main () {
| byte a
| word b
| output[0] = sizeofbyte
| output[1] = sizeof(a)
| output[2] = sizeof(word)
| output[3] = sizeof(b)
| output[4] = sizeof(output[1])
| output[5] = sizeof(long)
| }
""".stripMargin){m =>
m.readByte(0xc000) should equal(1)
m.readByte(0xc001) should equal(1)
m.readByte(0xc002) should equal(2)
m.readByte(0xc003) should equal(2)
m.readByte(0xc004) should equal(1)
m.readByte(0xc005) should equal(4)
}
}
}