From: Brian Warner Date: Fri, 26 Sep 2008 16:57:54 +0000 (-0700) Subject: netstring: add required_trailer= argument X-Git-Url: https://git.rkrishnan.org/somewhere?a=commitdiff_plain;h=98c8e25709579c17c6b37be181dee059cc2a016d;p=tahoe-lafs%2Ftahoe-lafs.git netstring: add required_trailer= argument --- diff --git a/src/allmydata/test/test_netstring.py b/src/allmydata/test/test_netstring.py index 3923a0e1..5c8199a6 100644 --- a/src/allmydata/test/test_netstring.py +++ b/src/allmydata/test/test_netstring.py @@ -12,6 +12,10 @@ class Netstring(unittest.TestCase): self.failUnlessRaises(ValueError, split_netstring, a, 3) self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2) self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2, False) + self.failUnlessEqual(split_netstring(a+"++", 2, required_trailer="++"), + ("hello", "world")) + self.failUnlessRaises(ValueError, + split_netstring, a+"+", 2, required_trailer="not") def test_extra(self): a = netstring("hello") diff --git a/src/allmydata/util/netstring.py b/src/allmydata/util/netstring.py index 70a14e01..a1fe8cb9 100644 --- a/src/allmydata/util/netstring.py +++ b/src/allmydata/util/netstring.py @@ -4,11 +4,18 @@ def netstring(s): assert isinstance(s, str), s # no unicode here return "%d:%s," % (len(s), s,) -def split_netstring(data, numstrings, allow_leftover=False): +def split_netstring(data, numstrings, + allow_leftover=False, + required_trailer=""): """like string.split(), but extracts netstrings. If allow_leftover=False, - returns numstrings elements, and throws ValueError if there was leftover - data. If allow_leftover=True, returns numstrings+1 elements, in which the - last element is the leftover data (possibly an empty string)""" + I return numstrings elements, and throw ValueError if there was leftover + data that does not exactly equal 'required_trailer'. If + allow_leftover=True, required_trailer must be empty, and I return + numstrings+1 elements, in which the last element is the leftover data + (possibly an empty string)""" + + assert not (allow_leftover and required_trailer) + elements = [] assert numstrings >= 0 while data: @@ -25,7 +32,7 @@ def split_netstring(data, numstrings, allow_leftover=False): raise ValueError("ran out of netstrings") if allow_leftover: return tuple(elements + [data]) - if data: + if data != required_trailer: raise ValueError("leftover data in netstrings") return tuple(elements)