Skip to content

Commit

Permalink
Replace tcp_sockopts with socket_factory (#10520)
Browse files Browse the repository at this point in the history
Instead of TCPConnector taking a list of sockopts to be applied sockets
created, take a socket_factory callback that allows the caller to
implement socket creation entirely.
  • Loading branch information
TimMenninger committed Mar 10, 2025
1 parent 4399a6c commit 61c7bd8
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGES/10520.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added ``tcp_sockopts`` to ``TCPConnector`` to allow specifying custom socket options
-- by :user:`TimMenninger`.
12 changes: 6 additions & 6 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,9 @@ class TCPConnector(BaseConnector):
the happy eyeballs algorithm, set to None.
interleave - “First Address Family Count” as defined in RFC 8305
loop - Optional event loop.
tcp_sockopts - List of tuples of sockopts applied to underlying
socket
socket_factory - An aiohappyeyeballs.SocketFactoryType function
that, if supplied, will be used to create sockets
given an aiohappyeyeballs.AddrInfoType.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
Expand All @@ -849,7 +850,7 @@ def __init__(
timeout_ceil_threshold: float = 5,
happy_eyeballs_delay: Optional[float] = 0.25,
interleave: Optional[int] = None,
tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [],
socket_factory: Optional[aiohappyeyeballs.SocketFactoryType] = None,
):
super().__init__(
keepalive_timeout=keepalive_timeout,
Expand Down Expand Up @@ -880,7 +881,7 @@ def __init__(
self._happy_eyeballs_delay = happy_eyeballs_delay
self._interleave = interleave
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
self._tcp_sockopts = tcp_sockopts
self._socket_factory = socket_factory

def _close_immediately(self) -> List[Awaitable[object]]:
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
Expand Down Expand Up @@ -1122,9 +1123,8 @@ async def _wrap_create_connection(
happy_eyeballs_delay=self._happy_eyeballs_delay,
interleave=self._interleave,
loop=self._loop,
socket_factory=self._socket_factory,
)
for sockopt in self._tcp_sockopts:
sock.setsockopt(*sockopt)
connection = await self._loop.create_connection(
*args, **kwargs, sock=sock
)
Expand Down
23 changes: 15 additions & 8 deletions docs/client_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use
session = aiohttp.ClientSession(connector=conn)


Setting socket options
Custom socket creation
^^^^^^^^^^^^^^^^^^^^^^

Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed
to the underlying socket when creating a connection. For example, we may
want to change the conditions under which we consider a connection dead.
The following would change that to 9*7200 = 18 hours::
If the default socket is insufficient for your use case, pass an optional
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
`aiohappyeyeballs.SocketFactoryType`. This will be used to create all
sockets for the lifetime of the class object. For example, we may want to
change the conditions under which we consider a connection dead. The
following would make all sockets respect 9*7200 = 18 hours::

import socket

conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200),
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ])
def socket_factory(addr_info):
family, type_, proto, _, _, _ = addr_info
sock = socket.socket(family=family, type=type_, proto=proto)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9)
return sock
conn = aiohttp.TCPConnector(socket_factory=socket_factory)


Named pipes in Windows
Expand Down
8 changes: 4 additions & 4 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ is controlled by *force_close* constructor's parameter).
force_close=False, limit=100, limit_per_host=0, \
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
tcp_sockopts=[])
socket_factory=None)

Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.

Expand Down Expand Up @@ -1250,9 +1250,9 @@ is controlled by *force_close* constructor's parameter).

.. versionadded:: 3.10

:param list tcp_sockopts: options applied to the socket when a connection is
created. This should be a list of 3-tuples, each a ``(level, optname, value)``.
Each tuple is deconstructed and passed verbatim to ``<socket>.setsockopt``.
:param aiohappyeyeballs.SocketFactoryType socket_factory: This function takes
an ``aiohappyeyeballs.AddrInfoType`` and is used in lieu of ``socket.socket()``
when creating TCP connections.

.. versionadded:: 3.12

Expand Down
53 changes: 37 additions & 16 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3767,27 +3767,48 @@ def test_connect() -> Literal[True]:
assert raw_response_list == [True, True]


async def test_tcp_connector_setsockopts(
async def test_tcp_connector_socket_factory(
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
) -> None:
"""Check that sockopts get passed to socket"""
conn = aiohttp.TCPConnector(
tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)]
)

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
start_connection.return_value = s
create_connection.return_value = mock.Mock(), mock.Mock()
"""Check that socket factory is called"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
start_connection.return_value = s

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)
local_addr = None
socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s
happy_eyeballs_delay = 0.123
interleave = 3
conn = aiohttp.TCPConnector(
interleave=interleave,
local_addr=local_addr,
happy_eyeballs_delay=happy_eyeballs_delay,
socket_factory=socket_factory,
)

with mock.patch.object(
conn._loop,
"create_connection",
autospec=True,
spec_set=True,
return_value=(mock.Mock(), mock.Mock()),
) as create_connection:
host = "127.0.0.1"
port = 443
req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop)
with closing(await conn.connect(req, [], ClientTimeout())):
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2

await conn.close()
pass
await conn.close()

start_connection.assert_called_with(
addr_infos=[
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port))
],
local_addr_infos=local_addr,
happy_eyeballs_delay=happy_eyeballs_delay,
interleave=interleave,
loop=loop,
socket_factory=socket_factory,
)


def test_default_ssl_context_creation_without_ssl() -> None:
Expand Down

0 comments on commit 61c7bd8

Please # to comment.