from StringIO import StringIO
from twisted.trial import unittest
from twisted.internet import defer, reactor
-from twisted.python import failure
+from twisted.python.failure import Failure
from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
from allmydata.util import limiter, time_format, pollmixin, cachedir
-from allmydata.util import statistics, dictutil
+from allmydata.util import statistics, dictutil, rrefutil
+from allmydata.util.rrefutil import ServerFailure
class Base32(unittest.TestCase):
def test_b2a_matches_Pythons(self):
self.failUnlessEqual(good, [])
self.failUnlessEqual(len(bad), 1)
f = bad[0]
- self.failUnless(isinstance(f, failure.Failure))
+ self.failUnless(isinstance(f, Failure))
self.failUnless(f.check(ValueError))
class HashUtilTests(unittest.TestCase):
self.failUnlessEqual(x, "b")
self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
+class FakeRemoteReference:
+ def callRemote(self, methname, *args, **kwargs):
+ return defer.maybeDeferred(self.oops)
+ def oops(self):
+ raise IndexError("remote missing key")
+
+class RemoteFailures(unittest.TestCase):
+ def test_check(self):
+ try:
+ raise IndexError("local missing key")
+ except IndexError:
+ localf = Failure()
+ self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
+ self.failUnlessEqual(localf.check(ValueError, KeyError), None)
+ self.failUnlessEqual(localf.check(ServerFailure), None)
+
+ frr = FakeRemoteReference()
+ wrr = rrefutil.WrappedRemoteReference(frr)
+ d = wrr.callRemote("oops")
+ def _check(f):
+ self.failUnlessEqual(f.check(IndexError, KeyError), None)
+ self.failUnlessEqual(f.check(ServerFailure, KeyError),
+ ServerFailure)
+ d.addErrback(_check)
+ return d
+
+ def test_is_remote(self):
+ try:
+ raise IndexError("local missing key")
+ except IndexError:
+ localf = Failure()
+ self.failIf(rrefutil.is_remote(localf))
+ self.failUnless(rrefutil.is_local(localf))
+
+ frr = FakeRemoteReference()
+ wrr = rrefutil.WrappedRemoteReference(frr)
+ d = wrr.callRemote("oops")
+ def _check(f):
+ self.failUnless(rrefutil.is_remote(f))
+ self.failIf(rrefutil.is_local(f))
+ d.addErrback(_check)
+ return d
+
+ def test_trap(self):
+ try:
+ raise IndexError("local missing key")
+ except IndexError:
+ localf = Failure()
+
+ self.failUnlessRaises(Failure, localf.trap, ValueError, KeyError)
+ self.failUnlessRaises(Failure, localf.trap, ServerFailure)
+ self.failUnlessEqual(localf.trap(IndexError, KeyError), IndexError)
+ self.failUnlessEqual(rrefutil.trap_local(localf, IndexError, KeyError),
+ IndexError)
+ self.failUnlessRaises(Failure,
+ rrefutil.trap_remote, localf, ValueError, KeyError)
+
+ frr = FakeRemoteReference()
+ wrr = rrefutil.WrappedRemoteReference(frr)
+ d = wrr.callRemote("oops")
+ def _check(f):
+ self.failUnlessRaises(Failure,
+ f.trap, ValueError, KeyError)
+ self.failUnlessRaises(Failure,
+ f.trap, IndexError)
+ self.failUnlessEqual(f.trap(ServerFailure), ServerFailure)
+ self.failUnlessRaises(Failure,
+ rrefutil.trap_remote, f, ValueError, KeyError)
+ self.failUnlessEqual(rrefutil.trap_remote(f, IndexError, KeyError),
+ IndexError)
+ self.failUnlessRaises(Failure,
+ rrefutil.trap_local, f, ValueError, KeyError)
+ self.failUnlessRaises(Failure,
+ rrefutil.trap_local, f, IndexError)
+ d.addErrback(_check)
+ return d
from foolscap.tokens import Violation
class ServerFailure(exceptions.Exception):
- # If the server returns a Failure instead of the normal response to a protocol, then this
- # exception will be raised, with the Failure that the server returned as its .remote_failure
- # attribute.
+ # If the server returns a Failure instead of the normal response to a
+ # protocol, then this exception will be raised, with the Failure that the
+ # server returned as its .remote_failure attribute.
def __init__(self, remote_failure):
self.remote_failure = remote_failure
def __repr__(self):
def __str__(self):
return str(self.remote_failure)
+def is_remote(f):
+ if isinstance(f.value, ServerFailure):
+ return True
+ return False
+
+def is_local(f):
+ return not is_remote(f)
+
+def trap_remote(f, *errorTypes):
+ if is_remote(f):
+ return f.value.remote_failure.trap(*errorTypes)
+ raise f
+
+def trap_local(f, *errorTypes):
+ if is_local(f):
+ return f.trap(*errorTypes)
+ raise f
+
def _wrap_server_failure(f):
raise ServerFailure(f)
class WrappedRemoteReference(object):
- """I intercept any errback from the server and wrap it in a ServerFailure."""
+ """I intercept any errback from the server and wrap it in a
+ ServerFailure."""
def __init__(self, original):
self.rref = original
return self.rref.dontNotifyOnDisconnect(*args, **kwargs)
class VersionedRemoteReference(WrappedRemoteReference):
- """I wrap a RemoteReference, and add a .version attribute. I also intercept any errback from
- the server and wrap it in a ServerFailure."""
+ """I wrap a RemoteReference, and add a .version attribute. I also
+ intercept any errback from the server and wrap it in a ServerFailure."""
def __init__(self, original, version):
WrappedRemoteReference.__init__(self, original)