SFTP: further improvements to test coverage.
authordavid-sarah <david-sarah@jacaranda.org>
Wed, 2 Jun 2010 23:44:22 +0000 (16:44 -0700)
committerdavid-sarah <david-sarah@jacaranda.org>
Wed, 2 Jun 2010 23:44:22 +0000 (16:44 -0700)
src/allmydata/frontends/sftpd.py
src/allmydata/test/test_sftp.py

index 14ff9d53e4c1d5924b077d37c84c81f293833841..8c2a38ad3210c8a7741c308d5a72327edea36acb 100644 (file)
@@ -697,33 +697,33 @@ class GeneralSFTPFile(PrefixingLogMixin):
         self.filenode = filenode
         self.metadata = metadata
 
-        if not self.closed:
-            tempfile_maker = EncryptedTemporaryFile
+        assert not self.closed
+        tempfile_maker = EncryptedTemporaryFile
 
-            if (self.flags & FXF_TRUNC) or not filenode:
-                # We're either truncating or creating the file, so we don't need the old contents.
-                self.consumer = OverwriteableFileConsumer(0, tempfile_maker)
-                self.consumer.finish()
+        if (self.flags & FXF_TRUNC) or not filenode:
+            # We're either truncating or creating the file, so we don't need the old contents.
+            self.consumer = OverwriteableFileConsumer(0, tempfile_maker)
+            self.consumer.finish()
+        else:
+            assert IFileNode.providedBy(filenode), filenode
+
+            # TODO: use download interface described in #993 when implemented.
+            if filenode.is_mutable():
+                self.async.addCallback(lambda ign: filenode.download_best_version())
+                def _downloaded(data):
+                    self.consumer = OverwriteableFileConsumer(len(data), tempfile_maker)
+                    self.consumer.write(data)
+                    self.consumer.finish()
+                    return None
+                self.async.addCallback(_downloaded)
             else:
-                assert IFileNode.providedBy(filenode), filenode
-
-                # TODO: use download interface described in #993 when implemented.
-                if filenode.is_mutable():
-                    self.async.addCallback(lambda ign: filenode.download_best_version())
-                    def _downloaded(data):
-                        self.consumer = OverwriteableFileConsumer(len(data), tempfile_maker)
-                        self.consumer.write(data)
-                        self.consumer.finish()
-                        return None
-                    self.async.addCallback(_downloaded)
-                else:
-                    download_size = filenode.get_size()
-                    assert download_size is not None, "download_size is None"
-                    self.consumer = OverwriteableFileConsumer(download_size, tempfile_maker)
-                    def _read(ign):
-                        if noisy: self.log("_read immutable", level=NOISY)
-                        filenode.read(self.consumer, 0, None)
-                    self.async.addCallback(_read)
+                download_size = filenode.get_size()
+                assert download_size is not None, "download_size is None"
+                self.consumer = OverwriteableFileConsumer(download_size, tempfile_maker)
+                def _read(ign):
+                    if noisy: self.log("_read immutable", level=NOISY)
+                    filenode.read(self.consumer, 0, None)
+                self.async.addCallback(_read)
 
         eventually_callback(self.async)(None)
 
@@ -1238,8 +1238,8 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
             d.addCallback(_got_file)
         return d
 
-    def openFile(self, pathstring, flags, attrs):
-        request = ".openFile(%r, %r = %r, %r)" % (pathstring, flags, _repr_flags(flags), attrs)
+    def openFile(self, pathstring, flags, attrs, delay=None):
+        request = ".openFile(%r, %r = %r, %r, delay=%r)" % (pathstring, flags, _repr_flags(flags), attrs, delay)
         self.log(request, level=OPERATIONAL)
 
         # This is used for both reading and writing.
@@ -1311,7 +1311,8 @@ class SFTPUserHandler(ConchUser, PrefixingLogMixin):
         # Note that the permission checks below are for more precise error reporting on
         # the open call; later operations would fail even if we did not make these checks.
 
-        d = self._get_root(path)
+        d = delay or defer.succeed(None)
+        d.addCallback(lambda ign: self._get_root(path))
         def _got_root( (root, path) ):
             if root.is_unknown():
                 raise SFTPError(FX_PERMISSION_DENIED,
index 2679c74a7073ec6c442fbda88ce9e89fee73c55f..06d38659c90c9d60a85ecf3521ae74f211432414 100644 (file)
@@ -3,7 +3,7 @@ import re, struct, traceback, gc, time, calendar
 from stat import S_IFREG, S_IFDIR
 
 from twisted.trial import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from twisted.python.failure import Failure
 from twisted.internet.error import ProcessDone, ProcessTerminated
 
@@ -871,6 +871,28 @@ class Handler(GridTestMixin, ShouldFailMixin, unittest.TestCase):
                       self.shouldFail(NoSuchChildError, "rename newexcl while open", "newexcl",
                                       self.root.get, u"newexcl"))
 
+        # it should be possible to rename even before the open has completed
+        def _open_and_rename_race(ign):
+            slow_open = defer.Deferred()
+            reactor.callLater(1, slow_open.callback, None)
+            d2 = self.handler.openFile("new", sftp.FXF_WRITE | sftp.FXF_CREAT, {}, delay=slow_open)
+
+            # deliberate race between openFile and renameFile
+            d3 = self.handler.renameFile("new", "new2")
+            return d2
+        d.addCallback(_open_and_rename_race)
+        def _write_rename_race(wf):
+            d2 = wf.writeChunk(0, "abcd")
+            d2.addCallback(lambda ign: wf.close())
+            return d2
+        d.addCallback(_write_rename_race)
+        d.addCallback(lambda ign: self.root.get(u"new2"))
+        d.addCallback(lambda node: download_to_data(node))
+        d.addCallback(lambda data: self.failUnlessReallyEqual(data, "abcd"))
+        d.addCallback(lambda ign:
+                      self.shouldFail(NoSuchChildError, "rename new while open", "new",
+                                      self.root.get, u"new"))
+
         return d
 
     def test_removeFile(self):