]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - misc/coding_tools/check-miscaptures.py
2e827fb96bd47aff9903e4e85f796b5773a2ab6a
[tahoe-lafs/tahoe-lafs.git] / misc / coding_tools / check-miscaptures.py
1 #! /usr/bin/python
2
3 import os, sys, compiler
4 from compiler.ast import Node, For, While, ListComp, AssName, Name, Lambda, Function
5
6
7 def check_source(source):
8     return check_thing(compiler.parse, source)
9
10 def check_file(path):
11     return check_thing(compiler.parseFile, path)
12
13 def check_thing(parser, thing):
14     try:
15         ast = parser(thing)
16     except SyntaxError, e:
17         return [e]
18     else:
19         results = []
20         check_ast(ast, results)
21         return results
22
23 def check_ast(ast, results):
24     """Check a node outside a loop."""
25     if isinstance(ast, (For, While, ListComp)):
26         check_loop(ast, results)
27     else:
28         for child in ast.getChildNodes():
29             if isinstance(ast, Node):
30                 check_ast(child, results)
31
32 def check_loop(ast, results):
33     """Check a particular outer loop."""
34
35     # List comprehensions have a poorly designed AST of the form
36     # ListComp(exprNode, [ListCompFor(...), ...]), in which the
37     # result expression is outside the ListCompFor node even though
38     # it is logically inside the loop(s).
39     # There may be multiple ListCompFor nodes (in cases such as
40     #   [lambda: (a,b) for a in ... for b in ...]
41     # ), and that case they are not nested in the AST. But these
42     # warts (nonobviously) happen not to matter for our analysis.
43
44     declared = {}  # maps name to lineno of declaration
45     nested = set()
46     collect_declared_and_nested(ast, declared, nested)
47
48     # For each nested function...
49     for funcnode in nested:
50         # Check for captured variables in this function.
51         captured = set()
52         collect_captured(funcnode, declared, captured)
53         for name in captured:
54             # We want to report the outermost capturing function
55             # (since that is where the workaround will need to be
56             # added), and the variable declaration. Just one report
57             # per capturing function per variable will do.
58             results.append(make_result(funcnode, name, declared[name]))
59
60         # Check each node in the function body in case it
61         # contains another 'for' loop.
62         childnodes = funcnode.getChildNodes()[len(funcnode.defaults):]
63         for child in childnodes:
64             check_ast(funcnode, results)
65
66 def collect_declared_and_nested(ast, declared, nested):
67     """
68     Collect the names declared in this 'for' loop, not including
69     names declared in nested functions. Also collect the nodes of
70     functions that are nested one level deep.
71     """
72     if isinstance(ast, AssName):
73         declared[ast.name] = ast.lineno
74     else:
75         childnodes = ast.getChildNodes()
76         if isinstance(ast, (Lambda, Function)):
77             nested.add(ast)
78
79             # The default argument expressions are "outside" the
80             # function, even though they are children of the
81             # Lambda or Function node.
82             childnodes = childnodes[:len(ast.defaults)]
83
84         for child in childnodes:
85             if isinstance(ast, Node):
86                 collect_declared_and_nested(child, declared, nested)
87
88 def collect_captured(ast, declared, captured):
89     """Collect any captured variables that are also in declared."""
90     if isinstance(ast, Name):
91         if ast.name in declared:
92             captured.add(ast.name)
93     else:
94         childnodes = ast.getChildNodes()
95
96         if isinstance(ast, (Lambda, Function)):
97             # Formal parameters of the function are excluded from
98             # captures we care about in subnodes of the function body.
99             declared = declared.copy()
100             for argname in ast.argnames:
101                 if argname in declared:
102                     del declared[argname]
103
104             for child in childnodes[len(ast.defaults):]:
105                 collect_captured(child, declared, captured)
106
107             # The default argument expressions are "outside" the
108             # function, even though they are children of the
109             # Lambda or Function node.
110             childnodes = childnodes[:len(ast.defaults)]
111
112         for child in childnodes:
113             if isinstance(ast, Node):
114                 collect_captured(child, declared, captured)
115
116
117 def make_result(funcnode, var_name, var_lineno):
118     if hasattr(funcnode, 'name'):
119         func_name = 'function %r' % (funcnode.name,)
120     else:
121         func_name = '<lambda>'
122     return (funcnode.lineno, func_name, var_name, var_lineno)
123
124 def report(out, path, results):
125     for r in results:
126         if isinstance(r, SyntaxError):
127             print >>out, path + (" NOT ANALYSED due to syntax error: %s" % r)
128         else:
129             print >>out, path + (":%r %s captures %r declared at line %d" % r)
130
131 def check(sources, out):
132     class Counts:
133         n = 0
134         processed_files = 0
135         suspect_files = 0
136     counts = Counts()
137
138     def _process(path):
139         results = check_file(path)
140         report(out, path, results)
141         counts.n += len(results)
142         counts.processed_files += 1
143         if len(results) > 0:
144             counts.suspect_files += 1
145
146     for source in sources:
147         print >>out, "Checking %s..." % (source,)
148         if os.path.isfile(source):
149             _process(source)
150         else:
151             for (dirpath, dirnames, filenames) in os.walk(source):
152                 for fn in filenames:
153                     (basename, ext) = os.path.splitext(fn)
154                     if ext == '.py':
155                         _process(os.path.join(dirpath, fn))
156
157     print >>out, ("%d suspiciously captured variables in %d out of %d files"
158                   % (counts.n, counts.suspect_files, counts.processed_files))
159     return counts.n
160
161
162 sources = ['src']
163 if len(sys.argv) > 1:
164     sources = sys.argv[1:]
165 if check(sources, sys.stderr) > 0:
166     sys.exit(1)
167
168
169 # TODO: self-tests