From 1b3e635936a34eee2e106d44ca33c1394d4f9068 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@allmydata.com>
Date: Fri, 27 Feb 2009 00:55:24 -0700
Subject: [PATCH] rrefutil: add trap_remote utility and friends

---
 src/allmydata/test/test_util.py | 83 +++++++++++++++++++++++++++++++--
 src/allmydata/util/rrefutil.py  | 31 +++++++++---
 2 files changed, 105 insertions(+), 9 deletions(-)

diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py
index 246b2b98..6f99611f 100644
--- a/src/allmydata/test/test_util.py
+++ b/src/allmydata/test/test_util.py
@@ -5,12 +5,13 @@ import os, time
 from StringIO import StringIO
 from twisted.trial import unittest
 from twisted.internet import defer, reactor
-from twisted.python import failure
+from twisted.python.failure import Failure
 
 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
 from allmydata.util import limiter, time_format, pollmixin, cachedir
-from allmydata.util import statistics, dictutil
+from allmydata.util import statistics, dictutil, rrefutil
+from allmydata.util.rrefutil import ServerFailure
 
 class Base32(unittest.TestCase):
     def test_b2a_matches_Pythons(self):
@@ -535,7 +536,7 @@ class DeferredUtilTests(unittest.TestCase):
         self.failUnlessEqual(good, [])
         self.failUnlessEqual(len(bad), 1)
         f = bad[0]
-        self.failUnless(isinstance(f, failure.Failure))
+        self.failUnless(isinstance(f, Failure))
         self.failUnless(f.check(ValueError))
 
 class HashUtilTests(unittest.TestCase):
@@ -1195,3 +1196,79 @@ class DictUtil(unittest.TestCase):
         self.failUnlessEqual(x, "b")
         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
 
+class FakeRemoteReference:
+    def callRemote(self, methname, *args, **kwargs):
+        return defer.maybeDeferred(self.oops)
+    def oops(self):
+        raise IndexError("remote missing key")
+
+class RemoteFailures(unittest.TestCase):
+    def test_check(self):
+        try:
+            raise IndexError("local missing key")
+        except IndexError:
+            localf = Failure()
+        self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
+        self.failUnlessEqual(localf.check(ValueError, KeyError), None)
+        self.failUnlessEqual(localf.check(ServerFailure), None)
+
+        frr = FakeRemoteReference()
+        wrr = rrefutil.WrappedRemoteReference(frr)
+        d = wrr.callRemote("oops")
+        def _check(f):
+            self.failUnlessEqual(f.check(IndexError, KeyError), None)
+            self.failUnlessEqual(f.check(ServerFailure, KeyError),
+                                 ServerFailure)
+        d.addErrback(_check)
+        return d
+
+    def test_is_remote(self):
+        try:
+            raise IndexError("local missing key")
+        except IndexError:
+            localf = Failure()
+        self.failIf(rrefutil.is_remote(localf))
+        self.failUnless(rrefutil.is_local(localf))
+
+        frr = FakeRemoteReference()
+        wrr = rrefutil.WrappedRemoteReference(frr)
+        d = wrr.callRemote("oops")
+        def _check(f):
+            self.failUnless(rrefutil.is_remote(f))
+            self.failIf(rrefutil.is_local(f))
+        d.addErrback(_check)
+        return d
+
+    def test_trap(self):
+        try:
+            raise IndexError("local missing key")
+        except IndexError:
+            localf = Failure()
+
+        self.failUnlessRaises(Failure, localf.trap, ValueError, KeyError)
+        self.failUnlessRaises(Failure, localf.trap, ServerFailure)
+        self.failUnlessEqual(localf.trap(IndexError, KeyError), IndexError)
+        self.failUnlessEqual(rrefutil.trap_local(localf, IndexError, KeyError),
+                             IndexError)
+        self.failUnlessRaises(Failure,
+                              rrefutil.trap_remote, localf, ValueError, KeyError)
+
+        frr = FakeRemoteReference()
+        wrr = rrefutil.WrappedRemoteReference(frr)
+        d = wrr.callRemote("oops")
+        def _check(f):
+            self.failUnlessRaises(Failure,
+                                  f.trap, ValueError, KeyError)
+            self.failUnlessRaises(Failure,
+                                  f.trap, IndexError)
+            self.failUnlessEqual(f.trap(ServerFailure), ServerFailure)
+            self.failUnlessRaises(Failure,
+                                  rrefutil.trap_remote, f, ValueError, KeyError)
+            self.failUnlessEqual(rrefutil.trap_remote(f, IndexError, KeyError),
+                                 IndexError)
+            self.failUnlessRaises(Failure,
+                                  rrefutil.trap_local, f, ValueError, KeyError)
+            self.failUnlessRaises(Failure,
+                                  rrefutil.trap_local, f, IndexError)
+        d.addErrback(_check)
+        return d
diff --git a/src/allmydata/util/rrefutil.py b/src/allmydata/util/rrefutil.py
index aafce8c2..97ad8fbe 100644
--- a/src/allmydata/util/rrefutil.py
+++ b/src/allmydata/util/rrefutil.py
@@ -3,9 +3,9 @@ import exceptions
 from foolscap.tokens import Violation
 
 class ServerFailure(exceptions.Exception):
-    # If the server returns a Failure instead of the normal response to a protocol, then this
-    # exception will be raised, with the Failure that the server returned as its .remote_failure
-    # attribute.
+    # If the server returns a Failure instead of the normal response to a
+    # protocol, then this exception will be raised, with the Failure that the
+    # server returned as its .remote_failure attribute.
     def __init__(self, remote_failure):
         self.remote_failure = remote_failure
     def __repr__(self):
@@ -13,11 +13,30 @@ class ServerFailure(exceptions.Exception):
     def __str__(self):
         return str(self.remote_failure)
 
+def is_remote(f):
+    if isinstance(f.value, ServerFailure):
+        return True
+    return False
+
+def is_local(f):
+    return not is_remote(f)
+
+def trap_remote(f, *errorTypes):
+    if is_remote(f):
+        return f.value.remote_failure.trap(*errorTypes)
+    raise f
+
+def trap_local(f, *errorTypes):
+    if is_local(f):
+        return f.trap(*errorTypes)
+    raise f
+
 def _wrap_server_failure(f):
     raise ServerFailure(f)
 
 class WrappedRemoteReference(object):
-    """I intercept any errback from the server and wrap it in a ServerFailure."""
+    """I intercept any errback from the server and wrap it in a
+    ServerFailure."""
 
     def __init__(self, original):
         self.rref = original
@@ -37,8 +56,8 @@ class WrappedRemoteReference(object):
         return self.rref.dontNotifyOnDisconnect(*args, **kwargs)
 
 class VersionedRemoteReference(WrappedRemoteReference):
-    """I wrap a RemoteReference, and add a .version attribute. I also intercept any errback from
-    the server and wrap it in a ServerFailure."""
+    """I wrap a RemoteReference, and add a .version attribute. I also
+    intercept any errback from the server and wrap it in a ServerFailure."""
 
     def __init__(self, original, version):
         WrappedRemoteReference.__init__(self, original)
-- 
2.45.2