]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
mutable: implement MutableFileNode.modify, plus tests
authorBrian Warner <warner@allmydata.com>
Fri, 18 Apr 2008 02:12:42 +0000 (19:12 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 18 Apr 2008 02:12:42 +0000 (19:12 -0700)
src/allmydata/mutable/node.py
src/allmydata/test/test_mutable.py

index 0c1ffc25b35c0345a62ad933a2804074bfa6fe76..ae9a54a32dc9474ca457d386ee1783c9c99f999b 100644 (file)
@@ -1,11 +1,12 @@
 
-import weakref
+import weakref, random
 from twisted.application import service
 
 from zope.interface import implements
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from allmydata.interfaces import IMutableFileNode, IMutableFileURI
 from allmydata.util import hashutil
+from allmydata.util.assertutil import precondition
 from allmydata.uri import WriteableSSKFileURI
 from allmydata.encode import NotEnoughSharesError
 from pycryptopp.publickey import rsa
@@ -13,11 +14,33 @@ from pycryptopp.cipher.aes import AES
 
 from publish import Publish
 from common import MODE_READ, MODE_WRITE, UnrecoverableFileError, \
-     ResponseCache
+     ResponseCache, UncoordinatedWriteError
 from servermap import ServerMap, ServermapUpdater
 from retrieve import Retrieve
 
 
+class BackoffAgent:
+    # these parameters are copied from foolscap.reconnector, which gets them
+    # from twisted.internet.protocol.ReconnectingClientFactory
+    initialDelay = 1.0
+    factor = 2.7182818284590451 # (math.e)
+    jitter = 0.11962656492 # molar Planck constant times c, Joule meter/mole
+    maxRetries = 4
+
+    def __init__(self):
+        self._delay = self.initialDelay
+        self._count = 0
+    def delay(self, node, f):
+        self._count += 1
+        if self._count == 4:
+            return f
+        self._delay = self._delay * self.factor
+        self._delay = random.normalvariate(self._delay,
+                                           self._delay * self.jitter)
+        d = defer.Deferred()
+        reactor.callLater(self._delay, d.callback, None)
+        return d
+
 # use client.create_mutable_file() to make one of these
 
 class MutableFileNode:
@@ -253,14 +276,13 @@ class MutableFileNode:
         return d
     def _try_once_to_download_best_version(self, servermap, mode):
         d = self._update_servermap(servermap, mode)
-        def _updated(ignored):
-            goal = servermap.best_recoverable_version()
-            if not goal:
-                raise UnrecoverableFileError("no recoverable versions")
-            return self._try_once_to_download_version(servermap, goal)
-        d.addCallback(_updated)
+        d.addCallback(self._once_updated_download_best_version, servermap)
         return d
-
+    def _once_updated_download_best_version(self, ignored, servermap):
+        goal = servermap.best_recoverable_version()
+        if not goal:
+            raise UnrecoverableFileError("no recoverable versions")
+        return self._try_once_to_download_version(servermap, goal)
 
     def overwrite(self, new_contents):
         return self._do_serialized(self._overwrite, new_contents)
@@ -271,7 +293,7 @@ class MutableFileNode:
         return d
 
 
-    def modify(self, modifier, *args, **kwargs):
+    def modify(self, modifier, backoffer=None):
         """I use a modifier callback to apply a change to the mutable file.
         I implement the following pseudocode::
 
@@ -283,7 +305,8 @@ class MutableFileNode:
            if new == old: break
            try:
              publish(new)
-           except UncoordinatedWriteError:
+           except UncoordinatedWriteError, e:
+             backoffer(e)
              continue
            break
          release_mutable_filenode_lock()
@@ -295,8 +318,49 @@ class MutableFileNode:
 
         Note that the modifier is required to run synchronously, and must not
         invoke any methods on this MutableFileNode instance.
+
+        The backoff-er is a callable that is responsible for inserting a
+        random delay between subsequent attempts, to help competing updates
+        from colliding forever. It is also allowed to give up after a while.
+        The backoffer is given two arguments: this MutableFileNode, and the
+        Failure object that contains the UncoordinatedWriteError. It should
+        return a Deferred that will fire when the next attempt should be
+        made, or return the Failure if the loop should give up. If
+        backoffer=None, a default one is provided which will perform
+        exponential backoff, and give up after 4 tries. Note that the
+        backoffer should not invoke any methods on this MutableFileNode
+        instance, and it needs to be highly conscious of deadlock issues.
         """
-        NotImplementedError
+        return self._do_serialized(self._modify, modifier, backoffer)
+    def _modify(self, modifier, backoffer):
+        servermap = ServerMap()
+        if backoffer is None:
+            backoffer = BackoffAgent().delay
+        return self._modify_and_retry(servermap, modifier, backoffer)
+    def _modify_and_retry(self, servermap, modifier, backoffer):
+        d = self._modify_once(servermap, modifier)
+        def _retry(f):
+            f.trap(UncoordinatedWriteError)
+            d2 = defer.maybeDeferred(backoffer, self, f)
+            d2.addCallback(lambda ignored:
+                           self._modify_and_retry(servermap, modifier,
+                                                  backoffer))
+            return d2
+        d.addErrback(_retry)
+        return d
+    def _modify_once(self, servermap, modifier):
+        d = self._update_servermap(servermap, MODE_WRITE)
+        d.addCallback(self._once_updated_download_best_version, servermap)
+        def _apply(old_contents):
+            new_contents = modifier(old_contents)
+            if new_contents is None or new_contents == old_contents:
+                # no changes need to be made
+                return
+            precondition(isinstance(new_contents, str),
+                         "Modifier function must return a string or None")
+            return self._upload(new_contents, servermap)
+        d.addCallback(_apply)
+        return d
 
     def get_servermap(self, mode):
         return self._do_serialized(self._get_servermap, mode)
index 7b5626c5e2358fb1f303b90e67be109f9fb1c98c..a3623bf8c0eddd5ed134754e610c59f4ef7ea39c 100644 (file)
@@ -14,9 +14,10 @@ from foolscap.eventual import eventually, fireEventually
 from foolscap.logging import log
 import sha
 
-from allmydata.mutable.node import MutableFileNode
+from allmydata.mutable.node import MutableFileNode, BackoffAgent
 from allmydata.mutable.common import DictOfSets, ResponseCache, \
-     MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, UnrecoverableFileError
+     MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
+     UnrecoverableFileError, UncoordinatedWriteError
 from allmydata.mutable.retrieve import Retrieve
 from allmydata.mutable.publish import Publish
 from allmydata.mutable.servermap import ServerMap, ServermapUpdater
@@ -254,6 +255,27 @@ class Filenode(unittest.TestCase):
         d.addCallback(_created)
         return d
 
+    def test_serialize(self):
+        n = MutableFileNode(self.client)
+        calls = []
+        def _callback(*args, **kwargs):
+            self.failUnlessEqual(args, (4,) )
+            self.failUnlessEqual(kwargs, {"foo": 5})
+            calls.append(1)
+            return 6
+        d = n._do_serialized(_callback, 4, foo=5)
+        def _check_callback(res):
+            self.failUnlessEqual(res, 6)
+            self.failUnlessEqual(calls, [1])
+        d.addCallback(_check_callback)
+
+        def _errback():
+            raise ValueError("heya")
+        d.addCallback(lambda res:
+                      self.shouldFail(ValueError, "_check_errback", "heya",
+                                      n._do_serialized, _errback))
+        return d
+
     def test_upload_and_download(self):
         d = self.client.create_mutable_file()
         def _created(n):
@@ -296,6 +318,147 @@ class Filenode(unittest.TestCase):
         d.addCallback(_created)
         return d
 
+    def failUnlessCurrentSeqnumIs(self, n, expected_seqnum):
+        d = n.get_servermap(MODE_READ)
+        d.addCallback(lambda servermap: servermap.best_recoverable_version())
+        d.addCallback(lambda verinfo:
+                      self.failUnlessEqual(verinfo[0], expected_seqnum))
+        return d
+
+    def shouldFail(self, expected_failure, which, substring,
+                   callable, *args, **kwargs):
+        assert substring is None or isinstance(substring, str)
+        d = defer.maybeDeferred(callable, *args, **kwargs)
+        def done(res):
+            if isinstance(res, failure.Failure):
+                res.trap(expected_failure)
+                if substring:
+                    self.failUnless(substring in str(res),
+                                    "substring '%s' not in '%s'"
+                                    % (substring, str(res)))
+            else:
+                self.fail("%s was supposed to raise %s, not get '%s'" %
+                          (which, expected_failure, res))
+        d.addBoth(done)
+        return d
+
+    def test_modify(self):
+        def _modifier(old_contents):
+            return old_contents + "line2"
+        def _non_modifier(old_contents):
+            return old_contents
+        def _none_modifier(old_contents):
+            return None
+        def _error_modifier(old_contents):
+            raise ValueError
+        calls = []
+        def _ucw_error_modifier(old_contents):
+            # simulate an UncoordinatedWriteError once
+            calls.append(1)
+            if len(calls) <= 1:
+                raise UncoordinatedWriteError("simulated")
+            return old_contents + "line3"
+
+        d = self.client.create_mutable_file("line1")
+        def _created(n):
+            d = n.modify(_modifier)
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
+
+            d.addCallback(lambda res: n.modify(_non_modifier))
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
+
+            d.addCallback(lambda res: n.modify(_none_modifier))
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
+
+            d.addCallback(lambda res:
+                          self.shouldFail(ValueError, "error_modifier", None,
+                                          n.modify, _error_modifier))
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
+
+            d.addCallback(lambda res: n.modify(_ucw_error_modifier))
+            d.addCallback(lambda res: self.failUnlessEqual(len(calls), 2))
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res,
+                                                           "line1line2line3"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
+
+            return d
+        d.addCallback(_created)
+        return d
+
+    def test_modify_backoffer(self):
+        def _modifier(old_contents):
+            return old_contents + "line2"
+        calls = []
+        def _ucw_error_modifier(old_contents):
+            # simulate an UncoordinatedWriteError once
+            calls.append(1)
+            if len(calls) <= 1:
+                raise UncoordinatedWriteError("simulated")
+            return old_contents + "line3"
+        def _always_ucw_error_modifier(old_contents):
+            raise UncoordinatedWriteError("simulated")
+        def _backoff_stopper(node, f):
+            return f
+        def _backoff_pauser(node, f):
+            d = defer.Deferred()
+            reactor.callLater(0.5, d.callback, None)
+            return d
+
+        # the give-up-er will hit its maximum retry count quickly
+        giveuper = BackoffAgent()
+        giveuper._delay = 0.1
+        giveuper.factor = 1
+
+        d = self.client.create_mutable_file("line1")
+        def _created(n):
+            d = n.modify(_modifier)
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
+
+            d.addCallback(lambda res:
+                          self.shouldFail(UncoordinatedWriteError,
+                                          "_backoff_stopper", None,
+                                          n.modify, _ucw_error_modifier,
+                                          _backoff_stopper))
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res, "line1line2"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 2))
+
+            def _reset_ucw_error_modifier(res):
+                calls[:] = []
+                return res
+            d.addCallback(_reset_ucw_error_modifier)
+            d.addCallback(lambda res: n.modify(_ucw_error_modifier,
+                                               _backoff_pauser))
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res,
+                                                           "line1line2line3"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
+
+            d.addCallback(lambda res:
+                          self.shouldFail(UncoordinatedWriteError,
+                                          "giveuper", None,
+                                          n.modify, _always_ucw_error_modifier,
+                                          giveuper.delay))
+            d.addCallback(lambda res: n.download_best_version())
+            d.addCallback(lambda res: self.failUnlessEqual(res,
+                                                           "line1line2line3"))
+            d.addCallback(lambda res: self.failUnlessCurrentSeqnumIs(n, 3))
+
+            return d
+        d.addCallback(_created)
+        return d
+
     def test_upload_and_download_full_size_keys(self):
         self.client.mutable_file_node_class = MutableFileNode
         d = self.client.create_mutable_file()