]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
download status: refactor into a separate object, so we don't need to keep the Downlo...
authorBrian Warner <warner@allmydata.com>
Wed, 13 Feb 2008 02:01:03 +0000 (19:01 -0700)
committerBrian Warner <warner@allmydata.com>
Wed, 13 Feb 2008 02:01:03 +0000 (19:01 -0700)
src/allmydata/download.py

index 5229e6a60cdbb62ed59f8141e5c9215b44f7830c..aa54f0f1f0cddae1fb037086bd668c89084516b1 100644 (file)
@@ -29,7 +29,8 @@ class DownloadStopped(Exception):
     pass
 
 class Output:
-    def __init__(self, downloadable, key, total_length, log_parent):
+    def __init__(self, downloadable, key, total_length, log_parent,
+                 download_status):
         self.downloadable = downloadable
         self._decryptor = AES(key)
         self._crypttext_hasher = hashutil.crypttext_hasher()
@@ -41,9 +42,8 @@ class Output:
         self._crypttext_hash_tree = None
         self._opened = False
         self._log_parent = log_parent
-
-    def get_progress(self):
-        return float(self.length) / self.total_length
+        self._status = download_status
+        self._status.set_progress(0.0)
 
     def log(self, *args, **kwargs):
         if "parent" not in kwargs:
@@ -58,6 +58,7 @@ class Output:
 
     def write_segment(self, crypttext):
         self.length += len(crypttext)
+        self._status.set_progress( float(self.length) / self.total_length )
 
         # memory footprint: 'crypttext' is the only segment_size usage
         # outstanding. While we decrypt it into 'plaintext', we hit
@@ -324,10 +325,55 @@ class SegmentDownloader:
     def bucket_failed(self, vbucket):
         self.parent.bucket_failed(vbucket)
 
+class DownloadStatus:
+    implements(IDownloadStatus)
+
+    def __init__(self):
+        self.storage_index = None
+        self.size = None
+        self.helper = False
+        self.status = "Not started"
+        self.progress = 0.0
+        self.paused = False
+        self.stopped = False
+
+    def get_storage_index(self):
+        return self.storage_index
+    def get_size(self):
+        return self.size
+    def using_helper(self):
+        return self.helper
+    def get_status(self):
+        status = self.status
+        if self.paused:
+            status += " (output paused)"
+        if self.stopped:
+            status += " (output stopped)"
+        return status
+    def get_progress(self):
+        return self.progress
+
+    def set_storage_index(self, si):
+        self.storage_index = si
+    def set_size(self, size):
+        self.size = size
+    def set_helper(self, helper):
+        self.helper = helper
+    def set_status(self, status):
+        self.status = status
+    def set_paused(self, paused):
+        self.paused = paused
+    def set_stopped(self, stopped):
+        self.stopped = stopped
+    def set_progress(self, value):
+        self.progress = value
+
+
 class FileDownloader:
-    implements(IPushProducer, IDownloadStatus)
+    implements(IPushProducer)
     check_crypttext_hash = True
     check_plaintext_hash = True
+    _status = None
 
     def __init__(self, client, u, downloadable):
         self._client = client
@@ -341,12 +387,14 @@ class FileDownloader:
 
         self.init_logging()
 
-        self._status = "Starting"
+        self._status = s = DownloadStatus()
+        s.set_status("Starting")
 
         if IConsumer.providedBy(downloadable):
             downloadable.registerProducer(self, True)
         self._downloadable = downloadable
-        self._output = Output(downloadable, u.key, self._size, self._log_number)
+        self._output = Output(downloadable, u.key, self._size, self._log_number,
+                              self._status)
         self._paused = False
         self._stopped = False
 
@@ -381,16 +429,22 @@ class FileDownloader:
         if self._paused:
             return
         self._paused = defer.Deferred()
+        if self._status:
+            self._status.set_paused(True)
 
     def resumeProducing(self):
         if self._paused:
             p = self._paused
             self._paused = None
             eventually(p.callback, None)
+            if self._status:
+                self._status.set_paused(False)
 
     def stopProducing(self):
         self.log("Download.stopProducing")
         self._stopped = True
+        if self._status:
+            self._status.set_stopped(True)
 
     def start(self):
         self.log("starting download")
@@ -406,13 +460,15 @@ class FileDownloader:
         # once we know that, we can download blocks from everybody
         d.addCallback(self._download_all_segments)
         def _finished(res):
-            self._status = "Finished"
+            if self._status:
+                self._status.set_status("Finished")
             if IConsumer.providedBy(self._downloadable):
                 self._downloadable.unregisterProducer()
             return res
         d.addBoth(_finished)
         def _failed(why):
-            self._status = "Failed"
+            if self._status:
+                self._status.set_status("Failed")
             self._output.fail(why)
             return why
         d.addErrback(_failed)
@@ -428,14 +484,18 @@ class FileDownloader:
             dl.append(d)
         self._responses_received = 0
         self._queries_sent = len(dl)
-        self._status = "Locating Shares (%d/%d)" % (self._responses_received,
-                                                    self._queries_sent)
+        if self._status:
+            self._status.set_status("Locating Shares (%d/%d)" %
+                                    (self._responses_received,
+                                     self._queries_sent))
         return defer.DeferredList(dl)
 
     def _got_response(self, buckets):
         self._responses_received += 1
-        self._status = "Locating Shares (%d/%d)" % (self._responses_received,
-                                                    self._queries_sent)
+        if self._status:
+            self._status.set_status("Locating Shares (%d/%d)" %
+                                    (self._responses_received,
+                                     self._queries_sent))
         for sharenum, bucket in buckets.iteritems():
             b = storage.ReadBucketProxy(bucket)
             self.add_share_bucket(sharenum, b)
@@ -477,7 +537,8 @@ class FileDownloader:
         # all are supposed to be identical. We compute the hash of the data
         # that comes back, and compare it against the version in our URI. If
         # they don't match, ignore their data and try someone else.
-        self._status = "Obtaining URI Extension"
+        if self._status:
+            self._status.set_status("Obtaining URI Extension")
 
         def _validate(proposal, bucket):
             h = hashutil.uri_extension_hash(proposal)
@@ -537,7 +598,8 @@ class FileDownloader:
         self._share_hashtree.set_hashes({0: self._roothash})
 
     def _get_hashtrees(self, res):
-        self._status = "Retrieving Hash Trees"
+        if self._status:
+            self._status.set_status("Retrieving Hash Trees")
         d = self._get_plaintext_hashtrees()
         d.addCallback(self._get_crypttext_hashtrees)
         d.addCallback(self._setup_hashtrees)
@@ -653,8 +715,9 @@ class FileDownloader:
         return res
 
     def _download_segment(self, res, segnum):
-        self._status = "Downloading segment %d of %d" % (segnum,
-                                                         self._total_segments)
+        if self._status:
+            self._status.set_status("Downloading segment %d of %d" %
+                                    (segnum, self._total_segments))
         self.log("downloading seg#%d of %d (%d%%)"
                  % (segnum, self._total_segments,
                     100.0 * segnum / self._total_segments))
@@ -730,21 +793,9 @@ class FileDownloader:
                 got=self._output.length, expected=self._size)
         return self._output.finish()
 
-    def get_storage_index(self):
-        return self._storage_index
-    def get_size(self):
-        return self._size
-    def using_helper(self):
-        return False
-    def get_status(self):
-        status = self._status
-        if self._paused:
-            status += " (output paused)"
-        if self._stopped:
-            status += " (output stopped)"
-        return status
-    def get_progress(self):
-        return self._output.get_progress()
+    def get_download_status(self):
+        return self._status
+
 
 class LiteralDownloader:
     implements(IDownloadStatus)
@@ -753,24 +804,22 @@ class LiteralDownloader:
         self._uri = IFileURI(u)
         assert isinstance(self._uri, uri.LiteralFileURI)
         self._downloadable = downloadable
+        self._status = s = DownloadStatus()
+        s.set_storage_index(None)
+        s.set_helper(False)
+        s.set_status("Done")
+        s.set_progress(1.0)
 
     def start(self):
         data = self._uri.data
+        self._status.set_size(len(data))
         self._downloadable.open(len(data))
         self._downloadable.write(data)
         self._downloadable.close()
         return defer.maybeDeferred(self._downloadable.finish)
 
-    def get_storage_index(self):
-        return None
-    def get_size(self):
-        return len(self._uri.data)
-    def using_helper(self):
-        return False
-    def get_status(self):
-        return "Done"
-    def get_progress(self):
-        return 1.0
+    def get_download_status(self):
+        return self._status
 
 class FileName:
     implements(IDownloadTarget)
@@ -858,7 +907,7 @@ class Downloader(service.MultiService):
             dl = FileDownloader(self.parent, u, t)
         else:
             raise RuntimeError("I don't know how to download a %s" % u)
-        self._all_downloads[dl] = None
+        self._all_downloads[dl.get_download_status()] = None
         d = dl.start()
         return d