test_web: add HEAD coverage
authorBrian Warner <warner@allmydata.com>
Tue, 20 May 2008 18:47:43 +0000 (11:47 -0700)
committerBrian Warner <warner@allmydata.com>
Tue, 20 May 2008 18:47:43 +0000 (11:47 -0700)
src/allmydata/test/test_web.py

index b058dfe3c25748731fa4a4509922fdc26ce3f374..4f089d1b438458aa4b66ee3b79ff37fef2f2bedc 100644 (file)
@@ -2,7 +2,7 @@ import re, urllib
 import simplejson
 from twisted.application import service
 from twisted.trial import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from twisted.web import client, error, http
 from twisted.python import failure, log
 from allmydata import interfaces, provisioning, uri, webish, upload, download
@@ -91,13 +91,19 @@ class FakeClient(service.MultiService):
     def list_all_helper_statuses(self):
         return []
 
+class HTTPClientHEADFactory(client.HTTPClientFactory):
+    def __init__(self, *args, **kwargs):
+        client.HTTPClientFactory.__init__(self, *args, **kwargs)
+        self.deferred.addCallback(lambda res: self.response_headers)
+
+
 class WebMixin(object):
     def setUp(self):
         self.s = FakeClient()
         self.s.startService()
         self.ws = s = webish.WebishServer("0")
         s.setServiceParent(self.s)
-        port = s.listener._port.getHost().port
+        self.webish_port = port = s.listener._port.getHost().port
         self.webish_url = "http://localhost:%d" % port
 
         l = [ self.s.create_empty_dirnode() for x in range(6) ]
@@ -211,6 +217,13 @@ class WebMixin(object):
         url = self.webish_url + urlpath
         return client.getPage(url, method="GET", followRedirect=followRedirect)
 
+    def HEAD(self, urlpath):
+        # this requires some surgery, because twisted.web.client doesn't want
+        # to give us back the response headers.
+        factory = HTTPClientHEADFactory(urlpath)
+        reactor.connectTCP("localhost", self.webish_port, factory)
+        return factory.deferred
+
     def PUT(self, urlpath, data):
         url = self.webish_url + urlpath
         return client.getPage(url, method="PUT", postdata=data)
@@ -463,6 +476,15 @@ class Web(WebMixin, unittest.TestCase):
         d.addCallback(self.failUnlessIsBarDotTxt)
         return d
 
+    def test_HEAD_FILEURL(self):
+        d = self.HEAD(self.public_url + "/foo/bar.txt")
+        def _got(headers):
+            self.failUnlessEqual(headers["content-length"][0],
+                                 str(len(self.BAR_CONTENTS)))
+            self.failUnlessEqual(headers["content-type"], ["text/plain"])
+        d.addCallback(_got)
+        return d
+
     def test_GET_FILEURL_named(self):
         base = "/file/%s" % urllib.quote(self._bar_txt_uri)
         base2 = "/named/%s" % urllib.quote(self._bar_txt_uri)
@@ -1131,6 +1153,15 @@ class Web(WebMixin, unittest.TestCase):
         d.addCallback(lambda res:
                       self.failUnlessEqual(res, EVEN_NEWER_CONTENTS))
 
+        # and that HEAD computes the size correctly
+        d.addCallback(lambda res:
+                      self.HEAD(self.public_url + "/foo/new.txt"))
+        def _got_headers(headers):
+            self.failUnlessEqual(headers["content-length"][0],
+                                 str(len(EVEN_NEWER_CONTENTS)))
+            self.failUnlessEqual(headers["content-type"], ["text/plain"])
+        d.addCallback(_got_headers)
+
         d.addErrback(self.dump_error)
         return d