From ccda06b0612200c3fc86888162ac05ad56700fc2 Mon Sep 17 00:00:00 2001
From: robk-tahoe <robk-tahoe@allmydata.com>
Date: Thu, 3 Apr 2008 13:01:43 -0700
Subject: [PATCH] key_generator: added a unit test

implemented a unit test of basic KeyGenService functionality,
fixed a bug in the timing of pool refreshes
---
 src/allmydata/key_generator.py    | 10 ++--
 src/allmydata/test/test_keygen.py | 98 +++++++++++++++++++++++++++++++
 2 files changed, 104 insertions(+), 4 deletions(-)
 create mode 100644 src/allmydata/test/test_keygen.py

diff --git a/src/allmydata/key_generator.py b/src/allmydata/key_generator.py
index aa735c4c..35ada63d 100644
--- a/src/allmydata/key_generator.py
+++ b/src/allmydata/key_generator.py
@@ -54,6 +54,7 @@ class KeyGenerator(foolscap.Referenceable):
     def remote_get_rsa_key_pair(self, key_size):
         self.vlog('%s remote_get_key' % (self,))
         if key_size != self.DEFAULT_KEY_SIZE or not self.keypool:
+            self.reset_timer()
             return self.gen_key(key_size)
         else:
             self.reset_timer()
@@ -62,7 +63,7 @@ class KeyGenerator(foolscap.Referenceable):
 class KeyGeneratorService(service.MultiService):
     furl_file = 'key_generator.furl'
 
-    def __init__(self):
+    def __init__(self, display_furl=True):
         service.MultiService.__init__(self)
         self.tub = foolscap.Tub(certFile='key_generator.pem')
         self.tub.setServiceParent(self)
@@ -73,7 +74,7 @@ class KeyGeneratorService(service.MultiService):
         d = self.tub.setLocationAutomatically()
         if portnum is None:
             d.addCallback(self.save_portnum)
-        d.addCallback(self.tub_ready)
+        d.addCallback(self.tub_ready, display_furl)
         d.addErrback(log.err)
 
     def get_portnum(self):
@@ -84,6 +85,7 @@ class KeyGeneratorService(service.MultiService):
         portnum = self.listener.getPortnum()
         file('portnum', 'wb').write('%d\n' % (portnum,))
 
-    def tub_ready(self, junk):
+    def tub_ready(self, junk, display_furl):
         self.keygen_furl = self.tub.registerReference(self.key_generator, furlFile=self.furl_file)
-        print 'key generator at:', self.keygen_furl 
+        if display_furl:
+            print 'key generator at:', self.keygen_furl 
diff --git a/src/allmydata/test/test_keygen.py b/src/allmydata/test/test_keygen.py
new file mode 100644
index 00000000..692681a6
--- /dev/null
+++ b/src/allmydata/test/test_keygen.py
@@ -0,0 +1,98 @@
+
+import os
+from twisted.trial import unittest
+from twisted.application import service
+
+from foolscap import Tub, eventual
+
+from allmydata import key_generator
+from allmydata.util import testutil
+from pycryptopp.publickey import rsa
+
+def flush_but_dont_ignore(res):
+    d = eventual.flushEventualQueue()
+    def _done(ignored):
+        return res
+    d.addCallback(_done)
+    return d
+
+class KeyGenService(unittest.TestCase, testutil.PollMixin):
+    def setUp(self):
+        self.parent = service.MultiService()
+        self.parent.startService()
+
+        self.tub = t = Tub()
+        t.setServiceParent(self.parent)
+        t.listenOn("tcp:0")
+        t.setLocationAutomatically()
+
+    def tearDown(self):
+        d = self.parent.stopService()
+        d.addCallback(eventual.fireEventually)
+        d.addBoth(flush_but_dont_ignore)
+        return d
+
+    def test_key_gen_service(self):
+        def p(junk, msg):
+            #import time
+            #print time.asctime(), msg
+            return junk
+
+        #print 'starting key generator service'
+        kgs = key_generator.KeyGeneratorService(display_furl=False)
+        kgs.key_generator.verbose = True
+        kgs.setServiceParent(self.parent)
+        kgs.key_generator.pool_size = 8
+        keysize = kgs.key_generator.DEFAULT_KEY_SIZE
+
+        def keypool_full():
+            return len(kgs.key_generator.keypool) == kgs.key_generator.pool_size
+
+        # first wait for key gen pool to fill up
+        d = eventual.fireEventually()
+        d.addCallback(p, 'waiting for pool to fill up')
+        d.addCallback(lambda junk: self.poll(keypool_full, timeout=16))
+        
+        d.addCallback(p, 'grabbing a few keys')
+        # grab a few keys, check that pool size shrinks
+        def get_key(junk=None):
+            d = self.tub.getReference(kgs.keygen_furl)
+            d.addCallback(lambda kg: kg.callRemote('get_rsa_key_pair', keysize))
+            return d
+
+        def check_poolsize(junk, size):
+            self.failUnlessEqual(len(kgs.key_generator.keypool), size)
+
+        n_keys_to_waste = 4
+        for i in range(n_keys_to_waste):
+            d.addCallback(get_key)
+        d.addCallback(check_poolsize, kgs.key_generator.pool_size - n_keys_to_waste)
+
+        d.addCallback(p, 'checking a key works')
+        # check that a retrieved key is actually useful
+        d.addCallback(get_key)
+        def check_key_works(keys):
+            verifying_key, signing_key = keys
+            v = rsa.create_verifying_key_from_string(verifying_key)
+            s = rsa.create_signing_key_from_string(signing_key)
+            junk = os.urandom(42)
+            sig = s.sign(junk)
+            self.failUnless(v.verify(junk, sig))
+        d.addCallback(check_key_works)
+
+        d.addCallback(p, 'checking pool exhaustion')
+        # exhaust the pool
+        for i in range(kgs.key_generator.pool_size):
+            d.addCallback(get_key)
+        d.addCallback(check_poolsize, 0)
+
+        # and check it still works (will gen key synchronously on demand)
+        d.addCallback(get_key)
+        d.addCallback(check_key_works)
+        
+        d.addCallback(p, 'checking pool replenishment')
+        # check that the pool will refill
+        timeout = 2*kgs.key_generator.pool_size + kgs.key_generator.pool_refresh_delay
+        d.addCallback(lambda junk: self.poll(keypool_full, timeout=timeout))
+
+        return d
-- 
2.45.2