From 8c3013c4f70ef59174c08488f50ab786d43ec68a Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Fri, 27 Feb 2009 00:59:57 -0700
Subject: [PATCH] rrefutil: add check_remote utility function

---
 src/allmydata/test/test_util.py | 13 +++++++++++++
 src/allmydata/util/rrefutil.py  | 10 ++++++++++
 2 files changed, 23 insertions(+)

diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py
index 6f99611f..ece54181 100644
--- a/src/allmydata/test/test_util.py
+++ b/src/allmydata/test/test_util.py
@@ -1204,13 +1204,21 @@ class FakeRemoteReference:
 
 class RemoteFailures(unittest.TestCase):
     def test_check(self):
+        check_local = rrefutil.check_local
+        check_remote = rrefutil.check_remote
         try:
             raise IndexError("local missing key")
         except IndexError:
             localf = Failure()
+
         self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
         self.failUnlessEqual(localf.check(ValueError, KeyError), None)
         self.failUnlessEqual(localf.check(ServerFailure), None)
+        self.failUnlessEqual(check_local(localf, IndexError, KeyError),
+                             IndexError)
+        self.failUnlessEqual(check_local(localf, ValueError, KeyError), None)
+        self.failUnlessEqual(check_remote(localf, IndexError, KeyError), None)
+        self.failUnlessEqual(check_remote(localf, ValueError, KeyError), None)
 
         frr = FakeRemoteReference()
         wrr = rrefutil.WrappedRemoteReference(frr)
@@ -1219,6 +1227,11 @@ class RemoteFailures(unittest.TestCase):
             self.failUnlessEqual(f.check(IndexError, KeyError), None)
             self.failUnlessEqual(f.check(ServerFailure, KeyError),
                                  ServerFailure)
+            self.failUnlessEqual(check_remote(f, IndexError, KeyError),
+                                 IndexError)
+            self.failUnlessEqual(check_remote(f, ValueError, KeyError), None)
+            self.failUnlessEqual(check_local(f, IndexError, KeyError), None)
+            self.failUnlessEqual(check_local(f, ValueError, KeyError), None)
         d.addErrback(_check)
         return d
 
diff --git a/src/allmydata/util/rrefutil.py b/src/allmydata/util/rrefutil.py
index 97ad8fbe..1c037cc4 100644
--- a/src/allmydata/util/rrefutil.py
+++ b/src/allmydata/util/rrefutil.py
@@ -21,6 +21,16 @@ def is_remote(f):
 def is_local(f):
     return not is_remote(f)
 
+def check_remote(f, *errorTypes):
+    if is_remote(f):
+        return f.value.remote_failure.check(*errorTypes)
+    return None
+
+def check_local(f, *errorTypes):
+    if is_local(f):
+        return f.check(*errorTypes)
+    return None
+
 def trap_remote(f, *errorTypes):
     if is_remote(f):
         return f.value.remote_failure.trap(*errorTypes)
-- 
2.45.2