test_util: add more coverage for assertutil.py
authorBrian Warner <warner@lothar.com>
Sun, 8 Apr 2007 20:02:13 +0000 (13:02 -0700)
committerBrian Warner <warner@lothar.com>
Sun, 8 Apr 2007 20:02:13 +0000 (13:02 -0700)
src/allmydata/test/test_util.py

index 94536e9646c81deb8eb9ad9cc75ee7b88fefa970..1c14457cbafd589f7d4086605970cf65a26ad5b5 100644 (file)
@@ -4,6 +4,7 @@ def foo(): pass # keep the line number constant
 from twisted.trial import unittest
 
 from allmydata.util import bencode, idlib, humanreadable, mathutil
+from allmydata.util import assertutil
 
 
 class IDLib(unittest.TestCase):
@@ -206,3 +207,66 @@ class Math(unittest.TestCase):
         self.failUnlessEqual(f([0,0,0,4]), 1)
         self.failUnlessAlmostEqual(f([0.0, 1.0, 1.0]), .666666666666)
 
+
+class Asserts(unittest.TestCase):
+    def should_assert(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except AssertionError, e:
+            return str(e)
+        except Exception, e:
+            self.fail("assert failed with non-AssertionError: %s" % e)
+        self.fail("assert was not caught")
+
+    def should_not_assert(self, func, *args, **kwargs):
+        if "re" in kwargs:
+            regexp = kwargs["re"]
+            del kwargs["re"]
+        try:
+            func(*args, **kwargs)
+        except AssertionError, e:
+            self.fail("assertion fired when it should not have: %s" % e)
+        except Exception, e:
+            self.fail("assertion (which shouldn't have failed) failed with non-AssertionError: %s" % e)
+        return # we're happy
+
+
+    def test_assert(self):
+        f = assertutil._assert
+        self.should_assert(f)
+        self.should_assert(f, False)
+        self.should_not_assert(f, True)
+
+        m = self.should_assert(f, False, "message")
+        self.failUnlessEqual(m, "'message' <type 'str'>", m)
+        m = self.should_assert(f, False, "message1", othermsg=12)
+        self.failUnlessEqual("'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
+        m = self.should_assert(f, False, othermsg="message2")
+        self.failUnlessEqual("othermsg: 'message2' <type 'str'>", m)
+
+    def test_precondition(self):
+        f = assertutil.precondition
+        self.should_assert(f)
+        self.should_assert(f, False)
+        self.should_not_assert(f, True)
+
+        m = self.should_assert(f, False, "message")
+        self.failUnlessEqual("precondition: 'message' <type 'str'>", m)
+        m = self.should_assert(f, False, "message1", othermsg=12)
+        self.failUnlessEqual("precondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
+        m = self.should_assert(f, False, othermsg="message2")
+        self.failUnlessEqual("precondition: othermsg: 'message2' <type 'str'>", m)
+
+    def test_postcondition(self):
+        f = assertutil.postcondition
+        self.should_assert(f)
+        self.should_assert(f, False)
+        self.should_not_assert(f, True)
+
+        m = self.should_assert(f, False, "message")
+        self.failUnlessEqual("postcondition: 'message' <type 'str'>", m)
+        m = self.should_assert(f, False, "message1", othermsg=12)
+        self.failUnlessEqual("postcondition: 'message1' <type 'str'>, othermsg: 12 <type 'int'>", m)
+        m = self.should_assert(f, False, othermsg="message2")
+        self.failUnlessEqual("postcondition: othermsg: 'message2' <type 'str'>", m)
+