From 2f63d9b522f6d95c8006cd8f9731754365b5ad81 Mon Sep 17 00:00:00 2001
From: Brian Warner <warner@lothar.com>
Date: Sat, 26 Feb 2011 19:11:38 -0700
Subject: [PATCH] immutable/upload.py: reduce use of get_serverid()

---
 src/allmydata/immutable/upload.py   | 80 ++++++++++++++---------------
 src/allmydata/test/test_upload.py   |  6 ++-
 src/allmydata/util/happinessutil.py |  2 +-
 3 files changed, 45 insertions(+), 43 deletions(-)

diff --git a/src/allmydata/immutable/upload.py b/src/allmydata/immutable/upload.py
index a41fc3b8..aee3e240 100644
--- a/src/allmydata/immutable/upload.py
+++ b/src/allmydata/immutable/upload.py
@@ -69,21 +69,18 @@ def pretty_print_shnum_to_servers(s):
     return ', '.join([ "sh%s: %s" % (k, '+'.join([idlib.shortnodeid_b2a(x) for x in v])) for k, v in s.iteritems() ])
 
 class ServerTracker:
-    def __init__(self, serverid, storage_server,
+    def __init__(self, server,
                  sharesize, blocksize, num_segments, num_share_hashes,
                  storage_index,
                  bucket_renewal_secret, bucket_cancel_secret):
-        precondition(isinstance(serverid, str), serverid)
-        precondition(len(serverid) == 20, serverid)
-        self.serverid = serverid
-        self._storageserver = storage_server # to an RIStorageServer
+        self._server = server
         self.buckets = {} # k: shareid, v: IRemoteBucketWriter
         self.sharesize = sharesize
 
         wbp = layout.make_write_bucket_proxy(None, sharesize,
                                              blocksize, num_segments,
                                              num_share_hashes,
-                                             EXTENSION_SIZE, serverid)
+                                             EXTENSION_SIZE, server.get_serverid())
         self.wbp_class = wbp.__class__ # to create more of them
         self.allocated_size = wbp.get_allocated_size()
         self.blocksize = blocksize
@@ -96,23 +93,28 @@ class ServerTracker:
 
     def __repr__(self):
         return ("<ServerTracker for server %s and SI %s>"
-                % (idlib.shortnodeid_b2a(self.serverid),
-                   si_b2a(self.storage_index)[:5]))
+                % (self._server.name(), si_b2a(self.storage_index)[:5]))
+
+    def get_serverid(self):
+        return self._server.get_serverid()
+    def name(self):
+        return self._server.name()
 
     def query(self, sharenums):
-        d = self._storageserver.callRemote("allocate_buckets",
-                                           self.storage_index,
-                                           self.renew_secret,
-                                           self.cancel_secret,
-                                           sharenums,
-                                           self.allocated_size,
-                                           canary=Referenceable())
+        rref = self._server.get_rref()
+        d = rref.callRemote("allocate_buckets",
+                            self.storage_index,
+                            self.renew_secret,
+                            self.cancel_secret,
+                            sharenums,
+                            self.allocated_size,
+                            canary=Referenceable())
         d.addCallback(self._got_reply)
         return d
 
     def ask_about_existing_shares(self):
-        return self._storageserver.callRemote("get_buckets",
-                                              self.storage_index)
+        rref = self._server.get_rref()
+        return rref.callRemote("get_buckets", self.storage_index)
 
     def _got_reply(self, (alreadygot, buckets)):
         #log.msg("%s._got_reply(%s)" % (self, (alreadygot, buckets)))
@@ -123,7 +125,7 @@ class ServerTracker:
                                 self.num_segments,
                                 self.num_share_hashes,
                                 EXTENSION_SIZE,
-                                self.serverid)
+                                self._server.get_serverid())
             b[sharenum] = bp
         self.buckets.update(b)
         return (alreadygot, set(b.keys()))
@@ -209,8 +211,7 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
                                              num_share_hashes, EXTENSION_SIZE,
                                              None)
         allocated_size = wbp.get_allocated_size()
-        all_servers = [(s.get_serverid(), s.get_rref())
-                       for s in storage_broker.get_servers_for_psi(storage_index)]
+        all_servers = storage_broker.get_servers_for_psi(storage_index)
         if not all_servers:
             raise NoServersError("client gave us zero servers")
 
@@ -219,8 +220,8 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
         # field) from getting large shares (for files larger than about
         # 12GiB). See #439 for details.
         def _get_maxsize(server):
-            (serverid, conn) = server
-            v1 = conn.version["http://allmydata.org/tahoe/protocols/storage/v1"]
+            v0 = server.get_rref().version
+            v1 = v0["http://allmydata.org/tahoe/protocols/storage/v1"]
             return v1["maximum-immutable-share-size"]
         writable_servers = [server for server in all_servers
                             if _get_maxsize(server) >= allocated_size]
@@ -237,11 +238,11 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
                                                      storage_index)
         def _make_trackers(servers):
             trackers = []
-            for (serverid, conn) in servers:
-                seed = serverid
+            for s in servers:
+                seed = s.get_lease_seed()
                 renew = bucket_renewal_secret_hash(file_renewal_secret, seed)
                 cancel = bucket_cancel_secret_hash(file_cancel_secret, seed)
-                st = ServerTracker(serverid, conn,
+                st = ServerTracker(s,
                                    share_size, block_size,
                                    num_segments, num_share_hashes,
                                    storage_index,
@@ -268,27 +269,26 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
         for tracker in readonly_trackers:
             assert isinstance(tracker, ServerTracker)
             d = tracker.ask_about_existing_shares()
-            d.addBoth(self._handle_existing_response, tracker.serverid)
+            d.addBoth(self._handle_existing_response, tracker)
             ds.append(d)
             self.num_servers_contacted += 1
             self.query_count += 1
             self.log("asking server %s for any existing shares" %
-                     (idlib.shortnodeid_b2a(tracker.serverid),),
-                    level=log.NOISY)
+                     (tracker.name(),), level=log.NOISY)
         dl = defer.DeferredList(ds)
         dl.addCallback(lambda ign: self._loop())
         return dl
 
 
-    def _handle_existing_response(self, res, serverid):
+    def _handle_existing_response(self, res, tracker):
         """
         I handle responses to the queries sent by
         Tahoe2ServerSelector._existing_shares.
         """
+        serverid = tracker.get_serverid()
         if isinstance(res, failure.Failure):
             self.log("%s got error during existing shares check: %s"
-                    % (idlib.shortnodeid_b2a(serverid), res),
-                    level=log.UNUSUAL)
+                    % (tracker.name(), res), level=log.UNUSUAL)
             self.error_count += 1
             self.bad_query_count += 1
         else:
@@ -296,7 +296,7 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
             if buckets:
                 self.serverids_with_shares.add(serverid)
             self.log("response to get_buckets() from server %s: alreadygot=%s"
-                    % (idlib.shortnodeid_b2a(serverid), tuple(sorted(buckets))),
+                    % (tracker.name(), tuple(sorted(buckets))),
                     level=log.NOISY)
             for bucket in buckets:
                 self.preexisting_shares.setdefault(bucket, set()).add(serverid)
@@ -404,7 +404,7 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
             if self._status:
                 self._status.set_status("Contacting Servers [%s] (first query),"
                                         " %d shares left.."
-                                        % (idlib.shortnodeid_b2a(tracker.serverid),
+                                        % (tracker.name(),
                                            len(self.homeless_shares)))
             d = tracker.query(shares_to_ask)
             d.addBoth(self._got_response, tracker, shares_to_ask,
@@ -425,7 +425,7 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
             if self._status:
                 self._status.set_status("Contacting Servers [%s] (second query),"
                                         " %d shares left.."
-                                        % (idlib.shortnodeid_b2a(tracker.serverid),
+                                        % (tracker.name(),
                                            len(self.homeless_shares)))
             d = tracker.query(shares_to_ask)
             d.addBoth(self._got_response, tracker, shares_to_ask,
@@ -486,12 +486,12 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
         else:
             (alreadygot, allocated) = res
             self.log("response to allocate_buckets() from server %s: alreadygot=%s, allocated=%s"
-                    % (idlib.shortnodeid_b2a(tracker.serverid),
+                    % (tracker.name(),
                        tuple(sorted(alreadygot)), tuple(sorted(allocated))),
                     level=log.NOISY)
             progress = False
             for s in alreadygot:
-                self.preexisting_shares.setdefault(s, set()).add(tracker.serverid)
+                self.preexisting_shares.setdefault(s, set()).add(tracker.get_serverid())
                 if s in self.homeless_shares:
                     self.homeless_shares.remove(s)
                     progress = True
@@ -505,7 +505,7 @@ class Tahoe2ServerSelector(log.PrefixingLogMixin):
                 progress = True
 
             if allocated or alreadygot:
-                self.serverids_with_shares.add(tracker.serverid)
+                self.serverids_with_shares.add(tracker.get_serverid())
 
             not_yet_present = set(shares_to_ask) - set(alreadygot)
             still_homeless = not_yet_present - set(allocated)
@@ -948,14 +948,14 @@ class CHKUploader:
             buckets.update(tracker.buckets)
             for shnum in tracker.buckets:
                 self._server_trackers[shnum] = tracker
-                servermap.setdefault(shnum, set()).add(tracker.serverid)
+                servermap.setdefault(shnum, set()).add(tracker.get_serverid())
         assert len(buckets) == sum([len(tracker.buckets)
                                     for tracker in upload_trackers]), \
             "%s (%s) != %s (%s)" % (
                 len(buckets),
                 buckets,
                 sum([len(tracker.buckets) for tracker in upload_trackers]),
-                [(t.buckets, t.serverid) for t in upload_trackers]
+                [(t.buckets, t.get_serverid()) for t in upload_trackers]
                 )
         encoder.set_shareholders(buckets, servermap)
 
@@ -964,7 +964,7 @@ class CHKUploader:
         r = self._results
         for shnum in self._encoder.get_shares_placed():
             server_tracker = self._server_trackers[shnum]
-            serverid = server_tracker.serverid
+            serverid = server_tracker.get_serverid()
             r.sharemap.add(shnum, serverid)
             r.servermap.add(serverid, shnum)
         r.pushed_shares = len(self._encoder.get_shares_placed())
diff --git a/src/allmydata/test/test_upload.py b/src/allmydata/test/test_upload.py
index 17c110c8..f1fefd4e 100644
--- a/src/allmydata/test/test_upload.py
+++ b/src/allmydata/test/test_upload.py
@@ -727,8 +727,10 @@ def is_happy_enough(servertoshnums, h, k):
 
 class FakeServerTracker:
     def __init__(self, serverid, buckets):
-        self.serverid = serverid
+        self._serverid = serverid
         self.buckets = buckets
+    def get_serverid(self):
+        return self._serverid
 
 class EncodingParameters(GridTestMixin, unittest.TestCase, SetDEPMixin,
     ShouldFailMixin):
@@ -789,7 +791,7 @@ class EncodingParameters(GridTestMixin, unittest.TestCase, SetDEPMixin,
             for tracker in upload_trackers:
                 buckets.update(tracker.buckets)
                 for bucket in tracker.buckets:
-                    servermap.setdefault(bucket, set()).add(tracker.serverid)
+                    servermap.setdefault(bucket, set()).add(tracker.get_serverid())
             encoder.set_shareholders(buckets, servermap)
             d = encoder.start()
             return d
diff --git a/src/allmydata/util/happinessutil.py b/src/allmydata/util/happinessutil.py
index 8c7d391c..33ba5673 100644
--- a/src/allmydata/util/happinessutil.py
+++ b/src/allmydata/util/happinessutil.py
@@ -74,7 +74,7 @@ def merge_servers(servermap, upload_trackers=None):
 
     for tracker in upload_trackers:
         for shnum in tracker.buckets:
-            servermap.setdefault(shnum, set()).add(tracker.serverid)
+            servermap.setdefault(shnum, set()).add(tracker.get_serverid())
     return servermap
 
 def servers_of_happiness(sharemap):
-- 
2.45.2