3 import os, sys, compiler
4 from compiler.ast import Node, For, While, ListComp, AssName, Name, Lambda, Function
7 def check_source(source):
8 return check_thing(compiler.parse, source)
11 return check_thing(compiler.parseFile, path)
13 def check_thing(parser, thing):
16 except SyntaxError, e:
20 check_ast(ast, results)
23 def check_ast(ast, results):
24 """Check a node outside a loop."""
25 if isinstance(ast, (For, While, ListComp)):
26 check_loop(ast, results)
28 for child in ast.getChildNodes():
29 if isinstance(ast, Node):
30 check_ast(child, results)
32 def check_loop(ast, results):
33 """Check a particular outer loop."""
35 # List comprehensions have a poorly designed AST of the form
36 # ListComp(exprNode, [ListCompFor(...), ...]), in which the
37 # result expression is outside the ListCompFor node even though
38 # it is logically inside the loop(s).
39 # There may be multiple ListCompFor nodes (in cases such as
40 # [lambda: (a,b) for a in ... for b in ...]
41 # ), and that case they are not nested in the AST. But these
42 # warts (nonobviously) happen not to matter for our analysis.
44 declared = {} # maps name to lineno of declaration
46 collect_declared_and_nested(ast, declared, nested)
48 # For each nested function...
49 for funcnode in nested:
50 # Check for captured variables in this function.
52 collect_captured(funcnode, declared, captured)
54 # We want to report the outermost capturing function
55 # (since that is where the workaround will need to be
56 # added), and the variable declaration. Just one report
57 # per capturing function per variable will do.
58 results.append(make_result(funcnode, name, declared[name]))
60 # Check each node in the function body in case it
61 # contains another 'for' loop.
62 childnodes = funcnode.getChildNodes()[len(funcnode.defaults):]
63 for child in childnodes:
64 check_ast(funcnode, results)
66 def collect_declared_and_nested(ast, declared, nested):
68 Collect the names declared in this 'for' loop, not including
69 names declared in nested functions. Also collect the nodes of
70 functions that are nested one level deep.
72 if isinstance(ast, AssName):
73 declared[ast.name] = ast.lineno
75 childnodes = ast.getChildNodes()
76 if isinstance(ast, (Lambda, Function)):
79 # The default argument expressions are "outside" the
80 # function, even though they are children of the
81 # Lambda or Function node.
82 childnodes = childnodes[:len(ast.defaults)]
84 for child in childnodes:
85 if isinstance(ast, Node):
86 collect_declared_and_nested(child, declared, nested)
88 def collect_captured(ast, declared, captured):
89 """Collect any captured variables that are also in declared."""
90 if isinstance(ast, Name):
91 if ast.name in declared:
92 captured.add(ast.name)
94 childnodes = ast.getChildNodes()
96 if isinstance(ast, (Lambda, Function)):
97 # Formal parameters of the function are excluded from
98 # captures we care about in subnodes of the function body.
99 new_declared = declared.copy()
100 remove_argnames(ast.argnames, new_declared)
102 for child in childnodes[len(ast.defaults):]:
103 collect_captured(child, declared, captured)
105 # The default argument expressions are "outside" the
106 # function, even though they are children of the
107 # Lambda or Function node.
108 childnodes = childnodes[:len(ast.defaults)]
110 for child in childnodes:
111 if isinstance(ast, Node):
112 collect_captured(child, declared, captured)
115 def remove_argnames(names, fromset):
116 for element in names:
117 if element in fromset:
119 elif isinstance(element, (tuple, list)):
120 remove_argnames(element, fromset)
123 def make_result(funcnode, var_name, var_lineno):
124 if hasattr(funcnode, 'name'):
125 func_name = 'function %r' % (funcnode.name,)
127 func_name = '<lambda>'
128 return (funcnode.lineno, func_name, var_name, var_lineno)
130 def report(out, path, results):
132 if isinstance(r, SyntaxError):
133 print >>out, path + (" NOT ANALYSED due to syntax error: %s" % r)
135 print >>out, path + (":%r %s captures %r declared at line %d" % r)
137 def check(sources, out):
145 results = check_file(path)
146 report(out, path, results)
147 counts.n += len(results)
148 counts.processed_files += 1
150 counts.suspect_files += 1
152 for source in sources:
153 print >>out, "Checking %s..." % (source,)
154 if os.path.isfile(source):
157 for (dirpath, dirnames, filenames) in os.walk(source):
159 (basename, ext) = os.path.splitext(fn)
161 _process(os.path.join(dirpath, fn))
163 print >>out, ("%d suspiciously captured variables in %d out of %d files"
164 % (counts.n, counts.suspect_files, counts.processed_files))
169 if len(sys.argv) > 1:
170 sources = sys.argv[1:]
171 if check(sources, sys.stderr) > 0: