Skip to content

Commit 56cbea8

Browse files
authored
fix(rab): run async background boundary refresh on detached session (#17441)
When AuthorizedSession.request() makes an API call, it runs inside a temporary aiohttp ClientSession block. If our background Regional Access Boundary (RAB) refresh worker naively shares this exact same session, a fast primary call (like an instant 401/403 or a quick CRM check) will exit its block and close the active socket mid-flight. This causes the background worker to silently fail with "RuntimeError: Session is closed" and forces the RAB manager into a 15-minute cooldown. This commit resolves the race condition and ensures safe connection lifecycle management: - Shifted the cloning block to run synchronously inside start_refresh, capturing a fresh, independent ClientSession before the foreground thread can close the source transport. - Added a _clone() method to async Request adapters (both modern and legacy) to copy proxy settings and trace configurations while enforcing connector limits. - Prevented resource leaks on task creation failures by capturing exceptions in start_refresh and closing the cloned session synchronously. - Refactored the close wrapper to inspect and await generic awaitables (such as asyncio.Future) returned by custom or third-party transports. - Aligned exception behaviors by raising a wrapped TransportError directly when calling a closed instance of the legacy aiohttp_requests adapter. - Ensured the cloned transport is cleanly closed in a finally block after the background lookup settles.
1 parent b50cf1a commit 56cbea8

9 files changed

Lines changed: 1025 additions & 4 deletions

File tree

packages/google-auth/google/auth/_regional_access_boundary_utils.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from google.auth import _helpers
2828
from google.auth import environment_vars
2929

30-
if TYPE_CHECKING:
30+
if TYPE_CHECKING: # pragma: NO COVER
3131
import google.auth.credentials
3232
import google.auth.transport
3333

@@ -455,6 +455,61 @@ def start_refresh(self, credentials, request, rab_manager):
455455
self._worker.start()
456456

457457

458+
def _prepare_async_lookup_callable(request):
459+
"""Unwraps a request callable, clones the transport, and returns the new callable.
460+
461+
Args:
462+
request: The original request callable (e.g. functools.partial or raw Request).
463+
464+
Returns:
465+
Tuple[Callable, Any, bool]: A tuple containing the new lookup callable, the
466+
underlying request object, and a boolean indicating if it was cloned.
467+
"""
468+
is_partial = isinstance(request, functools.partial)
469+
base_callable = request.func if is_partial else request
470+
471+
if not hasattr(base_callable, "_clone"):
472+
return request, base_callable, False
473+
474+
cloned_callable = base_callable._clone()
475+
is_cloned = cloned_callable is not base_callable
476+
477+
if is_partial:
478+
new_request = functools.partial(
479+
cloned_callable, *request.args, **request.keywords
480+
)
481+
else:
482+
new_request = cloned_callable
483+
484+
return new_request, cloned_callable, is_cloned
485+
486+
487+
async def _close_cloned_request(lookup_request, is_cloned):
488+
"""Safely closes the underlying cloned request transport, if applicable.
489+
490+
Args:
491+
lookup_request (Any): The request object/transport to close.
492+
is_cloned (bool): Whether the request was actually cloned.
493+
"""
494+
if not is_cloned or not hasattr(lookup_request, "close"):
495+
return
496+
497+
is_async = False
498+
try:
499+
maybe_coro = lookup_request.close()
500+
if is_async := inspect.isawaitable(maybe_coro):
501+
await maybe_coro
502+
except Exception as e:
503+
if _helpers.is_logging_enabled(_LOGGER):
504+
adapter_type = " asynchronous " if is_async else " "
505+
_LOGGER.warning(
506+
"Failed to cleanly close cloned%srequest transport: %s",
507+
adapter_type,
508+
e,
509+
exc_info=True,
510+
)
511+
512+
458513
class _AsyncRegionalAccessBoundaryRefreshManager(object):
459514
"""Manages a task for background refreshing of the Regional Access Boundary in async flows."""
460515

@@ -491,11 +546,28 @@ def start_refresh(self, credentials, request, rab_manager):
491546
# A refresh is already in progress.
492547
return
493548

549+
try:
550+
(
551+
lookup_callable,
552+
lookup_request,
553+
is_cloned,
554+
) = _prepare_async_lookup_callable(request)
555+
except Exception as e:
556+
if _helpers.is_logging_enabled(_LOGGER):
557+
_LOGGER.warning(
558+
"Synchronous cloning of request for Regional Access Boundary lookup failed: %s",
559+
e,
560+
exc_info=True,
561+
)
562+
rab_manager.process_regional_access_boundary_info(None)
563+
return
564+
494565
async def _worker():
495566
try:
496-
# credentials._lookup_regional_access_boundary should be async in the async creds class
497567
regional_access_boundary_info = (
498-
await credentials._lookup_regional_access_boundary(request)
568+
await credentials._lookup_regional_access_boundary(
569+
lookup_callable
570+
)
499571
)
500572
except Exception as e:
501573
if _helpers.is_logging_enabled(_LOGGER):
@@ -505,6 +577,8 @@ async def _worker():
505577
exc_info=True,
506578
)
507579
regional_access_boundary_info = None
580+
finally:
581+
await _close_cloned_request(lookup_request, is_cloned)
508582

509583
rab_manager.process_regional_access_boundary_info(
510584
regional_access_boundary_info
@@ -514,7 +588,15 @@ async def _worker():
514588
try:
515589
self._worker_task = asyncio.create_task(coro)
516590
except Exception:
591+
# Clean up cloned request if task creation fails
517592
coro.close()
593+
try:
594+
asyncio.get_running_loop().create_task(
595+
_close_cloned_request(lookup_request, is_cloned)
596+
)
597+
except RuntimeError:
598+
pass
599+
rab_manager.process_regional_access_boundary_info(None)
518600
raise
519601

520602

packages/google-auth/google/auth/aio/transport/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,13 @@ async def close(self) -> None:
142142
Close the underlying session.
143143
"""
144144
raise NotImplementedError("close must be implemented.")
145+
146+
def _clone(self) -> "Request":
147+
"""Creates a copy of this request adapter.
148+
149+
The base implementation returns `self` (an identical shared instance).
150+
Transport adapters that maintain internal connection pools or stateful
151+
sessions must override this method to return an independent, detached
152+
adapter instance.
153+
"""
154+
return self

packages/google-auth/google/auth/aio/transport/aiohttp.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
else:
3737
try:
3838
from aiohttp import ClientTimeout
39-
except (ImportError, AttributeError):
39+
except (ImportError, AttributeError): # pragma: NO COVER
4040
ClientTimeout = None
4141

4242
_LOGGER = logging.getLogger(__name__)
@@ -203,3 +203,83 @@ async def close(self) -> None:
203203
if not self._closed and self._session:
204204
await self._session.close()
205205
self._closed = True
206+
207+
def _clone(self) -> "Request":
208+
"""Creates an independent copy of this request adapter.
209+
210+
Clones the connection settings, trace configurations, and session defaults
211+
(headers, cookies, basic auth, and timeouts).
212+
213+
Only standard `aiohttp.TCPConnector` and `aiohttp.UnixConnector` connectors
214+
are supported. The DNS resolver is not copied to avoid closing shared resolver
215+
resources.
216+
217+
Returns:
218+
google.auth.aio.transport.aiohttp.Request: A new request adapter.
219+
220+
Raises:
221+
google.auth.exceptions.TransportError: If the transport is closed, or if the
222+
session uses an unsupported connector.
223+
"""
224+
if self._closed:
225+
raise exceptions.TransportError("Cannot clone a closed transport.")
226+
227+
if not self._session:
228+
new_session = aiohttp.ClientSession(
229+
auto_decompress=False,
230+
trust_env=True,
231+
)
232+
return Request(session=new_session)
233+
234+
session_kwargs: dict = {
235+
"auto_decompress": False,
236+
"trust_env": getattr(self._session, "_trust_env", True),
237+
}
238+
239+
# Copy underlying connection pool settings (SSL context, IP bindings, limits).
240+
orig_connector = getattr(self._session, "_connector", None)
241+
if orig_connector and not orig_connector.closed:
242+
if isinstance(orig_connector, aiohttp.TCPConnector):
243+
# We explicitly do not copy the resolver. The connector
244+
# owns the resolver, and closing the cloned session would
245+
# close the shared resolver, breaking the original session.
246+
session_kwargs["connector"] = aiohttp.TCPConnector(
247+
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
248+
limit=getattr(orig_connector, "_limit", 100),
249+
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
250+
force_close=getattr(orig_connector, "_force_close", False),
251+
local_addr=getattr(orig_connector, "_local_addr", None),
252+
)
253+
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
254+
orig_connector, getattr(aiohttp, "UnixConnector")
255+
):
256+
path = getattr(orig_connector, "_path", None)
257+
if path:
258+
session_kwargs["connector"] = aiohttp.UnixConnector(
259+
path=path,
260+
limit=getattr(orig_connector, "_limit", 100),
261+
force_close=getattr(orig_connector, "_force_close", False),
262+
)
263+
else:
264+
raise exceptions.TransportError(
265+
f"Unsupported connector type for cloning: {type(orig_connector)}"
266+
)
267+
268+
# Preserve distributed tracing configurations.
269+
trace_configs = getattr(self._session, "_trace_configs", None)
270+
if trace_configs:
271+
session_kwargs["trace_configs"] = list(trace_configs)
272+
273+
# Copy session-level defaults (headers, cookies, auth, timeout).
274+
for attr_name, kwarg_name in [
275+
("_default_headers", "headers"),
276+
("_cookie_jar", "cookie_jar"),
277+
("_default_auth", "auth"),
278+
("_timeout", "timeout"),
279+
("_json_serialize", "json_serialize"),
280+
]:
281+
val = getattr(self._session, attr_name, None)
282+
if val is not None:
283+
session_kwargs[kwarg_name] = val
284+
285+
return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore

packages/google-auth/google/auth/transport/_aiohttp_requests.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(self, session=None):
148148
"Client sessions with auto_decompress=True are not supported."
149149
)
150150
self.session = session
151+
self._closed = False
151152

152153
async def __call__(
153154
self,
@@ -183,6 +184,9 @@ async def __call__(
183184
"""
184185

185186
try:
187+
if getattr(self, "_closed", False):
188+
raise exceptions.TransportError("session is closed.")
189+
186190
if self.session is None: # pragma: NO COVER
187191
self.session = aiohttp.ClientSession(
188192
auto_decompress=False
@@ -202,6 +206,92 @@ async def __call__(
202206
new_exc = exceptions.TransportError(caught_exc)
203207
raise new_exc from caught_exc
204208

209+
def _clone(self):
210+
"""Creates an independent copy of this request adapter.
211+
212+
Clones the connection settings, trace configurations, and session defaults
213+
(headers, cookies, basic auth, and timeouts).
214+
215+
Only standard `aiohttp.TCPConnector` and `aiohttp.UnixConnector` connectors
216+
are supported. The DNS resolver is not copied to avoid closing shared resolver
217+
resources.
218+
219+
Returns:
220+
google.auth.transport._aiohttp_requests.Request: A new request adapter.
221+
222+
Raises:
223+
google.auth.exceptions.TransportError: If the transport is closed, or if the
224+
session uses an unsupported connector.
225+
"""
226+
if getattr(self, "_closed", False):
227+
raise exceptions.TransportError("Cannot clone a closed transport.")
228+
229+
if not self.session:
230+
new_session = aiohttp.ClientSession(
231+
auto_decompress=False,
232+
trust_env=True,
233+
)
234+
return Request(session=new_session)
235+
236+
session_kwargs: dict = {
237+
"auto_decompress": False,
238+
"trust_env": getattr(self.session, "_trust_env", True),
239+
}
240+
241+
# Copy underlying connection pool settings (SSL context, IP bindings, limits).
242+
orig_connector = getattr(self.session, "_connector", None)
243+
if orig_connector and not getattr(orig_connector, "closed", True):
244+
if isinstance(orig_connector, aiohttp.TCPConnector):
245+
# We explicitly do not copy the resolver. The connector
246+
# owns the resolver, and closing the cloned session would
247+
# close the shared resolver, breaking the original session.
248+
session_kwargs["connector"] = aiohttp.TCPConnector(
249+
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
250+
limit=getattr(orig_connector, "_limit", 100),
251+
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
252+
force_close=getattr(orig_connector, "_force_close", False),
253+
local_addr=getattr(orig_connector, "_local_addr", None),
254+
)
255+
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
256+
orig_connector, getattr(aiohttp, "UnixConnector")
257+
):
258+
path = getattr(orig_connector, "_path", None)
259+
if path:
260+
session_kwargs["connector"] = aiohttp.UnixConnector(
261+
path=path,
262+
limit=getattr(orig_connector, "_limit", 100),
263+
force_close=getattr(orig_connector, "_force_close", False),
264+
)
265+
else:
266+
raise exceptions.TransportError(
267+
f"Unsupported connector type for cloning: {type(orig_connector)}"
268+
)
269+
270+
# Preserve distributed tracing configurations.
271+
trace_configs = getattr(self.session, "_trace_configs", None)
272+
if trace_configs:
273+
session_kwargs["trace_configs"] = list(trace_configs)
274+
275+
# Copy session-level defaults (headers, cookies, auth, timeout).
276+
for attr_name, kwarg_name in [
277+
("_default_headers", "headers"),
278+
("_cookie_jar", "cookie_jar"),
279+
("_default_auth", "auth"),
280+
("_timeout", "timeout"),
281+
("_json_serialize", "json_serialize"),
282+
]:
283+
val = getattr(self.session, attr_name, None)
284+
if val is not None:
285+
session_kwargs[kwarg_name] = val
286+
287+
return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore
288+
289+
async def close(self):
290+
"""Cleanly release the underlying aiohttp ClientSession resources."""
291+
if not getattr(self, "_closed", False) and self.session:
292+
await self.session.close()
293+
self._closed = True
294+
205295

206296
class AuthorizedSession(aiohttp.ClientSession):
207297
"""This is an async implementation of the Authorized Session class. We utilize an

0 commit comments

Comments
 (0)