3 import os, sys, compiler
4 from compiler.ast import Node, For, While, ListComp, AssName, Name, Lambda, Function
7 def check_source(source):
8 return check_thing(compiler.parse, source)
11 return check_thing(compiler.parseFile, path)
13 def check_thing(parser, thing):
16 except SyntaxError, e:
20 check_ast(ast, results)
23 def check_ast(ast, results):
24 """Check a node outside a loop."""
25 if isinstance(ast, (For, While, ListComp)):
26 check_loop(ast, results)
28 for child in ast.getChildNodes():
29 if isinstance(ast, Node):
30 check_ast(child, results)
32 def check_loop(ast, results):
33 """Check a particular outer loop."""
35 # List comprehensions have a poorly designed AST of the form
36 # ListComp(exprNode, [ListCompFor(...), ...]), in which the
37 # result expression is outside the ListCompFor node even though
38 # it is logically inside the loop(s).
39 # There may be multiple ListCompFor nodes (in cases such as
40 # [lambda: (a,b) for a in ... for b in ...]
41 # ), and that case they are not nested in the AST. But these
42 # warts (nonobviously) happen not to matter for our analysis.
44 assigned = {} # maps name to lineno of topmost assignment
46 collect_assigned_and_nested(ast, assigned, nested)
48 # For each nested function...
49 for funcnode in nested:
50 # Check for captured variables in this function.
52 collect_captured(funcnode, assigned, captured)
54 # We want to report the outermost capturing function
55 # (since that is where the workaround will need to be
56 # added), and the topmost assignment to the variable.
57 # Just one report per capturing function per variable
59 results.append(make_result(funcnode, name, assigned[name]))
61 # Check each node in the function body in case it
62 # contains another 'for' loop.
63 childnodes = funcnode.getChildNodes()[len(funcnode.defaults):]
64 for child in childnodes:
65 check_ast(funcnode, results)
67 def collect_assigned_and_nested(ast, assigned, nested):
69 Collect the names assigned in this loop, not including names
70 assigned in nested functions. Also collect the nodes of functions
71 that are nested one level deep.
73 if isinstance(ast, AssName):
74 if ast.name not in assigned or assigned[ast.name] > ast.lineno:
75 assigned[ast.name] = ast.lineno
77 childnodes = ast.getChildNodes()
78 if isinstance(ast, (Lambda, Function)):
81 # The default argument expressions are "outside" the
82 # function, even though they are children of the
83 # Lambda or Function node.
84 childnodes = childnodes[:len(ast.defaults)]
86 for child in childnodes:
87 if isinstance(ast, Node):
88 collect_assigned_and_nested(child, assigned, nested)
90 def collect_captured(ast, assigned, captured):
91 """Collect any captured variables that are also in assigned."""
92 if isinstance(ast, Name):
93 if ast.name in assigned:
94 captured.add(ast.name)
96 childnodes = ast.getChildNodes()
97 if isinstance(ast, (Lambda, Function)):
98 # Formal parameters of the function are excluded from
99 # captures we care about in subnodes of the function body.
100 new_assigned = assigned.copy()
101 remove_argnames(ast.argnames, new_assigned)
103 for child in childnodes[len(ast.defaults):]:
104 collect_captured(child, assigned, captured)
106 # The default argument expressions are "outside" the
107 # function, even though they are children of the
108 # Lambda or Function node.
109 childnodes = childnodes[:len(ast.defaults)]
111 for child in childnodes:
112 if isinstance(ast, Node):
113 collect_captured(child, assigned, captured)
116 def remove_argnames(names, fromset):
117 for element in names:
118 if element in fromset:
120 elif isinstance(element, (tuple, list)):
121 remove_argnames(element, fromset)
124 def make_result(funcnode, var_name, var_lineno):
125 if hasattr(funcnode, 'name'):
126 func_name = 'function %r' % (funcnode.name,)
128 func_name = '<lambda>'
129 return (funcnode.lineno, func_name, var_name, var_lineno)
131 def report(out, path, results):
133 if isinstance(r, SyntaxError):
134 print >>out, path + (" NOT ANALYSED due to syntax error: %s" % r)
136 print >>out, path + (":%r %s captures %r assigned at line %d" % r)
138 def check(sources, out):
146 results = check_file(path)
147 report(out, path, results)
148 counts.n += len(results)
149 counts.processed_files += 1
151 counts.suspect_files += 1
153 for source in sources:
154 print >>out, "Checking %s..." % (source,)
155 if os.path.isfile(source):
158 for (dirpath, dirnames, filenames) in os.walk(source):
160 (basename, ext) = os.path.splitext(fn)
162 _process(os.path.join(dirpath, fn))
164 print >>out, ("%d suspiciously captured variables in %d out of %d files"
165 % (counts.n, counts.suspect_files, counts.processed_files))
170 if len(sys.argv) > 1:
171 sources = sys.argv[1:]
172 if check(sources, sys.stderr) > 0: