3 import os, sys, compiler
4 from compiler.ast import Node, For, While, ListComp, AssName, Name, Lambda, Function
7 def check_source(source):
8 return check_thing(compiler.parse, source)
11 return check_thing(compiler.parseFile, path)
13 def check_thing(parser, thing):
16 except SyntaxError, e:
20 check_ast(ast, results)
23 def check_ast(ast, results):
24 """Check a node outside a loop."""
25 if isinstance(ast, (For, While, ListComp)):
26 check_loop(ast, results)
28 for child in ast.getChildNodes():
29 if isinstance(ast, Node):
30 check_ast(child, results)
32 def check_loop(ast, results):
33 """Check a particular outer loop."""
35 # List comprehensions have a poorly designed AST of the form
36 # ListComp(exprNode, [ListCompFor(...), ...]), in which the
37 # result expression is outside the ListCompFor node even though
38 # it is logically inside the loop(s).
39 # There may be multiple ListCompFor nodes (in cases such as
40 # [lambda: (a,b) for a in ... for b in ...]
41 # ), and that case they are not nested in the AST. But these
42 # warts (nonobviously) happen not to matter for our analysis.
44 declared = {} # maps name to lineno of declaration
46 collect_declared_and_nested(ast, declared, nested)
48 # For each nested function...
49 for funcnode in nested:
50 # Check for captured variables in this function.
52 collect_captured(funcnode, declared, captured)
54 # We want to report the outermost capturing function
55 # (since that is where the workaround will need to be
56 # added), and the variable declaration. Just one report
57 # per capturing function per variable will do.
58 results.append(make_result(funcnode, name, declared[name]))
60 # Check each node in the function body in case it
61 # contains another 'for' loop.
62 childnodes = funcnode.getChildNodes()[len(funcnode.defaults):]
63 for child in childnodes:
64 check_ast(funcnode, results)
66 def collect_declared_and_nested(ast, declared, nested):
68 Collect the names declared in this 'for' loop, not including
69 names declared in nested functions. Also collect the nodes of
70 functions that are nested one level deep.
72 if isinstance(ast, AssName):
73 declared[ast.name] = ast.lineno
75 childnodes = ast.getChildNodes()
76 if isinstance(ast, (Lambda, Function)):
79 # The default argument expressions are "outside" the
80 # function, even though they are children of the
81 # Lambda or Function node.
82 childnodes = childnodes[:len(ast.defaults)]
84 for child in childnodes:
85 if isinstance(ast, Node):
86 collect_declared_and_nested(child, declared, nested)
88 def collect_captured(ast, declared, captured):
89 """Collect any captured variables that are also in declared."""
90 if isinstance(ast, Name):
91 if ast.name in declared:
92 captured.add(ast.name)
94 childnodes = ast.getChildNodes()
96 if isinstance(ast, (Lambda, Function)):
97 # Formal parameters of the function are excluded from
98 # captures we care about in subnodes of the function body.
99 declared = declared.copy()
100 for argname in ast.argnames:
101 if argname in declared:
102 del declared[argname]
104 for child in childnodes[len(ast.defaults):]:
105 collect_captured(child, declared, captured)
107 # The default argument expressions are "outside" the
108 # function, even though they are children of the
109 # Lambda or Function node.
110 childnodes = childnodes[:len(ast.defaults)]
112 for child in childnodes:
113 if isinstance(ast, Node):
114 collect_captured(child, declared, captured)
117 def make_result(funcnode, var_name, var_lineno):
118 if hasattr(funcnode, 'name'):
119 func_name = 'function %r' % (funcnode.name,)
121 func_name = '<lambda>'
122 return (funcnode.lineno, func_name, var_name, var_lineno)
124 def report(out, path, results):
126 if isinstance(r, SyntaxError):
127 print >>out, path + (" NOT ANALYSED due to syntax error: %s" % r)
129 print >>out, path + (":%r %s captures %r declared at line %d" % r)
131 def check(sources, out):
139 results = check_file(path)
140 report(out, path, results)
141 counts.n += len(results)
142 counts.processed_files += 1
144 counts.suspect_files += 1
146 for source in sources:
147 print >>out, "Checking %s..." % (source,)
148 if os.path.isfile(source):
151 for (dirpath, dirnames, filenames) in os.walk(source):
153 (basename, ext) = os.path.splitext(fn)
155 _process(os.path.join(dirpath, fn))
157 print >>out, ("%d suspiciously captured variables in %d out of %d files"
158 % (counts.n, counts.suspect_files, counts.processed_files))
163 if len(sys.argv) > 1:
164 sources = sys.argv[1:]
165 if check(sources, sys.stderr) > 0: