]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/frontends/sftpd.py
Improve SFTP error handling and remove use of IFinishableConsumer. fixes #1525
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / frontends / sftpd.py
index 4e16d6ab344bc3beba2fa47048d7db00c1a3dbcb..f6e11cedf0022e64ba8f9c8120acac1e48782ddc 100644 (file)
@@ -22,7 +22,7 @@ from twisted.python.failure import Failure
 from twisted.internet.interfaces import ITransport
 
 from twisted.internet import defer
-from twisted.internet.interfaces import IFinishableConsumer
+from twisted.internet.interfaces import IConsumer
 from foolscap.api import eventually
 from allmydata.util import deferredutil
 
@@ -294,7 +294,7 @@ def _direntry_for(filenode_or_parent, childname, filenode=None):
 
 
 class OverwriteableFileConsumer(PrefixingLogMixin):
-    implements(IFinishableConsumer)
+    implements(IConsumer)
     """I act both as a consumer for the download of the original file contents, and as a
     wrapper for a temporary file that records the downloaded data and any overwrites.
     I use a priority queue to keep track of which regions of the file have been overwritten
@@ -321,12 +321,9 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
         self.milestones = []  # empty heap of (offset, d)
         self.overwrites = []  # empty heap of (start, end)
         self.is_closed = False
-        self.done = self.when_reached(download_size)  # adds a milestone
-        self.is_done = False
-        def _signal_done(ign):
-            if noisy: self.log("DONE", level=NOISY)
-            self.is_done = True
-        self.done.addCallback(_signal_done)
+
+        self.done = defer.Deferred()
+        self.done_status = None  # None -> not complete, Failure -> download failed, str -> download succeeded
         self.producer = None
 
     def get_file(self):
@@ -349,7 +346,7 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
             self.download_size = size
 
         if self.downloaded >= self.download_size:
-            self.finish()
+            self.download_done("size changed")
 
     def registerProducer(self, p, streaming):
         if noisy: self.log(".registerProducer(%r, streaming=%r)" % (p, streaming), level=NOISY)
@@ -362,7 +359,7 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
             p.resumeProducing()
         else:
             def _iterate():
-                if not self.is_done:
+                if self.done_status is None:
                     p.resumeProducing()
                     eventually(_iterate)
             _iterate()
@@ -429,13 +426,17 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
                 return
             if noisy: self.log("MILESTONE %r %r" % (next, d), level=NOISY)
             heapq.heappop(self.milestones)
-            eventually(d.callback, None)
+            eventually_callback(d)("reached")
 
         if milestone >= self.download_size:
-            self.finish()
+            self.download_done("reached download size")
 
     def overwrite(self, offset, data):
         if noisy: self.log(".overwrite(%r, <data of length %r>)" % (offset, len(data)), level=NOISY)
+        if self.is_closed:
+            self.log("overwrite called on a closed OverwriteableFileConsumer", level=WEIRD)
+            raise SFTPError(FX_BAD_MESSAGE, "cannot write to a closed file handle")
+
         if offset > self.current_size:
             # Normally writing at an offset beyond the current end-of-file
             # would leave a hole that appears filled with zeroes. However, an
@@ -463,6 +464,9 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
         The caller must perform no more overwrites until the Deferred has fired."""
 
         if noisy: self.log(".read(%r, %r), current_size = %r" % (offset, length, self.current_size), level=NOISY)
+        if self.is_closed:
+            self.log("read called on a closed OverwriteableFileConsumer", level=WEIRD)
+            raise SFTPError(FX_BAD_MESSAGE, "cannot read from a closed file handle")
 
         # Note that the overwrite method is synchronous. When a write request is processed
         # (e.g. a writeChunk request on the async queue of GeneralSFTPFile), overwrite will
@@ -478,48 +482,68 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
             if noisy: self.log("truncating read to %r bytes" % (length,), level=NOISY)
 
         needed = min(offset + length, self.download_size)
-        d = self.when_reached(needed)
-        def _reached(ign):
+
+        # If we fail to reach the needed number of bytes, the read request will fail.
+        d = self.when_reached_or_failed(needed)
+        def _reached_in_read(res):
             # It is not necessarily the case that self.downloaded >= needed, because
             # the file might have been truncated (thus truncating the download) and
             # then extended.
 
             _assert(self.current_size >= offset + length,
                     current_size=self.current_size, offset=offset, length=length)
-            if noisy: self.log("self.f = %r" % (self.f,), level=NOISY)
+            if noisy: self.log("_reached_in_read(%r), self.f = %r" % (res, self.f,), level=NOISY)
             self.f.seek(offset)
             return self.f.read(length)
-        d.addCallback(_reached)
+        d.addCallback(_reached_in_read)
         return d
 
-    def when_reached(self, index):
-        if noisy: self.log(".when_reached(%r)" % (index,), level=NOISY)
-        if index <= self.downloaded:  # already reached
-            if noisy: self.log("already reached %r" % (index,), level=NOISY)
-            return defer.succeed(None)
+    def when_reached_or_failed(self, index):
+        if noisy: self.log(".when_reached_or_failed(%r)" % (index,), level=NOISY)
+        def _reached(res):
+            if noisy: self.log("reached %r with result %r" % (index, res), level=NOISY)
+            return res
+
+        if self.done_status is not None:
+            return defer.execute(_reached, self.done_status)
+        if index <= self.downloaded:  # already reached successfully
+            if noisy: self.log("already reached %r successfully" % (index,), level=NOISY)
+            return defer.succeed("already reached successfully")
         d = defer.Deferred()
-        def _reached(ign):
-            if noisy: self.log("reached %r" % (index,), level=NOISY)
-            return ign
         d.addCallback(_reached)
         heapq.heappush(self.milestones, (index, d))
         return d
 
     def when_done(self):
-        return self.done
+        d = defer.Deferred()
+        self.done.addCallback(lambda ign: eventually_callback(d)(self.done_status))
+        return d
 
-    def finish(self):
-        """Called by the producer when it has finished producing, or when we have
-        received enough bytes, or as a result of a close. Defined by IFinishableConsumer."""
+    def download_done(self, res):
+        _assert(isinstance(res, (str, Failure)), res=res)
+        # Only the first call to download_done counts, but we log subsequent calls
+        # (multiple calls are normal).
+        if self.done_status is not None:
+            self.log("IGNORING extra call to download_done with result %r; previous result was %r"
+                     % (res, self.done_status), level=OPERATIONAL)
+            return
+
+        self.log("DONE with result %r" % (res,), level=OPERATIONAL)
+
+        # We avoid errbacking self.done so that we are not left with an 'Unhandled error in Deferred'
+        # in case when_done() is never called. Instead we stash the failure in self.done_status,
+        # from where the callback added in when_done() can retrieve it.
+        self.done_status = res
+        eventually_callback(self.done)(None)
 
         while len(self.milestones) > 0:
             (next, d) = self.milestones[0]
-            if noisy: self.log("MILESTONE FINISH %r %r" % (next, d), level=NOISY)
+            if noisy: self.log("MILESTONE FINISH %r %r %r" % (next, d, res), level=NOISY)
             heapq.heappop(self.milestones)
             # The callback means that the milestone has been reached if
             # it is ever going to be. Note that the file may have been
             # truncated to before the milestone.
-            eventually(d.callback, None)
+            eventually_callback(d)(res)
 
     def close(self):
         if not self.is_closed:
@@ -528,10 +552,14 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
                 self.f.close()
             except Exception, e:
                 self.log("suppressed %r from close of temporary file %r" % (e, self.f), level=WEIRD)
-        self.finish()
+        self.download_done("closed")
+        return self.done_status
 
     def unregisterProducer(self):
-        pass
+        # This will happen just before our client calls download_done, which will tell
+        # us the outcome of the download; we don't know the outcome at this point.
+        self.producer = None
+        self.log("producer unregistered", level=NOISY)
 
 
 SIZE_THRESHOLD = 1000
@@ -577,9 +605,9 @@ class ShortReadOnlySFTPFile(PrefixingLogMixin):
             # i.e. we respond with an EOF error iff offset is already at EOF.
 
             if offset >= len(data):
-                eventually(d.errback, SFTPError(FX_EOF, "read at or past end of file"))
+                eventually_errback(d)(Failure(SFTPError(FX_EOF, "read at or past end of file")))
             else:
-                eventually(d.callback, data[offset:offset+length])  # truncated if offset+length > len(data)
+                eventually_callback(d)(data[offset:offset+length])  # truncated if offset+length > len(data)
             return data
         self.async.addCallbacks(_read, eventually_errback(d))
         d.addBoth(_convert_error, request)
@@ -672,7 +700,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
         if (self.flags & FXF_TRUNC) or not filenode:
             # We're either truncating or creating the file, so we don't need the old contents.
             self.consumer = OverwriteableFileConsumer(0, tempfile_maker)
-            self.consumer.finish()
+            self.consumer.download_done("download not needed")
         else:
             self.async.addCallback(lambda ignored: filenode.get_best_readable_version())
 
@@ -683,10 +711,16 @@ class GeneralSFTPFile(PrefixingLogMixin):
 
                 self.consumer = OverwriteableFileConsumer(download_size, tempfile_maker)
 
-                version.read(self.consumer, 0, None)
+                d = version.read(self.consumer, 0, None)
+                def _finished(res):
+                    if not isinstance(res, Failure):
+                        res = "download finished"
+                    self.consumer.download_done(res)
+                d.addBoth(_finished)
+                # It is correct to drop d here.
             self.async.addCallback(_read)
 
-        eventually(self.async.callback, None)
+        eventually_callback(self.async)(None)
 
         if noisy: self.log("open done", level=NOISY)
         return self
@@ -739,7 +773,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
         def _read(ign):
             if noisy: self.log("_read in readChunk(%r, %r)" % (offset, length), level=NOISY)
             d2 = self.consumer.read(offset, length)
-            d2.addCallbacks(eventually_callback(d), eventually_errback(d))
+            d2.addBoth(eventually_callback(d))
             # It is correct to drop d2 here.
             return None
         self.async.addCallbacks(_read, eventually_errback(d))
@@ -783,6 +817,24 @@ class GeneralSFTPFile(PrefixingLogMixin):
         # don't addErrback to self.async, just allow subsequent async ops to fail.
         return defer.succeed(None)
 
+    def _do_close(self, res, d=None):
+        if noisy: self.log("_do_close(%r)" % (res,), level=NOISY)
+        status = None
+        if self.consumer:
+            status = self.consumer.close()
+
+        # We must close_notify before re-firing self.async.
+        if self.close_notify:
+            self.close_notify(self.userpath, self.parent, self.childname, self)
+
+        if not isinstance(res, Failure) and isinstance(status, Failure):
+            res = status
+
+        if d:
+            eventually_callback(d)(res)
+        elif isinstance(res, Failure):
+            self.log("suppressing %r" % (res,), level=OPERATIONAL)
+
     def close(self):
         request = ".close()"
         self.log(request, level=OPERATIONAL)
@@ -794,10 +846,14 @@ class GeneralSFTPFile(PrefixingLogMixin):
         self.closed = True
 
         if not (self.flags & (FXF_WRITE | FXF_CREAT)):
-            def _readonly_close():
-                if self.consumer:
-                    self.consumer.close()
-            return defer.execute(_readonly_close)
+            # We never fail a close of a handle opened only for reading, even if the file
+            # failed to download. (We could not do so deterministically, because it would
+            # depend on whether we reached the point of failure before abandoning the
+            # download.) Any reads that depended on file content that could not be downloaded
+            # will have failed. It is important that we don't close the consumer until
+            # previous read operations have completed.
+            self.async.addBoth(self._do_close)
+            return defer.succeed(None)
 
         # We must capture the abandoned, parent, and childname variables synchronously
         # at the close call. This is needed by the correctness arguments in the comments
@@ -811,20 +867,11 @@ class GeneralSFTPFile(PrefixingLogMixin):
         # it is correct to optimize out the commit if it is False at the close call.
         has_changed = self.has_changed
 
-        def _committed(res):
-            if noisy: self.log("_committed(%r)" % (res,), level=NOISY)
-
-            self.consumer.close()
-
-            # We must close_notify before re-firing self.async.
-            if self.close_notify:
-                self.close_notify(self.userpath, self.parent, self.childname, self)
-            return res
-
-        def _close(ign):
+        def _commit(ign):
             d2 = self.consumer.when_done()
             if self.filenode and self.filenode.is_mutable():
-                self.log("update mutable file %r childname=%r metadata=%r" % (self.filenode, childname, self.metadata), level=OPERATIONAL)
+                self.log("update mutable file %r childname=%r metadata=%r"
+                         % (self.filenode, childname, self.metadata), level=OPERATIONAL)
                 if self.metadata.get('no-write', False) and not self.filenode.is_readonly():
                     _assert(parent and childname, parent=parent, childname=childname, metadata=self.metadata)
                     d2.addCallback(lambda ign: parent.set_metadata_for(childname, self.metadata))
@@ -836,22 +883,19 @@ class GeneralSFTPFile(PrefixingLogMixin):
                     u = FileHandle(self.consumer.get_file(), self.convergence)
                     return parent.add_file(childname, u, metadata=self.metadata)
                 d2.addCallback(_add_file)
-
-            d2.addBoth(_committed)
             return d2
 
-        d = defer.Deferred()
-
         # If the file has been abandoned, we don't want the close operation to get "stuck",
-        # even if self.async fails to re-fire. Doing the close independently of self.async
-        # in that case ensures that dropping an ssh connection is sufficient to abandon
+        # even if self.async fails to re-fire. Completing the close independently of self.async
+        # in that case should ensure that dropping an ssh connection is sufficient to abandon
         # any heisenfiles that were not explicitly closed in that connection.
         if abandoned or not has_changed:
-            d.addCallback(_committed)
+            d = defer.succeed(None)
+            self.async.addBoth(self._do_close)
         else:
-            self.async.addCallback(_close)
-
-        self.async.addCallbacks(eventually_callback(d), eventually_errback(d))
+            d = defer.Deferred()
+            self.async.addCallback(_commit)
+            self.async.addBoth(self._do_close, d)
         d.addBoth(_convert_error, request)
         return d
 
@@ -873,7 +917,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
 
             # self.filenode might be None, but that's ok.
             attrs = _populate_attrs(self.filenode, self.metadata, size=self.consumer.get_current_size())
-            eventually(d.callback, attrs)
+            eventually_callback(d)(attrs)
             return None
         self.async.addCallbacks(_get, eventually_errback(d))
         d.addBoth(_convert_error, request)
@@ -911,7 +955,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
                 # TODO: should we refuse to truncate a file opened with FXF_APPEND?
                 # <http://allmydata.org/trac/tahoe-lafs/ticket/1037#comment:20>
                 self.consumer.set_current_size(size)
-            eventually(d.callback, None)
+            eventually_callback(d)(None)
             return None
         self.async.addCallbacks(_set, eventually_errback(d))
         d.addBoth(_convert_error, request)