diff --git a/src/unix.rs b/src/unix.rs index 225c02a..8bce150 100644 --- a/src/unix.rs +++ b/src/unix.rs @@ -48,7 +48,7 @@ fn create_on_interfaces( socket.bind(&SocketAddr::from(multicast_address).into())?; // Otherwhise we bind to 0.0.0.0 #[cfg(not(any(target_os = "linux", target_os = "android")))] - socket.bind(&SocketAddr::from(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 5353)).into())?; + socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), multicast_address.port()).into())?; Ok(MulticastSocket { socket, diff --git a/src/win.rs b/src/win.rs index f9414a5..8c3bc58 100644 --- a/src/win.rs +++ b/src/win.rs @@ -8,23 +8,16 @@ use std::time::Duration; use socket2::{Domain, Protocol, Socket, Type}; use winapi::ctypes::{c_char, c_int}; +use winapi::shared::inaddr::*; use winapi::shared::minwindef::DWORD; use winapi::shared::minwindef::{INT, LPDWORD}; use winapi::shared::ws2def::LPWSAMSG; use winapi::shared::ws2def::*; use winapi::shared::ws2ipdef::*; -use winapi::um::mswsock::{LPFN_WSARECVMSG, WSAID_WSARECVMSG}; +use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG}; use winapi::um::winsock2 as sock; use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET}; -/// On Windows, unlike all Unix variants, it is improper to bind to the multicast address -/// -/// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms737550(v=vs.85).aspx -fn bind_multicast(socket: &Socket, addr: &SocketAddr) -> io::Result<()> { - let addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), addr.port()); - socket.bind(&socket2::SockAddr::from(addr)) -} - fn last_error() -> io::Error { io::Error::from_raw_os_error(unsafe { sock::WSAGetLastError() }) } @@ -34,14 +27,7 @@ where T: Copy, { let payload = &payload as *const T as *const c_char; - if sock::setsockopt( - socket as usize, - opt, - val, - payload, - mem::size_of::() as c_int, - ) == 0 - { + if sock::setsockopt(socket as _, opt, val, payload, mem::size_of::() as c_int) == 0 { Ok(()) } else { Err(last_error()) @@ -57,12 +43,12 @@ type WSARecvMsgExtension = unsafe extern "system" fn( ) -> INT; fn locate_wsarecvmsg(socket: RawSocket) -> io::Result { - let mut fn_pointer = 0 as usize; + let mut fn_pointer: usize = 0; let mut byte_len: u32 = 0; let r = unsafe { sock::WSAIoctl( - socket as usize, + socket as _, SIO_GET_EXTENSION_FUNCTION_POINTER, &mut WSAID_WSARECVMSG as *const _ as *mut _, mem::size_of_val(&WSAID_WSARECVMSG) as DWORD, @@ -77,7 +63,7 @@ fn locate_wsarecvmsg(socket: RawSocket) -> io::Result { return Err(io::Error::last_os_error()); } - if mem::size_of::() != byte_len as usize { + if mem::size_of::() != byte_len as _ { return Err(io::Error::new( io::ErrorKind::Other, "Locating fn pointer to WSARecvMsg returned different expected bytes", @@ -94,6 +80,53 @@ fn locate_wsarecvmsg(socket: RawSocket) -> io::Result { } } +type WSASendMsgExtension = unsafe extern "system" fn( + s: SOCKET, + lpMsg: LPWSAMSG, + dwFlags: DWORD, + lpNumberOfBytesSent: LPDWORD, + lpOverlapped: LPWSAOVERLAPPED, + lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE, +) -> INT; + +fn locate_wsasendmsg(socket: RawSocket) -> io::Result { + let mut fn_pointer: usize = 0; + let mut byte_len: u32 = 0; + + let r = unsafe { + sock::WSAIoctl( + socket as _, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &mut WSAID_WSASENDMSG as *const _ as *mut _, + mem::size_of_val(&WSAID_WSASENDMSG) as DWORD, + &mut fn_pointer as *const _ as *mut _, + mem::size_of_val(&fn_pointer) as DWORD, + &mut byte_len, + ptr::null_mut(), + None, + ) + }; + if r != 0 { + return Err(io::Error::last_os_error()); + } + + if mem::size_of::() != byte_len as _ { + return Err(io::Error::new( + io::ErrorKind::Other, + "Locating fn pointer to WSASendMsg returned different expected bytes", + )); + } + let cast_to_fn: LPFN_WSASENDMSG = unsafe { mem::transmute(fn_pointer) }; + + match cast_to_fn { + None => Err(io::Error::new( + io::ErrorKind::Other, + "WSASendMsg extension not foud", + )), + Some(extension) => Ok(extension), + } +} + fn set_pktinfo(socket: RawSocket, payload: bool) -> io::Result<()> { unsafe { setsockopt(socket, IPPROTO_IP, IP_PKTINFO, payload as c_int) } } @@ -127,17 +160,21 @@ fn create_on_interfaces( // enable fetching interface information and locate the extension function set_pktinfo(socket.as_raw_socket(), true)?; let wsarecvmsg: WSARecvMsgExtension = locate_wsarecvmsg(socket.as_raw_socket())?; + let wsasendmsg: WSASendMsgExtension = locate_wsasendmsg(socket.as_raw_socket())?; // Join multicast listeners on every interface passed for interface in &interfaces { socket.join_multicast_v4(multicast_address.ip(), &interface)?; } - bind_multicast(&socket, &multicast_address.into())?; + // On Windows, unlike all Unix variants, it is improper to bind to the multicast address + // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms737550(v=vs.85).aspx + socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), multicast_address.port()).into())?; Ok(MulticastSocket { socket, wsarecvmsg, + wsasendmsg, interfaces, multicast_address, buffer_size: options.buffer_size, @@ -147,6 +184,7 @@ fn create_on_interfaces( pub struct MulticastSocket { socket: socket2::Socket, wsarecvmsg: WSARecvMsgExtension, + wsasendmsg: WSASendMsgExtension, interfaces: Vec, multicast_address: SocketAddrV4, buffer_size: usize, @@ -224,7 +262,7 @@ impl MulticastSocket { let r = { unsafe { (self.wsarecvmsg)( - self.socket.as_raw_socket() as usize, + self.socket.as_raw_socket() as _, &mut wsa_msg, &mut read_bytes, ptr::null_mut(), @@ -262,7 +300,7 @@ impl MulticastSocket { }; Ok(Message { - data: data_buffer[0..read_bytes as usize] + data: data_buffer[0..read_bytes as _] .iter() .map(|i| *i as u8) .collect(), @@ -272,21 +310,74 @@ impl MulticastSocket { } pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result { + let mut pkt_info: IN_PKTINFO = unsafe { mem::zeroed() }; match interface { - Interface::Default => self.socket.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED)?, - Interface::Ip(address) => self.socket.set_multicast_if_v4(address)?, - Interface::Index(index) => unsafe { - setsockopt( - self.socket.as_raw_socket(), - IPPROTO_IP, - IP_MULTICAST_IF, - interface_index_to_24bit_netorder(*index), - )? - }, + Interface::Default => {} + Interface::Ip(address) => { + pkt_info.ipi_addr = IN_ADDR { + S_un: to_s_addr(address), + }; + } + Interface::Index(index) => { + pkt_info.ipi_ifindex = *index; + } + }; + + let mut data = WSABUF { + buf: buf.as_ptr() as *mut _, + len: buf.len() as _, + }; + + let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE]; + let hdr = CMSGHDR { + cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE, + cmsg_level: IPPROTO_IP, + cmsg_type: IP_PKTINFO, + }; + unsafe { + ptr::copy( + &hdr as *const _ as *const _, + control_buffer.as_mut_ptr(), + CMSG_HEADER_SIZE, + ); + ptr::copy( + &pkt_info as *const _ as *const _, + control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE), + PKTINFO_DATA_SIZE, + ) }; + let control = WSABUF { + buf: control_buffer.as_mut_ptr(), + len: control_buffer.len() as _, + }; + + let destination = socket2::SockAddr::from(self.multicast_address); + let destination_address = destination.as_ptr(); + let mut wsa_msg = WSAMSG { + name: destination_address as *mut _, + namelen: destination.len(), + lpBuffers: &mut data, + Control: control, + dwBufferCount: 1, + dwFlags: 0, + }; + + let mut sent_bytes = 0; + let r = unsafe { + (self.wsasendmsg)( + self.socket.as_raw_socket() as _, + &mut wsa_msg, + 0, + &mut sent_bytes, + ptr::null_mut(), + None, + ) + }; + if r != 0 { + return Err(io::Error::last_os_error()); + } - self.socket - .send_to(buf, &SocketAddr::from(self.multicast_address).into()) + Ok(sent_bytes as _) } pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> { @@ -297,9 +388,10 @@ impl MulticastSocket { } } -fn interface_index_to_24bit_netorder(index: u32) -> DWORD { - let index = index.to_be(); - index & 0x00ff_0000 >> 16 as u8 - | index & 0x0000_ff00 >> 8 as u8 - | index & 0x0000_00ff >> 0 as u8 +fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un { + let octets = addr.octets(); + let res = u32::from_ne_bytes(octets); + let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() }; + unsafe { *(new_addr.S_addr_mut()) = res }; + new_addr }