From 6fcb0250b96437123b6ae20d9f20237b6dd6a589 Mon Sep 17 00:00:00 2001 From: Val Date: Sat, 1 Feb 2025 12:41:47 +0100 Subject: [PATCH 01/11] process localhost in parse_address --- lightbug_http/address.mojo | 557 ++++++++++++++++ lightbug_http/connection.mojo | 394 +++++++++++ lightbug_http/io/bytes.mojo | 2 +- lightbug_http/net.mojo | 830 ------------------------ lightbug_http/server.mojo | 2 +- lightbug_http/socket.mojo | 5 +- tests/lightbug_http/test_host_port.mojo | 2 +- 7 files changed, 956 insertions(+), 836 deletions(-) create mode 100644 lightbug_http/address.mojo create mode 100644 lightbug_http/connection.mojo delete mode 100644 lightbug_http/net.mojo diff --git a/lightbug_http/address.mojo b/lightbug_http/address.mojo new file mode 100644 index 00000000..000d3c93 --- /dev/null +++ b/lightbug_http/address.mojo @@ -0,0 +1,557 @@ +from memory import UnsafePointer +from sys.ffi import external_call, OpaquePointer +from lightbug_http.strings import NetworkType, to_string +from lightbug_http._libc import ( + c_int, + c_char, + in_addr, + sockaddr, + sockaddr_in, + socklen_t, + AF_INET, + AF_INET6, + AF_UNSPEC, + SOCK_STREAM, + ntohs, + inet_ntop, + socket, + gai_strerror, + INET_ADDRSTRLEN, + INET6_ADDRSTRLEN, +) +from lightbug_http._logger import logger +from lightbug_http.socket import Socket + +alias MAX_PORT = 65535 +alias MIN_PORT = 0 +alias DEFAULT_IP_PORT = UInt16(0) + +struct AddressConstants: + """Constants used in address parsing.""" + alias LOCALHOST = "localhost" + alias IPV4_LOCALHOST = "127.0.0.1" + alias IPV6_LOCALHOST = "::1" + alias EMPTY = "" + +trait Addr(Stringable, Representable, Writable, EqualityComparableCollectionElement): + alias _type: StringLiteral + + fn __init__(out self): + ... + + fn __init__(out self, ip: String, port: UInt16): + ... + + @always_inline + fn address_family(self) -> Int: + ... + + @always_inline + fn is_v4(self) -> Bool: + ... + + @always_inline + fn is_v6(self) -> Bool: + ... + + @always_inline + fn is_unix(self) -> Bool: + ... + + +trait AnAddrInfo: + fn get_ip_address(self, host: String) raises -> in_addr: + """TODO: Once default functions can be implemented in traits, this should use the functions currently + implemented in the `addrinfo_macos` and `addrinfo_unix` structs. + """ + ... + +@value +struct TCPAddr[network: NetworkType = NetworkType.tcp4](Addr): + alias _type = "TCPAddr" + var ip: String + var port: UInt16 + var zone: String # IPv6 addressing zone + + fn __init__(out self): + self.ip = "127.0.0.1" + self.port = 8000 + self.zone = "" + + fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000): + self.ip = ip + self.port = port + self.zone = "" + + fn __init__(out self, network: NetworkType, ip: String, port: UInt16, zone: String = ""): + self.ip = ip + self.port = port + self.zone = zone + + @always_inline + fn address_family(self) -> Int: + if network == NetworkType.tcp4: + return AF_INET + elif network == NetworkType.tcp6: + return AF_INET6 + else: + return AF_UNSPEC + + @always_inline + fn is_v4(self) -> Bool: + return network == NetworkType.tcp4 + + @always_inline + fn is_v6(self) -> Bool: + return network == NetworkType.tcp6 + + @always_inline + fn is_unix(self) -> Bool: + return False + + fn __eq__(self, other: Self) -> Bool: + return self.ip == other.ip and self.port == other.port and self.zone == other.zone + + fn __ne__(self, other: Self) -> Bool: + return not self == other + + fn __str__(self) -> String: + if self.zone != "": + return join_host_port(self.ip + "%" + self.zone, str(self.port)) + return join_host_port(self.ip, str(self.port)) + + fn __repr__(self) -> String: + return String.write(self) + + fn write_to[W: Writer, //](self, mut writer: W): + writer.write("TCPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")") + + +@value +struct UDPAddr[network: NetworkType = NetworkType.udp4](Addr): + alias _type = "UDPAddr" + var ip: String + var port: UInt16 + var zone: String # IPv6 addressing zone + + fn __init__(out self): + self.ip = "127.0.0.1" + self.port = 8000 + self.zone = "" + + fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000): + self.ip = ip + self.port = port + self.zone = "" + + fn __init__(out self, network: NetworkType, ip: String, port: UInt16): + self.ip = ip + self.port = port + self.zone = "" + + @always_inline + fn address_family(self) -> Int: + if network == NetworkType.udp4: + return AF_INET + elif network == NetworkType.udp6: + return AF_INET6 + else: + return AF_UNSPEC + + @always_inline + fn is_v4(self) -> Bool: + return network == NetworkType.udp4 + + @always_inline + fn is_v6(self) -> Bool: + return network == NetworkType.udp6 + + @always_inline + fn is_unix(self) -> Bool: + return False + + fn __eq__(self, other: Self) -> Bool: + return self.ip == other.ip and self.port == other.port and self.zone == other.zone + + fn __ne__(self, other: Self) -> Bool: + return not self == other + + fn __str__(self) -> String: + if self.zone != "": + return join_host_port(self.ip + "%" + self.zone, str(self.port)) + return join_host_port(self.ip, str(self.port)) + + fn __repr__(self) -> String: + return String.write(self) + + fn write_to[W: Writer, //](self, mut writer: W): + writer.write("UDPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")") + +@value +@register_passable("trivial") +struct addrinfo_macos(AnAddrInfo): + """ + For MacOS, I had to swap the order of ai_canonname and ai_addr. + https://stackoverflow.com/questions/53575101/calling-getaddrinfo-directly-from-python-ai-addr-is-null-pointer. + """ + + var ai_flags: c_int + var ai_family: c_int + var ai_socktype: c_int + var ai_protocol: c_int + var ai_addrlen: socklen_t + var ai_canonname: UnsafePointer[c_char] + var ai_addr: UnsafePointer[sockaddr] + var ai_next: OpaquePointer + + fn __init__( + out self, + ai_flags: c_int = 0, + ai_family: c_int = 0, + ai_socktype: c_int = 0, + ai_protocol: c_int = 0, + ai_addrlen: socklen_t = 0, + ): + self.ai_flags = ai_flags + self.ai_family = ai_family + self.ai_socktype = ai_socktype + self.ai_protocol = ai_protocol + self.ai_addrlen = ai_addrlen + self.ai_canonname = UnsafePointer[c_char]() + self.ai_addr = UnsafePointer[sockaddr]() + self.ai_next = OpaquePointer() + + fn get_ip_address(self, host: String) raises -> in_addr: + """Returns an IP address based on the host. + This is a MacOS-specific implementation. + + Args: + host: String - The host to get the IP from. + + Returns: + The IP address. + """ + var result = UnsafePointer[Self]() + var hints = Self(ai_flags=0, ai_family=AF_INET, ai_socktype=SOCK_STREAM, ai_protocol=0) + try: + getaddrinfo(host, String(), hints, result) + except e: + logger.error("Failed to get IP address.") + raise e + + if not result[].ai_addr: + freeaddrinfo(result) + raise Error("Failed to get IP address because the response's `ai_addr` was null.") + + var ip = result[].ai_addr.bitcast[sockaddr_in]()[].sin_addr + freeaddrinfo(result) + return ip + + +@value +@register_passable("trivial") +struct addrinfo_unix(AnAddrInfo): + """Standard addrinfo struct for Unix systems. + Overwrites the existing libc `getaddrinfo` function to adhere to the AnAddrInfo trait. + """ + + var ai_flags: c_int + var ai_family: c_int + var ai_socktype: c_int + var ai_protocol: c_int + var ai_addrlen: socklen_t + var ai_addr: UnsafePointer[sockaddr] + var ai_canonname: UnsafePointer[c_char] + var ai_next: OpaquePointer + + fn __init__( + out self, + ai_flags: c_int = 0, + ai_family: c_int = 0, + ai_socktype: c_int = 0, + ai_protocol: c_int = 0, + ai_addrlen: socklen_t = 0, + ): + self.ai_flags = ai_flags + self.ai_family = ai_family + self.ai_socktype = ai_socktype + self.ai_protocol = ai_protocol + self.ai_addrlen = ai_addrlen + self.ai_addr = UnsafePointer[sockaddr]() + self.ai_canonname = UnsafePointer[c_char]() + self.ai_next = OpaquePointer() + + fn get_ip_address(self, host: String) raises -> in_addr: + """Returns an IP address based on the host. + This is a Unix-specific implementation. + + Args: + host: String - The host to get IP from. + + Returns: + The IP address. + """ + var result = UnsafePointer[Self]() + var hints = Self(ai_flags=0, ai_family=AF_INET, ai_socktype=SOCK_STREAM, ai_protocol=0) + try: + getaddrinfo(host, String(), hints, result) + except e: + logger.error("Failed to get IP address.") + raise e + + if not result[].ai_addr: + freeaddrinfo(result) + raise Error("Failed to get IP address because the response's `ai_addr` was null.") + + var ip = result[].ai_addr.bitcast[sockaddr_in]()[].sin_addr + freeaddrinfo(result) + return ip + +fn is_ip_protocol(network: NetworkType) -> Bool: + """Check if the network type is an IP protocol.""" + return network in (NetworkType.ip, NetworkType.ip4, NetworkType.ip6) + +fn is_ipv4(network: NetworkType) -> Bool: + """Check if the network type is IPv4.""" + return network in (NetworkType.tcp4, NetworkType.udp4) + +fn is_ipv6(network: NetworkType) -> Bool: + """Check if the network type is IPv6.""" + return network in (NetworkType.tcp6, NetworkType.udp6) + +fn resolve_localhost(host: String, network: NetworkType) -> String: + """Resolve localhost to the appropriate IP address based on network type.""" + if host != AddressConstants.LOCALHOST: + return host + + if is_ipv4(network): + return AddressConstants.IPV4_LOCALHOST + elif is_ipv6(network): + return AddressConstants.IPV6_LOCALHOST + return host + +fn parse_ipv6_bracketed_address(address: String) raises -> (String, Int): + """Parse an IPv6 address enclosed in brackets. + + Returns: + Tuple of (host, colon_index_offset) + """ + if address[0] != "[": + return address, 0 + + var end_bracket_index = address.find("]") + if end_bracket_index == -1: + raise Error("missing ']' in address") + + if end_bracket_index + 1 == len(address): + raise MissingPortError + + var colon_index = end_bracket_index + 1 + if address[colon_index] != ":": + raise MissingPortError + + return ( + address[1:end_bracket_index], + end_bracket_index + 1 + ) + +fn validate_no_brackets(address: String, start_idx: Int, end_idx: Int = -1) raises: + """Validate that the address segment contains no brackets.""" + if address[start_idx:end_idx].find("[") != -1: + raise Error("unexpected '[' in address") + if address[start_idx:end_idx].find("]") != -1: + raise Error("unexpected ']' in address") + +fn parse_port(port_str: String) raises -> UInt16: + """Parse and validate port number.""" + if port_str == AddressConstants.EMPTY: + raise MissingPortError + + var port = int(port_str) + if port < MIN_PORT or port > MAX_PORT: + raise Error("Port number out of range (0-65535)") + + return UInt16(port) + +fn parse_address(network: NetworkType, address: String) raises -> (String, UInt16): + """Parse an address string into a host and port. + + Args: + network: The network type (tcp, tcp4, tcp6, udp, udp4, udp6, ip, ip4, ip6, unix) + address: The address string + + Returns: + Tuple containing the host and port + """ + # Handle IP protocols separately + if is_ip_protocol(network): + if address.find(":") != -1: + raise Error("IP protocol addresses should not include ports") + + var host = resolve_localhost(address, network) + if host == AddressConstants.EMPTY: + raise Error("missing host") + return host, DEFAULT_IP_PORT + + # Parse regular addresses + var colon_index = address.rfind(":") + if colon_index == -1: + raise MissingPortError + + var host: String + var bracket_offset: Int + + # Handle IPv6 addresses + try: + (host, bracket_offset) = parse_ipv6_bracketed_address(address) + except e: + raise e + + # Validate no unexpected brackets + validate_no_brackets(address, bracket_offset) + + # Parse and validate port + var port = parse_port(address[colon_index + 1:]) + + # Resolve localhost if needed + host = resolve_localhost(host, network) + if host == AddressConstants.EMPTY: + raise Error("missing host") + + return host, port + + + +# TODO: Support IPv6 long form. +fn join_host_port(host: String, port: String) -> String: + if host.find(":") != -1: # must be IPv6 literal + return "[" + host + "]:" + port + return host + ":" + port + + +alias MissingPortError = Error("missing port in address") +alias TooManyColonsError = Error("too many colons in address") + +fn binary_port_to_int(port: UInt16) -> Int: + """Convert a binary port to an integer. + + Args: + port: The binary port. + + Returns: + The port as an integer. + """ + return int(ntohs(port)) + + +fn binary_ip_to_string[address_family: Int32](owned ip_address: UInt32) raises -> String: + """Convert a binary IP address to a string by calling `inet_ntop`. + + Parameters: + address_family: The address family of the IP address. + + Args: + ip_address: The binary IP address. + + Returns: + The IP address as a string. + """ + constrained[int(address_family) in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]() + var ip: String + + @parameter + if address_family == AF_INET: + ip = inet_ntop[address_family, INET_ADDRSTRLEN](ip_address) + else: + ip = inet_ntop[address_family, INET6_ADDRSTRLEN](ip_address) + + return ip + + +fn _getaddrinfo[ + T: AnAddrInfo, hints_origin: MutableOrigin, result_origin: MutableOrigin, // +]( + nodename: UnsafePointer[c_char], + servname: UnsafePointer[c_char], + hints: Pointer[T, hints_origin], + res: Pointer[UnsafePointer[T], result_origin], +) -> c_int: + """Libc POSIX `getaddrinfo` function. + + Args: + nodename: The node name. + servname: The service name. + hints: A Pointer to the hints. + res: A UnsafePointer to the result. + + Returns: + 0 on success, an error code on failure. + + #### C Function + ```c + int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res) + ``` + + #### Notes: + * Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html + """ + return external_call[ + "getaddrinfo", + c_int, # FnName, RetType + UnsafePointer[c_char], + UnsafePointer[c_char], + Pointer[T, hints_origin], # Args + Pointer[UnsafePointer[T], result_origin], # Args + ](nodename, servname, hints, res) + + +fn getaddrinfo[ + T: AnAddrInfo, // +](node: String, service: String, mut hints: T, mut res: UnsafePointer[T],) raises: + """Libc POSIX `getaddrinfo` function. + + Args: + node: The node name. + service: The service name. + hints: A Pointer to the hints. + res: A UnsafePointer to the result. + + Raises: + Error: If an error occurs while attempting to receive data from the socket. + EAI_AGAIN: The name could not be resolved at this time. Future attempts may succeed. + EAI_BADFLAGS: The `ai_flags` value was invalid. + EAI_FAIL: A non-recoverable error occurred when attempting to resolve the name. + EAI_FAMILY: The `ai_family` member of the `hints` argument is not supported. + EAI_MEMORY: Out of memory. + EAI_NONAME: The name does not resolve for the supplied parameters. + EAI_SERVICE: The `servname` is not supported for `ai_socktype`. + EAI_SOCKTYPE: The `ai_socktype` is not supported. + EAI_SYSTEM: A system error occurred. `errno` is set in this case. + + #### C Function + ```c + int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res) + ``` + + #### Notes: + * Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html + """ + var result = _getaddrinfo( + node.unsafe_ptr(), service.unsafe_ptr(), Pointer.address_of(hints), Pointer.address_of(res) + ) + if result != 0: + # gai_strerror returns a char buffer that we don't know the length of. + # TODO: Perhaps switch to writing bytes once the Writer trait allows writing individual bytes. + var err = gai_strerror(result) + var msg = List[Byte, True]() + var i = 0 + while err[i] != 0: + msg.append(err[i]) + i += 1 + msg.append(0) + raise Error("getaddrinfo: " + String(msg^)) + + +fn freeaddrinfo[T: AnAddrInfo, //](ptr: UnsafePointer[T]): + """Free the memory allocated by `getaddrinfo`.""" + external_call["freeaddrinfo", NoneType, UnsafePointer[T]](ptr) diff --git a/lightbug_http/connection.mojo b/lightbug_http/connection.mojo new file mode 100644 index 00000000..dcf10eb9 --- /dev/null +++ b/lightbug_http/connection.mojo @@ -0,0 +1,394 @@ +from time import sleep +from memory import Span +from sys.info import os_is_macos +from lightbug_http.strings import NetworkType +from lightbug_http.io.bytes import Bytes, bytes +from lightbug_http.io.sync import Duration +from lightbug_http.address import parse_address, TCPAddr, UDPAddr +from lightbug_http._libc import ( + sockaddr, + AF_INET, + SOCK_DGRAM, + SO_REUSEADDR, + socket, + connect, + listen, + accept, + send, + bind, + shutdown, + close, +) +from lightbug_http._logger import logger +from lightbug_http.socket import Socket + + +alias default_buffer_size = 4096 +"""The default buffer size for reading and writing data.""" +alias default_tcp_keep_alive = Duration(15 * 1000 * 1000 * 1000) # 15 seconds +"""The default TCP keep-alive duration.""" + + +trait Connection(Movable): + fn read(self, mut buf: Bytes) raises -> Int: + ... + + fn write(self, buf: Span[Byte]) raises -> Int: + ... + + fn close(mut self) raises: + ... + + fn shutdown(mut self) raises -> None: + ... + + fn teardown(mut self) raises: + ... + + fn local_addr(self) -> TCPAddr: + ... + + fn remote_addr(self) -> TCPAddr: + ... + + +struct NoTLSListener: + """A TCP listener that listens for incoming connections and can accept them.""" + + var socket: Socket[TCPAddr] + + fn __init__(out self, owned socket: Socket[TCPAddr]): + self.socket = socket^ + + fn __init__(out self) raises: + self.socket = Socket[TCPAddr]() + + fn __moveinit__(out self, owned existing: Self): + self.socket = existing.socket^ + + fn accept(self) raises -> TCPConnection: + return TCPConnection(self.socket.accept()) + + fn close(mut self) raises -> None: + return self.socket.close() + + fn shutdown(mut self) raises -> None: + return self.socket.shutdown() + + fn teardown(mut self) raises: + self.socket.teardown() + + fn addr(self) -> TCPAddr: + return self.socket.local_address() + + +struct ListenConfig: + var _keep_alive: Duration + + fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive): + self._keep_alive = keep_alive + + fn listen[network: NetworkType = NetworkType.tcp4](mut self, address: String) raises -> NoTLSListener: + var local = parse_address(network, address) + var addr = TCPAddr(local[0], local[1]) + var socket: Socket[TCPAddr] + try: + socket = Socket[TCPAddr]() + except e: + logger.error(e) + raise Error("ListenConfig.listen: Failed to create listener due to socket creation failure.") + + @parameter + # TODO: do we want to add SO_REUSEPORT on linux? Doesn't work on some systems + if os_is_macos(): + try: + socket.set_socket_option(SO_REUSEADDR, 1) + except e: + logger.warn("ListenConfig.listen: Failed to set socket as reusable", e) + + var bind_success = False + var bind_fail_logged = False + while not bind_success: + try: + socket.bind(addr.ip, addr.port) + bind_success = True + except e: + if not bind_fail_logged: + print("Bind attempt failed: ", e) + print("Retrying. Might take 10-15 seconds.") + bind_fail_logged = True + print(".", end="", flush=True) + + try: + socket.shutdown() + except e: + logger.error("ListenConfig.listen: Failed to shutdown socket:", e) + # TODO: Should shutdown failure be a hard failure? We can still ungracefully close the socket. + sleep(UInt(1)) + + try: + socket.listen(128) + except e: + logger.error(e) + raise Error("ListenConfig.listen: Listen failed on sockfd: " + str(socket.fd)) + + var listener = NoTLSListener(socket^) + var msg = String.write("\nšŸ”„šŸ Lightbug is listening on ", "http://", addr.ip, ":", str(addr.port)) + print(msg) + print("Ready to accept connections...") + + return listener^ + + +struct TCPConnection: + var socket: Socket[TCPAddr] + + fn __init__(out self, owned socket: Socket[TCPAddr]): + self.socket = socket^ + + fn __moveinit__(out self, owned existing: Self): + self.socket = existing.socket^ + + fn read(self, mut buf: Bytes) raises -> Int: + try: + return self.socket.receive(buf) + except e: + if str(e) == "EOF": + raise e + else: + logger.error(e) + raise Error("TCPConnection.read: Failed to read data from connection.") + + fn write(self, buf: Span[Byte]) raises -> Int: + if buf[-1] == 0: + raise Error("TCPConnection.write: Buffer must not be null-terminated.") + + try: + return self.socket.send(buf) + except e: + logger.error("TCPConnection.write: Failed to write data to connection.") + raise e + + fn close(mut self) raises: + self.socket.close() + + fn shutdown(mut self) raises: + self.socket.shutdown() + + fn teardown(mut self) raises: + self.socket.teardown() + + fn is_closed(self) -> Bool: + return self.socket._closed + + # TODO: Switch to property or return ref when trait supports attributes. + fn local_addr(self) -> TCPAddr: + return self.socket.local_address() + + fn remote_addr(self) -> TCPAddr: + return self.socket.remote_address() + + +struct UDPConnection: + var socket: Socket[UDPAddr] + + fn __init__(out self, owned socket: Socket[UDPAddr]): + self.socket = socket^ + + fn __moveinit__(out self, owned existing: Self): + self.socket = existing.socket^ + + fn read_from(mut self, size: Int = default_buffer_size) raises -> (Bytes, String, UInt16): + """Reads data from the underlying file descriptor. + + Args: + size: The size of the buffer to read data into. + + Returns: + The number of bytes read, or an error if one occurred. + + Raises: + Error: If an error occurred while reading data. + """ + return self.socket.receive_from(size) + + fn read_from(mut self, mut dest: Bytes) raises -> (UInt, String, UInt16): + """Reads data from the underlying file descriptor. + + Args: + dest: The buffer to read data into. + + Returns: + The number of bytes read, or an error if one occurred. + + Raises: + Error: If an error occurred while reading data. + """ + return self.socket.receive_from(dest) + + fn write_to(mut self, src: Span[Byte], address: UDPAddr) raises -> Int: + """Writes data to the underlying file descriptor. + + Args: + src: The buffer to read data into. + address: The remote peer address. + + Returns: + The number of bytes written, or an error if one occurred. + + Raises: + Error: If an error occurred while writing data. + """ + return self.socket.send_to(src, address.ip, address.port) + + fn write_to(mut self, src: Span[Byte], host: String, port: UInt16) raises -> Int: + """Writes data to the underlying file descriptor. + + Args: + src: The buffer to read data into. + host: The remote peer address in IPv4 format. + port: The remote peer port. + + Returns: + The number of bytes written, or an error if one occurred. + + Raises: + Error: If an error occurred while writing data. + """ + return self.socket.send_to(src, host, port) + + fn close(mut self) raises: + self.socket.close() + + fn shutdown(mut self) raises: + self.socket.shutdown() + + fn teardown(mut self) raises: + self.socket.teardown() + + fn is_closed(self) -> Bool: + return self.socket._closed + + fn local_addr(self) -> ref [self.socket._local_address] UDPAddr: + return self.socket.local_address() + + fn remote_addr(self) -> ref [self.socket._remote_address] UDPAddr: + return self.socket.remote_address() + +fn create_connection(host: String, port: UInt16) raises -> TCPConnection: + """Connect to a server using a socket. + + Args: + host: The host to connect to. + port: The port to connect on. + + Returns: + The socket file descriptor. + """ + var socket = Socket[TCPAddr]() + try: + socket.connect(host, port) + except e: + logger.error(e) + try: + socket.shutdown() + except e: + logger.error("Failed to shutdown socket: " + str(e)) + raise Error("Failed to establish a connection to the server.") + + return TCPConnection(socket^) + +fn listen_udp(local_address: UDPAddr) raises -> UDPConnection: + """Creates a new UDP listener. + + Args: + local_address: The local address to listen on. + + Returns: + A UDP connection. + + Raises: + Error: If the address is invalid or failed to bind the socket. + """ + var socket = Socket[UDPAddr](socket_type=SOCK_DGRAM) + socket.bind(local_address.ip, local_address.port) + return UDPConnection(socket^) + + +fn listen_udp(local_address: String) raises -> UDPConnection: + """Creates a new UDP listener. + + Args: + local_address: The address to listen on. The format is "host:port". + + Returns: + A UDP connection. + + Raises: + Error: If the address is invalid or failed to bind the socket. + """ + var address = parse_address(NetworkType.udp4, local_address) + return listen_udp(UDPAddr(address[0], address[1])) + + +fn listen_udp(host: String, port: UInt16) raises -> UDPConnection: + """Creates a new UDP listener. + + Args: + host: The address to listen on in ipv4 format. + port: The port number. + + Returns: + A UDP connection. + + Raises: + Error: If the address is invalid or failed to bind the socket. + """ + return listen_udp(UDPAddr(host, port)) + + +fn dial_udp(local_address: UDPAddr) raises -> UDPConnection: + """Connects to the address on the named network. The network must be "udp", "udp4", or "udp6". + + Args: + local_address: The local address. + + Returns: + The UDP connection. + + Raises: + Error: If the network type is not supported or failed to connect to the address. + """ + return UDPConnection(Socket[UDPAddr](local_address=local_address, socket_type=SOCK_DGRAM)) + + +fn dial_udp(network: NetworkType, local_address: String) raises -> UDPConnection: + """Connects to the address on the named network. The network must be "udp", "udp4", or "udp6". + + Args: + local_address: The local address. + + Returns: + The UDP connection. + + Raises: + Error: If the network type is not supported or failed to connect to the address. + """ + var address = parse_address(network, local_address) + return dial_udp(UDPAddr(network, address[0], address[1])) + + +fn dial_udp(host: String, port: UInt16) raises -> UDPConnection: + """Connects to the address on the udp network. + + Args: + host: The host to connect to. + port: The port to connect on. + + Returns: + The UDP connection. + + Raises: + Error: If failed to connect to the address. + """ + return dial_udp(UDPAddr(host, port)) diff --git a/lightbug_http/io/bytes.mojo b/lightbug_http/io/bytes.mojo index 007fbdf4..cbc5b4c3 100644 --- a/lightbug_http/io/bytes.mojo +++ b/lightbug_http/io/bytes.mojo @@ -1,7 +1,7 @@ from utils import StringSlice from memory.span import Span, _SpanIter from lightbug_http.strings import BytesConstant -from lightbug_http.net import default_buffer_size +from lightbug_http.connection import default_buffer_size alias Bytes = List[Byte, True] diff --git a/lightbug_http/net.mojo b/lightbug_http/net.mojo deleted file mode 100644 index 46e392ec..00000000 --- a/lightbug_http/net.mojo +++ /dev/null @@ -1,830 +0,0 @@ -from utils import StaticTuple -from time import sleep, perf_counter_ns -from memory import UnsafePointer, stack_allocation, Span -from sys.info import sizeof, os_is_macos -from sys.ffi import external_call, OpaquePointer -from lightbug_http.strings import NetworkType, to_string -from lightbug_http.io.bytes import Bytes, bytes -from lightbug_http.io.sync import Duration -from lightbug_http._libc import ( - c_void, - c_int, - c_uint, - c_char, - c_ssize_t, - in_addr, - sockaddr, - sockaddr_in, - socklen_t, - AI_PASSIVE, - AF_INET, - AF_INET6, - SOCK_STREAM, - SOCK_DGRAM, - SOL_SOCKET, - SO_REUSEADDR, - SO_REUSEPORT, - SHUT_RDWR, - htons, - ntohs, - ntohl, - inet_pton, - inet_ntop, - socket, - connect, - setsockopt, - listen, - accept, - send, - recv, - bind, - shutdown, - close, - getsockname, - getpeername, - gai_strerror, - INET_ADDRSTRLEN, - INET6_ADDRSTRLEN, -) -from lightbug_http._logger import logger -from lightbug_http.socket import Socket - - -alias default_buffer_size = 4096 -"""The default buffer size for reading and writing data.""" -alias default_tcp_keep_alive = Duration(15 * 1000 * 1000 * 1000) # 15 seconds -"""The default TCP keep-alive duration.""" - - -trait Connection(Movable): - fn read(self, mut buf: Bytes) raises -> Int: - ... - - fn write(self, buf: Span[Byte]) raises -> Int: - ... - - fn close(mut self) raises: - ... - - fn shutdown(mut self) raises -> None: - ... - - fn teardown(mut self) raises: - ... - - fn local_addr(self) -> TCPAddr: - ... - - fn remote_addr(self) -> TCPAddr: - ... - - -trait Addr(Stringable, Representable, Writable, EqualityComparableCollectionElement): - alias _type: StringLiteral - - fn __init__(out self): - ... - - fn __init__(out self, ip: String, port: UInt16): - ... - - fn network(self) -> String: - ... - - -trait AnAddrInfo: - fn get_ip_address(self, host: String) raises -> in_addr: - """TODO: Once default functions can be implemented in traits, this function should use the functions currently - implemented in the `addrinfo_macos` and `addrinfo_unix` structs. - """ - ... - - -struct NoTLSListener: - """A TCP listener that listens for incoming connections and can accept them.""" - - var socket: Socket[TCPAddr] - - fn __init__(out self, owned socket: Socket[TCPAddr]): - self.socket = socket^ - - fn __init__(out self) raises: - self.socket = Socket[TCPAddr]() - - fn __moveinit__(out self, owned existing: Self): - self.socket = existing.socket^ - - fn accept(self) raises -> TCPConnection: - return TCPConnection(self.socket.accept()) - - fn close(mut self) raises -> None: - return self.socket.close() - - fn shutdown(mut self) raises -> None: - return self.socket.shutdown() - - fn teardown(mut self) raises: - self.socket.teardown() - - fn addr(self) -> TCPAddr: - return self.socket.local_address() - - -struct ListenConfig: - var _keep_alive: Duration - - fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive): - self._keep_alive = keep_alive - - fn listen[address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener: - constrained[address_family in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]() - var local = parse_address(address) - var addr = TCPAddr(local[0], local[1]) - var socket: Socket[TCPAddr] - try: - socket = Socket[TCPAddr]() - except e: - logger.error(e) - raise Error("ListenConfig.listen: Failed to create listener due to socket creation failure.") - - @parameter - # TODO: do we want to reuse port on linux? currently doesn't work - if os_is_macos(): - try: - socket.set_socket_option(SO_REUSEADDR, 1) - except e: - logger.warn("ListenConfig.listen: Failed to set socket as reusable", e) - - var bind_success = False - var bind_fail_logged = False - while not bind_success: - try: - socket.bind(addr.ip, addr.port) - bind_success = True - except e: - if not bind_fail_logged: - print("Bind attempt failed: ", e) - print("Retrying. Might take 10-15 seconds.") - bind_fail_logged = True - print(".", end="", flush=True) - - try: - socket.shutdown() - except e: - logger.error("ListenConfig.listen: Failed to shutdown socket:", e) - # TODO: Should shutdown failure be a hard failure? We can still ungracefully close the socket. - sleep(UInt(1)) - - try: - socket.listen(128) - except e: - logger.error(e) - raise Error("ListenConfig.listen: Listen failed on sockfd: " + str(socket.fd)) - - var listener = NoTLSListener(socket^) - var msg = String.write("\nšŸ”„šŸ Lightbug is listening on ", "http://", addr.ip, ":", str(addr.port)) - print(msg) - print("Ready to accept connections...") - - return listener^ - - -struct TCPConnection: - var socket: Socket[TCPAddr] - - fn __init__(out self, owned socket: Socket[TCPAddr]): - self.socket = socket^ - - fn __moveinit__(out self, owned existing: Self): - self.socket = existing.socket^ - - fn read(self, mut buf: Bytes) raises -> Int: - try: - return self.socket.receive(buf) - except e: - if str(e) == "EOF": - raise e - else: - logger.error(e) - raise Error("TCPConnection.read: Failed to read data from connection.") - - fn write(self, buf: Span[Byte]) raises -> Int: - if buf[-1] == 0: - raise Error("TCPConnection.write: Buffer must not be null-terminated.") - - try: - return self.socket.send(buf) - except e: - logger.error("TCPConnection.write: Failed to write data to connection.") - raise e - - fn close(mut self) raises: - self.socket.close() - - fn shutdown(mut self) raises: - self.socket.shutdown() - - fn teardown(mut self) raises: - self.socket.teardown() - - fn is_closed(self) -> Bool: - return self.socket._closed - - # TODO: Switch to property or return ref when trait supports attributes. - fn local_addr(self) -> TCPAddr: - return self.socket.local_address() - - fn remote_addr(self) -> TCPAddr: - return self.socket.remote_address() - - -struct UDPConnection: - var socket: Socket[UDPAddr] - - fn __init__(out self, owned socket: Socket[UDPAddr]): - self.socket = socket^ - - fn __moveinit__(out self, owned existing: Self): - self.socket = existing.socket^ - - fn read_from(mut self, size: Int = default_buffer_size) raises -> (Bytes, String, UInt16): - """Reads data from the underlying file descriptor. - - Args: - size: The size of the buffer to read data into. - - Returns: - The number of bytes read, or an error if one occurred. - - Raises: - Error: If an error occurred while reading data. - """ - return self.socket.receive_from(size) - - fn read_from(mut self, mut dest: Bytes) raises -> (UInt, String, UInt16): - """Reads data from the underlying file descriptor. - - Args: - dest: The buffer to read data into. - - Returns: - The number of bytes read, or an error if one occurred. - - Raises: - Error: If an error occurred while reading data. - """ - return self.socket.receive_from(dest) - - fn write_to(mut self, src: Span[Byte], address: UDPAddr) raises -> Int: - """Writes data to the underlying file descriptor. - - Args: - src: The buffer to read data into. - address: The remote peer address. - - Returns: - The number of bytes written, or an error if one occurred. - - Raises: - Error: If an error occurred while writing data. - """ - return self.socket.send_to(src, address.ip, address.port) - - fn write_to(mut self, src: Span[Byte], host: String, port: UInt16) raises -> Int: - """Writes data to the underlying file descriptor. - - Args: - src: The buffer to read data into. - host: The remote peer address in IPv4 format. - port: The remote peer port. - - Returns: - The number of bytes written, or an error if one occurred. - - Raises: - Error: If an error occurred while writing data. - """ - return self.socket.send_to(src, host, port) - - fn close(mut self) raises: - self.socket.close() - - fn shutdown(mut self) raises: - self.socket.shutdown() - - fn teardown(mut self) raises: - self.socket.teardown() - - fn is_closed(self) -> Bool: - return self.socket._closed - - fn local_addr(self) -> ref [self.socket._local_address] UDPAddr: - return self.socket.local_address() - - fn remote_addr(self) -> ref [self.socket._remote_address] UDPAddr: - return self.socket.remote_address() - - -@value -@register_passable("trivial") -struct addrinfo_macos(AnAddrInfo): - """ - For MacOS, I had to swap the order of ai_canonname and ai_addr. - https://stackoverflow.com/questions/53575101/calling-getaddrinfo-directly-from-python-ai-addr-is-null-pointer. - """ - - var ai_flags: c_int - var ai_family: c_int - var ai_socktype: c_int - var ai_protocol: c_int - var ai_addrlen: socklen_t - var ai_canonname: UnsafePointer[c_char] - var ai_addr: UnsafePointer[sockaddr] - var ai_next: OpaquePointer - - fn __init__( - out self, - ai_flags: c_int = 0, - ai_family: c_int = 0, - ai_socktype: c_int = 0, - ai_protocol: c_int = 0, - ai_addrlen: socklen_t = 0, - ): - self.ai_flags = ai_flags - self.ai_family = ai_family - self.ai_socktype = ai_socktype - self.ai_protocol = ai_protocol - self.ai_addrlen = ai_addrlen - self.ai_canonname = UnsafePointer[c_char]() - self.ai_addr = UnsafePointer[sockaddr]() - self.ai_next = OpaquePointer() - - fn get_ip_address(self, host: String) raises -> in_addr: - """Returns an IP address based on the host. - This is a MacOS-specific implementation. - - Args: - host: String - The host to get the IP from. - - Returns: - The IP address. - """ - var result = UnsafePointer[Self]() - var hints = Self(ai_flags=0, ai_family=AF_INET, ai_socktype=SOCK_STREAM, ai_protocol=0) - try: - getaddrinfo(host, String(), hints, result) - except e: - logger.error("Failed to get IP address.") - raise e - - if not result[].ai_addr: - freeaddrinfo(result) - raise Error("Failed to get IP address because the response's `ai_addr` was null.") - - var ip = result[].ai_addr.bitcast[sockaddr_in]()[].sin_addr - freeaddrinfo(result) - return ip - - -@value -@register_passable("trivial") -struct addrinfo_unix(AnAddrInfo): - """Standard addrinfo struct for Unix systems. - Overwrites the existing libc `getaddrinfo` function to adhere to the AnAddrInfo trait. - """ - - var ai_flags: c_int - var ai_family: c_int - var ai_socktype: c_int - var ai_protocol: c_int - var ai_addrlen: socklen_t - var ai_addr: UnsafePointer[sockaddr] - var ai_canonname: UnsafePointer[c_char] - var ai_next: OpaquePointer - - fn __init__( - out self, - ai_flags: c_int = 0, - ai_family: c_int = 0, - ai_socktype: c_int = 0, - ai_protocol: c_int = 0, - ai_addrlen: socklen_t = 0, - ): - self.ai_flags = ai_flags - self.ai_family = ai_family - self.ai_socktype = ai_socktype - self.ai_protocol = ai_protocol - self.ai_addrlen = ai_addrlen - self.ai_addr = UnsafePointer[sockaddr]() - self.ai_canonname = UnsafePointer[c_char]() - self.ai_next = OpaquePointer() - - fn get_ip_address(self, host: String) raises -> in_addr: - """Returns an IP address based on the host. - This is a Unix-specific implementation. - - Args: - host: String - The host to get IP from. - - Returns: - The IP address. - """ - var result = UnsafePointer[Self]() - var hints = Self(ai_flags=0, ai_family=AF_INET, ai_socktype=SOCK_STREAM, ai_protocol=0) - try: - getaddrinfo(host, String(), hints, result) - except e: - logger.error("Failed to get IP address.") - raise e - - if not result[].ai_addr: - freeaddrinfo(result) - raise Error("Failed to get IP address because the response's `ai_addr` was null.") - - var ip = result[].ai_addr.bitcast[sockaddr_in]()[].sin_addr - freeaddrinfo(result) - return ip - - -fn create_connection(host: String, port: UInt16) raises -> TCPConnection: - """Connect to a server using a socket. - - Args: - host: The host to connect to. - port: The port to connect on. - - Returns: - The socket file descriptor. - """ - var socket = Socket[TCPAddr]() - try: - socket.connect(host, port) - except e: - logger.error(e) - try: - socket.shutdown() - except e: - logger.error("Failed to shutdown socket: " + str(e)) - raise Error("Failed to establish a connection to the server.") - - return TCPConnection(socket^) - - -@value -struct TCPAddr(Addr): - alias _type = "TCPAddr" - var ip: String - var port: UInt16 - var zone: String # IPv6 addressing zone - - fn __init__(out self): - self.ip = "127.0.0.1" - self.port = 8000 - self.zone = "" - - fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000): - self.ip = ip - self.port = port - self.zone = "" - - fn network(self) -> String: - return NetworkType.tcp.value - - fn __eq__(self, other: Self) -> Bool: - return self.ip == other.ip and self.port == other.port and self.zone == other.zone - - fn __ne__(self, other: Self) -> Bool: - return not self == other - - fn __str__(self) -> String: - if self.zone != "": - return join_host_port(self.ip + "%" + self.zone, str(self.port)) - return join_host_port(self.ip, str(self.port)) - - fn __repr__(self) -> String: - return String.write(self) - - fn write_to[W: Writer, //](self, mut writer: W): - writer.write("TCPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")") - - -@value -struct UDPAddr(Addr): - alias _type = "UDPAddr" - var ip: String - var port: UInt16 - var zone: String # IPv6 addressing zone - - fn __init__(out self): - self.ip = "127.0.0.1" - self.port = 8000 - self.zone = "" - - fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000): - self.ip = ip - self.port = port - self.zone = "" - - fn network(self) -> String: - return NetworkType.udp.value - - fn __eq__(self, other: Self) -> Bool: - return self.ip == other.ip and self.port == other.port and self.zone == other.zone - - fn __ne__(self, other: Self) -> Bool: - return not self == other - - fn __str__(self) -> String: - if self.zone != "": - return join_host_port(self.ip + "%" + self.zone, str(self.port)) - return join_host_port(self.ip, str(self.port)) - - fn __repr__(self) -> String: - return String.write(self) - - fn write_to[W: Writer, //](self, mut writer: W): - writer.write("UDPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")") - - -fn listen_udp(local_address: UDPAddr) raises -> UDPConnection: - """Creates a new UDP listener. - - Args: - local_address: The local address to listen on. - - Returns: - A UDP connection. - - Raises: - Error: If the address is invalid or failed to bind the socket. - """ - var socket = Socket[UDPAddr](socket_type=SOCK_DGRAM) - socket.bind(local_address.ip, local_address.port) - return UDPConnection(socket^) - - -fn listen_udp(local_address: String) raises -> UDPConnection: - """Creates a new UDP listener. - - Args: - local_address: The address to listen on. The format is "host:port". - - Returns: - A UDP connection. - - Raises: - Error: If the address is invalid or failed to bind the socket. - """ - var address = parse_address(local_address) - return listen_udp(UDPAddr(address[0], address[1])) - - -fn listen_udp(host: String, port: UInt16) raises -> UDPConnection: - """Creates a new UDP listener. - - Args: - host: The address to listen on in ipv4 format. - port: The port number. - - Returns: - A UDP connection. - - Raises: - Error: If the address is invalid or failed to bind the socket. - """ - return listen_udp(UDPAddr(host, port)) - - -fn dial_udp(local_address: UDPAddr) raises -> UDPConnection: - """Connects to the address on the named network. The network must be "udp", "udp4", or "udp6". - - Args: - local_address: The local address. - - Returns: - The UDP connection. - - Raises: - Error: If the network type is not supported or failed to connect to the address. - """ - return UDPConnection(Socket(local_address=local_address, socket_type=SOCK_DGRAM)) - - -fn dial_udp(local_address: String) raises -> UDPConnection: - """Connects to the address on the named network. The network must be "udp", "udp4", or "udp6". - - Args: - local_address: The local address. - - Returns: - The UDP connection. - - Raises: - Error: If the network type is not supported or failed to connect to the address. - """ - var address = parse_address(local_address) - return dial_udp(UDPAddr(address[0], address[1])) - - -fn dial_udp(host: String, port: UInt16) raises -> UDPConnection: - """Connects to the address on the named network. The network must be "udp", "udp4", or "udp6". - - Args: - host: The host to connect to. - port: The port to connect on. - - Returns: - The UDP connection. - - Raises: - Error: If the network type is not supported or failed to connect to the address. - """ - return dial_udp(UDPAddr(host, port)) - - -# TODO: Support IPv6 long form. -fn join_host_port(host: String, port: String) -> String: - if host.find(":") != -1: # must be IPv6 literal - return "[" + host + "]:" + port - return host + ":" + port - - -alias MissingPortError = Error("missing port in address") -alias TooManyColonsError = Error("too many colons in address") - - -fn parse_address(address: String) raises -> (String, UInt16): - """Parse an address string into a host and port. - - Args: - address: The address string. - - Returns: - A tuple containing the host and port. - """ - var colon_index = address.rfind(":") - if colon_index == -1: - raise MissingPortError - - var host: String = "" - var port: String = "" - var j: Int = 0 - var k: Int = 0 - - if address[0] == "[": - var end_bracket_index = address.find("]") - if end_bracket_index == -1: - raise Error("missing ']' in address") - - if end_bracket_index + 1 == len(address): - raise MissingPortError - elif end_bracket_index + 1 == colon_index: - host = address[1:end_bracket_index] - j = 1 - k = end_bracket_index + 1 - else: - if address[end_bracket_index + 1] == ":": - raise TooManyColonsError - else: - raise MissingPortError - else: - host = address[:colon_index] - if host.find(":") != -1: - raise TooManyColonsError - - if address[j:].find("[") != -1: - raise Error("unexpected '[' in address") - if address[k:].find("]") != -1: - raise Error("unexpected ']' in address") - - port = address[colon_index + 1 :] - if port == "": - raise MissingPortError - if host == "": - raise Error("missing host") - return host, UInt16(int(port)) - - -fn binary_port_to_int(port: UInt16) -> Int: - """Convert a binary port to an integer. - - Args: - port: The binary port. - - Returns: - The port as an integer. - """ - return int(ntohs(port)) - - -fn binary_ip_to_string[address_family: Int32](owned ip_address: UInt32) raises -> String: - """Convert a binary IP address to a string by calling `inet_ntop`. - - Parameters: - address_family: The address family of the IP address. - - Args: - ip_address: The binary IP address. - - Returns: - The IP address as a string. - """ - constrained[int(address_family) in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]() - var ip: String - - @parameter - if address_family == AF_INET: - ip = inet_ntop[address_family, INET_ADDRSTRLEN](ip_address) - else: - ip = inet_ntop[address_family, INET6_ADDRSTRLEN](ip_address) - - return ip - - -fn _getaddrinfo[ - T: AnAddrInfo, hints_origin: MutableOrigin, result_origin: MutableOrigin, // -]( - nodename: UnsafePointer[c_char], - servname: UnsafePointer[c_char], - hints: Pointer[T, hints_origin], - res: Pointer[UnsafePointer[T], result_origin], -) -> c_int: - """Libc POSIX `getaddrinfo` function. - - Args: - nodename: The node name. - servname: The service name. - hints: A Pointer to the hints. - res: A UnsafePointer to the result. - - Returns: - 0 on success, an error code on failure. - - #### C Function - ```c - int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res) - ``` - - #### Notes: - * Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html - """ - return external_call[ - "getaddrinfo", - c_int, # FnName, RetType - UnsafePointer[c_char], - UnsafePointer[c_char], - Pointer[T, hints_origin], # Args - Pointer[UnsafePointer[T], result_origin], # Args - ](nodename, servname, hints, res) - - -fn getaddrinfo[ - T: AnAddrInfo, // -](node: String, service: String, mut hints: T, mut res: UnsafePointer[T],) raises: - """Libc POSIX `getaddrinfo` function. - - Args: - node: The node name. - service: The service name. - hints: A Pointer to the hints. - res: A UnsafePointer to the result. - - Raises: - Error: If an error occurs while attempting to receive data from the socket. - EAI_AGAIN: The name could not be resolved at this time. Future attempts may succeed. - EAI_BADFLAGS: The `ai_flags` value was invalid. - EAI_FAIL: A non-recoverable error occurred when attempting to resolve the name. - EAI_FAMILY: The `ai_family` member of the `hints` argument is not supported. - EAI_MEMORY: Out of memory. - EAI_NONAME: The name does not resolve for the supplied parameters. - EAI_SERVICE: The `servname` is not supported for `ai_socktype`. - EAI_SOCKTYPE: The `ai_socktype` is not supported. - EAI_SYSTEM: A system error occurred. `errno` is set in this case. - - #### C Function - ```c - int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res) - ``` - - #### Notes: - * Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html - """ - var result = _getaddrinfo( - node.unsafe_ptr(), service.unsafe_ptr(), Pointer.address_of(hints), Pointer.address_of(res) - ) - if result != 0: - # gai_strerror returns a char buffer that we don't know the length of. - # TODO: Perhaps switch to writing bytes once the Writer trait allows writing individual bytes. - var err = gai_strerror(result) - var msg = List[Byte, True]() - var i = 0 - while err[i] != 0: - msg.append(err[i]) - i += 1 - msg.append(0) - raise Error("getaddrinfo: " + String(msg^)) - - -fn freeaddrinfo[T: AnAddrInfo, //](ptr: UnsafePointer[T]): - """Free the memory allocated by `getaddrinfo`.""" - external_call["freeaddrinfo", NoneType, UnsafePointer[T]](ptr) diff --git a/lightbug_http/server.mojo b/lightbug_http/server.mojo index 1832d413..66023aad 100644 --- a/lightbug_http/server.mojo +++ b/lightbug_http/server.mojo @@ -3,7 +3,7 @@ from lightbug_http.io.sync import Duration from lightbug_http.io.bytes import Bytes, bytes, ByteReader from lightbug_http.strings import NetworkType from lightbug_http._logger import logger -from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig +from lightbug_http.connection import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig from lightbug_http.socket import Socket from lightbug_http.http import HTTPRequest, encode from lightbug_http.http.common_response import InternalError diff --git a/lightbug_http/socket.mojo b/lightbug_http/socket.mojo index b794a165..ec97c510 100644 --- a/lightbug_http/socket.mojo +++ b/lightbug_http/socket.mojo @@ -46,14 +46,14 @@ from lightbug_http._libc import ( ) from lightbug_http.io.bytes import Bytes from lightbug_http.strings import NetworkType -from lightbug_http.net import ( +from lightbug_http.address import ( Addr, - default_buffer_size, binary_port_to_int, binary_ip_to_string, addrinfo_macos, addrinfo_unix, ) +from lightbug_http.connection import default_buffer_size from lightbug_http._logger import logger @@ -106,7 +106,6 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri """ self.socket_type = socket_type self.protocol = protocol - self.fd = socket(address_family, socket_type, 0) self._local_address = local_address self._remote_address = remote_address diff --git a/tests/lightbug_http/test_host_port.mojo b/tests/lightbug_http/test_host_port.mojo index 2ad444b3..76623454 100644 --- a/tests/lightbug_http/test_host_port.mojo +++ b/tests/lightbug_http/test_host_port.mojo @@ -1,5 +1,5 @@ import testing -from lightbug_http.net import join_host_port, parse_address, TCPAddr +from lightbug_http.address import join_host_port, parse_address, TCPAddr from lightbug_http.strings import NetworkType From 3f5019d5bd2dc7027add6c13cfe9dcd9e6203d86 Mon Sep 17 00:00:00 2001 From: Val Date: Sat, 1 Feb 2025 14:40:03 +0100 Subject: [PATCH 02/11] write unit tests --- README.md | 4 +- benchmark/bench_server.mojo | 2 +- "lightbug.\360\237\224\245" | 2 +- lightbug_http/_libc.mojo | 2 +- lightbug_http/address.mojo | 42 ++++++----- lightbug_http/server.mojo | 4 +- scripts/bench_server.sh | 2 +- tests/lightbug_http/test_host_port.mojo | 96 +++++++++++++++++++++++-- 8 files changed, 123 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 4e272b99..e69b87a9 100644 --- a/README.md +++ b/README.md @@ -126,12 +126,12 @@ Once you have a Mojo project set up locally, fn main() raises: var server = Server() var handler = Welcome() - server.listen_and_serve("0.0.0.0:8080", handler) + server.listen_and_serve("localhost:8080", handler) ``` Feel free to change the settings in `listen_and_serve()` to serve on a particular host and port. -Now send a request `0.0.0.0:8080`. You should see some details about the request printed out to the console. +Now send a request `localhost:8080`. You should see some details about the request printed out to the console. Congrats šŸ„³ You're using Lightbug! diff --git a/benchmark/bench_server.mojo b/benchmark/bench_server.mojo index 40804494..7a74846c 100644 --- a/benchmark/bench_server.mojo +++ b/benchmark/bench_server.mojo @@ -6,7 +6,7 @@ def main(): try: var server = Server(tcp_keep_alive=True) var handler = TechEmpowerRouter() - server.listen_and_serve("0.0.0.0:8080", handler) + server.listen_and_serve("localhost:8080", handler) except e: print("Error starting server: " + e.__str__()) return diff --git "a/lightbug.\360\237\224\245" "b/lightbug.\360\237\224\245" index 4cc40e23..195a777f 100644 --- "a/lightbug.\360\237\224\245" +++ "b/lightbug.\360\237\224\245" @@ -3,4 +3,4 @@ from lightbug_http import Welcome, Server fn main() raises: var server = Server() var handler = Welcome() - server.listen_and_serve("0.0.0.0:8080", handler) + server.listen_and_serve("localhost:8080", handler) diff --git a/lightbug_http/_libc.mojo b/lightbug_http/_libc.mojo index 16a30a00..45a078f0 100644 --- a/lightbug_http/_libc.mojo +++ b/lightbug_http/_libc.mojo @@ -574,7 +574,7 @@ fn inet_pton[address_family: Int32](src: UnsafePointer[c_char]) raises -> c_uint * This function is valid for `AF_INET` and `AF_INET6`. """ constrained[ - int(address_family) in [AF_INET, AF_INET6], "Address family must be either INET_ADDRSTRLEN or INET6_ADDRSTRLEN." + int(address_family) in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6." ]() var ip_buffer: UnsafePointer[c_void] diff --git a/lightbug_http/address.mojo b/lightbug_http/address.mojo index 000d3c93..a5a33ae3 100644 --- a/lightbug_http/address.mojo +++ b/lightbug_http/address.mojo @@ -313,11 +313,11 @@ fn is_ip_protocol(network: NetworkType) -> Bool: fn is_ipv4(network: NetworkType) -> Bool: """Check if the network type is IPv4.""" - return network in (NetworkType.tcp4, NetworkType.udp4) + return network in (NetworkType.tcp4, NetworkType.udp4, NetworkType.ip4) fn is_ipv6(network: NetworkType) -> Bool: """Check if the network type is IPv6.""" - return network in (NetworkType.tcp6, NetworkType.udp6) + return network in (NetworkType.tcp6, NetworkType.udp6, NetworkType.ip6) fn resolve_localhost(host: String, network: NetworkType) -> String: """Resolve localhost to the appropriate IP address based on network type.""" @@ -383,37 +383,44 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1 Returns: Tuple containing the host and port """ - # Handle IP protocols separately if is_ip_protocol(network): - if address.find(":") != -1: - raise Error("IP protocol addresses should not include ports") - var host = resolve_localhost(address, network) if host == AddressConstants.EMPTY: raise Error("missing host") + + # For IPv6 addresses in IP protocol mode, we need to handle the address as-is + if network == NetworkType.ip6 and host.find(":") != -1: + return host, DEFAULT_IP_PORT + + # For other IP protocols, no colons allowed + if host.find(":") != -1: + raise Error("IP protocol addresses should not include ports") + return host, DEFAULT_IP_PORT - # Parse regular addresses var colon_index = address.rfind(":") if colon_index == -1: raise MissingPortError var host: String - var bracket_offset: Int + var bracket_offset: Int = 0 # Handle IPv6 addresses - try: - (host, bracket_offset) = parse_ipv6_bracketed_address(address) - except e: - raise e - - # Validate no unexpected brackets - validate_no_brackets(address, bracket_offset) + if address[0] == "[": + try: + (host, bracket_offset) = parse_ipv6_bracketed_address(address) + except e: + raise e + + validate_no_brackets(address, bracket_offset) + else: + # For IPv4, simply split at the last colon + host = address[:colon_index] + if host.find(":") != -1: + raise TooManyColonsError - # Parse and validate port var port = parse_port(address[colon_index + 1:]) - # Resolve localhost if needed host = resolve_localhost(host, network) if host == AddressConstants.EMPTY: raise Error("missing host") @@ -421,7 +428,6 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1 return host, port - # TODO: Support IPv6 long form. fn join_host_port(host: String, port: String) -> String: if host.find(":") != -1: # must be IPv6 literal diff --git a/lightbug_http/server.mojo b/lightbug_http/server.mojo index 66023aad..c693ed57 100644 --- a/lightbug_http/server.mojo +++ b/lightbug_http/server.mojo @@ -91,8 +91,8 @@ struct Server(Movable): address: The address (host:port) to listen on. handler: An object that handles incoming HTTP requests. """ - var net = ListenConfig() - var listener = net.listen(address) + var config = ListenConfig() + var listener = config.listen(address) self.set_address(address) self.serve(listener^, handler) diff --git a/scripts/bench_server.sh b/scripts/bench_server.sh index 50d663d6..ba71acbb 100644 --- a/scripts/bench_server.sh +++ b/scripts/bench_server.sh @@ -9,7 +9,7 @@ echo "running server..." sleep 2 echo "Running benchmark" -wrk -t1 -c1 -d10s http://0.0.0.0:8080/ --header "User-Agent: wrk" +wrk -t1 -c1 -d10s http://localhost:8080/ --header "User-Agent: wrk" kill $! wait $! 2>/dev/null diff --git a/tests/lightbug_http/test_host_port.mojo b/tests/lightbug_http/test_host_port.mojo index 76623454..279ffc5b 100644 --- a/tests/lightbug_http/test_host_port.mojo +++ b/tests/lightbug_http/test_host_port.mojo @@ -4,21 +4,107 @@ from lightbug_http.strings import NetworkType def test_split_host_port(): - # IPv4 - var hp = parse_address("127.0.0.1:8080") + # TCP4 + var hp = parse_address(NetworkType.tcp4, "127.0.0.1:8080") testing.assert_equal(hp[0], "127.0.0.1") testing.assert_equal(hp[1], 8080) - # IPv6 - hp = parse_address("[::1]:8080") + # TCP4 with localhost + hp = parse_address(NetworkType.tcp4, "localhost:8080") + testing.assert_equal(hp[0], "127.0.0.1") + testing.assert_equal(hp[1], 8080) + + # TCP6 + hp = parse_address(NetworkType.tcp6, "[::1]:8080") testing.assert_equal(hp[0], "::1") testing.assert_equal(hp[1], 8080) - # # TODO: IPv6 long form - Not supported yet. + # TCP6 with localhost + hp = parse_address(NetworkType.tcp6, "localhost:8080") + testing.assert_equal(hp[0], "::1") + testing.assert_equal(hp[1], 8080) + + # UDP4 + hp = parse_address(NetworkType.udp4, "192.168.1.1:53") + testing.assert_equal(hp[0], "192.168.1.1") + testing.assert_equal(hp[1], 53) + + # UDP4 with localhost + hp = parse_address(NetworkType.udp4, "localhost:53") + testing.assert_equal(hp[0], "127.0.0.1") + testing.assert_equal(hp[1], 53) + + # UDP6 + hp = parse_address(NetworkType.udp6, "[2001:db8::1]:53") + testing.assert_equal(hp[0], "2001:db8::1") + testing.assert_equal(hp[1], 53) + + # UDP6 with localhost + hp = parse_address(NetworkType.udp6, "localhost:53") + testing.assert_equal(hp[0], "::1") + testing.assert_equal(hp[1], 53) + + # IP4 (no port) + hp = parse_address(NetworkType.ip4, "192.168.1.1") + testing.assert_equal(hp[0], "192.168.1.1") + testing.assert_equal(hp[1], 0) + + # IP4 with localhost + hp = parse_address(NetworkType.ip4, "localhost") + testing.assert_equal(hp[0], "127.0.0.1") + testing.assert_equal(hp[1], 0) + + # IP6 (no port) + hp = parse_address(NetworkType.ip6, "2001:db8::1") + testing.assert_equal(hp[0], "2001:db8::1") + testing.assert_equal(hp[1], 0) + + # IP6 with localhost + hp = parse_address(NetworkType.ip6, "localhost") + testing.assert_equal(hp[0], "::1") + testing.assert_equal(hp[1], 0) + + # TODO: IPv6 long form - Not supported yet. # hp = parse_address("0:0:0:0:0:0:0:1:8080") # testing.assert_equal(hp[0], "0:0:0:0:0:0:0:1") # testing.assert_equal(hp[1], 8080) + # Error cases + # IP protocol with port + try: + _ = parse_address(NetworkType.ip4, "192.168.1.1:80") + testing.assert_false("Should have raised an error for IP protocol with port") + except Error: + testing.assert_true(True) + + # Missing port + try: + _ = parse_address(NetworkType.tcp4, "192.168.1.1") + testing.assert_false("Should have raised MissingPortError") + except MissingPortError: + testing.assert_true(True) + + # Missing port + try: + _ = parse_address(NetworkType.tcp6, "[::1]") + testing.assert_false("Should have raised MissingPortError") + except MissingPortError: + testing.assert_true(True) + + # Port out of range + try: + _ = parse_address(NetworkType.tcp4, "192.168.1.1:70000") + testing.assert_false("Should have raised error for invalid port") + except Error: + testing.assert_true(True) + + # Missing closing bracket + try: + _ = parse_address(NetworkType.tcp6, "[::1:8080") + testing.assert_false("Should have raised error for missing bracket") + except Error: + testing.assert_true(True) + def test_join_host_port(): # IPv4 From c2093b50c9483d04e87dbf2ee48fd6d1cfee40b9 Mon Sep 17 00:00:00 2001 From: Val Date: Sat, 1 Feb 2025 14:55:48 +0100 Subject: [PATCH 03/11] fix import --- lightbug_http/http/response.mojo | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightbug_http/http/response.mojo b/lightbug_http/http/response.mojo index a138e7e7..6757c9a9 100644 --- a/lightbug_http/http/response.mojo +++ b/lightbug_http/http/response.mojo @@ -1,6 +1,9 @@ +from utils import StringSlice +from collections import Optional from small_time.small_time import now from lightbug_http.uri import URI from lightbug_http.io.bytes import Bytes, bytes, byte, ByteReader, ByteWriter +from lightbug_http.connection import TCPConnection, default_buffer_size from lightbug_http.strings import ( strHttp11, strHttp, @@ -11,9 +14,6 @@ from lightbug_http.strings import ( lineBreak, to_string, ) -from collections import Optional -from utils import StringSlice -from lightbug_http.net import TCPConnection, default_buffer_size struct StatusCode: From 865ea8e6d983141d245db09201208a98f7603818 Mon Sep 17 00:00:00 2001 From: Val Date: Sat, 1 Feb 2025 15:21:18 +0100 Subject: [PATCH 04/11] update README --- README.md | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index e69b87a9..23af7fa2 100644 --- a/README.md +++ b/README.md @@ -29,11 +29,10 @@ Lightbug is a simple and sweet HTTP framework for Mojo that builds on best pract This is not production ready yet. We're aiming to keep up with new developments in Mojo, but it might take some time to get to a point when this is safe to use in real-world applications. Lightbug currently has the following features: - - [x] Pure Mojo networking! No dependencies on Python by default - - [x] TCP-based server and client implementation - - [x] Assign your own custom handler to a route - - [x] Craft HTTP requests and responses with built-in primitives - - [x] Everything is fully typed, with no `def` functions used + - [x] Pure Mojo! No Python dependencies. Everything is fully typed, with no `def` functions used + - [x] HTTP Server and Client implementations + - [x] TCP and UDP support + - [x] Cookie support ### Check Out These Mojo Libraries: @@ -200,19 +199,15 @@ from lightbug_http import * from lightbug_http.client import Client fn test_request(mut client: Client) raises -> None: - var uri = URI.parse_raises("http://httpbin.org/status/404") - var headers = Header("Host", "httpbin.org") - + var uri = URI.parse("google.com") + var headers = Headers(Header("Host", "google.com")) var request = HTTPRequest(uri, headers) var response = client.do(request^) # print status code print("Response:", response.status_code) - # print parsed headers (only some are parsed for now) - print("Content-Type:", response.headers["Content-Type"]) - print("Content-Length", response.headers["Content-Length"]) - print("Server:", to_string(response.headers["Server"])) + print(response.headers) print( "Is connection set to connection-close? ", response.connection_close() @@ -252,19 +247,17 @@ Note: as of September, 2024, `PythonServer` and `PythonClient` throw a compilati We're working on support for the following (contributors welcome!): -- [ ] [WebSocket Support](https://github.com/saviorand/lightbug_http/pull/57) + - [ ] [JSON support](https://github.com/saviorand/lightbug_http/issues/4) + - [ ] Complete HTTP/1.x support compliant with RFC 9110/9112 specs (see issues) - [ ] [SSL/HTTPS support](https://github.com/saviorand/lightbug_http/issues/20) - - [ ] UDP support - - [ ] [Better error handling](https://github.com/saviorand/lightbug_http/issues/3), [improved form/multipart and JSON support](https://github.com/saviorand/lightbug_http/issues/4) - [ ] [Multiple simultaneous connections](https://github.com/saviorand/lightbug_http/issues/5), [parallelization and performance optimizations](https://github.com/saviorand/lightbug_http/issues/6) - [ ] [HTTP 2.0/3.0 support](https://github.com/saviorand/lightbug_http/issues/8) - - [ ] [ASGI spec conformance](https://github.com/saviorand/lightbug_http/issues/17) The plan is to get to a feature set similar to Python frameworks like [Starlette](https://github.com/encode/starlette), but with better performance. Our vision is to develop three libraries, with `lightbug_http` (this repo) as a starting point: - - `lightbug_http` - HTTP infrastructure and basic API development - - `lightbug_api` - (coming later in 2024!) Tools to make great APIs fast, with support for OpenAPI spec and domain driven design + - `lightbug_http` - Lightweight and simple HTTP framework, basic networking primitives + - [`lightbug_api`](https://github.com/saviorand/lightbug_api) - Tools to make great APIs fast, with OpenAPI support and automated docs - `lightbug_web` - (release date TBD) Full-stack web framework for Mojo, similar to NextJS or SvelteKit The idea is to get to a point where the entire codebase of a simple modern web application can be written in Mojo. From 1f38c877176c831c087252ad99bc045cd73dc178 Mon Sep 17 00:00:00 2001 From: Val Date: Sat, 1 Feb 2025 16:44:28 +0100 Subject: [PATCH 05/11] fix imports and add udp example --- README.md | 46 +++++++++++++++++++++++---- lightbug_http/client.mojo | 3 +- lightbug_http/pool_manager.mojo | 2 +- tests/integration/udp/udp_client.mojo | 3 +- tests/integration/udp/udp_server.mojo | 3 +- testutils/utils.mojo | 3 +- 6 files changed, 48 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 23af7fa2..b2752d71 100644 --- a/README.md +++ b/README.md @@ -227,16 +227,50 @@ fn main() -> None: Pure Mojo-based client is available by default. This client is also used internally for testing the server. -## Switching between pure Mojo and Python implementations - -By default, Lightbug uses the pure Mojo implementation for networking. To use Python's `socket` library instead, just import the `PythonServer` instead of the `Server` with the following line: +### UDP Support +To get started with UDP, just use the `listen_udp` and `dial_udp` functions, along with `write_to` and `read_from` methods, like below. +On the client: ```mojo -from lightbug_http.python.server import PythonServer +from lightbug_http.connection import dial_udp +from lightbug_http.address import UDPAddr +from utils import StringSlice + +alias test_string = "Hello, lightbug!" + +fn main() raises: + print("Dialing UDP server...") + alias host = "127.0.0.1" + alias port = 12000 + var udp = dial_udp(host, port) + + print("Sending " + str(len(test_string)) + " messages to the server...") + for i in range(len(test_string)): + _ = udp.write_to(str(test_string[i]).as_bytes(), host, port) + + try: + response, _, _ = udp.read_from(16) + print("Response received:", StringSlice(unsafe_from_utf8=response)) + except e: + if str(e) != str("EOF"): + raise e + ``` -You can then use all the regular server commands in the same way as with the default server. -Note: as of September, 2024, `PythonServer` and `PythonClient` throw a compilation error when starting. There's an open [issue](https://github.com/saviorand/lightbug_http/issues/41) to fix this - contributions welcome! +On the server: +```mojo +fn main() raises: + var listener = listen_udp("127.0.0.1", 12000) + + while True: + response, host, port = listener.read_from(16) + var message = StringSlice(unsafe_from_utf8=response) + print("Message received:", message) + + # Response with the same message in uppercase + _ = listener.write_to(String.upper(message).as_bytes(), host, port) + +``` ## Roadmap diff --git a/lightbug_http/client.mojo b/lightbug_http/client.mojo index 08d405fe..59c8a58e 100644 --- a/lightbug_http/client.mojo +++ b/lightbug_http/client.mojo @@ -1,10 +1,9 @@ from collections import Dict from utils import StringSlice from memory import UnsafePointer -from lightbug_http.net import default_buffer_size +from lightbug_http.connection import TCPConnection, default_buffer_size, create_connection from lightbug_http.http import HTTPRequest, HTTPResponse, encode from lightbug_http.header import Headers, HeaderKey -from lightbug_http.net import create_connection, TCPConnection from lightbug_http.io.bytes import Bytes, ByteReader from lightbug_http._logger import logger from lightbug_http.pool_manager import PoolManager, PoolKey diff --git a/lightbug_http/pool_manager.mojo b/lightbug_http/pool_manager.mojo index c34ba0e0..10e7b87f 100644 --- a/lightbug_http/pool_manager.mojo +++ b/lightbug_http/pool_manager.mojo @@ -4,7 +4,7 @@ from builtin.value import StringableCollectionElement from memory import UnsafePointer, bitcast, memcpy from collections import Dict, Optional from collections.dict import RepresentableKeyElement -from lightbug_http.net import create_connection, TCPConnection, Connection +from lightbug_http.connection import create_connection, TCPConnection, Connection from lightbug_http._logger import logger from lightbug_http._owning_list import OwningList from lightbug_http.uri import Scheme diff --git a/tests/integration/udp/udp_client.mojo b/tests/integration/udp/udp_client.mojo index 88811b7e..ae87eba4 100644 --- a/tests/integration/udp/udp_client.mojo +++ b/tests/integration/udp/udp_client.mojo @@ -1,4 +1,5 @@ -from lightbug_http.net import dial_udp, UDPAddr +from lightbug_http.connection import dial_udp +from lightbug_http.address import UDPAddr from utils import StringSlice alias test_string = "Hello, lightbug!" diff --git a/tests/integration/udp/udp_server.mojo b/tests/integration/udp/udp_server.mojo index 5c37deeb..648e3e64 100644 --- a/tests/integration/udp/udp_server.mojo +++ b/tests/integration/udp/udp_server.mojo @@ -1,4 +1,5 @@ -from lightbug_http.net import listen_udp, UDPAddr +from lightbug_http.connection import listen_udp +from lightbug_http.address import UDPAddr from utils import StringSlice diff --git a/testutils/utils.mojo b/testutils/utils.mojo index 685172b9..132fe84d 100644 --- a/testutils/utils.mojo +++ b/testutils/utils.mojo @@ -3,7 +3,8 @@ from lightbug_http.io.bytes import Bytes from lightbug_http.error import ErrorHandler from lightbug_http.uri import URI from lightbug_http.http import HTTPRequest, HTTPResponse -from lightbug_http.net import Listener, Addr, Connection, TCPAddr +from lightbug_http.connection import Listener, Connection +from lightbug_http.address import Addr, TCPAddr from lightbug_http.service import HTTPService, OK from lightbug_http.server import ServerTrait from lightbug_http.client import Client From a2884566593cebb7ccee5b224dbbf2afb83b8727 Mon Sep 17 00:00:00 2001 From: Val Date: Sun, 2 Feb 2025 13:37:23 +0100 Subject: [PATCH 06/11] return uint --- lightbug_http/address.mojo | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightbug_http/address.mojo b/lightbug_http/address.mojo index a5a33ae3..09c62ba7 100644 --- a/lightbug_http/address.mojo +++ b/lightbug_http/address.mojo @@ -330,14 +330,14 @@ fn resolve_localhost(host: String, network: NetworkType) -> String: return AddressConstants.IPV6_LOCALHOST return host -fn parse_ipv6_bracketed_address(address: String) raises -> (String, Int): +fn parse_ipv6_bracketed_address(address: String) raises -> (String, UInt16): """Parse an IPv6 address enclosed in brackets. Returns: Tuple of (host, colon_index_offset) """ if address[0] != "[": - return address, 0 + return address, UInt16(0) var end_bracket_index = address.find("]") if end_bracket_index == -1: @@ -352,7 +352,7 @@ fn parse_ipv6_bracketed_address(address: String) raises -> (String, Int): return ( address[1:end_bracket_index], - end_bracket_index + 1 + UInt16(end_bracket_index + 1) ) fn validate_no_brackets(address: String, start_idx: Int, end_idx: Int = -1) raises: From 4a5d14ac1b8e14947c306c2e26a4d717f95048b2 Mon Sep 17 00:00:00 2001 From: Val Date: Sun, 2 Feb 2025 14:05:24 +0100 Subject: [PATCH 07/11] move functions to networktype --- lightbug_http/address.mojo | 87 +++++++++++++++++++--- lightbug_http/connection.mojo | 3 +- lightbug_http/http/request.mojo | 13 ++++ lightbug_http/server.mojo | 2 +- lightbug_http/socket.mojo | 2 +- lightbug_http/strings.mojo | 96 ------------------------- tests/integration/test_net.mojo | 0 tests/lightbug_http/test_host_port.mojo | 3 +- 8 files changed, 96 insertions(+), 110 deletions(-) delete mode 100644 tests/integration/test_net.mojo diff --git a/lightbug_http/address.mojo b/lightbug_http/address.mojo index 09c62ba7..61162ffc 100644 --- a/lightbug_http/address.mojo +++ b/lightbug_http/address.mojo @@ -1,6 +1,7 @@ from memory import UnsafePointer +from collections import Optional from sys.ffi import external_call, OpaquePointer -from lightbug_http.strings import NetworkType, to_string +from lightbug_http.strings import to_string from lightbug_http._libc import ( c_int, c_char, @@ -66,6 +67,68 @@ trait AnAddrInfo: """ ... +@value +struct NetworkType(EqualityComparableCollectionElement): + var value: String + + alias empty = NetworkType("") + alias tcp = NetworkType("tcp") + alias tcp4 = NetworkType("tcp4") + alias tcp6 = NetworkType("tcp6") + alias udp = NetworkType("udp") + alias udp4 = NetworkType("udp4") + alias udp6 = NetworkType("udp6") + alias ip = NetworkType("ip") + alias ip4 = NetworkType("ip4") + alias ip6 = NetworkType("ip6") + alias unix = NetworkType("unix") + + alias SUPPORTED_TYPES = [ + Self.tcp, + Self.tcp4, + Self.tcp6, + Self.udp, + Self.udp4, + Self.udp6, + Self.ip, + Self.ip4, + Self.ip6, + ] + alias TCP_TYPES = [ + Self.tcp, + Self.tcp4, + Self.tcp6, + ] + alias UDP_TYPES = [ + Self.udp, + Self.udp4, + Self.udp6, + ] + alias IP_TYPES = [ + Self.ip, + Self.ip4, + Self.ip6, + ] + + fn __eq__(self, other: NetworkType) -> Bool: + return self.value == other.value + + fn __ne__(self, other: NetworkType) -> Bool: + return self.value != other.value + + fn is_ip_protocol(self) -> Bool: + """Check if the network type is an IP protocol.""" + return self in (NetworkType.ip, NetworkType.ip4, NetworkType.ip6) + + fn is_ipv4(self) -> Bool: + """Check if the network type is IPv4.""" + print("self.value:", self.value) + return self in (NetworkType.tcp4, NetworkType.udp4, NetworkType.ip4) + + fn is_ipv6(self) -> Bool: + """Check if the network type is IPv6.""" + return self in (NetworkType.tcp6, NetworkType.udp6, NetworkType.ip6) + @value struct TCPAddr[network: NetworkType = NetworkType.tcp4](Addr): alias _type = "TCPAddr" @@ -324,10 +387,11 @@ fn resolve_localhost(host: String, network: NetworkType) -> String: if host != AddressConstants.LOCALHOST: return host - if is_ipv4(network): + if network.is_ipv4(): return AddressConstants.IPV4_LOCALHOST - elif is_ipv6(network): + elif network.is_ipv6(): return AddressConstants.IPV6_LOCALHOST + return host fn parse_ipv6_bracketed_address(address: String) raises -> (String, UInt16): @@ -355,11 +419,18 @@ fn parse_ipv6_bracketed_address(address: String) raises -> (String, UInt16): UInt16(end_bracket_index + 1) ) -fn validate_no_brackets(address: String, start_idx: Int, end_idx: Int = -1) raises: +fn validate_no_brackets(address: String, start_idx: UInt16, end_idx: Optional[UInt16] = None) raises: """Validate that the address segment contains no brackets.""" - if address[start_idx:end_idx].find("[") != -1: + var segment: String + + if end_idx is None: + segment = address[int(start_idx):] + else: + segment = address[int(start_idx):int(end_idx.value())] + + if segment.find("[") != -1: raise Error("unexpected '[' in address") - if address[start_idx:end_idx].find("]") != -1: + if segment.find("]") != -1: raise Error("unexpected ']' in address") fn parse_port(port_str: String) raises -> UInt16: @@ -383,7 +454,7 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1 Returns: Tuple containing the host and port """ - if is_ip_protocol(network): + if network.is_ip_protocol(): var host = resolve_localhost(address, network) if host == AddressConstants.EMPTY: raise Error("missing host") @@ -403,7 +474,7 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1 raise MissingPortError var host: String - var bracket_offset: Int = 0 + var bracket_offset: UInt16 = 0 # Handle IPv6 addresses if address[0] == "[": diff --git a/lightbug_http/connection.mojo b/lightbug_http/connection.mojo index dcf10eb9..ee0fce2c 100644 --- a/lightbug_http/connection.mojo +++ b/lightbug_http/connection.mojo @@ -1,7 +1,7 @@ from time import sleep from memory import Span from sys.info import os_is_macos -from lightbug_http.strings import NetworkType +from lightbug_http.address import NetworkType from lightbug_http.io.bytes import Bytes, bytes from lightbug_http.io.sync import Duration from lightbug_http.address import parse_address, TCPAddr, UDPAddr @@ -51,7 +51,6 @@ trait Connection(Movable): fn remote_addr(self) -> TCPAddr: ... - struct NoTLSListener: """A TCP listener that listens for incoming connections and can accept them.""" diff --git a/lightbug_http/http/request.mojo b/lightbug_http/http/request.mojo index 0e83f409..6b418e6f 100644 --- a/lightbug_http/http/request.mojo +++ b/lightbug_http/http/request.mojo @@ -17,6 +17,19 @@ from lightbug_http.strings import ( ) +@value +struct RequestMethod: + var value: String + + alias get = RequestMethod("GET") + alias post = RequestMethod("POST") + alias put = RequestMethod("PUT") + alias delete = RequestMethod("DELETE") + alias head = RequestMethod("HEAD") + alias patch = RequestMethod("PATCH") + alias options = RequestMethod("OPTIONS") + + @value struct HTTPRequest(Writable, Stringable): var headers: Headers diff --git a/lightbug_http/server.mojo b/lightbug_http/server.mojo index a59207f3..0691643c 100644 --- a/lightbug_http/server.mojo +++ b/lightbug_http/server.mojo @@ -1,7 +1,7 @@ from memory import Span from lightbug_http.io.sync import Duration from lightbug_http.io.bytes import Bytes, bytes, ByteReader -from lightbug_http.strings import NetworkType +from lightbug_http.address import NetworkType from lightbug_http._logger import logger from lightbug_http.connection import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig from lightbug_http.socket import Socket diff --git a/lightbug_http/socket.mojo b/lightbug_http/socket.mojo index ec97c510..b0047635 100644 --- a/lightbug_http/socket.mojo +++ b/lightbug_http/socket.mojo @@ -45,8 +45,8 @@ from lightbug_http._libc import ( ShutdownInvalidArgumentError, ) from lightbug_http.io.bytes import Bytes -from lightbug_http.strings import NetworkType from lightbug_http.address import ( + NetworkType, Addr, binary_port_to_int, binary_ip_to_string, diff --git a/lightbug_http/strings.mojo b/lightbug_http/strings.mojo index c1c5b9de..30e096cc 100644 --- a/lightbug_http/strings.mojo +++ b/lightbug_http/strings.mojo @@ -31,102 +31,6 @@ struct BytesConstant: alias nChar = byte(nChar) -@value -struct NetworkType(EqualityComparableCollectionElement): - var value: String - - alias empty = NetworkType("") - alias tcp = NetworkType("tcp") - alias tcp4 = NetworkType("tcp4") - alias tcp6 = NetworkType("tcp6") - alias udp = NetworkType("udp") - alias udp4 = NetworkType("udp4") - alias udp6 = NetworkType("udp6") - alias ip = NetworkType("ip") - alias ip4 = NetworkType("ip4") - alias ip6 = NetworkType("ip6") - alias unix = NetworkType("unix") - - alias SUPPORTED_TYPES = [ - Self.tcp, - Self.tcp4, - Self.tcp6, - Self.udp, - Self.udp4, - Self.udp6, - Self.ip, - Self.ip4, - Self.ip6, - ] - alias TCP_TYPES = [ - Self.tcp, - Self.tcp4, - Self.tcp6, - ] - alias UDP_TYPES = [ - Self.udp, - Self.udp4, - Self.udp6, - ] - alias IP_TYPES = [ - Self.ip, - Self.ip4, - Self.ip6, - ] - - fn __eq__(self, other: NetworkType) -> Bool: - return self.value == other.value - - fn __ne__(self, other: NetworkType) -> Bool: - return self.value != other.value - - -@value -struct ConnType: - var value: String - - alias empty = ConnType("") - alias http = ConnType("http") - alias websocket = ConnType("websocket") - - -@value -struct RequestMethod: - var value: String - - alias get = RequestMethod("GET") - alias post = RequestMethod("POST") - alias put = RequestMethod("PUT") - alias delete = RequestMethod("DELETE") - alias head = RequestMethod("HEAD") - alias patch = RequestMethod("PATCH") - alias options = RequestMethod("OPTIONS") - - -@value -struct CharSet: - var value: String - - alias utf8 = CharSet("utf-8") - - -@value -struct MediaType: - var value: String - - alias empty = MediaType("") - alias plain = MediaType("text/plain") - alias json = MediaType("application/json") - - -@value -struct Message: - var type: String - - alias empty = Message("") - alias http_start = Message("http.response.start") - - fn to_string[T: Writable](value: T) -> String: return String.write(value) diff --git a/tests/integration/test_net.mojo b/tests/integration/test_net.mojo deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/lightbug_http/test_host_port.mojo b/tests/lightbug_http/test_host_port.mojo index 279ffc5b..5334dfd9 100644 --- a/tests/lightbug_http/test_host_port.mojo +++ b/tests/lightbug_http/test_host_port.mojo @@ -1,6 +1,5 @@ import testing -from lightbug_http.address import join_host_port, parse_address, TCPAddr -from lightbug_http.strings import NetworkType +from lightbug_http.address import TCPAddr, NetworkType, join_host_port, parse_address def test_split_host_port(): From 4b30bcd0643304d90167909e20f8a471f2049b3c Mon Sep 17 00:00:00 2001 From: Val Date: Sun, 2 Feb 2025 14:20:47 +0100 Subject: [PATCH 08/11] update parse ipv6 with byteview --- lightbug_http/address.mojo | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lightbug_http/address.mojo b/lightbug_http/address.mojo index 61162ffc..88324d24 100644 --- a/lightbug_http/address.mojo +++ b/lightbug_http/address.mojo @@ -1,7 +1,10 @@ -from memory import UnsafePointer +from memory import UnsafePointer, Span from collections import Optional from sys.ffi import external_call, OpaquePointer from lightbug_http.strings import to_string +from lightbug_http.io.bytes import ByteView +from lightbug_http._logger import logger +from lightbug_http.socket import Socket from lightbug_http._libc import ( c_int, c_char, @@ -20,8 +23,6 @@ from lightbug_http._libc import ( INET_ADDRSTRLEN, INET6_ADDRSTRLEN, ) -from lightbug_http._logger import logger -from lightbug_http.socket import Socket alias MAX_PORT = 65535 alias MIN_PORT = 0 @@ -394,16 +395,16 @@ fn resolve_localhost(host: String, network: NetworkType) -> String: return host -fn parse_ipv6_bracketed_address(address: String) raises -> (String, UInt16): +fn parse_ipv6_bracketed_address(address: ByteView[ImmutableAnyOrigin]) raises -> (ByteView[ImmutableAnyOrigin], UInt16): """Parse an IPv6 address enclosed in brackets. Returns: Tuple of (host, colon_index_offset) """ - if address[0] != "[": + if address[0] != Byte(ord("[")): return address, UInt16(0) - var end_bracket_index = address.find("]") + var end_bracket_index = address.find(Byte(ord("]"))) if end_bracket_index == -1: raise Error("missing ']' in address") @@ -411,7 +412,7 @@ fn parse_ipv6_bracketed_address(address: String) raises -> (String, UInt16): raise MissingPortError var colon_index = end_bracket_index + 1 - if address[colon_index] != ":": + if address[colon_index] != Byte(ord(":")): raise MissingPortError return ( From cdf3b06be02f86713af888bb506924cdbfe4a0b4 Mon Sep 17 00:00:00 2001 From: Val Date: Sun, 2 Feb 2025 18:50:20 +0100 Subject: [PATCH 09/11] use ByteView for improved performance --- lightbug_http/address.mojo | 42 +++++++++--------- lightbug_http/io/bytes.mojo | 20 ++++++++- tests/lightbug_http/io/test_bytes.mojo | 3 +- tests/lightbug_http/test_host_port.mojo | 58 ++++++++++++------------- 4 files changed, 71 insertions(+), 52 deletions(-) diff --git a/lightbug_http/address.mojo b/lightbug_http/address.mojo index 88324d24..96727dc4 100644 --- a/lightbug_http/address.mojo +++ b/lightbug_http/address.mojo @@ -383,19 +383,19 @@ fn is_ipv6(network: NetworkType) -> Bool: """Check if the network type is IPv6.""" return network in (NetworkType.tcp6, NetworkType.udp6, NetworkType.ip6) -fn resolve_localhost(host: String, network: NetworkType) -> String: +fn resolve_localhost(host: ByteView[StaticConstantOrigin], network: NetworkType) -> ByteView[StaticConstantOrigin]: """Resolve localhost to the appropriate IP address based on network type.""" - if host != AddressConstants.LOCALHOST: + if host != AddressConstants.LOCALHOST.as_bytes(): return host if network.is_ipv4(): - return AddressConstants.IPV4_LOCALHOST + return AddressConstants.IPV4_LOCALHOST.as_bytes() elif network.is_ipv6(): - return AddressConstants.IPV6_LOCALHOST + return AddressConstants.IPV6_LOCALHOST.as_bytes() return host -fn parse_ipv6_bracketed_address(address: ByteView[ImmutableAnyOrigin]) raises -> (ByteView[ImmutableAnyOrigin], UInt16): +fn parse_ipv6_bracketed_address(address: ByteView[StaticConstantOrigin]) raises -> (ByteView[StaticConstantOrigin], UInt16): """Parse an IPv6 address enclosed in brackets. Returns: @@ -420,32 +420,32 @@ fn parse_ipv6_bracketed_address(address: ByteView[ImmutableAnyOrigin]) raises -> UInt16(end_bracket_index + 1) ) -fn validate_no_brackets(address: String, start_idx: UInt16, end_idx: Optional[UInt16] = None) raises: +fn validate_no_brackets(address: ByteView[StaticConstantOrigin], start_idx: UInt16, end_idx: Optional[UInt16] = None) raises: """Validate that the address segment contains no brackets.""" - var segment: String + var segment: ByteView[StaticConstantOrigin] if end_idx is None: segment = address[int(start_idx):] else: segment = address[int(start_idx):int(end_idx.value())] - if segment.find("[") != -1: + if segment.find(Byte(ord("["))) != -1: raise Error("unexpected '[' in address") - if segment.find("]") != -1: + if segment.find(Byte(ord("]"))) != -1: raise Error("unexpected ']' in address") -fn parse_port(port_str: String) raises -> UInt16: +fn parse_port(port_str: ByteView[StaticConstantOrigin]) raises -> UInt16: """Parse and validate port number.""" - if port_str == AddressConstants.EMPTY: + if port_str == AddressConstants.EMPTY.as_bytes(): raise MissingPortError - var port = int(port_str) + var port = int(str(port_str)) if port < MIN_PORT or port > MAX_PORT: raise Error("Port number out of range (0-65535)") return UInt16(port) -fn parse_address(network: NetworkType, address: String) raises -> (String, UInt16): +fn parse_address(network: NetworkType, address: ByteView[StaticConstantOrigin]) raises -> (ByteView[StaticConstantOrigin], UInt16): """Parse an address string into a host and port. Args: @@ -457,28 +457,28 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1 """ if network.is_ip_protocol(): var host = resolve_localhost(address, network) - if host == AddressConstants.EMPTY: + if host == AddressConstants.EMPTY.as_bytes(): raise Error("missing host") # For IPv6 addresses in IP protocol mode, we need to handle the address as-is - if network == NetworkType.ip6 and host.find(":") != -1: + if network == NetworkType.ip6 and host.find(Byte(ord(":"))) != -1: return host, DEFAULT_IP_PORT # For other IP protocols, no colons allowed - if host.find(":") != -1: + if host.find(Byte(ord(":"))) != -1: raise Error("IP protocol addresses should not include ports") return host, DEFAULT_IP_PORT - var colon_index = address.rfind(":") + var colon_index = address.rfind(Byte(ord(":"))) if colon_index == -1: raise MissingPortError - var host: String + var host: ByteView[StaticConstantOrigin] var bracket_offset: UInt16 = 0 # Handle IPv6 addresses - if address[0] == "[": + if address[0] == Byte(ord("[")): try: (host, bracket_offset) = parse_ipv6_bracketed_address(address) except e: @@ -488,13 +488,13 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1 else: # For IPv4, simply split at the last colon host = address[:colon_index] - if host.find(":") != -1: + if host.find(Byte(ord(":"))) != -1: raise TooManyColonsError var port = parse_port(address[colon_index + 1:]) host = resolve_localhost(host, network) - if host == AddressConstants.EMPTY: + if host == AddressConstants.EMPTY.as_bytes(): raise Error("missing host") return host, port diff --git a/lightbug_http/io/bytes.mojo b/lightbug_http/io/bytes.mojo index cbc5b4c3..a80aee86 100644 --- a/lightbug_http/io/bytes.mojo +++ b/lightbug_http/io/bytes.mojo @@ -146,7 +146,7 @@ struct ByteView[origin: Origin](): fn __iter__(self) -> _SpanIter[Byte, origin]: return self._inner.__iter__() - + fn find(self, target: Byte) -> Int: """Finds the index of a byte in a byte span. @@ -162,6 +162,24 @@ struct ByteView[origin: Origin](): return -1 + fn rfind(self, target: Byte) -> Int: + """Finds the index of the last occurrence of a byte in a byte span. + + Args: + target: The byte to find. + + Returns: + The index of the last occurrence of the byte in the span, or -1 if not found. + """ + # Start from the end and work backwards + var i = len(self) - 1 + while i >= 0: + if self[i] == target: + return i + i -= 1 + + return -1 + fn to_bytes(self) -> Bytes: return Bytes(self._inner) diff --git a/tests/lightbug_http/io/test_bytes.mojo b/tests/lightbug_http/io/test_bytes.mojo index 7e847d7b..8affbc89 100644 --- a/tests/lightbug_http/io/test_bytes.mojo +++ b/tests/lightbug_http/io/test_bytes.mojo @@ -1,6 +1,6 @@ import testing from collections import Dict, List -from lightbug_http.io.bytes import Bytes, bytes +from lightbug_http.io.bytes import Bytes, ByteView, bytes fn test_string_literal_to_bytes() raises: @@ -35,3 +35,4 @@ fn test_string_to_bytes() raises: for c in cases.items(): testing.assert_equal(Bytes(c[].key.as_bytes()), c[].value) + diff --git a/tests/lightbug_http/test_host_port.mojo b/tests/lightbug_http/test_host_port.mojo index 5334dfd9..6084170d 100644 --- a/tests/lightbug_http/test_host_port.mojo +++ b/tests/lightbug_http/test_host_port.mojo @@ -4,63 +4,63 @@ from lightbug_http.address import TCPAddr, NetworkType, join_host_port, parse_ad def test_split_host_port(): # TCP4 - var hp = parse_address(NetworkType.tcp4, "127.0.0.1:8080") - testing.assert_equal(hp[0], "127.0.0.1") + var hp = parse_address(NetworkType.tcp4, "127.0.0.1:8080".as_bytes()) + testing.assert_equal(hp[0], "127.0.0.1".as_bytes()) testing.assert_equal(hp[1], 8080) # TCP4 with localhost - hp = parse_address(NetworkType.tcp4, "localhost:8080") - testing.assert_equal(hp[0], "127.0.0.1") + hp = parse_address(NetworkType.tcp4, "localhost:8080".as_bytes()) + testing.assert_equal(hp[0], "127.0.0.1".as_bytes()) testing.assert_equal(hp[1], 8080) # TCP6 - hp = parse_address(NetworkType.tcp6, "[::1]:8080") - testing.assert_equal(hp[0], "::1") + hp = parse_address(NetworkType.tcp6, "[::1]:8080".as_bytes()) + testing.assert_equal(hp[0], "::1".as_bytes()) testing.assert_equal(hp[1], 8080) # TCP6 with localhost - hp = parse_address(NetworkType.tcp6, "localhost:8080") - testing.assert_equal(hp[0], "::1") + hp = parse_address(NetworkType.tcp6, "localhost:8080".as_bytes()) + testing.assert_equal(hp[0], "::1".as_bytes()) testing.assert_equal(hp[1], 8080) # UDP4 - hp = parse_address(NetworkType.udp4, "192.168.1.1:53") - testing.assert_equal(hp[0], "192.168.1.1") + hp = parse_address(NetworkType.udp4, "192.168.1.1:53".as_bytes()) + testing.assert_equal(hp[0], "192.168.1.1".as_bytes()) testing.assert_equal(hp[1], 53) # UDP4 with localhost - hp = parse_address(NetworkType.udp4, "localhost:53") - testing.assert_equal(hp[0], "127.0.0.1") + hp = parse_address(NetworkType.udp4, "localhost:53".as_bytes()) + testing.assert_equal(hp[0], "127.0.0.1".as_bytes()) testing.assert_equal(hp[1], 53) # UDP6 - hp = parse_address(NetworkType.udp6, "[2001:db8::1]:53") - testing.assert_equal(hp[0], "2001:db8::1") + hp = parse_address(NetworkType.udp6, "[2001:db8::1]:53".as_bytes()) + testing.assert_equal(hp[0], "2001:db8::1".as_bytes()) testing.assert_equal(hp[1], 53) # UDP6 with localhost - hp = parse_address(NetworkType.udp6, "localhost:53") - testing.assert_equal(hp[0], "::1") + hp = parse_address(NetworkType.udp6, "localhost:53".as_bytes()) + testing.assert_equal(hp[0], "::1".as_bytes()) testing.assert_equal(hp[1], 53) # IP4 (no port) - hp = parse_address(NetworkType.ip4, "192.168.1.1") - testing.assert_equal(hp[0], "192.168.1.1") + hp = parse_address(NetworkType.ip4, "192.168.1.1".as_bytes()) + testing.assert_equal(hp[0], "192.168.1.1".as_bytes()) testing.assert_equal(hp[1], 0) # IP4 with localhost - hp = parse_address(NetworkType.ip4, "localhost") - testing.assert_equal(hp[0], "127.0.0.1") + hp = parse_address(NetworkType.ip4, "localhost".as_bytes()) + testing.assert_equal(hp[0], "127.0.0.1".as_bytes()) testing.assert_equal(hp[1], 0) # IP6 (no port) - hp = parse_address(NetworkType.ip6, "2001:db8::1") - testing.assert_equal(hp[0], "2001:db8::1") + hp = parse_address(NetworkType.ip6, "2001:db8::1".as_bytes()) + testing.assert_equal(hp[0], "2001:db8::1".as_bytes()) testing.assert_equal(hp[1], 0) # IP6 with localhost - hp = parse_address(NetworkType.ip6, "localhost") - testing.assert_equal(hp[0], "::1") + hp = parse_address(NetworkType.ip6, "localhost".as_bytes()) + testing.assert_equal(hp[0], "::1".as_bytes()) testing.assert_equal(hp[1], 0) # TODO: IPv6 long form - Not supported yet. @@ -71,35 +71,35 @@ def test_split_host_port(): # Error cases # IP protocol with port try: - _ = parse_address(NetworkType.ip4, "192.168.1.1:80") + _ = parse_address(NetworkType.ip4, "192.168.1.1:80".as_bytes()) testing.assert_false("Should have raised an error for IP protocol with port") except Error: testing.assert_true(True) # Missing port try: - _ = parse_address(NetworkType.tcp4, "192.168.1.1") + _ = parse_address(NetworkType.tcp4, "192.168.1.1".as_bytes()) testing.assert_false("Should have raised MissingPortError") except MissingPortError: testing.assert_true(True) # Missing port try: - _ = parse_address(NetworkType.tcp6, "[::1]") + _ = parse_address(NetworkType.tcp6, "[::1]".as_bytes()) testing.assert_false("Should have raised MissingPortError") except MissingPortError: testing.assert_true(True) # Port out of range try: - _ = parse_address(NetworkType.tcp4, "192.168.1.1:70000") + _ = parse_address(NetworkType.tcp4, "192.168.1.1:70000".as_bytes()) testing.assert_false("Should have raised error for invalid port") except Error: testing.assert_true(True) # Missing closing bracket try: - _ = parse_address(NetworkType.tcp6, "[::1:8080") + _ = parse_address(NetworkType.tcp6, "[::1:8080".as_bytes()) testing.assert_false("Should have raised error for missing bracket") except Error: testing.assert_true(True) From 835a725bfcf1e8ba464065c88526b935c259ff38 Mon Sep 17 00:00:00 2001 From: Val Date: Sun, 2 Feb 2025 19:13:34 +0100 Subject: [PATCH 10/11] add network param to udp --- lightbug_http/connection.mojo | 46 +++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/lightbug_http/connection.mojo b/lightbug_http/connection.mojo index ee0fce2c..61920f7c 100644 --- a/lightbug_http/connection.mojo +++ b/lightbug_http/connection.mojo @@ -87,9 +87,9 @@ struct ListenConfig: fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive): self._keep_alive = keep_alive - fn listen[network: NetworkType = NetworkType.tcp4](mut self, address: String) raises -> NoTLSListener: - var local = parse_address(network, address) - var addr = TCPAddr(local[0], local[1]) + fn listen[network: NetworkType = NetworkType.tcp4](mut self, address: StringLiteral) raises -> NoTLSListener: + var local = parse_address(network, address.as_bytes()) + var addr = TCPAddr(str(local[0]), local[1]) var socket: Socket[TCPAddr] try: socket = Socket[TCPAddr]() @@ -188,10 +188,10 @@ struct TCPConnection: return self.socket.remote_address() -struct UDPConnection: - var socket: Socket[UDPAddr] +struct UDPConnection[network: NetworkType]: + var socket: Socket[UDPAddr[network]] - fn __init__(out self, owned socket: Socket[UDPAddr]): + fn __init__(out self, owned socket: Socket[UDPAddr[network]]): self.socket = socket^ fn __moveinit__(out self, owned existing: Self): @@ -268,10 +268,10 @@ struct UDPConnection: fn is_closed(self) -> Bool: return self.socket._closed - fn local_addr(self) -> ref [self.socket._local_address] UDPAddr: + fn local_addr(self) -> ref [self.socket._local_address] UDPAddr[network]: return self.socket.local_address() - fn remote_addr(self) -> ref [self.socket._remote_address] UDPAddr: + fn remote_addr(self) -> ref [self.socket._remote_address] UDPAddr[network]: return self.socket.remote_address() fn create_connection(host: String, port: UInt16) raises -> TCPConnection: @@ -297,7 +297,7 @@ fn create_connection(host: String, port: UInt16) raises -> TCPConnection: return TCPConnection(socket^) -fn listen_udp(local_address: UDPAddr) raises -> UDPConnection: +fn listen_udp[network: NetworkType = NetworkType.udp4](local_address: UDPAddr) raises -> UDPConnection[network]: """Creates a new UDP listener. Args: @@ -309,12 +309,12 @@ fn listen_udp(local_address: UDPAddr) raises -> UDPConnection: Raises: Error: If the address is invalid or failed to bind the socket. """ - var socket = Socket[UDPAddr](socket_type=SOCK_DGRAM) + var socket = Socket[UDPAddr[network]](socket_type=SOCK_DGRAM) socket.bind(local_address.ip, local_address.port) - return UDPConnection(socket^) + return UDPConnection[network](socket^) -fn listen_udp(local_address: String) raises -> UDPConnection: +fn listen_udp[network: NetworkType = NetworkType.udp4](local_address: StringLiteral) raises -> UDPConnection[network]: """Creates a new UDP listener. Args: @@ -326,11 +326,11 @@ fn listen_udp(local_address: String) raises -> UDPConnection: Raises: Error: If the address is invalid or failed to bind the socket. """ - var address = parse_address(NetworkType.udp4, local_address) - return listen_udp(UDPAddr(address[0], address[1])) + var address = parse_address(NetworkType.udp4, local_address.as_bytes()) + return listen_udp[network](UDPAddr[network](str(address[0]), address[1])) -fn listen_udp(host: String, port: UInt16) raises -> UDPConnection: +fn listen_udp[network: NetworkType = NetworkType.udp4](host: String, port: UInt16) raises -> UDPConnection[network]: """Creates a new UDP listener. Args: @@ -343,10 +343,10 @@ fn listen_udp(host: String, port: UInt16) raises -> UDPConnection: Raises: Error: If the address is invalid or failed to bind the socket. """ - return listen_udp(UDPAddr(host, port)) + return listen_udp[network](UDPAddr[network](host, port)) -fn dial_udp(local_address: UDPAddr) raises -> UDPConnection: +fn dial_udp[network: NetworkType = NetworkType.udp4](local_address: UDPAddr[network]) raises -> UDPConnection[network]: """Connects to the address on the named network. The network must be "udp", "udp4", or "udp6". Args: @@ -358,10 +358,10 @@ fn dial_udp(local_address: UDPAddr) raises -> UDPConnection: Raises: Error: If the network type is not supported or failed to connect to the address. """ - return UDPConnection(Socket[UDPAddr](local_address=local_address, socket_type=SOCK_DGRAM)) + return UDPConnection(Socket[UDPAddr[network]](local_address=local_address, socket_type=SOCK_DGRAM)) -fn dial_udp(network: NetworkType, local_address: String) raises -> UDPConnection: +fn dial_udp[network: NetworkType = NetworkType.udp4](local_address: StringLiteral) raises -> UDPConnection[network]: """Connects to the address on the named network. The network must be "udp", "udp4", or "udp6". Args: @@ -373,11 +373,11 @@ fn dial_udp(network: NetworkType, local_address: String) raises -> UDPConnection Raises: Error: If the network type is not supported or failed to connect to the address. """ - var address = parse_address(network, local_address) - return dial_udp(UDPAddr(network, address[0], address[1])) + var address = parse_address(network, local_address.as_bytes()) + return dial_udp[network](UDPAddr[network](str(address[0]), address[1])) -fn dial_udp(host: String, port: UInt16) raises -> UDPConnection: +fn dial_udp[network: NetworkType = NetworkType.udp4](host: String, port: UInt16) raises -> UDPConnection[network]: """Connects to the address on the udp network. Args: @@ -390,4 +390,4 @@ fn dial_udp(host: String, port: UInt16) raises -> UDPConnection: Raises: Error: If failed to connect to the address. """ - return dial_udp(UDPAddr(host, port)) + return dial_udp[network](UDPAddr[network](host, port)) From d6c6828c94b52383d6493adea959b74871ae4104 Mon Sep 17 00:00:00 2001 From: Val Date: Sun, 2 Feb 2025 19:15:20 +0100 Subject: [PATCH 11/11] string to stringliteral in listen --- lightbug_http/address.mojo | 1 - lightbug_http/server.mojo | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/lightbug_http/address.mojo b/lightbug_http/address.mojo index 96727dc4..fc8692af 100644 --- a/lightbug_http/address.mojo +++ b/lightbug_http/address.mojo @@ -123,7 +123,6 @@ struct NetworkType(EqualityComparableCollectionElement): fn is_ipv4(self) -> Bool: """Check if the network type is IPv4.""" - print("self.value:", self.value) return self in (NetworkType.tcp4, NetworkType.udp4, NetworkType.ip4) fn is_ipv6(self) -> Bool: diff --git a/lightbug_http/server.mojo b/lightbug_http/server.mojo index 0691643c..dc290074 100644 --- a/lightbug_http/server.mojo +++ b/lightbug_http/server.mojo @@ -81,7 +81,7 @@ struct Server(Movable): """ return self.max_concurrent_connections - fn listen_and_serve[T: HTTPService](mut self, address: String, mut handler: T) raises: + fn listen_and_serve[T: HTTPService](mut self, address: StringLiteral, mut handler: T) raises: """Listen for incoming connections and serve HTTP requests. Parameters: