Skip to content

Commit

Permalink
Merge pull request #5 from bltavares/fix-windows
Browse files Browse the repository at this point in the history
Fix windows packet creation
  • Loading branch information
bltavares authored Jul 7, 2020
2 parents 4f1a75a + d6b4990 commit 89065dd
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ features = ['reuseport']

[target.'cfg(windows)'.dependencies.winapi]
version = '0.3.9'
features = ['mswsock']
features = ['mswsock', 'iphlpapi']

[target.'cfg(not(windows))'.dependencies.nix]
# Needs https://github.com/nix-rust/nix/pull/1265 which is unreleased as of today
Expand Down
130 changes: 97 additions & 33 deletions src/win.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::collections::{HashMap, HashSet};
use std::ffi::CStr;
use std::io;
use std::iter::FromIterator;
use std::mem;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::os::windows::prelude::*;
use std::ptr;
use std::str::FromStr;

use socket2::{Domain, Protocol, Socket, Type};

Expand All @@ -13,6 +17,7 @@ use winapi::shared::minwindef::{INT, LPDWORD};
use winapi::shared::ws2def::LPWSAMSG;
use winapi::shared::ws2def::*;
use winapi::shared::ws2ipdef::*;
use winapi::um::iptypes;
use winapi::um::mswsock::{LPFN_WSARECVMSG, LPFN_WSASENDMSG, WSAID_WSARECVMSG, WSAID_WSASENDMSG};
use winapi::um::winsock2 as sock;
use winapi::um::winsock2::{LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE, SOCKET};
Expand Down Expand Up @@ -154,6 +159,8 @@ fn create_on_interfaces(
// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms737550(v=vs.85).aspx
socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), multicast_address.port()).into())?;

let interfaces = build_address_table(HashSet::from_iter(interfaces))?;

Ok(MulticastSocket {
socket,
wsarecvmsg,
Expand All @@ -164,11 +171,58 @@ fn create_on_interfaces(
})
}

/// Defines a allocation size for the buffer
/// That seems like a pretty good number for most cases
/// If things break, we can allocate the buffer a vec and try to double on error
const MAX_AMOUNT_OF_INTERFACES: usize = 16;

fn build_address_table(interfaces: HashSet<Ipv4Addr>) -> io::Result<HashMap<u32, Ipv4Addr>> {
let mut buffer = [0; mem::size_of::<iptypes::IP_ADAPTER_INFO>() * MAX_AMOUNT_OF_INTERFACES];
let mut adapter_info = buffer.as_mut_ptr() as iptypes::PIP_ADAPTER_INFO;
let mut size = buffer.len() as u32;
let r = unsafe { winapi::um::iphlpapi::GetAdaptersInfo(adapter_info, &mut size) };

if r != 0 {
return Err(io::Error::last_os_error());
}

let mut table = HashMap::with_capacity(interfaces.len());
loop {
if adapter_info.is_null() {
break;
}

let current: iptypes::IP_ADAPTER_INFO = unsafe { *adapter_info };
let ip_address =
unsafe { CStr::from_ptr(current.IpAddressList.IpAddress.String.as_ptr()) }.to_str();
let ip_address = match ip_address {
Ok(i) => Ipv4Addr::from_str(&i),
_ => {
continue;
}
};
let ip_address = match ip_address {
Ok(i) => i,
_ => {
continue;
}
};

if interfaces.contains(&ip_address) {
table.insert(current.Index, ip_address);
}

adapter_info = current.Next;
}

Ok(table)
}

pub struct MulticastSocket {
socket: socket2::Socket,
wsarecvmsg: WSARecvMsgExtension,
wsasendmsg: WSASendMsgExtension,
interfaces: Vec<Ipv4Addr>,
interfaces: HashMap<u32, Ipv4Addr>,
multicast_address: SocketAddrV4,
buffer_size: usize,
}
Expand Down Expand Up @@ -293,45 +347,55 @@ impl MulticastSocket {
}

pub fn send(&self, buf: &[u8], interface: &Interface) -> io::Result<usize> {
let mut pkt_info: IN_PKTINFO = unsafe { mem::zeroed() };
match interface {
Interface::Default => {}
Interface::Ip(address) => {
pkt_info.ipi_addr = IN_ADDR {
let pkt_info = match interface {
Interface::Default => None,
Interface::Ip(address) => Some(IN_PKTINFO {
ipi_addr: IN_ADDR {
S_un: to_s_addr(address),
};
}
Interface::Index(index) => {
pkt_info.ipi_ifindex = *index;
}
},
ipi_ifindex: 0,
}),
Interface::Index(index) => self.interfaces.get(index).map(|address| IN_PKTINFO {
ipi_addr: IN_ADDR {
S_un: to_s_addr(address),
},
ipi_ifindex: *index,
}),
};

let mut data = WSABUF {
buf: buf.as_ptr() as *mut _,
len: buf.len() as _,
};

let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
let hdr = CMSGHDR {
cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE,
cmsg_level: IPPROTO_IP,
cmsg_type: IP_PKTINFO,
};
unsafe {
ptr::copy(
&hdr as *const _ as *const _,
control_buffer.as_mut_ptr(),
CMSG_HEADER_SIZE,
);
ptr::copy(
&pkt_info as *const _ as *const _,
control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE),
PKTINFO_DATA_SIZE,
)
};
let control = WSABUF {
buf: control_buffer.as_mut_ptr(),
len: control_buffer.len() as _,
let control = if let Some(pkt_info) = pkt_info {
let mut control_buffer = [0; CONTROL_PKTINFO_BUFFER_SIZE];
let hdr = CMSGHDR {
cmsg_len: CONTROL_PKTINFO_BUFFER_SIZE,
cmsg_level: IPPROTO_IP,
cmsg_type: IP_PKTINFO,
};
unsafe {
ptr::copy(
&hdr as *const _ as *const _,
control_buffer.as_mut_ptr(),
CMSG_HEADER_SIZE,
);
ptr::copy(
&pkt_info as *const _ as *const _,
control_buffer.as_mut_ptr().add(CMSG_HEADER_SIZE),
PKTINFO_DATA_SIZE,
)
};
WSABUF {
buf: control_buffer.as_mut_ptr(),
len: control_buffer.len() as _,
}
} else {
WSABUF {
buf: [].as_mut_ptr(),
len: 0,
}
};

let destination = socket2::SockAddr::from(self.multicast_address);
Expand Down Expand Up @@ -364,7 +428,7 @@ impl MulticastSocket {
}

pub fn broadcast(&self, buf: &[u8]) -> io::Result<()> {
for interface in &self.interfaces {
for interface in self.interfaces.values() {
self.send(buf, &Interface::Ip(*interface))?;
}
Ok(())
Expand Down

0 comments on commit 89065dd

Please # to comment.