From 4061258c85da2960f1389dd4b4a8012087b92083 Mon Sep 17 00:00:00 2001
From: david-sarah <david-sarah@jacaranda.org>
Date: Tue, 26 Oct 2010 21:33:02 -0700
Subject: [PATCH] make ResponseCache smarter to avoid memory leaks: don't
 record timestamps, use DataSpans to merge entries, and clear the cache when
 we see a new seqnum. refs #1045, #1229

---
 src/allmydata/mutable/common.py    | 90 +++++++++++-----------------
 src/allmydata/mutable/filenode.py  |  4 +-
 src/allmydata/mutable/publish.py   |  3 +-
 src/allmydata/mutable/retrieve.py  |  6 +-
 src/allmydata/mutable/servermap.py |  5 +-
 src/allmydata/test/test_mutable.py | 95 ++++++++++--------------------
 6 files changed, 74 insertions(+), 129 deletions(-)

diff --git a/src/allmydata/mutable/common.py b/src/allmydata/mutable/common.py
index e154b9df..29656be4 100644
--- a/src/allmydata/mutable/common.py
+++ b/src/allmydata/mutable/common.py
@@ -1,6 +1,6 @@
 
 from allmydata.util import idlib
-from allmydata.util.dictutil import DictOfSets
+from allmydata.util.spans import DataSpans
 
 MODE_CHECK = "MODE_CHECK" # query all peers
 MODE_ANYTHING = "MODE_ANYTHING" # one recoverable version
@@ -59,74 +59,52 @@ class UnknownVersionError(Exception):
 class ResponseCache:
     """I cache share data, to reduce the number of round trips used during
     mutable file operations. All of the data in my cache is for a single
-    storage index, but I will keep information on multiple shares (and
-    multiple versions) for that storage index.
+    storage index, but I will keep information on multiple shares for
+    that storage index.
+
+    I maintain a highest-seen sequence number, and will flush all entries
+    each time this number increases (this doesn't necessarily imply that
+    all entries have the same sequence number).
 
     My cache is indexed by a (verinfo, shnum) tuple.
 
-    My cache entries contain a set of non-overlapping byteranges: (start,
-    data, timestamp) tuples.
+    My cache entries are DataSpans instances, each representing a set of
+    non-overlapping byteranges.
     """
 
     def __init__(self):
-        self.cache = DictOfSets()
+        self.cache = {}
+        self.seqnum = None
 
     def _clear(self):
-        # used by unit tests
-        self.cache = DictOfSets()
-
-    def _does_overlap(self, x_start, x_length, y_start, y_length):
-        if x_start < y_start:
-            x_start, y_start = y_start, x_start
-            x_length, y_length = y_length, x_length
-        x_end = x_start + x_length
-        y_end = y_start + y_length
-        # this just returns a boolean. Eventually we'll want a form that
-        # returns a range.
-        if not x_length:
-            return False
-        if not y_length:
-            return False
-        if x_start >= y_end:
-            return False
-        if y_start >= x_end:
-            return False
-        return True
-
-
-    def _inside(self, x_start, x_length, y_start, y_length):
-        x_end = x_start + x_length
-        y_end = y_start + y_length
-        if x_start < y_start:
-            return False
-        if x_start >= y_end:
-            return False
-        if x_end < y_start:
-            return False
-        if x_end > y_end:
-            return False
-        return True
-
-    def add(self, verinfo, shnum, offset, data, timestamp):
+        # also used by unit tests
+        self.cache = {}
+
+    def add(self, verinfo, shnum, offset, data):
+        seqnum = verinfo[0]
+        if seqnum > self.seqnum:
+            self._clear()
+            self.seqnum = seqnum
+
         index = (verinfo, shnum)
-        self.cache.add(index, (offset, data, timestamp) )
+        if index in self.cache:
+            self.cache[index].add(offset, data)
+        else:
+            spans = DataSpans()
+            spans.add(offset, data)
+            self.cache[index] = spans
 
     def read(self, verinfo, shnum, offset, length):
         """Try to satisfy a read request from cache.
-        Returns (data, timestamp), or (None, None) if the cache did not hold
-        the requested data.
+        Returns data, or None if the cache did not hold the entire requested span.
         """
 
-        # TODO: join multiple fragments, instead of only returning a hit if
-        # we have a fragment that contains the whole request
+        # TODO: perhaps return a DataSpans object representing the fragments
+        # that we have, instead of only returning a hit if we can satisfy the
+        # whole request from cache.
 
         index = (verinfo, shnum)
-        for entry in self.cache.get(index, set()):
-            (e_start, e_data, e_timestamp) = entry
-            if self._inside(offset, length, e_start, len(e_data)):
-                want_start = offset - e_start
-                want_end = offset+length - e_start
-                return (e_data[want_start:want_end], e_timestamp)
-        return None, None
-
-
+        if index in self.cache:
+            return self.cache[index].get(offset, length)
+        else:
+            return None
diff --git a/src/allmydata/mutable/filenode.py b/src/allmydata/mutable/filenode.py
index d9cd9274..6c38a856 100644
--- a/src/allmydata/mutable/filenode.py
+++ b/src/allmydata/mutable/filenode.py
@@ -149,8 +149,8 @@ class MutableFileNode:
         self._privkey = privkey
     def _populate_encprivkey(self, encprivkey):
         self._encprivkey = encprivkey
-    def _add_to_cache(self, verinfo, shnum, offset, data, timestamp):
-        self._cache.add(verinfo, shnum, offset, data, timestamp)
+    def _add_to_cache(self, verinfo, shnum, offset, data):
+        self._cache.add(verinfo, shnum, offset, data)
     def _read_from_cache(self, verinfo, shnum, offset, length):
         return self._cache.read(verinfo, shnum, offset, length)
 
diff --git a/src/allmydata/mutable/publish.py b/src/allmydata/mutable/publish.py
index 1b7e050a..2d63c87b 100644
--- a/src/allmydata/mutable/publish.py
+++ b/src/allmydata/mutable/publish.py
@@ -7,12 +7,13 @@ from twisted.internet import defer
 from twisted.python import failure
 from allmydata.interfaces import IPublishStatus
 from allmydata.util import base32, hashutil, mathutil, idlib, log
+from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
 from allmydata.storage.server import si_b2a
 from pycryptopp.cipher.aes import AES
 from foolscap.api import eventually, fireEventually
 
-from allmydata.mutable.common import MODE_WRITE, MODE_CHECK, DictOfSets, \
+from allmydata.mutable.common import MODE_WRITE, MODE_CHECK, \
      UncoordinatedWriteError, NotEnoughServersError
 from allmydata.mutable.servermap import ServerMap
 from allmydata.mutable.layout import pack_prefix, pack_share, unpack_header, pack_checkstring, \
diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py
index b4fa1c22..257cc5f3 100644
--- a/src/allmydata/mutable/retrieve.py
+++ b/src/allmydata/mutable/retrieve.py
@@ -7,12 +7,13 @@ from twisted.python import failure
 from foolscap.api import DeadReferenceError, eventually, fireEventually
 from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError
 from allmydata.util import hashutil, idlib, log
+from allmydata.util.dictutil import DictOfSets
 from allmydata import hashtree, codec
 from allmydata.storage.server import si_b2a
 from pycryptopp.cipher.aes import AES
 from pycryptopp.publickey import rsa
 
-from allmydata.mutable.common import DictOfSets, CorruptShareError, UncoordinatedWriteError
+from allmydata.mutable.common import CorruptShareError, UncoordinatedWriteError
 from allmydata.mutable.layout import SIGNED_PREFIX, unpack_share_data
 
 class RetrieveStatus:
@@ -198,8 +199,7 @@ class Retrieve:
         got_from_cache = False
         datavs = []
         for (offset, length) in readv:
-            (data, timestamp) = self._node._read_from_cache(self.verinfo, shnum,
-                                                            offset, length)
+            data = self._node._read_from_cache(self.verinfo, shnum, offset, length)
             if data is not None:
                 datavs.append(data)
         if len(datavs) == len(readv):
diff --git a/src/allmydata/mutable/servermap.py b/src/allmydata/mutable/servermap.py
index 6478afcb..999691fa 100644
--- a/src/allmydata/mutable/servermap.py
+++ b/src/allmydata/mutable/servermap.py
@@ -6,12 +6,13 @@ from twisted.internet import defer
 from twisted.python import failure
 from foolscap.api import DeadReferenceError, RemoteException, eventually
 from allmydata.util import base32, hashutil, idlib, log
+from allmydata.util.dictutil import DictOfSets
 from allmydata.storage.server import si_b2a
 from allmydata.interfaces import IServermapUpdaterStatus
 from pycryptopp.publickey import rsa
 
 from allmydata.mutable.common import MODE_CHECK, MODE_ANYTHING, MODE_WRITE, MODE_READ, \
-     DictOfSets, CorruptShareError, NeedMoreDataError
+     CorruptShareError, NeedMoreDataError
 from allmydata.mutable.layout import unpack_prefix_and_signature, unpack_header, unpack_share, \
      SIGNED_PREFIX_LENGTH
 
@@ -581,7 +582,7 @@ class ServermapUpdater:
                 verinfo = self._got_results_one_share(shnum, data, peerid, lp)
                 last_verinfo = verinfo
                 last_shnum = shnum
-                self._node._add_to_cache(verinfo, shnum, 0, data, now)
+                self._node._add_to_cache(verinfo, shnum, 0, data)
             except CorruptShareError, e:
                 # log it and give the other shares a chance to be processed
                 f = failure.Failure()
diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py
index 375de1ff..e4e6eb7f 100644
--- a/src/allmydata/test/test_mutable.py
+++ b/src/allmydata/test/test_mutable.py
@@ -301,16 +301,16 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
             d.addCallback(lambda res: self.failUnlessEqual(res, "contents"))
             d.addCallback(lambda ign: self.failUnless(isinstance(n._cache, ResponseCache)))
 
-            def _check_cache_size(expected):
-                # The total size of cache entries should not increase on the second download.
+            def _check_cache(expected):
+                # The total size of cache entries should not increase on the second download;
+                # in fact the cache contents should be identical.
                 d2 = n.download_best_version()
-                d2.addCallback(lambda ign: self.failUnlessEqual(len(repr(n._cache.cache)), expected))
+                d2.addCallback(lambda rep: self.failUnlessEqual(repr(n._cache.cache), expected))
                 return d2
-            d.addCallback(lambda ign: _check_cache_size(len(repr(n._cache.cache))))
+            d.addCallback(lambda ign: _check_cache(repr(n._cache.cache)))
             return d
         d.addCallback(_created)
         return d
-    test_response_cache_memory_leak.todo = "This isn't fixed (see #1045)."
 
     def test_create_with_initial_contents_function(self):
         data = "initial contents"
@@ -1717,72 +1717,37 @@ class MultipleVersions(unittest.TestCase, PublishMixin, CheckerMixin):
 
 
 class Utils(unittest.TestCase):
-    def _do_inside(self, c, x_start, x_length, y_start, y_length):
-        # we compare this against sets of integers
-        x = set(range(x_start, x_start+x_length))
-        y = set(range(y_start, y_start+y_length))
-        should_be_inside = x.issubset(y)
-        self.failUnlessEqual(should_be_inside, c._inside(x_start, x_length,
-                                                         y_start, y_length),
-                             str((x_start, x_length, y_start, y_length)))
-
-    def test_cache_inside(self):
-        c = ResponseCache()
-        x_start = 10
-        x_length = 5
-        for y_start in range(8, 17):
-            for y_length in range(8):
-                self._do_inside(c, x_start, x_length, y_start, y_length)
-
-    def _do_overlap(self, c, x_start, x_length, y_start, y_length):
-        # we compare this against sets of integers
-        x = set(range(x_start, x_start+x_length))
-        y = set(range(y_start, y_start+y_length))
-        overlap = bool(x.intersection(y))
-        self.failUnlessEqual(overlap, c._does_overlap(x_start, x_length,
-                                                      y_start, y_length),
-                             str((x_start, x_length, y_start, y_length)))
-
-    def test_cache_overlap(self):
-        c = ResponseCache()
-        x_start = 10
-        x_length = 5
-        for y_start in range(8, 17):
-            for y_length in range(8):
-                self._do_overlap(c, x_start, x_length, y_start, y_length)
-
     def test_cache(self):
         c = ResponseCache()
         # xdata = base62.b2a(os.urandom(100))[:100]
         xdata = "1Ex4mdMaDyOl9YnGBM3I4xaBF97j8OQAg1K3RBR01F2PwTP4HohB3XpACuku8Xj4aTQjqJIR1f36mEj3BCNjXaJmPBEZnnHL0U9l"
         ydata = "4DCUQXvkEPnnr9Lufikq5t21JsnzZKhzxKBhLhrBB6iIcBOWRuT4UweDhjuKJUre8A4wOObJnl3Kiqmlj4vjSLSqUGAkUD87Y3vs"
-        nope = (None, None)
-        c.add("v1", 1, 0, xdata, "time0")
-        c.add("v1", 1, 2000, ydata, "time1")
-        self.failUnlessEqual(c.read("v2", 1, 10, 11), nope)
-        self.failUnlessEqual(c.read("v1", 2, 10, 11), nope)
-        self.failUnlessEqual(c.read("v1", 1, 0, 10), (xdata[:10], "time0"))
-        self.failUnlessEqual(c.read("v1", 1, 90, 10), (xdata[90:], "time0"))
-        self.failUnlessEqual(c.read("v1", 1, 300, 10), nope)
-        self.failUnlessEqual(c.read("v1", 1, 2050, 5), (ydata[50:55], "time1"))
-        self.failUnlessEqual(c.read("v1", 1, 0, 101), nope)
-        self.failUnlessEqual(c.read("v1", 1, 99, 1), (xdata[99:100], "time0"))
-        self.failUnlessEqual(c.read("v1", 1, 100, 1), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 9), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 10), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 11), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 15), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 19), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 20), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 21), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1990, 25), nope)
-        self.failUnlessEqual(c.read("v1", 1, 1999, 25), nope)
-
-        # optional: join fragments
+        c.add("v1", 1, 0, xdata)
+        c.add("v1", 1, 2000, ydata)
+        self.failUnlessEqual(c.read("v2", 1, 10, 11), None)
+        self.failUnlessEqual(c.read("v1", 2, 10, 11), None)
+        self.failUnlessEqual(c.read("v1", 1, 0, 10), xdata[:10])
+        self.failUnlessEqual(c.read("v1", 1, 90, 10), xdata[90:])
+        self.failUnlessEqual(c.read("v1", 1, 300, 10), None)
+        self.failUnlessEqual(c.read("v1", 1, 2050, 5), ydata[50:55])
+        self.failUnlessEqual(c.read("v1", 1, 0, 101), None)
+        self.failUnlessEqual(c.read("v1", 1, 99, 1), xdata[99:100])
+        self.failUnlessEqual(c.read("v1", 1, 100, 1), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 9), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 10), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 11), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 15), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 19), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 20), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 21), None)
+        self.failUnlessEqual(c.read("v1", 1, 1990, 25), None)
+        self.failUnlessEqual(c.read("v1", 1, 1999, 25), None)
+
+        # test joining fragments
         c = ResponseCache()
-        c.add("v1", 1, 0, xdata[:10], "time0")
-        c.add("v1", 1, 10, xdata[10:20], "time1")
-        #self.failUnlessEqual(c.read("v1", 1, 0, 20), (xdata[:20], "time0"))
+        c.add("v1", 1, 0, xdata[:10])
+        c.add("v1", 1, 10, xdata[10:20])
+        self.failUnlessEqual(c.read("v1", 1, 0, 20), xdata[:20])
 
 class Exceptions(unittest.TestCase):
     def test_repr(self):
-- 
2.45.2