From a143b1297bfddf4baa522cb579585927899bd90b Mon Sep 17 00:00:00 2001
From: david-sarah <david-sarah@jacaranda.org>
Date: Tue, 25 May 2010 11:42:10 -0700
Subject: [PATCH] SFTP: handle removing a file while it is open. Also some
 simplifications of the logout handling.

---
 src/allmydata/frontends/sftpd.py | 97 ++++++++++++++++++++------------
 src/allmydata/test/test_sftp.py  | 36 +++++++++++-
 2 files changed, 97 insertions(+), 36 deletions(-)

diff --git a/src/allmydata/frontends/sftpd.py b/src/allmydata/frontends/sftpd.py
index b0e9ba2e..b6ccefb9 100644
--- a/src/allmydata/frontends/sftpd.py
+++ b/src/allmydata/frontends/sftpd.py
@@ -336,10 +336,9 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
     This abstraction is mostly independent of SFTP. Consider moving it, if it is found
     useful for other frontends."""
 
-    def __init__(self, check_abort, download_size, tempfile_maker):
+    def __init__(self, download_size, tempfile_maker):
         PrefixingLogMixin.__init__(self, facility="tahoe.sftp")
-        if noisy: self.log(".__init__(%r, %r, %r)" % (check_abort, download_size, tempfile_maker), level=NOISY)
-        self.check_abort = check_abort
+        if noisy: self.log(".__init__(%r, %r)" % (download_size, tempfile_maker), level=NOISY)
         self.download_size = download_size
         self.current_size = download_size
         self.f = tempfile_maker()
@@ -389,9 +388,6 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
         if noisy: self.log(".write(<data of length %r>)" % (len(data),), level=NOISY)
         if self.is_closed:
             return
-        if self.check_abort():
-            self.close()
-            return
 
         if self.downloaded >= self.download_size:
             return
@@ -631,13 +627,12 @@ class GeneralSFTPFile(PrefixingLogMixin):
     file handle, and requests to my OverwriteableFileConsumer. This queue is
     implemented by the callback chain of self.async."""
 
-    def __init__(self, close_notify, check_abort, flags, convergence, parent=None, childname=None, filenode=None, metadata=None):
+    def __init__(self, close_notify, flags, convergence, parent=None, childname=None, filenode=None, metadata=None):
         PrefixingLogMixin.__init__(self, facility="tahoe.sftp")
-        if noisy: self.log(".__init__(%r, %r, %r, <convergence censored>, parent=%r, childname=%r, filenode=%r, metadata=%r)" %
-                           (close_notify, check_abort, flags, parent, childname, filenode, metadata), level=NOISY)
+        if noisy: self.log(".__init__(%r, %r, <convergence censored>, parent=%r, childname=%r, filenode=%r, metadata=%r)" %
+                           (close_notify, flags, parent, childname, filenode, metadata), level=NOISY)
 
         self.close_notify = close_notify
-        self.check_abort = check_abort
         self.flags = flags
         self.convergence = convergence
         self.parent = parent
@@ -648,6 +643,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
         # Creating or truncating the file is a change, but if FXF_EXCL is set, a zero-length file has already been created.
         self.has_changed = (flags & (FXF_CREAT | FXF_TRUNC)) and not (flags & FXF_EXCL)
         self.closed = False
+        self.removed = False
         
         # self.consumer should only be relied on in callbacks for self.async, since it might
         # not be set before then.
@@ -656,7 +652,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
 
         if (flags & FXF_TRUNC) or not filenode:
             # We're either truncating or creating the file, so we don't need the old contents.
-            self.consumer = OverwriteableFileConsumer(self.check_abort, 0, tempfile_maker)
+            self.consumer = OverwriteableFileConsumer(0, tempfile_maker)
             self.consumer.finish()
         else:
             assert IFileNode.providedBy(filenode), filenode
@@ -665,7 +661,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
             if filenode.is_mutable():
                 self.async.addCallback(lambda ign: filenode.download_best_version())
                 def _downloaded(data):
-                    self.consumer = OverwriteableFileConsumer(self.check_abort, len(data), tempfile_maker)
+                    self.consumer = OverwriteableFileConsumer(len(data), tempfile_maker)
                     self.consumer.write(data)
                     self.consumer.finish()
                     return None
@@ -673,7 +669,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
             else:
                 download_size = filenode.get_size()
                 assert download_size is not None, "download_size is None"
-                self.consumer = OverwriteableFileConsumer(self.check_abort, download_size, tempfile_maker)
+                self.consumer = OverwriteableFileConsumer(download_size, tempfile_maker)
                 def _read(ign):
                     if noisy: self.log("_read immutable", level=NOISY)
                     filenode.read(self.consumer, 0, None)
@@ -687,6 +683,11 @@ class GeneralSFTPFile(PrefixingLogMixin):
         self.parent = new_parent
         self.childname = new_childname
 
+    def remove(self):
+        self.log(".remove()", level=OPERATIONAL)
+
+        self.removed = True
+
     def readChunk(self, offset, length):
         request = ".readChunk(%r, %r)" % (offset, length)
         self.log(request, level=OPERATIONAL)
@@ -763,7 +764,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
 
         def _close(ign):
             d2 = defer.succeed(None)
-            if self.has_changed:
+            if self.has_changed and not self.removed:
                 d2.addCallback(lambda ign: self.consumer.when_done())
                 if self.filenode and self.filenode.is_mutable():
                     d2.addCallback(lambda ign: self.consumer.get_current_size())
@@ -785,6 +786,13 @@ class GeneralSFTPFile(PrefixingLogMixin):
 
         def _closed(res):
             self.close_notify(self.parent, self.childname, self)
+
+            # It is possible for there to be a race between adding the file and removing it.
+            if self.removed:
+                self.log("oops, we added %r but must now remove it" % (self.childname,), level=OPERATIONAL)
+                d2 = self.parent.delete(self.childname)
+                d2.addBoth(lambda ign: res)
+                return d2
             return res
         d.addBoth(_closed)
         d.addBoth(_convert_error, request)
@@ -878,7 +886,6 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
         self._root = rootnode
         self._username = username
         self._convergence = client.convergence
-        self._logged_out = False
         self._open_files = {}  # files created by this user handler and still open
 
     def gotVersion(self, otherVersion, extData):
@@ -904,20 +911,36 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
             else:
                 all_open_files[direntry] = (files_to_add, time())
 
-    def _remove_open_files(self, direntry, files_to_remove):
-        if direntry and not self._logged_out:
-            assert direntry in self._open_files, (direntry, self._open_files)
-            assert direntry in all_open_files, (direntry, all_open_files)
-
+    def _remove_any_open_files(self, direntry):
+        removed = False
+        if direntry in self._open_files:
+            for f in self._open_files[direntry]:
+                f.remove()
+            del self._open_files[direntry]
+            removed = True
+
+        if direntry in all_open_files:
+            (files, opentime) = all_open_files[direntry]
+            for f in files:
+                f.remove()
+            del all_open_files[direntry]
+            removed = True
+
+        return removed
+
+    def _close_notify(self, parent, childname, file_to_remove):
+        direntry = self._direntry_for(parent, childname)
+        if direntry in self._open_files:
             old_files = self._open_files[direntry]
-            new_files = [f for f in old_files if f not in files_to_remove]
+            new_files = [f for f in old_files if f is not file_to_remove]
             if len(new_files) > 0:
                 self._open_files[direntry] = new_files
             else:
                 del self._open_files[direntry]
 
+        if direntry in all_open_files:
             (all_old_files, opentime) = all_open_files[direntry]
-            all_new_files = [f for f in all_old_files if f not in files_to_remove]
+            all_new_files = [f for f in all_old_files if f is not file_to_remove]
             if len(all_new_files) > 0:
                 all_open_files[direntry] = (all_new_files, opentime)
             else:
@@ -950,16 +973,12 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
         return None
 
     def logout(self):
-        if not self._logged_out:
-            self._logged_out = True
-            for (direntry, files_at_direntry) in enumerate(self._open_files):
-                self._remove_open_files(direntry, files_at_direntry)
-
-    def _check_abort(self):
-        return self._logged_out
+        self.log(".logout()", level=OPERATIONAL)
 
-    def _close_notify(self, parent, childname, f):
-        self._remove_open_files(self._direntry_for(parent, childname), [f])
+        for (direntry, files_at_direntry) in enumerate(self._open_files):
+            for f in files_at_direntry:
+                f.remove()
+                f.close()
 
     def _make_file(self, flags, parent=None, childname=None, filenode=None, metadata=None):
         if noisy: self.log("._make_file(%r = %r, parent=%r, childname=%r, filenode=%r, metadata=%r" %
@@ -975,7 +994,7 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
             if writing:
                 direntry = self._direntry_for(parent, childname)
 
-            file = GeneralSFTPFile(self._close_notify, self._check_abort, flags, self._convergence,
+            file = GeneralSFTPFile(self._close_notify, flags, self._convergence,
                                    parent=parent, childname=childname, filenode=filenode, metadata=metadata)
             self._add_open_files(direntry, [file])
             return file
@@ -1247,7 +1266,9 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
             # might not be enforced correctly if the type has just changed.
 
             if childname is None:
-                raise SFTPError(FX_NO_SUCH_FILE, "cannot delete an object specified by URI")
+                raise SFTPError(FX_NO_SUCH_FILE, "cannot remove an object specified by URI")
+
+            removed = self._remove_any_open_files(self._direntry_for(parent, childname))
 
             d2 = parent.get(childname)
             def _got_child(child):
@@ -1257,7 +1278,13 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
                 if must_be_file and IDirectoryNode.providedBy(child):
                     raise SFTPError(FX_PERMISSION_DENIED, "rmfile called on a directory")
                 return parent.delete(childname)
-            d2.addCallback(_got_child)
+            def _no_child(err):
+                if removed and err.check(NoSuchChildError):
+                    # suppress NoSuchChildError if an open file was removed
+                    return None
+                else:
+                    return err
+            d2.addCallbacks(_got_child, _no_child)
             return d2
         d.addCallback(_got_parent)
         return d
@@ -1362,7 +1389,7 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
         if "size" in attrs:
             # this would require us to download and re-upload the truncated/extended
             # file contents
-            def _unsupported(): raise SFTPError(FX_OP_UNSUPPORTED, "setAttrs wth size attribute")
+            def _unsupported(): raise SFTPError(FX_OP_UNSUPPORTED, "setAttrs wth size attribute unsupported")
             return defer.execute(_unsupported)
         return defer.succeed(None)
 
diff --git a/src/allmydata/test/test_sftp.py b/src/allmydata/test/test_sftp.py
index ca21c3b5..b3baef2f 100644
--- a/src/allmydata/test/test_sftp.py
+++ b/src/allmydata/test/test_sftp.py
@@ -892,7 +892,7 @@ class Handler(GridTestMixin, ShouldFailMixin, unittest.TestCase):
         # removing a link to an open file should not prevent it from being read
         d.addCallback(lambda ign: self.handler.openFile("small", sftp.FXF_READ, {}))
         def _remove_and_read_small(rf):
-            d2= self.handler.removeFile("small")
+            d2 = self.handler.removeFile("small")
             d2.addCallback(lambda ign:
                            self.shouldFail(NoSuchChildError, "removeFile small", "small",
                                            self.root.get, u"small"))
@@ -902,6 +902,40 @@ class Handler(GridTestMixin, ShouldFailMixin, unittest.TestCase):
             return d2
         d.addCallback(_remove_and_read_small)
 
+        # removing a link to a created file should prevent it from being created
+        d.addCallback(lambda ign: self.handler.openFile("tempfile", sftp.FXF_READ | sftp.FXF_WRITE |
+                                                                    sftp.FXF_CREAT, {}))
+        def _write_remove(rwf):
+            d2 = rwf.writeChunk(0, "0123456789")
+            d2.addCallback(lambda ign: self.handler.removeFile("tempfile"))
+            d2.addCallback(lambda ign: rwf.readChunk(0, 10))
+            d2.addCallback(lambda data: self.failUnlessReallyEqual(data, "0123456789"))
+            d2.addCallback(lambda ign: rwf.close())
+            return d2
+        d.addCallback(_write_remove)
+        d.addCallback(lambda ign:
+                      self.shouldFail(NoSuchChildError, "removeFile tempfile", "tempfile",
+                                      self.root.get, u"tempfile"))
+
+        # ... even if the link is renamed while open
+        d.addCallback(lambda ign: self.handler.openFile("tempfile2", sftp.FXF_READ | sftp.FXF_WRITE |
+                                                                     sftp.FXF_CREAT, {}))
+        def _write_rename_remove(rwf):
+            d2 = rwf.writeChunk(0, "0123456789")
+            d2.addCallback(lambda ign: self.handler.renameFile("tempfile2", "tempfile3"))
+            d2.addCallback(lambda ign: self.handler.removeFile("tempfile3"))
+            d2.addCallback(lambda ign: rwf.readChunk(0, 10))
+            d2.addCallback(lambda data: self.failUnlessReallyEqual(data, "0123456789"))
+            d2.addCallback(lambda ign: rwf.close())
+            return d2
+        d.addCallback(_write_rename_remove)
+        d.addCallback(lambda ign:
+                      self.shouldFail(NoSuchChildError, "removeFile tempfile2", "tempfile2",
+                                      self.root.get, u"tempfile2"))
+        d.addCallback(lambda ign:
+                      self.shouldFail(NoSuchChildError, "removeFile tempfile3", "tempfile3",
+                                      self.root.get, u"tempfile3"))
+
         return d
 
     def test_removeDirectory(self):
-- 
2.45.2