from abc import ABC, abstractmethod from collections import defaultdict from typing import Generic, TypeVar from unittest.mock import MagicMock, _Call import logging from connection.tcp_connection_manager import TcpConnectionManager logger = logging.getLogger(__name__) T = TypeVar('T') DEFAULT_RESPONSE_TIMEOUT = 0.5 # in seconds class CallExpectations(ABC): def __init__(self, mocked_method: MagicMock) -> None: self.mocked_method = mocked_method self.fulfilled = False self.fulfilled_call: _Call | None = None @abstractmethod def is_expected_call(self, mock_call: _Call) -> bool: """ Check if the given mock call is the expected event. Do NOT raise AssertionError in this method. Return False if the call is not expected. """ pass @abstractmethod def validate_call(self, mock_call: _Call) -> None: """ Validate the given mock call. Raise AssertionError if the call is not expected. """ pass def verify(self) -> bool: for index in range(self.mocked_method.call_count): mock_call = self.mocked_method.call_args_list[index] if not self.is_expected_call(mock_call): logger.debug(f"Skipping unwanted call: {mock_call}") continue logger.debug(f"Found expected call: {mock_call}") logger.debug("Validating call...") self.validate_call(mock_call) logger.debug(f"Call validated successfully") self.fulfilled = True self.fulfilled_call = mock_call self.mocked_method.call_count -= 1 self.mocked_method.call_args_list.pop(index) self.mocked_method.mock_calls.pop(index) return True logger.debug(f"Expected call of {self.mocked_method._mock_name} not found") return False def __repr__(self) -> str: cls_name = self.__class__.__name__ base = f"<{cls_name}(mocked_method={getattr(self.mocked_method, '_mock_name', self.mocked_method)})" extra_attrs = [] for k, v in sorted(self.__dict__.items()): if k in {'mocked_method', 'fulfilled', 'fulfilled_call'}: continue extra_attrs.append(f"{k}={repr(v)}") if extra_attrs: base += ", " + ", ".join(extra_attrs) return base + ">" class ResponseExpectation(Generic[T], CallExpectations): def __init__(self, mocked_method: MagicMock, response_type: type[T], request_id: int, expect_success: bool = True) -> None: super().__init__(mocked_method) self.response_type = response_type self.request_id = request_id self.expect_success = expect_success def is_expected_call(self, mock_call: _Call) -> bool: if len(mock_call.args) != 1: return False call_arg = mock_call.args[0] if not isinstance(call_arg, self.response_type): return False if not hasattr(call_arg, "request_id"): return False if call_arg.request_id != self.request_id: return False return True def validate_call(self, mock_call: _Call) -> None: call_arg = mock_call.args[0] assert hasattr(call_arg, "error_message") if self.expect_success: assert call_arg.error_message == "", f"Unexpected error: {call_arg.error_message}" else: assert call_arg.error_message != "", "Unexpected success response" def get_response(self) -> T: assert self.fulfilled_call is not None return self.fulfilled_call.args[0] class CallExpectationsManager: def __init__(self) -> None: self.pending_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list) self.fulfilled_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list) self.tcp_connection_manager: TcpConnectionManager | None = None self.response_timeout: float = DEFAULT_RESPONSE_TIMEOUT def create_mock(self, mock_name: str) -> MagicMock: mocked_method = MagicMock(name=mock_name) self.register_mock(mocked_method) return mocked_method def register_mock(self, mocked_method: MagicMock) -> None: self.pending_expectations_by_mock[mocked_method] def add_expectation(self, call_expectation: CallExpectations) -> None: self.pending_expectations_by_mock[call_expectation.mocked_method].append(call_expectation) def get_number_of_pending_calls(self) -> int: return sum([len(pending_expectations) for pending_expectations in self.pending_expectations_by_mock.values()]) def setup_network(self, tcp_connection_manager: TcpConnectionManager, response_timeout: float = DEFAULT_RESPONSE_TIMEOUT) -> None: """ Set up the network to be able to await network events when verifying expectations. Ideally this would be set in the constructor, but it's not possible due to circular dependencies. For this reason, we allow the network to not be set up and expectations still to be checked. """ self.tcp_connection_manager = tcp_connection_manager self.response_timeout = response_timeout def verify_expectations(self, assert_no_pending_calls: bool = True) -> None: # Check pending expectations before awaiting network events because it may fulfill some expectations and avoid # unnecessary blocking network calls. Then check it again after awaiting network events to ensure all expectations # are fulfilled. self._check_pending_expectations() if self.tcp_connection_manager: self._await_network_events() self._check_pending_expectations() if assert_no_pending_calls: if self.get_number_of_pending_calls() > 0: logger.error(self._make_issues_report()) assert self.get_number_of_pending_calls() == 0, self._make_issues_report() def verify_no_unexpected_calls(self, assert_no_unexpected_calls: bool = True) -> None: has_unexpected_calls = False issues = [] for mocked_method, _ in self.pending_expectations_by_mock.items(): for index in range(mocked_method.call_count): has_unexpected_calls = True mock_call = mocked_method.call_args_list[index] issues.append(f"Detected unexpected call to {mocked_method._mock_name}: {repr(mock_call)}") if assert_no_unexpected_calls: assert not has_unexpected_calls, "\n".join(issues) def _make_issues_report(self) -> str: return "Expectations not fulfilled:\n" + "\n".join([f"{mocked_method._mock_name}: {call_expectation}" for mocked_method, call_expectations in self.pending_expectations_by_mock.items() for call_expectation in call_expectations if not call_expectation.fulfilled]) def _check_pending_expectations(self) -> None: for mocked_method, call_expectations in self.pending_expectations_by_mock.items(): if not call_expectations: continue logger.info(f"Verifying {len(call_expectations)} expectations for {mocked_method._mock_name}") for call_expectation in call_expectations.copy(): if call_expectation.verify(): call_expectations.remove(call_expectation) self.fulfilled_expectations_by_mock[mocked_method].append(call_expectation) else: logger.debug(f"Expectation still pending for {mocked_method._mock_name}: {call_expectation}") def _await_network_events(self) -> None: assert self.tcp_connection_manager is not None, "Network not set up" pending_calls_count = self.get_number_of_pending_calls() next_timeout = self.response_timeout if pending_calls_count > 0 else 0 logger.info(f"Awaiting (at least) {pending_calls_count} network events for up to {next_timeout} seconds...") total_received_events = 0 while True: received_events = self.tcp_connection_manager.wait_for_events(timeout_in_seconds=next_timeout) logger.debug(f"Received {received_events} network events") if received_events == 0: logger.debug("Stopping network event waiting loop") break total_received_events += received_events pending_calls_count = self.get_number_of_pending_calls() if pending_calls_count == 0: logger.debug("Received all pending network events. Next network read attempts will be non-blocking.") next_timeout = 0 else: logger.debug(f"Still waiting for {pending_calls_count} pending network events") logger.debug(f"Received {total_received_events} network events in total")