mirror of
https://github.com/irmen/prog8.git
synced 2024-12-24 16:29:21 +00:00
compiler
This commit is contained in:
parent
67f1941766
commit
3ea0723c3e
247
il65/compile.py
247
il65/compile.py
@ -14,7 +14,7 @@ import attr
|
|||||||
from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \
|
from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \
|
||||||
SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\
|
SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\
|
||||||
SymbolName, Dereference, AddressOf, IncrDecr, AstNode, datatype_of, coerce_constant_value, \
|
SymbolName, Dereference, AddressOf, IncrDecr, AstNode, datatype_of, coerce_constant_value, \
|
||||||
check_symbol_definition, UndefinedSymbolError, process_expression
|
check_symbol_definition, UndefinedSymbolError, process_expression, Label
|
||||||
from .plylex import SourceRef, print_bold
|
from .plylex import SourceRef, print_bold
|
||||||
from .datatypes import DataType, VarType
|
from .datatypes import DataType, VarType
|
||||||
|
|
||||||
@ -39,13 +39,12 @@ class PlyParser:
|
|||||||
self.create_multiassigns(module)
|
self.create_multiassigns(module)
|
||||||
self.check_and_merge_zeropages(module)
|
self.check_and_merge_zeropages(module)
|
||||||
self.process_all_expressions(module)
|
self.process_all_expressions(module)
|
||||||
return module # XXX
|
if not self.imported_module:
|
||||||
# if not self.parsing_import:
|
# the following shall only be done on the main module after all imports have been done:
|
||||||
# # these shall only be done on the main module after all imports have been done:
|
self.apply_directive_options(module)
|
||||||
# self.apply_directive_options(module)
|
self.determine_subroutine_usage(module)
|
||||||
# self.determine_subroutine_usage(module)
|
self.semantic_check(module)
|
||||||
# self.semantic_check(module)
|
self.allocate_zeropage_vars(module)
|
||||||
# self.allocate_zeropage_vars(module)
|
|
||||||
except ParseError as x:
|
except ParseError as x:
|
||||||
self.handle_parse_error(x)
|
self.handle_parse_error(x)
|
||||||
if self.parse_errors:
|
if self.parse_errors:
|
||||||
@ -58,7 +57,7 @@ class PlyParser:
|
|||||||
print_bold("ERROR: {}: {}".format(sourceref, fmtstring.format(*args)))
|
print_bold("ERROR: {}: {}".format(sourceref, fmtstring.format(*args)))
|
||||||
|
|
||||||
def _check_last_statement_is_return(self, last_stmt: AstNode) -> None:
|
def _check_last_statement_is_return(self, last_stmt: AstNode) -> None:
|
||||||
if isinstance(last_stmt, Subroutine):
|
if isinstance(last_stmt, (Subroutine, Return, Goto)):
|
||||||
return
|
return
|
||||||
if isinstance(last_stmt, Directive) and last_stmt.name == "noreturn":
|
if isinstance(last_stmt, Directive) and last_stmt.name == "noreturn":
|
||||||
return
|
return
|
||||||
@ -69,52 +68,47 @@ class PlyParser:
|
|||||||
continue
|
continue
|
||||||
if "jmp " in line or "jmp\t" in line or "rts" in line or "rti" in line:
|
if "jmp " in line or "jmp\t" in line or "rts" in line or "rti" in line:
|
||||||
return
|
return
|
||||||
|
print(last_stmt)
|
||||||
raise ParseError("last statement in a block/subroutine must be a return or goto, "
|
raise ParseError("last statement in a block/subroutine must be a return or goto, "
|
||||||
"(or %noreturn directive to silence this error)", last_stmt.sourceref)
|
"(or %noreturn directive to silence this error)", last_stmt.sourceref)
|
||||||
|
|
||||||
# def semantic_check(self, module: Module) -> None:
|
def semantic_check(self, module: Module) -> None:
|
||||||
# # perform semantic analysis / checks on the syntactic parse tree we have so far
|
# perform semantic analysis / checks on the syntactic parse tree we have so far
|
||||||
# # (note: symbol names have already been checked to exist when we start this)
|
# (note: symbol names have already been checked to exist when we start this)
|
||||||
# for node, parent in module.all_nodes():
|
previous_stmt = None
|
||||||
# previous_stmt = None
|
for node in module.all_nodes():
|
||||||
# if isinstance(node, SubCall):
|
if isinstance(node, Scope):
|
||||||
# if isinstance(node.target, SymbolName):
|
previous_stmt = None
|
||||||
# subdef = block.scope.lookup(stmt.target.target.name)
|
if node.nodes and isinstance(node.parent, (Block, Subroutine)):
|
||||||
# self.check_subroutine_arguments(stmt, subdef)
|
self._check_last_statement_is_return(node.nodes[-1])
|
||||||
# if isinstance(stmt, Subroutine):
|
elif isinstance(node, SubCall):
|
||||||
# # the previous statement (if any) must be a Goto or Return
|
if isinstance(node.target, SymbolName):
|
||||||
# if previous_stmt and not isinstance(previous_stmt, (Goto, Return, VarDef, Subroutine)):
|
subdef = node.my_scope().lookup(node.target.name)
|
||||||
# raise ParseError("statement preceding subroutine must be a goto or return or another subroutine", stmt.sourceref)
|
self.check_subroutine_arguments(node, subdef) # type: ignore
|
||||||
# if isinstance(previous_stmt, Subroutine):
|
elif isinstance(node, Subroutine):
|
||||||
# # the statement after a subroutine can not be some random executable instruction because it could not be reached
|
# the previous statement (if any) must be a Goto or Return
|
||||||
# if not isinstance(stmt, (Subroutine, Label, Directive, InlineAssembly, VarDef)):
|
if previous_stmt and not isinstance(previous_stmt, (Goto, Return, VarDef, Subroutine)):
|
||||||
# raise ParseError("statement following a subroutine can't be runnable code, "
|
raise ParseError("statement preceding subroutine must be a goto or return or another subroutine", node.sourceref)
|
||||||
# "at least use a label first", stmt.sourceref)
|
elif isinstance(node, IncrDecr):
|
||||||
# previous_stmt = stmt
|
if isinstance(node.target, SymbolName):
|
||||||
# if isinstance(stmt, IncrDecr):
|
symdef = node.my_scope().lookup(node.target.name)
|
||||||
# if isinstance(stmt.target, SymbolName):
|
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
|
||||||
# symdef = block.scope.lookup(stmt.target.name)
|
raise ParseError("cannot modify a constant", node.sourceref)
|
||||||
# if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
|
previous_stmt = node
|
||||||
# raise ParseError("cannot modify a constant", stmt.sourceref)
|
|
||||||
#
|
|
||||||
# if parent and block.name != "ZP" and not isinstance(stmt, (Return, Goto)):
|
|
||||||
# self._check_last_statement_is_return(stmt)
|
|
||||||
|
|
||||||
def check_subroutine_arguments(self, call: SubCall, subdef: Subroutine) -> None:
|
def check_subroutine_arguments(self, call: SubCall, subdef: Subroutine) -> None:
|
||||||
# @todo must be moved to expression processing, or, restructure whole AST tree walking to make it easier to walk over everything
|
if len(call.arguments.nodes) != len(subdef.param_spec):
|
||||||
if len(call.arguments) != len(subdef.param_spec):
|
|
||||||
raise ParseError("invalid number of arguments ({:d}, required: {:d})"
|
raise ParseError("invalid number of arguments ({:d}, required: {:d})"
|
||||||
.format(len(call.arguments), len(subdef.param_spec)), call.sourceref)
|
.format(len(call.arguments.nodes), len(subdef.param_spec)), call.sourceref)
|
||||||
for arg, param in zip(call.arguments, subdef.param_spec):
|
for arg, param in zip(call.arguments.nodes, subdef.param_spec):
|
||||||
if arg.name and arg.name != param[0]:
|
if arg.name and arg.name != param[0]:
|
||||||
raise ParseError("parameter name mismatch", arg.sourceref)
|
raise ParseError("parameter name mismatch", arg.sourceref)
|
||||||
|
|
||||||
def check_and_merge_zeropages(self, module: Module) -> None:
|
def check_and_merge_zeropages(self, module: Module) -> None:
|
||||||
# merge all ZP blocks into one
|
# merge all ZP blocks into one
|
||||||
# XXX done: converted to new nodes
|
|
||||||
zeropage = None
|
zeropage = None
|
||||||
for block in module.all_nodes([Block]):
|
for block in module.all_nodes(Block):
|
||||||
if block.name == "ZP":
|
if block.name == "ZP": # type: ignore
|
||||||
if zeropage:
|
if zeropage:
|
||||||
# merge other ZP block into first ZP block
|
# merge other ZP block into first ZP block
|
||||||
for node in block.nodes:
|
for node in block.nodes:
|
||||||
@ -149,12 +143,12 @@ class PlyParser:
|
|||||||
raise ParseError(str(x), vardef.sourceref)
|
raise ParseError(str(x), vardef.sourceref)
|
||||||
|
|
||||||
def check_all_symbolnames(self, module: Module) -> None:
|
def check_all_symbolnames(self, module: Module) -> None:
|
||||||
for node in module.all_nodes([SymbolName]):
|
for node in module.all_nodes(SymbolName):
|
||||||
check_symbol_definition(node.name, node.my_scope(), node.sourceref)
|
check_symbol_definition(node.name, node.my_scope(), node.sourceref) # type: ignore
|
||||||
|
|
||||||
def process_all_expressions(self, module: Module) -> None:
|
def process_all_expressions(self, module: Module) -> None:
|
||||||
# process/simplify all expressions (constant folding etc)
|
# process/simplify all expressions (constant folding etc)
|
||||||
encountered_blocks = set()
|
encountered_blocks = set() # type: Set[Block]
|
||||||
for node in module.all_nodes():
|
for node in module.all_nodes():
|
||||||
if isinstance(node, Block):
|
if isinstance(node, Block):
|
||||||
parentname = (node.parent.name + ".") if node.parent else ""
|
parentname = (node.parent.name + ".") if node.parent else ""
|
||||||
@ -174,15 +168,19 @@ class PlyParser:
|
|||||||
elif isinstance(node, Assignment):
|
elif isinstance(node, Assignment):
|
||||||
lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes)
|
lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes)
|
||||||
if len(lvalue_types) == 1:
|
if len(lvalue_types) == 1:
|
||||||
_, node.right = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref)
|
_, newright = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref)
|
||||||
|
if isinstance(newright, (LiteralValue, Expression)):
|
||||||
|
node.right = newright
|
||||||
|
else:
|
||||||
|
raise TypeError("invalid coerced constant type", newright)
|
||||||
else:
|
else:
|
||||||
for lv_dt in lvalue_types:
|
for lv_dt in lvalue_types:
|
||||||
coerce_constant_value(lv_dt, node.right, node.sourceref)
|
coerce_constant_value(lv_dt, node.right, node.sourceref)
|
||||||
|
|
||||||
|
@no_type_check
|
||||||
def create_multiassigns(self, module: Module) -> None:
|
def create_multiassigns(self, module: Module) -> None:
|
||||||
# create multi-assign statements from nested assignments (A=B=C=5),
|
# create multi-assign statements from nested assignments (A=B=C=5),
|
||||||
# and optimize TargetRegisters down to single Register if it's just one register.
|
# and optimize TargetRegisters down to single Register if it's just one register.
|
||||||
# XXX done: converted to new nodes
|
|
||||||
def reduce_right(assign: Assignment) -> Assignment:
|
def reduce_right(assign: Assignment) -> Assignment:
|
||||||
if isinstance(assign.right, Assignment):
|
if isinstance(assign.right, Assignment):
|
||||||
right = reduce_right(assign.right)
|
right = reduce_right(assign.right)
|
||||||
@ -190,11 +188,12 @@ class PlyParser:
|
|||||||
assign.right = right.right
|
assign.right = right.right
|
||||||
return assign
|
return assign
|
||||||
|
|
||||||
for node in module.all_nodes([Assignment]):
|
for node in module.all_nodes(Assignment):
|
||||||
if isinstance(node.right, Assignment):
|
if isinstance(node.right, Assignment):
|
||||||
multi = reduce_right(node)
|
multi = reduce_right(node)
|
||||||
assert multi is node and len(multi.left) > 1 and not isinstance(multi.right, Assignment)
|
assert multi is node and len(multi.left) > 1 and not isinstance(multi.right, Assignment)
|
||||||
|
|
||||||
|
@no_type_check
|
||||||
def apply_directive_options(self, module: Module) -> None:
|
def apply_directive_options(self, module: Module) -> None:
|
||||||
def set_save_registers(scope: Scope, save_dir: Directive) -> None:
|
def set_save_registers(scope: Scope, save_dir: Directive) -> None:
|
||||||
if not scope:
|
if not scope:
|
||||||
@ -211,81 +210,77 @@ class PlyParser:
|
|||||||
else:
|
else:
|
||||||
scope.save_registers = True
|
scope.save_registers = True
|
||||||
|
|
||||||
for block, parent in module.all_scopes():
|
for directive in module.all_nodes(Directive):
|
||||||
if isinstance(block, Module):
|
node = directive.my_scope().parent
|
||||||
|
if isinstance(node, Module):
|
||||||
# process the module's directives
|
# process the module's directives
|
||||||
for directive in block.scope.filter_nodes(Directive):
|
if directive.name == "output":
|
||||||
if directive.name == "output":
|
if len(directive.args) != 1 or not isinstance(directive.args[0], str):
|
||||||
if len(directive.args) != 1 or not isinstance(directive.args[0], str):
|
raise ParseError("expected one str directive argument", directive.sourceref)
|
||||||
raise ParseError("expected one str directive argument", directive.sourceref)
|
if directive.args[0] == "raw":
|
||||||
if directive.args[0] == "raw":
|
node.format = ProgramFormat.RAW
|
||||||
block.format = ProgramFormat.RAW
|
node.address = 0xc000
|
||||||
block.address = 0xc000
|
elif directive.args[0] == "prg":
|
||||||
elif directive.args[0] == "prg":
|
node.format = ProgramFormat.PRG
|
||||||
block.format = ProgramFormat.PRG
|
node.address = 0xc000
|
||||||
block.address = 0xc000
|
elif directive.args[0] == "basic":
|
||||||
elif directive.args[0] == "basic":
|
node.format = ProgramFormat.BASIC
|
||||||
block.format = ProgramFormat.BASIC
|
node.address = 0x0801
|
||||||
block.address = 0x0801
|
|
||||||
else:
|
|
||||||
raise ParseError("invalid directive args", directive.sourceref)
|
|
||||||
elif directive.name == "address":
|
|
||||||
if len(directive.args) != 1 or type(directive.args[0]) is not int:
|
|
||||||
raise ParseError("expected one integer directive argument", directive.sourceref)
|
|
||||||
if block.format == ProgramFormat.BASIC:
|
|
||||||
raise ParseError("basic cannot have a custom load address", directive.sourceref)
|
|
||||||
block.address = directive.args[0]
|
|
||||||
attr.validate(block)
|
|
||||||
elif directive.name in "import":
|
|
||||||
pass # is processed earlier
|
|
||||||
elif directive.name == "zp":
|
|
||||||
if len(directive.args) not in (1, 2) or set(directive.args) - {"clobber", "restore"}:
|
|
||||||
raise ParseError("invalid directive args", directive.sourceref)
|
|
||||||
if "clobber" in directive.args and "restore" in directive.args:
|
|
||||||
module.zp_options = ZpOptions.CLOBBER_RESTORE
|
|
||||||
elif "clobber" in directive.args:
|
|
||||||
module.zp_options = ZpOptions.CLOBBER
|
|
||||||
elif "restore" in directive.args:
|
|
||||||
raise ParseError("invalid directive args", directive.sourceref)
|
|
||||||
elif directive.name == "saveregisters":
|
|
||||||
set_save_registers(block.scope, directive)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(directive.name)
|
raise ParseError("invalid directive args", directive.sourceref)
|
||||||
elif isinstance(block, Block):
|
elif directive.name == "address":
|
||||||
|
if len(directive.args) != 1 or type(directive.args[0]) is not int:
|
||||||
|
raise ParseError("expected one integer directive argument", directive.sourceref)
|
||||||
|
if node.format == ProgramFormat.BASIC:
|
||||||
|
raise ParseError("basic cannot have a custom load address", directive.sourceref)
|
||||||
|
node.address = directive.args[0]
|
||||||
|
attr.validate(node)
|
||||||
|
elif directive.name in "import":
|
||||||
|
pass # is processed earlier
|
||||||
|
elif directive.name == "zp":
|
||||||
|
if len(directive.args) not in (1, 2) or set(directive.args) - {"clobber", "restore"}:
|
||||||
|
raise ParseError("invalid directive args", directive.sourceref)
|
||||||
|
if "clobber" in directive.args and "restore" in directive.args:
|
||||||
|
module.zp_options = ZpOptions.CLOBBER_RESTORE
|
||||||
|
elif "clobber" in directive.args:
|
||||||
|
module.zp_options = ZpOptions.CLOBBER
|
||||||
|
elif "restore" in directive.args:
|
||||||
|
raise ParseError("invalid directive args", directive.sourceref)
|
||||||
|
elif directive.name == "saveregisters":
|
||||||
|
set_save_registers(directive.my_scope(), directive)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(directive.name)
|
||||||
|
elif isinstance(node, Block):
|
||||||
# process the block's directives
|
# process the block's directives
|
||||||
for directive in block.scope.filter_nodes(Directive):
|
if directive.name == "saveregisters":
|
||||||
if directive.name == "saveregisters":
|
set_save_registers(directive.my_scope(), directive)
|
||||||
set_save_registers(block.scope, directive)
|
elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"):
|
||||||
elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"):
|
continue
|
||||||
continue
|
else:
|
||||||
else:
|
raise NotImplementedError(directive.name)
|
||||||
raise NotImplementedError(directive.name)
|
elif isinstance(node, Subroutine):
|
||||||
elif isinstance(block, Subroutine):
|
# process the sub's directives
|
||||||
if block.scope:
|
if directive.name == "saveregisters":
|
||||||
# process the sub's directives
|
set_save_registers(directive.my_scope(), directive)
|
||||||
for directive in block.scope.filter_nodes(Directive):
|
elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"):
|
||||||
if directive.name == "saveregisters":
|
continue
|
||||||
set_save_registers(block.scope, directive)
|
else:
|
||||||
elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"):
|
raise NotImplementedError(directive.name)
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(directive.name)
|
|
||||||
|
|
||||||
@no_type_check
|
@no_type_check
|
||||||
def determine_subroutine_usage(self, module: Module) -> None:
|
def determine_subroutine_usage(self, module: Module) -> None:
|
||||||
module.subroutine_usage.clear()
|
module.subroutine_usage.clear()
|
||||||
for block, parent in module.all_scopes():
|
for node in module.all_nodes():
|
||||||
for node in block.nodes:
|
if isinstance(node, InlineAssembly):
|
||||||
if isinstance(node, InlineAssembly):
|
self._get_subroutine_usages_from_asm(module.subroutine_usage, node, node.my_scope())
|
||||||
self._get_subroutine_usages_from_asm(module.subroutine_usage, node, block.scope)
|
elif isinstance(node, SubCall):
|
||||||
elif isinstance(node, SubCall):
|
self._get_subroutine_usages_from_subcall(module.subroutine_usage, node, node.my_scope())
|
||||||
self._get_subroutine_usages_from_subcall(module.subroutine_usage, node, block.scope)
|
elif isinstance(node, Goto):
|
||||||
elif isinstance(node, Goto):
|
self._get_subroutine_usages_from_goto(module.subroutine_usage, node, node.my_scope())
|
||||||
self._get_subroutine_usages_from_goto(module.subroutine_usage, node, block.scope)
|
elif isinstance(node, Return):
|
||||||
elif isinstance(node, Return):
|
self._get_subroutine_usages_from_return(module.subroutine_usage, node, node.my_scope())
|
||||||
self._get_subroutine_usages_from_return(module.subroutine_usage, node, block.scope)
|
elif isinstance(node, Assignment):
|
||||||
elif isinstance(node, Assignment):
|
self._get_subroutine_usages_from_assignment(module.subroutine_usage, node, node.my_scope())
|
||||||
self._get_subroutine_usages_from_assignment(module.subroutine_usage, node, block.scope)
|
|
||||||
print("----------SUBROUTINES IN USE-------------") # XXX
|
print("----------SUBROUTINES IN USE-------------") # XXX
|
||||||
import pprint
|
import pprint
|
||||||
pprint.pprint(module.subroutine_usage) # XXX
|
pprint.pprint(module.subroutine_usage) # XXX
|
||||||
@ -293,10 +288,9 @@ class PlyParser:
|
|||||||
|
|
||||||
def _get_subroutine_usages_from_subcall(self, usages: Dict[Tuple[str, str], Set[str]],
|
def _get_subroutine_usages_from_subcall(self, usages: Dict[Tuple[str, str], Set[str]],
|
||||||
subcall: SubCall, parent_scope: Scope) -> None:
|
subcall: SubCall, parent_scope: Scope) -> None:
|
||||||
target = subcall.target.target
|
if isinstance(subcall.target, SymbolName):
|
||||||
if isinstance(target, SymbolName):
|
usages[(parent_scope.name, subcall.target.name)].add(str(subcall.sourceref))
|
||||||
usages[(parent_scope.name, target.name)].add(str(subcall.sourceref))
|
for arg in subcall.arguments.nodes:
|
||||||
for arg in subcall.arguments:
|
|
||||||
self._get_subroutine_usages_from_expression(usages, arg.value, parent_scope)
|
self._get_subroutine_usages_from_expression(usages, arg.value, parent_scope)
|
||||||
|
|
||||||
def _get_subroutine_usages_from_expression(self, usages: Dict[Tuple[str, str], Set[str]],
|
def _get_subroutine_usages_from_expression(self, usages: Dict[Tuple[str, str], Set[str]],
|
||||||
@ -324,6 +318,7 @@ class PlyParser:
|
|||||||
else:
|
else:
|
||||||
raise TypeError("unknown expr type to scan for sub usages", expr, expr.sourceref)
|
raise TypeError("unknown expr type to scan for sub usages", expr, expr.sourceref)
|
||||||
|
|
||||||
|
@no_type_check
|
||||||
def _get_subroutine_usages_from_goto(self, usages: Dict[Tuple[str, str], Set[str]],
|
def _get_subroutine_usages_from_goto(self, usages: Dict[Tuple[str, str], Set[str]],
|
||||||
goto: Goto, parent_scope: Scope) -> None:
|
goto: Goto, parent_scope: Scope) -> None:
|
||||||
target = goto.target.target
|
target = goto.target.target
|
||||||
@ -369,7 +364,6 @@ class PlyParser:
|
|||||||
usages[(namespace, symbol.name)].add(str(asmnode.sourceref))
|
usages[(namespace, symbol.name)].add(str(asmnode.sourceref))
|
||||||
|
|
||||||
def check_directives(self, module: Module) -> None:
|
def check_directives(self, module: Module) -> None:
|
||||||
# XXX done: converted to new nodes
|
|
||||||
imports = set() # type: Set[str]
|
imports = set() # type: Set[str]
|
||||||
for node in module.all_nodes():
|
for node in module.all_nodes():
|
||||||
if isinstance(node, Directive):
|
if isinstance(node, Directive):
|
||||||
@ -391,13 +385,12 @@ class PlyParser:
|
|||||||
|
|
||||||
def process_imports(self, module: Module) -> None:
|
def process_imports(self, module: Module) -> None:
|
||||||
# (recursively) imports the modules
|
# (recursively) imports the modules
|
||||||
# XXX done: converted to new nodes
|
|
||||||
imported = []
|
imported = []
|
||||||
for directive in module.all_nodes([Directive]):
|
for directive in module.all_nodes(Directive):
|
||||||
if directive.name == "import":
|
if directive.name == "import": # type: ignore
|
||||||
if len(directive.args) < 1:
|
if len(directive.args) < 1: # type: ignore
|
||||||
raise ParseError("missing argument(s) for import directive", directive.sourceref)
|
raise ParseError("missing argument(s) for import directive", directive.sourceref)
|
||||||
for arg in directive.args:
|
for arg in directive.args: # type: ignore
|
||||||
filename = self.find_import_file(arg, directive.sourceref.file)
|
filename = self.find_import_file(arg, directive.sourceref.file)
|
||||||
if not filename:
|
if not filename:
|
||||||
raise ParseError("imported file not found", directive.sourceref)
|
raise ParseError("imported file not found", directive.sourceref)
|
||||||
|
@ -75,17 +75,14 @@ class AstNode:
|
|||||||
scope = scope.parent
|
scope = scope.parent
|
||||||
raise LookupError("no scope found in node ancestry")
|
raise LookupError("no scope found in node ancestry")
|
||||||
|
|
||||||
def all_nodes(self, nodetypes: Sequence['AstNode']=None) -> Generator['AstNode', None, None]:
|
def all_nodes(self, *nodetypes: type) -> Generator['AstNode', None, None]:
|
||||||
if nodetypes is None:
|
nodetypes = nodetypes or (AstNode, )
|
||||||
nodett = AstNode
|
|
||||||
else:
|
|
||||||
nodett = tuple(nodetypes) # type: ignore
|
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
if isinstance(node, nodett): # type: ignore
|
if isinstance(node, nodetypes): # type: ignore
|
||||||
yield node
|
yield node
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
if isinstance(node, AstNode):
|
if isinstance(node, AstNode):
|
||||||
yield from node.all_nodes(nodetypes)
|
yield from node.all_nodes(*nodetypes)
|
||||||
|
|
||||||
def remove_node(self, node: 'AstNode') -> None:
|
def remove_node(self, node: 'AstNode') -> None:
|
||||||
self.nodes.remove(node)
|
self.nodes.remove(node)
|
||||||
@ -597,15 +594,15 @@ class Return(AstNode):
|
|||||||
# one, two or three subnodes: value_A, value_X, value_Y (all three Expression)
|
# one, two or three subnodes: value_A, value_X, value_Y (all three Expression)
|
||||||
@property
|
@property
|
||||||
def value_A(self) -> Expression:
|
def value_A(self) -> Expression:
|
||||||
return self.nodes[0] # type: ignore
|
return self.nodes[0] if self.nodes else None # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_X(self) -> Expression:
|
def value_X(self) -> Expression:
|
||||||
return self.nodes[0] # type: ignore
|
return self.nodes[0] if self.nodes else None # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_Y(self) -> Expression:
|
def value_Y(self) -> Expression:
|
||||||
return self.nodes[0] # type: ignore
|
return self.nodes[0] if self.nodes else None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@attr.s(cmp=False, slots=True, repr=False)
|
@attr.s(cmp=False, slots=True, repr=False)
|
||||||
|
@ -102,7 +102,7 @@ def test_char_to_bytevalue():
|
|||||||
|
|
||||||
def test_coerce_value():
|
def test_coerce_value():
|
||||||
def lv(v) -> LiteralValue:
|
def lv(v) -> LiteralValue:
|
||||||
return LiteralValue(value=v, sourceref=SourceRef("test", 1, 1))
|
return LiteralValue(value=v, sourceref=SourceRef("test", 1, 1)) # type: ignore
|
||||||
assert coerce_constant_value(datatypes.DataType.BYTE, lv(0)) == (False, lv(0))
|
assert coerce_constant_value(datatypes.DataType.BYTE, lv(0)) == (False, lv(0))
|
||||||
assert coerce_constant_value(datatypes.DataType.BYTE, lv(255)) == (False, lv(255))
|
assert coerce_constant_value(datatypes.DataType.BYTE, lv(255)) == (False, lv(255))
|
||||||
assert coerce_constant_value(datatypes.DataType.BYTE, lv('@')) == (True, lv(64))
|
assert coerce_constant_value(datatypes.DataType.BYTE, lv('@')) == (True, lv(64))
|
||||||
|
@ -137,7 +137,7 @@ def test_parser():
|
|||||||
assert sub2.lineref == "src l. 19"
|
assert sub2.lineref == "src l. 19"
|
||||||
all_nodes = list(result.all_nodes())
|
all_nodes = list(result.all_nodes())
|
||||||
assert len(all_nodes) == 12
|
assert len(all_nodes) == 12
|
||||||
all_nodes = list(result.all_nodes([Subroutine]))
|
all_nodes = list(result.all_nodes(Subroutine))
|
||||||
assert len(all_nodes) == 1
|
assert len(all_nodes) == 1
|
||||||
assert isinstance(all_nodes[0], Subroutine)
|
assert isinstance(all_nodes[0], Subroutine)
|
||||||
assert isinstance(all_nodes[0].parent, Scope)
|
assert isinstance(all_nodes[0].parent, Scope)
|
||||||
|
Loading…
Reference in New Issue
Block a user