Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 118 additions & 32 deletions src/apify/events/_apify_event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import asyncio
import contextlib
import time
from typing import TYPE_CHECKING, Annotated, Self

import websockets.asyncio.client
import websockets.client
import websockets.exceptions
from pydantic import Discriminator, TypeAdapter
from typing_extensions import Unpack, override

Expand All @@ -16,6 +19,7 @@
from apify.log import logger

if TYPE_CHECKING:
from collections.abc import Generator
from types import TracebackType

from crawlee.events._event_manager import EventManagerOptions
Expand Down Expand Up @@ -45,6 +49,17 @@ class ApifyEventManager(EventManager):
with the event system.
"""

_NON_RETRYABLE_CLOSE_CODES = frozenset({1002, 1003, 1007, 1008, 1010})
"""WebSocket close codes for a permanent condition, on which the connection is not re-established.

The platform sends `1008` (policy violation) for an unknown/missing run ID or an exceeded per-run
connection limit. `1002`, `1003`, and `1007` are protocol or data errors, and `1010` a mandatory
extension failure.
"""

_HEALTHY_CONNECTION_MIN_DURATION = 1.0
"""Seconds a connection must stay open to count as healthy, after which a drop reconnects without backoff."""

def __init__(self, configuration: Configuration, **kwargs: Unpack[EventManagerOptions]) -> None:
"""Initialize a new instance.

Expand Down Expand Up @@ -93,50 +108,121 @@ async def __aexit__(
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
if self._platform_events_websocket:
await self._platform_events_websocket.close()

# Cancel the task before closing the websocket so that the closed connection is not treated as a drop
# and followed by a reconnect attempt.
if self._process_platform_messages_task and not self._process_platform_messages_task.done():
self._process_platform_messages_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._process_platform_messages_task

if self._platform_events_websocket:
await self._platform_events_websocket.close()

await super().__aexit__(exc_type, exc_value, exc_traceback)

def _process_connection_exception(self, exc: Exception) -> Exception | None:
"""Decide whether a failed connection attempt to the platform websocket should be retried.

Before the first successful connection, every error is fatal so that `__aenter__` fails fast. After that,
the default `websockets` behavior decides which errors are transient and retried with exponential backoff.
"""
if self._connected_to_platform_websocket and self._connected_to_platform_websocket.done():
return websockets.asyncio.client.process_exception(exc)
return exc

async def _process_platform_messages(self, ws_url: str) -> None:
# The `websockets` reconnect iterator only backs off on failed connection *attempts*, not on a connection
# that opens and is then closed. Track our own backoff here so a server that keeps accepting and immediately
# closing is not hammered; it is reset after a healthy connection so a healthy drop reconnects immediately.
backoff_delays: Generator[float] | None = None

try:
async with websockets.asyncio.client.connect(ws_url) as websocket:
# Used as an async iterator, `connect` reconnects with exponential backoff on failed connection attempts.
async for websocket in websockets.asyncio.client.connect(
ws_url, process_exception=self._process_connection_exception
):
self._platform_events_websocket = websocket
if self._connected_to_platform_websocket is not None:
if self._connected_to_platform_websocket and not self._connected_to_platform_websocket.done():
self._connected_to_platform_websocket.set_result(True)

async for message in websocket:
try:
parsed_message = event_data_adapter.validate_json(message)

if isinstance(parsed_message, DeprecatedEvent):
continue

if isinstance(parsed_message, UnknownEvent):
logger.info(
f'Unknown message received: event_name={parsed_message.name}, '
f'event_data={parsed_message.data}'
)
continue

self.emit(
event=parsed_message.name,
event_data=parsed_message.data
if not isinstance(parsed_message.data, SystemInfoEventData)
else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1),
)

if parsed_message.name == Event.MIGRATING:
await self._emit_persist_state_event_rec_task.stop()
self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True))
except Exception:
logger.exception('Cannot parse Actor event', extra={'raw_message': message})
else:
logger.info('Reconnected to the platform events websocket.')

connection_opened_at = time.monotonic()
connection_lost = await self._consume_messages(websocket)

if not self._should_reconnect_after_close(websocket, connection_lost=connection_lost):
break

# Reconnect a healthy connection immediately; back off only on repeated rapid drops.
if time.monotonic() - connection_opened_at >= self._HEALTHY_CONNECTION_MIN_DURATION:
backoff_delays = None
elif backoff_delays is None:
backoff_delays = websockets.client.backoff()
else:
await asyncio.sleep(next(backoff_delays))
except Exception:
logger.exception('Error in websocket connection')
if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done():
self._connected_to_platform_websocket.set_result(False)

async def _consume_messages(self, websocket: websockets.asyncio.client.ClientConnection) -> bool:
"""Handle platform messages until the connection closes; return whether it was lost vs. closed cleanly."""
try:
async for message in websocket:
await self._handle_platform_message(message)
except websockets.exceptions.ConnectionClosed:
return True
return False

async def _handle_platform_message(self, message: str | bytes) -> None:
"""Parse a single platform message and emit the matching local event."""
try:
parsed_message = event_data_adapter.validate_json(message)

if isinstance(parsed_message, DeprecatedEvent):
return

if isinstance(parsed_message, UnknownEvent):
logger.info(
f'Unknown message received: event_name={parsed_message.name}, event_data={parsed_message.data}'
)
return

self.emit(
event=parsed_message.name,
event_data=parsed_message.data
if not isinstance(parsed_message.data, SystemInfoEventData)
else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1),
)

if parsed_message.name == Event.MIGRATING:
await self._emit_persist_state_event_rec_task.stop()
self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True))
except Exception:
logger.exception('Cannot parse Actor event', extra={'raw_message': message})

def _should_reconnect_after_close(
self,
websocket: websockets.asyncio.client.ClientConnection,
*,
connection_lost: bool,
) -> bool:
"""Log the websocket close and report whether to reconnect (`False` on a non-retryable close code)."""
if websocket.close_code in self._NON_RETRYABLE_CLOSE_CODES:
logger.error(
f'Connection to platform events websocket was closed with a non-retryable code '
f'(code={websocket.close_code}, reason={websocket.close_reason!r}); not reconnecting.'
)
return False

if connection_lost:
logger.warning(
f'Connection to platform events websocket was lost '
f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...'
)
else:
logger.info(
f'Connection to platform events websocket was closed '
f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...'
)
return True
Loading
Loading