package prog8.optimizer import prog8.ast.IFunctionCall import prog8.ast.IStatementContainer import prog8.ast.Node import prog8.ast.Program import prog8.ast.base.FatalAstException import prog8.ast.base.UndefinedSymbolError import prog8.ast.expressions.* import prog8.ast.statements.* import prog8.ast.walk.AstWalker import prog8.ast.walk.IAstModification import prog8.code.core.* import prog8.compiler.CallGraph // Fix up the literal value's type to match that of the vardecl // (also check range literal operands types before they get expanded into arrays for instance) class VarConstantValueTypeAdjuster( private val program: Program, private val options: CompilationOptions, private val errors: IErrorReporter ) : AstWalker() { private val callGraph by lazy { CallGraph(program) } override fun after(decl: VarDecl, parent: Node): Iterable { if(decl.parent is AnonymousScope) throw FatalAstException("vardecl may no longer occur in anonymousscope") try { val declConstValue = decl.value?.constValue(program) if(declConstValue!=null && (decl.type== VarDeclType.VAR || decl.type==VarDeclType.CONST) && declConstValue.type != decl.datatype) { // avoid silent float roundings if(decl.datatype in IntegerDatatypes && declConstValue.type == DataType.FLOAT) { errors.err("refused truncating of float to avoid loss of precision", decl.value!!.position) } else if(decl.datatype!=DataType.BOOL) { // cast the numeric literal to the appropriate datatype of the variable if it's not boolean declConstValue.linkParents(decl) val cast = declConstValue.cast(decl.datatype, true) if (cast.isValid) return listOf(IAstModification.ReplaceNode(decl.value!!, cast.valueOrZero(), decl)) } } } catch (x: UndefinedSymbolError) { errors.err(x.message, x.position) } // replace variables by constants, if possible if(options.optimize) { if (decl.sharedWithAsm || decl.type != VarDeclType.VAR || decl.origin != VarDeclOrigin.USERCODE || decl.datatype !in NumericDatatypes) return noModifications if (decl.value != null && decl.value!!.constValue(program) == null) return noModifications val usages = callGraph.usages(decl) val (writes, reads) = usages .partition { it is InlineAssembly // can't really tell if it's written to or only read, assume the worst || it.parent is AssignTarget || it.parent is ForLoop || it.parent is AddressOf || (it.parent as? IFunctionCall)?.target?.nameInSource?.singleOrNull() in InplaceModifyingBuiltinFunctions } val singleAssignment = writes.singleOrNull()?.parent?.parent as? Assignment ?: writes.singleOrNull()?.parent as? Assignment if (singleAssignment == null) { if (writes.isEmpty()) { if(reads.isEmpty()) { // variable is never used AT ALL so we just remove it altogether if("ignore_unused" !in decl.definingBlock.options()) errors.info("removing unused variable '${decl.name}'", decl.position) return listOf(IAstModification.Remove(decl, parent as IStatementContainer)) } val declValue = decl.value?.constValue(program) if (declValue != null) { // variable is never written to, so it can be replaced with a constant, IF the value is a constant errors.info("variable '${decl.name}' is never written to and was replaced by a constant", decl.position) val const = VarDecl(VarDeclType.CONST, decl.origin, decl.datatype, decl.zeropage, decl.arraysize, decl.name, decl.names, declValue, decl.sharedWithAsm, decl.splitArray, decl.position) decl.value = null return listOf( IAstModification.ReplaceNode(decl, const, parent) ) } } } else { if (singleAssignment.origin == AssignmentOrigin.VARINIT && singleAssignment.value.constValue(program) != null) { if(reads.isEmpty()) { // variable is never used AT ALL so we just remove it altogether, including the single assignment if("ignore_unused" !in decl.definingBlock.options()) errors.info("removing unused variable '${decl.name}'", decl.position) return listOf( IAstModification.Remove(decl, parent as IStatementContainer), IAstModification.Remove(singleAssignment, singleAssignment.parent as IStatementContainer) ) } // variable only has a single write and it is the initialization value, so it can be replaced with a constant, IF the value is a constant errors.info("variable '${decl.name}' is never written to and was replaced by a constant", decl.position) val const = VarDecl(VarDeclType.CONST, decl.origin, decl.datatype, decl.zeropage, decl.arraysize, decl.name, decl.names, singleAssignment.value, decl.sharedWithAsm, decl.splitArray, decl.position) return listOf( IAstModification.ReplaceNode(decl, const, parent), IAstModification.Remove(singleAssignment, singleAssignment.parent as IStatementContainer) ) } } /* TODO: need to check if there are no variable usages between the declaration and the assignment (because these rely on the original initialization value) if(writes.size==2) { val firstAssignment = writes[0].parent as? Assignment val secondAssignment = writes[1].parent as? Assignment if(firstAssignment?.origin==AssignmentOrigin.VARINIT && secondAssignment?.value?.constValue(program)!=null) { errors.warn("variable is only assigned once here, consider using this as the initialization value in the declaration instead", secondAssignment.position) } } */ } return noModifications } override fun after(range: RangeExpression, parent: Node): Iterable { val from = range.from.constValue(program)?.number val to = range.to.constValue(program)?.number val step = range.step.constValue(program)?.number if(from==null) { if(!range.from.inferType(program).isInteger) errors.err("range expression from value must be integer", range.from.position) } else if(from-from.toInt()>0) { errors.err("range expression from value must be integer", range.from.position) } if(to==null) { val toType = range.to.inferType(program) if(toType.isKnown && !range.to.inferType(program).isInteger) errors.err("range expression to value must be integer", range.to.position) } else if(to-to.toInt()>0) { errors.err("range expression to value must be integer", range.to.position) } if(step==null) { if(!range.step.inferType(program).isInteger) errors.err("range expression step value must be integer", range.step.position) } else if(step-step.toInt()>0) { errors.err("range expression step value must be integer", range.step.position) } return noModifications } override fun after(functionCallExpr: FunctionCallExpression, parent: Node): Iterable { // choose specific builtin function for the given types val func = functionCallExpr.target.nameInSource if(func==listOf("clamp")) { val t1 = functionCallExpr.args[0].inferType(program) if(t1.isKnown) { val replaceFunc: String if(t1.isBytes) { replaceFunc = if(t1.istype(DataType.BYTE)) "clamp__byte" else "clamp__ubyte" } else if(t1.isInteger) { replaceFunc = if(t1.istype(DataType.WORD)) "clamp__word" else "clamp__uword" } else { errors.err("clamp builtin not supported for floats, use floats.clamp", functionCallExpr.position) return noModifications } return listOf(IAstModification.SetExpression({functionCallExpr.target = it as IdentifierReference}, IdentifierReference(listOf(replaceFunc), functionCallExpr.target.position), functionCallExpr)) } } else if(func==listOf("min") || func==listOf("max")) { val t1 = functionCallExpr.args[0].inferType(program) val t2 = functionCallExpr.args[1].inferType(program) if(t1.isKnown && t2.isKnown) { val funcName = func[0] val replaceFunc: String if(t1.isBytes && t2.isBytes) { replaceFunc = if(t1.istype(DataType.BYTE) || t2.istype(DataType.BYTE)) "${funcName}__byte" else "${funcName}__ubyte" } else if(t1.isInteger && t2.isInteger) { replaceFunc = if(t1.istype(DataType.WORD) || t2.istype(DataType.WORD)) "${funcName}__word" else "${funcName}__uword" } else if(t1.isNumeric && t2.isNumeric) { errors.err("min/max not supported for floats", functionCallExpr.position) return noModifications } else { errors.err("expected numeric arguments", functionCallExpr.args[0].position) return noModifications } return listOf(IAstModification.SetExpression({functionCallExpr.target = it as IdentifierReference}, IdentifierReference(listOf(replaceFunc), functionCallExpr.target.position), functionCallExpr)) } } else if(func==listOf("abs")) { val t1 = functionCallExpr.args[0].inferType(program) if(t1.isKnown) { val dt = t1.getOrElse { throw InternalCompilerException("invalid dt") } val replaceFunc = when(dt) { DataType.BYTE -> "abs__byte" DataType.WORD -> "abs__word" DataType.FLOAT -> "abs__float" DataType.UBYTE, DataType.UWORD -> { return listOf(IAstModification.ReplaceNode(functionCallExpr, functionCallExpr.args[0], parent)) } else -> { errors.err("expected numeric argument", functionCallExpr.args[0].position) return noModifications } } return listOf(IAstModification.SetExpression({functionCallExpr.target = it as IdentifierReference}, IdentifierReference(listOf(replaceFunc), functionCallExpr.target.position), functionCallExpr)) } } else if(func==listOf("sqrt")) { val t1 = functionCallExpr.args[0].inferType(program) if(t1.isKnown) { val dt = t1.getOrElse { throw InternalCompilerException("invalid dt") } val replaceFunc = when(dt) { DataType.UBYTE -> "sqrt__ubyte" DataType.UWORD -> "sqrt__uword" DataType.FLOAT -> "sqrt__float" else -> { errors.err("expected unsigned or float numeric argument", functionCallExpr.args[0].position) return noModifications } } return listOf(IAstModification.SetExpression({functionCallExpr.target = it as IdentifierReference}, IdentifierReference(listOf(replaceFunc), functionCallExpr.target.position), functionCallExpr)) } } return noModifications } override fun after(functionCallStatement: FunctionCallStatement, parent: Node): Iterable { // choose specific builtin function for the given types val func = functionCallStatement.target.nameInSource if(func==listOf("divmod")) { val argTypes = functionCallStatement.args.map {it.inferType(program)}.toSet() if(argTypes.size!=1) { errors.err("expected all ubyte or all uword arguments", functionCallStatement.args[0].position) return noModifications } val t1 = argTypes.single() if(t1.isKnown) { val dt = t1.getOrElse { throw InternalCompilerException("invalid dt") } val replaceFunc = when(dt) { DataType.UBYTE -> "divmod__ubyte" DataType.UWORD -> "divmod__uword" else -> { errors.err("expected all ubyte or all uword arguments", functionCallStatement.args[0].position) return noModifications } } return listOf(IAstModification.SetExpression({functionCallStatement.target = it as IdentifierReference}, IdentifierReference(listOf(replaceFunc), functionCallStatement.target.position), functionCallStatement)) } } return noModifications } } // Replace all constant identifiers with their actual value, // and the array var initializer values and sizes. // This is needed because further constant optimizations depend on those. internal class ConstantIdentifierReplacer( private val program: Program, private val options: CompilationOptions, private val errors: IErrorReporter ) : AstWalker() { override fun before(addressOf: AddressOf, parent: Node): Iterable { val constValue = addressOf.constValue(program) if(constValue!=null) { return listOf(IAstModification.ReplaceNode(addressOf, constValue, parent)) } return noModifications } override fun after(identifier: IdentifierReference, parent: Node): Iterable { // replace identifiers that refer to const value, with the value itself // if it's a simple type and if it's not a left hand side variable if(identifier.parent is AssignTarget) return noModifications var forloop = identifier.parent as? ForLoop if(forloop==null) forloop = identifier.parent.parent as? ForLoop if(forloop!=null && identifier===forloop.loopVar) return noModifications val dt = identifier.inferType(program) if(!dt.isKnown || !dt.isNumeric && !dt.isBool) return noModifications try { val cval = identifier.constValue(program) ?: return noModifications val arrayIdx = identifier.parent as? ArrayIndexedExpression if(arrayIdx!=null && cval.type in NumericDatatypes) { // special case when the identifier is used as a pointer var // var = constpointer[x] --> var = @(constvalue+x) [directmemoryread] // constpointer[x] = var -> @(constvalue+x) [directmemorywrite] = var val add = BinaryExpression(NumericLiteral(cval.type, cval.number, identifier.position), "+", arrayIdx.indexer.indexExpr, identifier.position) return if(arrayIdx.parent is AssignTarget) { val memwrite = DirectMemoryWrite(add, identifier.position) val assignTarget = AssignTarget(null, null, memwrite, null, false, identifier.position) listOf(IAstModification.ReplaceNode(arrayIdx.parent, assignTarget, arrayIdx.parent.parent)) } else { val memread = DirectMemoryRead(add, identifier.position) listOf(IAstModification.ReplaceNode(arrayIdx, memread, arrayIdx.parent)) } } when (cval.type) { in NumericDatatypesWithBoolean -> { if(parent is AddressOf) return noModifications // cannot replace the identifier INSIDE the addr-of here, let's do it later. return listOf( IAstModification.ReplaceNode( identifier, NumericLiteral(cval.type, cval.number, identifier.position), identifier.parent ) ) } in PassByReferenceDatatypes -> throw InternalCompilerException("pass-by-reference type should not be considered a constant") else -> return noModifications } } catch (x: UndefinedSymbolError) { errors.err(x.message, x.position) return noModifications } } override fun after(decl: VarDecl, parent: Node): Iterable { // the initializer value can't refer to the variable itself (recursive definition) if(decl.value?.referencesIdentifier(listOf(decl.name)) == true || decl.arraysize?.indexExpr?.referencesIdentifier(listOf(decl.name)) == true) { errors.err("recursive var declaration", decl.position) return noModifications } if(decl.isArray && decl.type==VarDeclType.MEMORY && decl.value !is IdentifierReference) { val memaddr = decl.value?.constValue(program) if(memaddr!=null && memaddr !== decl.value) { return listOf(IAstModification.SetExpression( { decl.value = it }, memaddr, decl )) } } if(decl.type==VarDeclType.CONST || decl.type==VarDeclType.VAR) { if(decl.isArray){ val arraysize = decl.arraysize if(arraysize==null) { // for arrays that have no size specifier attempt to deduce the size val arrayval = decl.value as? ArrayLiteral if(arrayval!=null) { return listOf(IAstModification.SetExpression( { decl.arraysize = ArrayIndex(it, decl.position) }, NumericLiteral.optimalInteger(arrayval.value.size, decl.position), decl )) } } } when(decl.datatype) { DataType.FLOAT -> { // vardecl: for scalar float vars, promote constant integer initialization values to floats val litval = decl.value as? NumericLiteral if (litval!=null && litval.type in IntegerDatatypesWithBoolean) { val newValue = NumericLiteral(DataType.FLOAT, litval.number, litval.position) return listOf(IAstModification.ReplaceNode(decl.value!!, newValue, decl)) } } in ArrayDatatypes -> { val replacedArrayInitializer = createConstArrayInitializerValue(decl) if(replacedArrayInitializer!=null) return listOf(IAstModification.ReplaceNode(decl.value!!, replacedArrayInitializer, decl)) } else -> { // nothing to do for this type } } } return noModifications } private fun createConstArrayInitializerValue(decl: VarDecl): ArrayLiteral? { if(decl.type==VarDeclType.MEMORY) return null // memory mapped arrays can never have an initializer value other than the address where they're mapped. // convert the initializer range expression from a range or int, to an actual array. when(decl.datatype) { DataType.ARRAY_UB, DataType.ARRAY_B, DataType.ARRAY_UW, DataType.ARRAY_W, DataType.ARRAY_W_SPLIT, DataType.ARRAY_UW_SPLIT -> { val rangeExpr = decl.value as? RangeExpression if(rangeExpr!=null) { val constRange = rangeExpr.toConstantIntegerRange() if(constRange?.isEmpty()==true) { if(constRange.first>constRange.last && constRange.step>=0) errors.err("descending range with positive step", decl.value?.position!!) else if(constRange.first { if(fillvalue !in 0..255) errors.err("ubyte value overflow", numericLv.position) } DataType.ARRAY_B -> { if(fillvalue !in -128..127) errors.err("byte value overflow", numericLv.position) } DataType.ARRAY_UW -> { if(fillvalue !in 0..65535) errors.err("uword value overflow", numericLv.position) } DataType.ARRAY_W -> { if(fillvalue !in -32768..32767) errors.err("word value overflow", numericLv.position) } else -> {} } // create the array itself, filled with the fillvalue. val array = Array(size) {fillvalue}.map { NumericLiteral(ArrayToElementTypes.getValue(decl.datatype), it.toDouble(), numericLv.position) }.toTypedArray() return ArrayLiteral(InferredTypes.InferredType.known(decl.datatype), array, position = numericLv.position) } } DataType.ARRAY_F -> { val rangeExpr = decl.value as? RangeExpression if(rangeExpr!=null) { // convert the initializer range expression to an actual array of floats val declArraySize = decl.arraysize?.constIndex() if(declArraySize!=null && declArraySize!=rangeExpr.size()) errors.err("range expression size (${rangeExpr.size()}) doesn't match declared array size ($declArraySize)", decl.value?.position!!) val constRange = rangeExpr.toConstantIntegerRange() if(constRange!=null) { return ArrayLiteral(InferredTypes.InferredType.known(DataType.ARRAY_F), constRange.map { NumericLiteral(DataType.FLOAT, it.toDouble(), decl.value!!.position) }.toTypedArray(), position = decl.value!!.position) } } val numericLv = decl.value as? NumericLiteral val size = decl.arraysize?.constIndex() ?: return null if(rangeExpr==null && numericLv!=null) { // arraysize initializer is a single int, and we know the array size. val fillvalue = numericLv.number if (fillvalue < options.compTarget.machine.FLOAT_MAX_NEGATIVE || fillvalue > options.compTarget.machine.FLOAT_MAX_POSITIVE) errors.err("float value overflow", numericLv.position) else { val array = Array(size) {fillvalue}.map { NumericLiteral(DataType.FLOAT, it, numericLv.position) }.toTypedArray() return ArrayLiteral(InferredTypes.InferredType.known(DataType.ARRAY_F), array, position = numericLv.position) } } } DataType.ARRAY_BOOL -> { val numericLv = decl.value as? NumericLiteral val size = decl.arraysize?.constIndex() ?: return null if(numericLv!=null) { // arraysize initializer is a single value, and we know the array size. if(numericLv.type!=DataType.BOOL) { if(options.strictBool || numericLv.type !in ByteDatatypes) errors.err("initializer value is not a boolean", numericLv.position) return null } val array = Array(size) {numericLv.number}.map { NumericLiteral(DataType.BOOL, it, numericLv.position) }.toTypedArray() return ArrayLiteral(InferredTypes.InferredType.known(DataType.ARRAY_BOOL), array, position = numericLv.position) } } else -> return null } return null } }