From 2ddd7237e3615d4a55460ca86de22a669c90232c Mon Sep 17 00:00:00 2001
From: Ramakrishnan Muthukrishnan <ram@rkrishnan.org>
Date: Sun, 19 Jun 2016 12:00:19 +0530
Subject: [PATCH] Tracker/Udp: refactor the trackerloop, add timeouts

---
 src/FuncTorrent/PeerMsgs.hs     |  8 +++++
 src/FuncTorrent/Tracker.hs      |  1 +
 src/FuncTorrent/Tracker/Http.hs | 11 +++---
 src/FuncTorrent/Tracker/Udp.hs  | 62 ++++++++++++++++++++-------------
 4 files changed, 51 insertions(+), 31 deletions(-)

diff --git a/src/FuncTorrent/PeerMsgs.hs b/src/FuncTorrent/PeerMsgs.hs
index 79c41c2..cfefce1 100644
--- a/src/FuncTorrent/PeerMsgs.hs
+++ b/src/FuncTorrent/PeerMsgs.hs
@@ -24,6 +24,7 @@ module FuncTorrent.PeerMsgs
         sendMsg,
         getMsg,
         Peer(..),
+        makePeer,
         PeerMsg(..)
        ) where
 
@@ -32,6 +33,7 @@ import Prelude hiding (lookup, concat, replicate, splitAt, take)
 import System.IO (Handle)
 import Data.ByteString (ByteString, pack, unpack, concat, hGet, hPut, singleton)
 import Data.ByteString.Lazy (fromStrict, fromChunks, toStrict)
+import Data.ByteString.Char8 as BC (splitAt)
 import qualified Data.ByteString.Char8 as BC (replicate, pack)
 import Control.Monad (replicateM, liftM)
 import Control.Applicative (liftA3)
@@ -40,6 +42,8 @@ import Data.Binary (Binary(..), decode, encode)
 import Data.Binary.Put (putWord32be, putWord16be, putWord8)
 import Data.Binary.Get (getWord32be, getWord16be, getWord8, runGet)
 
+import FuncTorrent.Utils (toIP, toPort)
+
 -- | Peer is a PeerID, IP address, port tuple
 data Peer = Peer ID IP Port
           deriving (Show, Eq)
@@ -142,3 +146,7 @@ genHandshakeMsg infoHash peer_id = concat [pstrlen, pstr, reserved, infoHash, pe
 
 bsToInt :: ByteString -> Int
 bsToInt x = fromIntegral (runGet getWord32be (fromChunks (return x)))
+
+makePeer :: ByteString -> Peer
+makePeer peer = Peer "" (toIP ip') (toPort port')
+  where (ip', port') = splitAt 4 peer
diff --git a/src/FuncTorrent/Tracker.hs b/src/FuncTorrent/Tracker.hs
index 9873fe1..8f6a5cc 100644
--- a/src/FuncTorrent/Tracker.hs
+++ b/src/FuncTorrent/Tracker.hs
@@ -61,6 +61,7 @@ runTracker msgChannel fsChan infohash port peerId announceList sz = do
       return ()
     Udp -> do
       _ <- forkIO $ UT.trackerLoop turl (fromIntegral port) peerId infohash fsChan initialTState
+      runStateT (msgHandler msgChannel) initialTState
       return ()
     _ ->
       error "Tracker Protocol unimplemented"
diff --git a/src/FuncTorrent/Tracker/Http.hs b/src/FuncTorrent/Tracker/Http.hs
index abb4b32..f911b2f 100644
--- a/src/FuncTorrent/Tracker/Http.hs
+++ b/src/FuncTorrent/Tracker/Http.hs
@@ -41,8 +41,8 @@ import qualified FuncTorrent.Bencode as Benc
 import FuncTorrent.Bencode (BVal(..))
 import qualified FuncTorrent.FileSystem as FS (MsgChannel, Stats(..), getStats)
 import FuncTorrent.Network (sendGetRequest)
-import FuncTorrent.Peer (Peer(..))
-import FuncTorrent.Utils (splitN, toIP, toPort, IP, Port)
+import FuncTorrent.PeerMsgs (Peer(..), makePeer)
+import FuncTorrent.Utils (splitN, IP, Port)
 import FuncTorrent.Tracker.Types(TState(..), TrackerResponse(..))
 
 
@@ -74,12 +74,12 @@ mkArgs port peer_id up down left' infoHash =
    ("event", "started")]
 
 trackerLoop :: String -> PortNumber -> String -> ByteString -> FS.MsgChannel -> TState -> IO ()
-trackerLoop url port peerId infohash fschan tstate = forever $ do
+trackerLoop url sport peerId infohash fschan tstate = forever $ do
   st' <- FS.getStats fschan
   st <- readMVar st'
   let up = FS.bytesRead st
       down = FS.bytesWritten st
-  resp <- sendGetRequest url $ mkArgs port peerId up down (left tstate) infohash
+  resp <- sendGetRequest url $ mkArgs sport peerId up down (left tstate) infohash
   case Benc.decode resp of
     Left e ->
       return () -- $ pack (show e)
@@ -113,6 +113,3 @@ parseTrackerResponse resp =
     where
       (Bdict body) = resp
 
-makePeer :: ByteString -> Peer
-makePeer peer = Peer "" (toIP ip') (toPort port')
-  where (ip', port') = splitAt 4 peer
diff --git a/src/FuncTorrent/Tracker/Udp.hs b/src/FuncTorrent/Tracker/Udp.hs
index 37979c4..fe4d9e0 100644
--- a/src/FuncTorrent/Tracker/Udp.hs
+++ b/src/FuncTorrent/Tracker/Udp.hs
@@ -23,8 +23,9 @@ module FuncTorrent.Tracker.Udp
        ) where
 
 import Control.Applicative (liftA2)
-import Control.Monad (liftM)
-import Control.Concurrent.MVar (readMVar)
+import Control.Monad (liftM, forever, void)
+import Control.Concurrent (threadDelay)
+import Control.Concurrent.MVar (readMVar, putMVar, isEmptyMVar, swapMVar)
 import Control.Monad.Reader (ReaderT, runReaderT, ask, liftIO)
 import Data.Binary (Binary(..), encode, decode)
 import Data.Binary.Get (Get, isEmpty, getWord32be, getWord64be, getByteString)
@@ -36,7 +37,9 @@ import Data.Word (Word16, Word32, Word64)
 import Network.Socket (Socket, Family( AF_INET ), SocketType( Datagram ), defaultProtocol, SockAddr(..), socket, close, getAddrInfo, addrAddress, SockAddr(..))
 import Network.Socket.ByteString (sendTo, recvFrom)
 import System.Random (randomIO)
+import System.Timeout (timeout)
 
+import FuncTorrent.Peer (Peer(..))
 import FuncTorrent.Tracker.Types (TrackerEventState(..), TState(..))
 import FuncTorrent.Utils (IP, Port, toIP, toPort, getHostname, getPort)
 import qualified FuncTorrent.FileSystem as FS (MsgChannel, Stats(..), getStats)
@@ -53,7 +56,7 @@ data UDPRequest = ConnectReq Word32
                 deriving (Show, Eq)
 
 data UDPResponse = ConnectResp Word32 Word64 -- transaction_id connection_id
-                 | AnnounceResp Word32 Word32 Word32 Word32 [(IP, Port)] -- transaction_id interval leechers seeders [(ip, port)]
+                 | AnnounceResp Word32 Word32 Word32 Word32 [Peer] -- transaction_id interval leechers seeders [(ip, port)]
                  | ScrapeResp Integer Integer Integer Integer
                  | ErrorResp Integer String
                  deriving (Show, Eq)
@@ -170,7 +173,8 @@ announceRequest cid infohash peerId up down left port = do
 
 data PeerStats = PeerStats { leechers :: Word32
                            , seeders :: Word32
-                           , peers :: [(IP, Port)]
+                           , interval :: Word32
+                           , peers :: [Peer]
                            } deriving (Show)
 
 announceResponse :: Word32 -> ReaderT UDPTrackerHandle IO PeerStats
@@ -182,12 +186,12 @@ announceResponse tid = do
       if tidr == tid
       then do
         liftIO $ putStrLn "announce succeeded"
-        return $ PeerStats ls ss xs
+        return $ PeerStats ls ss interval xs
       else
-        return $ PeerStats 0 0 []
-    _ -> return $ PeerStats 0 0 []
+        return $ PeerStats 0 0 0 []
+    _ -> return $ PeerStats 0 0 0 []
 
-getIPPortPairs :: Get [(IP, Port)]
+getIPPortPairs :: Get [Peer]
 getIPPortPairs = do
   empty <- isEmpty
   if empty
@@ -196,7 +200,7 @@ getIPPortPairs = do
     ip <- toIP <$> getByteString 4
     port <- toPort <$> getByteString 2
     ipportpairs <- getIPPortPairs
-    return $ (ip, port) : ipportpairs
+    return $ (Peer "" ip port) : ipportpairs
 
 startSession :: String -> Port -> IO UDPTrackerHandle
 startSession host port = do
@@ -211,19 +215,29 @@ closeSession :: UDPTrackerHandle -> IO ()
 closeSession (UDPTrackerHandle s _ _) = close s
 
 trackerLoop :: String -> Port -> String -> ByteString -> FS.MsgChannel -> TState -> IO ()
-trackerLoop url sport peerId infohash fschan tstate = do
-  st' <- FS.getStats fschan
-  st <- readMVar st'
-  let up = FS.bytesRead st
-      down = FS.bytesWritten st
-      port = getPort url
-      host = getHostname url
-  putStrLn $ "host = " ++ (show host) ++ " port= " ++ (show port)
+trackerLoop url sport peerId infohash fschan tstate = forever $ do
+  st <- fmap readMVar $ FS.getStats fschan
+  up <- fmap FS.bytesRead st
+  down <- fmap FS.bytesWritten st
   handle <- startSession host port
-  flip runReaderT handle $ do
-    t1 <- connectRequest
-    cid <- connectResponse t1
-    t2 <- announceRequest cid infohash peerId (fromIntegral up) (fromIntegral down) (fromIntegral (left tstate)) (fromIntegral sport)
-    stats <- announceResponse t2
-    liftIO $ print stats
---    _ <- threadDelay $
+  stats <- timeout (15*(10^6)) $ worker handle up down
+  case stats of
+    Nothing -> closeSession handle
+    Just stats' -> do
+      ps <- isEmptyMVar $ connectedPeers tstate
+      if ps
+        then
+        putMVar (connectedPeers tstate) (peers stats')
+        else
+        void $ swapMVar (connectedPeers tstate) (peers stats')
+      threadDelay $ fromIntegral (interval stats') * (10^6)
+      return ()
+  where
+    port = getPort url
+    host = getHostname url
+    worker handle up down = flip runReaderT handle $ do
+      t1 <- connectRequest
+      cid <- connectResponse t1
+      t2 <- announceRequest cid infohash peerId (fromIntegral up) (fromIntegral down) (fromIntegral (left tstate)) (fromIntegral sport)
+      stats <- announceResponse t2
+      return stats
-- 
2.45.2