79 lines
3.7 KiB
Python
79 lines
3.7 KiB
Python
from decimal import Decimal
|
|
import socket
|
|
from typing import Callable
|
|
import logging
|
|
|
|
from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory
|
|
from connection.ip_address import IpAddress
|
|
from proto.admin_pb2 import CreateInstrumentRequest, CreateInstrumentResponse
|
|
from proto.common_pb2 import Instrument, MessageType
|
|
|
|
from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AdminTestClient(ConnectionHandler):
|
|
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
|
|
call_expectations_manager: CallExpectationsManager) -> None:
|
|
super().__init__(socket_fd, ip_address, on_close)
|
|
self.next_request_id = 1
|
|
self.callbacks: dict[int, Callable] = {}
|
|
self.call_expectations_manager = call_expectations_manager
|
|
self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response")
|
|
|
|
def on_disconnect(self) -> None:
|
|
logger.info(f"Disconnected from {self.ip_address}")
|
|
|
|
def handle_message(self, message_type: int, raw_message: bytes) -> None:
|
|
logger.info(f"Handling message of type {MessageType.Name(message_type)}")
|
|
if message_type == MessageType.ADMIN_CREATE_INSTRUMENT_RESPONSE:
|
|
message = CreateInstrumentResponse.FromString(raw_message)
|
|
else:
|
|
raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}")
|
|
|
|
callback = self.callbacks.pop(message.request_id, None)
|
|
if callback:
|
|
callback(message)
|
|
else:
|
|
raise ValueError(f"Received response with unknown request_id: {message.request_id}")
|
|
|
|
def test_create_instrument(self, instrument: Instrument, tick_size: Decimal,
|
|
expect_success: bool = True) -> ResponseExpectation[CreateInstrumentResponse]:
|
|
request = CreateInstrumentRequest(instrument=instrument, tick_size=float(tick_size))
|
|
request_id = self._get_next_request_id()
|
|
request.request_id = request_id
|
|
self.callbacks[request_id] = self.on_create_instrument_response
|
|
self.send_message(MessageType.ADMIN_CREATE_INSTRUMENT_REQUEST, request)
|
|
|
|
response_expectation = ResponseExpectation(
|
|
self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success)
|
|
self.call_expectations_manager.add_expectation(response_expectation)
|
|
|
|
if expect_success:
|
|
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
|
assert response_expectation.fulfilled, "Expected CreateInstrumentResponse not received"
|
|
|
|
return response_expectation
|
|
|
|
def send_message(self, message_type: int, message) -> None:
|
|
logger.info(f"Sending message of type {MessageType.Name(MessageType.ValueType(message_type))}")
|
|
super().send_message(message_type, message)
|
|
|
|
def _get_next_request_id(self) -> int:
|
|
request_id = self.next_request_id
|
|
self.next_request_id += 1
|
|
return request_id
|
|
|
|
|
|
class AdminClientConnectionHandlerFactory(ConnectionHandlerFactory[AdminTestClient]):
|
|
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
|
|
self.call_expectations_manager = call_expectations_manager
|
|
|
|
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
|
|
close_callback: Callable[[], None]) -> AdminTestClient:
|
|
return AdminTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager)
|
|
|
|
def on_connection_closed(self, connection_handler: AdminTestClient) -> None:
|
|
pass
|