]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blob - src/allmydata/test/test_util.py
One fix for bug #1154: webapi GETs with a 'Range' header broke new-downloader.
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / test / test_util.py
1
2 def foo(): pass # keep the line number constant
3
4 import os, time, sys
5 from StringIO import StringIO
6 from twisted.trial import unittest
7 from twisted.internet import defer, reactor
8 from twisted.python.failure import Failure
9 from twisted.python import log
10 from hashlib import md5
11
12 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
13 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
14 from allmydata.util import limiter, time_format, pollmixin, cachedir
15 from allmydata.util import statistics, dictutil, pipeline
16 from allmydata.util import log as tahoe_log
17 from allmydata.util.spans import Spans, overlap, DataSpans
18
19 class Base32(unittest.TestCase):
20     def test_b2a_matches_Pythons(self):
21         import base64
22         y = "\x12\x34\x45\x67\x89\x0a\xbc\xde\xf0"
23         x = base64.b32encode(y)
24         while x and x[-1] == '=':
25             x = x[:-1]
26         x = x.lower()
27         self.failUnlessEqual(base32.b2a(y), x)
28     def test_b2a(self):
29         self.failUnlessEqual(base32.b2a("\x12\x34"), "ci2a")
30     def test_b2a_or_none(self):
31         self.failUnlessEqual(base32.b2a_or_none(None), None)
32         self.failUnlessEqual(base32.b2a_or_none("\x12\x34"), "ci2a")
33     def test_a2b(self):
34         self.failUnlessEqual(base32.a2b("ci2a"), "\x12\x34")
35         self.failUnlessRaises(AssertionError, base32.a2b, "b0gus")
36
37 class IDLib(unittest.TestCase):
38     def test_nodeid_b2a(self):
39         self.failUnlessEqual(idlib.nodeid_b2a("\x00"*20), "a"*32)
40
41 class NoArgumentException(Exception):
42     def __init__(self):
43         pass
44
45 class HumanReadable(unittest.TestCase):
46     def test_repr(self):
47         hr = humanreadable.hr
48         self.failUnlessEqual(hr(foo), "<foo() at test_util.py:2>")
49         self.failUnlessEqual(hr(self.test_repr),
50                              "<bound method HumanReadable.test_repr of <allmydata.test.test_util.HumanReadable testMethod=test_repr>>")
51         self.failUnlessEqual(hr(1L), "1")
52         self.failUnlessEqual(hr(10**40),
53                              "100000000000000000...000000000000000000")
54         self.failUnlessEqual(hr(self), "<allmydata.test.test_util.HumanReadable testMethod=test_repr>")
55         self.failUnlessEqual(hr([1,2]), "[1, 2]")
56         self.failUnlessEqual(hr({1:2}), "{1:2}")
57         try:
58             raise ValueError
59         except Exception, e:
60             self.failUnless(
61                 hr(e) == "<ValueError: ()>" # python-2.4
62                 or hr(e) == "ValueError()") # python-2.5
63         try:
64             raise ValueError("oops")
65         except Exception, e:
66             self.failUnless(
67                 hr(e) == "<ValueError: 'oops'>" # python-2.4
68                 or hr(e) == "ValueError('oops',)") # python-2.5
69         try:
70             raise NoArgumentException
71         except Exception, e:
72             self.failUnless(
73                 hr(e) == "<NoArgumentException>" # python-2.4
74                 or hr(e) == "NoArgumentException()") # python-2.5
75
76
77 class MyList(list):
78     pass
79
80 class Math(unittest.TestCase):
81     def test_div_ceil(self):
82         f = mathutil.div_ceil
83         self.failUnlessEqual(f(0, 1), 0)
84         self.failUnlessEqual(f(0, 2), 0)
85         self.failUnlessEqual(f(0, 3), 0)
86         self.failUnlessEqual(f(1, 3), 1)
87         self.failUnlessEqual(f(2, 3), 1)
88         self.failUnlessEqual(f(3, 3), 1)
89         self.failUnlessEqual(f(4, 3), 2)
90         self.failUnlessEqual(f(5, 3), 2)
91         self.failUnlessEqual(f(6, 3), 2)
92         self.failUnlessEqual(f(7, 3), 3)
93
94     def test_next_multiple(self):
95         f = mathutil.next_multiple
96         self.failUnlessEqual(f(5, 1), 5)
97         self.failUnlessEqual(f(5, 2), 6)
98         self.failUnlessEqual(f(5, 3), 6)
99         self.failUnlessEqual(f(5, 4), 8)
100         self.failUnlessEqual(f(5, 5), 5)
101         self.failUnlessEqual(f(5, 6), 6)
102         self.failUnlessEqual(f(32, 1), 32)
103         self.failUnlessEqual(f(32, 2), 32)
104         self.failUnlessEqual(f(32, 3), 33)
105         self.failUnlessEqual(f(32, 4), 32)
106         self.failUnlessEqual(f(32, 5), 35)
107         self.failUnlessEqual(f(32, 6), 36)
108         self.failUnlessEqual(f(32, 7), 35)
109         self.failUnlessEqual(f(32, 8), 32)
110         self.failUnlessEqual(f(32, 9), 36)
111         self.failUnlessEqual(f(32, 10), 40)
112         self.failUnlessEqual(f(32, 11), 33)
113         self.failUnlessEqual(f(32, 12), 36)
114         self.failUnlessEqual(f(32, 13), 39)
115         self.failUnlessEqual(f(32, 14), 42)
116         self.failUnlessEqual(f(32, 15), 45)
117         self.failUnlessEqual(f(32, 16), 32)
118         self.failUnlessEqual(f(32, 17), 34)
119         self.failUnlessEqual(f(32, 18), 36)
120         self.failUnlessEqual(f(32, 589), 589)
121
122     def test_pad_size(self):
123         f = mathutil.pad_size
124         self.failUnlessEqual(f(0, 4), 0)
125         self.failUnlessEqual(f(1, 4), 3)
126         self.failUnlessEqual(f(2, 4), 2)
127         self.failUnlessEqual(f(3, 4), 1)
128         self.failUnlessEqual(f(4, 4), 0)
129         self.failUnlessEqual(f(5, 4), 3)
130
131     def test_is_power_of_k(self):
132         f = mathutil.is_power_of_k
133         for i in range(1, 100):
134             if i in (1, 2, 4, 8, 16, 32, 64):
135                 self.failUnless(f(i, 2), "but %d *is* a power of 2" % i)
136             else:
137                 self.failIf(f(i, 2), "but %d is *not* a power of 2" % i)
138         for i in range(1, 100):
139             if i in (1, 3, 9, 27, 81):
140                 self.failUnless(f(i, 3), "but %d *is* a power of 3" % i)
141             else:
142                 self.failIf(f(i, 3), "but %d is *not* a power of 3" % i)
143
144     def test_next_power_of_k(self):
145         f = mathutil.next_power_of_k
146         self.failUnlessEqual(f(0,2), 1)
147         self.failUnlessEqual(f(1,2), 1)
148         self.failUnlessEqual(f(2,2), 2)
149         self.failUnlessEqual(f(3,2), 4)
150         self.failUnlessEqual(f(4,2), 4)
151         for i in range(5, 8): self.failUnlessEqual(f(i,2), 8, "%d" % i)
152         for i in range(9, 16): self.failUnlessEqual(f(i,2), 16, "%d" % i)
153         for i in range(17, 32): self.failUnlessEqual(f(i,2), 32, "%d" % i)
154         for i in range(33, 64): self.failUnlessEqual(f(i,2), 64, "%d" % i)
155         for i in range(65, 100): self.failUnlessEqual(f(i,2), 128, "%d" % i)
156
157         self.failUnlessEqual(f(0,3), 1)
158         self.failUnlessEqual(f(1,3), 1)
159         self.failUnlessEqual(f(2,3), 3)
160         self.failUnlessEqual(f(3,3), 3)
161         for i in range(4, 9): self.failUnlessEqual(f(i,3), 9, "%d" % i)
162         for i in range(10, 27): self.failUnlessEqual(f(i,3), 27, "%d" % i)
163         for i in range(28, 81): self.failUnlessEqual(f(i,3), 81, "%d" % i)
164         for i in range(82, 200): self.failUnlessEqual(f(i,3), 243, "%d" % i)
165
166     def test_ave(self):
167         f = mathutil.ave
168         self.failUnlessEqual(f([1,2,3]), 2)
169         self.failUnlessEqual(f([0,0,0,4]), 1)
170         self.failUnlessAlmostEqual(f([0.0, 1.0, 1.0]), .666666666666)
171
172     def test_round_sigfigs(self):
173         f = mathutil.round_sigfigs
174         self.failUnlessEqual(f(22.0/3, 4), 7.3330000000000002)
175
176 class Statistics(unittest.TestCase):
177     def should_assert(self, msg, func, *args, **kwargs):
178         try:
179             func(*args, **kwargs)
180             self.fail(msg)
181         except AssertionError:
182             pass
183
184     def failUnlessListEqual(self, a, b, msg = None):
185         self.failUnlessEqual(len(a), len(b))
186         for i in range(len(a)):
187             self.failUnlessEqual(a[i], b[i], msg)
188
189     def failUnlessListAlmostEqual(self, a, b, places = 7, msg = None):
190         self.failUnlessEqual(len(a), len(b))
191         for i in range(len(a)):
192             self.failUnlessAlmostEqual(a[i], b[i], places, msg)
193
194     def test_binomial_coeff(self):
195         f = statistics.binomial_coeff
196         self.failUnlessEqual(f(20, 0), 1)
197         self.failUnlessEqual(f(20, 1), 20)
198         self.failUnlessEqual(f(20, 2), 190)
199         self.failUnlessEqual(f(20, 8), f(20, 12))
200         self.should_assert("Should assert if n < k", f, 2, 3)
201
202     def test_binomial_distribution_pmf(self):
203         f = statistics.binomial_distribution_pmf
204
205         pmf_comp = f(2, .1)
206         pmf_stat = [0.81, 0.18, 0.01]
207         self.failUnlessListAlmostEqual(pmf_comp, pmf_stat)
208
209         # Summing across a PMF should give the total probability 1
210         self.failUnlessAlmostEqual(sum(pmf_comp), 1)
211         self.should_assert("Should assert if not 0<=p<=1", f, 1, -1)
212         self.should_assert("Should assert if n < 1", f, 0, .1)
213
214         out = StringIO()
215         statistics.print_pmf(pmf_comp, out=out)
216         lines = out.getvalue().splitlines()
217         self.failUnlessEqual(lines[0], "i=0: 0.81")
218         self.failUnlessEqual(lines[1], "i=1: 0.18")
219         self.failUnlessEqual(lines[2], "i=2: 0.01")
220
221     def test_survival_pmf(self):
222         f = statistics.survival_pmf
223         # Cross-check binomial-distribution method against convolution
224         # method.
225         p_list = [.9999] * 100 + [.99] * 50 + [.8] * 20
226         pmf1 = statistics.survival_pmf_via_conv(p_list)
227         pmf2 = statistics.survival_pmf_via_bd(p_list)
228         self.failUnlessListAlmostEqual(pmf1, pmf2)
229         self.failUnlessTrue(statistics.valid_pmf(pmf1))
230         self.should_assert("Should assert if p_i > 1", f, [1.1]);
231         self.should_assert("Should assert if p_i < 0", f, [-.1]);
232
233     def test_repair_count_pmf(self):
234         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
235         repair_pmf = statistics.repair_count_pmf(survival_pmf, 3)
236         # repair_pmf[0] == sum(survival_pmf[0,1,2,5])
237         # repair_pmf[1] == survival_pmf[4]
238         # repair_pmf[2] = survival_pmf[3]
239         self.failUnlessListAlmostEqual(repair_pmf,
240                                        [0.00001 + 0.00045 + 0.0081 + 0.59049,
241                                         .32805,
242                                         .0729,
243                                         0, 0, 0])
244
245     def test_repair_cost(self):
246         survival_pmf = statistics.binomial_distribution_pmf(5, .9)
247         bwcost = statistics.bandwidth_cost_function
248         cost = statistics.mean_repair_cost(bwcost, 1000,
249                                            survival_pmf, 3, ul_dl_ratio=1.0)
250         self.failUnlessAlmostEqual(cost, 558.90)
251         cost = statistics.mean_repair_cost(bwcost, 1000,
252                                            survival_pmf, 3, ul_dl_ratio=8.0)
253         self.failUnlessAlmostEqual(cost, 1664.55)
254
255         # I haven't manually checked the math beyond here -warner
256         cost = statistics.eternal_repair_cost(bwcost, 1000,
257                                               survival_pmf, 3,
258                                               discount_rate=0, ul_dl_ratio=1.0)
259         self.failUnlessAlmostEqual(cost, 65292.056074766246)
260         cost = statistics.eternal_repair_cost(bwcost, 1000,
261                                               survival_pmf, 3,
262                                               discount_rate=0.05,
263                                               ul_dl_ratio=1.0)
264         self.failUnlessAlmostEqual(cost, 9133.6097158191551)
265
266     def test_convolve(self):
267         f = statistics.convolve
268         v1 = [ 1, 2, 3 ]
269         v2 = [ 4, 5, 6 ]
270         v3 = [ 7, 8 ]
271         v1v2result = [ 4, 13, 28, 27, 18 ]
272         # Convolution is commutative
273         r1 = f(v1, v2)
274         r2 = f(v2, v1)
275         self.failUnlessListEqual(r1, r2, "Convolution should be commutative")
276         self.failUnlessListEqual(r1, v1v2result, "Didn't match known result")
277         # Convolution is associative
278         r1 = f(f(v1, v2), v3)
279         r2 = f(v1, f(v2, v3))
280         self.failUnlessListEqual(r1, r2, "Convolution should be associative")
281         # Convolution is distributive
282         r1 = f(v3, [ a + b for a, b in zip(v1, v2) ])
283         tmp1 = f(v3, v1)
284         tmp2 = f(v3, v2)
285         r2 = [ a + b for a, b in zip(tmp1, tmp2) ]
286         self.failUnlessListEqual(r1, r2, "Convolution should be distributive")
287         # Convolution is scalar multiplication associative
288         tmp1 = f(v1, v2)
289         r1 = [ a * 4 for a in tmp1 ]
290         tmp2 = [ a * 4 for a in v1 ]
291         r2 = f(tmp2, v2)
292         self.failUnlessListEqual(r1, r2, "Convolution should be scalar multiplication associative")
293
294     def test_find_k(self):
295         f = statistics.find_k
296         g = statistics.pr_file_loss
297         plist = [.9] * 10 + [.8] * 10 # N=20
298         t = .0001
299         k = f(plist, t)
300         self.failUnlessEqual(k, 10)
301         self.failUnless(g(plist, k) < t)
302
303     def test_pr_file_loss(self):
304         f = statistics.pr_file_loss
305         plist = [.5] * 10
306         self.failUnlessEqual(f(plist, 3), .0546875)
307
308     def test_pr_backup_file_loss(self):
309         f = statistics.pr_backup_file_loss
310         plist = [.5] * 10
311         self.failUnlessEqual(f(plist, .5, 3), .02734375)
312
313
314 class Asserts(unittest.TestCase):
315     def should_assert(self, func, *args, **kwargs):
316         try:
317             func(*args, **kwargs)
318         except AssertionError, e:
319             return str(e)
320         except Exception, e:
321             self.fail("assert failed with non-AssertionError: %s" % e)
322         self.fail("assert was not caught")
323
324     def should_not_assert(self, func, *args, **kwargs):
325         try:
326             func(*args, **kwargs)
327         except AssertionError, e:
328             self.fail("assertion fired when it should not have: %s" % e)
329         except Exception, e:
330             self.fail("assertion (which shouldn't have failed) failed with non-AssertionError: %s" % e)
331         return # we're happy
332
333
334     def test_assert(self):
335         f = assertutil._assert
336         self.should_assert(f)
337         self.should_assert(f, False)
338         self.should_not_assert(f, True)
339
340         m = self.should_assert(f, False, "message")
341         self.failUnlessEqual(m, "'message' <type 'str'>", m)
342         m = self.should_assert(f, False, "message1", othermsg=12)
343         self.failUnlessEqual("'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
344         m = self.should_assert(f, False, othermsg="message2")
345         self.failUnlessEqual("othermsg: 'message2' <type 'str'>", m)
346
347     def test_precondition(self):
348         f = assertutil.precondition
349         self.should_assert(f)
350         self.should_assert(f, False)
351         self.should_not_assert(f, True)
352
353         m = self.should_assert(f, False, "message")
354         self.failUnlessEqual("precondition: 'message' <type 'str'>", m)
355         m = self.should_assert(f, False, "message1", othermsg=12)
356         self.failUnlessEqual("precondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
357         m = self.should_assert(f, False, othermsg="message2")
358         self.failUnlessEqual("precondition: othermsg: 'message2' <type 'str'>", m)
359
360     def test_postcondition(self):
361         f = assertutil.postcondition
362         self.should_assert(f)
363         self.should_assert(f, False)
364         self.should_not_assert(f, True)
365
366         m = self.should_assert(f, False, "message")
367         self.failUnlessEqual("postcondition: 'message' <type 'str'>", m)
368         m = self.should_assert(f, False, "message1", othermsg=12)
369         self.failUnlessEqual("postcondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
370         m = self.should_assert(f, False, othermsg="message2")
371         self.failUnlessEqual("postcondition: othermsg: 'message2' <type 'str'>", m)
372
373 class FileUtil(unittest.TestCase):
374     def mkdir(self, basedir, path, mode=0777):
375         fn = os.path.join(basedir, path)
376         fileutil.make_dirs(fn, mode)
377
378     def touch(self, basedir, path, mode=None, data="touch\n"):
379         fn = os.path.join(basedir, path)
380         f = open(fn, "w")
381         f.write(data)
382         f.close()
383         if mode is not None:
384             os.chmod(fn, mode)
385
386     def test_rm_dir(self):
387         basedir = "util/FileUtil/test_rm_dir"
388         fileutil.make_dirs(basedir)
389         # create it again to test idempotency
390         fileutil.make_dirs(basedir)
391         d = os.path.join(basedir, "doomed")
392         self.mkdir(d, "a/b")
393         self.touch(d, "a/b/1.txt")
394         self.touch(d, "a/b/2.txt", 0444)
395         self.touch(d, "a/b/3.txt", 0)
396         self.mkdir(d, "a/c")
397         self.touch(d, "a/c/1.txt")
398         self.touch(d, "a/c/2.txt", 0444)
399         self.touch(d, "a/c/3.txt", 0)
400         os.chmod(os.path.join(d, "a/c"), 0444)
401         self.mkdir(d, "a/d")
402         self.touch(d, "a/d/1.txt")
403         self.touch(d, "a/d/2.txt", 0444)
404         self.touch(d, "a/d/3.txt", 0)
405         os.chmod(os.path.join(d, "a/d"), 0)
406
407         fileutil.rm_dir(d)
408         self.failIf(os.path.exists(d))
409         # remove it again to test idempotency
410         fileutil.rm_dir(d)
411
412     def test_remove_if_possible(self):
413         basedir = "util/FileUtil/test_remove_if_possible"
414         fileutil.make_dirs(basedir)
415         self.touch(basedir, "here")
416         fn = os.path.join(basedir, "here")
417         fileutil.remove_if_possible(fn)
418         self.failIf(os.path.exists(fn))
419         fileutil.remove_if_possible(fn) # should be idempotent
420         fileutil.rm_dir(basedir)
421         fileutil.remove_if_possible(fn) # should survive errors
422
423     def test_open_or_create(self):
424         basedir = "util/FileUtil/test_open_or_create"
425         fileutil.make_dirs(basedir)
426         fn = os.path.join(basedir, "here")
427         f = fileutil.open_or_create(fn)
428         f.write("stuff.")
429         f.close()
430         f = fileutil.open_or_create(fn)
431         f.seek(0, 2)
432         f.write("more.")
433         f.close()
434         f = open(fn, "r")
435         data = f.read()
436         f.close()
437         self.failUnlessEqual(data, "stuff.more.")
438
439     def test_NamedTemporaryDirectory(self):
440         basedir = "util/FileUtil/test_NamedTemporaryDirectory"
441         fileutil.make_dirs(basedir)
442         td = fileutil.NamedTemporaryDirectory(dir=basedir)
443         name = td.name
444         self.failUnless(basedir in name)
445         self.failUnless(basedir in repr(td))
446         self.failUnless(os.path.isdir(name))
447         del td
448         # it is conceivable that we need to force gc here, but I'm not sure
449         self.failIf(os.path.isdir(name))
450
451     def test_rename(self):
452         basedir = "util/FileUtil/test_rename"
453         fileutil.make_dirs(basedir)
454         self.touch(basedir, "here")
455         fn = os.path.join(basedir, "here")
456         fn2 = os.path.join(basedir, "there")
457         fileutil.rename(fn, fn2)
458         self.failIf(os.path.exists(fn))
459         self.failUnless(os.path.exists(fn2))
460
461     def test_du(self):
462         basedir = "util/FileUtil/test_du"
463         fileutil.make_dirs(basedir)
464         d = os.path.join(basedir, "space-consuming")
465         self.mkdir(d, "a/b")
466         self.touch(d, "a/b/1.txt", data="a"*10)
467         self.touch(d, "a/b/2.txt", data="b"*11)
468         self.mkdir(d, "a/c")
469         self.touch(d, "a/c/1.txt", data="c"*12)
470         self.touch(d, "a/c/2.txt", data="d"*13)
471
472         used = fileutil.du(basedir)
473         self.failUnlessEqual(10+11+12+13, used)
474
475     def test_abspath_expanduser_unicode(self):
476         self.failUnlessRaises(AssertionError, fileutil.abspath_expanduser_unicode, "bytestring")
477
478         saved_cwd = os.path.normpath(os.getcwdu())
479         abspath_cwd = fileutil.abspath_expanduser_unicode(u".")
480         self.failUnless(isinstance(saved_cwd, unicode), saved_cwd)
481         self.failUnless(isinstance(abspath_cwd, unicode), abspath_cwd)
482         self.failUnlessEqual(abspath_cwd, saved_cwd)
483
484         # adapted from <http://svn.python.org/view/python/branches/release26-maint/Lib/test/test_posixpath.py?view=markup&pathrev=78279#test_abspath>
485
486         self.failUnlessIn(u"foo", fileutil.abspath_expanduser_unicode(u"foo"))
487         self.failIfIn(u"~", fileutil.abspath_expanduser_unicode(u"~"))
488
489         cwds = ['cwd']
490         try:
491             cwds.append(u'\xe7w\xf0'.encode(sys.getfilesystemencoding()
492                                             or 'ascii'))
493         except UnicodeEncodeError:
494             pass # the cwd can't be encoded -- test with ascii cwd only
495
496         for cwd in cwds:
497             try:
498                 os.mkdir(cwd)
499                 os.chdir(cwd)
500                 for upath in (u'', u'fuu', u'f\xf9\xf9', u'/fuu', u'U:\\', u'~'):
501                     uabspath = fileutil.abspath_expanduser_unicode(upath)
502                     self.failUnless(isinstance(uabspath, unicode), uabspath)
503             finally:
504                 os.chdir(saved_cwd)
505
506 class PollMixinTests(unittest.TestCase):
507     def setUp(self):
508         self.pm = pollmixin.PollMixin()
509
510     def test_PollMixin_True(self):
511         d = self.pm.poll(check_f=lambda : True,
512                          pollinterval=0.1)
513         return d
514
515     def test_PollMixin_False_then_True(self):
516         i = iter([False, True])
517         d = self.pm.poll(check_f=i.next,
518                          pollinterval=0.1)
519         return d
520
521     def test_timeout(self):
522         d = self.pm.poll(check_f=lambda: False,
523                          pollinterval=0.01,
524                          timeout=1)
525         def _suc(res):
526             self.fail("poll should have failed, not returned %s" % (res,))
527         def _err(f):
528             f.trap(pollmixin.TimeoutError)
529             return None # success
530         d.addCallbacks(_suc, _err)
531         return d
532
533 class DeferredUtilTests(unittest.TestCase):
534     def test_gather_results(self):
535         d1 = defer.Deferred()
536         d2 = defer.Deferred()
537         res = deferredutil.gatherResults([d1, d2])
538         d1.errback(ValueError("BAD"))
539         def _callb(res):
540             self.fail("Should have errbacked, not resulted in %s" % (res,))
541         def _errb(thef):
542             thef.trap(ValueError)
543         res.addCallbacks(_callb, _errb)
544         return res
545
546     def test_success(self):
547         d1, d2 = defer.Deferred(), defer.Deferred()
548         good = []
549         bad = []
550         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
551         dlss.addCallbacks(good.append, bad.append)
552         d1.callback(1)
553         d2.callback(2)
554         self.failUnlessEqual(good, [[1,2]])
555         self.failUnlessEqual(bad, [])
556
557     def test_failure(self):
558         d1, d2 = defer.Deferred(), defer.Deferred()
559         good = []
560         bad = []
561         dlss = deferredutil.DeferredListShouldSucceed([d1,d2])
562         dlss.addCallbacks(good.append, bad.append)
563         d1.addErrback(lambda _ignore: None)
564         d2.addErrback(lambda _ignore: None)
565         d1.callback(1)
566         d2.errback(ValueError())
567         self.failUnlessEqual(good, [])
568         self.failUnlessEqual(len(bad), 1)
569         f = bad[0]
570         self.failUnless(isinstance(f, Failure))
571         self.failUnless(f.check(ValueError))
572
573 class HashUtilTests(unittest.TestCase):
574
575     def test_random_key(self):
576         k = hashutil.random_key()
577         self.failUnlessEqual(len(k), hashutil.KEYLEN)
578
579     def test_sha256d(self):
580         h1 = hashutil.tagged_hash("tag1", "value")
581         h2 = hashutil.tagged_hasher("tag1")
582         h2.update("value")
583         h2a = h2.digest()
584         h2b = h2.digest()
585         self.failUnlessEqual(h1, h2a)
586         self.failUnlessEqual(h2a, h2b)
587
588     def test_sha256d_truncated(self):
589         h1 = hashutil.tagged_hash("tag1", "value", 16)
590         h2 = hashutil.tagged_hasher("tag1", 16)
591         h2.update("value")
592         h2 = h2.digest()
593         self.failUnlessEqual(len(h1), 16)
594         self.failUnlessEqual(len(h2), 16)
595         self.failUnlessEqual(h1, h2)
596
597     def test_chk(self):
598         h1 = hashutil.convergence_hash(3, 10, 1000, "data", "secret")
599         h2 = hashutil.convergence_hasher(3, 10, 1000, "secret")
600         h2.update("data")
601         h2 = h2.digest()
602         self.failUnlessEqual(h1, h2)
603
604     def test_hashers(self):
605         h1 = hashutil.block_hash("foo")
606         h2 = hashutil.block_hasher()
607         h2.update("foo")
608         self.failUnlessEqual(h1, h2.digest())
609
610         h1 = hashutil.uri_extension_hash("foo")
611         h2 = hashutil.uri_extension_hasher()
612         h2.update("foo")
613         self.failUnlessEqual(h1, h2.digest())
614
615         h1 = hashutil.plaintext_hash("foo")
616         h2 = hashutil.plaintext_hasher()
617         h2.update("foo")
618         self.failUnlessEqual(h1, h2.digest())
619
620         h1 = hashutil.crypttext_hash("foo")
621         h2 = hashutil.crypttext_hasher()
622         h2.update("foo")
623         self.failUnlessEqual(h1, h2.digest())
624
625         h1 = hashutil.crypttext_segment_hash("foo")
626         h2 = hashutil.crypttext_segment_hasher()
627         h2.update("foo")
628         self.failUnlessEqual(h1, h2.digest())
629
630         h1 = hashutil.plaintext_segment_hash("foo")
631         h2 = hashutil.plaintext_segment_hasher()
632         h2.update("foo")
633         self.failUnlessEqual(h1, h2.digest())
634
635     def test_constant_time_compare(self):
636         self.failUnless(hashutil.constant_time_compare("a", "a"))
637         self.failUnless(hashutil.constant_time_compare("ab", "ab"))
638         self.failIf(hashutil.constant_time_compare("a", "b"))
639         self.failIf(hashutil.constant_time_compare("a", "aa"))
640
641     def _testknown(self, hashf, expected_a, *args):
642         got = hashf(*args)
643         got_a = base32.b2a(got)
644         self.failUnlessEqual(got_a, expected_a)
645
646     def test_known_answers(self):
647         # assert backwards compatibility
648         self._testknown(hashutil.storage_index_hash, "qb5igbhcc5esa6lwqorsy7e6am", "")
649         self._testknown(hashutil.block_hash, "msjr5bh4evuh7fa3zw7uovixfbvlnstr5b65mrerwfnvjxig2jvq", "")
650         self._testknown(hashutil.uri_extension_hash, "wthsu45q7zewac2mnivoaa4ulh5xvbzdmsbuyztq2a5fzxdrnkka", "")
651         self._testknown(hashutil.plaintext_hash, "5lz5hwz3qj3af7n6e3arblw7xzutvnd3p3fjsngqjcb7utf3x3da", "")
652         self._testknown(hashutil.crypttext_hash, "itdj6e4njtkoiavlrmxkvpreosscssklunhwtvxn6ggho4rkqwga", "")
653         self._testknown(hashutil.crypttext_segment_hash, "aovy5aa7jej6ym5ikgwyoi4pxawnoj3wtaludjz7e2nb5xijb7aa", "")
654         self._testknown(hashutil.plaintext_segment_hash, "4fdgf6qruaisyukhqcmoth4t3li6bkolbxvjy4awwcpprdtva7za", "")
655         self._testknown(hashutil.convergence_hash, "3mo6ni7xweplycin6nowynw2we", 3, 10, 100, "", "converge")
656         self._testknown(hashutil.my_renewal_secret_hash, "ujhr5k5f7ypkp67jkpx6jl4p47pyta7hu5m527cpcgvkafsefm6q", "")
657         self._testknown(hashutil.my_cancel_secret_hash, "rjwzmafe2duixvqy6h47f5wfrokdziry6zhx4smew4cj6iocsfaa", "")
658         self._testknown(hashutil.file_renewal_secret_hash, "hzshk2kf33gzbd5n3a6eszkf6q6o6kixmnag25pniusyaulqjnia", "", "si")
659         self._testknown(hashutil.file_cancel_secret_hash, "bfciwvr6w7wcavsngxzxsxxaszj72dej54n4tu2idzp6b74g255q", "", "si")
660         self._testknown(hashutil.bucket_renewal_secret_hash, "e7imrzgzaoashsncacvy3oysdd2m5yvtooo4gmj4mjlopsazmvuq", "", "\x00"*20)
661         self._testknown(hashutil.bucket_cancel_secret_hash, "dvdujeyxeirj6uux6g7xcf4lvesk632aulwkzjar7srildvtqwma", "", "\x00"*20)
662         self._testknown(hashutil.hmac, "c54ypfi6pevb3nvo6ba42jtglpkry2kbdopqsi7dgrm4r7tw5sra", "tag", "")
663         self._testknown(hashutil.mutable_rwcap_key_hash, "6rvn2iqrghii5n4jbbwwqqsnqu", "iv", "wk")
664         self._testknown(hashutil.ssk_writekey_hash, "ykpgmdbpgbb6yqz5oluw2q26ye", "")
665         self._testknown(hashutil.ssk_write_enabler_master_hash, "izbfbfkoait4dummruol3gy2bnixrrrslgye6ycmkuyujnenzpia", "")
666         self._testknown(hashutil.ssk_write_enabler_hash, "fuu2dvx7g6gqu5x22vfhtyed7p4pd47y5hgxbqzgrlyvxoev62tq", "wk", "\x00"*20)
667         self._testknown(hashutil.ssk_pubkey_fingerprint_hash, "3opzw4hhm2sgncjx224qmt5ipqgagn7h5zivnfzqycvgqgmgz35q", "")
668         self._testknown(hashutil.ssk_readkey_hash, "vugid4as6qbqgeq2xczvvcedai", "")
669         self._testknown(hashutil.ssk_readkey_data_hash, "73wsaldnvdzqaf7v4pzbr2ae5a", "iv", "rk")
670         self._testknown(hashutil.ssk_storage_index_hash, "j7icz6kigb6hxrej3tv4z7ayym", "")
671
672
673 class Abbreviate(unittest.TestCase):
674     def test_time(self):
675         a = abbreviate.abbreviate_time
676         self.failUnlessEqual(a(None), "unknown")
677         self.failUnlessEqual(a(0), "0 seconds")
678         self.failUnlessEqual(a(1), "1 second")
679         self.failUnlessEqual(a(2), "2 seconds")
680         self.failUnlessEqual(a(119), "119 seconds")
681         MIN = 60
682         self.failUnlessEqual(a(2*MIN), "2 minutes")
683         self.failUnlessEqual(a(60*MIN), "60 minutes")
684         self.failUnlessEqual(a(179*MIN), "179 minutes")
685         HOUR = 60*MIN
686         self.failUnlessEqual(a(180*MIN), "3 hours")
687         self.failUnlessEqual(a(4*HOUR), "4 hours")
688         DAY = 24*HOUR
689         MONTH = 30*DAY
690         self.failUnlessEqual(a(2*DAY), "2 days")
691         self.failUnlessEqual(a(2*MONTH), "2 months")
692         YEAR = 365*DAY
693         self.failUnlessEqual(a(5*YEAR), "5 years")
694
695     def test_space(self):
696         tests_si = [(None, "unknown"),
697                     (0, "0 B"),
698                     (1, "1 B"),
699                     (999, "999 B"),
700                     (1000, "1000 B"),
701                     (1023, "1023 B"),
702                     (1024, "1.02 kB"),
703                     (20*1000, "20.00 kB"),
704                     (1024*1024, "1.05 MB"),
705                     (1000*1000, "1.00 MB"),
706                     (1000*1000*1000, "1.00 GB"),
707                     (1000*1000*1000*1000, "1.00 TB"),
708                     (1000*1000*1000*1000*1000, "1.00 PB"),
709                     (1234567890123456, "1.23 PB"),
710                     ]
711         for (x, expected) in tests_si:
712             got = abbreviate.abbreviate_space(x, SI=True)
713             self.failUnlessEqual(got, expected)
714
715         tests_base1024 = [(None, "unknown"),
716                           (0, "0 B"),
717                           (1, "1 B"),
718                           (999, "999 B"),
719                           (1000, "1000 B"),
720                           (1023, "1023 B"),
721                           (1024, "1.00 kiB"),
722                           (20*1024, "20.00 kiB"),
723                           (1000*1000, "976.56 kiB"),
724                           (1024*1024, "1.00 MiB"),
725                           (1024*1024*1024, "1.00 GiB"),
726                           (1024*1024*1024*1024, "1.00 TiB"),
727                           (1000*1000*1000*1000*1000, "909.49 TiB"),
728                           (1024*1024*1024*1024*1024, "1.00 PiB"),
729                           (1234567890123456, "1.10 PiB"),
730                     ]
731         for (x, expected) in tests_base1024:
732             got = abbreviate.abbreviate_space(x, SI=False)
733             self.failUnlessEqual(got, expected)
734
735         self.failUnlessEqual(abbreviate.abbreviate_space_both(1234567),
736                              "(1.23 MB, 1.18 MiB)")
737
738     def test_parse_space(self):
739         p = abbreviate.parse_abbreviated_size
740         self.failUnlessEqual(p(""), None)
741         self.failUnlessEqual(p(None), None)
742         self.failUnlessEqual(p("123"), 123)
743         self.failUnlessEqual(p("123B"), 123)
744         self.failUnlessEqual(p("2K"), 2000)
745         self.failUnlessEqual(p("2kb"), 2000)
746         self.failUnlessEqual(p("2KiB"), 2048)
747         self.failUnlessEqual(p("10MB"), 10*1000*1000)
748         self.failUnlessEqual(p("10MiB"), 10*1024*1024)
749         self.failUnlessEqual(p("5G"), 5*1000*1000*1000)
750         self.failUnlessEqual(p("4GiB"), 4*1024*1024*1024)
751         e = self.failUnlessRaises(ValueError, p, "12 cubits")
752         self.failUnless("12 cubits" in str(e))
753
754 class Limiter(unittest.TestCase):
755     timeout = 480 # This takes longer than 240 seconds on Francois's arm box.
756
757     def job(self, i, foo):
758         self.calls.append( (i, foo) )
759         self.simultaneous += 1
760         self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
761         d = defer.Deferred()
762         def _done():
763             self.simultaneous -= 1
764             d.callback("done %d" % i)
765         reactor.callLater(1.0, _done)
766         return d
767
768     def bad_job(self, i, foo):
769         raise ValueError("bad_job %d" % i)
770
771     def test_limiter(self):
772         self.calls = []
773         self.simultaneous = 0
774         self.peak_simultaneous = 0
775         l = limiter.ConcurrencyLimiter()
776         dl = []
777         for i in range(20):
778             dl.append(l.add(self.job, i, foo=str(i)))
779         d = defer.DeferredList(dl, fireOnOneErrback=True)
780         def _done(res):
781             self.failUnlessEqual(self.simultaneous, 0)
782             self.failUnless(self.peak_simultaneous <= 10)
783             self.failUnlessEqual(len(self.calls), 20)
784             for i in range(20):
785                 self.failUnless( (i, str(i)) in self.calls)
786         d.addCallback(_done)
787         return d
788
789     def test_errors(self):
790         self.calls = []
791         self.simultaneous = 0
792         self.peak_simultaneous = 0
793         l = limiter.ConcurrencyLimiter()
794         dl = []
795         for i in range(20):
796             dl.append(l.add(self.job, i, foo=str(i)))
797         d2 = l.add(self.bad_job, 21, "21")
798         d = defer.DeferredList(dl, fireOnOneErrback=True)
799         def _most_done(res):
800             results = []
801             for (success, result) in res:
802                 self.failUnlessEqual(success, True)
803                 results.append(result)
804             results.sort()
805             expected_results = ["done %d" % i for i in range(20)]
806             expected_results.sort()
807             self.failUnlessEqual(results, expected_results)
808             self.failUnless(self.peak_simultaneous <= 10)
809             self.failUnlessEqual(len(self.calls), 20)
810             for i in range(20):
811                 self.failUnless( (i, str(i)) in self.calls)
812             def _good(res):
813                 self.fail("should have failed, not got %s" % (res,))
814             def _err(f):
815                 f.trap(ValueError)
816                 self.failUnless("bad_job 21" in str(f))
817             d2.addCallbacks(_good, _err)
818             return d2
819         d.addCallback(_most_done)
820         def _all_done(res):
821             self.failUnlessEqual(self.simultaneous, 0)
822             self.failUnless(self.peak_simultaneous <= 10)
823             self.failUnlessEqual(len(self.calls), 20)
824             for i in range(20):
825                 self.failUnless( (i, str(i)) in self.calls)
826         d.addCallback(_all_done)
827         return d
828
829 class TimeFormat(unittest.TestCase):
830     def test_epoch(self):
831         return self._help_test_epoch()
832
833     def test_epoch_in_London(self):
834         # Europe/London is a particularly troublesome timezone.  Nowadays, its
835         # offset from GMT is 0.  But in 1970, its offset from GMT was 1.
836         # (Apparently in 1970 Britain had redefined standard time to be GMT+1
837         # and stayed in standard time all year round, whereas today
838         # Europe/London standard time is GMT and Europe/London Daylight
839         # Savings Time is GMT+1.)  The current implementation of
840         # time_format.iso_utc_time_to_localseconds() breaks if the timezone is
841         # Europe/London.  (As soon as this unit test is done then I'll change
842         # that implementation to something that works even in this case...)
843         origtz = os.environ.get('TZ')
844         os.environ['TZ'] = "Europe/London"
845         if hasattr(time, 'tzset'):
846             time.tzset()
847         try:
848             return self._help_test_epoch()
849         finally:
850             if origtz is None:
851                 del os.environ['TZ']
852             else:
853                 os.environ['TZ'] = origtz
854             if hasattr(time, 'tzset'):
855                 time.tzset()
856
857     def _help_test_epoch(self):
858         origtzname = time.tzname
859         s = time_format.iso_utc_time_to_seconds("1970-01-01T00:00:01")
860         self.failUnlessEqual(s, 1.0)
861         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01")
862         self.failUnlessEqual(s, 1.0)
863         s = time_format.iso_utc_time_to_seconds("1970-01-01 00:00:01")
864         self.failUnlessEqual(s, 1.0)
865
866         self.failUnlessEqual(time_format.iso_utc(1.0), "1970-01-01_00:00:01")
867         self.failUnlessEqual(time_format.iso_utc(1.0, sep=" "),
868                              "1970-01-01 00:00:01")
869
870         now = time.time()
871         isostr = time_format.iso_utc(now)
872         timestamp = time_format.iso_utc_time_to_seconds(isostr)
873         self.failUnlessEqual(int(timestamp), int(now))
874
875         def my_time():
876             return 1.0
877         self.failUnlessEqual(time_format.iso_utc(t=my_time),
878                              "1970-01-01_00:00:01")
879         e = self.failUnlessRaises(ValueError,
880                                   time_format.iso_utc_time_to_seconds,
881                                   "invalid timestring")
882         self.failUnless("not a complete ISO8601 timestamp" in str(e))
883         s = time_format.iso_utc_time_to_seconds("1970-01-01_00:00:01.500")
884         self.failUnlessEqual(s, 1.5)
885
886         # Look for daylight-savings-related errors.
887         thatmomentinmarch = time_format.iso_utc_time_to_seconds("2009-03-20 21:49:02.226536")
888         self.failUnlessEqual(thatmomentinmarch, 1237585742.226536)
889         self.failUnlessEqual(origtzname, time.tzname)
890
891     def test_iso_utc(self):
892         when = 1266760143.7841301
893         out = time_format.iso_utc_date(when)
894         self.failUnlessEqual(out, "2010-02-21")
895         out = time_format.iso_utc_date(t=lambda: when)
896         self.failUnlessEqual(out, "2010-02-21")
897         out = time_format.iso_utc(when)
898         self.failUnlessEqual(out, "2010-02-21_13:49:03.784130")
899         out = time_format.iso_utc(when, sep="-")
900         self.failUnlessEqual(out, "2010-02-21-13:49:03.784130")
901
902     def test_parse_duration(self):
903         p = time_format.parse_duration
904         DAY = 24*60*60
905         self.failUnlessEqual(p("1 day"), DAY)
906         self.failUnlessEqual(p("2 days"), 2*DAY)
907         self.failUnlessEqual(p("3 months"), 3*31*DAY)
908         self.failUnlessEqual(p("4 mo"), 4*31*DAY)
909         self.failUnlessEqual(p("5 years"), 5*365*DAY)
910         e = self.failUnlessRaises(ValueError, p, "123")
911         self.failUnlessIn("no unit (like day, month, or year) in '123'",
912                           str(e))
913
914     def test_parse_date(self):
915         self.failUnlessEqual(time_format.parse_date("2010-02-21"), 1266710400)
916
917 class CacheDir(unittest.TestCase):
918     def test_basic(self):
919         basedir = "test_util/CacheDir/test_basic"
920
921         def _failIfExists(name):
922             absfn = os.path.join(basedir, name)
923             self.failIf(os.path.exists(absfn),
924                         "%s exists but it shouldn't" % absfn)
925
926         def _failUnlessExists(name):
927             absfn = os.path.join(basedir, name)
928             self.failUnless(os.path.exists(absfn),
929                             "%s doesn't exist but it should" % absfn)
930
931         cdm = cachedir.CacheDirectoryManager(basedir)
932         a = cdm.get_file("a")
933         b = cdm.get_file("b")
934         c = cdm.get_file("c")
935         f = open(a.get_filename(), "wb"); f.write("hi"); f.close(); del f
936         f = open(b.get_filename(), "wb"); f.write("hi"); f.close(); del f
937         f = open(c.get_filename(), "wb"); f.write("hi"); f.close(); del f
938
939         _failUnlessExists("a")
940         _failUnlessExists("b")
941         _failUnlessExists("c")
942
943         cdm.check()
944
945         _failUnlessExists("a")
946         _failUnlessExists("b")
947         _failUnlessExists("c")
948
949         del a
950         # this file won't be deleted yet, because it isn't old enough
951         cdm.check()
952         _failUnlessExists("a")
953         _failUnlessExists("b")
954         _failUnlessExists("c")
955
956         # we change the definition of "old" to make everything old
957         cdm.old = -10
958
959         cdm.check()
960         _failIfExists("a")
961         _failUnlessExists("b")
962         _failUnlessExists("c")
963
964         cdm.old = 60*60
965
966         del b
967
968         cdm.check()
969         _failIfExists("a")
970         _failUnlessExists("b")
971         _failUnlessExists("c")
972
973         b2 = cdm.get_file("b")
974
975         cdm.check()
976         _failIfExists("a")
977         _failUnlessExists("b")
978         _failUnlessExists("c")
979         del b2
980
981 ctr = [0]
982 class EqButNotIs:
983     def __init__(self, x):
984         self.x = x
985         self.hash = ctr[0]
986         ctr[0] += 1
987     def __repr__(self):
988         return "<%s %s>" % (self.__class__.__name__, self.x,)
989     def __hash__(self):
990         return self.hash
991     def __le__(self, other):
992         return self.x <= other
993     def __lt__(self, other):
994         return self.x < other
995     def __ge__(self, other):
996         return self.x >= other
997     def __gt__(self, other):
998         return self.x > other
999     def __ne__(self, other):
1000         return self.x != other
1001     def __eq__(self, other):
1002         return self.x == other
1003
1004 class DictUtil(unittest.TestCase):
1005     def _help_test_empty_dict(self, klass):
1006         d1 = klass()
1007         d2 = klass({})
1008
1009         self.failUnless(d1 == d2, "d1: %r, d2: %r" % (d1, d2,))
1010         self.failUnless(len(d1) == 0)
1011         self.failUnless(len(d2) == 0)
1012
1013     def _help_test_nonempty_dict(self, klass):
1014         d1 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1015         d2 = klass({'a': 1, 'b': "eggs", 3: "spam",})
1016
1017         self.failUnless(d1 == d2)
1018         self.failUnless(len(d1) == 3, "%s, %s" % (len(d1), d1,))
1019         self.failUnless(len(d2) == 3)
1020
1021     def _help_test_eq_but_notis(self, klass):
1022         d = klass({'a': 3, 'b': EqButNotIs(3), 'c': 3})
1023         d.pop('b')
1024
1025         d.clear()
1026         d['a'] = 3
1027         d['b'] = EqButNotIs(3)
1028         d['c'] = 3
1029         d.pop('b')
1030
1031         d.clear()
1032         d['b'] = EqButNotIs(3)
1033         d['a'] = 3
1034         d['c'] = 3
1035         d.pop('b')
1036
1037         d.clear()
1038         d['a'] = EqButNotIs(3)
1039         d['c'] = 3
1040         d['a'] = 3
1041
1042         d.clear()
1043         fake3 = EqButNotIs(3)
1044         fake7 = EqButNotIs(7)
1045         d[fake3] = fake7
1046         d[3] = 7
1047         d[3] = 8
1048         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1049         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1050         # The real 7 should have been ejected by the d[3] = 8.
1051         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1052         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1053         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1054         d[fake3] = 8
1055
1056         d.clear()
1057         d[3] = 7
1058         fake3 = EqButNotIs(3)
1059         fake7 = EqButNotIs(7)
1060         d[fake3] = fake7
1061         d[3] = 8
1062         self.failUnless(filter(lambda x: x is 8,  d.itervalues()))
1063         self.failUnless(filter(lambda x: x is fake7,  d.itervalues()))
1064         # The real 7 should have been ejected by the d[3] = 8.
1065         self.failUnless(not filter(lambda x: x is 7,  d.itervalues()))
1066         self.failUnless(filter(lambda x: x is fake3,  d.iterkeys()))
1067         self.failUnless(filter(lambda x: x is 3,  d.iterkeys()))
1068         d[fake3] = 8
1069
1070     def test_all(self):
1071         self._help_test_eq_but_notis(dictutil.UtilDict)
1072         self._help_test_eq_but_notis(dictutil.NumDict)
1073         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1074         self._help_test_nonempty_dict(dictutil.UtilDict)
1075         self._help_test_nonempty_dict(dictutil.NumDict)
1076         self._help_test_nonempty_dict(dictutil.ValueOrderedDict)
1077         self._help_test_eq_but_notis(dictutil.UtilDict)
1078         self._help_test_eq_but_notis(dictutil.NumDict)
1079         self._help_test_eq_but_notis(dictutil.ValueOrderedDict)
1080
1081     def test_dict_of_sets(self):
1082         ds = dictutil.DictOfSets()
1083         ds.add(1, "a")
1084         ds.add(2, "b")
1085         ds.add(2, "b")
1086         ds.add(2, "c")
1087         self.failUnlessEqual(ds[1], set(["a"]))
1088         self.failUnlessEqual(ds[2], set(["b", "c"]))
1089         ds.discard(3, "d") # should not raise an exception
1090         ds.discard(2, "b")
1091         self.failUnlessEqual(ds[2], set(["c"]))
1092         ds.discard(2, "c")
1093         self.failIf(2 in ds)
1094
1095         ds.union(1, ["a", "e"])
1096         ds.union(3, ["f"])
1097         self.failUnlessEqual(ds[1], set(["a","e"]))
1098         self.failUnlessEqual(ds[3], set(["f"]))
1099         ds2 = dictutil.DictOfSets()
1100         ds2.add(3, "f")
1101         ds2.add(3, "g")
1102         ds2.add(4, "h")
1103         ds.update(ds2)
1104         self.failUnlessEqual(ds[1], set(["a","e"]))
1105         self.failUnlessEqual(ds[3], set(["f", "g"]))
1106         self.failUnlessEqual(ds[4], set(["h"]))
1107
1108     def test_move(self):
1109         d1 = {1: "a", 2: "b"}
1110         d2 = {2: "c", 3: "d"}
1111         dictutil.move(1, d1, d2)
1112         self.failUnlessEqual(d1, {2: "b"})
1113         self.failUnlessEqual(d2, {1: "a", 2: "c", 3: "d"})
1114
1115         d1 = {1: "a", 2: "b"}
1116         d2 = {2: "c", 3: "d"}
1117         dictutil.move(2, d1, d2)
1118         self.failUnlessEqual(d1, {1: "a"})
1119         self.failUnlessEqual(d2, {2: "b", 3: "d"})
1120
1121         d1 = {1: "a", 2: "b"}
1122         d2 = {2: "c", 3: "d"}
1123         self.failUnlessRaises(KeyError, dictutil.move, 5, d1, d2, strict=True)
1124
1125     def test_subtract(self):
1126         d1 = {1: "a", 2: "b"}
1127         d2 = {2: "c", 3: "d"}
1128         d3 = dictutil.subtract(d1, d2)
1129         self.failUnlessEqual(d3, {1: "a"})
1130
1131         d1 = {1: "a", 2: "b"}
1132         d2 = {2: "c"}
1133         d3 = dictutil.subtract(d1, d2)
1134         self.failUnlessEqual(d3, {1: "a"})
1135
1136     def test_utildict(self):
1137         d = dictutil.UtilDict({1: "a", 2: "b"})
1138         d.del_if_present(1)
1139         d.del_if_present(3)
1140         self.failUnlessEqual(d, {2: "b"})
1141         def eq(a, b):
1142             return a == b
1143         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1144
1145         d = dictutil.UtilDict({1: "b", 2: "a"})
1146         self.failUnlessEqual(d.items_sorted_by_value(),
1147                              [(2, "a"), (1, "b")])
1148         self.failUnlessEqual(d.items_sorted_by_key(),
1149                              [(1, "b"), (2, "a")])
1150         self.failUnlessEqual(repr(d), "{1: 'b', 2: 'a'}")
1151         self.failUnless(1 in d)
1152
1153         d2 = dictutil.UtilDict({3: "c", 4: "d"})
1154         self.failUnless(d != d2)
1155         self.failUnless(d2 > d)
1156         self.failUnless(d2 >= d)
1157         self.failUnless(d <= d2)
1158         self.failUnless(d < d2)
1159         self.failUnlessEqual(d[1], "b")
1160         self.failUnlessEqual(sorted(list([k for k in d])), [1,2])
1161
1162         d3 = d.copy()
1163         self.failUnlessEqual(d, d3)
1164         self.failUnless(isinstance(d3, dictutil.UtilDict))
1165
1166         d4 = d.fromkeys([3,4], "e")
1167         self.failUnlessEqual(d4, {3: "e", 4: "e"})
1168
1169         self.failUnlessEqual(d.get(1), "b")
1170         self.failUnlessEqual(d.get(3), None)
1171         self.failUnlessEqual(d.get(3, "default"), "default")
1172         self.failUnlessEqual(sorted(list(d.items())),
1173                              [(1, "b"), (2, "a")])
1174         self.failUnlessEqual(sorted(list(d.iteritems())),
1175                              [(1, "b"), (2, "a")])
1176         self.failUnlessEqual(sorted(d.keys()), [1, 2])
1177         self.failUnlessEqual(sorted(d.values()), ["a", "b"])
1178         x = d.setdefault(1, "new")
1179         self.failUnlessEqual(x, "b")
1180         self.failUnlessEqual(d[1], "b")
1181         x = d.setdefault(3, "new")
1182         self.failUnlessEqual(x, "new")
1183         self.failUnlessEqual(d[3], "new")
1184         del d[3]
1185
1186         x = d.popitem()
1187         self.failUnless(x in [(1, "b"), (2, "a")])
1188         x = d.popitem()
1189         self.failUnless(x in [(1, "b"), (2, "a")])
1190         self.failUnlessRaises(KeyError, d.popitem)
1191
1192     def test_numdict(self):
1193         d = dictutil.NumDict({"a": 1, "b": 2})
1194
1195         d.add_num("a", 10, 5)
1196         d.add_num("c", 20, 5)
1197         d.add_num("d", 30)
1198         self.failUnlessEqual(d, {"a": 11, "b": 2, "c": 25, "d": 30})
1199
1200         d.subtract_num("a", 10)
1201         d.subtract_num("e", 10)
1202         d.subtract_num("f", 10, 15)
1203         self.failUnlessEqual(d, {"a": 1, "b": 2, "c": 25, "d": 30,
1204                                  "e": -10, "f": 5})
1205
1206         self.failUnlessEqual(d.sum(), sum([1, 2, 25, 30, -10, 5]))
1207
1208         d = dictutil.NumDict()
1209         d.inc("a")
1210         d.inc("a")
1211         d.inc("b", 5)
1212         self.failUnlessEqual(d, {"a": 2, "b": 6})
1213         d.dec("a")
1214         d.dec("c")
1215         d.dec("d", 5)
1216         self.failUnlessEqual(d, {"a": 1, "b": 6, "c": -1, "d": 4})
1217         self.failUnlessEqual(d.items_sorted_by_key(),
1218                              [("a", 1), ("b", 6), ("c", -1), ("d", 4)])
1219         self.failUnlessEqual(d.items_sorted_by_value(),
1220                              [("c", -1), ("a", 1), ("d", 4), ("b", 6)])
1221         self.failUnlessEqual(d.item_with_largest_value(), ("b", 6))
1222
1223         d = dictutil.NumDict({"a": 1, "b": 2})
1224         self.failUnlessEqual(repr(d), "{'a': 1, 'b': 2}")
1225         self.failUnless("a" in d)
1226
1227         d2 = dictutil.NumDict({"c": 3, "d": 4})
1228         self.failUnless(d != d2)
1229         self.failUnless(d2 > d)
1230         self.failUnless(d2 >= d)
1231         self.failUnless(d <= d2)
1232         self.failUnless(d < d2)
1233         self.failUnlessEqual(d["a"], 1)
1234         self.failUnlessEqual(sorted(list([k for k in d])), ["a","b"])
1235         def eq(a, b):
1236             return a == b
1237         self.failUnlessRaises(TypeError, eq, d, "not a dict")
1238
1239         d3 = d.copy()
1240         self.failUnlessEqual(d, d3)
1241         self.failUnless(isinstance(d3, dictutil.NumDict))
1242
1243         d4 = d.fromkeys(["a","b"], 5)
1244         self.failUnlessEqual(d4, {"a": 5, "b": 5})
1245
1246         self.failUnlessEqual(d.get("a"), 1)
1247         self.failUnlessEqual(d.get("c"), 0)
1248         self.failUnlessEqual(d.get("c", 5), 5)
1249         self.failUnlessEqual(sorted(list(d.items())),
1250                              [("a", 1), ("b", 2)])
1251         self.failUnlessEqual(sorted(list(d.iteritems())),
1252                              [("a", 1), ("b", 2)])
1253         self.failUnlessEqual(sorted(d.keys()), ["a", "b"])
1254         self.failUnlessEqual(sorted(d.values()), [1, 2])
1255         self.failUnless(d.has_key("a"))
1256         self.failIf(d.has_key("c"))
1257
1258         x = d.setdefault("c", 3)
1259         self.failUnlessEqual(x, 3)
1260         self.failUnlessEqual(d["c"], 3)
1261         x = d.setdefault("c", 5)
1262         self.failUnlessEqual(x, 3)
1263         self.failUnlessEqual(d["c"], 3)
1264         del d["c"]
1265
1266         x = d.popitem()
1267         self.failUnless(x in [("a", 1), ("b", 2)])
1268         x = d.popitem()
1269         self.failUnless(x in [("a", 1), ("b", 2)])
1270         self.failUnlessRaises(KeyError, d.popitem)
1271
1272         d.update({"c": 3})
1273         d.update({"c": 4, "d": 5})
1274         self.failUnlessEqual(d, {"c": 4, "d": 5})
1275
1276     def test_del_if_present(self):
1277         d = {1: "a", 2: "b"}
1278         dictutil.del_if_present(d, 1)
1279         dictutil.del_if_present(d, 3)
1280         self.failUnlessEqual(d, {2: "b"})
1281
1282     def test_valueordereddict(self):
1283         d = dictutil.ValueOrderedDict()
1284         d["a"] = 3
1285         d["b"] = 2
1286         d["c"] = 1
1287
1288         self.failUnlessEqual(d, {"a": 3, "b": 2, "c": 1})
1289         self.failUnlessEqual(d.items(), [("c", 1), ("b", 2), ("a", 3)])
1290         self.failUnlessEqual(d.values(), [1, 2, 3])
1291         self.failUnlessEqual(d.keys(), ["c", "b", "a"])
1292         self.failUnlessEqual(repr(d), "<ValueOrderedDict {c: 1, b: 2, a: 3}>")
1293         def eq(a, b):
1294             return a == b
1295         self.failIf(d == {"a": 4})
1296         self.failUnless(d != {"a": 4})
1297
1298         x = d.setdefault("d", 0)
1299         self.failUnlessEqual(x, 0)
1300         self.failUnlessEqual(d["d"], 0)
1301         x = d.setdefault("d", -1)
1302         self.failUnlessEqual(x, 0)
1303         self.failUnlessEqual(d["d"], 0)
1304
1305         x = d.remove("e", "default", False)
1306         self.failUnlessEqual(x, "default")
1307         self.failUnlessRaises(KeyError, d.remove, "e", "default", True)
1308         x = d.remove("d", 5)
1309         self.failUnlessEqual(x, 0)
1310
1311         x = d.__getitem__("c")
1312         self.failUnlessEqual(x, 1)
1313         x = d.__getitem__("e", "default", False)
1314         self.failUnlessEqual(x, "default")
1315         self.failUnlessRaises(KeyError, d.__getitem__, "e", "default", True)
1316
1317         self.failUnlessEqual(d.popitem(), ("c", 1))
1318         self.failUnlessEqual(d.popitem(), ("b", 2))
1319         self.failUnlessEqual(d.popitem(), ("a", 3))
1320         self.failUnlessRaises(KeyError, d.popitem)
1321
1322         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1323         x = d.pop("d", "default", False)
1324         self.failUnlessEqual(x, "default")
1325         self.failUnlessRaises(KeyError, d.pop, "d", "default", True)
1326         x = d.pop("b")
1327         self.failUnlessEqual(x, 2)
1328         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1329
1330         d = dictutil.ValueOrderedDict({"a": 3, "b": 2, "c": 1})
1331         x = d.pop_from_list(1) # pop the second item, b/2
1332         self.failUnlessEqual(x, "b")
1333         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
1334
1335     def test_auxdict(self):
1336         d = dictutil.AuxValueDict()
1337         # we put the serialized form in the auxdata
1338         d.set_with_aux("key", ("filecap", "metadata"), "serialized")
1339
1340         self.failUnlessEqual(d.keys(), ["key"])
1341         self.failUnlessEqual(d["key"], ("filecap", "metadata"))
1342         self.failUnlessEqual(d.get_aux("key"), "serialized")
1343         def _get_missing(key):
1344             return d[key]
1345         self.failUnlessRaises(KeyError, _get_missing, "nonkey")
1346         self.failUnlessEqual(d.get("nonkey"), None)
1347         self.failUnlessEqual(d.get("nonkey", "nonvalue"), "nonvalue")
1348         self.failUnlessEqual(d.get_aux("nonkey"), None)
1349         self.failUnlessEqual(d.get_aux("nonkey", "nonvalue"), "nonvalue")
1350
1351         d["key"] = ("filecap2", "metadata2")
1352         self.failUnlessEqual(d["key"], ("filecap2", "metadata2"))
1353         self.failUnlessEqual(d.get_aux("key"), None)
1354
1355         d.set_with_aux("key2", "value2", "aux2")
1356         self.failUnlessEqual(sorted(d.keys()), ["key", "key2"])
1357         del d["key2"]
1358         self.failUnlessEqual(d.keys(), ["key"])
1359         self.failIf("key2" in d)
1360         self.failUnlessRaises(KeyError, _get_missing, "key2")
1361         self.failUnlessEqual(d.get("key2"), None)
1362         self.failUnlessEqual(d.get_aux("key2"), None)
1363         d["key2"] = "newvalue2"
1364         self.failUnlessEqual(d.get("key2"), "newvalue2")
1365         self.failUnlessEqual(d.get_aux("key2"), None)
1366
1367         d = dictutil.AuxValueDict({1:2,3:4})
1368         self.failUnlessEqual(sorted(d.keys()), [1,3])
1369         self.failUnlessEqual(d[1], 2)
1370         self.failUnlessEqual(d.get_aux(1), None)
1371
1372         d = dictutil.AuxValueDict([ (1,2), (3,4) ])
1373         self.failUnlessEqual(sorted(d.keys()), [1,3])
1374         self.failUnlessEqual(d[1], 2)
1375         self.failUnlessEqual(d.get_aux(1), None)
1376
1377         d = dictutil.AuxValueDict(one=1, two=2)
1378         self.failUnlessEqual(sorted(d.keys()), ["one","two"])
1379         self.failUnlessEqual(d["one"], 1)
1380         self.failUnlessEqual(d.get_aux("one"), None)
1381
1382 class Pipeline(unittest.TestCase):
1383     def pause(self, *args, **kwargs):
1384         d = defer.Deferred()
1385         self.calls.append( (d, args, kwargs) )
1386         return d
1387
1388     def failUnlessCallsAre(self, expected):
1389         #print self.calls
1390         #print expected
1391         self.failUnlessEqual(len(self.calls), len(expected), self.calls)
1392         for i,c in enumerate(self.calls):
1393             self.failUnlessEqual(c[1:], expected[i], str(i))
1394
1395     def test_basic(self):
1396         self.calls = []
1397         finished = []
1398         p = pipeline.Pipeline(100)
1399
1400         d = p.flush() # fires immediately
1401         d.addCallbacks(finished.append, log.err)
1402         self.failUnlessEqual(len(finished), 1)
1403         finished = []
1404
1405         d = p.add(10, self.pause, "one")
1406         # the call should start right away, and our return Deferred should
1407         # fire right away
1408         d.addCallbacks(finished.append, log.err)
1409         self.failUnlessEqual(len(finished), 1)
1410         self.failUnlessEqual(finished[0], None)
1411         self.failUnlessCallsAre([ ( ("one",) , {} ) ])
1412         self.failUnlessEqual(p.gauge, 10)
1413
1414         # pipeline: [one]
1415
1416         finished = []
1417         d = p.add(20, self.pause, "two", kw=2)
1418         # pipeline: [one, two]
1419
1420         # the call and the Deferred should fire right away
1421         d.addCallbacks(finished.append, log.err)
1422         self.failUnlessEqual(len(finished), 1)
1423         self.failUnlessEqual(finished[0], None)
1424         self.failUnlessCallsAre([ ( ("one",) , {} ),
1425                                   ( ("two",) , {"kw": 2} ),
1426                                   ])
1427         self.failUnlessEqual(p.gauge, 30)
1428
1429         self.calls[0][0].callback("one-result")
1430         # pipeline: [two]
1431         self.failUnlessEqual(p.gauge, 20)
1432
1433         finished = []
1434         d = p.add(90, self.pause, "three", "posarg1")
1435         # pipeline: [two, three]
1436         flushed = []
1437         fd = p.flush()
1438         fd.addCallbacks(flushed.append, log.err)
1439         self.failUnlessEqual(flushed, [])
1440
1441         # the call will be made right away, but the return Deferred will not,
1442         # because the pipeline is now full.
1443         d.addCallbacks(finished.append, log.err)
1444         self.failUnlessEqual(len(finished), 0)
1445         self.failUnlessCallsAre([ ( ("one",) , {} ),
1446                                   ( ("two",) , {"kw": 2} ),
1447                                   ( ("three", "posarg1"), {} ),
1448                                   ])
1449         self.failUnlessEqual(p.gauge, 110)
1450
1451         self.failUnlessRaises(pipeline.SingleFileError, p.add, 10, self.pause)
1452
1453         # retiring either call will unblock the pipeline, causing the #3
1454         # Deferred to fire
1455         self.calls[2][0].callback("three-result")
1456         # pipeline: [two]
1457
1458         self.failUnlessEqual(len(finished), 1)
1459         self.failUnlessEqual(finished[0], None)
1460         self.failUnlessEqual(flushed, [])
1461
1462         # retiring call#2 will finally allow the flush() Deferred to fire
1463         self.calls[1][0].callback("two-result")
1464         self.failUnlessEqual(len(flushed), 1)
1465
1466     def test_errors(self):
1467         self.calls = []
1468         p = pipeline.Pipeline(100)
1469
1470         d1 = p.add(200, self.pause, "one")
1471         d2 = p.flush()
1472
1473         finished = []
1474         d1.addBoth(finished.append)
1475         self.failUnlessEqual(finished, [])
1476
1477         flushed = []
1478         d2.addBoth(flushed.append)
1479         self.failUnlessEqual(flushed, [])
1480
1481         self.calls[0][0].errback(ValueError("oops"))
1482
1483         self.failUnlessEqual(len(finished), 1)
1484         f = finished[0]
1485         self.failUnless(isinstance(f, Failure))
1486         self.failUnless(f.check(pipeline.PipelineError))
1487         self.failUnlessIn("PipelineError", str(f.value))
1488         self.failUnlessIn("ValueError", str(f.value))
1489         r = repr(f.value)
1490         self.failUnless("ValueError" in r, r)
1491         f2 = f.value.error
1492         self.failUnless(f2.check(ValueError))
1493
1494         self.failUnlessEqual(len(flushed), 1)
1495         f = flushed[0]
1496         self.failUnless(isinstance(f, Failure))
1497         self.failUnless(f.check(pipeline.PipelineError))
1498         f2 = f.value.error
1499         self.failUnless(f2.check(ValueError))
1500
1501         # now that the pipeline is in the failed state, any new calls will
1502         # fail immediately
1503
1504         d3 = p.add(20, self.pause, "two")
1505
1506         finished = []
1507         d3.addBoth(finished.append)
1508         self.failUnlessEqual(len(finished), 1)
1509         f = finished[0]
1510         self.failUnless(isinstance(f, Failure))
1511         self.failUnless(f.check(pipeline.PipelineError))
1512         r = repr(f.value)
1513         self.failUnless("ValueError" in r, r)
1514         f2 = f.value.error
1515         self.failUnless(f2.check(ValueError))
1516
1517         d4 = p.flush()
1518         flushed = []
1519         d4.addBoth(flushed.append)
1520         self.failUnlessEqual(len(flushed), 1)
1521         f = flushed[0]
1522         self.failUnless(isinstance(f, Failure))
1523         self.failUnless(f.check(pipeline.PipelineError))
1524         f2 = f.value.error
1525         self.failUnless(f2.check(ValueError))
1526
1527     def test_errors2(self):
1528         self.calls = []
1529         p = pipeline.Pipeline(100)
1530
1531         d1 = p.add(10, self.pause, "one")
1532         d2 = p.add(20, self.pause, "two")
1533         d3 = p.add(30, self.pause, "three")
1534         d4 = p.flush()
1535
1536         # one call fails, then the second one succeeds: make sure
1537         # ExpandableDeferredList tolerates the second one
1538
1539         flushed = []
1540         d4.addBoth(flushed.append)
1541         self.failUnlessEqual(flushed, [])
1542
1543         self.calls[0][0].errback(ValueError("oops"))
1544         self.failUnlessEqual(len(flushed), 1)
1545         f = flushed[0]
1546         self.failUnless(isinstance(f, Failure))
1547         self.failUnless(f.check(pipeline.PipelineError))
1548         f2 = f.value.error
1549         self.failUnless(f2.check(ValueError))
1550
1551         self.calls[1][0].callback("two-result")
1552         self.calls[2][0].errback(ValueError("three-error"))
1553
1554         del d1,d2,d3,d4
1555
1556 class SampleError(Exception):
1557     pass
1558
1559 class Log(unittest.TestCase):
1560     def test_err(self):
1561         if not hasattr(self, "flushLoggedErrors"):
1562             # without flushLoggedErrors, we can't get rid of the
1563             # twisted.log.err that tahoe_log records, so we can't keep this
1564             # test from [ERROR]ing
1565             raise unittest.SkipTest("needs flushLoggedErrors from Twisted-2.5.0")
1566         try:
1567             raise SampleError("simple sample")
1568         except:
1569             f = Failure()
1570         tahoe_log.err(format="intentional sample error",
1571                       failure=f, level=tahoe_log.OPERATIONAL, umid="wO9UoQ")
1572         self.flushLoggedErrors(SampleError)
1573
1574
1575 class SimpleSpans:
1576     # this is a simple+inefficient form of util.spans.Spans . We compare the
1577     # behavior of this reference model against the real (efficient) form.
1578
1579     def __init__(self, _span_or_start=None, length=None):
1580         self._have = set()
1581         if length is not None:
1582             for i in range(_span_or_start, _span_or_start+length):
1583                 self._have.add(i)
1584         elif _span_or_start:
1585             for (start,length) in _span_or_start:
1586                 self.add(start, length)
1587
1588     def add(self, start, length):
1589         for i in range(start, start+length):
1590             self._have.add(i)
1591         return self
1592
1593     def remove(self, start, length):
1594         for i in range(start, start+length):
1595             self._have.discard(i)
1596         return self
1597
1598     def each(self):
1599         return sorted(self._have)
1600
1601     def __iter__(self):
1602         items = sorted(self._have)
1603         prevstart = None
1604         prevend = None
1605         for i in items:
1606             if prevstart is None:
1607                 prevstart = prevend = i
1608                 continue
1609             if i == prevend+1:
1610                 prevend = i
1611                 continue
1612             yield (prevstart, prevend-prevstart+1)
1613             prevstart = prevend = i
1614         if prevstart is not None:
1615             yield (prevstart, prevend-prevstart+1)
1616
1617     def __nonzero__(self): # this gets us bool()
1618         return self.len()
1619
1620     def len(self):
1621         return len(self._have)
1622
1623     def __add__(self, other):
1624         s = self.__class__(self)
1625         for (start, length) in other:
1626             s.add(start, length)
1627         return s
1628
1629     def __sub__(self, other):
1630         s = self.__class__(self)
1631         for (start, length) in other:
1632             s.remove(start, length)
1633         return s
1634
1635     def __iadd__(self, other):
1636         for (start, length) in other:
1637             self.add(start, length)
1638         return self
1639
1640     def __isub__(self, other):
1641         for (start, length) in other:
1642             self.remove(start, length)
1643         return self
1644
1645     def __and__(self, other):
1646         s = self.__class__()
1647         for i in other.each():
1648             if i in self._have:
1649                 s.add(i, 1)
1650         return s
1651
1652     def __contains__(self, (start,length)):
1653         for i in range(start, start+length):
1654             if i not in self._have:
1655                 return False
1656         return True
1657
1658 class ByteSpans(unittest.TestCase):
1659     def test_basic(self):
1660         s = Spans()
1661         self.failUnlessEqual(list(s), [])
1662         self.failIf(s)
1663         self.failIf((0,1) in s)
1664         self.failUnlessEqual(s.len(), 0)
1665
1666         s1 = Spans(3, 4) # 3,4,5,6
1667         self._check1(s1)
1668
1669         s2 = Spans(s1)
1670         self._check1(s2)
1671
1672         s2.add(10,2) # 10,11
1673         self._check1(s1)
1674         self.failUnless((10,1) in s2)
1675         self.failIf((10,1) in s1)
1676         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11])
1677         self.failUnlessEqual(s2.len(), 6)
1678
1679         s2.add(15,2).add(20,2)
1680         self.failUnlessEqual(list(s2.each()), [3,4,5,6,10,11,15,16,20,21])
1681         self.failUnlessEqual(s2.len(), 10)
1682
1683         s2.remove(4,3).remove(15,1)
1684         self.failUnlessEqual(list(s2.each()), [3,10,11,16,20,21])
1685         self.failUnlessEqual(s2.len(), 6)
1686
1687         s1 = SimpleSpans(3, 4) # 3 4 5 6
1688         s2 = SimpleSpans(5, 4) # 5 6 7 8
1689         i = s1 & s2
1690         self.failUnlessEqual(list(i.each()), [5, 6])
1691
1692     def _check1(self, s):
1693         self.failUnlessEqual(list(s), [(3,4)])
1694         self.failUnless(s)
1695         self.failUnlessEqual(s.len(), 4)
1696         self.failIf((0,1) in s)
1697         self.failUnless((3,4) in s)
1698         self.failUnless((3,1) in s)
1699         self.failUnless((5,2) in s)
1700         self.failUnless((6,1) in s)
1701         self.failIf((6,2) in s)
1702         self.failIf((7,1) in s)
1703         self.failUnlessEqual(list(s.each()), [3,4,5,6])
1704
1705     def test_math(self):
1706         s1 = Spans(0, 10) # 0,1,2,3,4,5,6,7,8,9
1707         s2 = Spans(5, 3) # 5,6,7
1708         s3 = Spans(8, 4) # 8,9,10,11
1709
1710         s = s1 - s2
1711         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1712         s = s1 - s3
1713         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1714         s = s2 - s3
1715         self.failUnlessEqual(list(s.each()), [5,6,7])
1716         s = s1 & s2
1717         self.failUnlessEqual(list(s.each()), [5,6,7])
1718         s = s2 & s1
1719         self.failUnlessEqual(list(s.each()), [5,6,7])
1720         s = s1 & s3
1721         self.failUnlessEqual(list(s.each()), [8,9])
1722         s = s3 & s1
1723         self.failUnlessEqual(list(s.each()), [8,9])
1724         s = s2 & s3
1725         self.failUnlessEqual(list(s.each()), [])
1726         s = s3 & s2
1727         self.failUnlessEqual(list(s.each()), [])
1728         s = Spans() & s3
1729         self.failUnlessEqual(list(s.each()), [])
1730         s = s3 & Spans()
1731         self.failUnlessEqual(list(s.each()), [])
1732
1733         s = s1 + s2
1734         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1735         s = s1 + s3
1736         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1737         s = s2 + s3
1738         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1739
1740         s = Spans(s1)
1741         s -= s2
1742         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,8,9])
1743         s = Spans(s1)
1744         s -= s3
1745         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7])
1746         s = Spans(s2)
1747         s -= s3
1748         self.failUnlessEqual(list(s.each()), [5,6,7])
1749
1750         s = Spans(s1)
1751         s += s2
1752         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9])
1753         s = Spans(s1)
1754         s += s3
1755         self.failUnlessEqual(list(s.each()), [0,1,2,3,4,5,6,7,8,9,10,11])
1756         s = Spans(s2)
1757         s += s3
1758         self.failUnlessEqual(list(s.each()), [5,6,7,8,9,10,11])
1759
1760     def test_random(self):
1761         # attempt to increase coverage of corner cases by comparing behavior
1762         # of a simple-but-slow model implementation against the
1763         # complex-but-fast actual implementation, in a large number of random
1764         # operations
1765         S1 = SimpleSpans
1766         S2 = Spans
1767         s1 = S1(); s2 = S2()
1768         seed = ""
1769         def _create(subseed):
1770             ns1 = S1(); ns2 = S2()
1771             for i in range(10):
1772                 what = md5(subseed+str(i)).hexdigest()
1773                 start = int(what[2:4], 16)
1774                 length = max(1,int(what[5:6], 16))
1775                 ns1.add(start, length); ns2.add(start, length)
1776             return ns1, ns2
1777
1778         #print
1779         for i in range(1000):
1780             what = md5(seed+str(i)).hexdigest()
1781             op = what[0]
1782             subop = what[1]
1783             start = int(what[2:4], 16)
1784             length = max(1,int(what[5:6], 16))
1785             #print what
1786             if op in "0":
1787                 if subop in "01234":
1788                     s1 = S1(); s2 = S2()
1789                 elif subop in "5678":
1790                     s1 = S1(start, length); s2 = S2(start, length)
1791                 else:
1792                     s1 = S1(s1); s2 = S2(s2)
1793                 #print "s2 = %s" % s2.dump()
1794             elif op in "123":
1795                 #print "s2.add(%d,%d)" % (start, length)
1796                 s1.add(start, length); s2.add(start, length)
1797             elif op in "456":
1798                 #print "s2.remove(%d,%d)" % (start, length)
1799                 s1.remove(start, length); s2.remove(start, length)
1800             elif op in "78":
1801                 ns1, ns2 = _create(what[7:11])
1802                 #print "s2 + %s" % ns2.dump()
1803                 s1 = s1 + ns1; s2 = s2 + ns2
1804             elif op in "9a":
1805                 ns1, ns2 = _create(what[7:11])
1806                 #print "%s - %s" % (s2.dump(), ns2.dump())
1807                 s1 = s1 - ns1; s2 = s2 - ns2
1808             elif op in "bc":
1809                 ns1, ns2 = _create(what[7:11])
1810                 #print "s2 += %s" % ns2.dump()
1811                 s1 += ns1; s2 += ns2
1812             elif op in "de":
1813                 ns1, ns2 = _create(what[7:11])
1814                 #print "%s -= %s" % (s2.dump(), ns2.dump())
1815                 s1 -= ns1; s2 -= ns2
1816             else:
1817                 ns1, ns2 = _create(what[7:11])
1818                 #print "%s &= %s" % (s2.dump(), ns2.dump())
1819                 s1 = s1 & ns1; s2 = s2 & ns2
1820             #print "s2 now %s" % s2.dump()
1821             self.failUnlessEqual(list(s1.each()), list(s2.each()))
1822             self.failUnlessEqual(s1.len(), s2.len())
1823             self.failUnlessEqual(bool(s1), bool(s2))
1824             self.failUnlessEqual(list(s1), list(s2))
1825             for j in range(10):
1826                 what = md5(what[12:14]+str(j)).hexdigest()
1827                 start = int(what[2:4], 16)
1828                 length = max(1, int(what[5:6], 16))
1829                 span = (start, length)
1830                 self.failUnlessEqual(bool(span in s1), bool(span in s2))
1831
1832
1833     # s()
1834     # s(start,length)
1835     # s(s0)
1836     # s.add(start,length) : returns s
1837     # s.remove(start,length)
1838     # s.each() -> list of byte offsets, mostly for testing
1839     # list(s) -> list of (start,length) tuples, one per span
1840     # (start,length) in s -> True if (start..start+length-1) are all members
1841     #  NOT equivalent to x in list(s)
1842     # s.len() -> number of bytes, for testing, bool(), and accounting/limiting
1843     # bool(s)  (__nonzeron__)
1844     # s = s1+s2, s1-s2, +=s1, -=s1
1845
1846     def test_overlap(self):
1847         for a in range(20):
1848             for b in range(10):
1849                 for c in range(20):
1850                     for d in range(10):
1851                         self._test_overlap(a,b,c,d)
1852
1853     def _test_overlap(self, a, b, c, d):
1854         s1 = set(range(a,a+b))
1855         s2 = set(range(c,c+d))
1856         #print "---"
1857         #self._show_overlap(s1, "1")
1858         #self._show_overlap(s2, "2")
1859         o = overlap(a,b,c,d)
1860         expected = s1.intersection(s2)
1861         if not expected:
1862             self.failUnlessEqual(o, None)
1863         else:
1864             start,length = o
1865             so = set(range(start,start+length))
1866             #self._show(so, "o")
1867             self.failUnlessEqual(so, expected)
1868
1869     def _show_overlap(self, s, c):
1870         import sys
1871         out = sys.stdout
1872         if s:
1873             for i in range(max(s)):
1874                 if i in s:
1875                     out.write(c)
1876                 else:
1877                     out.write(" ")
1878         out.write("\n")
1879
1880 def extend(s, start, length, fill):
1881     if len(s) >= start+length:
1882         return s
1883     assert len(fill) == 1
1884     return s + fill*(start+length-len(s))
1885
1886 def replace(s, start, data):
1887     assert len(s) >= start+len(data)
1888     return s[:start] + data + s[start+len(data):]
1889
1890 class SimpleDataSpans:
1891     def __init__(self, other=None):
1892         self.missing = "" # "1" where missing, "0" where found
1893         self.data = ""
1894         if other:
1895             for (start, data) in other.get_chunks():
1896                 self.add(start, data)
1897
1898     def __nonzero__(self): # this gets us bool()
1899         return self.len()
1900     def len(self):
1901         return len(self.missing.replace("1", ""))
1902     def _dump(self):
1903         return [i for (i,c) in enumerate(self.missing) if c == "0"]
1904     def _have(self, start, length):
1905         m = self.missing[start:start+length]
1906         if not m or len(m)<length or int(m):
1907             return False
1908         return True
1909     def get_chunks(self):
1910         for i in self._dump():
1911             yield (i, self.data[i])
1912     def get_spans(self):
1913         return SimpleSpans([(start,len(data))
1914                             for (start,data) in self.get_chunks()])
1915     def get(self, start, length):
1916         if self._have(start, length):
1917             return self.data[start:start+length]
1918         return None
1919     def pop(self, start, length):
1920         data = self.get(start, length)
1921         if data:
1922             self.remove(start, length)
1923         return data
1924     def remove(self, start, length):
1925         self.missing = replace(extend(self.missing, start, length, "1"),
1926                                start, "1"*length)
1927     def add(self, start, data):
1928         self.missing = replace(extend(self.missing, start, len(data), "1"),
1929                                start, "0"*len(data))
1930         self.data = replace(extend(self.data, start, len(data), " "),
1931                             start, data)
1932
1933
1934 class StringSpans(unittest.TestCase):
1935     def do_basic(self, klass):
1936         ds = klass()
1937         self.failUnlessEqual(ds.len(), 0)
1938         self.failUnlessEqual(list(ds._dump()), [])
1939         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 0)
1940         s = ds.get_spans()
1941         self.failUnlessEqual(ds.get(0, 4), None)
1942         self.failUnlessEqual(ds.pop(0, 4), None)
1943         ds.remove(0, 4)
1944
1945         ds.add(2, "four")
1946         self.failUnlessEqual(ds.len(), 4)
1947         self.failUnlessEqual(list(ds._dump()), [2,3,4,5])
1948         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1949         s = ds.get_spans()
1950         self.failUnless((2,2) in s)
1951         self.failUnlessEqual(ds.get(0, 4), None)
1952         self.failUnlessEqual(ds.pop(0, 4), None)
1953         self.failUnlessEqual(ds.get(4, 4), None)
1954
1955         ds2 = klass(ds)
1956         self.failUnlessEqual(ds2.len(), 4)
1957         self.failUnlessEqual(list(ds2._dump()), [2,3,4,5])
1958         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 4)
1959         self.failUnlessEqual(ds2.get(0, 4), None)
1960         self.failUnlessEqual(ds2.pop(0, 4), None)
1961         self.failUnlessEqual(ds2.pop(2, 3), "fou")
1962         self.failUnlessEqual(sum([len(d) for (s,d) in ds2.get_chunks()]), 1)
1963         self.failUnlessEqual(ds2.get(2, 3), None)
1964         self.failUnlessEqual(ds2.get(5, 1), "r")
1965         self.failUnlessEqual(ds.get(2, 3), "fou")
1966         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 4)
1967
1968         ds.add(0, "23")
1969         self.failUnlessEqual(ds.len(), 6)
1970         self.failUnlessEqual(list(ds._dump()), [0,1,2,3,4,5])
1971         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 6)
1972         self.failUnlessEqual(ds.get(0, 4), "23fo")
1973         self.failUnlessEqual(ds.pop(0, 4), "23fo")
1974         self.failUnlessEqual(sum([len(d) for (s,d) in ds.get_chunks()]), 2)
1975         self.failUnlessEqual(ds.get(0, 4), None)
1976         self.failUnlessEqual(ds.pop(0, 4), None)
1977
1978         ds = klass()
1979         ds.add(2, "four")
1980         ds.add(3, "ea")
1981         self.failUnlessEqual(ds.get(2, 4), "fear")
1982
1983     def do_scan(self, klass):
1984         # do a test with gaps and spans of size 1 and 2
1985         #  left=(1,11) * right=(1,11) * gapsize=(1,2)
1986         # 111, 112, 121, 122, 211, 212, 221, 222
1987         #    211
1988         #      121
1989         #         112
1990         #            212
1991         #               222
1992         #                   221
1993         #                      111
1994         #                        122
1995         #  11 1  1 11 11  11  1 1  111
1996         # 0123456789012345678901234567
1997         # abcdefghijklmnopqrstuvwxyz-=
1998         pieces = [(1, "bc"),
1999                   (4, "e"),
2000                   (7, "h"),
2001                   (9, "jk"),
2002                   (12, "mn"),
2003                   (16, "qr"),
2004                   (20, "u"),
2005                   (22, "w"),
2006                   (25, "z-="),
2007                   ]
2008         p_elements = set([1,2,4,7,9,10,12,13,16,17,20,22,25,26,27])
2009         S = "abcdefghijklmnopqrstuvwxyz-="
2010         # TODO: when adding data, add capital letters, to make sure we aren't
2011         # just leaving the old data in place
2012         l = len(S)
2013         def base():
2014             ds = klass()
2015             for start, data in pieces:
2016                 ds.add(start, data)
2017             return ds
2018         def dump(s):
2019             p = set(s._dump())
2020             # wow, this is the first time I've ever wanted ?: in python
2021             # note: this requires python2.5
2022             d = "".join([(S[i] if i in p else " ") for i in range(l)])
2023             assert len(d) == l
2024             return d
2025         DEBUG = False
2026         for start in range(0, l):
2027             for end in range(start+1, l):
2028                 # add [start-end) to the baseline
2029                 which = "%d-%d" % (start, end-1)
2030                 p_added = set(range(start, end))
2031                 b = base()
2032                 if DEBUG:
2033                     print
2034                     print dump(b), which
2035                     add = klass(); add.add(start, S[start:end])
2036                     print dump(add)
2037                 b.add(start, S[start:end])
2038                 if DEBUG:
2039                     print dump(b)
2040                 # check that the new span is there
2041                 d = b.get(start, end-start)
2042                 self.failUnlessEqual(d, S[start:end], which)
2043                 # check that all the original pieces are still there
2044                 for t_start, t_data in pieces:
2045                     t_len = len(t_data)
2046                     self.failUnlessEqual(b.get(t_start, t_len),
2047                                          S[t_start:t_start+t_len],
2048                                          "%s %d+%d" % (which, t_start, t_len))
2049                 # check that a lot of subspans are mostly correct
2050                 for t_start in range(l):
2051                     for t_len in range(1,4):
2052                         d = b.get(t_start, t_len)
2053                         if d is not None:
2054                             which2 = "%s+(%d-%d)" % (which, t_start,
2055                                                      t_start+t_len-1)
2056                             self.failUnlessEqual(d, S[t_start:t_start+t_len],
2057                                                  which2)
2058                         # check that removing a subspan gives the right value
2059                         b2 = klass(b)
2060                         b2.remove(t_start, t_len)
2061                         removed = set(range(t_start, t_start+t_len))
2062                         for i in range(l):
2063                             exp = (((i in p_elements) or (i in p_added))
2064                                    and (i not in removed))
2065                             which2 = "%s-(%d-%d)" % (which, t_start,
2066                                                      t_start+t_len-1)
2067                             self.failUnlessEqual(bool(b2.get(i, 1)), exp,
2068                                                  which2+" %d" % i)
2069
2070     def test_test(self):
2071         self.do_basic(SimpleDataSpans)
2072         self.do_scan(SimpleDataSpans)
2073
2074     def test_basic(self):
2075         self.do_basic(DataSpans)
2076         self.do_scan(DataSpans)
2077
2078     def test_random(self):
2079         # attempt to increase coverage of corner cases by comparing behavior
2080         # of a simple-but-slow model implementation against the
2081         # complex-but-fast actual implementation, in a large number of random
2082         # operations
2083         S1 = SimpleDataSpans
2084         S2 = DataSpans
2085         s1 = S1(); s2 = S2()
2086         seed = ""
2087         def _randstr(length, seed):
2088             created = 0
2089             pieces = []
2090             while created < length:
2091                 piece = md5(seed + str(created)).hexdigest()
2092                 pieces.append(piece)
2093                 created += len(piece)
2094             return "".join(pieces)[:length]
2095         def _create(subseed):
2096             ns1 = S1(); ns2 = S2()
2097             for i in range(10):
2098                 what = md5(subseed+str(i)).hexdigest()
2099                 start = int(what[2:4], 16)
2100                 length = max(1,int(what[5:6], 16))
2101                 ns1.add(start, _randstr(length, what[7:9]));
2102                 ns2.add(start, _randstr(length, what[7:9]))
2103             return ns1, ns2
2104
2105         #print
2106         for i in range(1000):
2107             what = md5(seed+str(i)).hexdigest()
2108             op = what[0]
2109             subop = what[1]
2110             start = int(what[2:4], 16)
2111             length = max(1,int(what[5:6], 16))
2112             #print what
2113             if op in "0":
2114                 if subop in "0123456":
2115                     s1 = S1(); s2 = S2()
2116                 else:
2117                     s1, s2 = _create(what[7:11])
2118                 #print "s2 = %s" % list(s2._dump())
2119             elif op in "123456":
2120                 #print "s2.add(%d,%d)" % (start, length)
2121                 s1.add(start, _randstr(length, what[7:9]));
2122                 s2.add(start, _randstr(length, what[7:9]))
2123             elif op in "789abc":
2124                 #print "s2.remove(%d,%d)" % (start, length)
2125                 s1.remove(start, length); s2.remove(start, length)
2126             else:
2127                 #print "s2.pop(%d,%d)" % (start, length)
2128                 d1 = s1.pop(start, length); d2 = s2.pop(start, length)
2129                 self.failUnlessEqual(d1, d2)
2130             #print "s1 now %s" % list(s1._dump())
2131             #print "s2 now %s" % list(s2._dump())
2132             self.failUnlessEqual(s1.len(), s2.len())
2133             self.failUnlessEqual(list(s1._dump()), list(s2._dump()))
2134             for j in range(100):
2135                 what = md5(what[12:14]+str(j)).hexdigest()
2136                 start = int(what[2:4], 16)
2137                 length = max(1, int(what[5:6], 16))
2138                 d1 = s1.get(start, length); d2 = s2.get(start, length)
2139                 self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))