diff --git a/Cargo.toml b/Cargo.toml index d7a27e4..3516604 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ features = ['reuseport'] [target.'cfg(windows)'.dependencies.winapi] version = '0.3.9' -features = ['mswsock'] +features = ['mswsock', 'iphlpapi'] [target.'cfg(not(windows))'.dependencies.nix] # Needs https://github.com/nix-rust/nix/pull/1265 which is unreleased as of today diff --git a/src/win.rs b/src/win.rs index 06b5ff1..a838b40 100644 --- a/src/win.rs +++ b/src/win.rs @@ -1,8 +1,12 @@ +use std::collections::{HashMap, HashSet}; +use std::ffi::CStr; use std::io; +use std::iter::FromIterator; use std::mem; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::os::windows::prelude::*; use std::ptr; +use std::str::FromStr; use socket2::{Domain, Protocol, Socket, Type}; @@ -13,6 +17,7 @@ use winapi::shared::minwindef::{INT, LPDWORD}; use winapi::shared::ws2def::LPWSAMSG; use winapi::shared::ws2def::*; use winapi::shared::ws2ipdef::*; +use winapi::um::iptypes; use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG}; use winapi::um::winsock2 as sock; use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET}; @@ -154,6 +159,8 @@ fn create_on_interfaces( // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms737550(v=vs.85).aspx socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), multicast_address.port()).into())?; + let interfaces = build_address_table(HashSet::from_iter(interfaces))?; + Ok(MulticastSocket { socket, wsarecvmsg, @@ -164,11 +171,58 @@ fn create_on_interfaces( }) } +/// Defines a allocation size for the buffer +/// That seems like a pretty good number for most cases +/// If things break, we can allocate the buffer a vec and try to double on error +const MAX_AMOUNT_OF_INTERFACES: usize = 16; + +fn build_address_table(interfaces: HashSet) -> io::Result> { + let mut buffer = [0; mem::size_of::() * MAX_AMOUNT_OF_INTERFACES]; + let mut adapter_info = buffer.as_mut_ptr() as iptypes::PIP_ADAPTER_INFO; + let mut size = buffer.len() as u32; + let r = unsafe { winapi::um::iphlpapi::GetAdaptersInfo(adapter_info, &mut size) }; + + if r != 0 { + return Err(io::Error::last_os_error()); + } + + let mut table = HashMap::with_capacity(interfaces.len()); + loop { + if adapter_info.is_null() { + break; + } + + let current: iptypes::IP_ADAPTER_INFO = unsafe { *adapter_info }; + let ip_address = + unsafe { CStr::from_ptr(current.IpAddressList.IpAddress.String.as_ptr()) }.to_str(); + let ip_address = match ip_address { + Ok(i) => Ipv4Addr::from_str(&i), + _ => { + continue; + } + }; + let ip_address = match ip_address { + Ok(i) => i, + _ => { + continue; + } + }; + + if interfaces.contains(&ip_address) { + table.insert(current.Index, ip_address); + } + + adapter_info = current.Next; + } + + Ok(table) +} + pub struct MulticastSocket { socket: socket2::Socket, wsarecvmsg: WSARecvMsgExtension, wsasendmsg: WSASendMsgExtension, - interfaces: Vec, + interfaces: HashMap, multicast_address: SocketAddrV4, buffer_size: usize, } @@ -293,17 +347,20 @@ impl MulticastSocket { } pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result { - let mut pkt_info: IN_PKTINFO = unsafe { mem::zeroed() }; - match interface { - Interface::Default => {} - Interface::Ip(address) => { - pkt_info.ipi_addr = IN_ADDR { + let pkt_info = match interface { + Interface::Default => None, + Interface::Ip(address) => Some(IN_PKTINFO { + ipi_addr: IN_ADDR { S_un: to_s_addr(address), - }; - } - Interface::Index(index) => { - pkt_info.ipi_ifindex = *index; - } + }, + ipi_ifindex: 0, + }), + Interface::Index(index) => self.interfaces.get(index).map(|address| IN_PKTINFO { + ipi_addr: IN_ADDR { + S_un: to_s_addr(address), + }, + ipi_ifindex: *index, + }), }; let mut data = WSABUF { @@ -311,27 +368,34 @@ impl MulticastSocket { len: buf.len() as _, }; - let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE]; - let hdr = CMSGHDR { - cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE, - cmsg_level: IPPROTO_IP, - cmsg_type: IP_PKTINFO, - }; - unsafe { - ptr::copy( - &hdr as *const _ as *const _, - control_buffer.as_mut_ptr(), - CMSG_HEADER_SIZE, - ); - ptr::copy( - &pkt_info as *const _ as *const _, - control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE), - PKTINFO_DATA_SIZE, - ) - }; - let control = WSABUF { - buf: control_buffer.as_mut_ptr(), - len: control_buffer.len() as _, + let control = if let Some(pkt_info) = pkt_info { + let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE]; + let hdr = CMSGHDR { + cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE, + cmsg_level: IPPROTO_IP, + cmsg_type: IP_PKTINFO, + }; + unsafe { + ptr::copy( + &hdr as *const _ as *const _, + control_buffer.as_mut_ptr(), + CMSG_HEADER_SIZE, + ); + ptr::copy( + &pkt_info as *const _ as *const _, + control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE), + PKTINFO_DATA_SIZE, + ) + }; + WSABUF { + buf: control_buffer.as_mut_ptr(), + len: control_buffer.len() as _, + } + } else { + WSABUF { + buf: [].as_mut_ptr(), + len: 0, + } }; let destination = socket2::SockAddr::from(self.multicast_address); @@ -364,7 +428,7 @@ impl MulticastSocket { } pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> { - for interface in &self.interfaces { + for interface in self.interfaces.values() { self.send(buf, &Interface::Ip(*interface))?; } Ok(())