]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/blobdiff - src/allmydata/mutable/retrieve.py
IServer refactoring: pass IServer instances around, instead of peerids
[tahoe-lafs/tahoe-lafs.git] / src / allmydata / mutable / retrieve.py
index 0e507704669a3367c9e259754b6249b3e5fedd39..0845d29086912838c66ae1e240b769875e75d849 100644 (file)
@@ -27,7 +27,7 @@ class RetrieveStatus:
         self.timings["decode"] = 0.0
         self.timings["decrypt"] = 0.0
         self.timings["cumulative_verify"] = 0.0
-        self.problems = {}
+        self._problems = {}
         self.active = True
         self.storage_index = None
         self.helper = False
@@ -56,11 +56,14 @@ class RetrieveStatus:
         return self.active
     def get_counter(self):
         return self.counter
-
-    def add_fetch_timing(self, peerid, elapsed):
-        if peerid not in self.timings["fetch_per_server"]:
-            self.timings["fetch_per_server"][peerid] = []
-        self.timings["fetch_per_server"][peerid].append(elapsed)
+    def get_problems(self):
+        return self._problems
+
+    def add_fetch_timing(self, server, elapsed):
+        serverid = server.get_serverid()
+        if serverid not in self.timings["fetch_per_server"]:
+            self.timings["fetch_per_server"][serverid] = []
+        self.timings["fetch_per_server"][serverid].append(elapsed)
     def accumulate_decode_time(self, elapsed):
         self.timings["decode"] += elapsed
     def accumulate_decrypt_time(self, elapsed):
@@ -79,6 +82,9 @@ class RetrieveStatus:
         self.progress = value
     def set_active(self, value):
         self.active = value
+    def add_problem(self, server, f):
+        serverid = server.get_serverid()
+        self._problems[serverid] = f
 
 class Marker:
     pass
@@ -91,16 +97,16 @@ class Retrieve:
     # will use a single ServerMap instance.
     implements(IPushProducer)
 
-    def __init__(self, filenode, servermap, verinfo, fetch_privkey=False,
-                 verify=False):
+    def __init__(self, filenode, storage_broker, servermap, verinfo,
+                 fetch_privkey=False, verify=False):
         self._node = filenode
         assert self._node.get_pubkey()
+        self._storage_broker = storage_broker
         self._storage_index = filenode.get_storage_index()
         assert self._node.get_readkey()
         self._last_failure = None
         prefix = si_b2a(self._storage_index)[:5]
         self._log_number = log.msg("Retrieve(%s): starting" % prefix)
-        self._outstanding_queries = {} # maps (peerid,shnum) to start_time
         self._running = True
         self._decoding = False
         self._bad_shares = set()
@@ -239,11 +245,11 @@ class Retrieve:
         self.log("starting download")
         self._started_fetching = time.time()
         # The download process beyond this is a state machine.
-        # _add_active_peers will select the peers that we want to use
+        # _add_active_servers will select the servers that we want to use
         # for the download, and then attempt to start downloading. After
         # each segment, it will check for doneness, reacting to broken
-        # peers and corrupt shares as necessary. If it runs out of good
-        # peers before downloading all of the segments, _done_deferred
+        # servers and corrupt shares as necessary. If it runs out of good
+        # servers before downloading all of the segments, _done_deferred
         # will errback.  Otherwise, it will eventually callback with the
         # contents of the mutable file.
         self.loop()
@@ -251,7 +257,7 @@ class Retrieve:
 
     def loop(self):
         d = fireEventually(None) # avoid #237 recursion limit problem
-        d.addCallback(lambda ign: self._activate_enough_peers())
+        d.addCallback(lambda ign: self._activate_enough_servers())
         d.addCallback(lambda ign: self._download_current_segment())
         # when we're done, _download_current_segment will call _done. If we
         # aren't, it will call loop() again.
@@ -277,20 +283,19 @@ class Retrieve:
         shares = versionmap[self.verinfo]
         # this sharemap is consumed as we decide to send requests
         self.remaining_sharemap = DictOfSets()
-        for (shnum, peerid, timestamp) in shares:
-            self.remaining_sharemap.add(shnum, peerid)
+        for (shnum, server, timestamp) in shares:
+            self.remaining_sharemap.add(shnum, server)
             # If the servermap update fetched anything, it fetched at least 1
             # KiB, so we ask for that much.
             # TODO: Change the cache methods to allow us to fetch all of the
             # data that they have, then change this method to do that.
             any_cache = self._node._read_from_cache(self.verinfo, shnum,
                                                     0, 1000)
-            ss = self.servermap.connections[peerid]
-            reader = MDMFSlotReadProxy(ss,
+            reader = MDMFSlotReadProxy(server.get_rref(),
                                        self._storage_index,
                                        shnum,
                                        any_cache)
-            reader.peerid = peerid
+            reader.server = server
             self.readers[shnum] = reader
         assert len(self.remaining_sharemap) >= k
 
@@ -436,7 +441,7 @@ class Retrieve:
 
         self._current_segment = self._start_segment
 
-    def _activate_enough_peers(self):
+    def _activate_enough_servers(self):
         """
         I populate self._active_readers with enough active readers to
         retrieve the contents of this mutable file. I am called before
@@ -445,9 +450,9 @@ class Retrieve:
         """
         # TODO: It would be cool to investigate other heuristics for
         # reader selection. For instance, the cost (in time the user
-        # spends waiting for their file) of selecting a really slow peer
+        # spends waiting for their file) of selecting a really slow server
         # that happens to have a primary share is probably more than
-        # selecting a really fast peer that doesn't have a primary
+        # selecting a really fast server that doesn't have a primary
         # share. Maybe the servermap could be extended to provide this
         # information; it could keep track of latency information while
         # it gathers more important data, and then this routine could
@@ -485,7 +490,7 @@ class Retrieve:
         else:
             new_shnums = []
 
-        self.log("adding %d new peers to the active list" % len(new_shnums))
+        self.log("adding %d new servers to the active list" % len(new_shnums))
         for shnum in new_shnums:
             reader = self.readers[shnum]
             self._active_readers.append(reader)
@@ -496,7 +501,7 @@ class Retrieve:
             # segment decoding, then we'll take more drastic measures.
             if self._need_privkey and not self._node.is_readonly():
                 d = reader.get_encprivkey()
-                d.addCallback(self._try_to_validate_privkey, reader)
+                d.addCallback(self._try_to_validate_privkey, reader, reader.server)
                 # XXX: don't just drop the Deferred. We need error-reporting
                 # but not flow-control here.
         assert len(self._active_readers) >= self._required_shares
@@ -528,7 +533,7 @@ class Retrieve:
 
     def _remove_reader(self, reader):
         """
-        At various points, we will wish to remove a peer from
+        At various points, we will wish to remove a server from
         consideration and/or use. These include, but are not necessarily
         limited to:
 
@@ -556,18 +561,18 @@ class Retrieve:
         self._active_readers.remove(reader)
         # TODO: self.readers.remove(reader)?
         for shnum in list(self.remaining_sharemap.keys()):
-            self.remaining_sharemap.discard(shnum, reader.peerid)
+            self.remaining_sharemap.discard(shnum, reader.server)
 
 
-    def _mark_bad_share(self, reader, f):
+    def _mark_bad_share(self, server, shnum, reader, f):
         """
-        I mark the (peerid, shnum) encapsulated by my reader argument as
-        a bad share, which means that it will not be used anywhere else.
+        I mark the given (server, shnum) as a bad share, which means that it
+        will not be used anywhere else.
 
         There are several reasons to want to mark something as a bad
         share. These include:
 
-            - A connection error to the peer.
+            - A connection error to the server.
             - A mismatched prefix (that is, a prefix that does not match
               our local conception of the version information string).
             - A failing block hash, salt hash, share hash, or other
@@ -576,21 +581,18 @@ class Retrieve:
         This method will ensure that readers that we wish to mark bad
         (for these reasons or other reasons) are not used for the rest
         of the download. Additionally, it will attempt to tell the
-        remote peer (with no guarantee of success) that its share is
+        remote server (with no guarantee of success) that its share is
         corrupt.
         """
         self.log("marking share %d on server %s as bad" % \
-                 (reader.shnum, reader))
+                 (shnum, server.get_name()))
         prefix = self.verinfo[-2]
-        self.servermap.mark_bad_share(reader.peerid,
-                                      reader.shnum,
-                                      prefix)
+        self.servermap.mark_bad_share(server, shnum, prefix)
         self._remove_reader(reader)
-        self._bad_shares.add((reader.peerid, reader.shnum, f))
-        self._status.problems[reader.peerid] = f
+        self._bad_shares.add((server, shnum, f))
+        self._status.add_problem(server, f)
         self._last_failure = f
-        self.notify_server_corruption(reader.peerid, reader.shnum,
-                                      str(f.value))
+        self.notify_server_corruption(server, shnum, str(f.value))
 
 
     def _download_current_segment(self):
@@ -632,7 +634,7 @@ class Retrieve:
             d = reader.get_block_and_salt(segnum)
             d2 = self._get_needed_hashes(reader, segnum)
             dl = defer.DeferredList([d, d2], consumeErrors=True)
-            dl.addCallback(self._validate_block, segnum, reader, started)
+            dl.addCallback(self._validate_block, segnum, reader, reader.server, started)
             dl.addErrback(self._validation_or_decoding_failed, [reader])
             ds.append(dl)
         dl = defer.DeferredList(ds)
@@ -722,20 +724,20 @@ class Retrieve:
         I am called when a block or a salt fails to correctly validate, or when
         the decryption or decoding operation fails for some reason.  I react to
         this failure by notifying the remote server of corruption, and then
-        removing the remote peer from further activity.
+        removing the remote server from further activity.
         """
         assert isinstance(readers, list)
         bad_shnums = [reader.shnum for reader in readers]
 
-        self.log("validation or decoding failed on share(s) %s, peer(s) %s "
+        self.log("validation or decoding failed on share(s) %s, server(s) %s "
                  ", segment %d: %s" % \
                  (bad_shnums, readers, self._current_segment, str(f)))
         for reader in readers:
-            self._mark_bad_share(reader, f)
+            self._mark_bad_share(reader.server, reader.shnum, reader, f)
         return
 
 
-    def _validate_block(self, results, segnum, reader, started):
+    def _validate_block(self, results, segnum, reader, server, started):
         """
         I validate a block from one share on a remote server.
         """
@@ -744,7 +746,7 @@ class Retrieve:
         self.log("validating share %d for segment %d" % (reader.shnum,
                                                              segnum))
         elapsed = time.time() - started
-        self._status.add_fetch_timing(reader.peerid, elapsed)
+        self._status.add_fetch_timing(server, elapsed)
         self._set_current_status("validating blocks")
         # Did we fail to fetch either of the things that we were
         # supposed to? Fail if so.
@@ -757,7 +759,7 @@ class Retrieve:
             assert isinstance(results[0][1], failure.Failure)
 
             f = results[0][1]
-            raise CorruptShareError(reader.peerid,
+            raise CorruptShareError(server,
                                     reader.shnum,
                                     "Connection error: %s" % str(f))
 
@@ -777,7 +779,7 @@ class Retrieve:
                 bht.set_hashes(blockhashes)
             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
                     IndexError), e:
-                raise CorruptShareError(reader.peerid,
+                raise CorruptShareError(server,
                                         reader.shnum,
                                         "block hash tree failure: %s" % e)
 
@@ -791,7 +793,7 @@ class Retrieve:
            bht.set_hashes(leaves={segnum: blockhash})
         except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
                 IndexError), e:
-            raise CorruptShareError(reader.peerid,
+            raise CorruptShareError(server,
                                     reader.shnum,
                                     "block hash tree failure: %s" % e)
 
@@ -812,7 +814,7 @@ class Retrieve:
                                             leaves={reader.shnum: bht[0]})
             except (hashtree.BadHashError, hashtree.NotEnoughHashesError, \
                     IndexError), e:
-                raise CorruptShareError(reader.peerid,
+                raise CorruptShareError(server,
                                         reader.shnum,
                                         "corrupt hashes: %s" % e)
 
@@ -931,13 +933,13 @@ class Retrieve:
         return plaintext
 
 
-    def notify_server_corruption(self, peerid, shnum, reason):
-        ss = self.servermap.connections[peerid]
-        ss.callRemoteOnly("advise_corrupt_share",
-                          "mutable", self._storage_index, shnum, reason)
+    def notify_server_corruption(self, server, shnum, reason):
+        rref = server.get_rref()
+        rref.callRemoteOnly("advise_corrupt_share",
+                            "mutable", self._storage_index, shnum, reason)
 
 
-    def _try_to_validate_privkey(self, enc_privkey, reader):
+    def _try_to_validate_privkey(self, enc_privkey, reader, server):
         alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
         alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
         if alleged_writekey != self._node.get_writekey():
@@ -945,13 +947,13 @@ class Retrieve:
                      (reader, reader.shnum),
                      level=log.WEIRD, umid="YIw4tA")
             if self._verify:
-                self.servermap.mark_bad_share(reader.peerid, reader.shnum,
+                self.servermap.mark_bad_share(server, reader.shnum,
                                               self.verinfo[-2])
-                e = CorruptShareError(reader.peerid,
+                e = CorruptShareError(server,
                                       reader.shnum,
                                       "invalid privkey")
                 f = failure.Failure(e)
-                self._bad_shares.add((reader.peerid, reader.shnum, f))
+                self._bad_shares.add((server, reader.shnum, f))
             return
 
         # it's good
@@ -986,7 +988,7 @@ class Retrieve:
         self._node._populate_total_shares(N)
 
         if self._verify:
-            ret = list(self._bad_shares)
+            ret = self._bad_shares
             self.log("done verifying, found %d bad shares" % len(ret))
         else:
             # TODO: upload status here?
@@ -997,14 +999,14 @@ class Retrieve:
 
     def _raise_notenoughshareserror(self):
         """
-        I am called by _activate_enough_peers when there are not enough
-        active peers left to complete the download. After making some
+        I am called by _activate_enough_servers when there are not enough
+        active servers left to complete the download. After making some
         useful logging statements, I throw an exception to that effect
         to the caller of this Retrieve object through
         self._done_deferred.
         """
 
-        format = ("ran out of peers: "
+        format = ("ran out of servers: "
                   "have %(have)d of %(total)d segments "
                   "found %(bad)d bad shares "
                   "encoding %(k)d-of-%(n)d")