]> git.rkrishnan.org Git - functorrent.git/blobdiff - src/FuncTorrent/Tracker/Udp.hs
refactoring: return type of tracker
[functorrent.git] / src / FuncTorrent / Tracker / Udp.hs
index 7467ce8cd5a6d707eea50d0d4a68043f6ea62c34..aaa99472b44c3a9d72f955afd0f35a5e8c05d534 100644 (file)
 
 {-# LANGUAGE OverloadedStrings #-}
 module FuncTorrent.Tracker.Udp
 
 {-# LANGUAGE OverloadedStrings #-}
 module FuncTorrent.Tracker.Udp
-       (
+       (trackerLoop
        ) where
 
 import Control.Applicative (liftA2)
        ) where
 
 import Control.Applicative (liftA2)
+import Control.Monad (liftM, forever, void)
+import Control.Concurrent (threadDelay)
+import Control.Concurrent.MVar (readMVar, putMVar, isEmptyMVar, swapMVar)
 import Control.Monad.Reader (ReaderT, runReaderT, ask, liftIO)
 import Data.Binary (Binary(..), encode, decode)
 import Control.Monad.Reader (ReaderT, runReaderT, ask, liftIO)
 import Data.Binary (Binary(..), encode, decode)
-import Data.Binary.Get (Get, isEmpty, getWord32be, getByteString)
+import Data.Binary.Get (Get, isEmpty, getWord32be, getWord64be, getByteString)
 import Data.Binary.Put (putWord16be, putWord64be, putWord32be, putByteString)
 import Data.ByteString (ByteString)
 import qualified Data.ByteString.Char8 as BC
 import Data.ByteString.Lazy (fromStrict, toStrict)
 import Data.Binary.Put (putWord16be, putWord64be, putWord32be, putByteString)
 import Data.ByteString (ByteString)
 import qualified Data.ByteString.Char8 as BC
 import Data.ByteString.Lazy (fromStrict, toStrict)
-import Data.Word (Word32, Word64)
-import Network.Socket (Socket, Family( AF_INET ), SocketType( Datagram ), defaultProtocol, SockAddr(..), socket, inet_addr)
+import Data.Word (Word16, Word32, Word64)
+import Network.Socket (Socket, Family( AF_INET ), SocketType( Datagram ), defaultProtocol, SockAddr(..), socket, close, getAddrInfo, addrAddress, SockAddr(..))
 import Network.Socket.ByteString (sendTo, recvFrom)
 import System.Random (randomIO)
 import Network.Socket.ByteString (sendTo, recvFrom)
 import System.Random (randomIO)
+import System.Timeout (timeout)
 
 
-import FuncTorrent.Tracker.Types (TrackerEventState(..))
-import FuncTorrent.Utils (IP, Port, toIP, toPort)
+import FuncTorrent.PeerMsgs (Peer(..))
+import FuncTorrent.Tracker.Types (TrackerEventState(..), TState(..), UdpTrackerResponse(..))
+import FuncTorrent.Utils (IP, Port, toIP, toPort, getHostname, getPort)
+import qualified FuncTorrent.FileSystem as FS (MsgChannel, Stats(..), getStats)
 
 -- UDP tracker: http://bittorrent.org/beps/bep_0015.html
 data Action = Connect
 
 -- UDP tracker: http://bittorrent.org/beps/bep_0015.html
 data Action = Connect
@@ -45,12 +51,12 @@ data Action = Connect
             deriving (Show, Eq)
 
 data UDPRequest = ConnectReq Word32
             deriving (Show, Eq)
 
 data UDPRequest = ConnectReq Word32
-                | AnnounceReq Integer Integer ByteString String Integer Integer Integer TrackerEventState Integer
+                | AnnounceReq Word64 Word32 ByteString String Word64 Word64 Word64 TrackerEventState Word16
                 | ScrapeReq Integer Integer ByteString
                 deriving (Show, Eq)
 
 data UDPResponse = ConnectResp Word32 Word64 -- transaction_id connection_id
                 | ScrapeReq Integer Integer ByteString
                 deriving (Show, Eq)
 
 data UDPResponse = ConnectResp Word32 Word64 -- transaction_id connection_id
-                 | AnnounceResp Integer Integer Integer Integer [(IP, Port)] -- transaction_id interval leechers seeders [(ip, port)]
+                 | AnnounceResp Word32 Word32 Word32 Word32 [Peer] -- transaction_id interval leechers seeders [(ip, port)]
                  | ScrapeResp Integer Integer Integer Integer
                  | ErrorResp Integer String
                  deriving (Show, Eq)
                  | ScrapeResp Integer Integer Integer Integer
                  | ErrorResp Integer String
                  deriving (Show, Eq)
@@ -74,6 +80,7 @@ eventToInteger :: TrackerEventState -> Integer
 eventToInteger None = 0
 eventToInteger Completed = 1
 eventToInteger Started = 2
 eventToInteger None = 0
 eventToInteger Completed = 1
 eventToInteger Started = 2
+eventToInteger Stopped = 3
 
 instance Binary UDPRequest where
   put (ConnectReq transId) = do
 
 instance Binary UDPRequest where
   put (ConnectReq transId) = do
@@ -89,10 +96,10 @@ instance Binary UDPRequest where
     putWord64be (fromIntegral down)
     putWord64be (fromIntegral left)
     putWord64be (fromIntegral up)
     putWord64be (fromIntegral down)
     putWord64be (fromIntegral left)
     putWord64be (fromIntegral up)
-    putWord32be $ fromIntegral (eventToInteger None)
+    putWord32be $ fromIntegral (eventToInteger event)
     putWord32be 0
     putWord32be 0
-    -- key is optional, we will not send it for now
-    putWord32be $ fromIntegral (-1)
+    putWord32be 0
+    putWord32be 10
     putWord16be $ fromIntegral port
   put (ScrapeReq _ _ _) = undefined
   get = undefined
     putWord16be $ fromIntegral port
   put (ScrapeReq _ _ _) = undefined
   get = undefined
@@ -102,14 +109,14 @@ instance Binary UDPResponse where
   get = do
     a <- getWord32be -- action
     case a of
   get = do
     a <- getWord32be -- action
     case a of
-      0 -> liftA2 ConnectResp (fromIntegral <$> getWord32be) (fromIntegral <$> getWord32be)
+      0 -> liftA2 ConnectResp (fromIntegral <$> getWord32be) (fromIntegral <$> getWord64be)
       1 -> do
         tid <- fromIntegral <$> getWord32be
         interval' <- fromIntegral <$> getWord32be
       1 -> do
         tid <- fromIntegral <$> getWord32be
         interval' <- fromIntegral <$> getWord32be
-        _ <- getWord32be -- leechers
-        _ <- getWord32be -- seeders
+        l <- getWord32be -- leechers
+        s <- getWord32be -- seeders
         ipportpairs <- getIPPortPairs -- [(ip, port)]
         ipportpairs <- getIPPortPairs -- [(ip, port)]
-        return $ AnnounceResp tid interval' 0 0 ipportpairs
+        return $ AnnounceResp tid interval' l s ipportpairs
       2 -> do
         tid <- fromIntegral <$> getWord32be
         _ <- getWord32be
       2 -> do
         tid <- fromIntegral <$> getWord32be
         _ <- getWord32be
@@ -130,44 +137,101 @@ sendRequest h req = do
 
 recvResponse :: UDPTrackerHandle -> IO UDPResponse
 recvResponse h = do
 
 recvResponse :: UDPTrackerHandle -> IO UDPResponse
 recvResponse h = do
-  (bs, saddr) <- recvFrom (sock h) 32
+  (bs, saddr) <- recvFrom (sock h) (16*1024)
   return $ decode $ fromStrict bs
 
   return $ decode $ fromStrict bs
 
-connectRequest :: ReaderT UDPTrackerHandle IO ()
+connectRequest :: ReaderT UDPTrackerHandle IO Word32
 connectRequest = do
   h <- ask
 connectRequest = do
   h <- ask
-  let pkt = encode $ ConnectReq (tid h)
+  tidi <- liftIO randomIO
+  let pkt = encode $ ConnectReq tidi
   liftIO $ sendRequest h (toStrict pkt)
   liftIO $ sendRequest h (toStrict pkt)
+  return tidi
 
 
-connectResponse :: Word32 -> ReaderT UDPTrackerHandle IO Bool
-connectResponse itid = do
+connectResponse :: Word32 -> ReaderT UDPTrackerHandle IO Word64
+connectResponse tid = do
   h <- ask
   resp <- liftIO $ recvResponse h
   -- check if nbytes is at least 16 bytes long
   case resp of
   h <- ask
   resp <- liftIO $ recvResponse h
   -- check if nbytes is at least 16 bytes long
   case resp of
-    (ConnectResp tid cid) -> return $ tid == itid
-    _                     -> return False
+    (ConnectResp tidr cid) ->
+      if tidr == tid
+      then do
+        liftIO $ putStrLn "connect succeeded"
+        return cid
+      else
+        return 0
+    _                      -> return 0
+
+announceRequest :: Word64 -> ByteString -> String -> Word64 -> Word64 -> Word64 -> Word16 -> ReaderT UDPTrackerHandle IO Word32
+announceRequest cid infohash peerId up down left port = do
+  h <- ask
+  tidi <- liftIO randomIO
+  let pkt = encode $ AnnounceReq cid tidi infohash peerId down left up None port
+  liftIO $ sendRequest h (toStrict pkt)
+  return tidi
 
 
-getIPPortPairs :: Get [(IP, Port)]
+announceResponse :: Word32 -> ReaderT UDPTrackerHandle IO UdpTrackerResponse
+announceResponse tid = do
+  h <- ask
+  resp <- liftIO $ recvResponse h
+  case resp of
+    (AnnounceResp tidr interval ss ls xs) ->
+      if tidr == tid
+      then do
+        liftIO $ putStrLn "announce succeeded"
+        return $ UdpTrackerResponse ls ss interval xs
+      else
+        return $ UdpTrackerResponse 0 0 0 []
+    _ -> return $ UdpTrackerResponse 0 0 0 []
+
+getIPPortPairs :: Get [Peer]
 getIPPortPairs = do
   empty <- isEmpty
   if empty
     then return []
     else do
 getIPPortPairs = do
   empty <- isEmpty
   if empty
     then return []
     else do
-    ip <- toIP <$> getByteString 6
+    ip <- toIP <$> getByteString 4
     port <- toPort <$> getByteString 2
     ipportpairs <- getIPPortPairs
     port <- toPort <$> getByteString 2
     ipportpairs <- getIPPortPairs
-    return $ (ip, port) : ipportpairs
+    return $ (Peer ip port) : ipportpairs
 
 
-startSession :: IP -> Port -> IO UDPTrackerHandle
-startSession ip port = do
+startSession :: String -> Port -> IO UDPTrackerHandle
+startSession host port = do
   s <- socket AF_INET Datagram defaultProtocol
   s <- socket AF_INET Datagram defaultProtocol
-  hostAddr <- inet_addr ip
+  addrinfos <- getAddrInfo Nothing (Just host) (Just (show port))
+  let (SockAddrInet p ip) = addrAddress $ head addrinfos
   putStrLn "connected to tracker"
   putStrLn "connected to tracker"
-  r <- randomIO
-  return $ UDPTrackerHandle { sock = s
-                            , tid = r
-                            , addr = (SockAddrInet (fromIntegral port) hostAddr) }
+  return UDPTrackerHandle { sock = s
+                          , addr = (SockAddrInet (fromIntegral port) ip) }
   
   
--- closeSession :: UDPTrackerHandle
-
+closeSession :: UDPTrackerHandle -> IO ()
+closeSession (UDPTrackerHandle s _ _) = close s
+
+trackerLoop :: String -> Port -> String -> ByteString -> FS.MsgChannel -> TState -> IO ()
+trackerLoop url sport peerId infohash fschan tstate = forever $ do
+  st <- readMVar <$> FS.getStats fschan
+  up <- fmap FS.bytesRead st
+  down <- fmap FS.bytesWritten st
+  handle <- startSession host port
+  stats <- timeout (15*(10^6)) $ worker handle up down
+  case stats of
+    Nothing -> closeSession handle
+    Just stats' -> do
+      ps <- isEmptyMVar $ connectedPeers tstate
+      if ps
+        then
+        putMVar (connectedPeers tstate) (peers stats')
+        else
+        void $ swapMVar (connectedPeers tstate) (peers stats')
+      threadDelay $ fromIntegral (interval stats') * (10^6)
+      return ()
+  where
+    port = getPort url
+    host = getHostname url
+    worker handle up down = flip runReaderT handle $ do
+      t1 <- connectRequest
+      cid <- connectResponse t1
+      t2 <- announceRequest cid infohash peerId (fromIntegral up) (fromIntegral down) (fromIntegral (left tstate)) (fromIntegral sport)
+      stats <- announceResponse t2
+      return stats