198 lines
8.9 KiB
Python
198 lines
8.9 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from typing import Generic, TypeVar
|
|
from unittest.mock import MagicMock, _Call
|
|
import logging
|
|
|
|
from connection.tcp_connection_manager import TcpConnectionManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar('T')
|
|
|
|
DEFAULT_RESPONSE_TIMEOUT = 0.5 # in seconds
|
|
|
|
|
|
class CallExpectations(ABC):
|
|
def __init__(self, mocked_method: MagicMock) -> None:
|
|
self.mocked_method = mocked_method
|
|
self.fulfilled = False
|
|
self.fulfilled_call: _Call | None = None
|
|
|
|
@abstractmethod
|
|
def is_expected_call(self, mock_call: _Call) -> bool:
|
|
"""
|
|
Check if the given mock call is the expected event.
|
|
Do NOT raise AssertionError in this method. Return False if the call is not expected.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_call(self, mock_call: _Call) -> None:
|
|
"""
|
|
Validate the given mock call.
|
|
Raise AssertionError if the call is not expected.
|
|
"""
|
|
pass
|
|
|
|
def verify(self) -> bool:
|
|
for index in range(self.mocked_method.call_count):
|
|
mock_call = self.mocked_method.call_args_list[index]
|
|
if not self.is_expected_call(mock_call):
|
|
logger.debug(f"Skipping unwanted call: {mock_call}")
|
|
continue
|
|
|
|
logger.debug(f"Found expected call: {mock_call}")
|
|
logger.debug("Validating call...")
|
|
self.validate_call(mock_call)
|
|
logger.debug(f"Call validated successfully")
|
|
self.fulfilled = True
|
|
self.fulfilled_call = mock_call
|
|
|
|
self.mocked_method.call_count -= 1
|
|
self.mocked_method.call_args_list.pop(index)
|
|
self.mocked_method.mock_calls.pop(index)
|
|
return True
|
|
logger.debug(f"Expected call of {self.mocked_method._mock_name} not found")
|
|
return False
|
|
|
|
def __repr__(self) -> str:
|
|
cls_name = self.__class__.__name__
|
|
base = f"<{cls_name}(mocked_method={getattr(self.mocked_method, '_mock_name', self.mocked_method)})"
|
|
extra_attrs = []
|
|
for k, v in sorted(self.__dict__.items()):
|
|
if k in {'mocked_method', 'fulfilled', 'fulfilled_call'}:
|
|
continue
|
|
extra_attrs.append(f"{k}={repr(v)}")
|
|
if extra_attrs:
|
|
base += ", " + ", ".join(extra_attrs)
|
|
return base + ">"
|
|
|
|
|
|
class ResponseExpectation(Generic[T], CallExpectations):
|
|
def __init__(self, mocked_method: MagicMock, response_type: type[T], request_id: int, expect_success: bool = True) -> None:
|
|
super().__init__(mocked_method)
|
|
self.response_type = response_type
|
|
self.request_id = request_id
|
|
self.expect_success = expect_success
|
|
|
|
def is_expected_call(self, mock_call: _Call) -> bool:
|
|
if len(mock_call.args) != 1:
|
|
return False
|
|
call_arg = mock_call.args[0]
|
|
if not isinstance(call_arg, self.response_type):
|
|
return False
|
|
|
|
if not hasattr(call_arg, "request_id"):
|
|
return False
|
|
if call_arg.request_id != self.request_id:
|
|
return False
|
|
|
|
return True
|
|
|
|
def validate_call(self, mock_call: _Call) -> None:
|
|
call_arg = mock_call.args[0]
|
|
assert hasattr(call_arg, "error_message")
|
|
if self.expect_success:
|
|
assert call_arg.error_message == "", f"Unexpected error: {call_arg.error_message}"
|
|
else:
|
|
assert call_arg.error_message != "", "Unexpected success response"
|
|
|
|
def get_response(self) -> T:
|
|
assert self.fulfilled_call is not None
|
|
return self.fulfilled_call.args[0]
|
|
|
|
|
|
class CallExpectationsManager:
|
|
def __init__(self) -> None:
|
|
self.pending_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list)
|
|
self.fulfilled_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list)
|
|
self.tcp_connection_manager: TcpConnectionManager | None = None
|
|
self.response_timeout: float = DEFAULT_RESPONSE_TIMEOUT
|
|
|
|
def create_mock(self, mock_name: str) -> MagicMock:
|
|
mocked_method = MagicMock(name=mock_name)
|
|
self.register_mock(mocked_method)
|
|
return mocked_method
|
|
|
|
def register_mock(self, mocked_method: MagicMock) -> None:
|
|
self.pending_expectations_by_mock[mocked_method]
|
|
|
|
def add_expectation(self, call_expectation: CallExpectations) -> None:
|
|
self.pending_expectations_by_mock[call_expectation.mocked_method].append(call_expectation)
|
|
|
|
def get_number_of_pending_calls(self) -> int:
|
|
return sum([len(pending_expectations) for pending_expectations in self.pending_expectations_by_mock.values()])
|
|
|
|
def setup_network(self, tcp_connection_manager: TcpConnectionManager, response_timeout: float = DEFAULT_RESPONSE_TIMEOUT) -> None:
|
|
"""
|
|
Set up the network to be able to await network events when verifying expectations.
|
|
Ideally this would be set in the constructor, but it's not possible due to circular dependencies.
|
|
For this reason, we allow the network to not be set up and expectations still to be checked.
|
|
"""
|
|
self.tcp_connection_manager = tcp_connection_manager
|
|
self.response_timeout = response_timeout
|
|
|
|
def verify_expectations(self, assert_no_pending_calls: bool = True) -> None:
|
|
# Check pending expectations before awaiting network events because it may fulfill some expectations and avoid
|
|
# unnecessary blocking network calls. Then check it again after awaiting network events to ensure all expectations
|
|
# are fulfilled.
|
|
self._check_pending_expectations()
|
|
if self.tcp_connection_manager:
|
|
self._await_network_events()
|
|
self._check_pending_expectations()
|
|
if assert_no_pending_calls:
|
|
if self.get_number_of_pending_calls() > 0:
|
|
logger.error(self._make_issues_report())
|
|
assert self.get_number_of_pending_calls() == 0, self._make_issues_report()
|
|
|
|
def verify_no_unexpected_calls(self, assert_no_unexpected_calls: bool = True) -> None:
|
|
has_unexpected_calls = False
|
|
issues = []
|
|
for mocked_method, _ in self.pending_expectations_by_mock.items():
|
|
for index in range(mocked_method.call_count):
|
|
has_unexpected_calls = True
|
|
mock_call = mocked_method.call_args_list[index]
|
|
issues.append(f"Detected unexpected call to {mocked_method._mock_name}: {repr(mock_call)}")
|
|
if assert_no_unexpected_calls:
|
|
assert not has_unexpected_calls, "\n".join(issues)
|
|
|
|
def _make_issues_report(self) -> str:
|
|
return "Expectations not fulfilled:\n" + "\n".join([f"{mocked_method._mock_name}: {call_expectation}"
|
|
for mocked_method, call_expectations in self.pending_expectations_by_mock.items()
|
|
for call_expectation in call_expectations if not call_expectation.fulfilled])
|
|
|
|
def _check_pending_expectations(self) -> None:
|
|
for mocked_method, call_expectations in self.pending_expectations_by_mock.items():
|
|
if not call_expectations:
|
|
continue
|
|
logger.info(f"Verifying {len(call_expectations)} expectations for {mocked_method._mock_name}")
|
|
for call_expectation in call_expectations.copy():
|
|
if call_expectation.verify():
|
|
call_expectations.remove(call_expectation)
|
|
self.fulfilled_expectations_by_mock[mocked_method].append(call_expectation)
|
|
else:
|
|
logger.debug(f"Expectation still pending for {mocked_method._mock_name}: {call_expectation}")
|
|
|
|
def _await_network_events(self) -> None:
|
|
assert self.tcp_connection_manager is not None, "Network not set up"
|
|
pending_calls_count = self.get_number_of_pending_calls()
|
|
next_timeout = self.response_timeout if pending_calls_count > 0 else 0
|
|
logger.info(f"Awaiting (at least) {pending_calls_count} network events for up to {next_timeout} seconds...")
|
|
|
|
total_received_events = 0
|
|
while True:
|
|
received_events = self.tcp_connection_manager.wait_for_events(timeout_in_seconds=next_timeout)
|
|
logger.debug(f"Received {received_events} network events")
|
|
if received_events == 0:
|
|
logger.debug("Stopping network event waiting loop")
|
|
break
|
|
total_received_events += received_events
|
|
pending_calls_count = self.get_number_of_pending_calls()
|
|
if pending_calls_count == 0:
|
|
logger.debug("Received all pending network events. Next network read attempts will be non-blocking.")
|
|
next_timeout = 0
|
|
else:
|
|
logger.debug(f"Still waiting for {pending_calls_count} pending network events")
|
|
logger.debug(f"Received {total_received_events} network events in total")
|