gtat-tech-career-kickstarte.../solution/tests/test_client/admin_test_client.py

79 lines
3.7 KiB
Python

from decimal import Decimal
import socket
from typing import Callable
import logging
from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory
from connection.ip_address import IpAddress
from proto.admin_pb2 import CreateInstrumentRequest, CreateInstrumentResponse
from proto.common_pb2 import Instrument, MessageType
from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation
logger = logging.getLogger(__name__)
class AdminTestClient(ConnectionHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
call_expectations_manager: CallExpectationsManager) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.next_request_id = 1
self.callbacks: dict[int, Callable] = {}
self.call_expectations_manager = call_expectations_manager
self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response")
def on_disconnect(self) -> None:
logger.info(f"Disconnected from {self.ip_address}")
def handle_message(self, message_type: int, raw_message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)}")
if message_type == MessageType.ADMIN_CREATE_INSTRUMENT_RESPONSE:
message = CreateInstrumentResponse.FromString(raw_message)
else:
raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}")
callback = self.callbacks.pop(message.request_id, None)
if callback:
callback(message)
else:
raise ValueError(f"Received response with unknown request_id: {message.request_id}")
def test_create_instrument(self, instrument: Instrument, tick_size: Decimal,
expect_success: bool = True) -> ResponseExpectation[CreateInstrumentResponse]:
request = CreateInstrumentRequest(instrument=instrument, tick_size=float(tick_size))
request_id = self._get_next_request_id()
request.request_id = request_id
self.callbacks[request_id] = self.on_create_instrument_response
self.send_message(MessageType.ADMIN_CREATE_INSTRUMENT_REQUEST, request)
response_expectation = ResponseExpectation(
self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
if expect_success:
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, "Expected CreateInstrumentResponse not received"
return response_expectation
def send_message(self, message_type: int, message) -> None:
logger.info(f"Sending message of type {MessageType.Name(MessageType.ValueType(message_type))}")
super().send_message(message_type, message)
def _get_next_request_id(self) -> int:
request_id = self.next_request_id
self.next_request_id += 1
return request_id
class AdminClientConnectionHandlerFactory(ConnectionHandlerFactory[AdminTestClient]):
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
self.call_expectations_manager = call_expectations_manager
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> AdminTestClient:
return AdminTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager)
def on_connection_closed(self, connection_handler: AdminTestClient) -> None:
pass