From a91b767bf4d372b3ec8f7f7e2dbe1b44a43fe20e Mon Sep 17 00:00:00 2001
From: Ramakrishnan Muthukrishnan <ram@rkrishnan.org>
Date: Sun, 12 Jun 2016 10:04:21 +0530
Subject: [PATCH] WIP: UDP tracker, compiles

---
 functorrent.cabal                |  2 +
 src/FuncTorrent/Tracker/Http.hs  | 15 +------
 src/FuncTorrent/Tracker/Types.hs |  5 ---
 src/FuncTorrent/Tracker/Udp.hs   | 70 ++++++++++++++++----------------
 src/FuncTorrent/Utils.hs         | 32 ++++++++++++---
 5 files changed, 64 insertions(+), 60 deletions(-)

diff --git a/functorrent.cabal b/functorrent.cabal
index 4b865b6..3de1ebf 100644
--- a/functorrent.cabal
+++ b/functorrent.cabal
@@ -28,6 +28,7 @@ library
                        FuncTorrent.Tracker,
                        FuncTorrent.Tracker.Http,
                        FuncTorrent.Tracker.Types,
+                       FuncTorrent.Tracker.Udp,
                        FuncTorrent.Utils
 
   other-extensions:    OverloadedStrings
@@ -47,6 +48,7 @@ library
                        network-uri,
                        parsec,
                        QuickCheck,
+                       random,
                        safe,
                        transformers
 
diff --git a/src/FuncTorrent/Tracker/Http.hs b/src/FuncTorrent/Tracker/Http.hs
index 5caefd6..fe1f3e1 100644
--- a/src/FuncTorrent/Tracker/Http.hs
+++ b/src/FuncTorrent/Tracker/Http.hs
@@ -42,8 +42,8 @@ import FuncTorrent.Bencode (BVal(..))
 import qualified FuncTorrent.FileSystem as FS (MsgChannel, Stats(..), getStats)
 import FuncTorrent.Network (sendGetRequest)
 import FuncTorrent.Peer (Peer(..))
-import FuncTorrent.Utils (splitN)
-import FuncTorrent.Tracker.Types(TState(..), TrackerResponse(..), Port, IP)
+import FuncTorrent.Utils (splitN, toIP, toPort, IP, Port)
+import FuncTorrent.Tracker.Types(TState(..), TrackerResponse(..))
 
 
 --- | URL encode hash as per RFC1738
@@ -118,14 +118,3 @@ parseTrackerResponse resp =
 makePeer :: ByteString -> Peer
 makePeer peer = Peer "" (toIP ip') (toPort port')
   where (ip', port') = splitAt 4 peer
-
-toPort :: ByteString -> Port
-toPort = read . ("0x" ++) . unpack . B16.encode
-
-toIP :: ByteString -> IP
-toIP = Data.List.intercalate "." .
-       map (show . toInt . ("0x" ++) . unpack) .
-       splitN 2 . B16.encode
-
-toInt :: String -> Integer
-toInt = read
diff --git a/src/FuncTorrent/Tracker/Types.hs b/src/FuncTorrent/Tracker/Types.hs
index a1fc669..c79fcef 100644
--- a/src/FuncTorrent/Tracker/Types.hs
+++ b/src/FuncTorrent/Tracker/Types.hs
@@ -24,8 +24,6 @@ module FuncTorrent.Tracker.Types
        , TrackerEventState(..)
        , TState(..)
        , TrackerMsg(..)
-       , IP
-       , Port
        ) where
 
 import Data.ByteString (ByteString)
@@ -33,9 +31,6 @@ import Control.Concurrent.MVar (MVar)
 
 import FuncTorrent.Peer (Peer(..))
 
-type IP = String
-type Port = Integer
-
 data TrackerProtocol = Http
                      | Udp
                      | UnknownProtocol
diff --git a/src/FuncTorrent/Tracker/Udp.hs b/src/FuncTorrent/Tracker/Udp.hs
index aa7bfd5..7467ce8 100644
--- a/src/FuncTorrent/Tracker/Udp.hs
+++ b/src/FuncTorrent/Tracker/Udp.hs
@@ -18,23 +18,25 @@
  -}
 
 {-# LANGUAGE OverloadedStrings #-}
-module Functorrent.Tracker.Udp
+module FuncTorrent.Tracker.Udp
        (
        ) where
 
 import Control.Applicative (liftA2)
-import Control.Monad.Error (ErrorT)
-import Control.Monad.Reader (ReaderT, runReaderT, ask)
+import Control.Monad.Reader (ReaderT, runReaderT, ask, liftIO)
 import Data.Binary (Binary(..), encode, decode)
 import Data.Binary.Get (Get, isEmpty, getWord32be, getByteString)
 import Data.Binary.Put (putWord16be, putWord64be, putWord32be, putByteString)
-import Data.ByteString.Char8 as BC
-import Data.ByteString.Lazy (fromStrict)
-import Data.Word (Word32)
-import Network.Socket (Socket, SockAddr, sendTo, recvFrom)
+import Data.ByteString (ByteString)
+import qualified Data.ByteString.Char8 as BC
+import Data.ByteString.Lazy (fromStrict, toStrict)
+import Data.Word (Word32, Word64)
+import Network.Socket (Socket, Family( AF_INET ), SocketType( Datagram ), defaultProtocol, SockAddr(..), socket, inet_addr)
+import Network.Socket.ByteString (sendTo, recvFrom)
 import System.Random (randomIO)
 
-import FuncTorrent.Tracker.Types (TrackerEventState(..), IP, Port)
+import FuncTorrent.Tracker.Types (TrackerEventState(..))
+import FuncTorrent.Utils (IP, Port, toIP, toPort)
 
 -- UDP tracker: http://bittorrent.org/beps/bep_0015.html
 data Action = Connect
@@ -47,7 +49,7 @@ data UDPRequest = ConnectReq Word32
                 | ScrapeReq Integer Integer ByteString
                 deriving (Show, Eq)
 
-data UDPResponse = ConnectResp Integer Integer -- transaction_id connection_id
+data UDPResponse = ConnectResp Word32 Word64 -- transaction_id connection_id
                  | AnnounceResp Integer Integer Integer Integer [(IP, Port)] -- transaction_id interval leechers seeders [(ip, port)]
                  | ScrapeResp Integer Integer Integer Integer
                  | ErrorResp Integer String
@@ -117,31 +119,34 @@ instance Binary UDPResponse where
       3 -> do -- error response
         tid <- fromIntegral <$> getWord32be
         bs  <- getByteString 4
-        return $ ErrorResp tid $ unpack bs
+        return $ ErrorResp tid $ BC.unpack bs
       _ -> error ("unknown response action type: " ++ show a)
 
-sendRequest :: UDPTrackerHandle -> UDPRequest -> IO ()
+sendRequest :: UDPTrackerHandle -> ByteString -> IO ()
 sendRequest h req = do
   n <- sendTo (sock h) req (addr h)
   -- sanity check with n?
   return ()
 
-recvResponse :: UDPTrackerHandle -> ErrorT String IO UDPResponse
+recvResponse :: UDPTrackerHandle -> IO UDPResponse
 recvResponse h = do
-  (bs, nbytes, saddr) <- recvFrom (sock h) 20
-  -- check if nbytes is at least 16 bytes long
+  (bs, saddr) <- recvFrom (sock h) 32
   return $ decode $ fromStrict bs
 
-connectRequest :: ReaderT UDPTrackerHandle IO Integer
+connectRequest :: ReaderT UDPTrackerHandle IO ()
 connectRequest = do
   h <- ask
   let pkt = encode $ ConnectReq (tid h)
-  sendRequest h pkt
+  liftIO $ sendRequest h (toStrict pkt)
 
-connectResponse :: ReaderT UDPTrackerHandle IO Bool
-connectResponse = do
+connectResponse :: Word32 -> ReaderT UDPTrackerHandle IO Bool
+connectResponse itid = do
   h <- ask
-  
+  resp <- liftIO $ recvResponse h
+  -- check if nbytes is at least 16 bytes long
+  case resp of
+    (ConnectResp tid cid) -> return $ tid == itid
+    _                     -> return False
 
 getIPPortPairs :: Get [(IP, Port)]
 getIPPortPairs = do
@@ -154,22 +159,15 @@ getIPPortPairs = do
     ipportpairs <- getIPPortPairs
     return $ (ip, port) : ipportpairs
 
-getResponse :: Socket -> IO UDPResponse
-getResponse s = do
-  -- connect packet is 16 bytes long
-  -- announce packet is atleast 20 bytes long
-  bs <- recv s (16*1024)
-  return $ decode $ fromStrict bs
-
-
-udpTrackerLoop :: PortNumber -> String -> Metainfo -> TState -> IO String
-udpTrackerLoop port peerId m st = do
-  -- h <- connectTo "exodus.desync.com" (PortNumber 6969)
+startSession :: IP -> Port -> IO UDPTrackerHandle
+startSession ip port = do
   s <- socket AF_INET Datagram defaultProtocol
-  hostAddr <- inet_addr "185.37.101.229"
+  hostAddr <- inet_addr ip
   putStrLn "connected to tracker"
-  _ <- sendTo s (toStrict $ encode (ConnectReq 42)) (SockAddrInet 2710 hostAddr)
-  putStrLn "--> sent ConnectReq to tracker"
-  resp <- recv s 16
-  putStrLn "<-- recv ConnectResp from tracker"
-  return $ show resp
+  r <- randomIO
+  return $ UDPTrackerHandle { sock = s
+                            , tid = r
+                            , addr = (SockAddrInet (fromIntegral port) hostAddr) }
+  
+-- closeSession :: UDPTrackerHandle
+
diff --git a/src/FuncTorrent/Utils.hs b/src/FuncTorrent/Utils.hs
index 4d89e83..fe9b423 100644
--- a/src/FuncTorrent/Utils.hs
+++ b/src/FuncTorrent/Utils.hs
@@ -18,12 +18,16 @@
  -}
 
 module FuncTorrent.Utils
-       (createDummyFile,
-        writeFileAtOffset,
-        readFileAtOffset,
-        splitNum,
-        splitN,
-        verifyHash
+       ( createDummyFile
+       , writeFileAtOffset
+       , readFileAtOffset
+       , splitNum
+       , splitN
+       , verifyHash
+       , IP
+       , Port
+       , toIP
+       , toPort
        )
        where
 
@@ -32,10 +36,15 @@ import Prelude hiding (writeFile, take)
 import qualified Crypto.Hash.SHA1 as SHA1 (hash)
 import Control.Exception.Base (IOException, try)
 import Data.ByteString (ByteString, writeFile, hPut, hGet, take)
+import qualified Data.ByteString.Base16 as B16 (encode)
 import qualified Data.ByteString.Char8 as BC
+import Data.List (intercalate)
 import System.IO (Handle, hSeek, SeekMode(..))
 import System.Directory (doesFileExist)
 
+type IP = String
+type Port = Integer
+
 splitN :: Int -> BC.ByteString -> [BC.ByteString]
 splitN n bs | BC.null bs = []
             | otherwise = BC.take n bs : splitN n (BC.drop n bs)
@@ -68,3 +77,14 @@ readFileAtOffset h offset len = do
 verifyHash :: ByteString -> ByteString -> Bool
 verifyHash bs pieceHash =
   take 20 (SHA1.hash bs) == pieceHash
+
+toPort :: ByteString -> Port
+toPort = read . ("0x" ++) . BC.unpack . B16.encode
+
+toIP :: ByteString -> IP
+toIP = Data.List.intercalate "." .
+       map (show . toInt . ("0x" ++) . BC.unpack) .
+       splitN 2 . B16.encode
+
+toInt :: String -> Integer
+toInt = read
-- 
2.45.2