Skip to content

Commit

Permalink
Use sendmsg on windows as well
Browse files Browse the repository at this point in the history
  • Loading branch information
bltavares committed Jun 30, 2020
1 parent 2095ec7 commit e7a446f
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn create_on_interfaces(
socket.bind(&SocketAddr::from(multicast_address).into())?;
// Otherwhise we bind to 0.0.0.0
#[cfg(not(any(target_os = "linux", target_os = "android")))]
socket.bind(&SocketAddr::from(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 5353)).into())?;
socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), multicast_address.port()).into())?;

Ok(MulticastSocket {
socket,
Expand Down
172 changes: 132 additions & 40 deletions src/win.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,16 @@ use std::time::Duration;
use socket2::{Domain, Protocol, Socket, Type};

use winapi::ctypes::{c_char, c_int};
use winapi::shared::inaddr::*;
use winapi::shared::minwindef::DWORD;
use winapi::shared::minwindef::{INT, LPDWORD};
use winapi::shared::ws2def::LPWSAMSG;
use winapi::shared::ws2def::*;
use winapi::shared::ws2ipdef::*;
use winapi::um::mswsock::{LPFN_WSARECVMSG, WSAID_WSARECVMSG};
use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG};
use winapi::um::winsock2 as sock;
use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET};

/// On Windows, unlike all Unix variants, it is improper to bind to the multicast address
///
/// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms737550(v=vs.85).aspx
fn bind_multicast(socket: &Socket, addr: &SocketAddr) -> io::Result<()> {
let addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), addr.port());
socket.bind(&socket2::SockAddr::from(addr))
}

fn last_error() -> io::Error {
io::Error::from_raw_os_error(unsafe { sock::WSAGetLastError() })
}
Expand All @@ -34,14 +27,7 @@ where
T: Copy,
{
let payload = &payload as *const T as *const c_char;
if sock::setsockopt(
socket as usize,
opt,
val,
payload,
mem::size_of::<T>() as c_int,
) == 0
{
if sock::setsockopt(socket as _, opt, val, payload, mem::size_of::<T>() as c_int) == 0 {
Ok(())
} else {
Err(last_error())
Expand All @@ -57,12 +43,12 @@ type WSARecvMsgExtension = unsafe extern "system" fn(
) -> INT;

fn locate_wsarecvmsg(socket: RawSocket) -> io::Result<WSARecvMsgExtension> {
let mut fn_pointer = 0 as usize;
let mut fn_pointer: usize = 0;
let mut byte_len: u32 = 0;

let r = unsafe {
sock::WSAIoctl(
socket as usize,
socket as _,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&mut WSAID_WSARECVMSG as *const _ as *mut _,
mem::size_of_val(&WSAID_WSARECVMSG) as DWORD,
Expand All @@ -77,7 +63,7 @@ fn locate_wsarecvmsg(socket: RawSocket) -> io::Result<WSARecvMsgExtension> {
return Err(io::Error::last_os_error());
}

if mem::size_of::<LPFN_WSARECVMSG>() != byte_len as usize {
if mem::size_of::<LPFN_WSARECVMSG>() != byte_len as _ {
return Err(io::Error::new(
io::ErrorKind::Other,
"Locating fn pointer to WSARecvMsg returned different expected bytes",
Expand All @@ -94,6 +80,53 @@ fn locate_wsarecvmsg(socket: RawSocket) -> io::Result<WSARecvMsgExtension> {
}
}

type WSASendMsgExtension = unsafe extern "system" fn(
s: SOCKET,
lpMsg: LPWSAMSG,
dwFlags: DWORD,
lpNumberOfBytesSent: LPDWORD,
lpOverlapped: LPWSAOVERLAPPED,
lpCompletionRoutine: LPWSAOVERLAPPED_COMPLETION_ROUTINE,
) -> INT;

fn locate_wsasendmsg(socket: RawSocket) -> io::Result<WSASendMsgExtension> {
let mut fn_pointer: usize = 0;
let mut byte_len: u32 = 0;

let r = unsafe {
sock::WSAIoctl(
socket as _,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&mut WSAID_WSASENDMSG as *const _ as *mut _,
mem::size_of_val(&WSAID_WSASENDMSG) as DWORD,
&mut fn_pointer as *const _ as *mut _,
mem::size_of_val(&fn_pointer) as DWORD,
&mut byte_len,
ptr::null_mut(),
None,
)
};
if r != 0 {
return Err(io::Error::last_os_error());
}

if mem::size_of::<LPFN_WSASENDMSG>() != byte_len as _ {
return Err(io::Error::new(
io::ErrorKind::Other,
"Locating fn pointer to WSASendMsg returned different expected bytes",
));
}
let cast_to_fn: LPFN_WSASENDMSG = unsafe { mem::transmute(fn_pointer) };

match cast_to_fn {
None => Err(io::Error::new(
io::ErrorKind::Other,
"WSASendMsg extension not foud",
)),
Some(extension) => Ok(extension),
}
}

fn set_pktinfo(socket: RawSocket, payload: bool) -> io::Result<()> {
unsafe { setsockopt(socket, IPPROTO_IP, IP_PKTINFO, payload as c_int) }
}
Expand Down Expand Up @@ -127,17 +160,21 @@ fn create_on_interfaces(
// enable fetching interface information and locate the extension function
set_pktinfo(socket.as_raw_socket(), true)?;
let wsarecvmsg: WSARecvMsgExtension = locate_wsarecvmsg(socket.as_raw_socket())?;
let wsasendmsg: WSASendMsgExtension = locate_wsasendmsg(socket.as_raw_socket())?;

// Join multicast listeners on every interface passed
for interface in &interfaces {
socket.join_multicast_v4(multicast_address.ip(), &interface)?;
}

bind_multicast(&socket, &multicast_address.into())?;
// On Windows, unlike all Unix variants, it is improper to bind to the multicast address
// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms737550(v=vs.85).aspx
socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), multicast_address.port()).into())?;

Ok(MulticastSocket {
socket,
wsarecvmsg,
wsasendmsg,
interfaces,
multicast_address,
buffer_size: options.buffer_size,
Expand All @@ -147,6 +184,7 @@ fn create_on_interfaces(
pub struct MulticastSocket {
socket: socket2::Socket,
wsarecvmsg: WSARecvMsgExtension,
wsasendmsg: WSASendMsgExtension,
interfaces: Vec<Ipv4Addr>,
multicast_address: SocketAddrV4,
buffer_size: usize,
Expand Down Expand Up @@ -224,7 +262,7 @@ impl MulticastSocket {
let r = {
unsafe {
(self.wsarecvmsg)(
self.socket.as_raw_socket() as usize,
self.socket.as_raw_socket() as _,
&mut wsa_msg,
&mut read_bytes,
ptr::null_mut(),
Expand Down Expand Up @@ -262,7 +300,7 @@ impl MulticastSocket {
};

Ok(Message {
data: data_buffer[0..read_bytes as usize]
data: data_buffer[0..read_bytes as _]
.iter()
.map(|i| *i as u8)
.collect(),
Expand All @@ -272,21 +310,74 @@ impl MulticastSocket {
}

pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result<usize> {
let mut pkt_info: IN_PKTINFO = unsafe { mem::zeroed() };
match interface {
Interface::Default => self.socket.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED)?,
Interface::Ip(address) => self.socket.set_multicast_if_v4(address)?,
Interface::Index(index) => unsafe {
setsockopt(
self.socket.as_raw_socket(),
IPPROTO_IP,
IP_MULTICAST_IF,
interface_index_to_24bit_netorder(*index),
)?
},
Interface::Default => {}
Interface::Ip(address) => {
pkt_info.ipi_addr = IN_ADDR {
S_un: to_s_addr(address),
};
}
Interface::Index(index) => {
pkt_info.ipi_ifindex = *index;
}
};

let mut data = WSABUF {
buf: buf.as_ptr() as *mut _,
len: buf.len() as _,
};

let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
let hdr = CMSGHDR {
cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE,
cmsg_level: IPPROTO_IP,
cmsg_type: IP_PKTINFO,
};
unsafe {
ptr::copy(
&hdr as *const _ as *const _,
control_buffer.as_mut_ptr(),
CMSG_HEADER_SIZE,
);
ptr::copy(
&pkt_info as *const _ as *const _,
control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE),
PKTINFO_DATA_SIZE,
)
};
let control = WSABUF {
buf: control_buffer.as_mut_ptr(),
len: control_buffer.len() as _,
};

let destination = socket2::SockAddr::from(self.multicast_address);
let destination_address = destination.as_ptr();
let mut wsa_msg = WSAMSG {
name: destination_address as *mut _,
namelen: destination.len(),
lpBuffers: &mut data,
Control: control,
dwBufferCount: 1,
dwFlags: 0,
};

let mut sent_bytes = 0;
let r = unsafe {
(self.wsasendmsg)(
self.socket.as_raw_socket() as _,
&mut wsa_msg,
0,
&mut sent_bytes,
ptr::null_mut(),
None,
)
};
if r != 0 {
return Err(io::Error::last_os_error());
}

self.socket
.send_to(buf, &SocketAddr::from(self.multicast_address).into())
Ok(sent_bytes as _)
}

pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> {
Expand All @@ -297,9 +388,10 @@ impl MulticastSocket {
}
}

fn interface_index_to_24bit_netorder(index: u32) -> DWORD {
let index = index.to_be();
index & 0x00ff_0000 >> 16 as u8
| index & 0x0000_ff00 >> 8 as u8
| index & 0x0000_00ff >> 0 as u8
fn to_s_addr(addr: &Ipv4Addr) -> in_addr_S_un {
let octets = addr.octets();
let res = u32::from_ne_bytes(octets);
let mut new_addr: in_addr_S_un = unsafe { mem::zeroed() };
unsafe { *(new_addr.S_addr_mut()) = res };
new_addr
}

0 comments on commit e7a446f

Please # to comment.