1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-19 19:30:08 +00:00

Enumeration types. Stricter type checks.

This commit is contained in:
Karol Stasiak 2018-07-20 22:46:53 +02:00
parent ff93775cbe
commit c4c1bf00f2
28 changed files with 826 additions and 178 deletions

View File

@ -10,6 +10,8 @@
* Added aliases.
* Added enumeration types.
* Added preprocessor
* Automatic selection of text encoding based on target platform.

View File

@ -43,11 +43,13 @@ Note that you cannot mix `+'` and `-'` with `+` and `-`.
In the descriptions below, arguments to the operators are explained as follows:
* `byte` means any one-byte type
* `enum` means any enumeration type
* `word` means any two-byte type, or a byte expanded to a word
* `byte` means any numeric one-byte type
* `long` means any type longer than two bytes, or a shorter type expanded to such length to match the other argument
* `word` means any numeric two-byte type, or a byte expanded to a word; `pointer` is considered to be numeric
* `long` means any numeric type longer than two bytes, or a shorter type expanded to such length to match the other argument
* `constant` means a compile-time constant
@ -128,11 +130,13 @@ Note you cannot mix those operators, so `a <= b < c` is not valid.
(the current implementation calls it twice, but do not rely on this behaviour).
* `==`: equality
`enum == enum`
`byte == byte`
`simple word == simple word`
`simple long == simple long`
* `!=`: inequality
`enum != enum`
`byte != byte`
`simple word != simple word`
`simple long != simple long`
@ -152,6 +156,7 @@ and fail to compile otherwise. This will be changed in the future.
An expression of form `a[f()] += b` may call `f` an undefined number of times.
* `=`: normal assignment
`mutable enum = enum`
`mutable byte = byte`
`mutable word = word`
`mutable long = long`
@ -189,13 +194,20 @@ While Millfork does not consider indexing an operator, this is a place as good a
An expression of form `a[i]`, where `i` is an expression of type `byte`, is:
* when `a` is an array: an access to the `i`-th element of the array `a`
* when `a` is an array that has numeric index type: an access to the `i`-th element of the array `a`
* when `a` is a pointer variable: an access to the byte in memory at address `a + i`
Those expressions are of type `byte`. If `a` is any other kind of expression, `a[i]` is invalid.
If the zeropage register is enabled, `i` can also be of type `word`.
If the zeropage register is enabled, `i` can also be of type `word`.
An expression of form `a[i]`, where `i` is an expression of a enumeration type, is:
* when `a` is an array that has index type equal to the type of `i`:
an access to the element of the array `a` at the location assigned to the key `i`
* otherwise: a compile error
## Built-in functions
@ -212,6 +224,8 @@ Other kinds of expressions than the above (even `nonet(byte + byte + byte)`) wil
* `hi`, `lo`: most/least significant byte of a word
`hi(word)`
Furthermore, any type that can be assigned to a variable
can be used to convert from one type to another of the same size.

View File

@ -94,6 +94,12 @@ Syntax:
then defaults to `default_code_segment` as defined for the platform if the array has initial values,
or to `default` if it doesn't.
* `<size>`: either a constant number, which then defines the size of the array,
or a name of a plain enumeration type, in which case changes the type of the index to that enumeration
and declares the array size to be equal to the number of variants in that enumeration.
If the size is not specified here, then it's deduced from the `<initial_values>`.
If the declared size and the size deduced from the `<initial_values>` don't match, then an error is raised.
TODO
### Function declarations

View File

@ -41,3 +41,33 @@ TODO
## Special types
* `void` a unit type containing no information, can be only used as a return type for a function.
## Enumerations
Enumeration is a 1-byte type that represents a set of values:
enum <name> { <variants, separated by commas or newlines> }
The first variant has value 0. Every next variant has a value increased by 1 compared to a previous one.
Alternatively, a variant can be given a custom constant value, which will change the sequence.
If there is at least one variant and no variant is given a custom constant value,
then the enumeration is considered _plain_. Plain enumeration types can be used as array keys.
For plain enumerations, a constant `<name>.count` is defined,
equal to the number of variants in the enumeration.
Assigment between numeric types and enumerations is not possible without an explicit type cast:
enum E {}
byte b
E e
e = b // won't compile
b = e // won't compile
b = byte(e) // ok
e = E(b) // ok
Plain enumerations have their variants equal to `byte(0)` to `byte(<name>.count - 1)`.
Tip: You can use an enumeration with no variants as a strongly checked alternative byte type,
as there are no checks no values when converting bytes to enumeration values and vice versa.

View File

@ -12,6 +12,15 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
def getExpressionType(ctx: CompilationContext, expr: Expression): Type = AbstractExpressionCompiler.getExpressionType(ctx, expr)
def assertAllArithmetic(ctx: CompilationContext,expressions: List[Expression]) = {
for(e <- expressions) {
val typ = getExpressionType(ctx, e)
if (!typ.isArithmetic) {
ErrorReporting.error(s"Cannot perform arithmetic operations on type `$typ`", e.position)
}
}
}
def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = AbstractExpressionCompiler.lookupFunction(ctx, f)
def assertCompatible(exprType: Type, variableType: Type): Unit = {
@ -24,7 +33,8 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
ctx.copy(env = result)
}
def getParamMaxSize(ctx: CompilationContext, params: List[Expression]): Int = {
def getArithmeticParamMaxSize(ctx: CompilationContext, params: List[Expression]): Int = {
assertAllArithmetic(ctx, params)
params.map(expr => getExpressionType(ctx, expr).size).max
}
@ -32,12 +42,19 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
params.map { case (_, expr) => getExpressionType(ctx, expr).size}.max
}
def assertAllBytes(msg: String, ctx: CompilationContext, params: List[Expression]): Unit = {
def assertAllArithmeticBytes(msg: String, ctx: CompilationContext, params: List[Expression]): Unit = {
assertAllArithmetic(ctx, params)
if (params.exists { expr => getExpressionType(ctx, expr).size != 1 }) {
ErrorReporting.fatal(msg, params.head.position)
}
}
@inline
def assertArithmeticBinary(ctx: CompilationContext, params: List[Expression]): (Expression, Expression, Int) = {
assertAllArithmetic(ctx, params)
assertBinary(ctx, params)
}
def assertBinary(ctx: CompilationContext, params: List[Expression]): (Expression, Expression, Int) = {
if (params.length != 2) {
ErrorReporting.fatal("sfgdgfsd", None)
@ -56,6 +73,11 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
}
}
def assertArithmeticComparison(ctx: CompilationContext, params: List[Expression]): (Int, Boolean) = {
assertAllArithmetic(ctx, params)
assertComparison(ctx, params)
}
def assertBool(ctx: CompilationContext, fname: String, params: List[Expression], expectedParamCount: Int): Unit = {
if (params.length != expectedParamCount) {
ErrorReporting.error("Invalid number of parameters for " + fname, params.headOption.flatMap(_.position))
@ -78,10 +100,11 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
}
}
def assertAssignmentLike(ctx: CompilationContext, params: List[Expression]): (LhsExpression, Expression, Int) = {
def assertArithmeticAssignmentLike(ctx: CompilationContext, params: List[Expression]): (LhsExpression, Expression, Int) = {
if (params.length != 2) {
ErrorReporting.fatal("sfgdgfsd", None)
}
assertAllArithmetic(ctx, params)
(params.head, params(1)) match {
case (l: LhsExpression, r: Expression) =>
val lsize = getExpressionType(ctx, l).size
@ -115,7 +138,8 @@ object AbstractExpressionCompiler {
case HalfWordExpression(param, _) =>
getExpressionType(ctx, param)
b
case IndexedExpression(_, _) => b
case IndexedExpression(name, _) =>
env.getPointy(name).elementType
case SeparateBytesExpression(hi, lo) =>
if (getExpressionType(ctx, hi).size > 1) ErrorReporting.error("Hi byte too large", hi.position)
if (getExpressionType(ctx, lo).size > 1) ErrorReporting.error("Lo byte too large", lo.position)
@ -176,6 +200,29 @@ object AbstractExpressionCompiler {
}
}
def checkIndexType(ctx: CompilationContext, pointy: Pointy, index: Expression): Unit = {
val indexerType = getExpressionType(ctx, index)
if (!indexerType.isAssignableTo(pointy.indexType)) {
ErrorReporting.error(s"Invalid type for index${pointy.name.fold("")(" for `" + _ + "`")}: expected `${pointy.indexType}`, got `$indexerType`", index.position)
}
}
def checkAssignmentTypeAndGetSourceType(ctx: CompilationContext, source: Expression, target: LhsExpression): Type = {
val sourceType = getExpressionType(ctx, source)
val targetType = getExpressionType(ctx, target)
if (!sourceType.isAssignableTo(targetType)) {
ErrorReporting.error(s"Cannot assign `$sourceType` to `$targetType`", target.position.orElse(source.position))
}
sourceType
}
def checkAssignmentType(ctx: CompilationContext, source: Expression, targetType: Type): Unit = {
val sourceType = getExpressionType(ctx, source)
if (!sourceType.isAssignableTo(targetType)) {
ErrorReporting.error(s"Cannot assign `$sourceType` to `$targetType`", source.position)
}
}
def lookupFunction(ctx: CompilationContext, f: FunctionCallExpression): MangledFunction = {
val paramsWithTypes = f.expressions.map(x => getExpressionType(ctx, x) -> x)
ctx.env.lookupFunction(f.functionName, paramsWithTypes).getOrElse(

View File

@ -102,15 +102,17 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] {
def compileForStatement(ctx: CompilationContext, f: ForStatement): List[T] = {
// TODO: check sizes
// TODO: special faster cases
val p = f.position
val vex = VariableExpression(f.variable)
val one = LiteralExpression(1, 1)
val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one)))
val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one)))
val one = LiteralExpression(1, 1).pos(p)
val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one)).pos(p)).pos(p)
val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one)).pos(p)).pos(p)
val names = Set("", "for", f.variable)
val startEvaluated = ctx.env.eval(f.start)
val endEvaluated = ctx.env.eval(f.end)
ctx.env.maybeGet[Variable](f.variable).foreach{ v=>
val variable = ctx.env.maybeGet[Variable](f.variable)
variable.foreach{ v=>
startEvaluated.foreach(value => if (!value.quickSimplify.fitsInto(v.typ)) {
ErrorReporting.error(s"Variable `${f.variable}` is too small to hold the initial value in the for loop", f.position)
})
@ -129,45 +131,80 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] {
case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e - 1 =>
val end = nextLabel("of")
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start) :: f.body) ++ labelChunk(end)
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start).pos(p) :: f.body) ++ labelChunk(end)
case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s >= e =>
Nil
case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e =>
val end = nextLabel("of")
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start) :: f.body) ++ labelChunk(end)
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start).pos(p) :: f.body) ++ labelChunk(end)
case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s > e =>
Nil
case (ForDirection.ParallelUntil, Some(NumericConstant(0, ssize)), Some(NumericConstant(e, _))) if e > 0 =>
case (ForDirection.Until | ForDirection.ParallelUntil, Some(c), Some(NumericConstant(256, _)))
if variable.map(_.typ.size).contains(1) && c.requiredSize == 1 && c.isProvablyNonnegative =>
// LDX #s
// loop:
// stuff
// INX
// BNE loop
compile(ctx, List(
Assignment(vex, f.start).pos(p),
DoWhileStatement(f.body, List(increment), FunctionCallExpression("!=", List(vex, LiteralExpression(0, 1).pos(p))), names).pos(p)
))
case (ForDirection.ParallelUntil, Some(NumericConstant(0, _)), Some(NumericConstant(e, _))) if e > 0 =>
compile(ctx, List(
Assignment(vex, f.end),
DoWhileStatement(Nil, decrement :: f.body, FunctionCallExpression("!=", List(vex, f.start)), names)
DoWhileStatement(Nil, decrement :: f.body, FunctionCallExpression("!=", List(vex, f.start)), names).pos(p)
))
case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s == e =>
val end = nextLabel("of")
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, LiteralExpression(s, ssize)) :: f.body) ++ labelChunk(end)
compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, LiteralExpression(s, ssize)).pos(p) :: f.body) ++ labelChunk(end)
case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s < e =>
Nil
case (ForDirection.DownTo, Some(NumericConstant(s, 1)), Some(NumericConstant(0, _))) if s > 0 =>
compile(ctx, List(
Assignment(vex, FunctionCallExpression("lo", List(SumExpression(List(false -> f.start, false -> LiteralExpression(1, 2)), decimal = false)))),
DoWhileStatement(decrement :: f.body, Nil, FunctionCallExpression("!=", List(vex, f.end)), names)
Assignment(
vex,
FunctionCallExpression("lo", List(
SumExpression(List(
false -> f.start,
false -> LiteralExpression(1, 2).pos(p)),
decimal = false
).pos(p)
)).pos(p)
).pos(p),
DoWhileStatement(
decrement :: f.body,
Nil,
FunctionCallExpression("!=", List(vex, f.end)).pos(p), names).pos(p)
))
case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(0, _))) if s > 0 =>
compile(ctx, List(
Assignment(vex, SumExpression(List(false -> f.start, false -> LiteralExpression(1, 1)), decimal = false)),
DoWhileStatement(decrement :: f.body, Nil, FunctionCallExpression("!=", List(vex, f.end)), names)
Assignment(
vex,
SumExpression(
List(false -> f.start, false -> LiteralExpression(1, 1).pos(p)),
decimal = false
).pos(p)
).pos(p),
DoWhileStatement(
decrement :: f.body,
Nil,
FunctionCallExpression("!=", List(vex, f.end)).pos(p),
names
).pos(p)
))
case (ForDirection.Until | ForDirection.ParallelUntil, _, _) =>
compile(ctx, List(
Assignment(vex, f.start),
Assignment(vex, f.start).pos(p),
WhileStatement(
FunctionCallExpression("<", List(vex, f.end)),
f.body, List(increment), names),
FunctionCallExpression("<", List(vex, f.end)).pos(p),
f.body, List(increment), names).pos(p)
))
// case (ForDirection.To | ForDirection.ParallelTo, _, Some(NumericConstant(n, _))) if n > 0 && n < 255 =>
// compile(ctx, List(
@ -178,28 +215,28 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] {
// ))
case (ForDirection.To | ForDirection.ParallelTo, _, _) =>
compile(ctx, List(
Assignment(vex, f.start),
Assignment(vex, f.start).pos(p),
WhileStatement(
VariableExpression("true"),
VariableExpression("true").pos(p),
f.body,
List(IfStatement(
FunctionCallExpression("==", List(vex, f.end)),
List(BreakStatement(f.variable)),
FunctionCallExpression("==", List(vex, f.end)).pos(p),
List(BreakStatement(f.variable).pos(p)),
List(increment)
)),
names),
names)
))
case (ForDirection.DownTo, _, _) =>
compile(ctx, List(
Assignment(vex, f.start),
Assignment(vex, f.start).pos(p),
IfStatement(
FunctionCallExpression(">=", List(vex, f.end)),
FunctionCallExpression(">=", List(vex, f.end)).pos(p),
List(DoWhileStatement(
f.body,
List(decrement),
FunctionCallExpression("!=", List(vex, f.end)),
FunctionCallExpression("!=", List(vex, f.end)).pos(p),
names
)),
).pos(p)),
Nil)
))
}

View File

@ -50,6 +50,7 @@ object BuiltIns {
Nil -> AssemblyLine.variable(ctx, opcode, v)
case IndexedExpression(arrayName, index) =>
val pointy = env.getPointy(arrayName)
AbstractExpressionCompiler.checkIndexType(ctx, pointy, index)
val (variablePart, constantPart) = env.evalVariableAndConstantSubParts(index)
val indexerSize = variablePart.map(v => getIndexerSize(ctx, v)).getOrElse(1)
val totalIndexSize = getIndexerSize(ctx, index)
@ -57,11 +58,11 @@ object BuiltIns {
case (p: ConstantPointy, _, _, _, None) =>
Nil -> List(AssemblyLine.absolute(opcode, p.value + constantPart))
case (p: ConstantPointy, _, 1, IndexChoice.RequireX | IndexChoice.PreferX, Some(v)) =>
MosExpressionCompiler.compile(ctx, v, Some(b -> RegisterVariable(MosRegister.X, b)), NoBranching) -> List(AssemblyLine.absoluteX(opcode, p.value + constantPart))
MosExpressionCompiler.compile(ctx, v, Some(b -> RegisterVariable(MosRegister.X, pointy.indexType)), NoBranching) -> List(AssemblyLine.absoluteX(opcode, p.value + constantPart))
case (p: ConstantPointy, _, 1, IndexChoice.PreferY, Some(v)) =>
MosExpressionCompiler.compile(ctx, v, Some(b -> RegisterVariable(MosRegister.Y, b)), NoBranching) -> List(AssemblyLine.absoluteY(opcode, p.value + constantPart))
MosExpressionCompiler.compile(ctx, v, Some(b -> RegisterVariable(MosRegister.Y, pointy.indexType)), NoBranching) -> List(AssemblyLine.absoluteY(opcode, p.value + constantPart))
case (p: VariablePointy, 0 | 1, _, IndexChoice.PreferX | IndexChoice.PreferY, _) =>
MosExpressionCompiler.compile(ctx, index, Some(b -> RegisterVariable(MosRegister.Y, b)), NoBranching) -> List(AssemblyLine.indexedY(opcode, p.addr))
MosExpressionCompiler.compile(ctx, index, Some(b -> RegisterVariable(MosRegister.Y, pointy.indexType)), NoBranching) -> List(AssemblyLine.indexedY(opcode, p.addr))
case (p: ConstantPointy, _, 2, IndexChoice.PreferX | IndexChoice.PreferY, Some(v)) =>
MosExpressionCompiler.prepareWordIndexing(ctx, p, index) -> List(AssemblyLine.indexedY(opcode, env.get[VariableInMemory]("__reg")))
case (p: VariablePointy, 2, _, IndexChoice.PreferX | IndexChoice.PreferY, _) =>
@ -157,6 +158,14 @@ object BuiltIns {
case None => expr match {
case VariableExpression(_) => 'V'
case IndexedExpression(_, LiteralExpression(_, _)) => 'K'
case IndexedExpression(_, expr@VariableExpression(v)) =>
env.eval(expr) match {
case Some(_) => 'K'
case None => env.get[Variable](v).typ.size match {
case 1 => 'J'
case _ => 'I'
}
}
case IndexedExpression(_, VariableExpression(v)) if env.get[Variable](v).typ.size == 1 => 'J'
case IndexedExpression(_, _) => 'I'
case _ => 'A'
@ -760,8 +769,8 @@ object BuiltIns {
}
}
private def getIndexerSize(ctx: CompilationContext, indexExpr: Expression) = {
ctx.env.evalVariableAndConstantSubParts(indexExpr)._1.map(v => MosExpressionCompiler.getExpressionType(ctx, v)).size
private def getIndexerSize(ctx: CompilationContext, indexExpr: Expression): Int = {
ctx.env.evalVariableAndConstantSubParts(indexExpr)._1.map(v => MosExpressionCompiler.getExpressionType(ctx, v).size).sum
}
def compileInPlaceWordOrLongAddition(ctx: CompilationContext, lhs: LhsExpression, addend: Expression, subtract: Boolean, decimal: Boolean): List[AssemblyLine] = {

View File

@ -143,14 +143,14 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
val reg = ctx.env.get[VariableInMemory]("__reg")
val compileIndex = compile(ctx, indexExpression, Some(MosExpressionCompiler.getExpressionType(ctx, indexExpression) -> RegisterVariable(MosRegister.YA, w)), BranchSpec.None)
val prepareRegister = pointy match {
case ConstantPointy(addr, _) =>
case ConstantPointy(addr, _, _, _, _) =>
List(
AssemblyLine.implied(CLC),
AssemblyLine.immediate(ADC, addr.hiByte),
AssemblyLine.zeropage(STA, reg, 1),
AssemblyLine.immediate(LDA, addr.loByte),
AssemblyLine.zeropage(STA, reg))
case VariablePointy(addr) =>
case VariablePointy(addr, _, _) =>
List(
AssemblyLine.implied(CLC),
AssemblyLine.zeropage(ADC, addr + 1),
@ -257,7 +257,8 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
case (p: VariablePointy, _, _, 2) =>
wrapWordIndexingStorage(prepareWordIndexing(ctx, p, indexExpr))
case (p: ConstantPointy, Some(v), 2, _) =>
wrapWordIndexingStorage(prepareWordIndexing(ctx, ConstantPointy(p.value + constIndex, if (constIndex.isProvablyZero) p.size else None), v))
val w = env.get[VariableType]("word")
wrapWordIndexingStorage(prepareWordIndexing(ctx, ConstantPointy(p.value + constIndex, None, if (constIndex.isProvablyZero) p.size else None, w, p.elementType), v))
case (p: ConstantPointy, Some(v), 1, _) =>
storeToArrayAtUnknownIndex(v, p.value)
//TODO: should there be a type check or a zeropage check?
@ -499,6 +500,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
case IndexedExpression(arrayName, indexExpr) =>
val pointy = env.getPointy(arrayName)
AbstractExpressionCompiler.checkIndexType(ctx, pointy, indexExpr)
// TODO: check
val (variableIndex, constantIndex) = env.evalVariableAndConstantSubParts(indexExpr)
val variableIndexSize = variableIndex.map(v => getExpressionType(ctx, v).size).getOrElse(0)
@ -559,7 +561,12 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
case (a: ConstantPointy, Some(v), _, 1) =>
loadFromArrayAtUnknownIndex(v, a.value)
case (a: ConstantPointy, Some(v), _, 2) =>
prepareWordIndexing(ctx, ConstantPointy(a.value + constantIndex, if (constantIndex.isProvablyZero) a.size else None), v) ++ loadFromReg()
prepareWordIndexing(ctx, ConstantPointy(
a.value + constantIndex,
None,
if (constantIndex.isProvablyZero) a.size else None,
env.get[VariableType]("word"),
a.elementType), v) ++ loadFromReg()
case (a: VariablePointy, _, 2, _) =>
prepareWordIndexing(ctx, a, indexExpr) ++ loadFromReg()
case (p:VariablePointy, None, 0 | 1, _) =>
@ -592,6 +599,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case SumExpression(params, decimal) =>
assertAllArithmetic(ctx, params.map(_._2))
val a = params.map{case (n, p) => env.eval(p).map(n -> _)}
if (a.forall(_.isDefined)) {
val value = a.foldLeft(Constant.Zero){(c, pair) =>
@ -693,7 +701,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
ErrorReporting.error("Invalid number of parameters", f.position)
Nil
} else {
assertAllBytes("Nonet argument has to be a byte", ctx, params)
assertAllArithmeticBytes("Nonet argument has to be a byte", ctx, params)
params.head match {
case SumExpression(addends, _) =>
if (addends.exists(a => !a._1)) {
@ -738,7 +746,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
case "^^" => ???
case "&" =>
getParamMaxSize(ctx, params) match {
getArithmeticParamMaxSize(ctx, params) match {
case 1 =>
zeroExtend = true
BuiltIns.compileBitOps(AND, ctx, params)
@ -746,31 +754,31 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
case "*" =>
zeroExtend = true
assertAllBytes("Long multiplication not supported", ctx, params)
assertAllArithmeticBytes("Long multiplication not supported", ctx, params)
BuiltIns.compileByteMultiplication(ctx, params)
case "|" =>
getParamMaxSize(ctx, params) match {
getArithmeticParamMaxSize(ctx, params) match {
case 1 =>
zeroExtend = true
BuiltIns.compileBitOps(ORA, ctx, params)
case 2 => PseudoregisterBuiltIns.compileWordBitOpsToAX(ctx, params, ORA)
}
case "^" =>
getParamMaxSize(ctx, params) match {
getArithmeticParamMaxSize(ctx, params) match {
case 1 =>
zeroExtend = true
BuiltIns.compileBitOps(EOR, ctx, params)
case 2 => PseudoregisterBuiltIns.compileWordBitOpsToAX(ctx, params, EOR)
}
case ">>>>" =>
val (l, r, 2) = assertBinary(ctx, params)
val (l, r, 2) = assertArithmeticBinary(ctx, params)
l match {
case v: LhsExpression =>
zeroExtend = true
BuiltIns.compileNonetOps(ctx, v, r)
}
case "<<" =>
val (l, r, size) = assertBinary(ctx, params)
val (l, r, size) = assertArithmeticBinary(ctx, params)
size match {
case 1 =>
zeroExtend = true
@ -782,7 +790,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
Nil
}
case ">>" =>
val (l, r, size) = assertBinary(ctx, params)
val (l, r, size) = assertArithmeticBinary(ctx, params)
size match {
case 1 =>
zeroExtend = true
@ -795,17 +803,17 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
case "<<'" =>
zeroExtend = true
assertAllBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertBinary(ctx, params)
assertAllArithmeticBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertArithmeticBinary(ctx, params)
DecimalBuiltIns.compileByteShiftLeft(ctx, l, r, rotate = false)
case ">>'" =>
zeroExtend = true
assertAllBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertBinary(ctx, params)
assertAllArithmeticBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertArithmeticBinary(ctx, params)
DecimalBuiltIns.compileByteShiftRight(ctx, l, r, rotate = false)
case "<" =>
// TODO: signed
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, "<", params, exprTypeAndVariable, branches) { (l, r) =>
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
@ -815,7 +823,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
case ">=" =>
// TODO: signed
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, ">=", params, exprTypeAndVariable, branches) { (l, r) =>
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
@ -825,7 +833,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
case ">" =>
// TODO: signed
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, ">", params, exprTypeAndVariable, branches) { (l, r) =>
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
@ -835,7 +843,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
case "<=" =>
// TODO: signed
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, "<=", params, exprTypeAndVariable, branches) { (l, r) =>
size match {
case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
@ -860,7 +868,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
case _ => BuiltIns.compileLongComparison(ctx, ComparisonType.NotEqual, l, r, size, branches)
}
case "+=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = false, decimal = false)
@ -876,7 +884,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "-=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = true, decimal = false)
@ -892,7 +900,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "+'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = false, decimal = true)
@ -908,7 +916,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "-'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteAddition(ctx, l, r, subtract = true, decimal = true)
@ -924,7 +932,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "<<=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteShiftOps(ASL, ctx, l, r)
@ -935,7 +943,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case ">>=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteShiftOps(LSR, ctx, l, r)
@ -946,7 +954,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "<<'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
DecimalBuiltIns.compileByteShiftLeft(ctx, l, r, rotate = false) ++ compileByteStorage(ctx, MosRegister.A, l)
@ -957,7 +965,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case ">>'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
DecimalBuiltIns.compileByteShiftRight(ctx, l, r, rotate = false) ++ compileByteStorage(ctx, MosRegister.A, l)
@ -968,15 +976,15 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "*=" =>
assertAllBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertAssignmentLike(ctx, params)
assertAllArithmeticBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params)
BuiltIns.compileInPlaceByteMultiplication(ctx, l, r)
case "*'=" =>
assertAllBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertAssignmentLike(ctx, params)
assertAllArithmeticBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params)
DecimalBuiltIns.compileInPlaceByteMultiplication(ctx, l, r)
case "&=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteBitOp(ctx, l, r, AND)
@ -987,7 +995,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "^=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteBitOp(ctx, l, r, EOR)
@ -998,7 +1006,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
}
}
case "|=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 =>
BuiltIns.compileInPlaceByteBitOp(ctx, l, r, ORA)
@ -1241,13 +1249,14 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
def compileAssignment(ctx: CompilationContext, source: Expression, target: LhsExpression): List[AssemblyLine] = {
val env = ctx.env
val sourceType = AbstractExpressionCompiler.checkAssignmentTypeAndGetSourceType(ctx, source, target)
val b = env.get[Type]("byte")
val w = env.get[Type]("word")
target match {
case VariableExpression(name) =>
val v = env.get[Variable](name, target.position)
// TODO check v.typ
compile(ctx, source, Some((getExpressionType(ctx, source), v)), NoBranching)
compile(ctx, source, Some((sourceType, v)), NoBranching)
case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) =>
compile(ctx, source, Some(w, RegisterVariable(MosRegister.AX, w)), NoBranching) ++
compileByteStorage(ctx, MosRegister.A, l) ++ compileByteStorage(ctx, MosRegister.X, h)
@ -1261,10 +1270,12 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] {
def arrayBoundsCheck(ctx: CompilationContext, pointy: Pointy, register: MosRegister.Value, index: Expression): List[AssemblyLine] = {
if (!ctx.options.flags(CompilationFlag.CheckIndexOutOfBounds)) return Nil
val arrayLength = pointy match {
val arrayLength:Int = pointy match {
case _: VariablePointy => return Nil
case ConstantPointy(_, None) => return Nil
case ConstantPointy(_, Some(s)) => s
case p: ConstantPointy => p.size match {
case None => return Nil
case Some(s) => s
}
}
ctx.env.eval(index) match {
case Some(NumericConstant(i, _)) =>

View File

@ -121,19 +121,20 @@ object MosReturnDispatch {
else actualMax
}
val b = ctx.env.get[VariableType]("byte")
while (env.parent.isDefined) env = env.parent.get
val label = MosCompiler.nextLabel("di")
val paramArrays = stmt.params.indices.map { ix =>
val a = InitializedArray(label + "$" + ix + ".array", None, (paramMins(ix) to paramMaxes(ix)).map { key =>
map(key)._2.lift(ix).getOrElse(LiteralExpression(0, 1))
}.toList,
ctx.function.declaredBank)
ctx.function.declaredBank, b, b)
env.registerUnnamedArray(a)
a
}
val useJmpaix = ctx.options.flag(CompilationFlag.EmitCmosOpcodes) && !ctx.options.flag(CompilationFlag.LUnixRelocatableCode) && (actualMax - actualMin) <= 127
val b = ctx.env.get[Type]("byte")
import AddrMode._
import millfork.assembly.mos.Opcode._
@ -147,7 +148,7 @@ object MosReturnDispatch {
}
if (useJmpaix) {
val jumpTable = InitializedArray(label + "$jt.array", None, (actualMin to actualMax).flatMap(i => List(lobyte0(map(i)._1), hibyte0(map(i)._1))).toList, ctx.function.declaredBank)
val jumpTable = InitializedArray(label + "$jt.array", None, (actualMin to actualMax).flatMap(i => List(lobyte0(map(i)._1), hibyte0(map(i)._1))).toList, ctx.function.declaredBank, b, b)
env.registerUnnamedArray(jumpTable)
if (copyParams.isEmpty) {
val loadIndex = MosExpressionCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None)
@ -162,8 +163,8 @@ object MosReturnDispatch {
}
} else {
val loadIndex = MosExpressionCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(MosRegister.X, b)), BranchSpec.None)
val jumpTableLo = InitializedArray(label + "$jl.array", None, (actualMin to actualMax).map(i => lobyte1(map(i)._1)).toList, ctx.function.declaredBank)
val jumpTableHi = InitializedArray(label + "$jh.array", None, (actualMin to actualMax).map(i => hibyte1(map(i)._1)).toList, ctx.function.declaredBank)
val jumpTableLo = InitializedArray(label + "$jl.array", None, (actualMin to actualMax).map(i => lobyte1(map(i)._1)).toList, ctx.function.declaredBank, b, b)
val jumpTableHi = InitializedArray(label + "$jh.array", None, (actualMin to actualMax).map(i => hibyte1(map(i)._1)).toList, ctx.function.declaredBank, b, b)
env.registerUnnamedArray(jumpTableLo)
env.registerUnnamedArray(jumpTableHi)
val actualJump = if (ctx.options.flag(CompilationFlag.LUnixRelocatableCode)) {

View File

@ -223,6 +223,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] {
MosExpressionCompiler.compile(ctx, e, someRegisterAX, NoBranching) ++ stackPointerFixBeforeReturn(ctx) ++ returnInstructions
}
case _ =>
AbstractExpressionCompiler.checkAssignmentType(ctx, e, m.returnType)
m.returnType.size match {
case 0 =>
ErrorReporting.error("Cannot return anything from a void function", statement.position)

View File

@ -205,7 +205,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
case code => code ++ loadByteViaHL(target)
}
case SumExpression(params, decimal) =>
getParamMaxSize(ctx, params.map(_._2)) match {
getArithmeticParamMaxSize(ctx, params.map(_._2)) match {
case 1 => targetifyA(target, ZBuiltIns.compile8BitSum(ctx, params, decimal), isSigned = false)
case 2 => targetifyHL(target, ZBuiltIns.compile16BitSum(ctx, params, decimal))
}
@ -302,50 +302,50 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
case "^^" => ???
case "&" =>
getParamMaxSize(ctx, params) match {
getArithmeticParamMaxSize(ctx, params) match {
case 1 => targetifyA(target, ZBuiltIns.compile8BitOperation(ctx, AND, params), isSigned = false)
case 2 => targetifyHL(target, ZBuiltIns.compile16BitOperation(ctx, AND, params))
}
case "*" =>
assertAllBytes("Long multiplication not supported", ctx, params)
assertAllArithmeticBytes("Long multiplication not supported", ctx, params)
targetifyA(target, Z80Multiply.compile8BitMultiply(ctx, params), isSigned = false)
case "|" =>
getParamMaxSize(ctx, params) match {
getArithmeticParamMaxSize(ctx, params) match {
case 1 => targetifyA(target, ZBuiltIns.compile8BitOperation(ctx, OR, params), isSigned = false)
case 2 => targetifyHL(target, ZBuiltIns.compile16BitOperation(ctx, OR, params))
}
case "^" =>
getParamMaxSize(ctx, params) match {
getArithmeticParamMaxSize(ctx, params) match {
case 1 => targetifyA(target, ZBuiltIns.compile8BitOperation(ctx, XOR, params), isSigned = false)
case 2 => targetifyHL(target, ZBuiltIns.compile16BitOperation(ctx, XOR, params))
}
case ">>>>" =>
val (l, r, 2) = assertBinary(ctx, params)
val (l, r, 2) = assertArithmeticBinary(ctx, params)
???
case "<<" =>
val (l, r, size) = assertBinary(ctx, params)
val (l, r, size) = assertArithmeticBinary(ctx, params)
size match {
case 1 => targetifyA(target, Z80Shifting.compile8BitShift(ctx, l, r, left = true), isSigned = false)
case 2 => Z80Shifting.compile16BitShift(ctx, l, r, left = true)
case _ => ???
}
case ">>" =>
val (l, r, size) = assertBinary(ctx, params)
val (l, r, size) = assertArithmeticBinary(ctx, params)
size match {
case 1 => targetifyA(target, Z80Shifting.compile8BitShift(ctx, l, r, left = false), isSigned = false)
case 2 => Z80Shifting.compile16BitShift(ctx, l, r, left = false)
case _ => ???
}
case "<<'" =>
assertAllBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertBinary(ctx, params)
assertAllArithmeticBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertArithmeticBinary(ctx, params)
???
case ">>'" =>
assertAllBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertBinary(ctx, params)
assertAllArithmeticBytes("Long shift ops not supported", ctx, params)
val (l, r, 1) = assertArithmeticBinary(ctx, params)
???
case "<" =>
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, "<", params, target, branches) { (l, r) =>
size match {
case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches)
@ -354,7 +354,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
}
}
case ">=" =>
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, ">=", params, target, branches) { (l, r) =>
size match {
case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches)
@ -363,7 +363,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
}
}
case ">" =>
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, ">", params, target, branches) { (l, r) =>
size match {
case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches)
@ -372,7 +372,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
}
}
case "<=" =>
val (size, signed) = assertComparison(ctx, params)
val (size, signed) = assertArithmeticComparison(ctx, params)
compileTransitiveRelation(ctx, "<=", params, target, branches) { (l, r) =>
size match {
case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches)
@ -399,72 +399,72 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
}
}
case "+=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => ZBuiltIns.perform8BitInPlace(ctx, l, r, ADD)
case _ => ZBuiltIns.performLongInPlace(ctx, l, r, ADD, ADC, size)
}
case "-=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => ZBuiltIns.perform8BitInPlace(ctx, l, r, SUB)
case _ => ZBuiltIns.performLongInPlace(ctx, l, r, SUB, SBC, size)
}
case "+'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => ZBuiltIns.perform8BitInPlace(ctx, l, r, ADD, decimal = true)
case _ => ZBuiltIns.performLongInPlace(ctx, l, r, ADD, ADC, size, decimal = true)
}
case "-'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => ZBuiltIns.perform8BitInPlace(ctx, l, r, SUB, decimal = true)
case _ => ZBuiltIns.performLongInPlace(ctx, l, r, SUB, SBC, size, decimal = true)
}
Nil
case "<<=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => Z80Shifting.compile8BitShiftInPlace(ctx, l, r, left = true)
case 2 => Z80Shifting.compile16BitShift(ctx, l, r, left = true) ++ storeHL(ctx, l, signedSource = false)
case _ => ???
}
case ">>=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => Z80Shifting.compile8BitShiftInPlace(ctx, l, r, left = false)
case 2 => Z80Shifting.compile16BitShift(ctx, l, r, left = false) ++ storeHL(ctx, l, signedSource = false)
case _ => ???
}
case "<<'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
???
case ">>'=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
???
case "*=" =>
assertAllBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertAssignmentLike(ctx, params)
assertAllArithmeticBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params)
Z80Multiply.compile8BitInPlaceMultiply(ctx, l, r)
case "*'=" =>
assertAllBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertAssignmentLike(ctx, params)
assertAllArithmeticBytes("Long multiplication not supported", ctx, params)
val (l, r, 1) = assertArithmeticAssignmentLike(ctx, params)
???
case "&=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => ZBuiltIns.perform8BitInPlace(ctx, l, r, AND)
case _ => ZBuiltIns.performLongInPlace(ctx, l, r, AND, AND, size)
}
case "^=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => ZBuiltIns.perform8BitInPlace(ctx, l, r, XOR)
case _ => ZBuiltIns.performLongInPlace(ctx, l, r, XOR, XOR, size)
}
case "|=" =>
val (l, r, size) = assertAssignmentLike(ctx, params)
val (l, r, size) = assertArithmeticAssignmentLike(ctx, params)
size match {
case 1 => ZBuiltIns.perform8BitInPlace(ctx, l, r, OR)
case _ => ZBuiltIns.performLongInPlace(ctx, l, r, OR, OR, size)
@ -584,8 +584,10 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
def calculateAddressToHL(ctx: CompilationContext, i: IndexedExpression): List[ZLine] = {
val env = ctx.env
env.getPointy(i.name) match {
case ConstantPointy(baseAddr, _) =>
val pointy = env.getPointy(i.name)
AbstractExpressionCompiler.checkIndexType(ctx, pointy, i.index)
pointy match {
case ConstantPointy(baseAddr, _, _, _, _) =>
env.evalVariableAndConstantSubParts(i.index) match {
case (None, offset) => List(ZLine.ldImm16(ZRegister.HL, (baseAddr + offset).quickSimplify))
case (Some(index), offset) =>
@ -593,7 +595,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] {
stashBCIfChanged(compileToHL(ctx, index)) ++
List(ZLine.registers(ADD_16, ZRegister.HL, ZRegister.BC))
}
case VariablePointy(varAddr) =>
case VariablePointy(varAddr, _, _) =>
compileToHL(ctx, i.index) ++
loadBCFromHL ++
List(

View File

@ -50,18 +50,20 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] {
Z80ExpressionCompiler.compileToHL(ctx, e) ++ fixStackOnReturn(ctx) ++
List(ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_BCDEIX), ZLine.implied(RET))
}
case t => t.size match {
case 0 =>
ErrorReporting.error("Cannot return anything from a void function", statement.position)
fixStackOnReturn(ctx) ++
List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BCDEIX), ZLine.implied(RET))
case 1 =>
Z80ExpressionCompiler.compileToA(ctx, e) ++ fixStackOnReturn(ctx) ++
List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BCDEIX), ZLine.implied(RET))
case 2 =>
Z80ExpressionCompiler.compileToHL(ctx, e) ++ fixStackOnReturn(ctx) ++
List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_BCDEIX), ZLine.implied(RET))
}
case t =>
AbstractExpressionCompiler.checkAssignmentType(ctx, e, ctx.function.returnType)
t.size match {
case 0 =>
ErrorReporting.error("Cannot return anything from a void function", statement.position)
fixStackOnReturn(ctx) ++
List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BCDEIX), ZLine.implied(RET))
case 1 =>
Z80ExpressionCompiler.compileToA(ctx, e) ++ fixStackOnReturn(ctx) ++
List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BCDEIX), ZLine.implied(RET))
case 2 =>
Z80ExpressionCompiler.compileToHL(ctx, e) ++ fixStackOnReturn(ctx) ++
List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_BCDEIX), ZLine.implied(RET))
}
}
case Assignment(destination, source) =>
val sourceType = AbstractExpressionCompiler.getExpressionType(ctx, source)

View File

@ -24,6 +24,7 @@ import millfork.node.Position
sealed trait Constant {
def isProvablyZero: Boolean = false
def isProvably(value: Int): Boolean = false
def isProvablyNonnegative: Boolean = false
def asl(i: Constant): Constant = i match {
case NumericConstant(sa, _) => asl(sa.toInt)
@ -81,6 +82,7 @@ sealed trait Constant {
case class AssertByte(c: Constant) extends Constant {
override def isProvablyZero: Boolean = c.isProvablyZero
override def isProvably(i: Int): Boolean = c.isProvably(i)
override def isProvablyNonnegative: Boolean = c.isProvablyNonnegative
override def requiredSize: Int = 1
@ -105,6 +107,7 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant {
}
override def isProvablyZero: Boolean = value == 0
override def isProvably(i: Int): Boolean = value == i
override def isProvablyNonnegative: Boolean = value >= 0
override def isLowestByteAlwaysEqual(i: Int) : Boolean = (value & 0xff) == (i&0xff)
@ -146,6 +149,9 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant {
}
case class MemoryAddressConstant(var thing: ThingInMemory) extends Constant {
override def isProvablyNonnegative: Boolean = true
override def requiredSize = 2
override def toString: String = thing.name
@ -168,6 +174,8 @@ case class SubbyteConstant(base: Constant, index: Int) extends Constant {
override def requiredSize = 1
override def isProvablyNonnegative: Boolean = true
override def toString: String = base + (index match {
case 0 => ".lo"
case 1 => ".hi"
@ -185,21 +193,35 @@ object MathOperator extends Enumeration {
}
case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Constant) extends Constant {
override def isProvablyNonnegative: Boolean = {
import MathOperator._
operator match {
case Plus | DecimalPlus |
Times | DecimalTimes |
Shl | DecimalShl |
Shl9 | DecimalShl9 |
Shr | DecimalShr |
And | Or | Exor => lhs.isProvablyNonnegative && rhs.isProvablyNonnegative
case _ => false
}
}
override def quickSimplify: Constant = {
val l = lhs.quickSimplify
val r = rhs.quickSimplify
(l, r) match {
case (CompoundConstant(MathOperator.Plus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv,_)) if operator == MathOperator.Plus =>
case (CompoundConstant(MathOperator.Plus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv, _)) if operator == MathOperator.Plus =>
CompoundConstant(MathOperator.Plus, a, ll + rr).quickSimplify
case (CompoundConstant(MathOperator.Minus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv,_)) if operator == MathOperator.Minus =>
case (CompoundConstant(MathOperator.Minus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv, _)) if operator == MathOperator.Minus =>
CompoundConstant(MathOperator.Minus, a, ll + rr).quickSimplify
case (CompoundConstant(MathOperator.Plus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv,_)) if operator == MathOperator.Minus =>
case (CompoundConstant(MathOperator.Plus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv, _)) if operator == MathOperator.Minus =>
if (lv >= rv) {
CompoundConstant(MathOperator.Plus, a, ll - rr).quickSimplify
} else {
CompoundConstant(MathOperator.Minus, a, rr - ll).quickSimplify
}
case (CompoundConstant(MathOperator.Minus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv,_)) if operator == MathOperator.Plus =>
case (CompoundConstant(MathOperator.Minus, a, ll@NumericConstant(lv, _)), rr@NumericConstant(rv, _)) if operator == MathOperator.Plus =>
if (lv >= rv) {
CompoundConstant(MathOperator.Minus, a, ll - rr).quickSimplify
} else {
@ -229,7 +251,7 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
size = 2
case MathOperator.Times | MathOperator.Shl =>
val mask = (1 << (size * 8)) - 1
if (value != (value & mask)){
if (value != (value & mask)) {
size = ls + rs
}
case _ =>
@ -364,7 +386,16 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
}
}
override def requiredSize: Int = lhs.requiredSize max rhs.requiredSize
override def requiredSize: Int = {
import MathOperator._
operator match {
case Plus9 | DecimalPlus9 | Shl9 | DecimalShl9 => 2
case Times | Shl =>
// TODO
lhs.requiredSize max rhs.requiredSize
case _ => lhs.requiredSize max rhs.requiredSize
}
}
override def isRelatedTo(v: Thing): Boolean = lhs.isRelatedTo(v) || rhs.isRelatedTo(v)
}

View File

@ -73,7 +73,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
def getAllFixedAddressObjects: List[(String, Int, Int)] = {
things.values.flatMap {
case RelativeArray(_, NumericConstant(addr, _), size, declaredBank) =>
case RelativeArray(_, NumericConstant(addr, _), size, declaredBank, _, _) =>
List((declaredBank.getOrElse("default"), addr.toInt, size))
case RelativeVariable(_, NumericConstant(addr, _), typ, _, declaredBank) =>
List((declaredBank.getOrElse("default"), addr.toInt, typ.size))
@ -270,14 +270,24 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
InitializedMemoryVariable
UninitializedMemoryVariable
getArrayOrPointer(name) match {
case th@InitializedArray(_, _, cs, _) => ConstantPointy(th.toAddress, Some(cs.length))
case th@UninitializedArray(_, size, _) => ConstantPointy(th.toAddress, Some(size))
case th@RelativeArray(_, _, size, _) => ConstantPointy(th.toAddress, Some(size))
case ConstantThing(_, value, typ) if typ.size <= 2 => ConstantPointy(value, None)
case th:VariableInMemory => VariablePointy(th.toAddress)
case th@InitializedArray(_, _, cs, _, i, e) => ConstantPointy(th.toAddress, Some(name), Some(cs.length), i, e)
case th@UninitializedArray(_, size, _, i, e) => ConstantPointy(th.toAddress, Some(name), Some(size), i, e)
case th@RelativeArray(_, _, size, _, i, e) => ConstantPointy(th.toAddress, Some(name), Some(size), i, e)
case ConstantThing(_, value, typ) if typ.size <= 2 =>
val b = get[VariableType]("byte")
val w = get[VariableType]("word")
// TODO:
ConstantPointy(value, None, None, w, b)
case th:VariableInMemory =>
val b = get[VariableType]("byte")
val w = get[VariableType]("word")
// TODO:
VariablePointy(th.toAddress, w, b)
case _ =>
ErrorReporting.error(s"$name is not a valid pointer or array")
ConstantPointy(Constant.Zero, None)
val b = get[VariableType]("byte")
val w = get[VariableType]("word")
ConstantPointy(Constant.Zero, None, None, w, b)
}
}
@ -506,6 +516,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
case "||" | "|" =>
constantOperation(MathOperator.Or, params)
case _ =>
if (params.size == 1) {
return maybeGet[Type](name).flatMap(_ => eval(params.head))
}
None
}
}
@ -580,6 +593,26 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
addThing(Alias(stmt.name, stmt.target), stmt.position)
}
def registerEnum(stmt: EnumDefinitionStatement): Unit = {
val count = if (stmt.variants.nonEmpty && stmt.variants.forall(_._2.isEmpty)) {
val size = stmt.variants.size
addThing(ConstantThing(stmt.name + ".count", NumericConstant(size, 1), get[Type]("byte")), stmt.position)
Some(size)
} else None
val t = EnumType(stmt.name, count)
addThing(t, stmt.position)
var value = Constant.Zero
for((name, optValue) <- stmt.variants) {
optValue match {
case Some(v) =>
value = eval(v).getOrElse(Constant.error(s"Enum constant `${stmt.name}.$name` is not a constant", stmt.position))
case _ =>
}
addThing(ConstantThing(name, value, t), stmt.position)
value += 1
}
}
def registerFunction(stmt: FunctionDeclarationStatement, options: CompilationOptions): Unit = {
val w = get[Type]("word")
val name = stmt.name
@ -799,23 +832,37 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
def registerArray(stmt: ArrayDeclarationStatement, options: CompilationOptions): Unit = {
val b = get[Type]("byte")
val b = get[VariableType]("byte")
val w = get[VariableType]("word")
val p = get[Type]("pointer")
stmt.elements match {
case None =>
stmt.length match {
case None => ErrorReporting.error(s"Array `${stmt.name}` without size nor contents", stmt.position)
case Some(l) =>
// array arr[...]
val address = stmt.address.map(a => eval(a).getOrElse(ErrorReporting.fatal(s"Array `${stmt.name}` has non-constant address", stmt.position)))
val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
val (indexType, lengthConst) = l match {
case VariableExpression(name) =>
maybeGet[Type](name) match {
case Some(typ@EnumType(_, Some(count))) =>
typ -> NumericConstant(count, 1)
case Some(typ) =>
ErrorReporting.error(s"Type $name cannot be used as an array index", l.position)
w -> Constant.Zero
case _ => w -> eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
}
case _ =>
w -> eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
}
lengthConst match {
case NumericConstant(length, _) =>
if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${stmt.name}` has invalid length", stmt.position)
val array = address match {
case None => UninitializedArray(stmt.name + ".array", length.toInt,
declaredBank = stmt.bank)
declaredBank = stmt.bank, indexType, b)
case Some(aa) => RelativeArray(stmt.name + ".array", aa, length.toInt,
declaredBank = stmt.bank)
declaredBank = stmt.bank, indexType, b)
}
addThing(array, stmt.position)
registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, stmt.bank), stmt.position, options)
@ -850,20 +897,40 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
case Some(contents1) =>
val contents = extractArrayContents(contents1)
stmt.length match {
case None =>
case Some(l) =>
val lengthConst = eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
val indexType = stmt.length match {
case None => // array arr = [...]
w
case Some(l) => // array arr[...] = [...]
val (indexTyp, lengthConst) = l match {
case VariableExpression(name) =>
maybeGet[Type](name) match {
case Some(typ@EnumType(_, Some(count))) =>
if (count != contents.size)
ErrorReporting.error(s"Array `${stmt.name}` has actual length different than the number of variants in the enum `${typ.name}`", stmt.position)
typ -> NumericConstant(count, 1)
case Some(typ@EnumType(_, None)) =>
// using a non-enumerable enum for an array index is ok if the array is preïnitialized
typ -> NumericConstant(contents.length, 1)
case Some(_) =>
ErrorReporting.error(s"Type $name cannot be used as an array index", l.position)
w -> Constant.Zero
case _ =>
w -> eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
}
case _ =>
w -> eval(l).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant length", stmt.position))
}
lengthConst match {
case NumericConstant(ll, _) =>
if (ll != contents.length) ErrorReporting.error(s"Array `${stmt.name}` has different declared and actual length", stmt.position)
case _ => ErrorReporting.error(s"Array `${stmt.name}` has weird length", stmt.position)
}
indexTyp
}
val length = contents.length
if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${stmt.name}` has invalid length", stmt.position)
val address = stmt.address.map(a => eval(a).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant address", stmt.position)))
val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank)
val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank, indexType, b)
addThing(array, stmt.position)
registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None,
declaredBank = stmt.bank), stmt.position, options)
@ -905,7 +972,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
val b = get[Type]("byte")
val w = get[Type]("word")
val typ = get[PlainType](stmt.typ)
val typ = get[VariableType](stmt.typ)
if (stmt.typ == "pointer" || stmt.typ == "farpointer") {
// if (stmt.constant) {
// ErrorReporting.error(s"Pointer `${stmt.name}` cannot be constant")
@ -1067,15 +1134,23 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
def collectDeclarations(program: Program, options: CompilationOptions): Unit = {
val b = get[VariableType]("byte")
if (options.flag(CompilationFlag.OptimizeForSonicSpeed)) {
addThing(InitializedArray("identity$", None, List.tabulate(256)(n => LiteralExpression(n, 1)), declaredBank = None), None)
addThing(InitializedArray("identity$", None, List.tabulate(256)(n => LiteralExpression(n, 1)), declaredBank = None, b, b), None)
}
program.declarations.foreach {
case a: AliasDefinitionStatement => registerAlias(a)
case _ =>
}
program.declarations.foreach {
case e: EnumDefinitionStatement => registerEnum(e)
case _ =>
}
program.declarations.foreach {
case f: FunctionDeclarationStatement => registerFunction(f, options)
case v: VariableDeclarationStatement => registerVariable(v, options)
case a: ArrayDeclarationStatement => registerArray(a, options)
case a: AliasDefinitionStatement => registerAlias(a)
case i: ImportStatement => ()
case _ =>
}
if (options.zpRegisterSize > 0 && !things.contains("__reg")) {
addThing(BasicPlainType("__reg$type", options.zpRegisterSize), None)
@ -1093,7 +1168,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
if (CpuFamily.forType(options.platform.cpu) == CpuFamily.M6502) {
if (!things.contains("__constant8")) {
things("__constant8") = InitializedArray("__constant8", None, List(LiteralExpression(8, 1)), declaredBank = None)
things("__constant8") = InitializedArray("__constant8", None, List(LiteralExpression(8, 1)), declaredBank = None, b, b)
}
}
}

View File

@ -1,6 +1,13 @@
package millfork.env
trait Pointy
trait Pointy {
def name: Option[String]
def indexType: VariableType
def elementType: VariableType
}
case class VariablePointy(addr: Constant) extends Pointy
case class ConstantPointy(value: Constant, size: Option[Int]) extends Pointy
case class VariablePointy(addr: Constant, indexType: VariableType, elementType: VariableType) extends Pointy {
override def name: Option[String] = None
}
case class ConstantPointy(value: Constant, name: Option[String], size: Option[Int], indexType: VariableType, elementType: VariableType) extends Pointy

View File

@ -30,8 +30,12 @@ sealed trait Type extends CallableThing {
override def toString(): String = name
def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType)
def isArithmetic = false
}
sealed trait VariableType extends Type
case object VoidType extends Type {
def size = 0
@ -40,13 +44,16 @@ case object VoidType extends Type {
override def name = "void"
}
sealed trait PlainType extends Type {
sealed trait PlainType extends VariableType {
override def isCompatible(other: Type): Boolean = this == other || this.isSubtypeOf(other) || other.isSubtypeOf(this)
override def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType) || (targetType match {
case BasicPlainType(_, size) => size > this.size // TODO
case DerivedPlainType(_, parent, size) => isAssignableTo(parent)
case _ => false
})
override def isArithmetic = true
}
case class BasicPlainType(name: String, size: Int) extends PlainType {
@ -61,6 +68,12 @@ case class DerivedPlainType(name: String, parent: PlainType, isSigned: Boolean)
override def isSubtypeOf(other: Type): Boolean = parent == other || parent.isSubtypeOf(other)
}
case class EnumType(name: String, count: Option[Int]) extends VariableType {
override def size: Int = 1
override def isSigned: Boolean = false
}
sealed trait BooleanType extends Type {
def size = 0
@ -175,9 +188,12 @@ case class InitializedMemoryVariable(name: String, address: Option[Constant], ty
override def alloc: VariableAllocationMethod.Value = VariableAllocationMethod.Static
}
trait MfArray extends ThingInMemory with IndexableThing
trait MfArray extends ThingInMemory with IndexableThing {
def indexType: VariableType
def elementType: VariableType
}
case class UninitializedArray(name: String, sizeInBytes: Int, declaredBank: Option[String]) extends MfArray with UninitializedMemory {
case class UninitializedArray(name: String, sizeInBytes: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType) extends MfArray with UninitializedMemory {
override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this)
override def alloc = VariableAllocationMethod.Static
@ -189,7 +205,7 @@ case class UninitializedArray(name: String, sizeInBytes: Int, declaredBank: Opti
override def zeropage: Boolean = false
}
case class RelativeArray(name: String, address: Constant, sizeInBytes: Int, declaredBank: Option[String]) extends MfArray {
case class RelativeArray(name: String, address: Constant, sizeInBytes: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType) extends MfArray {
override def toAddress: Constant = address
override def isFar(compilationOptions: CompilationOptions): Boolean = farFlag.getOrElse(false)
@ -199,7 +215,7 @@ case class RelativeArray(name: String, address: Constant, sizeInBytes: Int, decl
override def zeropage: Boolean = false
}
case class InitializedArray(name: String, address: Option[Constant], contents: List[Expression], declaredBank: Option[String]) extends MfArray with PreallocableThing {
case class InitializedArray(name: String, address: Option[Constant], contents: List[Expression], declaredBank: Option[String], indexType: VariableType, elementType: VariableType) extends MfArray with PreallocableThing {
override def shouldGenerate = true
override def isFar(compilationOptions: CompilationOptions): Boolean = farFlag.getOrElse(false)

View File

@ -9,6 +9,21 @@ object ErrorReporting {
var hasErrors = false
private var sourceLines: Option[IndexedSeq[String]] = None
private def printErrorContext(pos: Option[Position]): Unit = {
if (sourceLines.isDefined && pos.isDefined) {
val line = sourceLines.get.apply(pos.get.line - 1)
val column = pos.get.column - 1
val margin = " "
print(margin)
println(line)
print(margin)
print(" " * column)
println("^")
}
}
def f(position: Option[Position]): String = position.fold("")(p => s"(${p.line}:${p.column}) ")
def info(msg: String, position: Option[Position] = None): Unit = {
@ -40,6 +55,7 @@ object ErrorReporting {
def warn(msg: String, options: CompilationOptions, position: Option[Position] = None): Unit = {
if (verbosity < 0) return
println("WARN: " + f(position) + msg)
printErrorContext(position)
flushOutput()
if (options.flag(CompilationFlag.FatalWarnings)) {
hasErrors = true
@ -49,12 +65,14 @@ object ErrorReporting {
def error(msg: String, position: Option[Position] = None): Unit = {
hasErrors = true
println("ERROR: " + f(position) + msg)
printErrorContext(position)
flushOutput()
}
def fatal(msg: String, position: Option[Position] = None): Nothing = {
hasErrors = true
println("FATAL: " + f(position) + msg)
printErrorContext(position)
flushOutput()
System.exit(1)
throw new RuntimeException(msg)
@ -63,6 +81,7 @@ object ErrorReporting {
def fatalQuit(msg: String, position: Option[Position] = None): Nothing = {
hasErrors = true
println("FATAL: " + f(position) + msg)
printErrorContext(position)
flushOutput()
System.exit(1)
throw new RuntimeException(msg)
@ -75,4 +94,13 @@ object ErrorReporting {
}
}
def clearErrors(): Unit = {
hasErrors = false
sourceLines = None
}
def setSource(source: Option[IndexedSeq[String]]): Unit = {
sourceLines = source
}
}

View File

@ -16,6 +16,12 @@ object Node {
node.position = Some(position)
node
}
def pos(position: Option[Position]): N = {
if (position.isDefined) {
node.position = position
}
node
}
}
}
@ -229,6 +235,10 @@ case class AliasDefinitionStatement(name: String, target: String) extends Declar
override def getAllExpressions: List[Expression] = Nil
}
case class EnumDefinitionStatement(name: String, variants: List[(String, Option[Expression])]) extends DeclarationStatement {
override def getAllExpressions: List[Expression] = variants.flatMap(_._2)
}
case class ArrayDeclarationStatement(name: String,
bank: Option[String],
length: Option[Expression],

View File

@ -247,7 +247,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
}
env.allPreallocatables.foreach {
case thing@InitializedArray(name, Some(NumericConstant(address, _)), items, _) =>
case thing@InitializedArray(name, Some(NumericConstant(address, _)), items, _, _, _) =>
val bank = thing.bank(options)
val bank0 = mem.banks(bank)
var index = address.toInt
@ -271,7 +271,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
}).mkString(", "))
}
initializedVariablesSize += items.length
case thing@InitializedArray(name, Some(_), items, _) => ???
case thing@InitializedArray(name, Some(_), items, _, _, _) => ???
case f: NormalFunction if f.address.isDefined =>
val bank = f.bank(options)
val bank0 = mem.banks(bank)
@ -354,7 +354,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
justAfterCode += "default" -> (index + 1)
}
env.allPreallocatables.foreach {
case thing@InitializedArray(name, None, items, _) =>
case thing@InitializedArray(name, None, items, _, _, _) =>
val bank = thing.bank(options)
val bank0 = mem.banks(bank)
var index = codeAllocators(bank).allocateBytes(bank0, options, items.size, initialized = true, writeable = true, location = AllocationLocation.High)

View File

@ -354,6 +354,7 @@ abstract class MfParser[T](filename: String, input: String, currentDirectory: St
bank <- bankDeclaration
flags <- functionFlags ~ HWS
returnType <- identifier ~ SWS
if !InvalidReturnTypes(returnType)
name <- identifier ~ HWS
params <- "(" ~/ AWS ~/ (if (flags("asm")) asmParamDefinition else paramDefinition).rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ ")" ~/ AWS
addr <- ("@" ~/ HWS ~/ mfExpression(1)).?.opaque("<address>") ~/ AWS
@ -385,9 +386,25 @@ abstract class MfParser[T](filename: String, input: String, currentDirectory: St
def validateAsmFunctionBody(p: Position, flags: Set[String], name: String, statements: Option[List[Statement]])
val enumVariant: P[(String, Option[Expression])] = for {
name <- identifier ~/ HWS
value <- ("=" ~/ HWS ~/ mfExpression(1)).? ~ HWS
} yield name -> value
val enumVariants: P[List[(String, Option[Expression])]] =
("{" ~/ AWS ~ enumVariant.rep(sep = NoCut(EOLOrComma) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList)
val enumDefinition: P[Seq[EnumDefinitionStatement]] = for {
p <- position()
_ <- "enum" ~ !letterOrDigit ~/ SWS ~ position("enum name")
name <- identifier ~/ HWS
_ <- position("enum defintion block")
variants <- enumVariants ~/ Pass
} yield Seq(EnumDefinitionStatement(name, variants).pos(p))
val program: Parser[Program] = for {
_ <- Start ~/ AWS ~/ Pass
definitions <- (importStatement | arrayDefinition | aliasDefinition | functionDefinition | globalVariableDefinition).rep(sep = EOL)
definitions <- (importStatement | arrayDefinition | aliasDefinition | enumDefinition | functionDefinition | globalVariableDefinition).rep(sep = EOL)
_ <- AWS ~ End
} yield Program(definitions.flatten.toList)
@ -405,6 +422,8 @@ object MfParser {
val EOL: P[Unit] = P(HWS ~ ("\r\n" | "\r" | "\n" | comment).opaque("<first line break>") ~ AWS).opaque("<line break>")
val EOLOrComma: P[Unit] = P(HWS ~ ("\r\n" | "\r" | "\n" | "," | comment).opaque("<first line break or comma>") ~ AWS).opaque("<line break or comma>")
val letter: P[String] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_").!)
val letterOrDigit: P[Unit] = P(CharIn("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_.$1234567890"))
@ -531,4 +550,6 @@ object MfParser {
val variableFlags: P[Set[String]] = flags_("const", "static", "volatile", "stack", "register")
val functionFlags: P[Set[String]] = flags_("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt")
val InvalidReturnTypes = Set("enum", "alias", "array", "const", "stack", "register", "static", "volatile", "import")
}

View File

@ -0,0 +1,155 @@
package millfork.test
import millfork.Cpu
import millfork.test.emu.{EmuUnoptimizedCrossPlatformRun, ShouldNotCompile}
import org.scalatest.{FunSuite, Matchers}
/**
* @author Karol Stasiak
*/
class EnumSuite extends FunSuite with Matchers {
test("Enum basic test") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| enum ugly {
| a
| b,c,
| d
| }
| void main () {
| byte i
| ugly e
| e = a
| if byte(e) != 0 { crash() }
| i = 1
| if ugly(i) != b { crash() }
| }
| asm void crash() {
| #if ARCH_6502
| sta $bfff
| rts
| #elseif ARCH_Z80
| ld ($bfff),a
| ret
| #else
| #error
| #endif
| }
""".stripMargin){_=>}
}
test("Enum arrays") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(
"""
| enum ugly {
| a
| b,c,
| d
| }
| array a1 [ugly]
| array a2 [ugly] = [6,7,8,9]
| void main () {
| if a2[a] != 6 { crash() }
| }
| asm void crash() {
| #if ARCH_6502
| sta $bfff
| rts
| #elseif ARCH_Z80
| ld ($bfff),a
| ret
| #else
| #error
| #endif
| }
""".stripMargin){_=>}
}
test("Enum-byte incompatibility test") {
// ShouldNotCompile(
// """
// | enum ugly { a }
// | void main() {
// | byte b
// | ugly u
// | b = u
// | }
// """.stripMargin)
//
// ShouldNotCompile(
// """
// | enum ugly { a }
// | void main() {
// | byte b
// | ugly u
// | u = b
// | }
// """.stripMargin)
//
// ShouldNotCompile(
// """
// | enum ugly { a }
// | byte main() {
// | byte b
// | ugly u
// | return u
// | }
// """.stripMargin)
//
// ShouldNotCompile(
// """
// | enum ugly { a }
// | ugly main() {
// | byte b
// | ugly u
// | return b
// | }
// """.stripMargin)
//
// ShouldNotCompile(
// """
// | enum ugly { a }
// | byte main() {
// | byte b
// | ugly u
// | return b + u
// | }
// """.stripMargin)
//
// ShouldNotCompile(
// """
// | enum ugly { a }
// | void main() {
// | byte b
// | ugly u
// | if b > u {}
// | }
// """.stripMargin)
//
// ShouldNotCompile(
// """
// | enum ugly { a }
// | array arr[ugly] = []
// | void main() {
// | }
// """.stripMargin)
// ShouldNotCompile(
// """
// | enum ugly { a }
// | array arr[ugly] = [1,2,3]
// | void main() {
// | }
// """.stripMargin)
ShouldNotCompile(
"""
| enum ugly { a }
| array arr[ugly]
| ugly main() {
| return a[0]
| }
""".stripMargin)
}
}

View File

@ -39,8 +39,8 @@ class FarwordTest extends FunSuite with Matchers {
| w = $7788
| b = $55
| output3 = $23344
| output2 = $11223344
| output1 = $11223344
| output2 = $23344
| output1 = $3344
| output2 = w
| output1 = b
| }

View File

@ -2,7 +2,7 @@ package millfork.test
import millfork.Cpu
import millfork.error.ErrorReporting
import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, EmuUnoptimizedRun}
import millfork.test.emu._
import org.scalatest.{FunSuite, Matchers}
/**
@ -209,5 +209,58 @@ class ForLoopSuite extends FunSuite with Matchers {
}
}
test("Edge cases - positive") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)("""
| void main() {
| byte i
| for i,0,until,256 { f() }
| for i,0,paralleluntil,256 { f() }
| for i,0,until,255 { f() }
| for i,0,paralleluntil,255 { f() }
| for i,0,to,255 { f() }
| for i,0,parallelto,255 { f() }
| for i,255,downto,0 { f() }
| }
| void f() { }
""".stripMargin){ m => }
}
test("Edge cases - negative") {
ShouldNotCompile("""
| void main() {
| byte i
| for i,0,until,257 { f() }
| }
| void f() { }
""".stripMargin)
ShouldNotCompile("""
| void main() {
| byte i
| for i,0,paralleluntil,257 { f() }
| }
| void f() { }
""".stripMargin)
ShouldNotCompile("""
| void main() {
| byte i
| for i,0,to,256 { f() }
| }
| void f() { }
""".stripMargin)
ShouldNotCompile("""
| void main() {
| byte i
| for i,0,parallelto,256 { f() }
| }
| void f() { }
""".stripMargin)
ShouldNotCompile("""
| void main() {
| byte i
| for i,256,downto,0 { f() }
| }
| void f() { }
""".stripMargin)
}
}

View File

@ -122,7 +122,7 @@ class IllegalSuite extends FunSuite with Matchers {
| byte main () {
| output = five()
| output <<= 1
| return output
| return output.b0
| }
| byte five () {
| return 5

View File

@ -41,7 +41,7 @@ class LongTest extends FunSuite with Matchers {
| b = $55
| output4 = $11223344
| output2 = $11223344
| output1 = $11223344
| output1 = $3344
| output2 = w
| output1 = b
| }

View File

@ -119,6 +119,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}"
if (source.contains("import zp_reg"))
effectiveSource += Files.readAllLines(Paths.get("include/zp_reg.mfk"), StandardCharsets.US_ASCII).asScala.mkString("\n", "\n", "")
ErrorReporting.setSource(Some(effectiveSource.lines.toIndexedSeq))
val (preprocessedSource, features) = Preprocessor.preprocessForTest(options, effectiveSource)
val parserF = MosParser("", preprocessedSource, "", options, features)
parserF.toAst match {
@ -167,7 +168,9 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
println(f"Gain: ${(100L * (unoptimizedSize - optimizedSize) / unoptimizedSize.toDouble).round}%5d%%")
}
ErrorReporting.assertNoErrors("Code generation failed")
if (ErrorReporting.hasErrors) {
fail("Code generation failed")
}
val memoryBank = assembler.mem.banks("default")
if (source.contains("return [")) {
@ -175,7 +178,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
if (memoryBank.readable(i)) memoryBank.readable(i + 1) = true
}
}
platform.cpu match {
val timings = platform.cpu match {
case millfork.Cpu.Cmos =>
runViaSymon(memoryBank, platform.codeAllocators("default").startAt, CpuBehavior.CMOS_6502)
case millfork.Cpu.Ricoh =>
@ -189,12 +192,14 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
ErrorReporting.trace("No emulation support for " + platform.cpu)
Timings(-1, -1) -> memoryBank
}
ErrorReporting.clearErrors()
timings
case f: Failure[_, _] =>
println(f)
println(f.extra.toString)
println(f.lastParser.toString)
ErrorReporting.error("Syntax error", Some(parserF.lastPosition))
???
fail("syntax error")
}
}

View File

@ -36,6 +36,7 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio
ErrorReporting.verbosity = 999
var effectiveSource = source
if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}"
ErrorReporting.setSource(Some(effectiveSource.lines.toIndexedSeq))
val (preprocessedSource, features) = Preprocessor.preprocessForTest(options, effectiveSource)
val parserF = Z80Parser("", preprocessedSource, "", options, features)
parserF.toAst match {
@ -84,7 +85,9 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio
println(f"Gain: ${(100L * (unoptimizedSize - optimizedSize) / unoptimizedSize.toDouble).round}%5d%%")
}
ErrorReporting.assertNoErrors("Code generation failed")
if (ErrorReporting.hasErrors) {
fail("Code generation failed")
}
val memoryBank = assembler.mem.banks("default")
(0x1f0 until 0x200).foreach(i => memoryBank.readable(i) = true)
@ -102,7 +105,7 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio
(0x200 until 0x2000).takeWhile(memoryBank.occupied(_)).map(memoryBank.output).grouped(16).map(_.map(i => f"$i%02x").mkString(" ")).foreach(ErrorReporting.debug(_))
platform.cpu match {
val timings = platform.cpu match {
case millfork.Cpu.Z80 =>
val cpu = new Z80Core(Z80Memory(memoryBank), DummyIO)
cpu.reset()
@ -118,6 +121,8 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio
case _ =>
Timings(-1, -1) -> memoryBank
}
ErrorReporting.clearErrors()
timings
case f: Failure[_, _] =>
println(f)
println(f.extra.toString)

View File

@ -0,0 +1,80 @@
package millfork.test.emu
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
import fastparse.core.Parsed.{Failure, Success}
import millfork.compiler.CompilationContext
import millfork.compiler.mos.MosCompiler
import millfork.env.{Environment, InitializedArray, InitializedMemoryVariable, NormalFunction}
import millfork.error.ErrorReporting
import millfork.node.StandardCallGraph
import millfork.parser.{MosParser, Preprocessor}
import millfork.{CompilationFlag, CompilationOptions, Cpu, CpuFamily}
import org.scalatest.Matchers
import scala.collection.JavaConverters._
object ShouldNotCompile extends Matchers {
def apply(source: String): Unit = {
checkCase(Cpu.Mos, source)
checkCase(Cpu.Z80, source)
}
private def checkCase(cpu: Cpu.Value, source: String) {
Console.out.flush()
Console.err.flush()
println(source)
val platform = EmuPlatform.get(cpu)
val options = CompilationOptions(platform, Map(CompilationFlag.LenientTextEncoding -> true), None, platform.zpRegisterSize)
ErrorReporting.hasErrors = false
ErrorReporting.verbosity = 999
var effectiveSource = source
if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}"
if (source.contains("import zp_reg"))
effectiveSource += Files.readAllLines(Paths.get("include/zp_reg.mfk"), StandardCharsets.US_ASCII).asScala.mkString("\n", "\n", "")
ErrorReporting.setSource(Some(effectiveSource.lines.toIndexedSeq))
val (preprocessedSource, features) = Preprocessor.preprocessForTest(options, effectiveSource)
val parserF = MosParser("", preprocessedSource, "", options, features)
parserF.toAst match {
case Success(program, _) =>
ErrorReporting.assertNoErrors("Parse failed")
// prepare
val callGraph = new StandardCallGraph(program)
val cpuFamily = CpuFamily.forType(cpu)
val env = new Environment(None, "", cpuFamily)
env.collectDeclarations(program, options)
var unoptimizedSize = 0L
// print unoptimized asm
env.allPreallocatables.foreach {
case f: NormalFunction =>
val unoptimized = MosCompiler.compile(CompilationContext(f.environment, f, 0, options, Set()))
unoptimizedSize += unoptimized.map(_.sizeInBytes).sum
case d: InitializedArray =>
unoptimizedSize += d.contents.length
case d: InitializedMemoryVariable =>
unoptimizedSize += d.typ.size
}
if (!ErrorReporting.hasErrors) {
val familyName = cpuFamily match {
case CpuFamily.M6502 => "6502"
case CpuFamily.I80 => "Z80"
case _ => "unknown CPU"
}
fail("Failed: Compilation succeeded for " + familyName)
}
ErrorReporting.clearErrors()
case f: Failure[_, _] =>
println(f.extra.toString)
println(f.lastParser.toString)
ErrorReporting.error("Syntax error: " + parserF.lastLabel, Some(parserF.lastPosition))
fail("syntax error")
}
}
}