From 38aee94a3eb3c007cc6cb1d7cb5b04b66f5d8f99 Mon Sep 17 00:00:00 2001
From: Jean-Paul Calderone <exarkun@twistedmatrix.com>
Date: Sun, 4 Jan 2015 09:44:56 -0500
Subject: [PATCH] Add the rest of the failure-case tests and a success-case
 test.  Update the implementation to make them pass.

---
 src/allmydata/frontends/auth.py | 77 +++++++++++++++++++++++++++++----
 src/allmydata/test/test_auth.py | 58 ++++++++++++++++++++++++-
 2 files changed, 125 insertions(+), 10 deletions(-)

diff --git a/src/allmydata/frontends/auth.py b/src/allmydata/frontends/auth.py
index 82ef1c65..745adbe8 100644
--- a/src/allmydata/frontends/auth.py
+++ b/src/allmydata/frontends/auth.py
@@ -3,6 +3,9 @@ from zope.interface import implements
 from twisted.web.client import getPage
 from twisted.internet import defer
 from twisted.cred import error, checkers, credentials
+from twisted.conch import error as conch_error
+from twisted.conch.ssh import keys
+
 from allmydata.util import base32
 
 class NeedRootcapLookupScheme(Exception):
@@ -18,7 +21,8 @@ class FTPAvatarID:
 class AccountFileChecker:
     implements(checkers.ICredentialsChecker)
     credentialInterfaces = (credentials.IUsernamePassword,
-                            credentials.IUsernameHashedPassword)
+                            credentials.IUsernameHashedPassword,
+                            credentials.ISSHPrivateKey)
     def __init__(self, client, accountfile):
         self.client = client
         self.passwords = {}
@@ -31,7 +35,7 @@ class AccountFileChecker:
             name, passwd, rest = line.split(None, 2)
             if passwd in ("ssh-dss", "ssh-rsa"):
                 bits = rest.split()
-                keystring = " ".join(bits[-1])
+                keystring = " ".join([passwd] + bits[:-1])
                 rootcap = bits[-1]
                 self.pubkeys[name] = keystring
             else:
@@ -44,12 +48,69 @@ class AccountFileChecker:
             return FTPAvatarID(username, self.rootcaps[username])
         raise error.UnauthorizedLogin
 
-    def requestAvatarId(self, credentials):
-        if credentials.username in self.passwords:
-            d = defer.maybeDeferred(credentials.checkPassword,
-                                    self.passwords[credentials.username])
-            d.addCallback(self._cbPasswordMatch, str(credentials.username))
-            return d
+    def requestAvatarId(self, creds):
+        if credentials.ISSHPrivateKey.providedBy(creds):
+            # Re-using twisted.conch.checkers.SSHPublicKeyChecker here, rather
+            # than re-implementing all of the ISSHPrivateKey checking logic,
+            # would be better.  That would require Twisted 14.1.0 or newer,
+            # though.
+            return self._checkKey(creds)
+        elif credentials.IUsernameHashedPassword.providedBy(creds):
+            return self._checkPassword(creds)
+        elif credentials.IUsernamePassword.providedBy(creds):
+            return self._checkPassword(creds)
+        else:
+            raise NotImplementedError()
+
+    def _checkPassword(self, creds):
+        """
+        Determine whether the password in the given credentials matches the
+        password in the account file.
+
+        Returns a Deferred that fires with the username if the password matches
+        or with an UnauthorizedLogin failure otherwise.
+        """
+        try:
+            correct = self.passwords[creds.username]
+        except KeyError:
+            return defer.fail(error.UnauthorizedLogin())
+
+        d = defer.maybeDeferred(creds.checkPassword, correct)
+        d.addCallback(self._cbPasswordMatch, str(creds.username))
+        return d
+
+    def _allowedKey(self, creds):
+        """
+        Determine whether the public key indicated by the given credentials is
+        one allowed to authenticate the username in those credentials.
+
+        Returns True if so, False otherwise.
+        """
+        return creds.blob == self.pubkeys.get(creds.username)
+
+    def _correctSignature(self, creds):
+        """
+        Determine whether the signature in the given credentials is the correct
+        signature for the data in those credentials.
+
+        Returns True if so, False otherwise.
+        """
+        key = keys.Key.fromString(creds.blob)
+        return key.verify(creds.signature, creds.sigData)
+
+    def _checkKey(self, creds):
+        """
+        Determine whether some key-based credentials correctly authenticates a
+        user.
+
+        Returns a Deferred that fires with the username if so or with an
+        UnauthorizedLogin failure otherwise.
+        """
+        if self._allowedKey(creds):
+            if creds.signature is None:
+                return defer.fail(conch_error.ValidPublicKey())
+            if self._correctSignature(creds):
+                return defer.succeed(creds.username)
         return defer.fail(error.UnauthorizedLogin())
 
 class AccountURLChecker:
diff --git a/src/allmydata/test/test_auth.py b/src/allmydata/test/test_auth.py
index 52705788..46c2fbfb 100644
--- a/src/allmydata/test/test_auth.py
+++ b/src/allmydata/test/test_auth.py
@@ -1,6 +1,7 @@
 from twisted.trial import unittest
 from twisted.python import filepath
 from twisted.cred import error, credentials
+from twisted.conch import error as conch_error
 from twisted.conch.ssh import keys
 
 from allmydata.frontends import auth
@@ -26,8 +27,8 @@ dBSD8940XU3YW+oeq8e+p3yQ2GinHfeJ3BYQyNQLuMAJ
 DUMMY_ACCOUNTS = u"""\
 alice password URI:DIR2:aaaaaaaaaaaaaaaaaaaaaaaaaa:1111111111111111111111111111111111111111111111111111
 bob sekrit URI:DIR2:bbbbbbbbbbbbbbbbbbbbbbbbbb:2222222222222222222222222222222222222222222222222222
-carol %(key)s URI:DIR2:cccccccccccccccccccccccccc:3333333333333333333333333333333333333333333333333333
-""".format(DUMMY_KEY.public().toString("openssh")).encode("ascii")
+carol {key} URI:DIR2:cccccccccccccccccccccccccc:3333333333333333333333333333333333333333333333333333
+""".format(key=DUMMY_KEY.public().toString("openssh")).encode("ascii")
 
 class AccountFileCheckerKeyTests(unittest.TestCase):
     """
@@ -49,6 +50,17 @@ class AccountFileCheckerKeyTests(unittest.TestCase):
         avatarId = self.checker.requestAvatarId(key_credentials)
         return self.assertFailure(avatarId, error.UnauthorizedLogin)
 
+    def test_password_auth_user(self):
+        """
+        AccountFileChecker.requestAvatarId returns a Deferred that fires with
+        UnauthorizedLogin if called with an SSHPrivateKey object for a username
+        only associated with a password in the account file.
+        """
+        key_credentials = credentials.SSHPrivateKey(
+            b"alice", b"md5", None, None, None)
+        avatarId = self.checker.requestAvatarId(key_credentials)
+        return self.assertFailure(avatarId, error.UnauthorizedLogin)
+
     def test_unrecognized_key(self):
         """
         AccountFileChecker.requestAvatarId returns a Deferred that fires with
@@ -63,3 +75,45 @@ ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAYQDJGMWlPXh2M3pYzTiamjcBIMqctt4VvLVW2QZgEFc8
             b"carol", b"md5", wrong_key_blob, None, None)
         avatarId = self.checker.requestAvatarId(key_credentials)
         return self.assertFailure(avatarId, error.UnauthorizedLogin)
+
+    def test_missing_signature(self):
+        """
+        AccountFileChecker.requestAvatarId returns a Deferred that fires with
+        ValidPublicKey if called with an SSHPrivateKey object with an
+        authorized key for the indicated user but with no signature.
+        """
+        right_key_blob = DUMMY_KEY.public().toString("openssh")
+        key_credentials = credentials.SSHPrivateKey(
+            b"carol", b"md5", right_key_blob, None, None)
+        avatarId = self.checker.requestAvatarId(key_credentials)
+        return self.assertFailure(avatarId, conch_error.ValidPublicKey)
+
+    def test_wrong_signature(self):
+        """
+        AccountFileChecker.requestAvatarId returns a Deferred that fires with
+        UnauthorizedLogin if called with an SSHPrivateKey object with a public
+        key matching that on the user's line in the account file but with the
+        wrong signature.
+        """
+        right_key_blob = DUMMY_KEY.public().toString("openssh")
+        key_credentials = credentials.SSHPrivateKey(
+            b"carol", b"md5", right_key_blob, b"signed data", b"wrong sig")
+        avatarId = self.checker.requestAvatarId(key_credentials)
+        return self.assertFailure(avatarId, error.UnauthorizedLogin)
+
+    def test_authenticated(self):
+        """
+        AccountFileChecker.requestAvatarId returns a Deferred that fires with
+        the username portion of the account file line that matches the username
+        and key blob portion of the SSHPrivateKey object if that object also
+        has a correct signature.
+        """
+        username = b"carol"
+        signed_data = b"signed data"
+        signature = DUMMY_KEY.sign(signed_data)
+        right_key_blob = DUMMY_KEY.public().toString("openssh")
+        key_credentials = credentials.SSHPrivateKey(
+            username, b"md5", right_key_blob, signed_data, signature)
+        avatarId = self.checker.requestAvatarId(key_credentials)
+        avatarId.addCallback(self.assertEqual, username)
+        return avatarId
-- 
2.45.2