diff --git a/il65/compile.py b/il65/compile.py index 61b4b3a5c..2caac21d8 100644 --- a/il65/compile.py +++ b/il65/compile.py @@ -32,7 +32,7 @@ class PlyParser: module = None try: module = parse_file(filename, self.lexer_error) - self.check_directives(module) + self.check_directives_and_const_defs(module) self.apply_directive_options(module) module.scope.define_builtin_functions() self.process_imports(module) @@ -44,7 +44,7 @@ class PlyParser: self.determine_subroutine_usage(module) self.all_parents_connected(module) cf = ConstantFold(module) - cf.fold_constants() # do some constant-folding + cf.fold_constants() self.semantic_check(module) self.coerce_values(module) self.check_floats_enabled(module) @@ -430,10 +430,13 @@ class PlyParser: namespace, name = name.rsplit(".", maxsplit=2) usages[(namespace, symbol.name)].add(str(asmnode.sourceref)) - def check_directives(self, module: Module) -> None: + def check_directives_and_const_defs(self, module: Module) -> None: imports = set() # type: Set[str] for node in module.all_nodes(): - if isinstance(node, Directive): + if isinstance(node, VarDef): + if node.value is None and node.vartype == VarType.CONST: + raise ParseError("const should be initialized with a compile-time constant value", node.sourceref) + elif isinstance(node, Directive): assert isinstance(node.parent, Scope) if node.parent.level == "module": if node.name not in {"output", "zp", "address", "import", "saveregisters", "noreturn"}: diff --git a/il65/constantfold.py b/il65/constantfold.py index ab9f9c7e7..f9fe11f21 100644 --- a/il65/constantfold.py +++ b/il65/constantfold.py @@ -42,7 +42,10 @@ class ConstantFold: self._constant_folding() def _constant_folding(self) -> None: - for expression in self.module.all_nodes(Expression): + for expression in list(self.module.all_nodes(Expression)): + if expression.parent is None or expression.parent.parent is None: + # stale expression node (was part of an expression that was constant-folded away) + continue if isinstance(expression, LiteralValue): continue try: diff --git a/il65/plyparse.py b/il65/plyparse.py index 16923d629..aa09fdd24 100644 --- a/il65/plyparse.py +++ b/il65/plyparse.py @@ -22,7 +22,7 @@ __all__ = ["ProgramFormat", "ZpOptions", "math_functions", "builtin_functions", "UndefinedSymbolError", "AstNode", "Directive", "Scope", "Block", "Module", "Label", "Expression", "Register", "Subroutine", "LiteralValue", "AddressOf", "SymbolName", "Dereference", "IncrDecr", "ExpressionWithOperator", "Goto", "SubCall", "VarDef", "Return", "Assignment", "AugAssignment", - "InlineAssembly", "AssignmentTargets", + "InlineAssembly", "AssignmentTargets", "BuiltinFunction", "TokenFilter", "parser", "connect_parents", "parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition"] @@ -108,6 +108,7 @@ class AstNode: self.nodes[idx] = newnode newnode.parent = self oldnode.parent = None + oldnode.nodes = None def add_node(self, newnode: 'AstNode', index: int = None) -> None: assert isinstance(newnode, AstNode) @@ -804,8 +805,10 @@ class VarDef(AstNode): if self.datatype.isarray() and sum(self.size) in (0, 1): print("warning: {}: array/matrix with size 1, use normal byte/word instead".format(self.sourceref)) if self.value is None and (self.datatype.isnumeric() or self.datatype.isarray()): - self.value = LiteralValue(value=0, sourceref=self.sourceref) - self.value.parent = self + if self.vartype != VarType.CONST: + # leave None when it's a const, so we can check for uninitialized consts later and raise error. + self.value = LiteralValue(value=0, sourceref=self.sourceref) + self.value.parent = self # if it's a matrix with interleave, it must be memory mapped if self.datatype == DataType.MATRIX and len(self.size) == 3: if self.vartype != VarType.MEMORY: diff --git a/tests/test_parser.py b/tests/test_parser.py index 03cea620c..0f5e2bf25 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,10 +1,9 @@ import math import pytest from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef -from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, IncrDecr, Scope, \ - AstNode, Expression, Assignment, VarDef, Register, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference,\ - BuiltinFunction, UndefinedSymbolError +from il65.plyparse import * from il65.datatypes import DataType, VarType +from il65.constantfold import ConstantFold def lexer_error(sourceref: SourceRef, fmtstring: str, *args: str) -> None: @@ -450,3 +449,49 @@ def test_const_other_expressions(): assert not e[2].is_compile_constant() with pytest.raises(TypeError): e[2].const_value() + + +def test_vdef_const_folds(): + src = """ +~ { + const cb1 = 123 + const cb2 = cb1 + const cb3 = cb1*3 +} +""" + result = parse_source(src) + if isinstance(result, Module): + result.scope.define_builtin_functions() + vd = list(result.all_nodes(VarDef)) + assert vd[0].name == "cb1" + assert vd[0].vartype == VarType.CONST + assert vd[0].datatype == DataType.BYTE + assert isinstance(vd[0].value, LiteralValue) + assert vd[0].value.value == 123 + assert vd[1].name == "cb2" + assert vd[1].vartype == VarType.CONST + assert vd[1].datatype == DataType.BYTE + assert isinstance(vd[1].value, SymbolName) + assert vd[1].value.name == "cb1" + assert vd[2].name == "cb3" + assert vd[2].vartype == VarType.CONST + assert vd[2].datatype == DataType.BYTE + assert isinstance(vd[2].value, ExpressionWithOperator) + cf = ConstantFold(result) + cf.fold_constants() + vd = list(result.all_nodes(VarDef)) + assert vd[0].name == "cb1" + assert vd[0].vartype == VarType.CONST + assert vd[0].datatype == DataType.BYTE + assert isinstance(vd[0].value, LiteralValue) + assert vd[0].value.value == 123 + assert vd[1].name == "cb2" + assert vd[1].vartype == VarType.CONST + assert vd[1].datatype == DataType.BYTE + assert isinstance(vd[1].value, LiteralValue) + assert vd[1].value.value == 123 + assert vd[2].name == "cb3" + assert vd[2].vartype == VarType.CONST + assert vd[2].datatype == DataType.BYTE + assert isinstance(vd[2].value, LiteralValue) + assert vd[2].value.value == 369 diff --git a/tests/test_vardef.py b/tests/test_vardef.py index 8682e370f..f71a7ed07 100644 --- a/tests/test_vardef.py +++ b/tests/test_vardef.py @@ -18,8 +18,7 @@ def test_creation(): assert v.vartype == VarType.CONST assert v.datatype == DataType.BYTE assert v.size == [1] - assert isinstance(v.value, LiteralValue) - assert v.value.value == 0 + assert v.value is None assert v.zp_address is None v = VarDef(name="v2", vartype="memory", datatype=None, sourceref=sref) assert v.vartype == VarType.MEMORY @@ -88,11 +87,13 @@ def test_const_value(): with pytest.raises(TypeError): v.const_value() v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref) - assert v.const_value() == 0 + with pytest.raises(ValueError): + v.const_value() v.value = LiteralValue(value=42, sourceref=sref) assert v.const_value() == 42 v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="float", sourceref=sref), sourceref=sref) - assert v.const_value() == 0 + with pytest.raises(ValueError): + v.const_value() v.value = LiteralValue(value=42.9988, sourceref=sref) assert v.const_value() == 42.9988 e = ExpressionWithOperator(operator="-", sourceref=sref) diff --git a/todo.ill b/todo.ill index 2b36a3d48..980da2b63 100644 --- a/todo.ill +++ b/todo.ill @@ -1,9 +1,6 @@ ~ main { - var .float flt = -9.87e-21 const .word border = $0099 - const cbyte2 = 2222 - const cbyte2b = cbyte2*3 ; @todo fix scope not found crash var counter = 1 start: @@ -33,7 +30,6 @@ start: block2.start() return 44.123 -sub varssub()->() { sub goodbye ()->() { var xxxxxx ; @todo vars in sub? memory y = $c000 ; @todo memvars in sub? @@ -43,7 +39,7 @@ sub goodbye ()->() { return } -} + sub derp ()->() { const q = 22 A = q *4 ; @todo fix scope not found error