From 3aac33b54f8f7c3e49ba9b252597cd3b5ffb785d Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Fri, 26 Jul 2019 19:02:32 +0200 Subject: [PATCH] Add the bool type. Few boolean-related bugfixes. --- CHANGELOG.md | 2 + docs/abi/generated-labels.md | 2 + docs/lang/types.md | 2 + .../compiler/AbstractExpressionCompiler.scala | 16 ++- .../compiler/AbstractStatementCompiler.scala | 16 ++- .../AbstractStatementPreprocessor.scala | 3 +- .../compiler/mos/MosExpressionCompiler.scala | 100 +++++++++++++++++- .../compiler/mos/MosStatementCompiler.scala | 10 +- .../compiler/z80/Z80ExpressionCompiler.scala | 80 ++++++++++++++ .../compiler/z80/Z80StatementCompiler.scala | 29 +++-- src/main/scala/millfork/env/Constant.scala | 2 +- src/main/scala/millfork/env/Environment.scala | 1 + src/main/scala/millfork/env/Thing.scala | 16 +++ .../scala/millfork/test/BooleanSuite.scala | 53 +++++++++- .../millfork/test/emu/SymonTestRam.scala | 2 +- 15 files changed, 313 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4306edcd..9291e1bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ * Added goto. +* Added `bool` type. + * Added arrays of elements of size greater than byte. * Improved passing of register parameters to assembly functions. diff --git a/docs/abi/generated-labels.md b/docs/abi/generated-labels.md index 54a3b05d..b2726e05 100644 --- a/docs/abi/generated-labels.md +++ b/docs/abi/generated-labels.md @@ -14,6 +14,8 @@ where `11111` is a sequential number and `xx` is the type: * `bc` – array bounds checking (`-fbounds-checking`) +* `bo` – boolean type conversions + * `c8` – constant `#8` for `BIT` when immediate addressing is not available * `co` – greater-than comparison diff --git a/docs/lang/types.md b/docs/lang/types.md index 857b4049..2908ed85 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -79,6 +79,8 @@ Its actual value is defined using the feature `NULLPTR`, by default it's 0. TODO +* `bool` – a 1-byte boolean value + ## Special types * `void` – a unit type containing no information, can be only used as a return type for a function. diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 3882f3c7..312480ea 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -139,8 +139,12 @@ class AbstractExpressionCompiler[T <: AbstractCode] { return } params.foreach { param => - if (!getExpressionType(ctx, param).isInstanceOf[BooleanType]) - ctx.log.fatal("Parameter should be boolean", param.position) + getExpressionType(ctx, param) match { + case _: BooleanType => + case FatBooleanType => + case _=> + ctx.log.fatal("Parameter should be boolean", param.position) + } } } @@ -150,8 +154,12 @@ class AbstractExpressionCompiler[T <: AbstractCode] { return } params.foreach { param => - if (!getExpressionType(ctx, param).isInstanceOf[BooleanType]) - ctx.log.fatal("Parameter should be boolean", param.position) + getExpressionType(ctx, param) match { + case _: BooleanType => + case FatBooleanType => + case _=> + ctx.log.fatal("Parameter should be boolean", param.position) + } } } diff --git a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala index 8cef90f7..fa37c60c 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala @@ -48,7 +48,8 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { case ConstantBooleanType(_, true) => List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten case ConstantBooleanType(_, false) => Nil - case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) => + case _:FlagBooleanType | FatBooleanType => + val (jumpIfTrue, jumpIfFalse) = getJumpIfTrueAndFalse(ctx, condType) if (largeBodyBlock) { val conditionBlock = compileExpressionForBranching(ctx, s.condition, NoBranching) List(labelChunk(start), conditionBlock, branchChunk(jumpIfTrue, middle), jmpChunk(end), labelChunk(middle), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten @@ -86,7 +87,8 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten case ConstantBooleanType(_, false) => List(bodyBlock, labelChunk(inc), incrementBlock, labelChunk(end)).flatten - case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) => + case _:FlagBooleanType | FatBooleanType => + val (jumpIfTrue, jumpIfFalse) = getJumpIfTrueAndFalse(ctx, condType) val conditionBlock = compileExpressionForBranching(ctx, s.condition, NoBranching) if (largeBodyBlock) { List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, conditionBlock, branchChunk(jumpIfFalse, end), jmpChunk(start), labelChunk(end)).flatten @@ -417,6 +419,13 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { } } + private def getJumpIfTrueAndFalse(ctx: CompilationContext, condType: Type): (BranchingOpcodeMapping, BranchingOpcodeMapping) = condType match { + case FlagBooleanType(_, jumpIfTrue, jmpIfFalse) => jumpIfTrue -> jmpIfFalse + case FatBooleanType => + val cz = ctx.env.get[FlagBooleanType]("clear_zero") + cz.jumpIfTrue -> cz.jumpIfFalse + } + def compileIfStatement(ctx: CompilationContext, s: IfStatement): (List[T], List[T]) = { val condType = AbstractExpressionCompiler.getExpressionType(ctx, s.condition) val (thenBlock, extra1) = compile(ctx, s.thenBranch) @@ -428,7 +437,8 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { compileExpressionForBranching(ctx, s.condition, NoBranching) ++ thenBlock case ConstantBooleanType(_, false) => compileExpressionForBranching(ctx, s.condition, NoBranching) ++ elseBlock - case FlagBooleanType(_, jumpIfTrue, jumpIfFalse) => + case _:FlagBooleanType | FatBooleanType => + val (jumpIfTrue, jumpIfFalse) = getJumpIfTrueAndFalse(ctx, condType) (s.thenBranch, s.elseBranch) match { case (Nil, Nil) => compileExpressionForBranching(ctx, s.condition, NoBranching) diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index e1c21069..bc37e2d8 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -214,7 +214,8 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte search(arg, cv - v) case FunctionCallExpression(name, params) if hiddenEffectFreeFunctions(name) || env.maybeGet[Type](name).isDefined => - params.map(p => search(p, cv)).reduce(commonVV) + if (params.isEmpty) cv // to handle compilation errors + else params.map(p => search(p, cv)).reduce(commonVV) case FunctionCallExpression(_, _) => cv -- nonreentrantVars case SumExpression(params, _) => params.map(p => search(p._2, cv)).reduce(commonVV) case HalfWordExpression(arg, _) => search(arg, cv) diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index 59631a97..c8c17a93 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -1,6 +1,6 @@ package millfork.compiler.mos -import millfork.CompilationFlag +import millfork.{CompilationFlag, env} import millfork.assembly.Elidability import millfork.assembly.mos.AddrMode._ import millfork.assembly.mos.Opcode._ @@ -418,6 +418,83 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { } val noop: List[AssemblyLine] = Nil + def areNZFlagsBasedOnA(code: List[AssemblyLine]): Boolean = { + for (line <- code.reverse) { + line.opcode match { + case TXA | LDA | ADC | EOR | SBC | AND | ORA | TYA | TZA | PLA | LAX => return true + case CMP => return line.addrMode == Immediate && line.parameter.isProvablyZero + case STA | STX | STY | STZ | SAX | + STA_W | STX_W | STY_W | STZ_W | + CLD | SED | CLV | SEI | CLI | SEC | CLC | + NOP => () + case ASL | LSR | ROL | ROR | INC | DEC => return line.addrMode == Implied + case _ => return false + } + } + false + } + + def compileToFatBooleanInA(ctx: CompilationContext, expr: Expression): List[AssemblyLine] = { + val env = ctx.env + val sourceType = AbstractExpressionCompiler.getExpressionType(ctx, expr) + sourceType match { + case FatBooleanType | _:ConstantBooleanType => + compileToA(ctx, expr) + case _: FlagBooleanType | BuiltInBooleanType => + val label = env.nextLabel("bo") + val condition = compile(ctx, expr, None, BranchIfFalse(label)) + if (condition.isEmpty) { + ??? + } + val conditionWithoutJump = condition.init + val hasOnlyOneJump = !conditionWithoutJump.exists(_.refersTo(label)) + // TODO: helper functions to convert flags to booleans, to make code smaller + if (hasOnlyOneJump) { + condition.last.opcode match { + case BCC => + // our bool is in the carry flag + // 3 bytes 4 cycles + return conditionWithoutJump ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine.implied(ROL)) + case BCS if !ctx.options.flag(CompilationFlag.OptimizeForSpeed) => + // our bool is in the carry flag, negated + // 5 bytes 6 cycles + return conditionWithoutJump ++ List(AssemblyLine.immediate(LDA, 0), AssemblyLine.implied(ROL), AssemblyLine.immediate(EOR, 1)) + case BPL if areNZFlagsBasedOnA(conditionWithoutJump) => + // our bool is in the N flag and the 7th bit of A + // 4 bytes 6 cycles + return conditionWithoutJump ++ List(AssemblyLine.implied(ASL), AssemblyLine.immediate(LDA, 0), AssemblyLine.implied(ROL)) + case BMI if areNZFlagsBasedOnA(conditionWithoutJump) && ctx.options.flag(CompilationFlag.OptimizeForSize)=> + // our bool is in the N flag and the 7th bit of A, negated + // 6 bytes 8 cycles + return conditionWithoutJump ++ List(AssemblyLine.implied(ASL), AssemblyLine.immediate(LDA, 0), AssemblyLine.implied(ROL), AssemblyLine.immediate(EOR, 1)) + case _ => + } + } + if (hasOnlyOneJump) { + condition.last.opcode match { + case BCC | BCS | BVC | BVS => + // 7 bytes; for true: 5 cycles, for false 6 cycles + return conditionWithoutJump ++ List( + AssemblyLine.immediate(LDA, 0), + condition.last, + AssemblyLine.immediate(LDA, 1), + AssemblyLine.label(label)) + case _ => () + } + } + val skip = env.nextLabel("bo") + // at most 9 bytes; for true: 7 cycles, for false 5 cycles + condition ++ List( + AssemblyLine.immediate(LDA, 1), + AssemblyLine.absolute(JMP, Label(skip)), + AssemblyLine.label(label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(skip)) + case _ => + println(sourceType) + ??? + } + } def compileToA(ctx: CompilationContext, expr: Expression): List[AssemblyLine] = { val env = ctx.env @@ -523,6 +600,20 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case _ => } val b = env.get[Type]("byte") + val exprType = AbstractExpressionCompiler.getExpressionType(ctx, expr) + if (branches != NoBranching) { + (exprType, branches) match { + case (FatBooleanType, _) => + return compile(ctx, FunctionCallExpression("!=", List(expr, LiteralExpression(0, 1))), exprTypeAndVariable, branches) + case (ConstantBooleanType(_, false), BranchIfTrue(_)) | (ConstantBooleanType(_, true), BranchIfFalse(_))=> + return compile(ctx, expr, exprTypeAndVariable, NoBranching) + case (ConstantBooleanType(_, true), BranchIfTrue(x)) => + return compile(ctx, expr, exprTypeAndVariable, NoBranching) :+ AssemblyLine.absolute(JMP, Label(x)) + case (ConstantBooleanType(_, false), BranchIfFalse(x)) => + return compile(ctx, expr, exprTypeAndVariable, NoBranching) :+ AssemblyLine.absolute(JMP, Label(x)) + case _ => () + } + } val w = env.get[Type]("word") expr match { case HalfWordExpression(expression, _) => ??? // TODO @@ -1653,6 +1744,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { AssemblyLine.implied(PLA)) } case (t, v: VariableInMemory) => t.size match { + case 0 => ??? case 1 => v.typ.size match { case 1 => AssemblyLine.variable(ctx, STA, v) @@ -1876,7 +1968,11 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case VariableExpression(name) => val v = env.get[Variable](name, target.position) // TODO check v.typ - compile(ctx, source, Some((sourceType, v)), NoBranching) + if (v.typ == FatBooleanType) { + compileToFatBooleanInA(ctx, source) ++ compileByteStorage(ctx, MosRegister.A, target) + } else { + compile(ctx, source, Some((sourceType, v)), NoBranching) + } case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) => compile(ctx, source, Some(w, RegisterVariable(MosRegister.AX, w)), NoBranching) ++ compileByteStorage(ctx, MosRegister.A, l) ++ compileByteStorage(ctx, MosRegister.X, h) diff --git a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala index 1ca84b64..37cca242 100644 --- a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala @@ -25,7 +25,11 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { def compileExpressionForBranching(ctx: CompilationContext, expr: Expression, branching: BranchSpec): List[AssemblyLine] = { val b = ctx.env.get[Type]("byte") - MosExpressionCompiler.compile(ctx, expr, Some(b, RegisterVariable(MosRegister.A, b)), branching) + val prepareA = MosExpressionCompiler.compile(ctx, expr, Some(b, RegisterVariable(MosRegister.A, b)), branching) + if (AbstractExpressionCompiler.getExpressionType(ctx, expr) == FatBooleanType) { + if (MosExpressionCompiler.areNZFlagsBasedOnA(prepareA)) prepareA + else prepareA :+ AssemblyLine.immediate(CMP, 0) + } else prepareA } override def replaceLabel(ctx: CompilationContext, line: AssemblyLine, from: String, to: String): AssemblyLine = line.parameter match { @@ -249,6 +253,10 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { MosExpressionCompiler.compileAssignment(ctx, e, VariableExpression(ctx.function.name + "`return")) ++ stackPointerFixBeforeReturn(ctx) ++ returnInstructions } + case FatBooleanType => + MosExpressionCompiler.compileToFatBooleanInA(ctx, e) ++ + stackPointerFixBeforeReturn(ctx, preserveA = true) ++ + List(AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions case _ => AbstractExpressionCompiler.checkAssignmentType(ctx, e, m.returnType) m.returnType.size match { diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index e74de7c4..47e9fc45 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -21,6 +21,72 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { def compileToA(ctx: CompilationContext, expression: Expression): List[ZLine] = compile(ctx, expression, ZExpressionTarget.A) + def compileToFatBooleanInA(ctx: CompilationContext, expression: Expression): List[ZLine] = { + val sourceType = AbstractExpressionCompiler.getExpressionType(ctx, expression) + sourceType match { + case FatBooleanType | _: ConstantBooleanType => compileToA(ctx, expression) + case BuiltInBooleanType | _: FlagBooleanType => + // TODO optimize if using CARRY + // TODO: helper functions to convert flags to booleans, to make code smaller + val label = ctx.env.nextLabel("bo") + val condition = Z80ExpressionCompiler.compile(ctx, expression, ZExpressionTarget.NOTHING, BranchIfFalse(label)) + val conditionWithoutJump = condition.init + val hasOnlyOneJump = !conditionWithoutJump.exists(_.refersTo(label)) + if (hasOnlyOneJump && (condition.last.opcode == JP || condition.last.opcode == JR)) { + import ZRegister._ + condition.last.registers match { + case IfFlagClear(ZFlag.C) => + // our bool is in the carry flag + return conditionWithoutJump ++ List(ZLine.ldImm8(A, 0), ZLine.implied(RLA)) + case IfFlagSet(ZFlag.C) => + // our bool is in the carry flag, negated + return conditionWithoutJump ++ List(ZLine.ldImm8(A, 0), ZLine.implied(CCF), ZLine.implied(RLA)) + case IfFlagClear(ZFlag.S) if areSZFlagsBasedOnA(conditionWithoutJump) => + // our bool is in the sign flag and the 7th bit of A + return conditionWithoutJump ++ List(ZLine.implied(RLCA), ZLine.imm8(AND, 1)) + case IfFlagSet(ZFlag.S) if areSZFlagsBasedOnA(conditionWithoutJump) => + // our bool is in the sign flag and the 7th bit of A, negated + return conditionWithoutJump ++ List(ZLine.implied(RLCA), ZLine.imm8(XOR, 1), ZLine.imm8(AND, 1)) + case _ => + // TODO: helper functions to convert flags to booleans, to make code smaller + } + } + if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) { + val skip = ctx.env.nextLabel("bo") + condition ++ List( + ZLine.ldImm8(ZRegister.A, 1), + ZLine.jumpR(ctx, skip), + ZLine.label(label), + ZLine.ldImm8(ZRegister.A, 0), + ZLine.label(skip) + ) + } else { + conditionWithoutJump ++ List( + ZLine.ldImm8(ZRegister.A, 0), + condition.last, + ZLine.register(INC, ZRegister.A), + ZLine.label(label) + ) + } + case _ => + println(sourceType) + ??? + } + } + + def areSZFlagsBasedOnA(code: List[ZLine]): Boolean = { + for (line <- code.reverse) { + line.opcode match { + case ADD | SUB | SBC | ADC | XOR | OR | AND => return true + case CP => return line.registers == OneRegister(ZRegister.IMM_8) && line.parameter.isProvablyZero + case LD | LD_16 | NOP | POP => () + case RR | RL | SLA | SLL | RRC | RLC | SRA | SRL => return line.registers == OneRegister(ZRegister.A) + case _ => return false + } + } + false + } + def compile8BitTo(ctx: CompilationContext, expression: Expression, register: ZRegister.Value): List[ZLine] = { if (ZRegister.A == register) compileToA(ctx, expression) else { val toA = compileToA(ctx, expression) @@ -239,6 +305,20 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { val env = ctx.env val b = env.get[Type]("byte") val w = env.get[Type]("word") + val exprType = AbstractExpressionCompiler.getExpressionType(ctx, expression) + if (branches != NoBranching) { + (exprType, branches) match { + case (FatBooleanType, _) => + return compile(ctx, FunctionCallExpression("!=", List(expression, LiteralExpression(0, 1))), target, branches) + case (ConstantBooleanType(_, false), BranchIfTrue(_)) | (ConstantBooleanType(_, true), BranchIfFalse(_))=> + return compile(ctx, expression, target, NoBranching) + case (ConstantBooleanType(_, true), BranchIfTrue(x)) => + return compile(ctx, expression, target, NoBranching) :+ ZLine.jump(x) + case (ConstantBooleanType(_, false), BranchIfFalse(x)) => + return compile(ctx, expression, target, NoBranching) :+ ZLine.jump(x) + case _ => () + } + } env.eval(expression) match { case Some(const) => target match { diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala index 66c07d40..6569edad 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala @@ -16,7 +16,6 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { def compile(ctx: CompilationContext, statement: ExecutableStatement): (List[ZLine], List[ZLine])= { - ctx.log.trace(statement.toString) val options = ctx.options val env = ctx.env val ret = Z80Compiler.restoreRegistersAndReturn(ctx) @@ -53,6 +52,9 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { Z80ExpressionCompiler.compileToHL(ctx, e) ++ fixStackOnReturn(ctx) ++ List(ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_BC)) ++ ret } + case FatBooleanType => + Z80ExpressionCompiler.compileToFatBooleanInA(ctx, e) ++ fixStackOnReturn(ctx) ++ + List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BC), ZLine.implied(DISCARD_DE), ZLine.implied(RET)) case t => AbstractExpressionCompiler.checkAssignmentType(ctx, e, ctx.function.returnType) t.size match { @@ -97,9 +99,17 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { AbstractExpressionCompiler.checkAssignmentType(ctx, source, targetType) (sourceType.size match { case 0 => - ctx.log.error("Cannot assign a void expression", statement.position) - Z80ExpressionCompiler.compile(ctx, source, ZExpressionTarget.NOTHING, BranchSpec.None) ++ - Z80ExpressionCompiler.compile(ctx, destination, ZExpressionTarget.NOTHING, BranchSpec.None) + sourceType match { + case _:ConstantBooleanType => + Z80ExpressionCompiler.compileToA(ctx, source) ++ Z80ExpressionCompiler.storeA(ctx, destination, signedSource = false) + case _:BooleanType => + // TODO: optimize + Z80ExpressionCompiler.compileToFatBooleanInA(ctx, source) ++ Z80ExpressionCompiler.storeA(ctx, destination, signedSource = false) + case _ => + ctx.log.error("Cannot assign a void expression", statement.position) + Z80ExpressionCompiler.compile(ctx, source, ZExpressionTarget.NOTHING, BranchSpec.None) ++ + Z80ExpressionCompiler.compile(ctx, destination, ZExpressionTarget.NOTHING, BranchSpec.None) + } case 1 => Z80ExpressionCompiler.compileToA(ctx, source) ++ Z80ExpressionCompiler.storeA(ctx, destination, sourceType.isSigned) case 2 => ctx.env.eval(source) match { @@ -311,8 +321,15 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { def areBlocksLarge(blocks: List[ZLine]*): Boolean = false - override def compileExpressionForBranching(ctx: CompilationContext, expr: Expression, branching: BranchSpec): List[ZLine] = - Z80ExpressionCompiler.compile(ctx, expr, ZExpressionTarget.NOTHING, branching) + override def compileExpressionForBranching(ctx: CompilationContext, expr: Expression, branching: BranchSpec): List[ZLine] = { + if (AbstractExpressionCompiler.getExpressionType(ctx, expr) == FatBooleanType) { + val prepareA = Z80ExpressionCompiler.compile(ctx, expr, ZExpressionTarget.A, branching) + if (Z80ExpressionCompiler.areSZFlagsBasedOnA(prepareA)) prepareA + else prepareA :+ ZLine.register(OR, ZRegister.A) + } else { + Z80ExpressionCompiler.compile(ctx, expr, ZExpressionTarget.NOTHING, branching) + } + } override def replaceLabel(ctx: CompilationContext, line: ZLine, from: String, to: String): ZLine = line.parameter match { case MemoryAddressConstant(Label(l)) if l == from => line.copy(parameter = MemoryAddressConstant(Label(to))) diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index 0447f8d5..12673a8b 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -269,7 +269,7 @@ case class SubbyteConstant(base: Constant, index: Int) extends Constant { override def quickSimplify: Constant = { val simplified = base.quickSimplify simplified match { - case NumericConstant(x, size) => if (index >= size) { + case NumericConstant(x, size) => if (index != 0 && index >= size) { Constant.Zero } else { NumericConstant((x >> (index * 8)) & 0xff, 1) diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 5ae2053d..029eef59 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -417,6 +417,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa addThing(falseType, None) addThing(ConstantThing("true", NumericConstant(1, 0), trueType), None) addThing(ConstantThing("false", NumericConstant(0, 0), falseType), None) + addThing(FatBooleanType, None) val nullptrValue = options.features.getOrElse("NULLPTR", 0L) val nullptrConstant = NumericConstant(nullptrValue, 2) addThing(ConstantThing("nullptr", nullptrConstant, NullType), None) diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index 01007e17..c9c605c1 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -118,10 +118,26 @@ case class UnionType(name: String, fields: List[(String, String)]) extends Compo override def isSigned: Boolean = false } +case object FatBooleanType extends VariableType { + override def size: Int = 1 + + override def isSigned: Boolean = false + + override def name: String = "bool" + + override def isPointy: Boolean = false + + override def isSubtypeOf(other: Type): Boolean = this == other + + override def isAssignableTo(targetType: Type): Boolean = this == targetType +} + sealed trait BooleanType extends Type { def size = 0 def isSigned = false + + override def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType) || targetType == FatBooleanType } case class ConstantBooleanType(name: String, value: Boolean) extends BooleanType diff --git a/src/test/scala/millfork/test/BooleanSuite.scala b/src/test/scala/millfork/test/BooleanSuite.scala index 50fcdb59..bdacc196 100644 --- a/src/test/scala/millfork/test/BooleanSuite.scala +++ b/src/test/scala/millfork/test/BooleanSuite.scala @@ -17,8 +17,10 @@ class BooleanSuite extends FunSuite with Matchers { | void main () { | byte a | a = 5 - | if not(a < 3) {output = 4} - | if not(a > 3) {output = 3} + | if not(a < 3) { output = $84 } + | if not(a > 3) { output = $03 } + | if not(true) { output = $05 } + | if not(false) { output &= $7f } | } """.stripMargin)(_.readByte(0xc000) should equal(4)) @@ -89,4 +91,51 @@ class BooleanSuite extends FunSuite with Matchers { """.stripMargin EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp, Cpu.Intel8086)(code)(_.readByte(0xc000) should equal(code.sliding(4).count(_ == "pass"))) } + + test("Fat boolean") { + val code =""" + | byte output @$c000 + | const pointer outside = $bfff + | void main () { + | output = 1 + | bool x + | x = true + | memory_barrier() + | if x { pass() } + | memory_barrier() + | x = false + | memory_barrier() + | if x { fail(1) } + | + | if isZero(0) { pass() } + | if isZero(1) { fail(2) } + | if isLarge(10) { pass() } + | if isLarge(1) { fail(3) } + | + | x = id(2) == 2 + | if x { pass() } + | x = id(2) != 2 + | if x { fail(4) } + | + | if always() { pass() } + | if never() { fail(5) } + | + | x = always() + | x = not(x) + | if x { fail(6) } + | if not(x) { pass() } + | if x && true { fail(7) } + | if x || true { pass() } + | } + | inline void pass() { output += 1 } + | noinline void fail(byte i) { outside[0] = i } + | noinline byte id(byte x) = x + | bool isZero(byte x) = x == 0 + | bool isLarge(byte x) = x >= 10 + | bool always() = true + | bool never() = false + | + """.stripMargin + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(code)(_.readByte(0xc000) should equal(code.sliding(4).count(_ == "pass"))) + } } diff --git a/src/test/scala/millfork/test/emu/SymonTestRam.scala b/src/test/scala/millfork/test/emu/SymonTestRam.scala index 2e9e6c5a..27620585 100644 --- a/src/test/scala/millfork/test/emu/SymonTestRam.scala +++ b/src/test/scala/millfork/test/emu/SymonTestRam.scala @@ -24,7 +24,7 @@ class SymonTestRam(mem: MemoryBank) extends Device(0x0000, 0xffff, "RAM") { override def write(i: Int, i1: Int): Unit = { if (!mem.writeable(i)) { - throw new RuntimeException(s"Can't write to $$${i.toHexString}") + throw new RuntimeException(s"Can't write $$${i1.&(0xff).toHexString} to $$${i.toHexString}") } mem.output(i) = i1.toByte }