]> git.rkrishnan.org Git - functorrent.git/blobdiff - src/FuncTorrent/Tracker/Udp.hs
refactoring: return type of tracker
[functorrent.git] / src / FuncTorrent / Tracker / Udp.hs
index aedc4a76af9a777e3d543099dd6f9d9cc33f8b2c..aaa99472b44c3a9d72f955afd0f35a5e8c05d534 100644 (file)
@@ -23,8 +23,9 @@ module FuncTorrent.Tracker.Udp
        ) where
 
 import Control.Applicative (liftA2)
        ) where
 
 import Control.Applicative (liftA2)
-import Control.Monad (liftM)
-import Control.Concurrent.MVar (readMVar)
+import Control.Monad (liftM, forever, void)
+import Control.Concurrent (threadDelay)
+import Control.Concurrent.MVar (readMVar, putMVar, isEmptyMVar, swapMVar)
 import Control.Monad.Reader (ReaderT, runReaderT, ask, liftIO)
 import Data.Binary (Binary(..), encode, decode)
 import Data.Binary.Get (Get, isEmpty, getWord32be, getWord64be, getByteString)
 import Control.Monad.Reader (ReaderT, runReaderT, ask, liftIO)
 import Data.Binary (Binary(..), encode, decode)
 import Data.Binary.Get (Get, isEmpty, getWord32be, getWord64be, getByteString)
@@ -36,8 +37,10 @@ import Data.Word (Word16, Word32, Word64)
 import Network.Socket (Socket, Family( AF_INET ), SocketType( Datagram ), defaultProtocol, SockAddr(..), socket, close, getAddrInfo, addrAddress, SockAddr(..))
 import Network.Socket.ByteString (sendTo, recvFrom)
 import System.Random (randomIO)
 import Network.Socket (Socket, Family( AF_INET ), SocketType( Datagram ), defaultProtocol, SockAddr(..), socket, close, getAddrInfo, addrAddress, SockAddr(..))
 import Network.Socket.ByteString (sendTo, recvFrom)
 import System.Random (randomIO)
+import System.Timeout (timeout)
 
 
-import FuncTorrent.Tracker.Types (TrackerEventState(..), TState(..))
+import FuncTorrent.PeerMsgs (Peer(..))
+import FuncTorrent.Tracker.Types (TrackerEventState(..), TState(..), UdpTrackerResponse(..))
 import FuncTorrent.Utils (IP, Port, toIP, toPort, getHostname, getPort)
 import qualified FuncTorrent.FileSystem as FS (MsgChannel, Stats(..), getStats)
 
 import FuncTorrent.Utils (IP, Port, toIP, toPort, getHostname, getPort)
 import qualified FuncTorrent.FileSystem as FS (MsgChannel, Stats(..), getStats)
 
@@ -53,7 +56,7 @@ data UDPRequest = ConnectReq Word32
                 deriving (Show, Eq)
 
 data UDPResponse = ConnectResp Word32 Word64 -- transaction_id connection_id
                 deriving (Show, Eq)
 
 data UDPResponse = ConnectResp Word32 Word64 -- transaction_id connection_id
-                 | AnnounceResp Word32 Word32 Word32 Word32 [(IP, Port)] -- transaction_id interval leechers seeders [(ip, port)]
+                 | AnnounceResp Word32 Word32 Word32 Word32 [Peer] -- transaction_id interval leechers seeders [(ip, port)]
                  | ScrapeResp Integer Integer Integer Integer
                  | ErrorResp Integer String
                  deriving (Show, Eq)
                  | ScrapeResp Integer Integer Integer Integer
                  | ErrorResp Integer String
                  deriving (Show, Eq)
@@ -168,12 +171,7 @@ announceRequest cid infohash peerId up down left port = do
   liftIO $ sendRequest h (toStrict pkt)
   return tidi
 
   liftIO $ sendRequest h (toStrict pkt)
   return tidi
 
-data PeerStats = PeerStats { leechers :: Word32
-                           , seeders :: Word32
-                           , peers :: [(IP, Port)]
-                           } deriving (Show)
-
-announceResponse :: Word32 -> ReaderT UDPTrackerHandle IO PeerStats
+announceResponse :: Word32 -> ReaderT UDPTrackerHandle IO UdpTrackerResponse
 announceResponse tid = do
   h <- ask
   resp <- liftIO $ recvResponse h
 announceResponse tid = do
   h <- ask
   resp <- liftIO $ recvResponse h
@@ -182,12 +180,12 @@ announceResponse tid = do
       if tidr == tid
       then do
         liftIO $ putStrLn "announce succeeded"
       if tidr == tid
       then do
         liftIO $ putStrLn "announce succeeded"
-        return $ PeerStats ls ss xs
+        return $ UdpTrackerResponse ls ss interval xs
       else
       else
-        return $ PeerStats 0 0 []
-    _ -> return $ PeerStats 0 0 []
+        return $ UdpTrackerResponse 0 0 0 []
+    _ -> return $ UdpTrackerResponse 0 0 0 []
 
 
-getIPPortPairs :: Get [(IP, Port)]
+getIPPortPairs :: Get [Peer]
 getIPPortPairs = do
   empty <- isEmpty
   if empty
 getIPPortPairs = do
   empty <- isEmpty
   if empty
@@ -196,7 +194,7 @@ getIPPortPairs = do
     ip <- toIP <$> getByteString 4
     port <- toPort <$> getByteString 2
     ipportpairs <- getIPPortPairs
     ip <- toIP <$> getByteString 4
     port <- toPort <$> getByteString 2
     ipportpairs <- getIPPortPairs
-    return $ (ip, port) : ipportpairs
+    return $ (Peer ip port) : ipportpairs
 
 startSession :: String -> Port -> IO UDPTrackerHandle
 startSession host port = do
 
 startSession :: String -> Port -> IO UDPTrackerHandle
 startSession host port = do
@@ -211,20 +209,29 @@ closeSession :: UDPTrackerHandle -> IO ()
 closeSession (UDPTrackerHandle s _ _) = close s
 
 trackerLoop :: String -> Port -> String -> ByteString -> FS.MsgChannel -> TState -> IO ()
 closeSession (UDPTrackerHandle s _ _) = close s
 
 trackerLoop :: String -> Port -> String -> ByteString -> FS.MsgChannel -> TState -> IO ()
-trackerLoop url sport peerId infohash fschan tstate = do
-  st' <- FS.getStats fschan
-  st <- readMVar st'
-  let up = FS.bytesRead st
-      down = FS.bytesWritten st
-      port = getPort url
-      host = getHostname url
-  putStrLn $ "host = " ++ (show host) ++ " port= " ++ (show port)
+trackerLoop url sport peerId infohash fschan tstate = forever $ do
+  st <- readMVar <$> FS.getStats fschan
+  up <- fmap FS.bytesRead st
+  down <- fmap FS.bytesWritten st
   handle <- startSession host port
   handle <- startSession host port
-  flip runReaderT handle $ do
-    t1 <- connectRequest
-    cid <- connectResponse t1
-    liftIO $ print "connected: connect id"
-    t2 <- announceRequest cid infohash peerId (fromIntegral up) (fromIntegral down) (fromIntegral (left tstate)) (fromIntegral sport)
-    liftIO $ print "waiting for announce response"
-    stats <- announceResponse t2
-    liftIO $ print stats
+  stats <- timeout (15*(10^6)) $ worker handle up down
+  case stats of
+    Nothing -> closeSession handle
+    Just stats' -> do
+      ps <- isEmptyMVar $ connectedPeers tstate
+      if ps
+        then
+        putMVar (connectedPeers tstate) (peers stats')
+        else
+        void $ swapMVar (connectedPeers tstate) (peers stats')
+      threadDelay $ fromIntegral (interval stats') * (10^6)
+      return ()
+  where
+    port = getPort url
+    host = getHostname url
+    worker handle up down = flip runReaderT handle $ do
+      t1 <- connectRequest
+      cid <- connectResponse t1
+      t2 <- announceRequest cid infohash peerId (fromIntegral up) (fromIntegral down) (fromIntegral (left tstate)) (fromIntegral sport)
+      stats <- announceResponse t2
+      return stats