1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-15 09:29:49 +00:00

Related to #119:

– Detection of simple byte overflow cases.
– Optimization of 8×8→16 multiplication on 6809.
– Multiplication optimizations on Z80.
This commit is contained in:
Karol Stasiak 2021-08-06 21:01:03 +02:00
parent 7f6a0c6b0d
commit 90e5360bfd
12 changed files with 252 additions and 9 deletions

View File

@ -421,6 +421,7 @@ object Cpu extends Enumeration {
EnableBreakpoints,
UseOptimizationHints,
GenericWarnings,
ByteOverflowWarning,
UselessCodeWarning,
BuggyCodeWarning,
FallbackValueUseWarning,
@ -585,6 +586,7 @@ object CompilationFlag extends Enumeration {
SingleThreaded,
// warning options
GenericWarnings,
ByteOverflowWarning,
UselessCodeWarning,
BuggyCodeWarning,
DeprecationWarning,
@ -603,6 +605,7 @@ object CompilationFlag extends Enumeration {
val allWarnings: Set[CompilationFlag.Value] = Set(
GenericWarnings,
ByteOverflowWarning,
UselessCodeWarning,
BuggyCodeWarning,
DeprecationWarning,

View File

@ -818,6 +818,10 @@ object Main {
c.changeFlag(CompilationFlag.RorWarning, v)
}.description("Whether should warn about the ROR instruction (6502 only). Default: disabled.")
boolean("-Woverflow", "-Wno-overflow").repeatable().action { (c, v) =>
c.changeFlag(CompilationFlag.ByteOverflowWarning, v)
}.description("Whether should warn about byte overflow. Default: enabled.")
boolean("-Wuseless", "-Wno-useless").repeatable().action { (c, v) =>
c.changeFlag(CompilationFlag.UselessCodeWarning, v)
}.description("Whether should warn about code that does nothing. Default: enabled.")

View File

@ -7,6 +7,8 @@ import millfork.error.{ConsoleLogger, Logger}
import millfork.assembly.AbstractCode
import millfork.output.NoAlignment
import scala.collection.mutable.ListBuffer
/**
* @author Karol Stasiak
*/
@ -14,6 +16,33 @@ class AbstractExpressionCompiler[T <: AbstractCode] {
def getExpressionType(ctx: CompilationContext, expr: Expression): Type = AbstractExpressionCompiler.getExpressionType(ctx, expr)
def extractWordExpandedBytes(ctx: CompilationContext, params:List[Expression]): Option[List[Expression]] = {
val result = ListBuffer[Expression]()
for(param <- params) {
if (ctx.env.eval(param).isDefined) return None
AbstractExpressionCompiler.getExpressionType(ctx, param) match {
case t: PlainType if t.size == 1 && !t.isSigned =>
result += param
case t: PlainType if t.size == 2 =>
param match {
case FunctionCallExpression(functionName, List(inner)) =>
AbstractExpressionCompiler.getExpressionType(ctx, inner) match {
case t: PlainType if t.size == 1 && !t.isSigned =>
ctx.env.maybeGet[Type](functionName) match {
case Some(tw: PlainType) if tw.size == 2 =>
result += inner
case _ => return None
}
case _ => return None
}
case _ => return None
}
case _ => return None
}
}
Some(result.toList)
}
def assertAllArithmetic(ctx: CompilationContext,expressions: List[Expression], booleanHint: String = ""): Unit = {
for(e <- expressions) {
val typ = getExpressionType(ctx, e)

View File

@ -150,6 +150,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
case _ =>
}
new OverflowDetector(ctx).detectOverflow(stmt)
stmt match {
case Assignment(ve@VariableExpression(v), arg) if trackableVars(v) =>
cv = search(arg, cv)

View File

@ -7,7 +7,7 @@ import millfork.assembly.m6809.{Absolute, DAccumulatorIndexed, Immediate, Indexe
import millfork.compiler.{AbstractExpressionCompiler, BranchIfFalse, BranchIfTrue, BranchSpec, ComparisonType, CompilationContext, NoBranching}
import millfork.node.{DerefExpression, Expression, FunctionCallExpression, GeneratedConstantExpression, IndexedExpression, LhsExpression, LiteralExpression, M6809Register, SeparateBytesExpression, SumExpression, VariableExpression}
import millfork.assembly.m6809.MOpcode._
import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType}
import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, PlainType, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType}
import scala.collection.GenTraversableOnce
@ -292,7 +292,13 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] {
assertSizesForMultiplication(ctx, params, inPlace = false)
getArithmeticParamMaxSize(ctx, params) match {
case 1 => M6809MulDiv.compileByteMultiplication(ctx, params, updateDerefX = false) ++ targetifyB(ctx, target, isSigned = false)
case 2 => M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target)
case 2 =>
extractWordExpandedBytes(ctx, params) match {
case Some(byteParams) if byteParams.size == 2 =>
M6809MulDiv.compileByteMultiplication(ctx, byteParams, updateDerefX = false) ++ targetifyD(ctx, target)
case _ =>
M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target)
}
case 0 => Nil
case _ =>
ctx.log.error("Multiplication of variables larger than 2 bytes is not supported", expr.position)

View File

@ -15,9 +15,9 @@ import scala.collection.mutable.ListBuffer
*/
object M6809MulDiv {
def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean): List[MLine] = {
def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean, forceMul: Boolean = false): List[MLine] = {
var constant = Constant.One
val variablePart = params.flatMap { p =>
val variablePart = if(forceMul) params else params.flatMap { p =>
ctx.env.eval(p) match {
case Some(c) =>
constant = CompoundConstant(MathOperator.Times, constant, c).quickSimplify

View File

@ -730,11 +730,23 @@ object PseudoregisterBuiltIns {
case (1, 1) => // ok
case _ => ctx.log.fatal("Invalid code path", param2.position)
}
val b = ctx.env.get[Type]("byte")
val w = ctx.env.get[Type]("word")
val reg = ctx.env.get[VariableInMemory]("__reg")
if (!storeInRegLo && param1OrRegister.isDefined) {
(ctx.env.eval(param1OrRegister.get), ctx.env.eval(param2)) match {
case (Some(l), Some(r)) =>
val product = CompoundConstant(MathOperator.Times, l, r).quickSimplify
return List(AssemblyLine.immediate(LDA, product.loByte), AssemblyLine.immediate(LDX, product.hiByte))
case (Some(NumericConstant(2, _)), _) =>
val evalParam2 = MosExpressionCompiler.compile(ctx, param2, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None)
val label = ctx.nextLabel("sh")
return evalParam2 ++ List(
AssemblyLine.implied(ASL),
AssemblyLine.immediate(LDX, 0),
AssemblyLine.relative(BCC, label),
AssemblyLine.implied(INX),
AssemblyLine.label(label))
case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c)=>
return compileWordShiftOps(left = true, ctx, param2, LiteralExpression(java.lang.Long.bitCount(c - 1), 1))
case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c)=>
@ -742,9 +754,6 @@ object PseudoregisterBuiltIns {
case _ =>
}
}
val b = ctx.env.get[Type]("byte")
val w = ctx.env.get[Type]("word")
val reg = ctx.env.get[VariableInMemory]("__reg")
val load: List[AssemblyLine] = param1OrRegister match {
case Some(param1) =>
val code1 = MosExpressionCompiler.compile(ctx, param1, Some(w -> RegisterVariable(MosRegister.AX, w)), BranchSpec.None)

View File

@ -326,8 +326,14 @@ object Z80Multiply {
case (1, 1) => // ok
case _ => ctx.log.fatal("Invalid code path", l.position)
}
ctx.env.eval(r) match {
case Some(c) =>
(ctx.env.eval(l), ctx.env.eval(r)) match {
case (Some(p), Some(q)) =>
List(ZLine.ldImm16(ZRegister.HL, CompoundConstant(MathOperator.Times, p, q).quickSimplify))
case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c) =>
Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL))
case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c) =>
Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL))
case (_, Some(c)) =>
Z80ExpressionCompiler.compileToDE(ctx, l) ++ List(ZLine.ldImm8(ZRegister.A, c)) ++ multiplication16And8(ctx)
case _ =>
val lw = Z80ExpressionCompiler.compileToDE(ctx, l)

View File

@ -1929,6 +1929,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
def registerArray(stmt: ArrayDeclarationStatement, options: CompilationOptions): Unit = {
new OverflowDetector(this, options).detectOverflow(stmt)
if (options.flag(CompilationFlag.LUnixRelocatableCode) && stmt.alignment.exists(_.isMultiplePages)) {
log.error("Invalid alignment for LUnix code", stmt.position)
}
@ -2090,6 +2091,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions, isPointy: Boolean): Unit = {
new OverflowDetector(this, options).detectOverflow(stmt)
val name = stmt.name
val position = stmt.position
if (name == "" || name.contains(".") && !name.contains(".return")) {

View File

@ -0,0 +1,134 @@
package millfork.env
import millfork.{CompilationFlag, CompilationOptions}
import millfork.compiler.{AbstractExpressionCompiler, CompilationContext}
import millfork.error.Logger
import millfork.node._
/**
* @author Karol Stasiak
*/
class OverflowDetector(env: Environment, options: CompilationOptions) {
def this(ctx: CompilationContext) {
this(ctx.env, ctx.options)
}
private def log: Logger = options.log
private def isWord(e: Expression): Boolean =
AbstractExpressionCompiler.getExpressionType(env, log, e) match {
case t: PlainType => t.size == 2
case _ => false
}
private def isWord(typeName: String): Boolean =
env.maybeGet[Thing](typeName) match {
case Some(t: PlainType) => t.size == 2
case _ => false
}
private def isWord(typ: Type): Boolean =
typ match {
case t: PlainType => t.size == 2
case _ => false
}
private def isByte(e: Expression): Boolean =
AbstractExpressionCompiler.getExpressionType(env, log, e) match {
case t: PlainType => t.size == 1
case _ => false
}
def warnConstantOverflow(e: Expression, op: String): Unit = {
if (options.flag(CompilationFlag.ByteOverflowWarning)) {
log.warn(s"Constant byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position)
}
}
def warnDynamicOverflow(e: Expression, op: String): Unit = {
if (options.flag(CompilationFlag.ByteOverflowWarning)) {
log.warn(s"Potential byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position)
}
}
def scanExpression(e: Expression, willBeAssignedToWord: Boolean): Unit = {
if (willBeAssignedToWord) {
e match {
case FunctionCallExpression("<<", List(l, r)) =>
if (isByte(l) && isByte(r)) {
(env.eval(l), env.eval(r)) match {
case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) =>
if (lc >= 0 && rc >= 0 && (lc << rc) > 255) {
warnConstantOverflow(e, "<<")
}
case (_, Some(NumericConstant(0, _))) =>
case _ =>
warnDynamicOverflow(e, "<<")
}
}
case FunctionCallExpression("*", List(l, r)) =>
if (isByte(l) && isByte(r)) {
(env.eval(l), env.eval(r)) match {
case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) =>
if (lc >= 0 && rc >= 0 && (lc * rc) > 255) {
warnConstantOverflow(e, "*")
}
case (_, Some(NumericConstant(0, _))) =>
case (_, Some(NumericConstant(1, _))) =>
case (Some(NumericConstant(0, _)), _) =>
case (Some(NumericConstant(1, _)), _) =>
case _ =>
warnDynamicOverflow(e, "*")
}
}
case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", List(SumExpression(expressions, _))) =>
if (expressions.map(_._2).forall(isByte)) {
}
case _ =>
}
}
e match {
case SumExpression(expressions, decimal) =>
if (willBeAssignedToWord && !decimal && isByte(e)) env.eval(e) match {
case Some(NumericConstant(n, _)) if n < -128 || n > 255 =>
warnConstantOverflow(e, "+")
case _ =>
}
for ((_, e) <- expressions) {
scanExpression(e, willBeAssignedToWord = willBeAssignedToWord)
}
case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", expressions) =>
expressions.foreach(x => scanExpression(x, willBeAssignedToWord = true))
case FunctionCallExpression("|" | "^" | "&" | "not", expressions) =>
expressions.foreach(x => scanExpression(x, willBeAssignedToWord = false))
case FunctionCallExpression(fname, expressions) =>
env.maybeGet[Thing](fname) match {
case Some(f: FunctionInMemory) if f.params.length == expressions.length =>
for ((e, t) <- expressions zip f.params.types) {
scanExpression(e, willBeAssignedToWord = isWord(t))
}
case _ =>
for (e <- expressions) {
scanExpression(e, willBeAssignedToWord = false)
}
}
case _ =>
}
}
def detectOverflow(stmt: Statement): Unit = {
stmt match {
case Assignment(lhs, rhs) =>
if (isWord(lhs)) scanExpression(rhs, willBeAssignedToWord = true)
case v: VariableDeclarationStatement =>
v.initialValue match {
case Some(e) => scanExpression(e, willBeAssignedToWord = isWord(v.typ))
case _ =>
}
case s =>
s.getAllExpressions.foreach(e => scanExpression(e, willBeAssignedToWord = false))
}
}
}

View File

@ -465,4 +465,21 @@ class ByteMathSuite extends FunSuite with Matchers with AppendedClues {
m.readByte(0xc000) should equal(125)
}
}
test("Optimal multiplication detection") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
| import zp_reg
| word output @$c000
| noinline void run(byte a, byte b) {
| output = word(a) * b
| }
| void main () {
| run(100, 42)
| }
""".
stripMargin) { m =>
m.readWord(0xc000) should equal(4200)
}
}
}

View File

@ -55,4 +55,36 @@ class WarningSuite extends FunSuite with Matchers {
""".stripMargin) { m =>
}
}
test("Warn about unintended byte overflow") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos)(
"""
| import zp_reg
| const word screenOffset = (10*40)+5
| noinline void func(byte x, byte y) {
| word screenOffset
| screenOffset = (x*40) + y
| }
| noinline word getNESScreenOffset(byte x, byte y) {
| word temp
| temp = (y << 5) +x
| }
| noinline word getSomeFunc(byte x, byte y, byte z) {
| word temp
| temp = ((x + z) << 2) + (y << 5)
| temp = byte((x + z) << 2) + (y << 5)
| }
|
| noinline byte someFunc(byte x, byte y) {
| return (x*y)-24
| }
| void main() {
| func(0,0)
| getNESScreenOffset(0,0)
| getSomeFunc(0,screenOffset.lo,5)
| someFunc(0,0)
| }
""".stripMargin) { m =>
}
}
}