From 33716f38810fb4de15200262ee7689404463cbc2 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Mon, 8 Jan 2018 01:18:04 +0100 Subject: [PATCH] Runtime bounds checking --- doc/abi/generated-labels.md | 2 + doc/api/command-line.md | 6 +- .../millfork/assembly/AssemblyLine.scala | 2 +- .../scala/millfork/compiler/MfCompiler.scala | 100 ++++++++++++------ 4 files changed, 76 insertions(+), 34 deletions(-) diff --git a/doc/abi/generated-labels.md b/doc/abi/generated-labels.md index ae4f6be8..8f2c77cd 100644 --- a/doc/abi/generated-labels.md +++ b/doc/abi/generated-labels.md @@ -10,6 +10,8 @@ where `11111` is a sequential number and `xx` is the type: * `an` – logical conjunction short-circuiting +* `bc` – array bounds checking (`-fbounds-checking`) + * `c8` – constant `#8` for `BIT` when immediate addressing is not available * `co` – greater-than comparison diff --git a/doc/api/command-line.md b/doc/api/command-line.md index 5fba321d..86804f80 100644 --- a/doc/api/command-line.md +++ b/doc/api/command-line.md @@ -37,9 +37,11 @@ * `-fjmp-fix`, `-fno-jmp-fix` – Whether should prevent indirect JMP bug on page boundary. `.ini` equivalent: `prevent_jmp_indirect_bug`. -* `-fdecimal-mode`, `-fno-decimal-mode` – Whether should decimal mode be available.` .ini` equivalent: `decimal_mode`. +* `-fdecimal-mode`, `-fno-decimal-mode` – Whether decimal mode should be available. `.ini` equivalent: `decimal_mode`. -* `-fvariable-overlap`, `-fno-variable-overlap` – Whether should variables overlap if their scopes do not intersect. Default: yes. +* `-fvariable-overlap`, `-fno-variable-overlap` – Whether variables should overlap if their scopes do not intersect. Default: yes. + +* `-fbounds-checking`, `-fnobounds-checking` – Whether should insert bounds checking on array access. Default: no. ## Optimization options diff --git a/src/main/scala/millfork/assembly/AssemblyLine.scala b/src/main/scala/millfork/assembly/AssemblyLine.scala index 1b1bafe1..8569de22 100644 --- a/src/main/scala/millfork/assembly/AssemblyLine.scala +++ b/src/main/scala/millfork/assembly/AssemblyLine.scala @@ -178,7 +178,7 @@ object AssemblyLine { def discardYF() = AssemblyLine(DISCARD_YF, AddrMode.DoesNotExist, Constant.Zero) - def immediate(opcode: Opcode.Value, value: Int) = AssemblyLine(opcode, AddrMode.Immediate, NumericConstant(value, 1)) + def immediate(opcode: Opcode.Value, value: Long) = AssemblyLine(opcode, AddrMode.Immediate, NumericConstant(value, 1)) def immediate(opcode: Opcode.Value, value: Constant) = AssemblyLine(opcode, AddrMode.Immediate, value) diff --git a/src/main/scala/millfork/compiler/MfCompiler.scala b/src/main/scala/millfork/compiler/MfCompiler.scala index d38468fe..7602c39c 100644 --- a/src/main/scala/millfork/compiler/MfCompiler.scala +++ b/src/main/scala/millfork/compiler/MfCompiler.scala @@ -98,9 +98,9 @@ object MlCompiler { getExpressionType(ctx, param) b case IndexedExpression(_, _) => b - case SeparateBytesExpression(h, l) => - if (getExpressionType(ctx, h).size > 1) ErrorReporting.error("Hi byte too large", h.position) - if (getExpressionType(ctx, l).size > 1) ErrorReporting.error("Lo byte too large", l.position) + case SeparateBytesExpression(hi, lo) => + if (getExpressionType(ctx, hi).size > 1) ErrorReporting.error("Hi byte too large", hi.position) + if (getExpressionType(ctx, lo).size > 1) ErrorReporting.error("Lo byte too large", lo.position) w case SumExpression(params, _) => b case FunctionCallExpression("not", params) => bool @@ -158,7 +158,7 @@ object MlCompiler { AssemblyLine(LDA, Immediate, expr.hiByte), AssemblyLine(LDY, Immediate, expr.loByte)) case m: VariableInMemory => - val addrMode = if(m.zeropage) ZeroPage else Absolute + val addrMode = if (m.zeropage) ZeroPage else Absolute val addr = m.toAddress m.typ.size match { case 0 => Nil @@ -309,16 +309,16 @@ object MlCompiler { if (register == Register.A) { indexRegister match { case Register.Y => - calculatingIndex ++ List(AssemblyLine.absoluteY(STA, arrayAddr + constIndex)) + calculatingIndex ++ arrayBoundsCheck(ctx, arrayName, Register.Y, indexExpr) ++ List(AssemblyLine.absoluteY(STA, arrayAddr + constIndex)) case Register.X => - calculatingIndex ++ List(AssemblyLine.absoluteX(STA, arrayAddr + constIndex)) + calculatingIndex ++ arrayBoundsCheck(ctx, arrayName, Register.X, indexExpr) ++ List(AssemblyLine.absoluteX(STA, arrayAddr + constIndex)) } } else { indexRegister match { case Register.Y => - calculatingIndex ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteY(STA, arrayAddr + constIndex)) + calculatingIndex ++ arrayBoundsCheck(ctx, arrayName, Register.Y, indexExpr) ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteY(STA, arrayAddr + constIndex)) case Register.X => - calculatingIndex ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteX(STA, arrayAddr + constIndex)) + calculatingIndex ++ arrayBoundsCheck(ctx, arrayName, Register.X, indexExpr) ++ List(AssemblyLine.implied(transferToA), AssemblyLine.absoluteX(STA, arrayAddr + constIndex)) } } } @@ -711,11 +711,11 @@ object MlCompiler { else if (target.typ.isSigned) { val label = nextLabel("sx") AssemblyLine.variable(ctx, STA, target) ++ - List( - AssemblyLine.immediate(ORA, 0x7f), - AssemblyLine.relative(BMI, label), - AssemblyLine.immediate(LDA, 0), - AssemblyLine.label(label)) ++ + List( + AssemblyLine.immediate(ORA, 0x7f), + AssemblyLine.relative(BMI, label), + AssemblyLine.immediate(LDA, 0), + AssemblyLine.label(label)) ++ List.tabulate(target.typ.size - 1)(i => AssemblyLine.variable(ctx, STA, target, i + 1)).flatten } else { AssemblyLine.variable(ctx, STA, target) ++ @@ -735,9 +735,9 @@ object MlCompiler { val calculatingIndex = compile(ctx, variableIndex, Some(b, RegisterVariable(indexRegister, b)), NoBranching) indexRegister match { case Register.Y => - calculatingIndex ++ List(AssemblyLine.absoluteY(load, arrayAddr + constantIndex)) + calculatingIndex ++ arrayBoundsCheck(ctx, arrayName, Register.Y, indexExpr) ++ List(AssemblyLine.absoluteY(load, arrayAddr + constantIndex)) case Register.X => - calculatingIndex ++ List(AssemblyLine.absoluteX(load, arrayAddr + constantIndex)) + calculatingIndex ++ arrayBoundsCheck(ctx, arrayName, Register.X, indexExpr) ++ List(AssemblyLine.absoluteX(load, arrayAddr + constantIndex)) } } @@ -797,7 +797,7 @@ object MlCompiler { } { case (exprType, target) => assertCompatible(exprType, target.typ) target match { - // TODO: some more complex ones may not work correctly + // TODO: some more complex ones may not work correctly case RegisterVariable(Register.A | Register.X | Register.Y, _) => compile(ctx, l, exprTypeAndVariable, branches) case RegisterVariable(Register.AX, _) => compile(ctx, l, Some(b -> RegisterVariable(Register.A, b)), branches) ++ @@ -929,18 +929,18 @@ object MlCompiler { // TODO: signed val (size, signed) = assertComparison(ctx, params) compileTransitiveRelation(ctx, "<=", params, exprTypeAndVariable, branches) { (l, r) => - size match { - case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) - case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) - } + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + } } case "==" => val size = params.map(p => getExpressionType(ctx, p).size).max compileTransitiveRelation(ctx, "==", params, exprTypeAndVariable, branches) { (l, r) => - size match { - case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches) - case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches) - } + size match { + case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches) + } } case "!=" => val (l, r, size) = assertBinary(ctx, params) @@ -1150,7 +1150,7 @@ object MlCompiler { operator: String, params: List[Expression], exprTypeAndVariable: Option[(Type, Variable)], - branches: BranchSpec)(binary: (Expression, Expression) => List[AssemblyLine]): List[AssemblyLine] ={ + branches: BranchSpec)(binary: (Expression, Expression) => List[AssemblyLine]): List[AssemblyLine] = { params match { case List(l, r) => binary(l, r) case List(_) | Nil => @@ -1158,7 +1158,7 @@ object MlCompiler { case _ => val conjunction = params.init.zip(params.tail).map { case (l, r) => FunctionCallExpression(operator, List(l, r)) - }.reduceLeft((a,b) => FunctionCallExpression("&&", List(a, b))) + }.reduceLeft((a, b) => FunctionCallExpression("&&", List(a, b))) compile(ctx, conjunction, exprTypeAndVariable, branches) } } @@ -1355,7 +1355,7 @@ object MlCompiler { ??? case (_, actualParam) => } - case NormalParamSignature(Nil) => i.code + case NormalParamSignature(Nil) => case NormalParamSignature(normalParams) => ??? } actualCode @@ -1443,11 +1443,11 @@ object MlCompiler { statement match { case AssemblyStatement(o, a, x, e) => val c: Constant = x match { - // TODO: hmmm + // TODO: hmmm case VariableExpression(name) => if (OpcodeClasses.ShortBranching(o) || o == JMP || o == LABEL) { MemoryAddressConstant(Label(name)) - } else{ + } else { env.evalForAsm(x).getOrElse(env.get[ThingInMemory](name, x.position).toAddress) } case _ => @@ -1700,14 +1700,14 @@ object MlCompiler { FunctionCallExpression("<", List(vex, f.end)), f.body :+ increment), )) - case (ForDirection.To | ForDirection.ParallelTo,_,_) => + case (ForDirection.To | ForDirection.ParallelTo, _, _) => compile(ctx, List( Assignment(vex, f.start), WhileStatement( FunctionCallExpression("<=", List(vex, f.end)), f.body :+ increment), )) - case (ForDirection.DownTo,_,_) => + case (ForDirection.DownTo, _, _) => compile(ctx, List( Assignment(vex, f.start), IfStatement( @@ -1734,4 +1734,42 @@ object MlCompiler { private def branchChunk(opcode: Opcode.Value, labelName: String) = { LinearChunk(List(AssemblyLine.relative(opcode, Label(labelName)))) } + + def arrayBoundsCheck(ctx: CompilationContext, arrayName: String, register: Register.Value, index: Expression): List[AssemblyLine] = { + if (!ctx.options.flags(CompilationFlag.CheckIndexOutOfBounds)) return Nil + ctx.env.maybeGet[ConstantThing](arrayName + ".length") match { + case None => Nil + case Some(thing) => thing.value match { + case NumericConstant(arrayLength, _) => + ctx.env.eval(index) match { + case Some(NumericConstant(i, _)) => + if (i >= 0) { + if (i < arrayLength) return Nil + if (i >= arrayLength) return List( + AssemblyLine.implied(PHP), + AssemblyLine.absolute(JSR, ctx.env.get[ThingInMemory]("_panic"))) + } + case _ => + } + if (arrayLength > 0 && arrayLength < 255) { + val label = nextLabel("bc") + val compare = register match { + case Register.A => CMP + case Register.X => CPX + case Register.Y => CPY + } + List( + AssemblyLine.implied(PHP), + AssemblyLine.immediate(compare, arrayLength), + AssemblyLine.relative(BCC, label), + AssemblyLine.absolute(JSR, ctx.env.get[ThingInMemory]("_panic")), + AssemblyLine.label(label), + AssemblyLine.implied(PLP)) + } else { + Nil + } + case _ => Nil + } + } + } }