From 194ce4823943bed3e8074c7ec2f6533b337622d1 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Sun, 3 Dec 2006 22:42:19 -0700
Subject: [PATCH] add download code to vdrive, add system-level test for vdrive
 functionality, refactor DownloadTargets

---
 allmydata/download.py         | 132 +++++++++++++++++++++++++++-------
 allmydata/test/test_system.py |  24 +++++++
 allmydata/test/test_vdrive.py |   1 -
 allmydata/vdrive.py           |  72 +++++++++++++++++--
 4 files changed, 196 insertions(+), 33 deletions(-)

diff --git a/allmydata/download.py b/allmydata/download.py
index ecdaae5d..58b8b4b2 100644
--- a/allmydata/download.py
+++ b/allmydata/download.py
@@ -1,4 +1,6 @@
 
+import os
+from zope.interface import Interface, implements
 from twisted.python import failure, log
 from twisted.internet import defer
 from twisted.application import service
@@ -6,8 +8,6 @@ from twisted.application import service
 from allmydata.util import idlib
 from allmydata import encode
 
-from cStringIO import StringIO
-
 class NotEnoughPeersError(Exception):
     pass
 
@@ -23,13 +23,18 @@ class FileDownloader:
         assert isinstance(verifierid, str)
         self._verifierid = verifierid
 
-    def set_filehandle(self, filehandle):
-        self._filehandle = filehandle
+    def set_download_target(self, target):
+        self._target = target
+        self._target.register_canceller(self._cancel)
+
+    def _cancel(self):
+        pass
 
     def make_decoder(self):
         n = self._shares = 4
         k = self._desired_shares = 2
-        self._decoder = encode.Decoder(self._filehandle, k, n,
+        self._target.open()
+        self._decoder = encode.Decoder(self._target, k, n,
                                        self._verifierid)
 
     def start(self):
@@ -103,43 +108,118 @@ class FileDownloader:
         for peerid, buckets in self.landlords:
             all_buckets.extend(buckets)
         d = self._decoder.start(all_buckets)
+        def _done(res):
+            self._target.close()
+            return self._target.finish()
+        def _fail(res):
+            self._target.fail()
+            return res
+        d.addCallbacks(_done, _fail)
         return d
 
 def netstring(s):
     return "%d:%s," % (len(s), s)
 
+class IDownloadTarget(Interface):
+    def open():
+        """Called before any calls to write() or close()."""
+    def write(data):
+        pass
+    def close():
+        pass
+    def fail():
+        """fail() is called to indicate that the download has failed. No
+        further methods will be invoked on the IDownloadTarget after fail()."""
+    def register_canceller(cb):
+        """The FileDownloader uses this to register a no-argument function
+        that the target can call to cancel the download. Once this canceller
+        is invoked, no further calls to write() or close() will be made."""
+    def finish(self):
+        """When the FileDownloader is done, this finish() function will be
+        called. Whatever it returns will be returned to the invoker of
+        Downloader.download.
+        """
+
+class FileName:
+    implements(IDownloadTarget)
+    def __init__(self, filename):
+        self._filename = filename
+    def open(self):
+        self.f = open(self._filename, "wb")
+        return self.f
+    def write(self, data):
+        self.f.write(data)
+    def close(self):
+        self.f.close()
+    def fail(self):
+        self.f.close()
+        os.unlink(self._filename)
+    def register_canceller(self, cb):
+        pass # we won't use it
+    def finish(self):
+        pass
+
+class Data:
+    implements(IDownloadTarget)
+    def __init__(self):
+        self._data = []
+    def open(self):
+        pass
+    def write(self, data):
+        self._data.append(data)
+    def close(self):
+        self.data = "".join(self._data)
+        del self._data
+    def fail(self):
+        del self._data
+    def register_canceller(self, cb):
+        pass # we won't use it
+    def finish(self):
+        return self.data
+
+class FileHandle:
+    implements(IDownloadTarget)
+    def __init__(self, filehandle):
+        self._filehandle = filehandle
+    def open(self):
+        pass
+    def write(self, data):
+        self._filehandle.write(data)
+    def close(self):
+        # the originator of the filehandle reserves the right to close it
+        pass
+    def fail(self):
+        pass
+    def register_canceller(self, cb):
+        pass
+    def finish(self):
+        pass
+
+
 class Downloader(service.MultiService):
     """I am a service that allows file downloading.
     """
     name = "downloader"
 
-    def download_to_filename(self, verifierid, filename):
-        f = open(filename, "wb")
-        def _done(res):
-            f.close()
-            return res
-        d = self.download_filehandle(verifierid, f)
-        d.addBoth(_done)
-        return d
-
-    def download_to_data(self, verifierid):
-        f = StringIO()
-        d = self.download_filehandle(verifierid, f)
-        def _done(res):
-            return f.getvalue()
-        d.addCallback(_done)
-        return d
-
-    def download_filehandle(self, verifierid, f):
+    def download(self, verifierid, t):
         assert self.parent
         assert self.running
         assert isinstance(verifierid, str)
-        assert f.write
-        assert f.close
+        t = IDownloadTarget(t)
+        assert t.write
+        assert t.close
         dl = FileDownloader(self.parent, verifierid)
-        dl.set_filehandle(f)
+        dl.set_download_target(t)
         dl.make_decoder()
         d = dl.start()
         return d
 
+    # utility functions
+    def download_to_data(self, verifierid):
+        return self.download(verifierid, Data())
+    def download_to_filename(self, verifierid, filename):
+        return self.download(verifierid, FileName(filename))
+    def download_to_filehandle(self, verifierid, filehandle):
+        return self.download(verifierid, FileHandle(filehandle))
+
 
diff --git a/allmydata/test/test_system.py b/allmydata/test/test_system.py
index 5f934fe6..80205a6d 100644
--- a/allmydata/test/test_system.py
+++ b/allmydata/test/test_system.py
@@ -83,3 +83,27 @@ class SystemTest(unittest.TestCase):
         return d
     test_upload_and_download.timeout = 20
 
+    def test_vdrive(self):
+        DATA = "Some data to publish to the virtual drive\n"
+        d = self.set_up_nodes()
+        def _do_publish(res):
+            log.msg("PUBLISHING")
+            v0 = self.clients[0].getServiceNamed("vdrive")
+            d1 = v0.make_directory("/", "subdir1")
+            d1.addCallback(lambda subdir1:
+                           v0.put_file_by_data(subdir1, "data", DATA))
+            return d1
+        d.addCallback(_do_publish)
+        def _publish_done(res):
+            log.msg("publish finished")
+            v1 = self.clients[1].getServiceNamed("vdrive")
+            d1 = v1.get_file_to_data("/subdir1/data")
+            return d1
+        d.addCallback(_publish_done)
+        def _get_done(data):
+            log.msg("get finished")
+            self.failUnlessEqual(data, DATA)
+        d.addCallback(_get_done)
+        return d
+    test_vdrive.timeout = 20
+
diff --git a/allmydata/test/test_vdrive.py b/allmydata/test/test_vdrive.py
index 20ae2102..3610a66b 100644
--- a/allmydata/test/test_vdrive.py
+++ b/allmydata/test/test_vdrive.py
@@ -66,4 +66,3 @@ class Traverse(unittest.TestCase):
                       self.failUnlessEqual(sorted(files),
                                            ["2.a", "2.b", "d2.1"]))
         return d
-
diff --git a/allmydata/vdrive.py b/allmydata/vdrive.py
index 06670a83..d61251bd 100644
--- a/allmydata/vdrive.py
+++ b/allmydata/vdrive.py
@@ -3,7 +3,7 @@
 
 from twisted.application import service
 from twisted.internet import defer
-from allmydata.upload import Data, FileHandle, FileName
+from allmydata import upload, download
 
 class VDrive(service.MultiService):
     name = "vdrive"
@@ -40,6 +40,20 @@ class VDrive(service.MultiService):
         d.addCallback(_check)
         return d
 
+    def get_verifierid_from_parent(self, parent, filename):
+        assert not isinstance(parent, str), "'%s' isn't a directory node" % (parent,)
+        d = parent.callRemote("list")
+        def _find(table):
+            for name,target in table:
+                if name == filename:
+                    assert isinstance(target, str), "Hey, %s isn't a file" % filename
+                    return target
+            else:
+                raise KeyError("no such file '%s' in '%s'" %
+                               (filename, [t[0] for t in table]))
+        d.addCallback(_find)
+        return d
+
     def get_root(self):
         return self.gvd_root
 
@@ -64,10 +78,10 @@ class VDrive(service.MultiService):
         I return a deferred that will fire when the operation is complete.
         """
 
-        u = self.parent.getServiceNamed("uploader")
+        ul = self.parent.getServiceNamed("uploader")
         d = self.dirpath(dir_or_path)
         def _got_dir(dirnode):
-            d1 = u.upload(uploadable)
+            d1 = ul.upload(uploadable)
             d1.addCallback(lambda vid:
                            dirnode.callRemote("add_file", name, vid))
             return d1
@@ -75,14 +89,60 @@ class VDrive(service.MultiService):
         return d
 
     def put_file_by_filename(self, dir_or_path, name, filename):
-        return self.put_file(dir_or_path, name, FileName(filename))
+        return self.put_file(dir_or_path, name, upload.FileName(filename))
     def put_file_by_data(self, dir_or_path, name, data):
-        return self.put_file(dir_or_path, name, Data(data))
+        return self.put_file(dir_or_path, name, upload.Data(data))
     def put_file_by_filehandle(self, dir_or_path, name, filehandle):
-        return self.put_file(dir_or_path, name, FileHandle(filehandle))
+        return self.put_file(dir_or_path, name, upload.FileHandle(filehandle))
 
     def make_directory(self, dir_or_path, name):
         d = self.dirpath(dir_or_path)
         d.addCallback(lambda parent: parent.callRemote("add_directory", name))
         return d
 
+
+    def get_file(self, dir_and_name_or_path, download_target):
+        """Retrieve a file from the virtual drive and put it somewhere.
+
+        The file to be retrieved may either be specified as a (dir, name)
+        tuple or as a full /-delimited pathname. In the former case, 'dir'
+        can be either a DirectoryNode or a pathname.
+
+        The download target must be an IDownloadTarget instance like
+        allmydata.download.Data, .FileName, or .FileHandle .
+        """
+
+        dl = self.parent.getServiceNamed("downloader")
+
+        if isinstance(dir_and_name_or_path, tuple):
+            dir_or_path, name = dir_and_name_or_path
+            d = self.dirpath(dir_or_path)
+            def _got_dir(dirnode):
+                return self.get_verifierid_from_parent(dirnode, name)
+            d.addCallback(_got_dir)
+        else:
+            rslash = dir_and_name_or_path.rfind("/")
+            if rslash == -1:
+                # we're looking for a file in the root directory
+                dir = self.gvd_root
+                name = dir_and_name_or_path
+                d = self.get_verifierid_from_parent(dir, name)
+            else:
+                dirpath = dir_and_name_or_path[:rslash]
+                name = dir_and_name_or_path[rslash+1:]
+                d = self.dirpath(dirpath)
+                d.addCallback(lambda dir:
+                              self.get_verifierid_from_parent(dir, name))
+
+        def _got_verifierid(verifierid):
+            return dl.download(verifierid, download_target)
+        d.addCallback(_got_verifierid)
+        return d
+
+    def get_file_to_filename(self, from_where, filename):
+        return self.get_file(from_where, download.FileName(filename))
+    def get_file_to_data(self, from_where):
+        return self.get_file(from_where, download.Data())
+    def get_file_to_filehandle(self, from_where, filehandle):
+        return self.get_file(from_where, download.FileHandle(filehandle))
+
-- 
2.45.2