From d80bf91010f9f2f8653c55bb902ec83bf1f034a2 Mon Sep 17 00:00:00 2001
From: Ramakrishnan Muthukrishnan <ram@rkrishnan.org>
Date: Sun, 6 Aug 2017 14:34:58 +0530
Subject: [PATCH] quick and dirty initial extended metadata protocol
 implementation

---
 src/FuncTorrent/Bencode.hs  |  5 +++
 src/FuncTorrent/Peer.hs     | 90 ++++++++++++++++++++++++++++++++++---
 src/FuncTorrent/PeerMsgs.hs | 26 ++++++-----
 3 files changed, 104 insertions(+), 17 deletions(-)

diff --git a/src/FuncTorrent/Bencode.hs b/src/FuncTorrent/Bencode.hs
index d4eb58c..ac62211 100644
--- a/src/FuncTorrent/Bencode.hs
+++ b/src/FuncTorrent/Bencode.hs
@@ -23,6 +23,7 @@ module FuncTorrent.Bencode
     , bValToInteger
     , bstrToString
     , decode
+    , decodeWithLeftOvers
     , encode
     ) where
 
@@ -165,6 +166,10 @@ bencVal = Bstr <$> bencStr <|>
 decode :: ByteString -> Either ParseError BVal
 decode = parse bencVal "BVal"
 
+decodeWithLeftOvers :: ByteString -> Either ParseError (BVal, ByteString)
+decodeWithLeftOvers = parse ((,) <$> bencVal <*> (fmap pack leftOvers)) "BVal with LeftOvers"
+  where leftOvers = manyTill anyToken eof
+
 -- Encode BVal into a bencoded ByteString. Inverse of decode
 
 -- TODO: Use builders and lazy byte string to get O(1) concatenation over O(n)
diff --git a/src/FuncTorrent/Peer.hs b/src/FuncTorrent/Peer.hs
index 87e7d23..4396c1c 100644
--- a/src/FuncTorrent/Peer.hs
+++ b/src/FuncTorrent/Peer.hs
@@ -25,14 +25,16 @@ module FuncTorrent.Peer
 
 import Prelude hiding (lookup, concat, replicate, splitAt, take, drop)
 
+import Control.Concurrent.MVar (MVar, readMVar, putMVar, takeMVar)
 import Control.Monad.State
-import Data.ByteString (ByteString, unpack, concat, hGet, hPut, take, drop, empty)
+import Data.ByteString (ByteString, unpack, concat, hGet, hPut, take, drop, empty, singleton)
 import Data.Bits
 import Data.Word (Word8)
-import Data.Map ((!), adjust)
+import Data.Map (Map, (!), adjust, fromList, insert)
 import Network (connectTo, PortID(..))
 import System.IO (Handle, BufferMode(..), hSetBuffering, hClose)
 
+import FuncTorrent.Bencode(BVal(..), encode, decode, decodeWithLeftOvers)
 import FuncTorrent.Metainfo (Metainfo(..))
 import FuncTorrent.PeerMsgs (Peer(..), PeerMsg(..), sendMsg, getMsg, genHandshakeMsg)
 import FuncTorrent.Utils (splitNum, verifyHash)
@@ -46,6 +48,11 @@ data PState = PState { handle :: Handle
                      , heChoking :: Bool
                      , heInterested :: Bool}
 
+data InfoPieceMap = InfoPieceMap { infoLength :: Integer
+                                 , infoMap :: Map Integer (Maybe ByteString) }
+
+newtype InfoState = InfoState (MVar InfoPieceMap)
+
 havePiece :: PieceMap -> Integer -> Bool
 havePiece pm index =
   dlstate (pm ! index) == Have
@@ -56,18 +63,25 @@ connectToPeer (Peer ip port) = do
   hSetBuffering h LineBuffering
   return h
 
+
 doHandshake :: Bool -> Handle -> Peer -> ByteString -> String -> IO ()
 doHandshake True h p infohash peerid = do
   let hs = genHandshakeMsg infohash peerid
   hPut h hs
   putStrLn $ "--> handhake to peer: " ++ show p
-  _ <- hGet h (length (unpack hs))
+  hsMsg <- hGet h (length (unpack hs))
   putStrLn $ "<-- handshake from peer: " ++ show p
   return ()
+  -- if doesPeerSupportExtendedMsg hsMsg
+  --   then
+  --   return doExtendedHandshake h
+  --   else
+  --   return Nothing
 doHandshake False h p infohash peerid = do
   let hs = genHandshakeMsg infohash peerid
   putStrLn "waiting for a handshake"
-  hsMsg <- hGet h (length (unpack hs))
+  -- read 28 bytes. '19' ++ 'BitTorrent Protocol' ++ 8 reserved bytes
+  hsMsg <- hGet h 28
   putStrLn $ "<-- handshake from peer: " ++ show p
   let rxInfoHash = take 20 $ drop 28 hsMsg
   if rxInfoHash /= infohash
@@ -78,7 +92,12 @@ doHandshake False h p infohash peerid = do
     else do
     _ <- hPut h hs
     putStrLn $ "--> handhake to peer: " ++ show p
-    return ()
+    -- if doesPeerSupportExtendedMsg hsMsg
+    --   then do
+    --   doExtendedHandshake h
+    --   else
+    --   return Nothing
+
 
 bitfieldToList :: [Word8] -> [Integer]
 bitfieldToList bs = go bs 0
@@ -270,5 +289,64 @@ downloadPiece h index pieceLength = do
 
    At this point, we have the infodict.
 
--)
+-}
 
+{-
+data InfoPieceMap = { infoLength :: Integer
+                    , infoMap :: Map Integer (Maybe ByteString)
+                    }
+
+newtype InfoState = InfoState (MVar InfoPieceMap)
+
+-}
+
+
+metadataMsgLoop :: Handle -> InfoState -> IO ()
+metadataMsgLoop h (InfoState st) = do
+    infoState <- readMVar st
+    let metadataLen = infoLength infoState
+        -- send the handshake msg
+        metadata = encode (metadataMsg metadataLen)
+    sendMsg h (ExtendedMsg 0 metadata)
+    -- recv return msg from the peer. Will have 'metadata_size'
+    msg <- getMsg h
+    case msg of
+      ExtendedMsg 0 rBs -> do
+        -- decode rBs
+        let (Right (Bdict msgMap)) = decode rBs
+            (Bdict mVal) = msgMap ! "m" -- which is another dict
+            (Bint metadata_msgID) = mVal ! "ut_metadata"
+            (Bint metadata_size) = msgMap ! "metadata_size"
+            -- divide metadata_size into 16384 sized pieces, find number of pieces
+            (q, r) = metadata_size `divMod` 16384
+            -- pNumLengthPairs = zip [0..q-1] (take q (repeat 16384)) ++ (q, r)
+            -- TODO: corner case where infodict size is a multiple of 16384
+            -- and start sending request msg for each.
+        if metadataLen == 0
+          then -- We don't have any piece. Send request msg for all pieces.
+          mapM_ (\n -> do
+                    sendMsg h (ExtendedMsg metadata_msgID (encode (requestMsg n)))
+                    dataOrRejectMsg <- getMsg h
+                    case dataOrRejectMsg of
+                      ExtendedMsg 3 payload -> do
+                        -- bencoded dict followed by XXXXXX
+                        infoState <- takeMVar st
+                        let (Right (Bdict bval, pieceData)) = decodeWithLeftOvers payload
+                            (Bint pieceIndex) = bval ! "piece"
+                            payloadLen = length (unpack pieceData)
+                            infoMapVal = infoMap infoState
+                        putMVar st infoState {
+                          infoMap = insert pieceIndex (Just payload) infoMapVal }
+                )
+          [0..q]
+          else
+          return () -- TODO: reject for now
+      where
+        metadataMsg 0 = Bdict (fromList [("m", Bdict (fromList [("ut_metadata", (Bint 3))]))])
+        metadataMsg l = Bdict (fromList [("m", Bdict (fromList [("ut_metadata", (Bint 3))])),
+                                         ("metadata_size", (Bint l))])
+        requestMsg i = Bdict (fromList [("msg_type", (Bint 0)), ("piece", (Bint i))])
+        rejectmsg i = Bdict (fromList [("msg_type", (Bint 2)), ("piece", (Bint i))])
+
+doesPeerSupportExtendedMsg :: ByteString -> Bool
+doesPeerSupportExtendedMsg bs = take 1 (drop 5 bs) == singleton 0x10
diff --git a/src/FuncTorrent/PeerMsgs.hs b/src/FuncTorrent/PeerMsgs.hs
index d6bbdcf..467ac8d 100644
--- a/src/FuncTorrent/PeerMsgs.hs
+++ b/src/FuncTorrent/PeerMsgs.hs
@@ -64,7 +64,12 @@ data PeerMsg = KeepAliveMsg
              | CancelMsg Integer Integer Integer
              | PortMsg Port
              | ExtendedMsg Integer ByteString
-             deriving (Show)
+  deriving (Show)
+
+data ExtMetadataMsg = Request Integer
+                    | Data Integer Integer
+                    | Reject Integer
+  deriving (Eq, Show)
 
 instance Binary PeerMsg where
   put msg = case msg of
@@ -101,14 +106,13 @@ instance Binary PeerMsg where
              PortMsg p -> do putWord32be 3
                              putWord8 9
                              putWord16be (fromIntegral p)
-             ExtendedHandshakeMsg t b-> do putWord32be msgLen
-                                           putWord8 20
-                                           putWord8 t -- 0 => handshake msg
-                                           -- actual extension msg follows
-                                           mapM_ putWord8 blockList
-                                             where blockList = unpack b
-                                                   blockLen  = length blockList
-
+             ExtendedMsg t b-> do putWord32be (fromIntegral blockLen)
+                                  putWord8 20
+                                  putWord8 (fromIntegral t) -- 0 => handshake msg
+                                  -- actual extension msg follows
+                                  mapM_ putWord8 blockList
+                                    where blockList = unpack b
+                                          blockLen  = length blockList
 
     where putIndexOffsetLength i o l = do
             putWord32be (fromIntegral i)
@@ -152,9 +156,9 @@ genHandshakeMsg :: ByteString -> String -> ByteString
 genHandshakeMsg infoHash peer_id = concat [pstrlen, pstr, reserved1, reserved2, reserved3, infoHash, peerID]
   where pstrlen = singleton 19
         pstr = BC.pack "BitTorrent protocol"
-        reserved1 = BC.replicate 4 '\0'
+        reserved1 = BC.replicate 5 '\0'
         reserved2 = singleton 0x10 -- support extension protocol
-        reserved3 = BC.replicate 3 '\0'
+        reserved3 = BC.replicate 2 '\0'
         peerID = BC.pack peer_id
 
 bsToInt :: ByteString -> Int
-- 
2.45.2