From a68c9364b8f2392ce4dd36711a38bbae439e227f Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Wed, 6 Aug 2008 12:05:52 -0700
Subject: [PATCH] test/common: add ShouldFailMixin

---
 src/allmydata/test/common.py | 40 ++++++++++++++++++++++++++++++++++++
 1 file changed, 40 insertions(+)

diff --git a/src/allmydata/test/common.py b/src/allmydata/test/common.py
index 8866390c..be2234cd 100644
--- a/src/allmydata/test/common.py
+++ b/src/allmydata/test/common.py
@@ -790,3 +790,43 @@ class ShareManglingMixin(SystemTestMixin):
                         newdata = newshares.get((i, sharenum))
                         if newdata is not None:
                             open(pathtosharefile, "w").write(newdata)
+
+class ShouldFailMixin:
+    def shouldFail(self, expected_failure, which, substring,
+                   callable, *args, **kwargs):
+        """Assert that a function call raises some exception. This is a
+        Deferred-friendly version of TestCase.assertRaises() .
+
+        Suppose you want to verify the following function:
+
+         def broken(a, b, c):
+             if a < 0:
+                 raise TypeError('a must not be negative')
+             return defer.succeed(b+c)
+
+        You can use:
+            d = self.shouldFail(TypeError, 'test name',
+                                'a must not be negative',
+                                broken, -4, 5, c=12)
+        in your test method. The 'test name' string will be included in the
+        error message, if any, because Deferred chains frequently make it
+        difficult to tell which assertion was tripped.
+
+        The substring= argument, if not None, must appear inside the
+        stringified Failure, or the test will fail.
+        """
+
+        assert substring is None or isinstance(substring, str)
+        d = defer.maybeDeferred(callable, *args, **kwargs)
+        def done(res):
+            if isinstance(res, failure.Failure):
+                res.trap(expected_failure)
+                if substring:
+                    self.failUnless(substring in str(res),
+                                    "substring '%s' not in '%s'"
+                                    % (substring, str(res)))
+            else:
+                self.fail("%s was supposed to raise %s, not get '%s'" %
+                          (which, expected_failure, res))
+        d.addBoth(done)
+        return d
-- 
2.45.2