diff --git a/src/net.rs b/src/net.rs index f3d1930..95ee93a 100644 --- a/src/net.rs +++ b/src/net.rs @@ -13,6 +13,8 @@ use std::os::windows::prelude::*; use net2::TcpBuilder; use winapi::*; +use winapi::shared::inaddr::{in_addr_S_un, IN_ADDR}; +use winapi::shared::in6addr::{in6_addr_u, IN6_ADDR}; use ws2_32::*; /// A type to represent a buffer in which a socket address will be stored. @@ -478,13 +480,63 @@ fn cvt(i: c_int, size: DWORD) -> io::Result> { } } -fn socket_addr_to_ptrs(addr: &SocketAddr) -> (*const SOCKADDR, c_int) { +/// A type with the same memory layout as `SOCKADDR`. Used in converting Rust level +/// SocketAddr* types into their system representation. The benefit of this specific +/// type over using `SOCKADDR_STORAGE` is that this type is exactly as large as it +/// needs to be and not a lot larger. And it can be initialized cleaner from Rust. +#[repr(C)] +pub(crate) union SocketAddrCRepr { + v4: SOCKADDR_IN, + v6: SOCKADDR_IN6_LH, +} + +impl SocketAddrCRepr { + pub(crate) fn as_ptr(&self) -> *const SOCKADDR { + self as *const _ as *const SOCKADDR + } +} + +fn socket_addr_to_ptrs(addr: &SocketAddr) -> (SocketAddrCRepr, c_int) { match *addr { SocketAddr::V4(ref a) => { - (a as *const _ as *const _, mem::size_of::() as c_int) + let sin_addr = unsafe { + let mut s_un = mem::zeroed::(); + *s_un.S_addr_mut() = u32::from_ne_bytes(a.ip().octets()); + IN_ADDR { S_un: s_un } + }; + + let sockaddr_in = SOCKADDR_IN { + sin_family: AF_INET as ADDRESS_FAMILY, + sin_port: a.port().to_be(), + sin_addr, + sin_zero: [0; 8], + }; + + let sockaddr = SocketAddrCRepr { v4: sockaddr_in }; + (sockaddr, mem::size_of::() as c_int) } SocketAddr::V6(ref a) => { - (a as *const _ as *const _, mem::size_of::() as c_int) + let sin6_addr = unsafe { + let mut u = mem::zeroed::(); + *u.Byte_mut() = a.ip().octets(); + IN6_ADDR { u } + }; + let u = unsafe { + let mut u = mem::zeroed::(); + *u.sin6_scope_id_mut() = a.scope_id(); + u + }; + + let sockaddr_in6 = SOCKADDR_IN6_LH { + sin6_family: AF_INET6 as ADDRESS_FAMILY, + sin6_port: a.port().to_be(), + sin6_addr, + sin6_flowinfo: a.flowinfo(), + u, + }; + + let sockaddr = SocketAddrCRepr { v6: sockaddr_in6 }; + (sockaddr, mem::size_of::() as c_int) } } } @@ -643,7 +695,7 @@ unsafe fn connect_overlapped(socket: SOCKET, let (addr_buf, addr_len) = socket_addr_to_ptrs(addr); let mut bytes_sent: DWORD = 0; - let r = connect_ex(socket, addr_buf, addr_len, + let r = connect_ex(socket, addr_buf.as_ptr(), addr_len, buf.as_ptr() as *mut _, buf.len() as u32, &mut bytes_sent, overlapped); @@ -694,7 +746,7 @@ impl UdpSocketExt for UdpSocket { let mut sent_bytes = 0; let r = WSASendTo(self.as_raw_socket(), &mut buf, 1, &mut sent_bytes, 0, - addr_buf as *const _, addr_len, + addr_buf.as_ptr() as *const _, addr_len, overlapped, None); cvt(r, sent_bytes) }