mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-15 09:29:49 +00:00
Related to #119:
– Detection of simple byte overflow cases. – Optimization of 8×8→16 multiplication on 6809. – Multiplication optimizations on Z80.
This commit is contained in:
parent
7f6a0c6b0d
commit
90e5360bfd
@ -421,6 +421,7 @@ object Cpu extends Enumeration {
|
||||
EnableBreakpoints,
|
||||
UseOptimizationHints,
|
||||
GenericWarnings,
|
||||
ByteOverflowWarning,
|
||||
UselessCodeWarning,
|
||||
BuggyCodeWarning,
|
||||
FallbackValueUseWarning,
|
||||
@ -585,6 +586,7 @@ object CompilationFlag extends Enumeration {
|
||||
SingleThreaded,
|
||||
// warning options
|
||||
GenericWarnings,
|
||||
ByteOverflowWarning,
|
||||
UselessCodeWarning,
|
||||
BuggyCodeWarning,
|
||||
DeprecationWarning,
|
||||
@ -603,6 +605,7 @@ object CompilationFlag extends Enumeration {
|
||||
|
||||
val allWarnings: Set[CompilationFlag.Value] = Set(
|
||||
GenericWarnings,
|
||||
ByteOverflowWarning,
|
||||
UselessCodeWarning,
|
||||
BuggyCodeWarning,
|
||||
DeprecationWarning,
|
||||
|
@ -818,6 +818,10 @@ object Main {
|
||||
c.changeFlag(CompilationFlag.RorWarning, v)
|
||||
}.description("Whether should warn about the ROR instruction (6502 only). Default: disabled.")
|
||||
|
||||
boolean("-Woverflow", "-Wno-overflow").repeatable().action { (c, v) =>
|
||||
c.changeFlag(CompilationFlag.ByteOverflowWarning, v)
|
||||
}.description("Whether should warn about byte overflow. Default: enabled.")
|
||||
|
||||
boolean("-Wuseless", "-Wno-useless").repeatable().action { (c, v) =>
|
||||
c.changeFlag(CompilationFlag.UselessCodeWarning, v)
|
||||
}.description("Whether should warn about code that does nothing. Default: enabled.")
|
||||
|
@ -7,6 +7,8 @@ import millfork.error.{ConsoleLogger, Logger}
|
||||
import millfork.assembly.AbstractCode
|
||||
import millfork.output.NoAlignment
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
@ -14,6 +16,33 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
|
||||
|
||||
def getExpressionType(ctx: CompilationContext, expr: Expression): Type = AbstractExpressionCompiler.getExpressionType(ctx, expr)
|
||||
|
||||
def extractWordExpandedBytes(ctx: CompilationContext, params:List[Expression]): Option[List[Expression]] = {
|
||||
val result = ListBuffer[Expression]()
|
||||
for(param <- params) {
|
||||
if (ctx.env.eval(param).isDefined) return None
|
||||
AbstractExpressionCompiler.getExpressionType(ctx, param) match {
|
||||
case t: PlainType if t.size == 1 && !t.isSigned =>
|
||||
result += param
|
||||
case t: PlainType if t.size == 2 =>
|
||||
param match {
|
||||
case FunctionCallExpression(functionName, List(inner)) =>
|
||||
AbstractExpressionCompiler.getExpressionType(ctx, inner) match {
|
||||
case t: PlainType if t.size == 1 && !t.isSigned =>
|
||||
ctx.env.maybeGet[Type](functionName) match {
|
||||
case Some(tw: PlainType) if tw.size == 2 =>
|
||||
result += inner
|
||||
case _ => return None
|
||||
}
|
||||
case _ => return None
|
||||
}
|
||||
case _ => return None
|
||||
}
|
||||
case _ => return None
|
||||
}
|
||||
}
|
||||
Some(result.toList)
|
||||
}
|
||||
|
||||
def assertAllArithmetic(ctx: CompilationContext,expressions: List[Expression], booleanHint: String = ""): Unit = {
|
||||
for(e <- expressions) {
|
||||
val typ = getExpressionType(ctx, e)
|
||||
|
@ -150,6 +150,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
|
||||
|
||||
case _ =>
|
||||
}
|
||||
new OverflowDetector(ctx).detectOverflow(stmt)
|
||||
stmt match {
|
||||
case Assignment(ve@VariableExpression(v), arg) if trackableVars(v) =>
|
||||
cv = search(arg, cv)
|
||||
|
@ -7,7 +7,7 @@ import millfork.assembly.m6809.{Absolute, DAccumulatorIndexed, Immediate, Indexe
|
||||
import millfork.compiler.{AbstractExpressionCompiler, BranchIfFalse, BranchIfTrue, BranchSpec, ComparisonType, CompilationContext, NoBranching}
|
||||
import millfork.node.{DerefExpression, Expression, FunctionCallExpression, GeneratedConstantExpression, IndexedExpression, LhsExpression, LiteralExpression, M6809Register, SeparateBytesExpression, SumExpression, VariableExpression}
|
||||
import millfork.assembly.m6809.MOpcode._
|
||||
import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType}
|
||||
import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, PlainType, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType}
|
||||
|
||||
import scala.collection.GenTraversableOnce
|
||||
|
||||
@ -292,7 +292,13 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] {
|
||||
assertSizesForMultiplication(ctx, params, inPlace = false)
|
||||
getArithmeticParamMaxSize(ctx, params) match {
|
||||
case 1 => M6809MulDiv.compileByteMultiplication(ctx, params, updateDerefX = false) ++ targetifyB(ctx, target, isSigned = false)
|
||||
case 2 => M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target)
|
||||
case 2 =>
|
||||
extractWordExpandedBytes(ctx, params) match {
|
||||
case Some(byteParams) if byteParams.size == 2 =>
|
||||
M6809MulDiv.compileByteMultiplication(ctx, byteParams, updateDerefX = false) ++ targetifyD(ctx, target)
|
||||
case _ =>
|
||||
M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target)
|
||||
}
|
||||
case 0 => Nil
|
||||
case _ =>
|
||||
ctx.log.error("Multiplication of variables larger than 2 bytes is not supported", expr.position)
|
||||
|
@ -15,9 +15,9 @@ import scala.collection.mutable.ListBuffer
|
||||
*/
|
||||
object M6809MulDiv {
|
||||
|
||||
def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean): List[MLine] = {
|
||||
def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean, forceMul: Boolean = false): List[MLine] = {
|
||||
var constant = Constant.One
|
||||
val variablePart = params.flatMap { p =>
|
||||
val variablePart = if(forceMul) params else params.flatMap { p =>
|
||||
ctx.env.eval(p) match {
|
||||
case Some(c) =>
|
||||
constant = CompoundConstant(MathOperator.Times, constant, c).quickSimplify
|
||||
|
@ -730,11 +730,23 @@ object PseudoregisterBuiltIns {
|
||||
case (1, 1) => // ok
|
||||
case _ => ctx.log.fatal("Invalid code path", param2.position)
|
||||
}
|
||||
val b = ctx.env.get[Type]("byte")
|
||||
val w = ctx.env.get[Type]("word")
|
||||
val reg = ctx.env.get[VariableInMemory]("__reg")
|
||||
if (!storeInRegLo && param1OrRegister.isDefined) {
|
||||
(ctx.env.eval(param1OrRegister.get), ctx.env.eval(param2)) match {
|
||||
case (Some(l), Some(r)) =>
|
||||
val product = CompoundConstant(MathOperator.Times, l, r).quickSimplify
|
||||
return List(AssemblyLine.immediate(LDA, product.loByte), AssemblyLine.immediate(LDX, product.hiByte))
|
||||
case (Some(NumericConstant(2, _)), _) =>
|
||||
val evalParam2 = MosExpressionCompiler.compile(ctx, param2, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None)
|
||||
val label = ctx.nextLabel("sh")
|
||||
return evalParam2 ++ List(
|
||||
AssemblyLine.implied(ASL),
|
||||
AssemblyLine.immediate(LDX, 0),
|
||||
AssemblyLine.relative(BCC, label),
|
||||
AssemblyLine.implied(INX),
|
||||
AssemblyLine.label(label))
|
||||
case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c)=>
|
||||
return compileWordShiftOps(left = true, ctx, param2, LiteralExpression(java.lang.Long.bitCount(c - 1), 1))
|
||||
case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c)=>
|
||||
@ -742,9 +754,6 @@ object PseudoregisterBuiltIns {
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
val b = ctx.env.get[Type]("byte")
|
||||
val w = ctx.env.get[Type]("word")
|
||||
val reg = ctx.env.get[VariableInMemory]("__reg")
|
||||
val load: List[AssemblyLine] = param1OrRegister match {
|
||||
case Some(param1) =>
|
||||
val code1 = MosExpressionCompiler.compile(ctx, param1, Some(w -> RegisterVariable(MosRegister.AX, w)), BranchSpec.None)
|
||||
|
@ -326,8 +326,14 @@ object Z80Multiply {
|
||||
case (1, 1) => // ok
|
||||
case _ => ctx.log.fatal("Invalid code path", l.position)
|
||||
}
|
||||
ctx.env.eval(r) match {
|
||||
case Some(c) =>
|
||||
(ctx.env.eval(l), ctx.env.eval(r)) match {
|
||||
case (Some(p), Some(q)) =>
|
||||
List(ZLine.ldImm16(ZRegister.HL, CompoundConstant(MathOperator.Times, p, q).quickSimplify))
|
||||
case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c) =>
|
||||
Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL))
|
||||
case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c) =>
|
||||
Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL))
|
||||
case (_, Some(c)) =>
|
||||
Z80ExpressionCompiler.compileToDE(ctx, l) ++ List(ZLine.ldImm8(ZRegister.A, c)) ++ multiplication16And8(ctx)
|
||||
case _ =>
|
||||
val lw = Z80ExpressionCompiler.compileToDE(ctx, l)
|
||||
|
@ -1929,6 +1929,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
|
||||
def registerArray(stmt: ArrayDeclarationStatement, options: CompilationOptions): Unit = {
|
||||
new OverflowDetector(this, options).detectOverflow(stmt)
|
||||
if (options.flag(CompilationFlag.LUnixRelocatableCode) && stmt.alignment.exists(_.isMultiplePages)) {
|
||||
log.error("Invalid alignment for LUnix code", stmt.position)
|
||||
}
|
||||
@ -2090,6 +2091,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
|
||||
def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions, isPointy: Boolean): Unit = {
|
||||
new OverflowDetector(this, options).detectOverflow(stmt)
|
||||
val name = stmt.name
|
||||
val position = stmt.position
|
||||
if (name == "" || name.contains(".") && !name.contains(".return")) {
|
||||
|
134
src/main/scala/millfork/env/OverflowDetector.scala
vendored
Normal file
134
src/main/scala/millfork/env/OverflowDetector.scala
vendored
Normal file
@ -0,0 +1,134 @@
|
||||
package millfork.env
|
||||
|
||||
import millfork.{CompilationFlag, CompilationOptions}
|
||||
import millfork.compiler.{AbstractExpressionCompiler, CompilationContext}
|
||||
import millfork.error.Logger
|
||||
import millfork.node._
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
class OverflowDetector(env: Environment, options: CompilationOptions) {
|
||||
|
||||
def this(ctx: CompilationContext) {
|
||||
this(ctx.env, ctx.options)
|
||||
}
|
||||
|
||||
private def log: Logger = options.log
|
||||
|
||||
private def isWord(e: Expression): Boolean =
|
||||
AbstractExpressionCompiler.getExpressionType(env, log, e) match {
|
||||
case t: PlainType => t.size == 2
|
||||
case _ => false
|
||||
}
|
||||
|
||||
private def isWord(typeName: String): Boolean =
|
||||
env.maybeGet[Thing](typeName) match {
|
||||
case Some(t: PlainType) => t.size == 2
|
||||
case _ => false
|
||||
}
|
||||
|
||||
private def isWord(typ: Type): Boolean =
|
||||
typ match {
|
||||
case t: PlainType => t.size == 2
|
||||
case _ => false
|
||||
}
|
||||
|
||||
private def isByte(e: Expression): Boolean =
|
||||
AbstractExpressionCompiler.getExpressionType(env, log, e) match {
|
||||
case t: PlainType => t.size == 1
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def warnConstantOverflow(e: Expression, op: String): Unit = {
|
||||
if (options.flag(CompilationFlag.ByteOverflowWarning)) {
|
||||
log.warn(s"Constant byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position)
|
||||
}
|
||||
}
|
||||
|
||||
def warnDynamicOverflow(e: Expression, op: String): Unit = {
|
||||
if (options.flag(CompilationFlag.ByteOverflowWarning)) {
|
||||
log.warn(s"Potential byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position)
|
||||
}
|
||||
}
|
||||
|
||||
def scanExpression(e: Expression, willBeAssignedToWord: Boolean): Unit = {
|
||||
if (willBeAssignedToWord) {
|
||||
e match {
|
||||
case FunctionCallExpression("<<", List(l, r)) =>
|
||||
if (isByte(l) && isByte(r)) {
|
||||
(env.eval(l), env.eval(r)) match {
|
||||
case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) =>
|
||||
if (lc >= 0 && rc >= 0 && (lc << rc) > 255) {
|
||||
warnConstantOverflow(e, "<<")
|
||||
}
|
||||
case (_, Some(NumericConstant(0, _))) =>
|
||||
case _ =>
|
||||
warnDynamicOverflow(e, "<<")
|
||||
}
|
||||
}
|
||||
case FunctionCallExpression("*", List(l, r)) =>
|
||||
if (isByte(l) && isByte(r)) {
|
||||
(env.eval(l), env.eval(r)) match {
|
||||
case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) =>
|
||||
if (lc >= 0 && rc >= 0 && (lc * rc) > 255) {
|
||||
warnConstantOverflow(e, "*")
|
||||
}
|
||||
case (_, Some(NumericConstant(0, _))) =>
|
||||
case (_, Some(NumericConstant(1, _))) =>
|
||||
case (Some(NumericConstant(0, _)), _) =>
|
||||
case (Some(NumericConstant(1, _)), _) =>
|
||||
case _ =>
|
||||
warnDynamicOverflow(e, "*")
|
||||
}
|
||||
}
|
||||
case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", List(SumExpression(expressions, _))) =>
|
||||
if (expressions.map(_._2).forall(isByte)) {
|
||||
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
e match {
|
||||
case SumExpression(expressions, decimal) =>
|
||||
if (willBeAssignedToWord && !decimal && isByte(e)) env.eval(e) match {
|
||||
case Some(NumericConstant(n, _)) if n < -128 || n > 255 =>
|
||||
warnConstantOverflow(e, "+")
|
||||
case _ =>
|
||||
}
|
||||
for ((_, e) <- expressions) {
|
||||
scanExpression(e, willBeAssignedToWord = willBeAssignedToWord)
|
||||
}
|
||||
case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", expressions) =>
|
||||
expressions.foreach(x => scanExpression(x, willBeAssignedToWord = true))
|
||||
case FunctionCallExpression("|" | "^" | "&" | "not", expressions) =>
|
||||
expressions.foreach(x => scanExpression(x, willBeAssignedToWord = false))
|
||||
case FunctionCallExpression(fname, expressions) =>
|
||||
env.maybeGet[Thing](fname) match {
|
||||
case Some(f: FunctionInMemory) if f.params.length == expressions.length =>
|
||||
for ((e, t) <- expressions zip f.params.types) {
|
||||
scanExpression(e, willBeAssignedToWord = isWord(t))
|
||||
}
|
||||
case _ =>
|
||||
for (e <- expressions) {
|
||||
scanExpression(e, willBeAssignedToWord = false)
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
def detectOverflow(stmt: Statement): Unit = {
|
||||
stmt match {
|
||||
case Assignment(lhs, rhs) =>
|
||||
if (isWord(lhs)) scanExpression(rhs, willBeAssignedToWord = true)
|
||||
case v: VariableDeclarationStatement =>
|
||||
v.initialValue match {
|
||||
case Some(e) => scanExpression(e, willBeAssignedToWord = isWord(v.typ))
|
||||
case _ =>
|
||||
}
|
||||
case s =>
|
||||
s.getAllExpressions.foreach(e => scanExpression(e, willBeAssignedToWord = false))
|
||||
}
|
||||
}
|
||||
}
|
@ -465,4 +465,21 @@ class ByteMathSuite extends FunSuite with Matchers with AppendedClues {
|
||||
m.readByte(0xc000) should equal(125)
|
||||
}
|
||||
}
|
||||
|
||||
test("Optimal multiplication detection") {
|
||||
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
|
||||
"""
|
||||
| import zp_reg
|
||||
| word output @$c000
|
||||
| noinline void run(byte a, byte b) {
|
||||
| output = word(a) * b
|
||||
| }
|
||||
| void main () {
|
||||
| run(100, 42)
|
||||
| }
|
||||
""".
|
||||
stripMargin) { m =>
|
||||
m.readWord(0xc000) should equal(4200)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -55,4 +55,36 @@ class WarningSuite extends FunSuite with Matchers {
|
||||
""".stripMargin) { m =>
|
||||
}
|
||||
}
|
||||
|
||||
test("Warn about unintended byte overflow") {
|
||||
EmuUnoptimizedCrossPlatformRun(Cpu.Mos)(
|
||||
"""
|
||||
| import zp_reg
|
||||
| const word screenOffset = (10*40)+5
|
||||
| noinline void func(byte x, byte y) {
|
||||
| word screenOffset
|
||||
| screenOffset = (x*40) + y
|
||||
| }
|
||||
| noinline word getNESScreenOffset(byte x, byte y) {
|
||||
| word temp
|
||||
| temp = (y << 5) +x
|
||||
| }
|
||||
| noinline word getSomeFunc(byte x, byte y, byte z) {
|
||||
| word temp
|
||||
| temp = ((x + z) << 2) + (y << 5)
|
||||
| temp = byte((x + z) << 2) + (y << 5)
|
||||
| }
|
||||
|
|
||||
| noinline byte someFunc(byte x, byte y) {
|
||||
| return (x*y)-24
|
||||
| }
|
||||
| void main() {
|
||||
| func(0,0)
|
||||
| getNESScreenOffset(0,0)
|
||||
| getSomeFunc(0,screenOffset.lo,5)
|
||||
| someFunc(0,0)
|
||||
| }
|
||||
""".stripMargin) { m =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user