]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
rrefutil: add check_remote utility function
authorBrian Warner <warner@allmydata.com>
Fri, 27 Feb 2009 07:59:57 +0000 (00:59 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 27 Feb 2009 07:59:57 +0000 (00:59 -0700)
src/allmydata/test/test_util.py
src/allmydata/util/rrefutil.py

index 6f99611fe22e9ad941d1a76e657feb8d576cca50..ece54181dddbe14d2186bec23a976edb552ec29f 100644 (file)
@@ -1204,13 +1204,21 @@ class FakeRemoteReference:
 
 class RemoteFailures(unittest.TestCase):
     def test_check(self):
+        check_local = rrefutil.check_local
+        check_remote = rrefutil.check_remote
         try:
             raise IndexError("local missing key")
         except IndexError:
             localf = Failure()
+
         self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
         self.failUnlessEqual(localf.check(ValueError, KeyError), None)
         self.failUnlessEqual(localf.check(ServerFailure), None)
+        self.failUnlessEqual(check_local(localf, IndexError, KeyError),
+                             IndexError)
+        self.failUnlessEqual(check_local(localf, ValueError, KeyError), None)
+        self.failUnlessEqual(check_remote(localf, IndexError, KeyError), None)
+        self.failUnlessEqual(check_remote(localf, ValueError, KeyError), None)
 
         frr = FakeRemoteReference()
         wrr = rrefutil.WrappedRemoteReference(frr)
@@ -1219,6 +1227,11 @@ class RemoteFailures(unittest.TestCase):
             self.failUnlessEqual(f.check(IndexError, KeyError), None)
             self.failUnlessEqual(f.check(ServerFailure, KeyError),
                                  ServerFailure)
+            self.failUnlessEqual(check_remote(f, IndexError, KeyError),
+                                 IndexError)
+            self.failUnlessEqual(check_remote(f, ValueError, KeyError), None)
+            self.failUnlessEqual(check_local(f, IndexError, KeyError), None)
+            self.failUnlessEqual(check_local(f, ValueError, KeyError), None)
         d.addErrback(_check)
         return d
 
index 97ad8fbef7f77fce8921bae9727fca777cfdb231..1c037cc45d2a1624b11c63ff5e2d2ac75f7b00da 100644 (file)
@@ -21,6 +21,16 @@ def is_remote(f):
 def is_local(f):
     return not is_remote(f)
 
+def check_remote(f, *errorTypes):
+    if is_remote(f):
+        return f.value.remote_failure.check(*errorTypes)
+    return None
+
+def check_local(f, *errorTypes):
+    if is_local(f):
+        return f.check(*errorTypes)
+    return None
+
 def trap_remote(f, *errorTypes):
     if is_remote(f):
         return f.value.remote_failure.trap(*errorTypes)