]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
cli: simplify code by using stdlib's httplib module
authorBrian Warner <warner@allmydata.com>
Fri, 12 Oct 2007 05:29:23 +0000 (22:29 -0700)
committerBrian Warner <warner@allmydata.com>
Fri, 12 Oct 2007 05:29:23 +0000 (22:29 -0700)
src/allmydata/scripts/common.py
src/allmydata/scripts/common_http.py [new file with mode: 0644]
src/allmydata/scripts/tahoe_mv.py
src/allmydata/scripts/tahoe_put.py
src/allmydata/scripts/tahoe_rm.py

index 64cceb88e57cab62312766dc066584e0a3f7bb7a..794d9868ba9194c277921e3012888d433a5757ba 100644 (file)
@@ -62,3 +62,4 @@ class NoDefaultBasedirMixin(BasedirMixin):
         if not self.basedirs:
             raise usage.UsageError("--basedir must be provided")
 
+
diff --git a/src/allmydata/scripts/common_http.py b/src/allmydata/scripts/common_http.py
new file mode 100644 (file)
index 0000000..fbeaa59
--- /dev/null
@@ -0,0 +1,60 @@
+
+from cStringIO import StringIO
+import urlparse, httplib
+import allmydata # for __version__
+
+# copied from twisted/web/client.py
+def parse_url(url, defaultPort=None):
+    url = url.strip()
+    parsed = urlparse.urlparse(url)
+    scheme = parsed[0]
+    path = urlparse.urlunparse(('','')+parsed[2:])
+    if defaultPort is None:
+        if scheme == 'https':
+            defaultPort = 443
+        else:
+            defaultPort = 80
+    host, port = parsed[1], defaultPort
+    if ':' in host:
+        host, port = host.split(':')
+        port = int(port)
+    if path == "":
+        path = "/"
+    return scheme, host, port, path
+
+
+def do_http(method, url, body=""):
+    if isinstance(body, str):
+        body = StringIO(body)
+    elif isinstance(body, unicode):
+        raise RuntimeError("do_http body must be a bytestring, not unicode")
+    else:
+        assert body.tell
+        assert body.seek
+        assert body.read
+    scheme, host, port, path = parse_url(url)
+    if scheme == "http":
+        c = httplib.HTTPConnection(host, port)
+    elif scheme == "https":
+        c = httplib.HTTPSConnection(host, port)
+    else:
+        raise ValueError("unknown scheme '%s', need http or https" % scheme)
+    c.putrequest(method, path)
+    c.putheader("Hostname", host)
+    c.putheader("User-Agent", "tahoe_cli/%s" % allmydata.__version__)
+    c.putheader("Connection", "close")
+
+    old = body.tell()
+    body.seek(0, 2)
+    length = body.tell()
+    body.seek(old)
+    c.putheader("Content-Length", str(length))
+    c.endheaders()
+
+    while True:
+        data = body.read(8192)
+        if not data:
+            break
+        c.send(data)
+
+    return c.getresponse()
index 4020c6e8c383972b322c2b0a2268b8ad7f19dbac..2bb49c4e6bbb977be39f61b06b46ffd0dd6080c0 100644 (file)
@@ -1,44 +1,9 @@
 #! /usr/bin/python
 
 import re
-import urllib, httplib
-import urlparse
+import urllib
 import simplejson
-
-# copied from twisted/web/client.py
-def _parse(url, defaultPort=None):
-    url = url.strip()
-    parsed = urlparse.urlparse(url)
-    scheme = parsed[0]
-    path = urlparse.urlunparse(('','')+parsed[2:])
-    if defaultPort is None:
-        if scheme == 'https':
-            defaultPort = 443
-        else:
-            defaultPort = 80
-    host, port = parsed[1], defaultPort
-    if ':' in host:
-        host, port = host.split(':')
-        port = int(port)
-    if path == "":
-        path = "/"
-    return scheme, host, port, path
-
-def do_http(method, url, body=""):
-    scheme, host, port, path = _parse(url)
-    if scheme == "http":
-        c = httplib.HTTPConnection(host, port)
-    elif scheme == "https":
-        c = httplib.HTTPSConnection(host, port)
-    else:
-        raise ValueError("unknown scheme '%s', need http or https" % scheme)
-    c.putrequest(method, path)
-    import allmydata
-    c.putheader("User-Agent", "tahoe_mv/%s" % allmydata.__version__)
-    c.putheader("Content-Length", str(len(body)))
-    c.endheaders()
-    c.send(body)
-    return c.getresponse()
+from allmydata.scripts.common_http import do_http
 
 def mv(nodeurl, root_uri, frompath, topath, stdout, stderr):
     if nodeurl[-1] != "/":
@@ -48,6 +13,9 @@ def mv(nodeurl, root_uri, frompath, topath, stdout, stderr):
 
     nodetype, attrs = simplejson.loads(data)
     uri = attrs.get("rw_uri") or attrs["ro_uri"]
+    # simplejson always returns unicode, but we know that it's really just a
+    # bytestring.
+    uri = str(uri)
 
     put_url = url + topath + "?t=uri"
     resp = do_http("PUT", put_url, uri)
index e2b0feee14aa5b177401e259fdf9adcffe2aefb3..244e5e2af3a1bcf12c69f16043fb95f92fc74604 100644 (file)
@@ -1,8 +1,7 @@
 #!/usr/bin/env python
 
-import re, socket, urllib
-
-NODEURL_RE=re.compile("http://([^:]*)(:([1-9][0-9]*))?")
+import urllib
+from allmydata.scripts.common_http import do_http
 
 def put(nodeurl, root_uri, local_fname, vdrive_fname, verbosity,
         stdout, stderr):
@@ -11,66 +10,21 @@ def put(nodeurl, root_uri, local_fname, vdrive_fname, verbosity,
 
     @return: a Deferred which eventually fires with the exit code
     """
-    mo = NODEURL_RE.match(nodeurl)
-    host = mo.group(1)
-    port = int(mo.group(3))
-
-    url = "/uri/%s/" % urllib.quote(root_uri.replace("/","!"))
+    if nodeurl[-1] != "/":
+        nodeurl += "/"
+    url = nodeurl + "uri/%s/" % urllib.quote(root_uri.replace("/","!"))
     if vdrive_fname:
         url += vdrive_fname
 
     infileobj = open(local_fname, "rb")
-    infileobj.seek(0, 2)
-    infilelen = infileobj.tell()
-    infileobj.seek(0, 0)
-
-    so = socket.socket()
-    so.connect((host, port,))
-
-    CHUNKSIZE=2**16
-    data = "PUT %s HTTP/1.1\r\nConnection: close\r\nContent-Length: %s\r\nHostname: %s\r\n\r\n" % (url, infilelen, host,)
-    while data:
-        try:
-            sent = so.send(data)
-        except Exception, le:
-            print >>stderr, "got socket error: %s" % (le,)
-            return -1
-
-        if sent == len(data):
-            data = infileobj.read(CHUNKSIZE)
-        else:
-            data = data[sent:]
+    resp = do_http("PUT", url, infileobj)
 
-    respbuf = []
-    data = so.recv(CHUNKSIZE)
-    while data:
-        respbuf.append(data)
-        data = so.recv(CHUNKSIZE)
+    if resp.status in (200, 201,):
+        print >>stdout, "%s %s" % (resp.status, resp.reason)
+        return 0
 
-    so.shutdown(socket.SHUT_WR)
-
-    data = so.recv(CHUNKSIZE)
-    while data:
-        respbuf.append(data)
-        data = so.recv(CHUNKSIZE)
-
-    respstr = ''.join(respbuf)
-
-    headerend = respstr.find('\r\n\r\n')
-    if headerend == -1:
-        headerend = len(respstr)
-    header = respstr[:headerend]
-    RESP_RE=re.compile("^HTTP/[0-9]\.[0-9] ([0-9]*) *([A-Za-z_ ]*)")  # This regex is soooo ad hoc...  --Zooko 2007-08-16
-    mo = RESP_RE.match(header)
-    if mo:
-        code = int(mo.group(1))
-        word = mo.group(2)
-
-        if code in (200, 201,):
-            print >>stdout, "%s %s" % (code, word,)
-            return 0
-    
-    print >>stderr, respstr[headerend:]
+    print >>stderr, "error, got %s %s" % (resp.status, resp.reason)
+    print >>stderr, resp.read()
     return 1
 
 def main():
index ee2acfe65f1790ed6586f252d3e182b679169595..abc58542055089d62a5f5f3d91a28a5d3f9996be 100644 (file)
@@ -1,8 +1,7 @@
 #!/usr/bin/env python
 
-import re, socket, urllib
-
-NODEURL_RE=re.compile("http://([^:]*)(:([1-9][0-9]*))?")
+import urllib
+from allmydata.scripts.common_http import do_http
 
 def rm(nodeurl, root_uri, vdrive_pathname, verbosity, stdout, stderr):
     """
@@ -10,51 +9,20 @@ def rm(nodeurl, root_uri, vdrive_pathname, verbosity, stdout, stderr):
 
     @return: a Deferred which eventually fires with the exit code
     """
-    mo = NODEURL_RE.match(nodeurl)
-    host = mo.group(1)
-    port = int(mo.group(3))
-
-    url = "/uri/%s/" % urllib.quote(root_uri.replace("/","!"))
+    if nodeurl[-1] != "/":
+        nodeurl += "/"
+    url = nodeurl + "uri/%s/" % urllib.quote(root_uri.replace("/","!"))
     if vdrive_pathname:
         url += vdrive_pathname
 
-    so = socket.socket()
-    so.connect((host, port,))
-
-    CHUNKSIZE=2**16
-    data = "DELETE %s HTTP/1.1\r\nConnection: close\r\nHostname: %s\r\n\r\n" % (url, host,)
-    sent = so.send(data)
-
-    respbuf = []
-    data = so.recv(CHUNKSIZE)
-    while data:
-        respbuf.append(data)
-        data = so.recv(CHUNKSIZE)
+    resp = do_http("DELETE", url)
 
-    so.shutdown(socket.SHUT_WR)
+    if resp.status in (200,):
+        print >>stdout, "%s %s" % (resp.status, resp.reason)
+        return 0
 
-    data = so.recv(CHUNKSIZE)
-    while data:
-        respbuf.append(data)
-        data = so.recv(CHUNKSIZE)
-
-    respstr = ''.join(respbuf)
-
-    headerend = respstr.find('\r\n\r\n')
-    if headerend == -1:
-        headerend = len(respstr)
-    header = respstr[:headerend]
-    RESP_RE=re.compile("^HTTP/[0-9]\.[0-9] ([0-9]*) *([A-Za-z_ ]*)")  # This regex is soooo ad hoc...  --Zooko 2007-08-16
-    mo = RESP_RE.match(header)
-    if mo:
-        code = int(mo.group(1))
-        word = mo.group(2)
-
-        if code == 200:
-            print >>stdout, "%s %s" % (code, word,)
-            return 0
-    
-    print >>stderr, respstr[headerend:]
+    print >>stderr, "error, got %s %s" % (resp.status, resp.reason)
+    print >>stderr, resp.read()
     return 1
 
 def main():