]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
test/common: add ShouldFailMixin
authorBrian Warner <warner@lothar.com>
Wed, 6 Aug 2008 19:05:52 +0000 (12:05 -0700)
committerBrian Warner <warner@lothar.com>
Wed, 6 Aug 2008 19:05:52 +0000 (12:05 -0700)
src/allmydata/test/common.py

index 8866390c9d80f394797e445ab20140131582863b..be2234cdd5070b6dcbe0b0f3ae6d4cbfb7f038fa 100644 (file)
@@ -790,3 +790,43 @@ class ShareManglingMixin(SystemTestMixin):
                         newdata = newshares.get((i, sharenum))
                         if newdata is not None:
                             open(pathtosharefile, "w").write(newdata)
+
+class ShouldFailMixin:
+    def shouldFail(self, expected_failure, which, substring,
+                   callable, *args, **kwargs):
+        """Assert that a function call raises some exception. This is a
+        Deferred-friendly version of TestCase.assertRaises() .
+
+        Suppose you want to verify the following function:
+
+         def broken(a, b, c):
+             if a < 0:
+                 raise TypeError('a must not be negative')
+             return defer.succeed(b+c)
+
+        You can use:
+            d = self.shouldFail(TypeError, 'test name',
+                                'a must not be negative',
+                                broken, -4, 5, c=12)
+        in your test method. The 'test name' string will be included in the
+        error message, if any, because Deferred chains frequently make it
+        difficult to tell which assertion was tripped.
+
+        The substring= argument, if not None, must appear inside the
+        stringified Failure, or the test will fail.
+        """
+
+        assert substring is None or isinstance(substring, str)
+        d = defer.maybeDeferred(callable, *args, **kwargs)
+        def done(res):
+            if isinstance(res, failure.Failure):
+                res.trap(expected_failure)
+                if substring:
+                    self.failUnless(substring in str(res),
+                                    "substring '%s' not in '%s'"
+                                    % (substring, str(res)))
+            else:
+                self.fail("%s was supposed to raise %s, not get '%s'" %
+                          (which, expected_failure, res))
+        d.addBoth(done)
+        return d