From 90e5360bfd54ad00490147697c50ea51c8813e64 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Fri, 6 Aug 2021 21:01:03 +0200 Subject: [PATCH] =?UTF-8?q?Related=20to=20#119:=20=E2=80=93=20Detection=20?= =?UTF-8?q?of=20simple=20byte=20overflow=20cases.=20=E2=80=93=20Optimizati?= =?UTF-8?q?on=20of=208=C3=978=E2=86=9216=20multiplication=20on=206809.=20?= =?UTF-8?q?=E2=80=93=20Multiplication=20optimizations=20on=20Z80.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../scala/millfork/CompilationOptions.scala | 3 + src/main/scala/millfork/Main.scala | 4 + .../compiler/AbstractExpressionCompiler.scala | 29 ++++ .../AbstractStatementPreprocessor.scala | 1 + .../m6809/M6809ExpressionCompiler.scala | 10 +- .../millfork/compiler/m6809/M6809MulDiv.scala | 4 +- .../compiler/mos/PseudoregisterBuiltIns.scala | 15 +- .../millfork/compiler/z80/Z80Multiply.scala | 10 +- src/main/scala/millfork/env/Environment.scala | 2 + .../scala/millfork/env/OverflowDetector.scala | 134 ++++++++++++++++++ .../scala/millfork/test/ByteMathSuite.scala | 17 +++ .../scala/millfork/test/WarningSuite.scala | 32 +++++ 12 files changed, 252 insertions(+), 9 deletions(-) create mode 100644 src/main/scala/millfork/env/OverflowDetector.scala diff --git a/src/main/scala/millfork/CompilationOptions.scala b/src/main/scala/millfork/CompilationOptions.scala index ef2c6fba..72ca7290 100644 --- a/src/main/scala/millfork/CompilationOptions.scala +++ b/src/main/scala/millfork/CompilationOptions.scala @@ -421,6 +421,7 @@ object Cpu extends Enumeration { EnableBreakpoints, UseOptimizationHints, GenericWarnings, + ByteOverflowWarning, UselessCodeWarning, BuggyCodeWarning, FallbackValueUseWarning, @@ -585,6 +586,7 @@ object CompilationFlag extends Enumeration { SingleThreaded, // warning options GenericWarnings, + ByteOverflowWarning, UselessCodeWarning, BuggyCodeWarning, DeprecationWarning, @@ -603,6 +605,7 @@ object CompilationFlag extends Enumeration { val allWarnings: Set[CompilationFlag.Value] = Set( GenericWarnings, + ByteOverflowWarning, UselessCodeWarning, BuggyCodeWarning, DeprecationWarning, diff --git a/src/main/scala/millfork/Main.scala b/src/main/scala/millfork/Main.scala index 74ac3b43..4499beda 100644 --- a/src/main/scala/millfork/Main.scala +++ b/src/main/scala/millfork/Main.scala @@ -818,6 +818,10 @@ object Main { c.changeFlag(CompilationFlag.RorWarning, v) }.description("Whether should warn about the ROR instruction (6502 only). Default: disabled.") + boolean("-Woverflow", "-Wno-overflow").repeatable().action { (c, v) => + c.changeFlag(CompilationFlag.ByteOverflowWarning, v) + }.description("Whether should warn about byte overflow. Default: enabled.") + boolean("-Wuseless", "-Wno-useless").repeatable().action { (c, v) => c.changeFlag(CompilationFlag.UselessCodeWarning, v) }.description("Whether should warn about code that does nothing. Default: enabled.") diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 80a14a83..96f7d127 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -7,6 +7,8 @@ import millfork.error.{ConsoleLogger, Logger} import millfork.assembly.AbstractCode import millfork.output.NoAlignment +import scala.collection.mutable.ListBuffer + /** * @author Karol Stasiak */ @@ -14,6 +16,33 @@ class AbstractExpressionCompiler[T <: AbstractCode] { def getExpressionType(ctx: CompilationContext, expr: Expression): Type = AbstractExpressionCompiler.getExpressionType(ctx, expr) + def extractWordExpandedBytes(ctx: CompilationContext, params:List[Expression]): Option[List[Expression]] = { + val result = ListBuffer[Expression]() + for(param <- params) { + if (ctx.env.eval(param).isDefined) return None + AbstractExpressionCompiler.getExpressionType(ctx, param) match { + case t: PlainType if t.size == 1 && !t.isSigned => + result += param + case t: PlainType if t.size == 2 => + param match { + case FunctionCallExpression(functionName, List(inner)) => + AbstractExpressionCompiler.getExpressionType(ctx, inner) match { + case t: PlainType if t.size == 1 && !t.isSigned => + ctx.env.maybeGet[Type](functionName) match { + case Some(tw: PlainType) if tw.size == 2 => + result += inner + case _ => return None + } + case _ => return None + } + case _ => return None + } + case _ => return None + } + } + Some(result.toList) + } + def assertAllArithmetic(ctx: CompilationContext,expressions: List[Expression], booleanHint: String = ""): Unit = { for(e <- expressions) { val typ = getExpressionType(ctx, e) diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index eaeccfee..f1610e21 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -150,6 +150,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte case _ => } + new OverflowDetector(ctx).detectOverflow(stmt) stmt match { case Assignment(ve@VariableExpression(v), arg) if trackableVars(v) => cv = search(arg, cv) diff --git a/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala b/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala index c630ff70..4e5123e5 100644 --- a/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala @@ -7,7 +7,7 @@ import millfork.assembly.m6809.{Absolute, DAccumulatorIndexed, Immediate, Indexe import millfork.compiler.{AbstractExpressionCompiler, BranchIfFalse, BranchIfTrue, BranchSpec, ComparisonType, CompilationContext, NoBranching} import millfork.node.{DerefExpression, Expression, FunctionCallExpression, GeneratedConstantExpression, IndexedExpression, LhsExpression, LiteralExpression, M6809Register, SeparateBytesExpression, SumExpression, VariableExpression} import millfork.assembly.m6809.MOpcode._ -import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType} +import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, PlainType, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType} import scala.collection.GenTraversableOnce @@ -292,7 +292,13 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { assertSizesForMultiplication(ctx, params, inPlace = false) getArithmeticParamMaxSize(ctx, params) match { case 1 => M6809MulDiv.compileByteMultiplication(ctx, params, updateDerefX = false) ++ targetifyB(ctx, target, isSigned = false) - case 2 => M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target) + case 2 => + extractWordExpandedBytes(ctx, params) match { + case Some(byteParams) if byteParams.size == 2 => + M6809MulDiv.compileByteMultiplication(ctx, byteParams, updateDerefX = false) ++ targetifyD(ctx, target) + case _ => + M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target) + } case 0 => Nil case _ => ctx.log.error("Multiplication of variables larger than 2 bytes is not supported", expr.position) diff --git a/src/main/scala/millfork/compiler/m6809/M6809MulDiv.scala b/src/main/scala/millfork/compiler/m6809/M6809MulDiv.scala index afcc7240..f87936dc 100644 --- a/src/main/scala/millfork/compiler/m6809/M6809MulDiv.scala +++ b/src/main/scala/millfork/compiler/m6809/M6809MulDiv.scala @@ -15,9 +15,9 @@ import scala.collection.mutable.ListBuffer */ object M6809MulDiv { - def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean): List[MLine] = { + def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean, forceMul: Boolean = false): List[MLine] = { var constant = Constant.One - val variablePart = params.flatMap { p => + val variablePart = if(forceMul) params else params.flatMap { p => ctx.env.eval(p) match { case Some(c) => constant = CompoundConstant(MathOperator.Times, constant, c).quickSimplify diff --git a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala index 71e8db8b..c26a7633 100644 --- a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala @@ -730,11 +730,23 @@ object PseudoregisterBuiltIns { case (1, 1) => // ok case _ => ctx.log.fatal("Invalid code path", param2.position) } + val b = ctx.env.get[Type]("byte") + val w = ctx.env.get[Type]("word") + val reg = ctx.env.get[VariableInMemory]("__reg") if (!storeInRegLo && param1OrRegister.isDefined) { (ctx.env.eval(param1OrRegister.get), ctx.env.eval(param2)) match { case (Some(l), Some(r)) => val product = CompoundConstant(MathOperator.Times, l, r).quickSimplify return List(AssemblyLine.immediate(LDA, product.loByte), AssemblyLine.immediate(LDX, product.hiByte)) + case (Some(NumericConstant(2, _)), _) => + val evalParam2 = MosExpressionCompiler.compile(ctx, param2, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None) + val label = ctx.nextLabel("sh") + return evalParam2 ++ List( + AssemblyLine.implied(ASL), + AssemblyLine.immediate(LDX, 0), + AssemblyLine.relative(BCC, label), + AssemblyLine.implied(INX), + AssemblyLine.label(label)) case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c)=> return compileWordShiftOps(left = true, ctx, param2, LiteralExpression(java.lang.Long.bitCount(c - 1), 1)) case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c)=> @@ -742,9 +754,6 @@ object PseudoregisterBuiltIns { case _ => } } - val b = ctx.env.get[Type]("byte") - val w = ctx.env.get[Type]("word") - val reg = ctx.env.get[VariableInMemory]("__reg") val load: List[AssemblyLine] = param1OrRegister match { case Some(param1) => val code1 = MosExpressionCompiler.compile(ctx, param1, Some(w -> RegisterVariable(MosRegister.AX, w)), BranchSpec.None) diff --git a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala index 917854ce..6e99a9fd 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala @@ -326,8 +326,14 @@ object Z80Multiply { case (1, 1) => // ok case _ => ctx.log.fatal("Invalid code path", l.position) } - ctx.env.eval(r) match { - case Some(c) => + (ctx.env.eval(l), ctx.env.eval(r)) match { + case (Some(p), Some(q)) => + List(ZLine.ldImm16(ZRegister.HL, CompoundConstant(MathOperator.Times, p, q).quickSimplify)) + case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c) => + Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL)) + case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c) => + Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL)) + case (_, Some(c)) => Z80ExpressionCompiler.compileToDE(ctx, l) ++ List(ZLine.ldImm8(ZRegister.A, c)) ++ multiplication16And8(ctx) case _ => val lw = Z80ExpressionCompiler.compileToDE(ctx, l) diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 3c2b324f..5b317749 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -1929,6 +1929,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } def registerArray(stmt: ArrayDeclarationStatement, options: CompilationOptions): Unit = { + new OverflowDetector(this, options).detectOverflow(stmt) if (options.flag(CompilationFlag.LUnixRelocatableCode) && stmt.alignment.exists(_.isMultiplePages)) { log.error("Invalid alignment for LUnix code", stmt.position) } @@ -2090,6 +2091,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions, isPointy: Boolean): Unit = { + new OverflowDetector(this, options).detectOverflow(stmt) val name = stmt.name val position = stmt.position if (name == "" || name.contains(".") && !name.contains(".return")) { diff --git a/src/main/scala/millfork/env/OverflowDetector.scala b/src/main/scala/millfork/env/OverflowDetector.scala new file mode 100644 index 00000000..b97d7011 --- /dev/null +++ b/src/main/scala/millfork/env/OverflowDetector.scala @@ -0,0 +1,134 @@ +package millfork.env + +import millfork.{CompilationFlag, CompilationOptions} +import millfork.compiler.{AbstractExpressionCompiler, CompilationContext} +import millfork.error.Logger +import millfork.node._ + +/** + * @author Karol Stasiak + */ +class OverflowDetector(env: Environment, options: CompilationOptions) { + + def this(ctx: CompilationContext) { + this(ctx.env, ctx.options) + } + + private def log: Logger = options.log + + private def isWord(e: Expression): Boolean = + AbstractExpressionCompiler.getExpressionType(env, log, e) match { + case t: PlainType => t.size == 2 + case _ => false + } + + private def isWord(typeName: String): Boolean = + env.maybeGet[Thing](typeName) match { + case Some(t: PlainType) => t.size == 2 + case _ => false + } + + private def isWord(typ: Type): Boolean = + typ match { + case t: PlainType => t.size == 2 + case _ => false + } + + private def isByte(e: Expression): Boolean = + AbstractExpressionCompiler.getExpressionType(env, log, e) match { + case t: PlainType => t.size == 1 + case _ => false + } + + def warnConstantOverflow(e: Expression, op: String): Unit = { + if (options.flag(CompilationFlag.ByteOverflowWarning)) { + log.warn(s"Constant byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position) + } + } + + def warnDynamicOverflow(e: Expression, op: String): Unit = { + if (options.flag(CompilationFlag.ByteOverflowWarning)) { + log.warn(s"Potential byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position) + } + } + + def scanExpression(e: Expression, willBeAssignedToWord: Boolean): Unit = { + if (willBeAssignedToWord) { + e match { + case FunctionCallExpression("<<", List(l, r)) => + if (isByte(l) && isByte(r)) { + (env.eval(l), env.eval(r)) match { + case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) => + if (lc >= 0 && rc >= 0 && (lc << rc) > 255) { + warnConstantOverflow(e, "<<") + } + case (_, Some(NumericConstant(0, _))) => + case _ => + warnDynamicOverflow(e, "<<") + } + } + case FunctionCallExpression("*", List(l, r)) => + if (isByte(l) && isByte(r)) { + (env.eval(l), env.eval(r)) match { + case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) => + if (lc >= 0 && rc >= 0 && (lc * rc) > 255) { + warnConstantOverflow(e, "*") + } + case (_, Some(NumericConstant(0, _))) => + case (_, Some(NumericConstant(1, _))) => + case (Some(NumericConstant(0, _)), _) => + case (Some(NumericConstant(1, _)), _) => + case _ => + warnDynamicOverflow(e, "*") + } + } + case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", List(SumExpression(expressions, _))) => + if (expressions.map(_._2).forall(isByte)) { + + } + case _ => + } + } + e match { + case SumExpression(expressions, decimal) => + if (willBeAssignedToWord && !decimal && isByte(e)) env.eval(e) match { + case Some(NumericConstant(n, _)) if n < -128 || n > 255 => + warnConstantOverflow(e, "+") + case _ => + } + for ((_, e) <- expressions) { + scanExpression(e, willBeAssignedToWord = willBeAssignedToWord) + } + case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", expressions) => + expressions.foreach(x => scanExpression(x, willBeAssignedToWord = true)) + case FunctionCallExpression("|" | "^" | "&" | "not", expressions) => + expressions.foreach(x => scanExpression(x, willBeAssignedToWord = false)) + case FunctionCallExpression(fname, expressions) => + env.maybeGet[Thing](fname) match { + case Some(f: FunctionInMemory) if f.params.length == expressions.length => + for ((e, t) <- expressions zip f.params.types) { + scanExpression(e, willBeAssignedToWord = isWord(t)) + } + case _ => + for (e <- expressions) { + scanExpression(e, willBeAssignedToWord = false) + } + } + case _ => + } + } + + def detectOverflow(stmt: Statement): Unit = { + stmt match { + case Assignment(lhs, rhs) => + if (isWord(lhs)) scanExpression(rhs, willBeAssignedToWord = true) + case v: VariableDeclarationStatement => + v.initialValue match { + case Some(e) => scanExpression(e, willBeAssignedToWord = isWord(v.typ)) + case _ => + } + case s => + s.getAllExpressions.foreach(e => scanExpression(e, willBeAssignedToWord = false)) + } + } +} diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala index ad7477b4..767ad0d1 100644 --- a/src/test/scala/millfork/test/ByteMathSuite.scala +++ b/src/test/scala/millfork/test/ByteMathSuite.scala @@ -465,4 +465,21 @@ class ByteMathSuite extends FunSuite with Matchers with AppendedClues { m.readByte(0xc000) should equal(125) } } + + test("Optimal multiplication detection") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + | import zp_reg + | word output @$c000 + | noinline void run(byte a, byte b) { + | output = word(a) * b + | } + | void main () { + | run(100, 42) + | } + """. + stripMargin) { m => + m.readWord(0xc000) should equal(4200) + } + } } diff --git a/src/test/scala/millfork/test/WarningSuite.scala b/src/test/scala/millfork/test/WarningSuite.scala index 357c8dea..2aea4a48 100644 --- a/src/test/scala/millfork/test/WarningSuite.scala +++ b/src/test/scala/millfork/test/WarningSuite.scala @@ -55,4 +55,36 @@ class WarningSuite extends FunSuite with Matchers { """.stripMargin) { m => } } + + test("Warn about unintended byte overflow") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos)( + """ + | import zp_reg + | const word screenOffset = (10*40)+5 + | noinline void func(byte x, byte y) { + | word screenOffset + | screenOffset = (x*40) + y + | } + | noinline word getNESScreenOffset(byte x, byte y) { + | word temp + | temp = (y << 5) +x + | } + | noinline word getSomeFunc(byte x, byte y, byte z) { + | word temp + | temp = ((x + z) << 2) + (y << 5) + | temp = byte((x + z) << 2) + (y << 5) + | } + | + | noinline byte someFunc(byte x, byte y) { + | return (x*y)-24 + | } + | void main() { + | func(0,0) + | getNESScreenOffset(0,0) + | getSomeFunc(0,screenOffset.lo,5) + | someFunc(0,0) + | } + """.stripMargin) { m => + } + } }