From: Brian Warner Date: Fri, 27 Feb 2009 07:55:24 +0000 (-0700) Subject: rrefutil: add trap_remote utility and friends X-Git-Url: https://git.rkrishnan.org/COPYING.GPL?a=commitdiff_plain;h=1b3e635936a34eee2e106d44ca33c1394d4f9068;p=tahoe-lafs%2Ftahoe-lafs.git rrefutil: add trap_remote utility and friends --- diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py index 246b2b98..6f99611f 100644 --- a/src/allmydata/test/test_util.py +++ b/src/allmydata/test/test_util.py @@ -5,12 +5,13 @@ import os, time from StringIO import StringIO from twisted.trial import unittest from twisted.internet import defer, reactor -from twisted.python import failure +from twisted.python.failure import Failure from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil from allmydata.util import assertutil, fileutil, deferredutil, abbreviate from allmydata.util import limiter, time_format, pollmixin, cachedir -from allmydata.util import statistics, dictutil +from allmydata.util import statistics, dictutil, rrefutil +from allmydata.util.rrefutil import ServerFailure class Base32(unittest.TestCase): def test_b2a_matches_Pythons(self): @@ -535,7 +536,7 @@ class DeferredUtilTests(unittest.TestCase): self.failUnlessEqual(good, []) self.failUnlessEqual(len(bad), 1) f = bad[0] - self.failUnless(isinstance(f, failure.Failure)) + self.failUnless(isinstance(f, Failure)) self.failUnless(f.check(ValueError)) class HashUtilTests(unittest.TestCase): @@ -1195,3 +1196,79 @@ class DictUtil(unittest.TestCase): self.failUnlessEqual(x, "b") self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)]) +class FakeRemoteReference: + def callRemote(self, methname, *args, **kwargs): + return defer.maybeDeferred(self.oops) + def oops(self): + raise IndexError("remote missing key") + +class RemoteFailures(unittest.TestCase): + def test_check(self): + try: + raise IndexError("local missing key") + except IndexError: + localf = Failure() + self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError) + self.failUnlessEqual(localf.check(ValueError, KeyError), None) + self.failUnlessEqual(localf.check(ServerFailure), None) + + frr = FakeRemoteReference() + wrr = rrefutil.WrappedRemoteReference(frr) + d = wrr.callRemote("oops") + def _check(f): + self.failUnlessEqual(f.check(IndexError, KeyError), None) + self.failUnlessEqual(f.check(ServerFailure, KeyError), + ServerFailure) + d.addErrback(_check) + return d + + def test_is_remote(self): + try: + raise IndexError("local missing key") + except IndexError: + localf = Failure() + self.failIf(rrefutil.is_remote(localf)) + self.failUnless(rrefutil.is_local(localf)) + + frr = FakeRemoteReference() + wrr = rrefutil.WrappedRemoteReference(frr) + d = wrr.callRemote("oops") + def _check(f): + self.failUnless(rrefutil.is_remote(f)) + self.failIf(rrefutil.is_local(f)) + d.addErrback(_check) + return d + + def test_trap(self): + try: + raise IndexError("local missing key") + except IndexError: + localf = Failure() + + self.failUnlessRaises(Failure, localf.trap, ValueError, KeyError) + self.failUnlessRaises(Failure, localf.trap, ServerFailure) + self.failUnlessEqual(localf.trap(IndexError, KeyError), IndexError) + self.failUnlessEqual(rrefutil.trap_local(localf, IndexError, KeyError), + IndexError) + self.failUnlessRaises(Failure, + rrefutil.trap_remote, localf, ValueError, KeyError) + + frr = FakeRemoteReference() + wrr = rrefutil.WrappedRemoteReference(frr) + d = wrr.callRemote("oops") + def _check(f): + self.failUnlessRaises(Failure, + f.trap, ValueError, KeyError) + self.failUnlessRaises(Failure, + f.trap, IndexError) + self.failUnlessEqual(f.trap(ServerFailure), ServerFailure) + self.failUnlessRaises(Failure, + rrefutil.trap_remote, f, ValueError, KeyError) + self.failUnlessEqual(rrefutil.trap_remote(f, IndexError, KeyError), + IndexError) + self.failUnlessRaises(Failure, + rrefutil.trap_local, f, ValueError, KeyError) + self.failUnlessRaises(Failure, + rrefutil.trap_local, f, IndexError) + d.addErrback(_check) + return d diff --git a/src/allmydata/util/rrefutil.py b/src/allmydata/util/rrefutil.py index aafce8c2..97ad8fbe 100644 --- a/src/allmydata/util/rrefutil.py +++ b/src/allmydata/util/rrefutil.py @@ -3,9 +3,9 @@ import exceptions from foolscap.tokens import Violation class ServerFailure(exceptions.Exception): - # If the server returns a Failure instead of the normal response to a protocol, then this - # exception will be raised, with the Failure that the server returned as its .remote_failure - # attribute. + # If the server returns a Failure instead of the normal response to a + # protocol, then this exception will be raised, with the Failure that the + # server returned as its .remote_failure attribute. def __init__(self, remote_failure): self.remote_failure = remote_failure def __repr__(self): @@ -13,11 +13,30 @@ class ServerFailure(exceptions.Exception): def __str__(self): return str(self.remote_failure) +def is_remote(f): + if isinstance(f.value, ServerFailure): + return True + return False + +def is_local(f): + return not is_remote(f) + +def trap_remote(f, *errorTypes): + if is_remote(f): + return f.value.remote_failure.trap(*errorTypes) + raise f + +def trap_local(f, *errorTypes): + if is_local(f): + return f.trap(*errorTypes) + raise f + def _wrap_server_failure(f): raise ServerFailure(f) class WrappedRemoteReference(object): - """I intercept any errback from the server and wrap it in a ServerFailure.""" + """I intercept any errback from the server and wrap it in a + ServerFailure.""" def __init__(self, original): self.rref = original @@ -37,8 +56,8 @@ class WrappedRemoteReference(object): return self.rref.dontNotifyOnDisconnect(*args, **kwargs) class VersionedRemoteReference(WrappedRemoteReference): - """I wrap a RemoteReference, and add a .version attribute. I also intercept any errback from - the server and wrap it in a ServerFailure.""" + """I wrap a RemoteReference, and add a .version attribute. I also + intercept any errback from the server and wrap it in a ServerFailure.""" def __init__(self, original, version): WrappedRemoteReference.__init__(self, original)