]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
SFTP: implement execCommand to interoperate with clients that issue a 'df -P -k ...
authordavid-sarah <david-sarah@jacaranda.org>
Sun, 16 May 2010 01:27:54 +0000 (18:27 -0700)
committerdavid-sarah <david-sarah@jacaranda.org>
Sun, 16 May 2010 01:27:54 +0000 (18:27 -0700)
src/allmydata/frontends/sftpd.py
src/allmydata/test/test_sftp.py

index 826f8192a8b17fa54630fcb733615477a1a1aaed..5407ebad272433ece64a62e97beded4b22610497 100644 (file)
@@ -16,6 +16,7 @@ from twisted.conch.interfaces import ISFTPServer, ISFTPFile, IConchUser, ISessio
 from twisted.conch.avatar import ConchUser
 from twisted.conch.openssh_compat import primes
 from twisted.cred import portal
+from twisted.internet.error import ProcessDone, ProcessTerminated
 
 from twisted.internet import defer
 from twisted.internet.interfaces import IFinishableConsumer
@@ -786,6 +787,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
         self.async.addCallbacks(_resize, eventually_errback(d))
         return d
 
+
 class StoppableList:
     def __init__(self, items):
         self.items = items
@@ -796,17 +798,63 @@ class StoppableList:
         pass
 
 
-class SFTPHandler(PrefixingLogMixin):
-    implements(ISFTPServer)
-    def __init__(self, user):
+class Reason:
+    def __init__(self, value):
+        self.value = value
+
+
+class SFTPUserHandler(ConchUser, PrefixingLogMixin):
+    implements(ISFTPServer, ISession)
+    def __init__(self, check_abort, client, rootnode, username):
+        ConchUser.__init__(self)
         PrefixingLogMixin.__init__(self, facility="tahoe.sftp")
-        if noisy: self.log(".__init__(%r)" % (user,), level=NOISY)
+        if noisy: self.log(".__init__(%r, %r, %r, %r)" %
+                           (check_abort, client, rootnode, username), level=NOISY)
+
+        self.channelLookup["session"] = session.SSHSession
+        self.subsystemLookup["sftp"] = FileTransferServer
+
+        self.client = client
+        self.root = rootnode
+        self.username = username
+        self.convergence = client.convergence
+        self.logged_out = False
+
+    def logout(self):
+        self.logged_out = True
+
+    def check_abort(self):
+        return self.logged_out
+
+    # ISession
+    # This is needed because some clients may try to issue a 'df' command.
+
+    def getPty(self, terminal, windowSize, attrs):
+        self.log(".getPty(%r, %r, %r)" % (terminal, windowSize, attrs), level=OPERATIONAL)
 
-        self.check_abort = user.check_abort
-        self.client = user.client
-        self.root = user.root
-        self.username = user.username
-        self.convergence = user.convergence
+    def openShell(self, protocol):
+        self.log(".openShell(%r)" % (protocol,), level=OPERATIONAL)
+        raise NotImplementedError
+
+    def execCommand(self, protocol, cmd):
+        self.log(".execCommand(%r, %r)" % (protocol, cmd), level=OPERATIONAL)
+        if cmd == "df -P -k /":
+            protocol.write("Filesystem         1024-blocks      Used Available Capacity Mounted on\n"
+                           "tahoe                628318530 314159265 314159265      50% /\n")
+            protocol.processEnded(Reason(ProcessDone(None)))
+        else:
+            protocol.processEnded(Reason(ProcessTerminated(exitCode=1)))
+
+    def windowChanged(self, newWindowSize):
+        self.log(".windowChanged(%r)" % (newWindowSize,), level=OPERATIONAL)
+
+    def eofReceived(self):
+        self.log(".eofReceived()", level=OPERATIONAL)
+
+    def closed(self):
+        self.log(".closed()", level=OPERATIONAL)
+
+    # ISFTPServer
 
     def gotVersion(self, otherVersion, extData):
         self.log(".gotVersion(%r, %r)" % (otherVersion, extData), level=OPERATIONAL)
@@ -1262,14 +1310,9 @@ class Dispatcher:
     def requestAvatar(self, avatarID, mind, interface):
         assert interface == IConchUser
         rootnode = self.client.create_node_from_uri(avatarID.rootcap)
-        convergence = self.client.convergence
-        logged_out = {'flag': False}
-        def check_abort():
-            return logged_out['flag']
-        def logout():
-            logged_out['flag'] = True
-        s = SFTPUser(check_abort, self.client, rootnode, avatarID.username, convergence)
-        return (interface, s, logout)
+        handler = SFTPUserHandler(self.client, rootnode, avatarID.username)
+        return (interface, handler, handler.logout)
+
 
 class SFTPServer(service.MultiService):
     def __init__(self, client, accountfile, accounturl,
@@ -1287,7 +1330,7 @@ class SFTPServer(service.MultiService):
             p.registerChecker(c)
         if not accountfile and not accounturl:
             # we could leave this anonymous, with just the /uri/CAP form
-            raise NeedRootcapLookupScheme("must provide some translation")
+            raise NeedRootcapLookupScheme("must provide an account file or URL")
 
         pubkey = keys.Key.fromFile(pubkey_file)
         privkey = keys.Key.fromFile(privkey_file)
index 7e31fcce937fe524b883881c4e0dba7a053ddbcb..de0561a99c79c4c225cdc15c79d65be79ee19cfd 100644 (file)
@@ -5,6 +5,7 @@ from stat import S_IFREG, S_IFDIR
 from twisted.trial import unittest
 from twisted.internet import defer
 from twisted.python.failure import Failure
+from twisted.internet.error import ProcessDone, ProcessTerminated
 
 sftp = None
 sftpd = None
@@ -89,18 +90,15 @@ class Handler(GridTestMixin, ShouldFailMixin, unittest.TestCase):
         self.basedir = "sftp/" + basedir
         self.set_up_grid(num_clients=num_clients, num_servers=num_servers)
 
-        def check_abort():
-            pass
+        self.check_abort = lambda: False
         self.client = self.g.clients[0]
         self.username = "alice"
-        self.convergence = "convergence"
 
         d = self.client.create_dirnode()
         def _created_root(node):
             self.root = node
             self.root_uri = node.get_uri()
-            self.user = sftpd.SFTPUser(check_abort, self.client, self.root, self.username, self.convergence)
-            self.handler = sftpd.SFTPHandler(self.user)
+            self.handler = sftpd.SFTPUserHandler(self.check_abort, self.client, self.root, self.username)
         d.addCallback(_created_root)
         return d
 
@@ -915,3 +913,31 @@ class Handler(GridTestMixin, ShouldFailMixin, unittest.TestCase):
             self.shouldFailWithSFTPError(sftp.FX_PERMISSION_DENIED, "makeDirectory small",
                                          self.handler.makeDirectory, "small", {}))
         return d
+
+    def test_execCommand(self):
+        class FakeProtocol:
+            def __init__(self):
+                self.output = ""
+                self.reason = None
+            def write(self, data):
+                self.output += data
+            def processEnded(self, reason):
+                self.reason = reason
+
+        protocol_ok = FakeProtocol()
+        protocol_error = FakeProtocol()
+
+        d = self._set_up("execCommand")
+
+        d.addCallback(lambda ign: self.handler.execCommand(protocol_ok, "df -P -k /"))
+        d.addCallback(lambda ign: self.failUnlessIn("1024-blocks", protocol_ok.output))
+        d.addCallback(lambda ign: self.failUnless(isinstance(protocol_ok.reason.value, ProcessDone)))
+        d.addCallback(lambda ign: self.handler.eofReceived())
+        d.addCallback(lambda ign: self.handler.closed())
+
+        d.addCallback(lambda ign: self.handler.execCommand(protocol_error, "error"))
+        d.addCallback(lambda ign: self.failUnlessEqual(protocol_error.output, ""))
+        d.addCallback(lambda ign: self.failUnless(isinstance(protocol_error.reason.value, ProcessTerminated)))
+        d.addCallback(lambda ign: self.failUnlessEqual(protocol_error.reason.value.exitCode, 1))
+
+        return d
\ No newline at end of file