This commit is contained in:
Irmen de Jong 2018-01-07 19:14:21 +01:00
parent 9b68722df3
commit 68c1d2af4c
9 changed files with 330 additions and 70 deletions

View File

@ -13,7 +13,7 @@ from .symbols import FLOAT_MAX_POSITIVE, FLOAT_MAX_NEGATIVE, SourceRef, SymbolTa
class ParseError(Exception): class ParseError(Exception):
def __init__(self, message: str, sourcetext: str, sourceref: SourceRef) -> None: def __init__(self, message: str, sourcetext: Optional[str], sourceref: SourceRef) -> None:
self.sourceref = sourceref self.sourceref = sourceref
self.msg = message self.msg = message
self.sourcetext = sourcetext self.sourcetext = sourcetext

View File

@ -9,6 +9,8 @@ License: GNU GPL 3.0, see LICENSE
import sys import sys
import ply.lex import ply.lex
from .symbols import SourceRef from .symbols import SourceRef
from .parse import ParseError
# token names # token names
@ -218,7 +220,7 @@ def t_NAME(t):
def t_DIRECTIVE(t): def t_DIRECTIVE(t):
r"%[a-z]+" r"%[a-z]+\b"
t.value = t.value[1:] t.value = t.value[1:]
return t return t
@ -284,13 +286,15 @@ def t_PRESERVEREGS(t):
def t_ENDL(t): def t_ENDL(t):
r"\n+" r"\n+"
t.lexer.lineno += len(t.value) t.lexer.lineno += len(t.value)
t.value = "\n"
return t # end of lines are significant to the parser return t # end of lines are significant to the parser
def t_error(t): def t_error(t):
line, col = t.lineno, find_tok_column(t) line, col = t.lineno, find_tok_column(t)
sref = SourceRef("@todo-filename-f1", line, col) filename = getattr(t.lexer, "source_filename", "<unknown-file>")
t.lexer.error_function("{}: Illegal character '{:s}'", sref, t.value[0]) sref = SourceRef(filename, line, col)
t.lexer.error_function("{}: Illegal character '{:s}'", sref, t.value[0], sourceref=sref)
t.lexer.skip(1) t.lexer.skip(1)
@ -300,8 +304,8 @@ def find_tok_column(token):
return token.lexpos - last_cr return token.lexpos - last_cr
def error_function(fmtstring, *args): def error_function(fmtstring, *args, sourceref: SourceRef=None) -> None:
print("ERROR:", fmtstring.format(*args), file=sys.stderr) raise ParseError(fmtstring.format(*args), None, sourceref=sourceref)
lexer = ply.lex.lex() lexer = ply.lex.lex()

View File

@ -475,9 +475,9 @@ class Parser:
if self.result.format == ProgramFormat.PRG and self.result.with_sys and self.result.start_address != 0x0801: if self.result.format == ProgramFormat.PRG and self.result.with_sys and self.result.start_address != 0x0801:
raise self.PError("cannot use non-default 'address' when output format includes basic SYS program") raise self.PError("cannot use non-default 'address' when output format includes basic SYS program")
continue continue
elif directive == "preserve_registers": elif directive == "saveregisters":
if preserve_specified: if preserve_specified:
raise self.PError("can only specify preserve_registers option once") raise self.PError("can only specify saveregisters option once")
preserve_specified = True preserve_specified = True
_, _, optionstr = line.partition(" ") _, _, optionstr = line.partition(" ")
self.result.preserve_registers = optionstr in ("", "true", "yes") self.result.preserve_registers = optionstr in ("", "true", "yes")
@ -649,7 +649,7 @@ class Parser:
elif directive == "breakpoint": elif directive == "breakpoint":
self.cur_block.statements.append(BreakpointStmt(self.sourceref)) self.cur_block.statements.append(BreakpointStmt(self.sourceref))
self.print_warning("breakpoint defined") self.print_warning("breakpoint defined")
elif directive == "preserve_registers": elif directive == "saveregisters":
self.result.preserve_registers = optionstr in ("", "true", "yes") self.result.preserve_registers = optionstr in ("", "true", "yes")
else: else:
raise self.PError("invalid directive") raise self.PError("invalid directive")

153
il65/plycompiler.py Normal file
View File

@ -0,0 +1,153 @@
import os
import sys
import linecache
from typing import Optional, Generator, Tuple, Set
from .plyparser import parse_file, Module, Directive, Block, Subroutine, AstNode
from .parse import ParseError
from .symbols import SourceRef
class PlyParser:
def __init__(self):
self.parse_errors = 0
self.parsing_import = False
def parse_file(self, filename: str) -> Module:
print("parsing:", filename)
module = parse_file(filename)
try:
self.check_directives(module)
self.remove_empty_blocks(module)
self.process_imports(module)
except ParseError as x:
self.handle_parse_error(x)
return module
def remove_empty_blocks(self, module: Module) -> None:
# remove blocks without name and without address, or that are empty
for scope, parent in self.recurse_scopes(module):
if isinstance(scope, (Subroutine, Block)):
if not scope.scope:
continue
if all(isinstance(n, Directive) for n in scope.scope.nodes):
empty = True
for n in scope.scope.nodes:
empty = empty and n.name not in {"asmbinary", "asminclude"}
if empty:
self.print_warning("ignoring empty block or subroutine", scope.sourceref)
assert isinstance(parent, (Block, Module))
parent.scope.nodes.remove(scope)
if isinstance(scope, Block):
if not scope.name and scope.address is None:
self.print_warning("ignoring block without name and address", scope.sourceref)
assert isinstance(parent, Module)
parent.scope.nodes.remove(scope)
def check_directives(self, module: Module) -> None:
for node, parent in self.recurse_scopes(module):
if isinstance(node, Module):
# check module-level directives
imports = set() # type: Set[str]
for directive in node.scope.filter_nodes(Directive):
if directive.name not in {"output", "zp", "address", "import", "saveregisters"}:
raise ParseError("invalid directive in module", None, directive.sourceref)
if directive.name == "import":
if imports & set(directive.args):
raise ParseError("duplicate import", None, directive.sourceref)
imports |= set(directive.args)
if isinstance(node, (Block, Subroutine)):
# check block and subroutine-level directives
first_node = True
if not node.scope:
continue
for sub_node in node.scope.nodes:
if isinstance(sub_node, Directive):
if sub_node.name not in {"asmbinary", "asminclude", "breakpoint", "saveregisters"}:
raise ParseError("invalid directive in " + node.__class__.__name__.lower(), None, sub_node.sourceref)
if sub_node.name == "saveregisters" and not first_node:
raise ParseError("saveregisters directive should be the first", None, sub_node.sourceref)
first_node = False
def recurse_scopes(self, module: Module) -> Generator[Tuple[AstNode, AstNode], None, None]:
# generator that recursively yields through the scopes (preorder traversal), yields (node, parent_node) tuples.
yield module, None
for block in list(module.scope.filter_nodes(Block)):
yield block, module
for subroutine in list(block.scope.filter_nodes(Subroutine)):
yield subroutine, block
def process_imports(self, module: Module) -> None:
# (recursively) imports the modules
imported = []
for directive in module.scope.filter_nodes(Directive):
if directive.name == "import":
if len(directive.args) < 1:
raise ParseError("missing argument(s) for import directive", None, directive.sourceref)
for arg in directive.args:
filename = self.find_import_file(arg, directive.sourceref.file)
if not filename:
raise ParseError("imported file not found", None, directive.sourceref)
imported_module = self.import_file(filename)
imported_module.scope.parent_scope = module.scope
imported.append(imported_module)
# append the imported module's contents (blocks) at the end of the current module
for imported_module in imported:
for block in imported_module.scope.filter_nodes(Block):
module.scope.nodes.append(block)
def import_file(self, filename: str) -> Module:
sub_parser = PlyParser()
return sub_parser.parse_file(filename)
def find_import_file(self, modulename: str, sourcefile: str) -> Optional[str]:
filename_at_source_location = os.path.join(os.path.split(sourcefile)[0], modulename)
filename_at_libs_location = os.path.join(os.getcwd(), "lib", modulename)
candidates = [modulename,
filename_at_source_location,
filename_at_libs_location,
modulename+".ill",
filename_at_source_location+".ill",
filename_at_libs_location+".ill"]
for filename in candidates:
if os.path.isfile(filename):
return filename
return None
def print_warning(self, text: str, sourceref: SourceRef=None) -> None:
if sourceref:
self.print_bold("warning: {}: {:s}".format(sourceref, text))
else:
self.print_bold("warning: " + text)
def print_bold(self, text: str) -> None:
if sys.stdout.isatty():
print("\x1b[1m" + text + "\x1b[0m", flush=True)
else:
print(text)
def handle_parse_error(self, exc: ParseError) -> None:
self.parse_errors += 1
if sys.stderr.isatty():
print("\x1b[1m", file=sys.stderr)
if self.parsing_import:
print("Error (in imported file):", str(exc), file=sys.stderr)
else:
print("Error:", str(exc), file=sys.stderr)
if exc.sourcetext is None:
exc.sourcetext = linecache.getline(exc.sourceref.file, exc.sourceref.line).rstrip()
if exc.sourcetext:
# remove leading whitespace
stripped = exc.sourcetext.lstrip()
num_spaces = len(exc.sourcetext) - len(stripped)
stripped = stripped.rstrip()
print(" " + stripped, file=sys.stderr)
if exc.sourceref.column:
print(" " + ' ' * (exc.sourceref.column - num_spaces) + '^', file=sys.stderr)
if sys.stderr.isatty():
print("\x1b[0m", file=sys.stderr, end="", flush=True)
if __name__ == "__main__":
plyparser = PlyParser()
m = plyparser.parse_file(sys.argv[1])
print(str(m)[:400], "...")

View File

@ -8,7 +8,7 @@ License: GNU GPL 3.0, see LICENSE
import attr import attr
from ply.yacc import yacc from ply.yacc import yacc
from typing import Union from typing import Union, Type, Generator
from .symbols import SourceRef from .symbols import SourceRef
from .lexer import tokens, lexer, find_tok_column # get the lexer tokens. required. from .lexer import tokens, lexer, find_tok_column # get the lexer tokens. required.
@ -48,70 +48,121 @@ class AstNode:
tostr(self, 0) tostr(self, 0)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Module(AstNode):
nodes = attr.ib(type=list)
@attr.s(cmp=False)
class Directive(AstNode): class Directive(AstNode):
name = attr.ib(type=str) name = attr.ib(type=str)
args = attr.ib(type=list, default=attr.Factory(list)) args = attr.ib(type=list, default=attr.Factory(list))
@attr.s(cmp=False) @attr.s(cmp=False, slots=True, repr=False)
class Scope(AstNode): class Scope(AstNode):
nodes = attr.ib(type=list) nodes = attr.ib(type=list)
symbols = attr.ib(init=False)
name = attr.ib(init=False) # will be set by enclosing block, or subroutine etc.
parent_scope = attr.ib(init=False, default=None) # will be wired up later
save_registers = attr.ib(type=bool, default=False, init=False) # XXX will be set later
def __attrs_post_init__(self):
# populate the symbol table for this scope for fast lookups via scope["name"] or scope["dotted.name"]
self.symbols = {}
for node in self.nodes:
if isinstance(node, (Label, VarDef)):
self.symbols[node.name] = node
if isinstance(node, Subroutine):
self.symbols[node.name] = node
if node.scope is not None:
node.scope.parent_scope = self
if isinstance(node, Block):
if node.name:
self.symbols[node.name] = node
node.scope.parent_scope = self
def __getitem__(self, name: str) -> AstNode:
if '.' in name:
# look up the dotted name starting from the topmost scope
scope = self
while scope.parent_scope:
scope = scope.parent_scope
for namepart in name.split('.'):
if isinstance(scope, (Block, Subroutine)):
scope = scope.scope
if not isinstance(scope, Scope):
raise LookupError("undefined symbol: " + name)
scope = scope.symbols.get(namepart, None)
if scope is None:
raise LookupError("undefined symbol: " + name)
return scope
else:
# find the name in nested scope hierarchy
if name in self.symbols:
return self.symbols[name]
if self.parent_scope:
return self.parent_scope[name]
raise LookupError("undefined symbol: " + name)
def filter_nodes(self, nodetype) -> Generator[AstNode, None, None]:
for node in self.nodes:
if isinstance(node, nodetype):
yield node
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Module(AstNode):
name = attr.ib(type=str) # filename
scope = attr.ib(type=Scope)
@attr.s(cmp=False, repr=False)
class Block(AstNode): class Block(AstNode):
scope = attr.ib(type=Scope) scope = attr.ib(type=Scope)
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
address = attr.ib(type=int, default=None) address = attr.ib(type=int, default=None)
def __attrs_post_init__(self):
self.scope.name = self.name
@attr.s(cmp=False)
@attr.s(cmp=False, repr=False)
class Label(AstNode): class Label(AstNode):
name = attr.ib(type=str) name = attr.ib(type=str)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Register(AstNode): class Register(AstNode):
name = attr.ib(type=str) name = attr.ib(type=str)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class PreserveRegs(AstNode): class PreserveRegs(AstNode):
registers = attr.ib(type=str) registers = attr.ib(type=str)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Assignment(AstNode): class Assignment(AstNode):
left = attr.ib() # type: Union[str, TargetRegisters, Dereference] left = attr.ib() # type: Union[str, TargetRegisters, Dereference]
right = attr.ib() right = attr.ib()
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class AugAssignment(Assignment): class AugAssignment(Assignment):
operator = attr.ib(type=str) operator = attr.ib(type=str)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class SubCall(AstNode): class SubCall(AstNode):
target = attr.ib() target = attr.ib()
preserve_regs = attr.ib() preserve_regs = attr.ib()
arguments = attr.ib() arguments = attr.ib()
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Return(AstNode): class Return(AstNode):
value_A = attr.ib(default=None) value_A = attr.ib(default=None)
value_X = attr.ib(default=None) value_X = attr.ib(default=None)
value_Y = attr.ib(default=None) value_Y = attr.ib(default=None)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class TargetRegisters(AstNode): class TargetRegisters(AstNode):
registers = attr.ib(type=list) registers = attr.ib(type=list)
@ -119,12 +170,12 @@ class TargetRegisters(AstNode):
self.registers.append(register) self.registers.append(register)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class InlineAssembly(AstNode): class InlineAssembly(AstNode):
assembly = attr.ib(type=str) assembly = attr.ib(type=str)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class VarDef(AstNode): class VarDef(AstNode):
name = attr.ib(type=str) name = attr.ib(type=str)
vartype = attr.ib() vartype = attr.ib()
@ -132,13 +183,13 @@ class VarDef(AstNode):
value = attr.ib(default=None) value = attr.ib(default=None)
@attr.s(cmp=False, slots=True) @attr.s(cmp=False, slots=True, repr=False)
class Datatype(AstNode): class Datatype(AstNode):
name = attr.ib(type=str) name = attr.ib(type=str)
dimension = attr.ib(type=list, default=None) dimension = attr.ib(type=list, default=None)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Subroutine(AstNode): class Subroutine(AstNode):
name = attr.ib(type=str) name = attr.ib(type=str)
param_spec = attr.ib() param_spec = attr.ib()
@ -149,40 +200,42 @@ class Subroutine(AstNode):
def __attrs_post_init__(self): def __attrs_post_init__(self):
if self.scope is not None and self.address is not None: if self.scope is not None and self.address is not None:
raise ValueError("subroutine must have either a scope or an address, not both") raise ValueError("subroutine must have either a scope or an address, not both")
if self.scope is not None:
self.scope.name = self.name
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Goto(AstNode): class Goto(AstNode):
target = attr.ib() target = attr.ib()
if_stmt = attr.ib(default=None) if_stmt = attr.ib(default=None)
condition = attr.ib(default=None) condition = attr.ib(default=None)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class Dereference(AstNode): class Dereference(AstNode):
location = attr.ib() location = attr.ib()
datatype = attr.ib() datatype = attr.ib()
@attr.s(cmp=False, slots=True) @attr.s(cmp=False, slots=True, repr=False)
class CallTarget(AstNode): class CallTarget(AstNode):
target = attr.ib() target = attr.ib()
address_of = attr.ib(type=bool) address_of = attr.ib(type=bool)
@attr.s(cmp=False, slots=True) @attr.s(cmp=False, slots=True, repr=False)
class CallArgument(AstNode): class CallArgument(AstNode):
value = attr.ib() value = attr.ib()
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
@attr.s(cmp=False) @attr.s(cmp=False, repr=False)
class UnaryOp(AstNode): class UnaryOp(AstNode):
operator = attr.ib(type=str) operator = attr.ib(type=str)
operand = attr.ib() operand = attr.ib()
@attr.s(cmp=False, slots=True) @attr.s(cmp=False, slots=True, repr=False)
class Expression(AstNode): class Expression(AstNode):
left = attr.ib() left = attr.ib()
operator = attr.ib(type=str) operator = attr.ib(type=str)
@ -195,7 +248,13 @@ def p_start(p):
| module_elements | module_elements
""" """
if p[1]: if p[1]:
p[0] = Module(nodes=p[1], sourceref=_token_sref(p, 1)) scope = Scope(nodes=p[1], sourceref=_token_sref(p, 1))
scope.name = "<" + p.lexer.source_filename + " global scope>"
p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=_token_sref(p, 1))
else:
scope = Scope(nodes=[], sourceref=_token_sref(p, 1))
scope.name = "<" + p.lexer.source_filename + " global scope>"
p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=SourceRef(lexer.source_filename, 1, 1))
def p_module(p): def p_module(p):
@ -215,6 +274,7 @@ def p_module_elt(p):
| directive | directive
| block | block
""" """
if p[1] != '\n':
p[0] = p[1] p[0] = p[1]
@ -245,6 +305,7 @@ def p_directive_arg(p):
directive_arg : NAME directive_arg : NAME
| INTEGER | INTEGER
| STRING | STRING
| BOOLEAN
""" """
p[0] = p[1] p[0] = p[1]
@ -299,7 +360,10 @@ def p_scope_elements(p):
| scope_elements scope_element | scope_elements scope_element
""" """
if len(p) == 2: if len(p) == 2:
p[0] = [p[1]] p[0] = [] if p[1] in (None, '\n') else [p[1]]
else:
if p[2] in (None, '\n'):
p[0] = p[1]
else: else:
p[0] = p[1] + [p[2]] p[0] = p[1] + [p[2]]
@ -314,7 +378,10 @@ def p_scope_element(p):
| inlineasm | inlineasm
| statement | statement
""" """
if p[1] != '\n':
p[0] = p[1] p[0] = p[1]
else:
p[0] = None
def p_label(p): def p_label(p):
@ -729,17 +796,18 @@ def p_empty(p):
def p_error(p): def p_error(p):
stack_state_str = ' '.join([symbol.type for symbol in parser.symstack][1:])
print('\n[ERROR DEBUG: parser state={:d} stack: {} . {} ]'.format(parser.state, stack_state_str, p))
if p: if p:
sref = SourceRef(p.lexer.source_filename, p.lineno, find_tok_column(p)) sref = SourceRef(p.lexer.source_filename, p.lineno, find_tok_column(p))
p.lexer.error_function("{}: before '{:.20s}' ({})", sref, str(p.value), repr(p)) p.lexer.error_function("syntax error before '{:.20s}'", str(p.value), sourceref=sref)
else: else:
lexer.error_function("{}: at end of input", "@todo-filename3") lexer.error_function("syntax error at end of input", lexer.source_filename, sourceref=None)
def _token_sref(p, token_idx): def _token_sref(p, token_idx):
""" Returns the coordinates for the YaccProduction object 'p' indexed """ Returns the coordinates for the YaccProduction object 'p' indexed
with 'token_idx'. The coordinate includes the 'lineno' and with 'token_idx'. The coordinate includes the 'lineno' and 'column', starting from 1.
'column'. Both follow the lex semantic, starting from 1.
""" """
last_cr = p.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx)) last_cr = p.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx))
if last_cr < 0: if last_cr < 0:
@ -772,12 +840,10 @@ class TokenFilter:
parser = yacc(write_tables=True) parser = yacc(write_tables=True)
if __name__ == "__main__": def parse_file(filename: str) -> Module:
import sys lexer.lineno = 1
file = sys.stdin # open(sys.argv[1], "rU") lexer.source_filename = filename
lexer.source_filename = "derp" tfilter = TokenFilter(lexer)
tokenfilter = TokenFilter(lexer) with open(filename, "rU") as inf:
result = parser.parse(input=file.read(), sourcecode = inf.read()
tokenfunc=tokenfilter.token) or Module(None, SourceRef(lexer.source_filename, 1, 1)) return parser.parse(input=sourcecode, tokenfunc=tfilter.token)
print("RESULT:")
result.print_tree()

View File

@ -710,7 +710,8 @@ class AstNode:
try: try:
variables = vars(node).items() variables = vars(node).items()
except TypeError: except TypeError:
variables = {} pass
else:
for name, value in variables: for name, value in variables:
if isinstance(value, AstNode): if isinstance(value, AstNode):
tostr(value, level + 1) tostr(value, level + 1)

View File

@ -59,7 +59,6 @@ _m_with_add stx SCRATCH_ZP1
} }
} }
sub multiply_bytes_addA_16 (byte1: X, byte2: Y, add: A) -> (A?, XY) { sub multiply_bytes_addA_16 (byte1: X, byte2: Y, add: A) -> (A?, XY) {
; ---- multiply 2 bytes and add A, result as word in X/Y (unsigned) ; ---- multiply 2 bytes and add A, result as word in X/Y (unsigned)
%asm { %asm {

View File

@ -400,7 +400,7 @@ Normally, the registers are NOT preserved when calling a subroutine or when a ce
operations are performed. Most calls will be simply a few instructions to load the operations are performed. Most calls will be simply a few instructions to load the
values in the registers and then a JSR or JMP. values in the registers and then a JSR or JMP.
By using the ``%preserve_registers`` directive (globally or in a block) you can tell the By using the ``%saveregisters`` directive (globally or in a block) you can tell the
compiler to preserve all registers. This does generate a lot of extra code that puts compiler to preserve all registers. This does generate a lot of extra code that puts
original values on the stack and gets them off the stack again once the subroutine is done. original values on the stack and gets them off the stack again once the subroutine is done.
In this case however you don't have to worry about A, X and Y losing their original values In this case however you don't have to worry about A, X and Y losing their original values

View File

@ -4,12 +4,17 @@
~ main $4444 { ~ main $4444 {
%saveregisters true, false
const num = 2 const num = 2
var var1 =2 var var1 =2
var .word wvar1 = 2 var .word wvar1 = 2
start: start:
A=math.randbyte() A=math.randbyte()
A += c64.RASTER A += c64.RASTER
A-=c64.TIME_LO A-=c64.TIME_LO
@ -161,6 +166,38 @@ loop :
c64scr.print_word_decimal(1222) c64scr.print_word_decimal(1222)
c64.CHROUT('\n') c64.CHROUT('\n')
%breakpoint
return
sub sub1 () -> () {
%saveregisters off
%breakpoint
label:
return return
} }
sub emptysub () -> () {
%saveregisters on
}
}
~ zzzz {
%saveregisters
%breakpoint
return
}
~ {
;sdfsdf
return
;sdfsdf
}