more semantic code checks

This commit is contained in:
Irmen de Jong 2018-01-16 01:47:55 +01:00
parent db97be69fe
commit 9b77dcc6b8
7 changed files with 147 additions and 89 deletions

View File

@ -13,7 +13,7 @@ from typing import Optional, Tuple, Set, Dict, List, Any, no_type_check
import attr
from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \
SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\
SymbolName, Dereference, AddressOf, IncrDecr, Label, AstNode, datatype_of, coerce_constant_value
SymbolName, Dereference, AddressOf, IncrDecr, Label, AstNode, datatype_of, coerce_constant_value, UndefinedSymbolError
from .plylex import SourceRef, print_bold
from .datatypes import DataType, VarType
@ -36,7 +36,7 @@ class PlyParser:
self.process_imports(module)
self.create_multiassigns(module)
self.check_and_merge_zeropages(module)
self.process_all_expressions(module)
self.process_all_expressions_and_symbolnames(module)
if not self.parsing_import:
# these shall only be done on the main module after all imports have been done:
self.apply_directive_options(module)
@ -54,28 +54,33 @@ class PlyParser:
self.parse_errors += 1
print_bold("ERROR: {}: {}".format(sourceref, fmtstring.format(*args)))
def _check_last_statement_is_return(self, last_stmt: AstNode) -> None:
if isinstance(last_stmt, Subroutine):
return
if isinstance(last_stmt, Directive) and last_stmt.name == "noreturn":
return
if isinstance(last_stmt, InlineAssembly):
for line in reversed(last_stmt.assembly.splitlines()):
line = line.strip()
if line.startswith(";"):
continue
if "jmp " in line or "jmp\t" in line or "rts" in line or "rti" in line:
return
raise ParseError("last statement in a block/subroutine must be a return or goto, "
"(or %noreturn directive to silence this error)", last_stmt.sourceref)
def semantic_check(self, module: Module) -> None:
# perform semantic analysis / checks on the syntactic parse tree we have so far
def check_last_statement_is_return(last_stmt: AstNode) -> None:
if isinstance(last_stmt, Subroutine):
return
if isinstance(last_stmt, Directive) and last_stmt.name == "noreturn":
return
if isinstance(last_stmt, InlineAssembly):
for line in reversed(last_stmt.assembly.splitlines()):
line = line.strip()
if line.startswith(";"):
continue
if "jmp " in line or "jmp\t" in line or "rts" in line or "rti" in line:
return
raise ParseError("last statement in a block/subroutine must be a return or goto, "
"(or %noreturn directive to silence this error)", last_stmt.sourceref)
# (note: symbol names have already been checked to exist when we start this)
for block, parent in module.all_scopes():
assert isinstance(block, (Module, Block, Subroutine))
assert parent is None or isinstance(parent, (Module, Block, Subroutine))
previous_stmt = None
for stmt in block.nodes:
if isinstance(stmt, SubCall):
if isinstance(stmt.target.target, SymbolName):
subdef = block.scope.lookup(stmt.target.target.name)
self.check_subroutine_arguments(stmt, subdef)
if isinstance(stmt, Subroutine):
# the previous statement (if any) must be a Goto or Return
if previous_stmt and not isinstance(previous_stmt, (Goto, Return, VarDef, Subroutine)):
@ -88,11 +93,20 @@ class PlyParser:
previous_stmt = stmt
if isinstance(stmt, IncrDecr):
if isinstance(stmt.target, SymbolName):
symdef = block.scope[stmt.target.name]
symdef = block.scope.lookup(stmt.target.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
raise ParseError("cannot modify a constant", stmt.sourceref)
if parent and block.name != "ZP" and not isinstance(stmt, (Return, Goto)):
check_last_statement_is_return(stmt)
self._check_last_statement_is_return(stmt)
def check_subroutine_arguments(self, call: SubCall, subdef: Subroutine) -> None:
# @todo must be moved to expression processing, or, restructure whole AST tree walking to make it easier to walk over everything
if len(call.arguments) != len(subdef.param_spec):
raise ParseError("invalid number of arguments ({:d}, required: {:d})"
.format(len(call.arguments), len(subdef.param_spec)), call.sourceref)
for arg, param in zip(call.arguments, subdef.param_spec):
if arg.name and arg.name != param[0]:
raise ParseError("parameter name mismatch", arg.sourceref)
def check_and_merge_zeropages(self, module: Module) -> None:
# merge all ZP blocks into one
@ -133,8 +147,8 @@ class PlyParser:
raise ParseError(str(x), vardef.sourceref)
@no_type_check
def process_all_expressions(self, module: Module) -> None:
# process/simplify all expressions (constant folding etc)
def process_all_expressions_and_symbolnames(self, module: Module) -> None:
# process/simplify all expressions (constant folding etc), and check all symbol names
encountered_blocks = set()
for block, parent in module.all_scopes():
parentname = (parent.name + ".") if parent else ""
@ -144,6 +158,7 @@ class PlyParser:
encountered_blocks.add(blockname)
for node in block.nodes:
try:
node.verify_symbol_names(block.scope)
node.process_expressions(block.scope)
except ParseError:
raise
@ -297,10 +312,10 @@ class PlyParser:
return self._get_subroutine_usages_from_expression(usages, expr.name, parent_scope)
elif isinstance(expr, SymbolName):
try:
symbol = parent_scope[expr.name]
symbol = parent_scope.lookup(expr.name)
if isinstance(symbol, Subroutine):
usages[(parent_scope.name, expr.name)].add(str(expr.sourceref))
except LookupError:
except UndefinedSymbolError:
pass
else:
raise TypeError("unknown expr type to scan for sub usages", expr, expr.sourceref)
@ -338,8 +353,8 @@ class PlyParser:
if name[0] == '$':
continue
try:
symbol = parent_scope[name]
except LookupError:
symbol = parent_scope.lookup(name)
except UndefinedSymbolError:
pass
else:
if isinstance(symbol, Subroutine):
@ -439,6 +454,7 @@ class PlyParser:
print(' ' * (1+exc.sourceref.column) + '^', file=out)
if out.isatty():
print("\x1b[0m", file=out, end="", flush=True)
raise exc # XXX temporary to see where the error occurred
def handle_internal_error(self, exc: Exception, msg: str="") -> None:
out = sys.stdout

View File

@ -18,7 +18,7 @@ def generate_incrdecr(out: Callable, stmt: IncrDecr, scope: Scope) -> None:
assert stmt.operator in ("++", "--")
target = stmt.target # one of Register/SymbolName/Dereference
if isinstance(target, SymbolName):
symdef = scope[target.name]
symdef = scope.lookup(target.name)
if isinstance(symdef, VarDef):
target = symdef
else:

View File

@ -48,6 +48,10 @@ class ExpressionEvaluationError(ParseError):
pass
class UndefinedSymbolError(LookupError):
pass
start = "start"
@ -89,6 +93,11 @@ class AstNode:
# this is implemented in node types that have expression(s) and that should act on this.
pass
def verify_symbol_names(self, scope: 'Scope') -> None:
# check all SymbolNames to see if they exist.
# this is implemented in node types that have expression(s) and that should act on this.
pass
@attr.s(cmp=False, repr=False)
class Directive(AstNode):
@ -105,7 +114,7 @@ class Scope(AstNode):
save_registers = attr.ib(type=bool, default=None, init=False) # None = look in parent scope's setting @todo property that does that
def __attrs_post_init__(self):
# populate the symbol table for this scope for fast lookups via scope["name"] or scope["dotted.name"]
# populate the symbol table for this scope for fast lookups via scope.lookup("name") or scope.lookup("dotted.name")
self.symbols = {}
for node in self.nodes:
assert isinstance(node, AstNode)
@ -129,7 +138,7 @@ class Scope(AstNode):
self.symbols[node.name] = node
node.scope.parent_scope = self
def __getitem__(self, name: str) -> AstNode:
def lookup(self, name: str) -> AstNode:
assert isinstance(name, str)
if '.' in name:
# look up the dotted name starting from the topmost scope
@ -140,18 +149,18 @@ class Scope(AstNode):
if isinstance(scope, (Block, Subroutine)):
scope = scope.scope
if not isinstance(scope, Scope):
raise LookupError("undefined symbol: " + name)
raise UndefinedSymbolError("undefined symbol: " + name)
scope = scope.symbols.get(namepart, None)
if not scope:
raise LookupError("undefined symbol: " + name)
raise UndefinedSymbolError("undefined symbol: " + name)
return scope
else:
# find the name in nested scope hierarchy
if name in self.symbols:
return self.symbols[name]
if self.parent_scope:
return self.parent_scope[name]
raise LookupError("undefined symbol: " + name)
return self.parent_scope.lookup(name)
raise UndefinedSymbolError("undefined symbol: " + name)
def filter_nodes(self, nodetype) -> Generator[AstNode, None, None]:
for node in self.nodes:
@ -325,6 +334,15 @@ class Assignment(AstNode):
def process_expressions(self, scope: Scope) -> None:
self.right = process_expression(self.right, scope, self.right.sourceref)
def verify_symbol_names(self, scope: Scope) -> None:
for lv in self.left:
if isinstance(lv, SymbolName):
check_symbol_definition(lv.name, scope, lv.sourceref)
elif isinstance(lv, Dereference):
if isinstance(lv.location, SymbolName):
check_symbol_definition(lv.location.name, scope, lv.location.sourceref)
# the symbols in the assignment rvalue are checked when its expression is processed.
@attr.s(cmp=False, repr=False)
class AugAssignment(AstNode):
@ -335,6 +353,11 @@ class AugAssignment(AstNode):
def process_expressions(self, scope: Scope) -> None:
self.right = process_expression(self.right, scope, self.right.sourceref)
def verify_symbol_names(self, scope: Scope) -> None:
if isinstance(self.left, SymbolName):
check_symbol_definition(self.left.name, scope, self.left.sourceref)
# the symbols in the assignment rvalue are checked when its expression is processed.
@attr.s(cmp=False, repr=False)
class SubCall(AstNode):
@ -350,6 +373,11 @@ class SubCall(AstNode):
assert isinstance(callarg, CallArgument)
callarg.process_expressions(scope)
def verify_symbol_names(self, scope: Scope) -> None:
if isinstance(self.target.target, SymbolName):
check_symbol_definition(self.target.target.name, scope, self.target.target.sourceref)
# the symbols in the subroutine's arguments are checked when their expression is processed.
@attr.s(cmp=False, repr=False)
class Return(AstNode):
@ -505,6 +533,10 @@ class Goto(AstNode):
if self.condition is not None:
self.condition = process_expression(self.condition, scope, self.condition.sourceref)
def verify_symbol_names(self, scope: Scope) -> None:
if isinstance(self.target.target, SymbolName):
check_symbol_definition(self.target.target.name, scope, self.target.target.sourceref)
@attr.s(cmp=False, repr=False)
class Dereference(AstNode):
@ -525,6 +557,11 @@ class Dereference(AstNode):
raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref)
self.datatype = self.datatype.to_enum()
def verify_symbol_names(self, scope: Scope) -> None:
print("DEREF", self.location) # XXX not called?????
if isinstance(self.location, SymbolName):
check_symbol_definition(self.location.name, scope, self.location.sourceref)
@attr.s(cmp=False, repr=False)
class LiteralValue(AstNode):
@ -554,6 +591,10 @@ class IncrDecr(AstNode):
if isinstance(self.target, TargetRegisters):
raise ParseError("cannot incr/decr multiple registers at once", self.sourceref)
def verify_symbol_names(self, scope: Scope) -> None:
if isinstance(self.target, SymbolName):
check_symbol_definition(self.target.name, scope, self.target.sourceref)
@attr.s(cmp=False, repr=False)
class SymbolName(AstNode):
@ -621,7 +662,7 @@ def datatype_of(assignmenttarget: AstNode, scope: Scope) -> DataType:
if isinstance(assignmenttarget, (VarDef, Dereference, Register)):
return assignmenttarget.datatype
elif isinstance(assignmenttarget, SymbolName):
symdef = scope[assignmenttarget.name]
symdef = scope.lookup(assignmenttarget.name)
if isinstance(symdef, VarDef):
return symdef.datatype
elif isinstance(assignmenttarget, TargetRegisters):
@ -630,8 +671,8 @@ def datatype_of(assignmenttarget: AstNode, scope: Scope) -> DataType:
raise TypeError("cannot determine datatype", assignmenttarget)
def coerce_constant_value(datatype: DataType, value: Union[int, float, str],
sourceref: SourceRef=None) -> Tuple[bool, Union[int, float, str]]:
def coerce_constant_value(datatype: DataType, value: Any,
sourceref: SourceRef=None) -> Tuple[bool, Any]:
# if we're a BYTE type, and the value is a single character, convert it to the numeric value
def verify_bounds(value: Union[int, float, str]) -> None:
# if the value is out of bounds, raise an overflow exception
@ -684,56 +725,47 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc
elif isinstance(expr, LiteralValue):
return expr.value
elif isinstance(expr, SymbolName):
try:
value = symbolscope[expr.name]
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value
if isinstance(value, Expression):
raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, (int, float, str, bool)):
return value
else:
raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref)
except LookupError as x:
raise ExpressionEvaluationError(str(x), expr.sourceref) from None
value = check_symbol_definition(expr.name, symbolscope, expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value
if isinstance(value, Expression):
raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, (int, float, str, bool)):
return value
else:
raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref)
elif isinstance(expr, AddressOf):
assert isinstance(expr.name, SymbolName)
try:
value = symbolscope[expr.name.name]
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
return value.value
if value.vartype == VarType.CONST:
raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref)
raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant"
.format(value.__class__.__name__), expr.name.sourceref)
else:
raise ExpressionEvaluationError("constant address required, not {}"
.format(value.__class__.__name__), expr.name.sourceref)
except LookupError as x:
raise ParseError(str(x), expr.sourceref) from None
value = check_symbol_definition(expr.name.name, symbolscope, expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
return value.value
if value.vartype == VarType.CONST:
raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref)
raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant"
.format(value.__class__.__name__), expr.name.sourceref)
else:
raise ExpressionEvaluationError("constant address required, not {}"
.format(value.__class__.__name__), expr.name.sourceref)
elif isinstance(expr, SubCall):
if isinstance(expr.target, CallTarget):
target = expr.target.target
if isinstance(target, SymbolName): # 'function(1,2,3)'
funcname = target.name
if funcname in math_functions or funcname in builtin_functions:
if isinstance(expr.target.target, SymbolName):
func_args = []
for a in (process_constant_expression(callarg.value, sourceref, symbolscope) for callarg in expr.arguments):
if isinstance(a, LiteralValue):
func_args.append(a.value)
else:
func_args.append(a)
func = math_functions.get(funcname, builtin_functions.get(funcname))
try:
return func(*func_args)
except Exception as x:
raise ExpressionEvaluationError(str(x), expr.sourceref)
else:
raise ParseError("symbol name required, not {}".format(expr.target.__class__.__name__), expr.sourceref)
func_args = []
for a in (process_constant_expression(callarg.value, sourceref, symbolscope) for callarg in expr.arguments):
if isinstance(a, LiteralValue):
func_args.append(a.value)
else:
func_args.append(a)
func = math_functions.get(funcname, builtin_functions.get(funcname))
try:
return func(*func_args)
except Exception as x:
raise ExpressionEvaluationError(str(x), expr.sourceref)
else:
raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref)
elif isinstance(target, Dereference): # '[...](1,2,3)'
@ -766,8 +798,8 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc
expr.left = process_constant_expression(expr.left, left_sourceref, symbolscope)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_constant_expression(expr.right, right_sourceref, symbolscope)
if isinstance(expr.left, (LiteralValue, SymbolName, int, float, str, bool)):
if isinstance(expr.right, (LiteralValue, SymbolName, int, float, str, bool)):
if isinstance(expr.left, (LiteralValue, int, float, str, bool)):
if isinstance(expr.right, (LiteralValue, int, float, str, bool)):
return expr.evaluate_primitive_constants(symbolscope)
else:
raise ExpressionEvaluationError("constant value required on right, not {}"
@ -777,6 +809,13 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc
.format(expr.left.__class__.__name__), left_sourceref)
def check_symbol_definition(name: str, scope: Scope, sref: SourceRef) -> Any:
try:
return scope.lookup(name)
except UndefinedSymbolError as x:
raise ParseError(str(x), sref)
def process_dynamic_expression(expr: Any, sourceref: SourceRef, symbolscope: Scope) -> Any:
# constant-fold a dynamic expression
if expr is None or isinstance(expr, (int, float, str, bool)):
@ -797,10 +836,14 @@ def process_dynamic_expression(expr: Any, sourceref: SourceRef, symbolscope: Sco
try:
return process_constant_expression(expr, sourceref, symbolscope)
except ExpressionEvaluationError:
if isinstance(expr.target.target, SymbolName):
check_symbol_definition(expr.target.target.name, symbolscope, expr.target.target.sourceref)
return expr
elif isinstance(expr, Register):
return expr
elif isinstance(expr, Dereference):
if isinstance(expr.location, SymbolName):
check_symbol_definition(expr.location.name, symbolscope, expr.location.sourceref)
return expr
elif not isinstance(expr, Expression):
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)

View File

@ -112,6 +112,7 @@ def test_coerce_value():
assert coerce_constant_value(datatypes.DataType.FLOAT, '@') == (True, 64)
assert coerce_constant_value(datatypes.DataType.BYTE, 5.678) == (True, 5)
assert coerce_constant_value(datatypes.DataType.WORD, 5.678) == (True, 5)
assert coerce_constant_value(datatypes.DataType.WORD, "string") == (False, "string"), "string (address) can be assigned to a word"
assert coerce_constant_value(datatypes.DataType.STRING, "string") == (False, "string")
assert coerce_constant_value(datatypes.DataType.STRING_P, "string") == (False, "string")
assert coerce_constant_value(datatypes.DataType.STRING_S, "string") == (False, "string")
@ -134,7 +135,5 @@ def test_coerce_value():
coerce_constant_value(datatypes.DataType.FLOAT, 1.7014118347e+38)
with pytest.raises(TypeError):
coerce_constant_value(datatypes.DataType.BYTE, "string")
with pytest.raises(TypeError):
coerce_constant_value(datatypes.DataType.WORD, "string")
with pytest.raises(TypeError):
coerce_constant_value(datatypes.DataType.FLOAT, "string")

View File

@ -117,10 +117,10 @@ def test_parser():
assert result.scope.name == "<sourcefile global scope>"
assert result.subroutine_usage == {}
assert result.scope.parent_scope is None
sub = result.scope["block.calculate"]
sub = result.scope.lookup("block.calculate")
assert isinstance(sub, Subroutine)
assert sub.name == "calculate"
block = result.scope["block"]
block = result.scope.lookup("block")
assert isinstance(block, Block)
assert block.name == "block"
assert block.nodes is block.scope.nodes
@ -131,7 +131,7 @@ def test_parser():
assert isinstance(bool_vdef.value.right.value, int)
assert bool_vdef.value.right.value == 1
assert block.address == 49152
sub2 = block.scope["calculate"]
sub2 = block.scope.lookup("calculate")
assert sub2 is sub
assert sub2.lineref == "src l. 19"
all_scopes = list(result.all_scopes())
@ -163,7 +163,7 @@ def test_block_nodes():
test_source_2 = """
~ {
999(1,2)
&zz()
[zz]()
}
"""
@ -184,9 +184,9 @@ def test_parser_2():
assert isinstance(call, SubCall)
assert len(call.arguments) == 0
assert isinstance(call.target, CallTarget)
assert isinstance(call.target.target, SymbolName)
assert call.target.target.name == "zz"
assert call.target.address_of is True
assert isinstance(call.target.target, Dereference)
assert call.target.target.location.name == "zz"
assert call.target.address_of is False
test_source_3 = """

View File

@ -85,6 +85,7 @@ game_over:
c64scr.print_string("\nToo bad! It was: ")
c64scr.print_byte_decimal(secretnumber)
c64.CHROUT('\n')
return goodbye()
goodbye() ; @todo fix subroutine usage tracking, it doesn't register this one
return
return

View File

@ -68,7 +68,6 @@ start:
A+=3
XY+=6
XY+=222
A=222/13 ; @todo warn truncate (in assignment stmt)
XY+=666
return 44