]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
rrefutil: add trap_remote utility and friends
authorBrian Warner <warner@allmydata.com>
Fri, 27 Feb 2009 07:55:24 +0000 (00:55 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 27 Feb 2009 07:55:24 +0000 (00:55 -0700)
src/allmydata/test/test_util.py
src/allmydata/util/rrefutil.py

index 246b2b989e6ee0a91f097490a6af8c9e9901266c..6f99611fe22e9ad941d1a76e657feb8d576cca50 100644 (file)
@@ -5,12 +5,13 @@ import os, time
 from StringIO import StringIO
 from twisted.trial import unittest
 from twisted.internet import defer, reactor
-from twisted.python import failure
+from twisted.python.failure import Failure
 
 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
 from allmydata.util import limiter, time_format, pollmixin, cachedir
-from allmydata.util import statistics, dictutil
+from allmydata.util import statistics, dictutil, rrefutil
+from allmydata.util.rrefutil import ServerFailure
 
 class Base32(unittest.TestCase):
     def test_b2a_matches_Pythons(self):
@@ -535,7 +536,7 @@ class DeferredUtilTests(unittest.TestCase):
         self.failUnlessEqual(good, [])
         self.failUnlessEqual(len(bad), 1)
         f = bad[0]
-        self.failUnless(isinstance(f, failure.Failure))
+        self.failUnless(isinstance(f, Failure))
         self.failUnless(f.check(ValueError))
 
 class HashUtilTests(unittest.TestCase):
@@ -1195,3 +1196,79 @@ class DictUtil(unittest.TestCase):
         self.failUnlessEqual(x, "b")
         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
 
+class FakeRemoteReference:
+    def callRemote(self, methname, *args, **kwargs):
+        return defer.maybeDeferred(self.oops)
+    def oops(self):
+        raise IndexError("remote missing key")
+
+class RemoteFailures(unittest.TestCase):
+    def test_check(self):
+        try:
+            raise IndexError("local missing key")
+        except IndexError:
+            localf = Failure()
+        self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
+        self.failUnlessEqual(localf.check(ValueError, KeyError), None)
+        self.failUnlessEqual(localf.check(ServerFailure), None)
+
+        frr = FakeRemoteReference()
+        wrr = rrefutil.WrappedRemoteReference(frr)
+        d = wrr.callRemote("oops")
+        def _check(f):
+            self.failUnlessEqual(f.check(IndexError, KeyError), None)
+            self.failUnlessEqual(f.check(ServerFailure, KeyError),
+                                 ServerFailure)
+        d.addErrback(_check)
+        return d
+
+    def test_is_remote(self):
+        try:
+            raise IndexError("local missing key")
+        except IndexError:
+            localf = Failure()
+        self.failIf(rrefutil.is_remote(localf))
+        self.failUnless(rrefutil.is_local(localf))
+
+        frr = FakeRemoteReference()
+        wrr = rrefutil.WrappedRemoteReference(frr)
+        d = wrr.callRemote("oops")
+        def _check(f):
+            self.failUnless(rrefutil.is_remote(f))
+            self.failIf(rrefutil.is_local(f))
+        d.addErrback(_check)
+        return d
+
+    def test_trap(self):
+        try:
+            raise IndexError("local missing key")
+        except IndexError:
+            localf = Failure()
+
+        self.failUnlessRaises(Failure, localf.trap, ValueError, KeyError)
+        self.failUnlessRaises(Failure, localf.trap, ServerFailure)
+        self.failUnlessEqual(localf.trap(IndexError, KeyError), IndexError)
+        self.failUnlessEqual(rrefutil.trap_local(localf, IndexError, KeyError),
+                             IndexError)
+        self.failUnlessRaises(Failure,
+                              rrefutil.trap_remote, localf, ValueError, KeyError)
+
+        frr = FakeRemoteReference()
+        wrr = rrefutil.WrappedRemoteReference(frr)
+        d = wrr.callRemote("oops")
+        def _check(f):
+            self.failUnlessRaises(Failure,
+                                  f.trap, ValueError, KeyError)
+            self.failUnlessRaises(Failure,
+                                  f.trap, IndexError)
+            self.failUnlessEqual(f.trap(ServerFailure), ServerFailure)
+            self.failUnlessRaises(Failure,
+                                  rrefutil.trap_remote, f, ValueError, KeyError)
+            self.failUnlessEqual(rrefutil.trap_remote(f, IndexError, KeyError),
+                                 IndexError)
+            self.failUnlessRaises(Failure,
+                                  rrefutil.trap_local, f, ValueError, KeyError)
+            self.failUnlessRaises(Failure,
+                                  rrefutil.trap_local, f, IndexError)
+        d.addErrback(_check)
+        return d
index aafce8c24f721471846c84fa094137524f4082ab..97ad8fbef7f77fce8921bae9727fca777cfdb231 100644 (file)
@@ -3,9 +3,9 @@ import exceptions
 from foolscap.tokens import Violation
 
 class ServerFailure(exceptions.Exception):
-    # If the server returns a Failure instead of the normal response to a protocol, then this
-    # exception will be raised, with the Failure that the server returned as its .remote_failure
-    # attribute.
+    # If the server returns a Failure instead of the normal response to a
+    # protocol, then this exception will be raised, with the Failure that the
+    # server returned as its .remote_failure attribute.
     def __init__(self, remote_failure):
         self.remote_failure = remote_failure
     def __repr__(self):
@@ -13,11 +13,30 @@ class ServerFailure(exceptions.Exception):
     def __str__(self):
         return str(self.remote_failure)
 
+def is_remote(f):
+    if isinstance(f.value, ServerFailure):
+        return True
+    return False
+
+def is_local(f):
+    return not is_remote(f)
+
+def trap_remote(f, *errorTypes):
+    if is_remote(f):
+        return f.value.remote_failure.trap(*errorTypes)
+    raise f
+
+def trap_local(f, *errorTypes):
+    if is_local(f):
+        return f.trap(*errorTypes)
+    raise f
+
 def _wrap_server_failure(f):
     raise ServerFailure(f)
 
 class WrappedRemoteReference(object):
-    """I intercept any errback from the server and wrap it in a ServerFailure."""
+    """I intercept any errback from the server and wrap it in a
+    ServerFailure."""
 
     def __init__(self, original):
         self.rref = original
@@ -37,8 +56,8 @@ class WrappedRemoteReference(object):
         return self.rref.dontNotifyOnDisconnect(*args, **kwargs)
 
 class VersionedRemoteReference(WrappedRemoteReference):
-    """I wrap a RemoteReference, and add a .version attribute. I also intercept any errback from
-    the server and wrap it in a ServerFailure."""
+    """I wrap a RemoteReference, and add a .version attribute. I also
+    intercept any errback from the server and wrap it in a ServerFailure."""
 
     def __init__(self, original, version):
         WrappedRemoteReference.__init__(self, original)