]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
switch from rfc 3548 base-32 to z-base-32 except for tubids/nodeids
authorZooko O'Whielacronx <zooko@zooko.com>
Tue, 24 Jul 2007 20:46:06 +0000 (13:46 -0700)
committerZooko O'Whielacronx <zooko@zooko.com>
Tue, 24 Jul 2007 20:46:06 +0000 (13:46 -0700)
src/allmydata/introducer.py
src/allmydata/node.py
src/allmydata/test/test_introducer.py
src/allmydata/test/test_util.py
src/allmydata/util/idlib.py

index 8db66b2dc6169de082878a3ccccf838d8676d4dd..d5c4a306fb9c33f0ba3055629566c5d0a8e259c9 100644 (file)
@@ -1,4 +1,6 @@
 
+from base64 import b32encode, b32decode
+
 import re
 from zope.interface import implements
 from twisted.application import service
@@ -92,9 +94,9 @@ class IntroducerClient(service.Service, Referenceable):
         # them, which may or may not be what we want.
         m = re.match(r'pb://(\w+)@', furl)
         assert m
-        nodeid = idlib.a2b(m.group(1))
+        nodeid = b32decode(m.group(1).upper())
         def _got_peer(rref):
-            self.log(" connected to(%s)" % idlib.b2a(nodeid))
+            self.log(" connected to(%s)" % b32encode(nodeid).lower())
             self.connection_observers.notify(nodeid, rref)
             self.connections[nodeid] = rref
             def _lost():
index 941b4d29364da0fba6bc0ce31a8f404c498700eb..41ef9c6901cd27680961e770a40748ea510ee6ff 100644 (file)
@@ -1,3 +1,5 @@
+from base64 import b32encode, b32decode
+
 import os.path, re
 
 import twisted
@@ -5,7 +7,7 @@ from twisted.python import log
 from twisted.application import service
 from twisted.internet import defer, reactor
 from foolscap import Tub, eventual
-from allmydata.util import idlib, iputil, observer
+from allmydata.util import iputil, observer
 from allmydata.util.assertutil import precondition
 
 
@@ -34,9 +36,9 @@ class Node(service.MultiService):
         self.tub = Tub(certFile=certfile)
         self.tub.setOption("logLocalFailures", True)
         self.tub.setOption("logRemoteFailures", True)
-        self.nodeid = idlib.a2b(self.tub.tubID)
+        self.nodeid = b32encode(self.tub.tubID).lower()
         f = open(os.path.join(self.basedir, self.NODEIDFILE), "w")
-        f.write(idlib.b2a(self.nodeid) + "\n")
+        f.write(b32encode(self.nodeid).lower() + "\n")
         f.close()
         self.short_nodeid = self.tub.tubID[:4] # ready for printing
         assert self.PORTNUMFILE, "Your node.Node subclass must provide PORTNUMFILE"
index 5154f4e734d819ed6a220abc72e31b86ae932f22..bbdc47d71435f0b4e78eb191dcef4300a24c5b07 100644 (file)
@@ -1,3 +1,4 @@
+from base64 import b32encode, b32decode
 
 from twisted.trial import unittest
 from twisted.internet import defer, reactor
@@ -7,7 +8,7 @@ from foolscap import Tub, Referenceable
 from foolscap.eventual import flushEventualQueue
 from twisted.application import service
 from allmydata.introducer import IntroducerClient, Introducer
-from allmydata.util import idlib, testutil
+from allmydata.util import testutil
 
 class MyNode(Referenceable):
     pass
@@ -56,7 +57,7 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
         self.waiting_for_connections = NUMCLIENTS*NUMCLIENTS
         d = self._done_counting = defer.Deferred()
         def _count(nodeid, rref):
-            log.msg("NEW CONNECTION! %s %s" % (idlib.b2a(nodeid), rref))
+            log.msg("NEW CONNECTION! %s %s" % (b32encode(nodeid).lower(), rref))
             self.waiting_for_connections -= 1
             if self.waiting_for_connections == 0:
                 self._done_counting.callback("done!")
@@ -92,7 +93,7 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
             origin_c = clients[0]
             # find a target that is not themselves
             for nodeid,rref in origin_c.connections.items():
-                if idlib.b2a(nodeid) != tubs[origin_c].tubID:
+                if b32encode(nodeid).lower() != tubs[origin_c].tubID:
                     victim = rref
                     break
             log.msg(" disconnecting %s->%s" % (tubs[origin_c].tubID, victim))
@@ -111,7 +112,7 @@ class TestIntroducer(unittest.TestCase, testutil.PollMixin):
             origin_c = clients[0]
             # find a target that *is* themselves
             for nodeid,rref in origin_c.connections.items():
-                if idlib.b2a(nodeid) == tubs[origin_c].tubID:
+                if b32encode(nodeid).lower() == tubs[origin_c].tubID:
                     victim = rref
                     break
             log.msg(" disconnecting %s->%s" % (tubs[origin_c].tubID, victim))
index 0b7138cde4f6989700ecfdf886a2994015f070bb..118dd6a799b8304e6cf9e7c3d18a2b172e3f2590 100644 (file)
@@ -10,18 +10,13 @@ from allmydata.util import assertutil, fileutil
 
 class IDLib(unittest.TestCase):
     def test_b2a(self):
-        self.failUnlessEqual(idlib.b2a("\x12\x34"), "ci2a====")
+        self.failUnlessEqual(idlib.b2a("\x12\x34"), "ne4y")
     def test_b2a_or_none(self):
         self.failUnlessEqual(idlib.b2a_or_none(None), None)
-        self.failUnlessEqual(idlib.b2a_or_none("\x12\x34"), "ci2a====")
+        self.failUnlessEqual(idlib.b2a_or_none("\x12\x34"), "ne4y")
     def test_a2b(self):
-        self.failUnlessEqual(idlib.a2b("ci2a===="), "\x12\x34")
-        self.failUnlessRaises(TypeError, idlib.a2b, "bogus")
-    def test_peerid(self):
-        # these are 160-bit numbers
-        peerid = "\x80" + "\x00" * 19
-        short = idlib.peerid_to_short_string(peerid)
-        self.failUnlessEqual(short, "qaaa")
+        self.failUnlessEqual(idlib.a2b("ne4y"), "\x12\x34")
+        self.failUnlessRaises(AssertionError, idlib.a2b, "b0gus")
 
 class NoArgumentException(Exception):
     def __init__(self):
index 185b3a3bf79a29e45dde065a41fb9f4303ac90d2..0f89a33c3911cd1913aefe788d6f16b7382684cb 100644 (file)
-from base64 import b32encode, b32decode
+# from the Python Standard Library
+import string
 
-def b2a(i):
-    assert isinstance(i, str), "tried to idlib.b2a non-string '%s'" % (i,)
-    return b32encode(i).lower()
+from assertutil import _assert, precondition
 
-def b2a_or_none(i):
-    if i is None:
-        return None
-    return b2a(i)
+z_base_32_alphabet = "ybndrfg8ejkmcpqxot1uwisza345h769" # Zooko's choice, rationale in "DESIGN" doc
+rfc3548_alphabet = "abcdefghijklmnopqrstuvwxyz234567" # RFC3548 standard used by Gnutella, Content-Addressable Web, THEX, Bitzi, Web-Calculus...
+chars = z_base_32_alphabet
 
-def a2b(i):
-    assert isinstance(i, str), "tried to idlib.a2b non-string '%s'" % (i,)
-    try:
-        return b32decode(i.upper())
-    except TypeError:
-        print "b32decode failed on a %s byte string '%s'" % (len(i), i)
-        raise
+vals = ''.join(map(chr, range(32)))
+c2vtranstable = string.maketrans(chars, vals)
+v2ctranstable = string.maketrans(vals, chars)
+identitytranstable = string.maketrans(chars, chars)
 
+def _get_trailing_chars_without_lsbs(N, d):
+    """
+    @return: a list of chars that can legitimately appear in the last place when the least significant N bits are ignored.
+    """
+    s = []
+    if N < 4:
+        s.extend(_get_trailing_chars_without_lsbs(N+1, d=d))
+    i = 0
+    while i < len(chars):
+        if not d.has_key(i):
+            d[i] = None
+            s.append(chars[i])
+        i = i + 2**N
+    return s
+
+def get_trailing_chars_without_lsbs(N):
+    precondition((N >= 0) and (N < 5), "N is required to be > 0 and < len(chars).", N=N)
+    if N == 0:
+        return chars
+    d = {}
+    return ''.join(_get_trailing_chars_without_lsbs(N, d=d))
+
+def b2a(os):
+    """
+    @param os the data to be encoded (a string)
+
+    @return the contents of os in base-32 encoded form
+    """
+    return b2a_l(os, len(os)*8)
+
+def b2a_or_none(os):
+    if os is not None:
+        return b2a(os)
+        
+def b2a_l(os, lengthinbits):
+    """
+    @param os the data to be encoded (a string)
+    @param lengthinbits the number of bits of data in os to be encoded
+
+    b2a_l() will generate a base-32 encoded string big enough to encode lengthinbits bits.  So for
+    example if os is 2 bytes long and lengthinbits is 15, then b2a_l() will generate a 3-character-
+    long base-32 encoded string (since 3 quintets is sufficient to encode 15 bits).  If os is
+    2 bytes long and lengthinbits is 16 (or None), then b2a_l() will generate a 4-character string.
+    Note that b2a_l() does not mask off unused least-significant bits, so for example if os is
+    2 bytes long and lengthinbits is 15, then you must ensure that the unused least-significant bit
+    of os is a zero bit or you will get the wrong result.  This precondition is tested by assertions
+    if assertions are enabled.
+
+    Warning: if you generate a base-32 encoded string with b2a_l(), and then someone else tries to
+    decode it by calling a2b() instead of  a2b_l(), then they will (probably) get a different
+    string than the one you encoded!  So only use b2a_l() when you are sure that the encoding and
+    decoding sides know exactly which lengthinbits to use.  If you do not have a way for the
+    encoder and the decoder to agree upon the lengthinbits, then it is best to use b2a() and
+    a2b().  The only drawback to using b2a() over b2a_l() is that when you have a number of
+    bits to encode that is not a multiple of 8, b2a() can sometimes generate a base-32 encoded
+    string that is one or two characters longer than necessary.
+
+    @return the contents of os in base-32 encoded form
+    """
+    precondition(isinstance(lengthinbits, (int, long,)), "lengthinbits is required to be an integer.", lengthinbits=lengthinbits)
+    precondition((lengthinbits+7)/8 == len(os), "lengthinbits is required to specify a number of bits storable in exactly len(os) octets.", lengthinbits=lengthinbits, lenos=len(os))
+
+    os = map(ord, os)
+
+    numquintets = (lengthinbits+4)/5
+    numoctetsofdata = (lengthinbits+7)/8
+    # print "numoctetsofdata: %s, len(os): %s, lengthinbits: %s, numquintets: %s" % (numoctetsofdata, len(os), lengthinbits, numquintets,)
+    # strip trailing octets that won't be used
+    del os[numoctetsofdata:]
+    # zero out any unused bits in the final octet
+    if lengthinbits % 8 != 0:
+        os[-1] = os[-1] >> (8-(lengthinbits % 8))
+        os[-1] = os[-1] << (8-(lengthinbits % 8))
+    # append zero octets for padding if needed
+    numoctetsneeded = (numquintets*5+7)/8 + 1
+    os.extend([0]*(numoctetsneeded-len(os)))
+
+    quintets = []
+    cutoff = 256
+    num = os[0]
+    i = 0
+    while len(quintets) < numquintets:
+        i = i + 1
+        assert len(os) > i, "len(os): %s, i: %s, len(quintets): %s, numquintets: %s, lengthinbits: %s, numoctetsofdata: %s, numoctetsneeded: %s, os: %s" % (len(os), i, len(quintets), numquintets, lengthinbits, numoctetsofdata, numoctetsneeded, os,)
+        num = num * 256
+        num = num + os[i]
+        if cutoff == 1:
+            cutoff = 256
+            continue
+        cutoff = cutoff * 8
+        quintet = num / cutoff
+        quintets.append(quintet)
+        num = num - (quintet * cutoff)
+
+        cutoff = cutoff / 32
+        quintet = num / cutoff
+        quintets.append(quintet)
+        num = num - (quintet * cutoff)
+
+    if len(quintets) > numquintets:
+        assert len(quintets) == (numquintets+1), "len(quintets): %s, numquintets: %s, quintets: %s" % (len(quintets), numquintets, quintets,)
+        quintets = quintets[:numquintets]
+    res = string.translate(string.join(map(chr, quintets), ''), v2ctranstable)
+    assert could_be_base32_encoded_l(res, lengthinbits), "lengthinbits: %s, res: %s" % (lengthinbits, res,)
+    return res
+
+# b2a() uses the minimal number of quintets sufficient to encode the binary
+# input.  It just so happens that the relation is like this (everything is
+# modulo 40 bits).
+# num_qs = NUM_OS_TO_NUM_QS[num_os]
+NUM_OS_TO_NUM_QS=(0, 2, 4, 5, 7,)
+
+# num_os = NUM_QS_TO_NUM_OS[num_qs], but if not NUM_QS_LEGIT[num_qs] then
+# there is *no* number of octets which would have resulted in this number of
+# quintets, so either the encoded string has been mangled (truncated) or else
+# you were supposed to decode it with a2b_l() (which means you were supposed
+# to know the actual length of the encoded data).
+
+NUM_QS_TO_NUM_OS=(0, 1, 1, 2, 2, 3, 3, 4)
+NUM_QS_LEGIT=(1, 0, 1, 0, 1, 1, 0, 1,)
+NUM_QS_TO_NUM_BITS=tuple(map(lambda x: x*8, NUM_QS_TO_NUM_OS))
+
+# A fast way to determine whether a given string *could* be base-32 encoded data, assuming that the
+# original data had 8K bits for a positive integer K.
+# The boolean value of s8[len(s)%8][ord(s[-1])], where s is the possibly base-32 encoded string
+# tells whether the final character is reasonable.
+def add_check_array(cs, sfmap):
+    checka=[0] * 256
+    for c in cs:
+        checka[ord(c)] = 1
+    sfmap.append(tuple(checka))
+
+def init_s8():
+    s8 = []
+    add_check_array(chars, s8)
+    for lenmod8 in (1, 2, 3, 4, 5, 6, 7,):
+        if NUM_QS_LEGIT[lenmod8]:
+            add_check_array(get_trailing_chars_without_lsbs(4-(NUM_QS_TO_NUM_BITS[lenmod8]%5)), s8)
+        else:
+            add_check_array('', s8)
+    return tuple(s8)
+s8 = init_s8()
+
+# A somewhat fast way to determine whether a given string *could* be base-32 encoded data, given a
+# lengthinbits.
+# The boolean value of s5[lengthinbits%5][ord(s[-1])], where s is the possibly base-32 encoded
+# string tells whether the final character is reasonable.
+def init_s5():
+    s5 = []
+    add_check_array(chars, s5)
+    for lenmod5 in (1, 2, 3, 4,):
+        add_check_array(get_trailing_chars_without_lsbs(4-lenmod5), s5)
+    return tuple(s5)
+s5 = init_s5()
+
+def could_be_base32_encoded(s, s8=s8, tr=string.translate, identitytranstable=identitytranstable, chars=chars):
+    if s == '':
+        return True
+    return s8[len(s)%8][ord(s[-1])] and not tr(s, identitytranstable, chars)
+
+def could_be_base32_encoded_l(s, lengthinbits, s5=s5, tr=string.translate, identitytranstable=identitytranstable, chars=chars):
+    if s == '':
+        return True
+    assert lengthinbits%5 < len(s5), lengthinbits
+    assert ord(s[-1]) < s5[lengthinbits%5]
+    return (((lengthinbits+4)/5) == len(s)) and s5[lengthinbits%5][ord(s[-1])] and not string.translate(s, identitytranstable, chars)
+
+def num_octets_that_encode_to_this_many_quintets(numqs):
+    # Here is a computation that conveniently expresses this:
+    return (numqs*5+3)/8
+
+def a2b(cs):
+    """
+    @param cs the base-32 encoded data (a string)
+    """
+    precondition(could_be_base32_encoded(cs), "cs is required to be possibly base32 encoded data.", cs=cs)
+
+    return a2b_l(cs, num_octets_that_encode_to_this_many_quintets(len(cs))*8)
+
+def a2b_l(cs, lengthinbits):
+    """
+    @param lengthinbits the number of bits of data in encoded into cs
+
+    a2b_l() will return a result big enough to hold lengthinbits bits.  So for example if cs is
+    4 characters long (encoding at least 15 and up to 20 bits) and lengthinbits is 16, then a2b_l()
+    will return a string of length 2 (since 2 bytes is sufficient to store 16 bits).  If cs is 4
+    characters long and lengthinbits is 20, then a2b_l() will return a string of length 3 (since
+    3 bytes is sufficient to store 20 bits).  Note that b2a_l() does not mask off unused least-
+    significant bits, so for example if cs is 4 characters long and lengthinbits is 17, then you
+    must ensure that all three of the unused least-significant bits of cs are zero bits or you will
+    get the wrong result.  This precondition is tested by assertions if assertions are enabled.
+    (Generally you just require the encoder to ensure this consistency property between the least
+    significant zero bits and value of lengthinbits, and reject strings that have a length-in-bits
+    which isn't a multiple of 8 and yet don't have trailing zero bits, as improperly encoded.)
+
+    Please see the warning in the docstring of b2a_l() regarding the use of b2a() versus b2a_l().
+
+    @return the data encoded in cs
+    """
+    precondition(could_be_base32_encoded_l(cs, lengthinbits), "cs is required to be possibly base32 encoded data.", cs=cs, lengthinbits=lengthinbits)
+    if cs == '':
+        return ''
+
+    qs = map(ord, string.translate(cs, c2vtranstable))
+
+    numoctets = (lengthinbits+7)/8
+    numquintetsofdata = (lengthinbits+4)/5
+    # strip trailing quintets that won't be used
+    del qs[numquintetsofdata:]
+    # zero out any unused bits in the final quintet
+    if lengthinbits % 5 != 0:
+        qs[-1] = qs[-1] >> (5-(lengthinbits % 5))
+        qs[-1] = qs[-1] << (5-(lengthinbits % 5))
+    # append zero quintets for padding if needed
+    numquintetsneeded = (numoctets*8+4)/5
+    qs.extend([0]*(numquintetsneeded-len(qs)))
+
+    octets = []
+    pos = 2048
+    num = qs[0] * pos
+    readybits = 5
+    i = 1
+    while len(octets) < numoctets:
+        while pos > 256:
+            pos = pos / 32
+            num = num + (qs[i] * pos)
+            i = i + 1
+        octet = num / 256
+        octets.append(octet)
+        num = num - (octet * 256)
+        num = num * 256
+        pos = pos * 256
+    assert len(octets) == numoctets, "len(octets): %s, numoctets: %s, octets: %s" % (len(octets), numoctets, octets,)
+    res = ''.join(map(chr, octets))
+    precondition(b2a_l(res, lengthinbits) == cs, "cs is required to be the canonical base-32 encoding of some data.", b2a(res), res=res, cs=cs)
+    return res
 
-def peerid_to_short_string(peerid):
-    return b2a(peerid)[:4]