]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/util/rrefutil.py
introducer: stop tracking hints for subscribed clients
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / util / rrefutil.py
index 1c037cc45d2a1624b11c63ff5e2d2ac75f7b00da..b991267f1ad1998e1f8994bd4f31b2a23807ee6d 100644 (file)
@@ -1,87 +1,46 @@
-import exceptions
 
-from foolscap.tokens import Violation
+from twisted.internet import address
+from foolscap.api import Violation, RemoteException, DeadReferenceError, \
+     SturdyRef
 
-class ServerFailure(exceptions.Exception):
-    # If the server returns a Failure instead of the normal response to a
-    # protocol, then this exception will be raised, with the Failure that the
-    # server returned as its .remote_failure attribute.
-    def __init__(self, remote_failure):
-        self.remote_failure = remote_failure
-    def __repr__(self):
-        return repr(self.remote_failure)
-    def __str__(self):
-        return str(self.remote_failure)
-
-def is_remote(f):
-    if isinstance(f.value, ServerFailure):
-        return True
-    return False
-
-def is_local(f):
-    return not is_remote(f)
-
-def check_remote(f, *errorTypes):
-    if is_remote(f):
-        return f.value.remote_failure.check(*errorTypes)
-    return None
-
-def check_local(f, *errorTypes):
-    if is_local(f):
-        return f.check(*errorTypes)
-    return None
-
-def trap_remote(f, *errorTypes):
-    if is_remote(f):
-        return f.value.remote_failure.trap(*errorTypes)
-    raise f
-
-def trap_local(f, *errorTypes):
-    if is_local(f):
-        return f.trap(*errorTypes)
-    raise f
-
-def _wrap_server_failure(f):
-    raise ServerFailure(f)
-
-class WrappedRemoteReference(object):
-    """I intercept any errback from the server and wrap it in a
-    ServerFailure."""
-
-    def __init__(self, original):
-        self.rref = original
-
-    def callRemote(self, *args, **kwargs):
-        d = self.rref.callRemote(*args, **kwargs)
-        d.addErrback(_wrap_server_failure)
-        return d
-
-    def callRemoteOnly(self, *args, **kwargs):
-        return self.rref.callRemoteOnly(*args, **kwargs)
-
-    def notifyOnDisconnect(self, *args, **kwargs):
-        return self.rref.notifyOnDisconnect(*args, **kwargs)
-
-    def dontNotifyOnDisconnect(self, *args, **kwargs):
-        return self.rref.dontNotifyOnDisconnect(*args, **kwargs)
-
-class VersionedRemoteReference(WrappedRemoteReference):
-    """I wrap a RemoteReference, and add a .version attribute. I also
-    intercept any errback from the server and wrap it in a ServerFailure."""
-
-    def __init__(self, original, version):
-        WrappedRemoteReference.__init__(self, original)
-        self.version = version
-
-def get_versioned_remote_reference(rref, default):
-    """I return a Deferred that fires with a VersionedRemoteReference"""
+def add_version_to_remote_reference(rref, default):
+    """I try to add a .version attribute to the given RemoteReference. I call
+    the remote get_version() method to learn its version. I'll add the
+    default value if the remote side doesn't appear to have a get_version()
+    method."""
     d = rref.callRemote("get_version")
-    def _no_get_version(f):
-        f.trap(Violation, AttributeError)
-        return default
-    d.addErrback(_no_get_version)
     def _got_version(version):
-        return VersionedRemoteReference(rref, version)
-    d.addCallback(_got_version)
+        rref.version = version
+        return rref
+    def _no_get_version(f):
+        f.trap(Violation, RemoteException)
+        rref.version = default
+        return rref
+    d.addCallbacks(_got_version, _no_get_version)
     return d
 
+def trap_and_discard(f, *errorTypes):
+    f.trap(*errorTypes)
+    pass
+
+def trap_deadref(f):
+    return trap_and_discard(f, DeadReferenceError)
+
+
+def hosts_for_furl(furl, ignore_localhost=True):
+    advertised = []
+    for hint in SturdyRef(furl).locationHints:
+        assert not isinstance(hint, str), hint
+        if hint[0] == "ipv4":
+            host = hint[1]
+            if ignore_localhost and host == "127.0.0.1":
+                continue
+            advertised.append(host)
+    return advertised
+
+def stringify_remote_address(rref):
+    remote = rref.getPeer()
+    if isinstance(remote, address.IPv4Address):
+        return "%s:%d" % (remote.host, remote.port)
+    # loopback is a non-IPv4Address
+    return str(remote)