]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
directories: keep track of your position as you decode netstring after netstring...
authorZooko O'Whielacronx <zooko@zooko.com>
Sun, 5 Jul 2009 02:51:09 +0000 (19:51 -0700)
committerZooko O'Whielacronx <zooko@zooko.com>
Sun, 5 Jul 2009 02:51:09 +0000 (19:51 -0700)
This makes decoding linear in the number of netstrings instead of O(N^2).

src/allmydata/dirnode.py
src/allmydata/test/test_netstring.py
src/allmydata/util/netstring.py

index 731431992b25ac029c8278a2998eba68c106f677..8f72709b0f2e7bd45a8a7c5d8938d53050fca7d5 100644 (file)
@@ -204,15 +204,16 @@ class NewDirectoryNode:
         # rocap, rwcap, metadata), in which the name,rocap,metadata are in
         # cleartext. The 'name' is UTF-8 encoded. The rwcap is formatted as:
         # pack("16ss32s", iv, AES(H(writekey+iv), plaintextrwcap), mac)
-        assert isinstance(data, str)
+        assert isinstance(data, str), (repr(data), type(data))
         # an empty directory is serialized as an empty string
         if data == "":
             return {}
         writeable = not self.is_readonly()
         children = {}
-        while len(data) > 0:
-            entry, data = split_netstring(data, 1, True)
-            name, rocap, rwcapdata, metadata_s = split_netstring(entry, 4)
+        position = 0
+        while position < len(data):
+            entries, position = split_netstring(data, 1, position)
+            (name, rocap, rwcapdata, metadata_s), subpos = split_netstring(entries[0], 4)
             name = name.decode("utf-8")
             rwcap = None
             if writeable:
index 5c8199a6f34ae806f4293453fe5b11a9a0b12f96..8fecdc49fa815dda3f334cb37cc08718caa9cc0b 100644 (file)
@@ -5,37 +5,32 @@ from allmydata.util.netstring import netstring, split_netstring
 class Netstring(unittest.TestCase):
     def test_split(self):
         a = netstring("hello") + netstring("world")
-        self.failUnlessEqual(split_netstring(a, 2), ("hello", "world"))
-        self.failUnlessEqual(split_netstring(a, 2, False), ("hello", "world"))
-        self.failUnlessEqual(split_netstring(a, 2, True),
-                             ("hello", "world", ""))
+        self.failUnlessEqual(split_netstring(a, 2), (["hello", "world"], len(a)))
+        self.failUnlessEqual(split_netstring(a, 2, required_trailer=""), (["hello", "world"], len(a)))
         self.failUnlessRaises(ValueError, split_netstring, a, 3)
-        self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2)
-        self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2, False)
+        self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2, required_trailer="")
+        self.failUnlessEqual(split_netstring(a+" extra", 2), (["hello", "world"], len(a)))
         self.failUnlessEqual(split_netstring(a+"++", 2, required_trailer="++"),
-                             ("hello", "world"))
+                             (["hello", "world"], len(a)+2))
         self.failUnlessRaises(ValueError,
                               split_netstring, a+"+", 2, required_trailer="not")
 
     def test_extra(self):
         a = netstring("hello")
-        self.failUnlessEqual(split_netstring(a, 1, True), ("hello", ""))
+        self.failUnlessEqual(split_netstring(a, 1), (["hello"], len(a)))
         b = netstring("hello") + "extra stuff"
-        self.failUnlessEqual(split_netstring(b, 1, True),
-                             ("hello", "extra stuff"))
+        self.failUnlessEqual(split_netstring(b, 1),
+                             (["hello"], len(a)))
 
     def test_nested(self):
         a = netstring("hello") + netstring("world") + "extra stuff"
         b = netstring("a") + netstring("is") + netstring(a) + netstring(".")
-        top = split_netstring(b, 4)
+        (top, pos) = split_netstring(b, 4)
         self.failUnlessEqual(len(top), 4)
         self.failUnlessEqual(top[0], "a")
         self.failUnlessEqual(top[1], "is")
         self.failUnlessEqual(top[2], a)
         self.failUnlessEqual(top[3], ".")
-        self.failUnlessRaises(ValueError, split_netstring, a, 2)
-        self.failUnlessRaises(ValueError, split_netstring, a, 2, False)
-        bottom = split_netstring(a, 2, True)
-        self.failUnlessEqual(bottom, ("hello", "world", "extra stuff"))
-
-
+        self.failUnlessRaises(ValueError, split_netstring, a, 2, required_trailer="")
+        bottom = split_netstring(a, 2)
+        self.failUnlessEqual(bottom, (["hello", "world"], len(netstring("hello")+netstring("world"))))
index a1fe8cb981983396ff20cbdc5a52e5f6938c3c23..73f15658cd1315f88425128478e480572977dcca 100644 (file)
@@ -5,34 +5,34 @@ def netstring(s):
     return "%d:%s," % (len(s), s,)
 
 def split_netstring(data, numstrings,
-                    allow_leftover=False,
-                    required_trailer=""):
-    """like string.split(), but extracts netstrings. If allow_leftover=False,
-    I return numstrings elements, and throw ValueError if there was leftover
-    data that does not exactly equal 'required_trailer'. If
-    allow_leftover=True, required_trailer must be empty, and I return
-    numstrings+1 elements, in which the last element is the leftover data
-    (possibly an empty string)"""
-
-    assert not (allow_leftover and required_trailer)
+                    position=0,
+                    required_trailer=None):
+    """like string.split(), but extracts netstrings. Ignore all bytes of data
+    before the 'position' byte. Return a tuple of (list of elements (numstrings
+    in length), new position index). The new position index points to the first
+    byte which was not consumed (the 'required_trailer', if any, counts as
+    consumed).  If 'required_trailer' is not None, throw ValueError if leftover
+    data does not exactly equal 'required_trailer'."""
 
+    assert type(position) in (int, long), (repr(position), type(position))
     elements = []
     assert numstrings >= 0
-    while data:
-        colon = data.index(":")
-        length = int(data[:colon])
+    while position < len(data):
+        colon = data.index(":", position)
+        length = int(data[position:colon])
         string = data[colon+1:colon+1+length]
-        assert len(string) == length
+        assert len(string) == length, (len(string), length)
         elements.append(string)
-        assert data[colon+1+length] == ","
-        data = data[colon+1+length+1:]
+        position = colon+1+length
+        assert data[position] == ",", position
+        position += 1
         if len(elements) == numstrings:
             break
     if len(elements) < numstrings:
         raise ValueError("ran out of netstrings")
-    if allow_leftover:
-        return tuple(elements + [data])
-    if data != required_trailer:
-        raise ValueError("leftover data in netstrings")
-    return tuple(elements)
-
+    if required_trailer is not None:
+        if ((len(data) - position) != len(required_trailer)) or (data[position:] != required_trailer):
+            raise ValueError("leftover data in netstrings")
+        return (elements, position + len(required_trailer))
+    else:
+        return (elements, position)