266 lines
11 KiB
Python
266 lines
11 KiB
Python
"""Generic component orchestrator for system tests.
|
|
|
|
Reads deployment_config.json and testing_dependencies.json to determine which
|
|
components to start for a given protocol, resolves their startup order via
|
|
topological sort on connectTo references, assigns dynamic ports, and manages
|
|
process lifecycles.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import socket
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from connection.ip_address import IpAddress
|
|
from tests.common.process_manager import PerformanceStats, ProcessManager
|
|
from tests.common import constants
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
COMPONENT_DEPENDENCIES_FILE = Path(__file__).parent.parent / "continuous_deployment" / "testing_dependencies.json"
|
|
|
|
|
|
def _allocate_ports(count: int) -> list[int]:
|
|
"""Reserve `count` ephemeral ports from the OS, minimising collision risk."""
|
|
sockets: list[socket.socket] = []
|
|
ports: list[int] = []
|
|
try:
|
|
for _ in range(count):
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.bind(("localhost", 0))
|
|
ports.append(s.getsockname()[1])
|
|
sockets.append(s)
|
|
finally:
|
|
for s in sockets:
|
|
s.close()
|
|
return ports
|
|
|
|
|
|
@dataclass
|
|
class ComponentInstance:
|
|
name: str
|
|
package_name: str
|
|
protocols: list[str]
|
|
original_config: dict[str, Any]
|
|
auth_required: bool
|
|
address: IpAddress | None = None
|
|
process_manager: ProcessManager | None = None
|
|
|
|
|
|
class ComponentOrchestrator:
|
|
"""Manages the lifecycle of all components needed to test a specific protocol.
|
|
|
|
Typical usage inside a pytest fixture::
|
|
|
|
orchestrator = ComponentOrchestrator(venv_path, deployment_config, users_data_file)
|
|
orchestrator.start_for_protocol("order_book")
|
|
server_addr = orchestrator.get_server_address("order_book")
|
|
yield
|
|
orchestrator.stop_all()
|
|
"""
|
|
|
|
def __init__(self, venv_path: Path, deployment_config: dict[str, Any],
|
|
data_file_path: Path) -> None:
|
|
self.venv_path = venv_path
|
|
self.deployment_config = deployment_config
|
|
self.data_file_path = data_file_path
|
|
self.component_dependencies = self._load_component_dependencies()
|
|
self._running_components: list[ComponentInstance] = []
|
|
self._protocol_to_address: dict[str, IpAddress] = {}
|
|
|
|
@staticmethod
|
|
def _load_component_dependencies() -> dict[str, Any]:
|
|
with open(COMPONENT_DEPENDENCIES_FILE) as f:
|
|
return json.load(f)
|
|
|
|
def start_for_protocol(self, protocol: str) -> None:
|
|
"""Resolve dependencies and start every component needed to test *protocol*."""
|
|
required_protocols = self._resolve_required_protocols(protocol)
|
|
logger.info(f"Protocol '{protocol}' requires protocols: {required_protocols}")
|
|
|
|
components = self._select_components(required_protocols)
|
|
ordered = self._topological_sort(components)
|
|
logger.info(f"Component startup order: {[c.name for c in ordered]}")
|
|
|
|
ports = _allocate_ports(len(ordered))
|
|
port_assignments: dict[str, int] = {
|
|
comp.name: port for comp, port in zip(ordered, ports)
|
|
}
|
|
|
|
for comp in ordered:
|
|
config = self._build_runtime_config(comp, port_assignments)
|
|
address = IpAddress(host="localhost", port=port_assignments[comp.name])
|
|
|
|
pm = ProcessManager(self.venv_path)
|
|
assert pm.start_process(comp.package_name, config), (
|
|
f"Failed to start component '{comp.name}' (binary: {comp.package_name})"
|
|
)
|
|
|
|
startup_ms = pm.wait_until_server_is_ready(address, timeout_in_seconds=2)
|
|
logger.info(
|
|
f"Component '{comp.name}' ready in {startup_ms:.2f} ms "
|
|
f"on port {address.port}"
|
|
)
|
|
|
|
comp.address = address
|
|
comp.process_manager = pm
|
|
self._running_components.append(comp)
|
|
|
|
for p in comp.protocols:
|
|
self._protocol_to_address[p] = address
|
|
|
|
def get_server_address(self, protocol: str) -> IpAddress:
|
|
"""Return the listen address of the component implementing *protocol*."""
|
|
if protocol not in self._protocol_to_address:
|
|
raise ValueError(f"No running component implements protocol '{protocol}'")
|
|
return self._protocol_to_address[protocol]
|
|
|
|
def is_auth_required(self, protocol: str) -> bool:
|
|
"""Return whether the component implementing *protocol* requires authentication."""
|
|
for comp in self._running_components:
|
|
if protocol in comp.protocols:
|
|
return comp.auth_required
|
|
raise ValueError(f"No running component implements protocol '{protocol}'")
|
|
|
|
def stop_all(self) -> dict[str, PerformanceStats]:
|
|
"""Stop all running components in reverse startup order.
|
|
|
|
Returns a mapping of component name to performance statistics.
|
|
"""
|
|
perf_stats: dict[str, PerformanceStats] = {}
|
|
for comp in reversed(self._running_components):
|
|
if comp.process_manager:
|
|
try:
|
|
comp.process_manager.stop_process()
|
|
if comp.process_manager.performance_stats:
|
|
perf_stats[comp.name] = comp.process_manager.performance_stats
|
|
except Exception:
|
|
logger.exception(f"Error stopping component '{comp.name}'")
|
|
self._running_components.clear()
|
|
self._protocol_to_address.clear()
|
|
return perf_stats
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _resolve_required_protocols(self, protocol: str) -> set[str]:
|
|
"""Transitively collect all *required* protocols for *protocol*."""
|
|
result: set[str] = {protocol}
|
|
queue: deque[str] = deque([protocol])
|
|
while queue:
|
|
current = queue.popleft()
|
|
deps = self.component_dependencies.get(current, [])
|
|
for dep in deps:
|
|
if dep not in result:
|
|
result.add(dep)
|
|
queue.append(dep)
|
|
return result
|
|
|
|
def _select_components(self, required_protocols: set[str]) -> list[ComponentInstance]:
|
|
"""Pick the minimal set of components from the deployment config that
|
|
cover *required_protocols*, plus any transitive ``connectTo``
|
|
dependencies. A single component may cover several protocols."""
|
|
comp_cfg_by_name: dict[str, dict[str, Any]] = {
|
|
cfg["name"]: cfg for cfg in self.deployment_config["components"]
|
|
}
|
|
|
|
selected_names: set[str] = set()
|
|
for comp_cfg in self.deployment_config["components"]:
|
|
if set(comp_cfg["protocols"]) & required_protocols:
|
|
selected_names.add(comp_cfg["name"])
|
|
|
|
queue: deque[str] = deque(selected_names)
|
|
while queue:
|
|
name = queue.popleft()
|
|
cfg = comp_cfg_by_name.get(name)
|
|
assert cfg is not None, f"Component '{name}' not defined in deployment config"
|
|
for target in cfg["config"].get("connectTo", {}):
|
|
if target not in selected_names:
|
|
assert target in comp_cfg_by_name, f"Component '{target}' not defined in deployment config"
|
|
selected_names.add(target)
|
|
queue.append(target)
|
|
|
|
components: list[ComponentInstance] = []
|
|
covered: set[str] = set()
|
|
for comp_cfg in self.deployment_config["components"]:
|
|
if comp_cfg["name"] in selected_names:
|
|
comp_protocols = set(comp_cfg["protocols"])
|
|
components.append(ComponentInstance(
|
|
name=comp_cfg["name"],
|
|
package_name=comp_cfg["packageName"],
|
|
protocols=comp_protocols,
|
|
original_config=comp_cfg["config"],
|
|
auth_required=comp_cfg.get("authRequired", True),
|
|
))
|
|
covered |= comp_protocols
|
|
|
|
missing = required_protocols - covered
|
|
if missing:
|
|
raise ValueError(
|
|
f"Protocols {missing} are required but not implemented by "
|
|
f"any component in deployment_config.json"
|
|
)
|
|
return components
|
|
|
|
def _topological_sort(self, components: list[ComponentInstance]) -> list[ComponentInstance]:
|
|
"""Sort components so that each component's connectTo targets are
|
|
started before the component itself (Kahn's algorithm)."""
|
|
name_set = {c.name for c in components}
|
|
by_name = {c.name: c for c in components}
|
|
|
|
# adj[A] = {B, ...} means A depends on (connects to) B
|
|
adj: dict[str, set[str]] = {c.name: set() for c in components}
|
|
for comp in components:
|
|
for target in comp.original_config.get("connectTo", {}):
|
|
if target in name_set:
|
|
adj[comp.name].add(target)
|
|
|
|
in_deg = {name: len(deps) for name, deps in adj.items()}
|
|
queue: deque[str] = deque(n for n, d in in_deg.items() if d == 0)
|
|
result: list[ComponentInstance] = []
|
|
|
|
while queue:
|
|
name = queue.popleft()
|
|
result.append(by_name[name])
|
|
for other, deps in adj.items():
|
|
if name in deps:
|
|
in_deg[other] -= 1
|
|
if in_deg[other] == 0:
|
|
queue.append(other)
|
|
|
|
if len(result) != len(components):
|
|
raise ValueError("Circular dependency detected among components")
|
|
return result
|
|
|
|
def _build_runtime_config(
|
|
self, comp: ComponentInstance, port_assignments: dict[str, int]
|
|
) -> dict[str, Any]:
|
|
"""Clone the component's config with dynamic ports and the test log dir."""
|
|
config: dict[str, Any] = dict(comp.original_config)
|
|
config["logDirectory"] = constants.LOG_DIRECTORY
|
|
config["listenOn"] = {
|
|
"host": "localhost",
|
|
"port": port_assignments[comp.name],
|
|
}
|
|
|
|
config["dataFilePath"] = str(self.data_file_path)
|
|
|
|
if "connectTo" in config:
|
|
rewritten: dict[str, Any] = {}
|
|
for target_name, target_cfg in config["connectTo"].items():
|
|
if target_name in port_assignments:
|
|
rewritten[target_name] = {
|
|
"host": "localhost",
|
|
"port": port_assignments[target_name],
|
|
}
|
|
else:
|
|
rewritten[target_name] = target_cfg
|
|
config["connectTo"] = rewritten
|
|
|
|
return config
|