diff --git a/torrent-client/src/client/mod.rs b/torrent-client/src/client/mod.rs index 3684773..169615c 100644 --- a/torrent-client/src/client/mod.rs +++ b/torrent-client/src/client/mod.rs @@ -1,10 +1,12 @@ use crate::file::TorrentFile; +use crate::peer::connection::PeerConnection; use crate::peer::PeerId; use crate::tracker::{AnnounceParameters, RequestMode, TrackerClient, TrackerError}; use rand::seq::SliceRandom; -use std::io::{Read, Write}; +use std::collections::VecDeque; use std::net::TcpStream; -use std::str::from_utf8; +use std::sync::{Arc, Mutex}; +use std::time::Duration; use thiserror::Error; #[derive(Error, Debug)] @@ -15,7 +17,18 @@ pub enum ClientError { type Result = std::result::Result; #[derive(Default, Debug)] -pub struct Config {} +pub struct Config { + connection_numbers: usize, +} + +impl Config { + pub fn new(connection_numbers: usize) -> Self { + if connection_numbers == 0 { + panic!("connection numbers cannot be zero") + } + Self { connection_numbers } + } +} pub struct Client { client_id: PeerId, @@ -41,28 +54,42 @@ impl Client { let mut distribution = self.tracker_client.announce(meta.announce, params)?; let mut rng = rand::thread_rng(); distribution.peers.shuffle(&mut rng); - let peer = distribution.peers.pop().unwrap(); + let peers = Arc::new(Mutex::new(VecDeque::from(distribution.peers))); + let mut handles = vec![]; + for worker_id in 0..self.config.connection_numbers { + let peers_a = peers.clone(); + let client_id = self.client_id.clone(); + let handle = std::thread::spawn(move || { + let mut q = peers_a.lock().unwrap(); + if q.len() == 0 { + println!("thread {worker_id} closes due no peers"); + } + let peer = q.pop_back().unwrap(); + drop(q); + println!("{:#?}", peer); - println!("{:#?}", peer); - let mut message = Vec::new(); - message.push(19u8); - message.extend_from_slice(b"BitTorrent protocol"); - message.extend_from_slice(b"\0\0\0\0\0\0\0\0"); - message.extend_from_slice(meta.info.info_hash.as_slice()); - message.extend_from_slice(self.client_id.as_slice()); - for byte in &message { - print!("\\x{:0x} ", byte); + println!(); + let connection = TcpStream::connect_timeout(&peer.addr, Duration::from_secs(5)); + if connection.is_err() { + println!("timeout "); + return; + } + let connection = connection.unwrap(); + let bt_conn = PeerConnection::handshake( + connection, + &meta.info.info_hash.clone(), + &client_id, + ); + match bt_conn { + Ok(_) => println!("conn ok"), + Err(e) => println!("err {}", e.to_string()), + } + }); + handles.push(handle); + } + while let Some(cur_thread) = handles.pop() { + cur_thread.join().unwrap(); } - println!(); - assert_eq!(message.len(), 68); - let mut connection = TcpStream::connect(peer.addr).unwrap(); - println!("connected"); - let _ = connection.write_all(message.as_slice()).unwrap(); - let mut response: [u8; 68] = [0; 68]; - let read = connection.read(response.as_mut_slice()).unwrap(); - println!("{read}"); - println!("{:?}", from_utf8(response.as_slice()[48..68].as_ref())); - Ok(()) } } diff --git a/torrent-client/src/file/mod.rs b/torrent-client/src/file/mod.rs index 6746798..35429aa 100644 --- a/torrent-client/src/file/mod.rs +++ b/torrent-client/src/file/mod.rs @@ -5,9 +5,9 @@ use thiserror::Error; use url::Url; use bencode::{BencodeEncoder, BencodeError, BencodeList, BencodeString, Value}; -use torrent_client::Sha1; use crate::file::TorrentError::{IntegerOutOfBound, InvalidInfoHash, MissingField}; +use crate::util::Sha1; type Result = std::result::Result; @@ -115,7 +115,6 @@ impl Info { files.push(File::from_bencode(file.try_into()?)?); } } - // TODO: add pieces field Ok(Info { files, name, diff --git a/torrent-client/src/main.rs b/torrent-client/src/main.rs index c8e1fb5..f190f3f 100644 --- a/torrent-client/src/main.rs +++ b/torrent-client/src/main.rs @@ -12,6 +12,7 @@ mod client; mod file; mod peer; mod tracker; +mod util; fn main() { let cli = cli::Args::parse(); @@ -31,7 +32,7 @@ fn main() { let client_id = PeerId::random(); let tracker = Box::new(HttpTracker::new(&client_id).unwrap()); - let client = Client::new(client_id, Config::default(), tracker); + let client = Client::new(client_id, Config::new(25), tracker); let res = client.download(torrent); println!("{res:#?}"); diff --git a/torrent-client/src/peer/connection.rs b/torrent-client/src/peer/connection.rs new file mode 100644 index 0000000..144ee3b --- /dev/null +++ b/torrent-client/src/peer/connection.rs @@ -0,0 +1,157 @@ +use crate::peer::connection::ConnectionError::HandshakeFailed; +use crate::peer::connection::HandshakeMessageError::{ProtocolString, ProtocolStringLen}; +use crate::peer::PeerId; +use crate::util::Sha1; +use bytes::{Buf, BufMut}; +use std::io; +use std::io::{Read, Write}; +use std::net::TcpStream; +use thiserror::Error; + +type Result = std::result::Result; + +static BIT_TORRENT_PROTOCOL_STRING: &[u8; 19] = b"BitTorrent protocol"; + +#[derive(Error, Debug)] +enum HandshakeMessageError { + #[error("Invalid protocol string(pstr) length, expected 19, but got {0}")] + ProtocolStringLen(u8), + #[error("Unexpected protocol string, expected \"BitTorrent protocol\", but got {0}")] + ProtocolString(String), +} + +#[derive(Debug, PartialEq, Clone)] +struct HandshakeMessage { + // need to replace with appropriate structure + extension_bytes: [u8; 8], + info_hash: Sha1, + peer_id: PeerId, +} + +impl HandshakeMessage { + fn to_bytes(&self) -> Box<[u8; 68]> { + let mut res = Box::new([0; 68]); + res[0] = 19u8; + res[1..20].copy_from_slice(BIT_TORRENT_PROTOCOL_STRING.as_slice()); + res[20..28].copy_from_slice(self.extension_bytes.as_slice()); + res[28..48].copy_from_slice(self.info_hash.as_slice()); + res[48..68].copy_from_slice(self.peer_id.as_slice()); + res + } + + fn from_bytes(raw: Box<[u8; 68]>) -> std::result::Result { + let pstr_len = raw[0]; + if pstr_len != 19 { + return Err(ProtocolStringLen(pstr_len)); + } + let pstr: [u8; 19] = raw[1..20].try_into().unwrap(); + if pstr.as_slice() != BIT_TORRENT_PROTOCOL_STRING { + return Err(ProtocolString( + String::from_utf8_lossy(pstr.as_slice()).to_string(), + )); + } + let extension_bytes: [u8; 8] = raw[20..28].try_into().expect("Slice with incorrect length"); + let info_hash: [u8; 20] = raw[28..48].try_into().expect("Slice with incorrect length"); + let peer_id: [u8; 20] = raw[48..68].try_into().expect("Slice with incorrect length"); + + Ok(Self::new(extension_bytes, info_hash, PeerId::new(peer_id))) + } + + pub fn new(extension_bytes: [u8; 8], info_hash: Sha1, peer_id: PeerId) -> Self { + Self { + extension_bytes, + info_hash, + peer_id, + } + } +} + +#[derive(Error, Debug)] +pub enum ConnectionError { + #[error("BitTorrent handshake failed {0}")] + HandshakeFailed(String), + #[error("Error in parsing handshake response {0}")] + HandshakeResponse(#[from] HandshakeMessageError), + #[error(transparent)] + IoKind(#[from] io::Error), +} + +pub struct PeerConnection { + tcp_connection: TcpStream, + peer_id: PeerId, +} + +impl PeerConnection { + pub fn handshake( + mut tcp_connection: TcpStream, + info_hash: &Sha1, + peer_id: &PeerId, + ) -> Result { + let mut bytes = + HandshakeMessage::new([0; 8], info_hash.clone(), peer_id.clone()).to_bytes(); + let _ = tcp_connection.write_all(bytes.as_ref())?; + let read_bytes = tcp_connection.read(bytes.as_mut())?; + if read_bytes != 68 { + return Err(HandshakeFailed(format!("Invalid bytes count received {read_bytes}"))) + } + + let response = HandshakeMessage::from_bytes(bytes)?; + + Ok(Self { + tcp_connection, + peer_id: response.peer_id, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::peer::connection::{HandshakeMessage, BIT_TORRENT_PROTOCOL_STRING}; + use crate::peer::PeerId; + use bytes::{BufMut, BytesMut}; + use rand::RngCore; + + #[test] + fn handshake_message_as_bytes() { + let mut extensions_bytes = [0; 8]; + rand::thread_rng().fill_bytes(&mut extensions_bytes); + let mut info_hash = [0; 20]; + rand::thread_rng().fill_bytes(&mut info_hash); + let peed_id = PeerId::random(); + + let mut bytes = BytesMut::with_capacity(68); + bytes.put_u8(19u8); + bytes.extend_from_slice(BIT_TORRENT_PROTOCOL_STRING); + bytes.extend_from_slice(extensions_bytes.as_slice()); + bytes.extend_from_slice(info_hash.as_slice()); + bytes.extend_from_slice(peed_id.as_ref()); + + let message = HandshakeMessage::new(extensions_bytes, info_hash, peed_id); + let message_bytes = message.to_bytes(); + + assert_eq!(bytes.as_ref(), message_bytes.as_slice()); + } + + #[test] + fn handshake_message_from_bytes() { + let mut extensions_bytes = [0; 8]; + rand::thread_rng().fill_bytes(&mut extensions_bytes); + let mut info_hash = [0; 20]; + rand::thread_rng().fill_bytes(&mut info_hash); + let peed_id = PeerId::random(); + + let mut bytes = BytesMut::with_capacity(68); + bytes.put_u8(19u8); + bytes.extend_from_slice(BIT_TORRENT_PROTOCOL_STRING); + bytes.extend_from_slice(extensions_bytes.as_slice()); + bytes.extend_from_slice(info_hash.as_slice()); + bytes.extend_from_slice(peed_id.as_ref()); + + let message = HandshakeMessage::new(extensions_bytes, info_hash, peed_id); + + let message_from_bytes = + HandshakeMessage::from_bytes(bytes.to_vec().try_into().unwrap()).unwrap(); + + assert_eq!(message_from_bytes, message) + } +} diff --git a/torrent-client/src/peer/mod.rs b/torrent-client/src/peer/mod.rs index 0c03145..bf7e332 100644 --- a/torrent-client/src/peer/mod.rs +++ b/torrent-client/src/peer/mod.rs @@ -1,9 +1,11 @@ +pub mod connection; + use rand::RngCore; use std::borrow::Borrow; use std::net::SocketAddr; use std::ops::Deref; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct PeerId([u8; 20]); #[derive(Debug)] diff --git a/torrent-client/src/tracker/mod.rs b/torrent-client/src/tracker/mod.rs index 9d0539d..30b4298 100644 --- a/torrent-client/src/tracker/mod.rs +++ b/torrent-client/src/tracker/mod.rs @@ -2,6 +2,7 @@ use crate::peer::{Peer, PeerId}; use crate::tracker::TrackerError::{ AnnounceRequestError, InternalError, ResponseFormat, TrackerResponse, UnsupportedProtocol, }; +use crate::util::Sha1; use bencode::{BencodeDict, Value}; use bytes::Buf; use percent_encoding::{percent_encode, NON_ALPHANUMERIC}; @@ -9,7 +10,6 @@ use std::fmt::{Display, Formatter}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; use std::time::Duration; use thiserror::Error; -use torrent_client::Sha1; use url::Url; type Result = std::result::Result; diff --git a/torrent-client/src/lib.rs b/torrent-client/src/util/mod.rs similarity index 52% rename from torrent-client/src/lib.rs rename to torrent-client/src/util/mod.rs index 96f68ac..3080f6f 100644 --- a/torrent-client/src/lib.rs +++ b/torrent-client/src/util/mod.rs @@ -1 +1,3 @@ pub type Sha1 = [u8; 20]; + +pub struct BitField {}