gtat-tech-career-kickstarte.../solution/tests/common/mock_expectations.py

198 lines
8.9 KiB
Python

from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Generic, TypeVar
from unittest.mock import MagicMock, _Call
import logging
from connection.tcp_connection_manager import TcpConnectionManager
logger = logging.getLogger(__name__)
T = TypeVar('T')
DEFAULT_RESPONSE_TIMEOUT = 0.5 # in seconds
class CallExpectations(ABC):
def __init__(self, mocked_method: MagicMock) -> None:
self.mocked_method = mocked_method
self.fulfilled = False
self.fulfilled_call: _Call | None = None
@abstractmethod
def is_expected_call(self, mock_call: _Call) -> bool:
"""
Check if the given mock call is the expected event.
Do NOT raise AssertionError in this method. Return False if the call is not expected.
"""
pass
@abstractmethod
def validate_call(self, mock_call: _Call) -> None:
"""
Validate the given mock call.
Raise AssertionError if the call is not expected.
"""
pass
def verify(self) -> bool:
for index in range(self.mocked_method.call_count):
mock_call = self.mocked_method.call_args_list[index]
if not self.is_expected_call(mock_call):
logger.debug(f"Skipping unwanted call: {mock_call}")
continue
logger.debug(f"Found expected call: {mock_call}")
logger.debug("Validating call...")
self.validate_call(mock_call)
logger.debug(f"Call validated successfully")
self.fulfilled = True
self.fulfilled_call = mock_call
self.mocked_method.call_count -= 1
self.mocked_method.call_args_list.pop(index)
self.mocked_method.mock_calls.pop(index)
return True
logger.debug(f"Expected call of {self.mocked_method._mock_name} not found")
return False
def __repr__(self) -> str:
cls_name = self.__class__.__name__
base = f"<{cls_name}(mocked_method={getattr(self.mocked_method, '_mock_name', self.mocked_method)})"
extra_attrs = []
for k, v in sorted(self.__dict__.items()):
if k in {'mocked_method', 'fulfilled', 'fulfilled_call'}:
continue
extra_attrs.append(f"{k}={repr(v)}")
if extra_attrs:
base += ", " + ", ".join(extra_attrs)
return base + ">"
class ResponseExpectation(Generic[T], CallExpectations):
def __init__(self, mocked_method: MagicMock, response_type: type[T], request_id: int, expect_success: bool = True) -> None:
super().__init__(mocked_method)
self.response_type = response_type
self.request_id = request_id
self.expect_success = expect_success
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, self.response_type):
return False
if not hasattr(call_arg, "request_id"):
return False
if call_arg.request_id != self.request_id:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert hasattr(call_arg, "error_message")
if self.expect_success:
assert call_arg.error_message == "", f"Unexpected error: {call_arg.error_message}"
else:
assert call_arg.error_message != "", "Unexpected success response"
def get_response(self) -> T:
assert self.fulfilled_call is not None
return self.fulfilled_call.args[0]
class CallExpectationsManager:
def __init__(self) -> None:
self.pending_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list)
self.fulfilled_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list)
self.tcp_connection_manager: TcpConnectionManager | None = None
self.response_timeout: float = DEFAULT_RESPONSE_TIMEOUT
def create_mock(self, mock_name: str) -> MagicMock:
mocked_method = MagicMock(name=mock_name)
self.register_mock(mocked_method)
return mocked_method
def register_mock(self, mocked_method: MagicMock) -> None:
self.pending_expectations_by_mock[mocked_method]
def add_expectation(self, call_expectation: CallExpectations) -> None:
self.pending_expectations_by_mock[call_expectation.mocked_method].append(call_expectation)
def get_number_of_pending_calls(self) -> int:
return sum([len(pending_expectations) for pending_expectations in self.pending_expectations_by_mock.values()])
def setup_network(self, tcp_connection_manager: TcpConnectionManager, response_timeout: float = DEFAULT_RESPONSE_TIMEOUT) -> None:
"""
Set up the network to be able to await network events when verifying expectations.
Ideally this would be set in the constructor, but it's not possible due to circular dependencies.
For this reason, we allow the network to not be set up and expectations still to be checked.
"""
self.tcp_connection_manager = tcp_connection_manager
self.response_timeout = response_timeout
def verify_expectations(self, assert_no_pending_calls: bool = True) -> None:
# Check pending expectations before awaiting network events because it may fulfill some expectations and avoid
# unnecessary blocking network calls. Then check it again after awaiting network events to ensure all expectations
# are fulfilled.
self._check_pending_expectations()
if self.tcp_connection_manager:
self._await_network_events()
self._check_pending_expectations()
if assert_no_pending_calls:
if self.get_number_of_pending_calls() > 0:
logger.error(self._make_issues_report())
assert self.get_number_of_pending_calls() == 0, self._make_issues_report()
def verify_no_unexpected_calls(self, assert_no_unexpected_calls: bool = True) -> None:
has_unexpected_calls = False
issues = []
for mocked_method, _ in self.pending_expectations_by_mock.items():
for index in range(mocked_method.call_count):
has_unexpected_calls = True
mock_call = mocked_method.call_args_list[index]
issues.append(f"Detected unexpected call to {mocked_method._mock_name}: {repr(mock_call)}")
if assert_no_unexpected_calls:
assert not has_unexpected_calls, "\n".join(issues)
def _make_issues_report(self) -> str:
return "Expectations not fulfilled:\n" + "\n".join([f"{mocked_method._mock_name}: {call_expectation}"
for mocked_method, call_expectations in self.pending_expectations_by_mock.items()
for call_expectation in call_expectations if not call_expectation.fulfilled])
def _check_pending_expectations(self) -> None:
for mocked_method, call_expectations in self.pending_expectations_by_mock.items():
if not call_expectations:
continue
logger.info(f"Verifying {len(call_expectations)} expectations for {mocked_method._mock_name}")
for call_expectation in call_expectations.copy():
if call_expectation.verify():
call_expectations.remove(call_expectation)
self.fulfilled_expectations_by_mock[mocked_method].append(call_expectation)
else:
logger.debug(f"Expectation still pending for {mocked_method._mock_name}: {call_expectation}")
def _await_network_events(self) -> None:
assert self.tcp_connection_manager is not None, "Network not set up"
pending_calls_count = self.get_number_of_pending_calls()
next_timeout = self.response_timeout if pending_calls_count > 0 else 0
logger.info(f"Awaiting (at least) {pending_calls_count} network events for up to {next_timeout} seconds...")
total_received_events = 0
while True:
received_events = self.tcp_connection_manager.wait_for_events(timeout_in_seconds=next_timeout)
logger.debug(f"Received {received_events} network events")
if received_events == 0:
logger.debug("Stopping network event waiting loop")
break
total_received_events += received_events
pending_calls_count = self.get_number_of_pending_calls()
if pending_calls_count == 0:
logger.debug("Received all pending network events. Next network read attempts will be non-blocking.")
next_timeout = 0
else:
logger.debug(f"Still waiting for {pending_calls_count} pending network events")
logger.debug(f"Received {total_received_events} network events in total")