From 8ae9169e5bc03018a0edd160afdc761706a2c912 Mon Sep 17 00:00:00 2001
From: Daira Hopwood <david-sarah@jacaranda.org>
Date: Fri, 26 Apr 2013 00:30:04 +0100
Subject: [PATCH] Lock remote operations on sharesets. fixes #1869

Signed-off-by: Daira Hopwood <david-sarah@jacaranda.org>
---
 src/allmydata/interfaces.py                   | 14 ++++---
 src/allmydata/storage/backends/base.py        | 42 ++++++++++++++++---
 .../storage/backends/cloud/cloud_backend.py   | 15 +++----
 .../storage/backends/disk/disk_backend.py     | 12 +++---
 .../storage/backends/null/null_backend.py     | 12 +++---
 5 files changed, 66 insertions(+), 29 deletions(-)

diff --git a/src/allmydata/interfaces.py b/src/allmydata/interfaces.py
index 72a27dd9..0c2d5e22 100644
--- a/src/allmydata/interfaces.py
+++ b/src/allmydata/interfaces.py
@@ -328,11 +328,15 @@ class IStorageBackend(Interface):
 
     def get_sharesets_for_prefix(prefix):
         """
-        Return an iterable containing IShareSet objects for all storage
-        indices matching the given base-32 prefix, for which this backend
-        holds shares.
-        XXX This will probably need to return a Deferred, but for now it
-        is synchronous.
+        Return a Deferred that fires with an iterable of IShareSet objects
+        for all storage indices matching the given base-32 prefix, for
+        which this backend holds shares.
+
+        A caller will typically perform operations that take locks on some
+        of the sharesets returned by this method. Nothing prevents sharesets
+        matching the prefix from being deleted or added between listing the
+        sharesets and taking any such locks; callers must be able to tolerate
+        this.
         """
 
     def get_shareset(storageindex):
diff --git a/src/allmydata/storage/backends/base.py b/src/allmydata/storage/backends/base.py
index b56266c8..570da2ce 100644
--- a/src/allmydata/storage/backends/base.py
+++ b/src/allmydata/storage/backends/base.py
@@ -1,4 +1,6 @@
 
+from weakref import WeakValueDictionary
+
 from twisted.application import service
 from twisted.internet import defer
 
@@ -11,6 +13,18 @@ from allmydata.storage.leasedb import SHARETYPE_MUTABLE
 class Backend(service.MultiService):
     def __init__(self):
         service.MultiService.__init__(self)
+        self._lock_table = WeakValueDictionary()
+
+    def _get_lock(self, storage_index):
+        # Getting a shareset ensures that a lock exists for that storage_index.
+        # The _lock_table won't let go of an entry while the ShareSet (or any
+        # other objects that reference the lock) are live, or while it is locked.
+
+        lock = self._lock_table.get(storage_index, None)
+        if lock is None:
+            lock = defer.DeferredLock()
+            self._lock_table[storage_index] = lock
+        return lock
 
     def must_use_tubid_as_permutation_seed(self):
         # New backends cannot have been around before #466, and so have no backward
@@ -23,11 +37,10 @@ class ShareSet(object):
     This class implements shareset logic that could work for all backends, but
     might be useful to override for efficiency.
     """
-    # TODO: queue operations on a shareset to ensure atomicity for each fully
-    # successful operation (#1869).
 
-    def __init__(self, storage_index):
+    def __init__(self, storage_index, lock):
         self.storage_index = storage_index
+        self.lock = lock
 
     def get_storage_index(self):
         return self.storage_index
@@ -38,9 +51,25 @@ class ShareSet(object):
     def make_bucket_reader(self, account, share):
         return BucketReader(account, share)
 
+    def get_shares(self):
+        return self.lock.run(self._locked_get_shares)
+
+    def get_share(self, shnum):
+        return self.lock.run(self._locked_get_share, shnum)
+
+    def delete_share(self, shnum):
+        return self.lock.run(self._locked_delete_share, shnum)
+
     def testv_and_readv_and_writev(self, write_enabler,
                                    test_and_write_vectors, read_vector,
                                    expiration_time, account):
+        return self.lock.run(self._locked_testv_and_readv_and_writev, write_enabler,
+                             test_and_write_vectors, read_vector,
+                             expiration_time, account)
+
+    def _locked_testv_and_readv_and_writev(self, write_enabler,
+                                           test_and_write_vectors, read_vector,
+                                           expiration_time, account):
         # The implementation here depends on the following helper methods,
         # which must be provided by subclasses:
         #
@@ -52,7 +81,7 @@ class ShareSet(object):
         #     """create a mutable share with the given shnum and write_enabler"""
 
         sharemap = {}
-        d = self.get_shares()
+        d = self._locked_get_shares()
         def _got_shares( (shares, corrupted) ):
             d2 = defer.succeed(None)
             for share in shares:
@@ -150,6 +179,9 @@ class ShareSet(object):
         return d
 
     def readv(self, wanted_shnums, read_vector):
+        return self.lock.run(self._locked_readv, wanted_shnums, read_vector)
+
+    def _locked_readv(self, wanted_shnums, read_vector):
         """
         Read a vector from the numbered shares in this shareset. An empty
         shares list means to return data from all known shares.
@@ -160,7 +192,7 @@ class ShareSet(object):
         """
         shnums = []
         dreads = []
-        d = self.get_shares()
+        d = self._locked_get_shares()
         def _got_shares( (shares, corrupted) ):
             # We ignore corrupted shares.
             for share in shares:
diff --git a/src/allmydata/storage/backends/cloud/cloud_backend.py b/src/allmydata/storage/backends/cloud/cloud_backend.py
index 786c993b..efe4e627 100644
--- a/src/allmydata/storage/backends/cloud/cloud_backend.py
+++ b/src/allmydata/storage/backends/cloud/cloud_backend.py
@@ -80,12 +80,13 @@ class CloudBackend(Backend):
             # XXX we want this to be deterministic, so we return the sharesets sorted
             # by their si_strings, but we shouldn't need to explicitly re-sort them
             # because list_objects returns a sorted list.
-            return [CloudShareSet(si_a2b(s), self._container, self._incomingset) for s in sorted(si_strings)]
+            return [self.get_shareset(si_a2b(s)) for s in sorted(si_strings)]
         d.addCallback(_get_sharesets)
         return d
 
     def get_shareset(self, storage_index):
-        return CloudShareSet(storage_index, self._container, self._incomingset)
+        return CloudShareSet(storage_index, self._get_lock(storage_index),
+                             self._container, self._incomingset)
 
     def fill_in_space_stats(self, stats):
         # TODO: query space usage of container if supported.
@@ -101,8 +102,8 @@ class CloudBackend(Backend):
 class CloudShareSet(ShareSet):
     implements(IShareSet)
 
-    def __init__(self, storage_index, container, incomingset):
-        ShareSet.__init__(self, storage_index)
+    def __init__(self, storage_index, lock, container, incomingset):
+        ShareSet.__init__(self, storage_index, lock)
         self._container = container
         self._incomingset = incomingset
         self._key = get_share_key(storage_index)
@@ -110,7 +111,7 @@ class CloudShareSet(ShareSet):
     def get_overhead(self):
         return 0
 
-    def get_shares(self):
+    def _locked_get_shares(self):
         d = self._container.list_objects(prefix=self._key)
         def _get_shares(res):
             si = self.get_storage_index()
@@ -135,7 +136,7 @@ class CloudShareSet(ShareSet):
         d.addCallback(lambda shares: (shares, set()) )
         return d
 
-    def get_share(self, shnum):
+    def _locked_get_share(self, shnum):
         key = "%s%d" % (self._key, shnum)
         d = self._container.list_objects(prefix=key)
         def _get_share(res):
@@ -146,7 +147,7 @@ class CloudShareSet(ShareSet):
         d.addCallback(_get_share)
         return d
 
-    def delete_share(self, shnum):
+    def _locked_delete_share(self, shnum):
         key = "%s%d" % (self._key, shnum)
         return delete_chunks(self._container, key)
 
diff --git a/src/allmydata/storage/backends/disk/disk_backend.py b/src/allmydata/storage/backends/disk/disk_backend.py
index d507d152..e3d09a53 100644
--- a/src/allmydata/storage/backends/disk/disk_backend.py
+++ b/src/allmydata/storage/backends/disk/disk_backend.py
@@ -80,7 +80,7 @@ class DiskBackend(Backend):
     def get_shareset(self, storage_index):
         sharehomedir = si_si2dir(self._sharedir, storage_index)
         incominghomedir = si_si2dir(self._incomingdir, storage_index)
-        return DiskShareSet(storage_index, sharehomedir, incominghomedir)
+        return DiskShareSet(storage_index, self._get_lock(storage_index), sharehomedir, incominghomedir)
 
     def fill_in_space_stats(self, stats):
         stats['storage_server.reserved_space'] = self._reserved_space
@@ -123,8 +123,8 @@ class DiskBackend(Backend):
 class DiskShareSet(ShareSet):
     implements(IShareSet)
 
-    def __init__(self, storage_index, sharehomedir, incominghomedir=None):
-        ShareSet.__init__(self, storage_index)
+    def __init__(self, storage_index, lock, sharehomedir, incominghomedir=None):
+        ShareSet.__init__(self, storage_index, lock)
         self._sharehomedir = sharehomedir
         self._incominghomedir = incominghomedir
 
@@ -132,7 +132,7 @@ class DiskShareSet(ShareSet):
         return (fileutil.get_used_space(self._sharehomedir) +
                 fileutil.get_used_space(self._incominghomedir))
 
-    def get_shares(self):
+    def _locked_get_shares(self):
         si = self.get_storage_index()
         shares = {}
         corrupted = set()
@@ -149,11 +149,11 @@ class DiskShareSet(ShareSet):
         valid = [shares[shnum] for shnum in sorted(shares.keys())]
         return defer.succeed( (valid, corrupted) )
 
-    def get_share(self, shnum):
+    def _locked_get_share(self, shnum):
         return get_disk_share(os.path.join(self._sharehomedir, str(shnum)),
                               self.get_storage_index(), shnum)
 
-    def delete_share(self, shnum):
+    def _locked_delete_share(self, shnum):
         fileutil.remove(os.path.join(self._sharehomedir, str(shnum)))
         return defer.succeed(None)
 
diff --git a/src/allmydata/storage/backends/null/null_backend.py b/src/allmydata/storage/backends/null/null_backend.py
index dbd5c9e7..b8e80c3d 100644
--- a/src/allmydata/storage/backends/null/null_backend.py
+++ b/src/allmydata/storage/backends/null/null_backend.py
@@ -44,7 +44,7 @@ class NullBackend(Backend):
     def get_shareset(self, storage_index):
         shareset = self._sharesets.get(storage_index, None)
         if shareset is None:
-            shareset = NullShareSet(storage_index)
+            shareset = NullShareSet(storage_index, self._get_lock(storage_index))
             self._sharesets[storage_index] = shareset
         return shareset
 
@@ -55,8 +55,8 @@ class NullBackend(Backend):
 class NullShareSet(ShareSet):
     implements(IShareSet)
 
-    def __init__(self, storage_index):
-        self.storage_index = storage_index
+    def __init__(self, storage_index, lock):
+        ShareSet.__init__(self, storage_index, lock)
         self._incoming_shnums = set()
         self._immutable_shnums = set()
         self._mutable_shnums = set()
@@ -69,7 +69,7 @@ class NullShareSet(ShareSet):
     def get_overhead(self):
         return 0
 
-    def get_shares(self):
+    def _locked_get_shares(self):
         shares = {}
         for shnum in self._immutable_shnums:
             shares[shnum] = ImmutableNullShare(self, shnum)
@@ -78,7 +78,7 @@ class NullShareSet(ShareSet):
         # This backend never has any corrupt shares.
         return defer.succeed( ([shares[shnum] for shnum in sorted(shares.keys())], set()) )
 
-    def get_share(self, shnum):
+    def _locked_get_share(self, shnum):
         if shnum in self._immutable_shnums:
             return defer.succeed(ImmutableNullShare(self, shnum))
         elif shnum in self._mutable_shnums:
@@ -87,7 +87,7 @@ class NullShareSet(ShareSet):
             def _not_found(): raise IndexError("no such share %d" % (shnum,))
             return defer.execute(_not_found)
 
-    def delete_share(self, shnum, include_incoming=False):
+    def _locked_delete_share(self, shnum, include_incoming=False):
         if include_incoming and (shnum in self._incoming_shnums):
             self._incoming_shnums.remove(shnum)
         if shnum in self._immutable_shnums:
-- 
2.45.2