]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
switch to using RemoteException instead of 'wrapped' RemoteReferences. Should fix...
authorBrian Warner <warner@lothar.com>
Fri, 22 May 2009 00:46:32 +0000 (17:46 -0700)
committerBrian Warner <warner@lothar.com>
Fri, 22 May 2009 00:46:32 +0000 (17:46 -0700)
16 files changed:
src/allmydata/immutable/checker.py
src/allmydata/immutable/download.py
src/allmydata/immutable/upload.py
src/allmydata/introducer/client.py
src/allmydata/key_generator.py
src/allmydata/mutable/servermap.py
src/allmydata/node.py
src/allmydata/stats.py
src/allmydata/test/check_memory.py
src/allmydata/test/check_speed.py
src/allmydata/test/no_network.py
src/allmydata/test/test_helper.py
src/allmydata/test/test_introducer.py
src/allmydata/test/test_keygen.py
src/allmydata/test/test_util.py
src/allmydata/util/rrefutil.py

index 88d7b17103f0a619229296b9cdc134bc7ed7859d..c565bc30c00dc15ff6e89e838a932101c2d520e0 100644 (file)
@@ -1,11 +1,11 @@
-from foolscap.api import DeadReferenceError
+from foolscap.api import DeadReferenceError, RemoteException
 from twisted.internet import defer
 from allmydata import hashtree
 from allmydata.check_results import CheckResults
 from allmydata.immutable import download
 from allmydata.uri import CHKFileVerifierURI
 from allmydata.util.assertutil import precondition
-from allmydata.util import base32, deferredutil, dictutil, log, rrefutil
+from allmydata.util import base32, deferredutil, dictutil, log
 from allmydata.util.hashutil import file_renewal_secret_hash, \
      file_cancel_secret_hash, bucket_renewal_secret_hash, \
      bucket_cancel_secret_hash
@@ -38,7 +38,6 @@ class Checker(log.PrefixingLogMixin):
         assert precondition(isinstance(servers, (set, frozenset)), servers)
         for (serverid, serverrref) in servers:
             assert precondition(isinstance(serverid, str))
-            assert precondition(isinstance(serverrref, rrefutil.WrappedRemoteReference), serverrref)
 
         prefix = "%s" % base32.b2a_l(verifycap.storage_index[:8], 60)
         log.PrefixingLogMixin.__init__(self, facility="tahoe.immutable.checker", prefix=prefix)
@@ -84,19 +83,24 @@ class Checker(log.PrefixingLogMixin):
             def _done(res):
                 [(get_success, get_result),
                  (addlease_success, addlease_result)] = res
-                if (not addlease_success and
-                    not rrefutil.check_remote(addlease_result, IndexError)):
-                    # tahoe=1.3.0 raised IndexError on non-existant buckets,
-                    # which we ignore. But report others, including the
-                    # unfortunate internal KeyError bug that <1.3.0 had.
-                    return addlease_result # propagate error
-                return get_result
+                # ignore remote IndexError on the add_lease call. Propagate
+                # local errors and remote non-IndexErrors
+                if addlease_success:
+                    return get_result
+                if not addlease_result.check(RemoteException):
+                    # Propagate local errors
+                    return addlease_result
+                if addlease_result.value.failure.check(IndexError):
+                    # tahoe=1.3.0 raised IndexError on non-existant
+                    # buckets, which we ignore
+                    return get_result
+                # propagate remote errors that aren't IndexError, including
+                # the unfortunate internal KeyError bug that <1.3.0 had.
+                return addlease_result
             dl.addCallback(_done)
             d = dl
 
         def _wrap_results(res):
-            for k in res:
-                res[k] = rrefutil.WrappedRemoteReference(res[k])
             return (res, serverid, True)
 
         def _trap_errs(f):
@@ -133,27 +137,24 @@ class Checker(log.PrefixingLogMixin):
         d = veup.start()
 
         def _errb(f):
-            # Okay, we didn't succeed at fetching and verifying all the
-            # blocks of this share. Now we need to handle different reasons
-            # for failure differently. If the failure isn't one of the
-            # following four classes then it will get re-raised.
-            failtype = f.trap(DeadReferenceError,
-                              rrefutil.ServerFailure,
-                              layout.LayoutInvalid,
-                              layout.RidiculouslyLargeURIExtensionBlock,
-                              layout.ShareVersionIncompatible,
-                              download.BadOrMissingHash,
-                              download.BadURIExtensionHashValue)
+            # We didn't succeed at fetching and verifying all the blocks of
+            # this share. Handle each reason for failure differently.
 
             if f.check(DeadReferenceError):
                 return (False, sharenum, 'disconnect')
-            elif f.check(rrefutil.ServerFailure):
+            elif f.check(RemoteException):
                 return (False, sharenum, 'failure')
             elif f.check(layout.ShareVersionIncompatible):
                 return (False, sharenum, 'incompatible')
-            else:
+            elif f.check(layout.LayoutInvalid,
+                         layout.RidiculouslyLargeURIExtensionBlock,
+                         download.BadOrMissingHash,
+                         download.BadURIExtensionHashValue):
                 return (False, sharenum, 'corrupt')
 
+            # if it wasn't one of those reasons, re-raise the error
+            return f
+
         def _got_ueb(vup):
             self._share_hash_tree = hashtree.IncompleteHashTree(self._verifycap.total_shares)
             self._share_hash_tree.set_hashes({0: vup.share_root_hash})
@@ -238,7 +239,7 @@ class Checker(log.PrefixingLogMixin):
             return dl
 
         def _err(f):
-            f.trap(rrefutil.ServerFailure)
+            f.trap(RemoteException, DeadReferenceError)
             return (set(), serverid, set(), set(), False)
 
         d.addCallbacks(_got_buckets, _err)
index 2d544ad6c20d2c19732e579043715fb2017619e5..326bd572ffa8fe4d315dbdff8c332d5866a03de6 100644 (file)
@@ -3,11 +3,10 @@ from zope.interface import implements
 from twisted.internet import defer
 from twisted.internet.interfaces import IPushProducer, IConsumer
 from twisted.application import service
-from foolscap.api import DeadReferenceError, eventually
+from foolscap.api import DeadReferenceError, RemoteException, eventually
 
 from allmydata.util import base32, deferredutil, hashutil, log, mathutil
 from allmydata.util.assertutil import _assert, precondition
-from allmydata.util.rrefutil import ServerFailure
 from allmydata import codec, hashtree, uri
 from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI, IVerifierURI, \
      IDownloadStatus, IDownloadResults, IValidatedThingProxy, NotEnoughSharesError, \
@@ -82,11 +81,13 @@ class ValidatedThingObtainer:
         self._log_id = log_id
 
     def _bad(self, f, validatedthingproxy):
-        failtype = f.trap(ServerFailure, IntegrityCheckReject, layout.LayoutInvalid, layout.ShareVersionIncompatible, DeadReferenceError)
+        failtype = f.trap(RemoteException, DeadReferenceError,
+                          IntegrityCheckReject, layout.LayoutInvalid,
+                          layout.ShareVersionIncompatible)
         level = log.WEIRD
         if f.check(DeadReferenceError):
             level = log.UNUSUAL
-        elif f.check(ServerFailure):
+        elif f.check(RemoteException):
             level = log.WEIRD
         else:
             level = log.SCARY
@@ -476,8 +477,10 @@ class BlockDownloader(log.PrefixingLogMixin):
         self.parent.hold_block(self.blocknum, data)
 
     def _got_block_error(self, f):
-        failtype = f.trap(ServerFailure, IntegrityCheckReject, layout.LayoutInvalid, layout.ShareVersionIncompatible)
-        if f.check(ServerFailure):
+        failtype = f.trap(RemoteException, DeadReferenceError,
+                          IntegrityCheckReject,
+                          layout.LayoutInvalid, layout.ShareVersionIncompatible)
+        if f.check(RemoteException, DeadReferenceError):
             level = log.UNUSUAL
         else:
             level = log.WEIRD
index d7a4d7adab0b5e2c4dc72ce9f7d8f1589c2c7376..a8b270ebe5779a6957fb7eebd827683400b3394c 100644 (file)
@@ -14,7 +14,7 @@ from allmydata.storage.server import si_b2a
 from allmydata.immutable import encode
 from allmydata.util import base32, dictutil, idlib, log, mathutil
 from allmydata.util.assertutil import precondition
-from allmydata.util.rrefutil import get_versioned_remote_reference
+from allmydata.util.rrefutil import add_version_to_remote_reference
 from allmydata.interfaces import IUploadable, IUploader, IUploadResults, \
      IEncryptedUploadable, RIEncryptedUploadable, IUploadStatus, \
      NotEnoughSharesError, InsufficientVersionError, NoServersError
@@ -1227,7 +1227,7 @@ class Uploader(service.MultiService, log.PrefixingLogMixin):
                     { },
                     "application-version": "unknown: no get_version()",
                     }
-        d = get_versioned_remote_reference(helper, default)
+        d = add_version_to_remote_reference(helper, default)
         d.addCallback(self._got_versioned_helper)
 
     def _got_versioned_helper(self, helper):
index bd45f6e818ce1a3b9a5db0a413d6a5e3319ba6f3..db09c7eb9c46fe027aec466ca01f9610cf5cc2d9 100644 (file)
@@ -8,7 +8,7 @@ from allmydata.interfaces import InsufficientVersionError
 from allmydata.introducer.interfaces import RIIntroducerSubscriberClient, \
      IIntroducerClient
 from allmydata.util import log, idlib
-from allmydata.util.rrefutil import get_versioned_remote_reference
+from allmydata.util.rrefutil import add_version_to_remote_reference
 from allmydata.introducer.common import make_index
 
 
@@ -78,14 +78,14 @@ class RemoteServiceConnector:
         self.log("got connection to %s, getting versions" % self._nodeid_s)
 
         default = self.VERSION_DEFAULTS.get(self.service_name, {})
-        d = get_versioned_remote_reference(rref, default)
+        d = add_version_to_remote_reference(rref, default)
         d.addCallback(self._got_versioned_service)
 
     def _got_versioned_service(self, rref):
         self.log("connected to %s, version %s" % (self._nodeid_s, rref.version))
 
         self.last_connect_time = time.time()
-        self.remote_host = rref.rref.tracker.broker.transport.getPeer()
+        self.remote_host = rref.tracker.broker.transport.getPeer()
 
         self.rref = rref
 
@@ -156,7 +156,7 @@ class IntroducerClient(service.Service, Referenceable):
                     { },
                     "application-version": "unknown: no get_version()",
                     }
-        d = get_versioned_remote_reference(publisher, default)
+        d = add_version_to_remote_reference(publisher, default)
         d.addCallback(self._got_versioned_introducer)
         d.addErrback(self._got_error)
 
index 7d0ecdb7eb015cb173faa71939953838b2938544..89c8baffbcc095ff0024404eb70db977358a09da 100644 (file)
@@ -80,6 +80,7 @@ class KeyGeneratorService(service.MultiService):
         service.MultiService.__init__(self)
         self.basedir = basedir
         self.tub = Tub(certFile=os.path.join(self.basedir, 'key_generator.pem'))
+        self.tub.setOption("expose-remote-exception-types", False)
         self.tub.setServiceParent(self)
         self.key_generator = KeyGenerator(default_key_size=default_key_size)
         self.key_generator.setServiceParent(self)
index 12cf4ff018c26c5eb4e67489a504ca3d72cf4b0d..592e600edf5c43cf996ba44eec396a6640d8898f 100644 (file)
@@ -4,8 +4,8 @@ from zope.interface import implements
 from itertools import count
 from twisted.internet import defer
 from twisted.python import failure
-from foolscap.api import DeadReferenceError, eventually
-from allmydata.util import base32, hashutil, idlib, log, rrefutil
+from foolscap.api import DeadReferenceError, RemoteException, eventually
+from allmydata.util import base32, hashutil, idlib, log
 from allmydata.storage.server import si_b2a
 from allmydata.interfaces import IServermapUpdaterStatus
 from pycryptopp.publickey import rsa
@@ -546,13 +546,20 @@ class ServermapUpdater:
             def _done(res):
                 [(readv_success, readv_result),
                  (addlease_success, addlease_result)] = res
-                if (not addlease_success and
-                    not rrefutil.check_remote(addlease_result, IndexError)):
-                    # tahoe 1.3.0 raised IndexError on non-existant buckets,
-                    # which we ignore. Unfortunately tahoe <1.3.0 had a bug
-                    # and raised KeyError, which we report.
-                    return addlease_result # propagate error
-                return readv_result
+                # ignore remote IndexError on the add_lease call. Propagate
+                # local errors and remote non-IndexErrors
+                if addlease_success:
+                    return readv_result
+                if not addlease_result.check(RemoteException):
+                    # Propagate local errors
+                    return addlease_result
+                if addlease_result.value.failure.check(IndexError):
+                    # tahoe=1.3.0 raised IndexError on non-existant
+                    # buckets, which we ignore
+                    return readv_result
+                # propagate remote errors that aren't IndexError, including
+                # the unfortunate internal KeyError bug that <1.3.0 had.
+                return addlease_result
             dl.addCallback(_done)
             return dl
         return d
index b582d923a8af4e37384f57625a5d96e49fa10c87..582c590f724edb19fce954719cc91b163c0108ed 100644 (file)
@@ -141,6 +141,7 @@ class Node(service.MultiService):
         self.tub = Tub(certFile=certfile)
         self.tub.setOption("logLocalFailures", True)
         self.tub.setOption("logRemoteFailures", True)
+        self.tub.setOption("expose-remote-exception-types", False)
 
         # see #521 for a discussion of how to pick these timeout values.
         keepalive_timeout_s = self.get_config("node", "timeout.keepalive", "")
index 685aa3d0b56a35103da130377ad88c6f50ce053a..2ae63dbfccbc3da9ff52af2c9e7c65f5040819f0 100644 (file)
@@ -285,6 +285,7 @@ class StatsGathererService(service.MultiService):
         self.tub.setServiceParent(self)
         self.tub.setOption("logLocalFailures", True)
         self.tub.setOption("logRemoteFailures", True)
+        self.tub.setOption("expose-remote-exception-types", False)
 
         self.stats_gatherer = PickleStatsGatherer(self.basedir, verbose)
         self.stats_gatherer.setServiceParent(self)
index bfc6b91a136183be501736ecd0e641b8bf09f9c5..40cfa590ce3fa517f97fc89af1e5bfd5e6288af9 100644 (file)
@@ -74,6 +74,7 @@ class SystemFramework(pollmixin.PollMixin):
         self.sparent.startService()
         self.proc = None
         self.tub = Tub()
+        self.tub.setOption("expose-remote-exception-types", False)
         self.tub.setServiceParent(self.sparent)
         self.mode = mode
         self.failed = False
index f8d0fc1d63b0a9a03ff9f2d0f6a6f83f4d5c219a..d709f04861072308fc9e98beff96f1699dc64e84 100644 (file)
@@ -48,6 +48,7 @@ class SpeedTest:
     def setUp(self):
         self.base_service.startService()
         self.tub = Tub()
+        self.tub.setOption("expose-remote-exception-types", False)
         self.tub.setServiceParent(self.base_service)
         d = self.tub.getReference(self.control_furl)
         def _gotref(rref):
index d417f60a0d5c22cf39875853f348feee3f8a4983..5808b7203f6ed2f8bd36104860b0afcf0016b440 100644 (file)
@@ -17,12 +17,13 @@ import os.path
 import sha
 from twisted.application import service
 from twisted.internet import reactor
-from foolscap.api import Referenceable, fireEventually
+from twisted.python.failure import Failure
+from foolscap.api import Referenceable, fireEventually, RemoteException
 from base64 import b32encode
 from allmydata import uri as tahoe_uri
 from allmydata.client import Client
 from allmydata.storage.server import StorageServer, storage_index_to_dir
-from allmydata.util import fileutil, idlib, hashutil, rrefutil
+from allmydata.util import fileutil, idlib, hashutil
 from allmydata.introducer.client import RemoteServiceConnector
 from allmydata.test.common_web import HTTPClientGETFactory
 
@@ -61,6 +62,9 @@ class LocalWrapper:
             return meth(*args, **kwargs)
         d = fireEventually()
         d.addCallback(lambda res: _call())
+        def _wrap_exception(f):
+            return Failure(RemoteException(f))
+        d.addErrback(_wrap_exception)
         def _return_membrane(res):
             # rather than complete the difficult task of building a
             # fully-general Membrane (which would locate all Referenceable
@@ -88,20 +92,17 @@ class LocalWrapper:
         del self.disconnectors[marker]
 
 def wrap(original, service_name):
-    # The code in immutable.checker insists upon asserting the truth of
-    # isinstance(rref, rrefutil.WrappedRemoteReference). Much of the
-    # upload/download code uses rref.version (which normally comes from
-    # rrefutil.VersionedRemoteReference). To avoid using a network, we want a
-    # LocalWrapper here. Try to satisfy all these constraints at the same
-    # time.
-    local = LocalWrapper(original)
-    wrapped = rrefutil.WrappedRemoteReference(local)
+    # Much of the upload/download code uses rref.version (which normally
+    # comes from rrefutil.add_version_to_remote_reference). To avoid using a
+    # network, we want a LocalWrapper here. Try to satisfy all these
+    # constraints at the same time.
+    wrapper = LocalWrapper(original)
     try:
         version = original.remote_get_version()
     except AttributeError:
         version = RemoteServiceConnector.VERSION_DEFAULTS[service_name]
-    wrapped.version = version
-    return wrapped
+    wrapper.version = version
+    return wrapper
 
 class NoNetworkClient(Client):
 
index e9cd0c3765b0568e50b7c6d64961f18ed36a9af7..ea4486980df23c745e2ea7aa3767dde6df1c2b46 100644 (file)
@@ -95,6 +95,7 @@ class AssistedUpload(unittest.TestCase):
         self.s.startService()
 
         self.tub = t = Tub()
+        t.setOption("expose-remote-exception-types", False)
         t.setServiceParent(self.s)
         self.s.tub = t
         # we never actually use this for network traffic, so it can use a
index 6650bf2700af1f7dd523cc122ea30ad19fbfd554..535fab7796ee02155c7e0a67b7f1a59d31d690f9 100644 (file)
@@ -84,6 +84,7 @@ class SystemTestMixin(ServiceMixin, pollmixin.PollMixin):
         self.central_tub = tub = Tub()
         #tub.setOption("logLocalFailures", True)
         #tub.setOption("logRemoteFailures", True)
+        tub.setOption("expose-remote-exception-types", False)
         tub.setServiceParent(self.parent)
         l = tub.listenOn("tcp:0")
         portnum = l.getPortnum()
@@ -116,6 +117,7 @@ class SystemTest(SystemTestMixin, unittest.TestCase):
             tub = Tub()
             #tub.setOption("logLocalFailures", True)
             #tub.setOption("logRemoteFailures", True)
+            tub.setOption("expose-remote-exception-types", False)
             tub.setServiceParent(self.parent)
             l = tub.listenOn("tcp:0")
             portnum = l.getPortnum()
@@ -249,6 +251,7 @@ class NonV1Server(SystemTestMixin, unittest.TestCase):
         self.introducer_furl = self.central_tub.registerReference(i)
 
         tub = Tub()
+        tub.setOption("expose-remote-exception-types", False)
         tub.setServiceParent(self.parent)
         l = tub.listenOn("tcp:0")
         portnum = l.getPortnum()
index 34282b58b8174d8ae034edbf123e27760ae9ebde..45fbaea2220ce77ff7ee8e4a777f5ef335b98f12 100644 (file)
@@ -22,6 +22,7 @@ class KeyGenService(unittest.TestCase, pollmixin.PollMixin):
         self.parent.startService()
 
         self.tub = t = Tub()
+        t.setOption("expose-remote-exception-types", False)
         t.setServiceParent(self.parent)
         t.listenOn("tcp:0")
         t.setLocationAutomatically()
index 16a63f04178a156c03cf528020246fe5224c520f..2ae2667e77cebfc19883070f61010d7e35313f5a 100644 (file)
@@ -11,8 +11,7 @@ from twisted.python import log
 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
 from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
 from allmydata.util import limiter, time_format, pollmixin, cachedir
-from allmydata.util import statistics, dictutil, rrefutil, pipeline
-from allmydata.util.rrefutil import ServerFailure
+from allmydata.util import statistics, dictutil, pipeline
 
 class Base32(unittest.TestCase):
     def test_b2a_matches_Pythons(self):
@@ -1212,96 +1211,6 @@ class DictUtil(unittest.TestCase):
         self.failUnlessEqual(x, "b")
         self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
 
-class FakeRemoteReference:
-    def callRemote(self, methname, *args, **kwargs):
-        return defer.maybeDeferred(self.oops)
-    def oops(self):
-        raise IndexError("remote missing key")
-
-class RemoteFailures(unittest.TestCase):
-    def test_check(self):
-        check_local = rrefutil.check_local
-        check_remote = rrefutil.check_remote
-        try:
-            raise IndexError("local missing key")
-        except IndexError:
-            localf = Failure()
-
-        self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
-        self.failUnlessEqual(localf.check(ValueError, KeyError), None)
-        self.failUnlessEqual(localf.check(ServerFailure), None)
-        self.failUnlessEqual(check_local(localf, IndexError, KeyError),
-                             IndexError)
-        self.failUnlessEqual(check_local(localf, ValueError, KeyError), None)
-        self.failUnlessEqual(check_remote(localf, IndexError, KeyError), None)
-        self.failUnlessEqual(check_remote(localf, ValueError, KeyError), None)
-
-        frr = FakeRemoteReference()
-        wrr = rrefutil.WrappedRemoteReference(frr)
-        d = wrr.callRemote("oops")
-        def _check(f):
-            self.failUnlessEqual(f.check(IndexError, KeyError), None)
-            self.failUnlessEqual(f.check(ServerFailure, KeyError),
-                                 ServerFailure)
-            self.failUnlessEqual(check_remote(f, IndexError, KeyError),
-                                 IndexError)
-            self.failUnlessEqual(check_remote(f, ValueError, KeyError), None)
-            self.failUnlessEqual(check_local(f, IndexError, KeyError), None)
-            self.failUnlessEqual(check_local(f, ValueError, KeyError), None)
-        d.addErrback(_check)
-        return d
-
-    def test_is_remote(self):
-        try:
-            raise IndexError("local missing key")
-        except IndexError:
-            localf = Failure()
-        self.failIf(rrefutil.is_remote(localf))
-        self.failUnless(rrefutil.is_local(localf))
-
-        frr = FakeRemoteReference()
-        wrr = rrefutil.WrappedRemoteReference(frr)
-        d = wrr.callRemote("oops")
-        def _check(f):
-            self.failUnless(rrefutil.is_remote(f))
-            self.failIf(rrefutil.is_local(f))
-        d.addErrback(_check)
-        return d
-
-    def test_trap(self):
-        try:
-            raise IndexError("local missing key")
-        except IndexError:
-            localf = Failure()
-
-        self.failUnlessRaises(Failure, localf.trap, ValueError, KeyError)
-        self.failUnlessRaises(Failure, localf.trap, ServerFailure)
-        self.failUnlessEqual(localf.trap(IndexError, KeyError), IndexError)
-        self.failUnlessEqual(rrefutil.trap_local(localf, IndexError, KeyError),
-                             IndexError)
-        self.failUnlessRaises(Failure,
-                              rrefutil.trap_remote, localf, ValueError, KeyError)
-
-        frr = FakeRemoteReference()
-        wrr = rrefutil.WrappedRemoteReference(frr)
-        d = wrr.callRemote("oops")
-        def _check(f):
-            self.failUnlessRaises(Failure,
-                                  f.trap, ValueError, KeyError)
-            self.failUnlessRaises(Failure,
-                                  f.trap, IndexError)
-            self.failUnlessEqual(f.trap(ServerFailure), ServerFailure)
-            self.failUnlessRaises(Failure,
-                                  rrefutil.trap_remote, f, ValueError, KeyError)
-            self.failUnlessEqual(rrefutil.trap_remote(f, IndexError, KeyError),
-                                 IndexError)
-            self.failUnlessRaises(Failure,
-                                  rrefutil.trap_local, f, ValueError, KeyError)
-            self.failUnlessRaises(Failure,
-                                  rrefutil.trap_local, f, IndexError)
-        d.addErrback(_check)
-        return d
-
 class Pipeline(unittest.TestCase):
     def pause(self, *args, **kwargs):
         d = defer.Deferred()
@@ -1444,7 +1353,6 @@ class Pipeline(unittest.TestCase):
         self.failUnless(f.check(pipeline.PipelineError))
         f2 = f.value.error
         self.failUnless(f2.check(ValueError))
-        
 
     def test_errors2(self):
         self.calls = []
index 99e10774467a8163d4b3eeada6fa928248c331d7..097e732a1d04a97524777c66133a21d58630236f 100644 (file)
@@ -1,87 +1,19 @@
-import exceptions
 
-from foolscap.api import Violation
+from foolscap.api import Violation, RemoteException
 
-class ServerFailure(exceptions.Exception):
-    # If the server returns a Failure instead of the normal response to a
-    # protocol, then this exception will be raised, with the Failure that the
-    # server returned as its .remote_failure attribute.
-    def __init__(self, remote_failure):
-        self.remote_failure = remote_failure
-    def __repr__(self):
-        return repr(self.remote_failure)
-    def __str__(self):
-        return str(self.remote_failure)
-
-def is_remote(f):
-    if isinstance(f.value, ServerFailure):
-        return True
-    return False
-
-def is_local(f):
-    return not is_remote(f)
-
-def check_remote(f, *errorTypes):
-    if is_remote(f):
-        return f.value.remote_failure.check(*errorTypes)
-    return None
-
-def check_local(f, *errorTypes):
-    if is_local(f):
-        return f.check(*errorTypes)
-    return None
-
-def trap_remote(f, *errorTypes):
-    if is_remote(f):
-        return f.value.remote_failure.trap(*errorTypes)
-    raise f
-
-def trap_local(f, *errorTypes):
-    if is_local(f):
-        return f.trap(*errorTypes)
-    raise f
-
-def _wrap_server_failure(f):
-    raise ServerFailure(f)
-
-class WrappedRemoteReference(object):
-    """I intercept any errback from the server and wrap it in a
-    ServerFailure."""
-
-    def __init__(self, original):
-        self.rref = original
-
-    def callRemote(self, *args, **kwargs):
-        d = self.rref.callRemote(*args, **kwargs)
-        d.addErrback(_wrap_server_failure)
-        return d
-
-    def callRemoteOnly(self, *args, **kwargs):
-        return self.rref.callRemoteOnly(*args, **kwargs)
-
-    def notifyOnDisconnect(self, *args, **kwargs):
-        return self.rref.notifyOnDisconnect(*args, **kwargs)
-
-    def dontNotifyOnDisconnect(self, *args, **kwargs):
-        return self.rref.dontNotifyOnDisconnect(*args, **kwargs)
-
-class VersionedRemoteReference(WrappedRemoteReference):
-    """I wrap a RemoteReference, and add a .version attribute. I also
-    intercept any errback from the server and wrap it in a ServerFailure."""
-
-    def __init__(self, original, version):
-        WrappedRemoteReference.__init__(self, original)
-        self.version = version
-
-def get_versioned_remote_reference(rref, default):
-    """I return a Deferred that fires with a VersionedRemoteReference"""
+def add_version_to_remote_reference(rref, default):
+    """I try to add a .version attribute to the given RemoteReference. I call
+    the remote get_version() method to learn its version. I'll add the
+    default value if the remote side doesn't appear to have a get_version()
+    method."""
     d = rref.callRemote("get_version")
-    def _no_get_version(f):
-        f.trap(Violation, AttributeError)
-        return default
-    d.addErrback(_no_get_version)
     def _got_version(version):
-        return VersionedRemoteReference(rref, version)
-    d.addCallback(_got_version)
+        rref.version = version
+        return rref
+    def _no_get_version(f):
+        f.trap(Violation, RemoteException)
+        rref.version = default
+        return rref
+    d.addCallbacks(_got_version, _no_get_version)
     return d