netstring: add required_trailer= argument
authorBrian Warner <warner@allmydata.com>
Fri, 26 Sep 2008 16:57:54 +0000 (09:57 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 26 Sep 2008 16:57:54 +0000 (09:57 -0700)
src/allmydata/test/test_netstring.py
src/allmydata/util/netstring.py

index 3923a0e19d11c19462c7cfe066a87779ebac4cfd..5c8199a6f34ae806f4293453fe5b11a9a0b12f96 100644 (file)
@@ -12,6 +12,10 @@ class Netstring(unittest.TestCase):
         self.failUnlessRaises(ValueError, split_netstring, a, 3)
         self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2)
         self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2, False)
+        self.failUnlessEqual(split_netstring(a+"++", 2, required_trailer="++"),
+                             ("hello", "world"))
+        self.failUnlessRaises(ValueError,
+                              split_netstring, a+"+", 2, required_trailer="not")
 
     def test_extra(self):
         a = netstring("hello")
index 70a14e01b970c9c62aeb5247946ba206b5b9997d..a1fe8cb981983396ff20cbdc5a52e5f6938c3c23 100644 (file)
@@ -4,11 +4,18 @@ def netstring(s):
     assert isinstance(s, str), s # no unicode here
     return "%d:%s," % (len(s), s,)
 
-def split_netstring(data, numstrings, allow_leftover=False):
+def split_netstring(data, numstrings,
+                    allow_leftover=False,
+                    required_trailer=""):
     """like string.split(), but extracts netstrings. If allow_leftover=False,
-    returns numstrings elements, and throws ValueError if there was leftover
-    data. If allow_leftover=True, returns numstrings+1 elements, in which the
-    last element is the leftover data (possibly an empty string)"""
+    I return numstrings elements, and throw ValueError if there was leftover
+    data that does not exactly equal 'required_trailer'. If
+    allow_leftover=True, required_trailer must be empty, and I return
+    numstrings+1 elements, in which the last element is the leftover data
+    (possibly an empty string)"""
+
+    assert not (allow_leftover and required_trailer)
+
     elements = []
     assert numstrings >= 0
     while data:
@@ -25,7 +32,7 @@ def split_netstring(data, numstrings, allow_leftover=False):
         raise ValueError("ran out of netstrings")
     if allow_leftover:
         return tuple(elements + [data])
-    if data:
+    if data != required_trailer:
         raise ValueError("leftover data in netstrings")
     return tuple(elements)