From: Brian Warner Date: Fri, 27 Feb 2009 07:59:57 +0000 (-0700) Subject: rrefutil: add check_remote utility function X-Git-Url: https://git.rkrishnan.org/tahoe_css?a=commitdiff_plain;h=8c3013c4f70ef59174c08488f50ab786d43ec68a;p=tahoe-lafs%2Ftahoe-lafs.git rrefutil: add check_remote utility function --- diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py index 6f99611f..ece54181 100644 --- a/src/allmydata/test/test_util.py +++ b/src/allmydata/test/test_util.py @@ -1204,13 +1204,21 @@ class FakeRemoteReference: class RemoteFailures(unittest.TestCase): def test_check(self): + check_local = rrefutil.check_local + check_remote = rrefutil.check_remote try: raise IndexError("local missing key") except IndexError: localf = Failure() + self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError) self.failUnlessEqual(localf.check(ValueError, KeyError), None) self.failUnlessEqual(localf.check(ServerFailure), None) + self.failUnlessEqual(check_local(localf, IndexError, KeyError), + IndexError) + self.failUnlessEqual(check_local(localf, ValueError, KeyError), None) + self.failUnlessEqual(check_remote(localf, IndexError, KeyError), None) + self.failUnlessEqual(check_remote(localf, ValueError, KeyError), None) frr = FakeRemoteReference() wrr = rrefutil.WrappedRemoteReference(frr) @@ -1219,6 +1227,11 @@ class RemoteFailures(unittest.TestCase): self.failUnlessEqual(f.check(IndexError, KeyError), None) self.failUnlessEqual(f.check(ServerFailure, KeyError), ServerFailure) + self.failUnlessEqual(check_remote(f, IndexError, KeyError), + IndexError) + self.failUnlessEqual(check_remote(f, ValueError, KeyError), None) + self.failUnlessEqual(check_local(f, IndexError, KeyError), None) + self.failUnlessEqual(check_local(f, ValueError, KeyError), None) d.addErrback(_check) return d diff --git a/src/allmydata/util/rrefutil.py b/src/allmydata/util/rrefutil.py index 97ad8fbe..1c037cc4 100644 --- a/src/allmydata/util/rrefutil.py +++ b/src/allmydata/util/rrefutil.py @@ -21,6 +21,16 @@ def is_remote(f): def is_local(f): return not is_remote(f) +def check_remote(f, *errorTypes): + if is_remote(f): + return f.value.remote_failure.check(*errorTypes) + return None + +def check_local(f, *errorTypes): + if is_local(f): + return f.check(*errorTypes) + return None + def trap_remote(f, *errorTypes): if is_remote(f): return f.value.remote_failure.trap(*errorTypes)