mirror of
https://github.com/irmen/prog8.git
synced 2024-12-22 18:30:01 +00:00
replace most common subexpressions by a single temp variable
This commit is contained in:
parent
07feb5c925
commit
d1f8ee1e56
@ -31,7 +31,10 @@ sealed class PtNode(val position: Position) {
|
||||
}
|
||||
|
||||
|
||||
class PtNodeGroup : PtNode(Position.DUMMY)
|
||||
sealed interface IPtStatementContainer
|
||||
|
||||
|
||||
class PtNodeGroup : PtNode(Position.DUMMY), IPtStatementContainer
|
||||
|
||||
|
||||
sealed class PtNamedNode(var name: String, position: Position): PtNode(position) {
|
||||
@ -75,7 +78,7 @@ class PtBlock(name: String,
|
||||
val source: SourceCode, // taken from the module the block is defined in.
|
||||
val options: Options,
|
||||
position: Position
|
||||
) : PtNamedNode(name, position) {
|
||||
) : PtNamedNode(name, position), IPtStatementContainer {
|
||||
enum class BlockAlignment {
|
||||
NONE,
|
||||
WORD,
|
||||
|
@ -36,7 +36,14 @@ sealed class PtExpression(val type: DataType, position: Position) : PtNode(posit
|
||||
return arrayIndexExpr!! isSameAs other.arrayIndexExpr!!
|
||||
}
|
||||
is PtArrayIndexer -> other is PtArrayIndexer && other.type==type && other.variable isSameAs variable && other.index isSameAs index && other.splitWords==splitWords
|
||||
is PtBinaryExpression -> other is PtBinaryExpression && other.left isSameAs left && other.right isSameAs right
|
||||
is PtBinaryExpression -> {
|
||||
if(other !is PtBinaryExpression || other.operator!=operator)
|
||||
false
|
||||
else if(operator in AssociativeOperators)
|
||||
(other.left isSameAs left && other.right isSameAs right) || (other.left isSameAs right && other.right isSameAs left)
|
||||
else
|
||||
other.left isSameAs left && other.right isSameAs right
|
||||
}
|
||||
is PtContainmentCheck -> other is PtContainmentCheck && other.type==type && other.element isSameAs element && other.iterable isSameAs iterable
|
||||
is PtIdentifier -> other is PtIdentifier && other.type==type && other.name==name
|
||||
is PtMachineRegister -> other is PtMachineRegister && other.type==type && other.register==register
|
||||
|
@ -23,7 +23,7 @@ class PtSub(
|
||||
val parameters: List<PtSubroutineParameter>,
|
||||
val returntype: DataType?,
|
||||
position: Position
|
||||
) : PtNamedNode(name, position), IPtSubroutine {
|
||||
) : PtNamedNode(name, position), IPtSubroutine, IPtStatementContainer {
|
||||
init {
|
||||
// params and return value should not be str
|
||||
if(parameters.any{ it.type !in NumericDatatypes })
|
||||
|
120
codeCore/src/prog8/code/optimize/Optimizer.kt
Normal file
120
codeCore/src/prog8/code/optimize/Optimizer.kt
Normal file
@ -0,0 +1,120 @@
|
||||
package prog8.code.optimize
|
||||
|
||||
import prog8.code.ast.*
|
||||
import prog8.code.core.*
|
||||
|
||||
|
||||
fun optimizeIntermediateAst(program: PtProgram, options: CompilationOptions, errors: IErrorReporter) {
|
||||
if (!options.optimize)
|
||||
return
|
||||
while(errors.noErrors() && optimizeCommonSubExpressions(program, errors)>0) {
|
||||
// keep rolling
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun walkAst(root: PtNode, act: (node: PtNode, depth: Int) -> Boolean) {
|
||||
fun recurse(node: PtNode, depth: Int) {
|
||||
if(act(node, depth))
|
||||
node.children.forEach { recurse(it, depth+1) }
|
||||
}
|
||||
recurse(root, 0)
|
||||
}
|
||||
|
||||
|
||||
private fun optimizeCommonSubExpressions(program: PtProgram, errors: IErrorReporter): Int {
|
||||
|
||||
fun extractableSubExpr(expr: PtExpression): Boolean {
|
||||
return if(expr is PtBinaryExpression)
|
||||
!expr.left.isSimple() || !expr.right.isSimple() || (expr.operator !in LogicalOperators && expr.operator !in BitwiseOperators)
|
||||
else
|
||||
!expr.isSimple()
|
||||
}
|
||||
|
||||
// for each Binaryexpression, recurse to find a common subexpression pair therein.
|
||||
val commons = mutableMapOf<PtBinaryExpression, Pair<PtExpression, PtExpression>>()
|
||||
walkAst(program) { node: PtNode, depth: Int ->
|
||||
if(node is PtBinaryExpression) {
|
||||
val subExpressions = mutableListOf<PtExpression>()
|
||||
walkAst(node.left) { subNode: PtNode, subDepth: Int ->
|
||||
if (subNode is PtExpression) {
|
||||
if(extractableSubExpr(subNode)) subExpressions.add(subNode)
|
||||
true
|
||||
} else false
|
||||
}
|
||||
walkAst(node.right) { subNode: PtNode, subDepth: Int ->
|
||||
if (subNode is PtExpression) {
|
||||
if(extractableSubExpr(subNode)) subExpressions.add(subNode)
|
||||
true
|
||||
} else false
|
||||
}
|
||||
|
||||
outer@for (first in subExpressions) {
|
||||
for (second in subExpressions) {
|
||||
if (first!==second && first isSameAs second) {
|
||||
commons[node] = first to second
|
||||
break@outer // do only 1 replacement at a time per binaryexpression
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
} else true
|
||||
}
|
||||
|
||||
// replace common subexpressions by a temp variable that is assigned only once.
|
||||
commons.forEach { binexpr, (occurrence1, occurrence2) ->
|
||||
val (stmtContainer, stmt) = findContainingStatements(binexpr)
|
||||
val occurrence1idx = occurrence1.parent.children.indexOf(occurrence1)
|
||||
val occurrence2idx = occurrence2.parent.children.indexOf(occurrence2)
|
||||
val containerScopedName = findScopeName(stmtContainer)
|
||||
val tempvarName = "subexprvar_line${binexpr.position.line}_${binexpr.hashCode().toUInt()}"
|
||||
// TODO: some tempvars could be reused, if they are from different lines
|
||||
|
||||
val datatype = occurrence1.type
|
||||
val singleReplacement1 = PtIdentifier("$containerScopedName.$tempvarName", datatype, occurrence1.position)
|
||||
val singleReplacement2 = PtIdentifier("$containerScopedName.$tempvarName", datatype, occurrence2.position)
|
||||
occurrence1.parent.children[occurrence1idx] = singleReplacement1
|
||||
singleReplacement1.parent = occurrence1.parent
|
||||
occurrence2.parent.children[occurrence2idx] = singleReplacement2
|
||||
singleReplacement2.parent = occurrence2.parent
|
||||
|
||||
val tempassign = PtAssignment(binexpr.position).also { assign ->
|
||||
assign.add(PtAssignTarget(binexpr.position).also { tgt->
|
||||
tgt.add(PtIdentifier("$containerScopedName.$tempvarName", datatype, binexpr.position))
|
||||
})
|
||||
assign.add(occurrence1)
|
||||
occurrence1.parent = assign
|
||||
}
|
||||
stmtContainer.children.add(stmtContainer.children.indexOf(stmt), tempassign)
|
||||
tempassign.parent = stmtContainer
|
||||
|
||||
val tempvar = PtVariable(tempvarName, datatype, ZeropageWish.DONTCARE, null, null, binexpr.position)
|
||||
stmtContainer.add(0, tempvar)
|
||||
tempvar.parent = stmtContainer
|
||||
|
||||
errors.info("common subexpressions replaced by a tempvar, maybe simplify the expression manually", binexpr.position)
|
||||
}
|
||||
|
||||
return commons.size
|
||||
}
|
||||
|
||||
|
||||
internal fun findScopeName(node: PtNode): String {
|
||||
var parent=node
|
||||
while(parent !is PtNamedNode)
|
||||
parent = parent.parent
|
||||
return parent.scopedName
|
||||
}
|
||||
|
||||
|
||||
internal fun findContainingStatements(node: PtNode): Pair<PtNode, PtNode> { // returns (parentstatementcontainer, childstatement)
|
||||
var parent = node.parent
|
||||
var child = node
|
||||
while(true) {
|
||||
if(parent is IPtStatementContainer) {
|
||||
return parent to child
|
||||
}
|
||||
child=parent
|
||||
parent=parent.parent
|
||||
}
|
||||
}
|
@ -10,6 +10,7 @@ import prog8.ast.statements.Directive
|
||||
import prog8.code.SymbolTableMaker
|
||||
import prog8.code.ast.PtProgram
|
||||
import prog8.code.core.*
|
||||
import prog8.code.optimize.optimizeIntermediateAst
|
||||
import prog8.code.target.*
|
||||
import prog8.codegen.vm.VmCodeGen
|
||||
import prog8.compiler.astprocessing.*
|
||||
@ -102,7 +103,6 @@ fun compileProgram(args: CompilerArguments): CompilationResult? {
|
||||
compilationOptions,
|
||||
args.errors,
|
||||
BuiltinFunctionsFacade(BuiltinFunctions),
|
||||
compTarget
|
||||
)
|
||||
}
|
||||
postprocessAst(program, args.errors, compilationOptions)
|
||||
@ -124,6 +124,9 @@ fun compileProgram(args: CompilerArguments): CompilationResult? {
|
||||
// printProgram(program)
|
||||
|
||||
val intermediateAst = IntermediateAstMaker(program, args.errors).transform()
|
||||
// printAst(intermediateAst, true) { println(it) }
|
||||
optimizeIntermediateAst(intermediateAst, compilationOptions, args.errors)
|
||||
args.errors.report()
|
||||
|
||||
// println("*********** AST RIGHT BEFORE ASM GENERATION *************")
|
||||
// printAst(intermediateAst, true, ::println)
|
||||
@ -378,13 +381,13 @@ private fun processAst(program: Program, errors: IErrorReporter, compilerOptions
|
||||
errors.report()
|
||||
}
|
||||
|
||||
private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, errors: IErrorReporter, functions: IBuiltinFunctions, compTarget: ICompilationTarget) {
|
||||
val remover = UnusedCodeRemover(program, errors, compTarget)
|
||||
private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, errors: IErrorReporter, functions: IBuiltinFunctions) {
|
||||
val remover = UnusedCodeRemover(program, errors, compilerOptions.compTarget)
|
||||
remover.visit(program)
|
||||
remover.applyModifications()
|
||||
while (true) {
|
||||
// keep optimizing expressions and statements until no more steps remain
|
||||
val optsDone1 = program.simplifyExpressions(errors, compTarget)
|
||||
val optsDone1 = program.simplifyExpressions(errors, compilerOptions.compTarget)
|
||||
val optsDone2 = program.optimizeStatements(errors, functions, compilerOptions)
|
||||
val optsDone3 = program.inlineSubroutines(compilerOptions)
|
||||
program.constantFold(errors, compilerOptions) // because simplified statements and expressions can result in more constants that can be folded away
|
||||
@ -395,7 +398,7 @@ private fun optimizeAst(program: Program, compilerOptions: CompilationOptions, e
|
||||
if (optsDone1 + optsDone2 + optsDone3 == 0)
|
||||
break
|
||||
}
|
||||
val remover2 = UnusedCodeRemover(program, errors, compTarget)
|
||||
val remover2 = UnusedCodeRemover(program, errors, compilerOptions.compTarget)
|
||||
remover2.visit(program)
|
||||
remover2.applyModifications()
|
||||
if(errors.noErrors())
|
||||
|
@ -6,6 +6,7 @@ import io.kotest.matchers.ints.shouldBeGreaterThan
|
||||
import io.kotest.matchers.shouldBe
|
||||
import prog8.code.ast.*
|
||||
import prog8.code.core.DataType
|
||||
import prog8.code.core.Position
|
||||
import prog8.code.target.C64Target
|
||||
import prog8.compiler.astprocessing.IntermediateAstMaker
|
||||
import prog8tests.helpers.ErrorReporterForTests
|
||||
@ -60,4 +61,31 @@ class TestIntermediateAst: FunSpec({
|
||||
fcall.type shouldBe DataType.UBYTE
|
||||
}
|
||||
|
||||
test("isSame on binaryExpressions") {
|
||||
val expr1 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY)
|
||||
expr1.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
|
||||
expr1.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
|
||||
val expr2 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY)
|
||||
expr2.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
|
||||
expr2.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
|
||||
(expr1 isSameAs expr2) shouldBe true
|
||||
val expr3 = PtBinaryExpression("/", DataType.UBYTE, Position.DUMMY)
|
||||
expr3.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
|
||||
expr3.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
|
||||
(expr1 isSameAs expr3) shouldBe false
|
||||
}
|
||||
|
||||
test("isSame on binaryExpressions with associative operators") {
|
||||
val expr1 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY)
|
||||
expr1.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
|
||||
expr1.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
|
||||
val expr2 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY)
|
||||
expr2.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
|
||||
expr2.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
|
||||
(expr1 isSameAs expr2) shouldBe true
|
||||
val expr3 = PtBinaryExpression("+", DataType.UBYTE, Position.DUMMY)
|
||||
expr3.add(PtNumber(DataType.UBYTE, 2.0, Position.DUMMY))
|
||||
expr3.add(PtNumber(DataType.UBYTE, 1.0, Position.DUMMY))
|
||||
(expr1 isSameAs expr3) shouldBe true
|
||||
}
|
||||
})
|
@ -511,5 +511,33 @@ main {
|
||||
val value = (st[5] as Assignment).value as BinaryExpression
|
||||
value.operator shouldBe "%"
|
||||
}
|
||||
|
||||
test("isSame on binary expressions") {
|
||||
val left1 = NumericLiteral.optimalInteger(1, Position.DUMMY)
|
||||
val right1 = NumericLiteral.optimalInteger(2, Position.DUMMY)
|
||||
val expr1 = BinaryExpression(left1, "/", right1, Position.DUMMY)
|
||||
val left2 = NumericLiteral.optimalInteger(1, Position.DUMMY)
|
||||
val right2 = NumericLiteral.optimalInteger(2, Position.DUMMY)
|
||||
val expr2 = BinaryExpression(left2, "/", right2, Position.DUMMY)
|
||||
(expr1 isSameAs expr2) shouldBe true
|
||||
val left3 = NumericLiteral.optimalInteger(2, Position.DUMMY)
|
||||
val right3 = NumericLiteral.optimalInteger(1, Position.DUMMY)
|
||||
val expr3 = BinaryExpression(left3, "/", right3, Position.DUMMY)
|
||||
(expr1 isSameAs expr3) shouldBe false
|
||||
}
|
||||
|
||||
test("isSame on binary expressions with associative operators") {
|
||||
val left1 = NumericLiteral.optimalInteger(1, Position.DUMMY)
|
||||
val right1 = NumericLiteral.optimalInteger(2, Position.DUMMY)
|
||||
val expr1 = BinaryExpression(left1, "+", right1, Position.DUMMY)
|
||||
val left2 = NumericLiteral.optimalInteger(1, Position.DUMMY)
|
||||
val right2 = NumericLiteral.optimalInteger(2, Position.DUMMY)
|
||||
val expr2 = BinaryExpression(left2, "+", right2, Position.DUMMY)
|
||||
(expr1 isSameAs expr2) shouldBe true
|
||||
val left3 = NumericLiteral.optimalInteger(2, Position.DUMMY)
|
||||
val right3 = NumericLiteral.optimalInteger(1, Position.DUMMY)
|
||||
val expr3 = BinaryExpression(left3, "+", right3, Position.DUMMY)
|
||||
(expr1 isSameAs expr3) shouldBe true
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -34,11 +34,14 @@ sealed class Expression: Node {
|
||||
(other is IdentifierReference && other.nameInSource==nameInSource)
|
||||
is PrefixExpression ->
|
||||
(other is PrefixExpression && other.operator==operator && other.expression isSameAs expression)
|
||||
is BinaryExpression ->
|
||||
(other is BinaryExpression && other.operator==operator
|
||||
&& other.left isSameAs left
|
||||
&& other.right isSameAs right
|
||||
&& other.isChainedComparison() == isChainedComparison())
|
||||
is BinaryExpression -> {
|
||||
if(other !is BinaryExpression || other.operator!=operator || other.isChainedComparison()!=isChainedComparison())
|
||||
false
|
||||
else if(operator in AssociativeOperators)
|
||||
(other.left isSameAs left && other.right isSameAs right) || (other.left isSameAs right && other.right isSameAs left)
|
||||
else
|
||||
other.left isSameAs left && other.right isSameAs right
|
||||
}
|
||||
is ArrayIndexedExpression -> {
|
||||
(other is ArrayIndexedExpression && other.arrayvar.nameInSource == arrayvar.nameInSource
|
||||
&& other.indexer isSameAs indexer)
|
||||
|
@ -1,10 +1,18 @@
|
||||
|
||||
TODO
|
||||
====
|
||||
|
||||
PtAst/IR: attempt more complex common subexpression eliminations.
|
||||
for any "top level" PtExpression enumerate all subexpressions and find commons, replace them by a tempvar
|
||||
for walking the ast see walkAst() but it should not recurse into the "top level" PtExpression again
|
||||
- why is the right term of cx16.r0 = (cx16.r1+cx16.r2) + (cx16.r1+cx16.r2) flipped around but the left term isn't?
|
||||
|
||||
- Revert or fix current "desugar chained comparisons" it causes problems with if statements.
|
||||
ubyte @shared n=20
|
||||
ubyte @shared L1=10
|
||||
ubyte @shared L2=100
|
||||
|
||||
if n < L1 {
|
||||
;txt.print("bing")
|
||||
} else {
|
||||
txt.print("boom") ; no longer triggers
|
||||
}
|
||||
|
||||
...
|
||||
|
||||
@ -36,7 +44,6 @@ Compiler:
|
||||
global initialization values are simply a list of LOAD instructions.
|
||||
Variables replaced include all subroutine parameters! So the only variables that remain as variables are arrays and strings.
|
||||
- ir: add more optimizations in IRPeepholeOptimizer
|
||||
- ir: for expressions with array indexes that occur multiple times, can we avoid loading them into new virtualregs everytime and just reuse a single virtualreg as indexer? (this is a form of common subexpression elimination)
|
||||
- ir: the @split arrays are currently also split in _lsb/_msb arrays in the IR, and operations take multiple (byte) instructions that may lead to verbose and slow operation and machine code generation down the line.
|
||||
maybe another representation is needed once actual codegeneration is done from the IR...?
|
||||
- [problematic due to using 64tass:] better support for building library programs, where unused .proc shouldn't be deleted from the assembly?
|
||||
|
@ -1,13 +1,14 @@
|
||||
%import textio
|
||||
%zeropage basicsafe
|
||||
|
||||
; Note: this program can be compiled for multiple target systems.
|
||||
|
||||
main {
|
||||
sub start() {
|
||||
cx16.r0L=1
|
||||
while cx16.r0L < 10 and cx16.r0L>0 {
|
||||
cx16.r0L++
|
||||
}
|
||||
ubyte[10] array1
|
||||
ubyte[10] array2
|
||||
ubyte @shared xx
|
||||
|
||||
cx16.r0 = (cx16.r1+cx16.r2) / (cx16.r2+cx16.r1)
|
||||
cx16.r1 = 4*(cx16.r1+cx16.r2) + 3*(cx16.r1+cx16.r2)
|
||||
cx16.r2 = array1[xx+20]==10 or array2[xx+20]==20 or array1[xx+20]==30 or array2[xx+20]==40
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user