diff --git a/pyproject.toml b/pyproject.toml index 43b15ac..24497b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "pyyaml >= 6.0", "pyvisa >= 1.13.0", "pyvisa-py >= 0.7.0", - "dataclass-mage >= 0.25.1", + "pydantic_yaml >= 1.6.0", ] [project.optional-dependencies] diff --git a/src/team1k/client.py b/src/team1k/client.py index 8342326..b9a48d4 100644 --- a/src/team1k/client.py +++ b/src/team1k/client.py @@ -264,7 +264,7 @@ class Client: pbar = _make_pbar(total=n, desc="Capturing", unit="frame") for i in iterator: - raw = recv_bytes_or_msg(self._sock, frame_bytes) + raw = recv_bytes_or_msg(self._sock) if not isinstance(raw, bytes): raise ValueError(f"Expected frame data but got message: {raw}") frames[i] = np.frombuffer(raw, dtype=dtype).reshape(ny, nx) diff --git a/src/team1k/config.py b/src/team1k/config.py index 8d9b816..4f3ee78 100644 --- a/src/team1k/config.py +++ b/src/team1k/config.py @@ -171,7 +171,7 @@ def load_config(path: str | Path | None = None) -> Team1kConfig: config = parse_yaml_raw_as(Team1kConfig, f.read()) logger.info("Config loaded: detector=%s:%d, prefix=%s", - config.detector_ip, config.register_port, config.pv_prefix) + config.detector.ip, config.detector.register_port, config.server.pv_prefix) logger.debug("Full config:\n%s", to_yaml_str(config)) return config diff --git a/src/team1k/detector/data_port.py b/src/team1k/detector/data_port.py index 6887017..546eecd 100644 --- a/src/team1k/detector/data_port.py +++ b/src/team1k/detector/data_port.py @@ -51,6 +51,7 @@ class DataPort: # Try large buffer first, fall back to smaller try: self._socket = UDPSocket(detector_ip, port, recv_buffer_size=buffer_size) + assert self._socket._sock is not None, "Failed to create UDP socket" actual = self._socket._sock.getsockopt( __import__('socket').SOL_SOCKET, __import__('socket').SO_RCVBUF ) diff --git a/src/team1k/detector/udp_transport.py b/src/team1k/detector/udp_transport.py index 0162e4c..e041957 100644 --- a/src/team1k/detector/udp_transport.py +++ b/src/team1k/detector/udp_transport.py @@ -52,6 +52,7 @@ class UDPSocket: @property def fileno(self) -> int: """File descriptor for select().""" + assert self._sock is not None, "Socket is closed" return self._sock.fileno() def send(self, data: bytes | bytearray | memoryview) -> int: @@ -64,6 +65,7 @@ class UDPSocket: Raises: OSError: On send failure. """ + assert self._sock is not None, "Socket is closed" nsent = self._sock.sendto(data, self._dest_addr) if nsent != len(data): raise OSError(f"Sent {nsent} of {len(data)} bytes") @@ -80,6 +82,7 @@ class UDPSocket: Returns: Received bytes, or None on timeout. """ + assert self._sock is not None, "Socket is closed" ready, _, _ = select.select([self._sock], [], [], timeout_sec) if not ready: return None @@ -97,6 +100,7 @@ class UDPSocket: Returns: Number of bytes received, 0 on timeout. """ + assert self._sock is not None, "Socket is closed" ready, _, _ = select.select([self._sock], [], [], timeout_sec) if not ready: return 0 @@ -104,6 +108,7 @@ class UDPSocket: def clear_buffer(self) -> None: """Drain all pending data from the socket (non-blocking).""" + assert self._sock is not None, "Socket is closed" logger.debug("Clearing socket buffer...") self._sock.setblocking(False) try: diff --git a/src/team1k/peripherals/power_base.py b/src/team1k/peripherals/power_base.py index 63b5519..20d955d 100644 --- a/src/team1k/peripherals/power_base.py +++ b/src/team1k/peripherals/power_base.py @@ -10,6 +10,11 @@ import threading import dataclasses import pyvisa +import pyvisa.resources + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from team1k.config import PowerSupplyChannelConfig logger = logging.getLogger(__name__) @@ -20,7 +25,7 @@ class PowerSupplyBase: Args: port: Device port (e.g., "/dev/DetectorPowerSupply" or "COM3"). - settings: Channel settings: {ch_num: (voltage, current, ovp)}. + channels: Channel settings: {ch_num: PowerSupplyChannelConfig}. voltage_step: Voltage step size for ramping (default 0.2V). name: Human-readable name for logging. """ @@ -31,20 +36,20 @@ class PowerSupplyBase: current: float def __init__(self, port: str, - settings: dict[int, tuple[float, float, float]], + channels: dict[int, PowerSupplyChannelConfig], voltage_step: float = 0.2, name: str = "Power Supply"): self._port = f"ASRL{port}::INSTR" - self._settings = settings + self._channels = channels self._voltage_step = voltage_step self.name = name self._is_on = False self._lock = threading.Lock() - def _open_resource(self) -> pyvisa.Resource: + def _open_resource(self) -> pyvisa.resources.MessageBasedResource: """Open the VISA resource.""" rm = pyvisa.ResourceManager("@py") - return rm.open_resource(self._port) + return rm.open_resource(self._port) # pyright: ignore[reportReturnType] def initialize(self) -> None: """ @@ -57,11 +62,11 @@ class PowerSupplyBase: inst = self._open_resource() inst.write("OUTP:GEN 0 \n") - for ch, (voltage, current, ovp) in self._settings.items(): + for ch, config in self._channels.items(): inst.write(f"INST:NSEL {ch} \n") inst.write(f"OUTP:SEL 0 \n") - inst.write(f"APPL {voltage}, {current} \n") - inst.write(f"SOUR:VOLT:PROT {ovp} \n") + inst.write(f"APPL {config.voltage}, {config.current} \n") + inst.write(f"SOUR:VOLT:PROT {config.ovp} \n") inst.write(f"SOUR:VOLT:STEP {self._voltage_step} \n") volt = float(inst.query("SOUR:VOLT? \n")) @@ -85,8 +90,8 @@ class PowerSupplyBase: inst = self._open_resource() - for ch, (voltage, _, _) in self._settings.items(): - output_on = voltage > 0 + for ch, config in self._channels.items(): + output_on = config.voltage > 0 inst.write(f"INST:NSEL {ch} \n") inst.write(f"OUTP:SEL {1 if output_on else 0} \n") @@ -106,7 +111,7 @@ class PowerSupplyBase: inst = self._open_resource() inst.write("OUTP:GEN 0 \n") - for ch in self._settings: + for ch in self._channels.keys(): inst.write(f"INST:NSEL {ch} \n") inst.write(f"OUTP:SEL 0 \n") @@ -130,8 +135,8 @@ class PowerSupplyBase: try: inst = self._open_resource() - for ch, (voltage, _, _) in self._settings.items(): - if voltage == 0: + for ch, config in self._channels.items(): + if config.voltage == 0: continue inst.write(f"INST:NSEL {ch} \n") volt = float(inst.query("MEAS:VOLT? \n")) diff --git a/src/team1k/peripherals/power_supply.py b/src/team1k/peripherals/power_supply.py index fb4cbef..0f54d2b 100644 --- a/src/team1k/peripherals/power_supply.py +++ b/src/team1k/peripherals/power_supply.py @@ -26,7 +26,7 @@ class DetectorPowerSupply(PowerSupplyBase): if self._enabled: super().__init__( port=config.port, - settings=config.settings, + channels=config.channels, voltage_step=config.voltage_step, name="Detector Power Supply", ) diff --git a/src/team1k/peripherals/tec.py b/src/team1k/peripherals/tec.py index 30e2e09..06b9681 100644 --- a/src/team1k/peripherals/tec.py +++ b/src/team1k/peripherals/tec.py @@ -26,7 +26,7 @@ class TECController(PowerSupplyBase): if self._enabled: super().__init__( port=config.port, - settings=config.settings, + channels=config.channels, voltage_step=config.voltage_step, name="TEC Power Supply", ) @@ -53,7 +53,7 @@ class TECController(PowerSupplyBase): self._is_on = False logger.info("TEC: power_off (no-op, not configured)") - def get_status(self) -> dict[int, dict[str, float]]: + def get_status(self) -> dict[int, PowerSupplyBase.ChannelStatus]: if self._enabled: return super().get_status() return {} diff --git a/src/team1k/server.py b/src/team1k/server.py index 8cd9726..7b277ce 100644 --- a/src/team1k/server.py +++ b/src/team1k/server.py @@ -30,7 +30,7 @@ import argparse import threading from typing import Any -from .config import Team1kConfig, load_config +from .config import Team1kConfig, TriggerMode, TriggerPolarity, load_config from .state import DetectorState from .detector.registers import RegisterInterface @@ -67,10 +67,10 @@ class Team1kServer: # Cached parameter values (avoids register reads) self._params = { - "exposure_mode": config.exposure_mode, - "trigger_mode": config.trigger_mode, - "trigger_polarity": config.trigger_polarity, - "integration_time": config.integration_time_ms, + "exposure_mode": config.defaults.exposure_mode, + "trigger_mode": config.defaults.trigger_mode, + "trigger_polarity": config.defaults.trigger_polarity, + "integration_time": config.defaults.integration_time_ms, "frame_rate": 0.0, "frame_count": 0, } @@ -82,25 +82,25 @@ class Team1kServer: # Acquisition subprocess self.acquisition = AcquisitionProcess( - config.detector_ip, config.data_port, + config.detector.ip, config.detector.data_port, ring_name="team1k_frames", num_ring_slots=32, chip_config=self.chip_config, ) # PVA interface - self.pva = PVAInterface(self, prefix=config.pv_prefix) + self.pva = PVAInterface(self, prefix=config.server.pv_prefix) # PVA streamer (created after PVA setup) self._pva_streamer: PVAStreamer | None = None # TCP client server - self.tcp_server = TCPClientServer(self, port=config.client_port) + self.tcp_server = TCPClientServer(self, port=config.server.client_port) # Peripherals - self.bellow_stage = BellowStage(config.bellow_stage) - self.detector_power = DetectorPowerSupply(config.detector_power) - self.tec = TECController(config.tec) + self.bellow_stage = BellowStage(config.peripherals.bellow_stage) + self.detector_power = DetectorPowerSupply(config.peripherals.detector_power) + self.tec = TECController(config.peripherals.tec) @property def state(self) -> DetectorState: @@ -129,7 +129,7 @@ class Team1kServer: pass self.registers = RegisterInterface( - self.config.detector_ip, self.config.register_port, + self.config.detector.ip, self.config.detector.register_port, ) self.commands = DetectorCommands(self.registers) self.adc = ADCController(self.registers) @@ -149,6 +149,10 @@ class Team1kServer: if not self._connect_detector(): self.state = DetectorState.ERROR return False + + assert self.registers is not None + assert self.commands is not None + assert self.adc is not None try: # Firmware version @@ -162,32 +166,30 @@ class Team1kServer: configure_chip(self.registers, self.chip_config) # Apply config defaults - self.commands.set_exposure_mode(self.config.exposure_mode) + self.commands.set_exposure_mode(self.config.defaults.exposure_mode) self.commands.set_trigger_mode( - external=self.config.trigger_external, - polarity=not self.config.trigger_polarity_rising, + external=self.config.defaults.trigger_mode == TriggerMode.EXTERNAL, + polarity=self.config.defaults.trigger_polarity == TriggerPolarity.FALLING_EDGE, ) - self.commands.set_integration_time(self.config.integration_time_ms) - self.adc.set_clock_freq(self.config.adc_clock_frequency_mhz) - self.commands.set_adc_data_delay(self.config.adc_data_delay) + self.commands.set_integration_time(self.config.defaults.integration_time_ms) + self.adc.set_clock_freq(self.config.adc.clock_frequency_mhz) + self.commands.set_adc_data_delay(self.config.adc.data_delay) self.commands.set_adc_data_averaging(0) self.commands.enable_fpga_test_data(False) # ADC order registers - self.registers.write_register(30, self.config.adc_order_7to0) - self.registers.write_register(31, self.config.adc_order_15to8) + self.registers.write_register(30, self.config.adc.order_7to0) + self.registers.write_register(31, self.config.adc.order_15to8) # Digital signal registers - self.registers.write_register(27, self.config.digital_polarity) - self.registers.write_register(28, self.config.digital_order_7to0) - self.registers.write_register(29, self.config.digital_order_15to8) + self.registers.write_register(27, self.config.digital.polarity) + self.registers.write_register(28, self.config.digital.order_7to0) + self.registers.write_register(29, self.config.digital.order_15to8) # Update cached params - self._params["exposure_mode"] = self.config.exposure_mode - self._params["trigger_mode"] = ( - 1 if self.config.trigger_external else 0 - ) - self._params["integration_time"] = self.config.integration_time_ms + self._params["exposure_mode"] = self.config.defaults.exposure_mode + self._params["trigger_mode"] = self.config.defaults.trigger_mode == TriggerMode.EXTERNAL + self._params["integration_time"] = self.config.defaults.integration_time_ms self.state = DetectorState.IDLE logger.info("Detector initialized") @@ -222,11 +224,11 @@ class Team1kServer: elif name == "trigger_mode" or name == "trigger_polarity": if name == "trigger_mode": - mode = int(value) - polarity = self.config.trigger_polarity + mode = TriggerMode(int(value)) + polarity = self.config.defaults.trigger_polarity else: - mode = self.config.trigger_mode - polarity = int(value) + mode = self.config.defaults.trigger_mode + polarity = TriggerPolarity(int(value)) self.commands.set_trigger_mode(external=bool(mode), polarity=bool(polarity)) self._params["trigger_mode"] = mode self._params["trigger_polarity"] = polarity @@ -349,7 +351,7 @@ class Team1kServer: def _auto_reconnect_loop(self) -> None: """Background thread: auto-reconnect when in ERROR state.""" - interval = self.config.reconnect_interval + interval = self.config.detector.reconnect_interval while not self._shutdown_event.is_set(): if self.state == DetectorState.ERROR: logger.info("Auto-reconnect: attempting...") @@ -433,7 +435,7 @@ class Team1kServer: ).start() # Auto-reconnect thread (if enabled) - if self.config.auto_reconnect: + if self.config.detector.auto_reconnect: threading.Thread( target=self._auto_reconnect_loop, daemon=True, name="team1k-reconnect", @@ -506,11 +508,11 @@ def main(): # Apply CLI overrides if args.detector_ip: - config.detector_ip = args.detector_ip + config.detector.ip = args.detector_ip if args.pv_prefix: - config.pv_prefix = args.pv_prefix + config.server.pv_prefix = args.pv_prefix if args.client_port: - config.client_port = args.client_port + config.server.client_port = args.client_port server = Team1kServer(config) diff --git a/src/team1k/tcp_protocol.py b/src/team1k/tcp_protocol.py index 780f0e0..2956827 100644 --- a/src/team1k/tcp_protocol.py +++ b/src/team1k/tcp_protocol.py @@ -12,7 +12,6 @@ Message types: import json import struct import socket -from typing import Union # Length prefix formats _JSON_HEADER = struct.Struct("!I") # 4-byte unsigned int (max ~4 GB) @@ -26,7 +25,7 @@ def send_msg(sock: socket.socket, obj: dict) -> None: def recv_msg(sock: socket.socket) -> dict: """Receive a length-prefixed JSON message. Returns parsed dict.""" data = recv_bytes_or_msg(sock) - if isinstance(data, bytes): + if not isinstance(data, dict): raise ValueError("Expected JSON message but got binary data") return data @@ -35,7 +34,7 @@ def send_bytes(sock: socket.socket, data: bytes | memoryview) -> None: sock.sendall(b"D" + _DATA_HEADER.pack(len(data)) + data) -def recv_bytes_or_msg(sock: socket.socket) -> Union[bytes | dict]: +def recv_bytes_or_msg(sock: socket.socket) -> bytes | dict: """Receive length-prefixed binary data (reads the 8-byte header first).""" type_byte = _recv_exact(sock, 1) if type_byte == b"J": diff --git a/src/team1k/tcp_server.py b/src/team1k/tcp_server.py index 3a7482d..142708d 100644 --- a/src/team1k/tcp_server.py +++ b/src/team1k/tcp_server.py @@ -9,12 +9,13 @@ Only one capture can run at a time (capture lock). import socket import logging import threading -from typing import TYPE_CHECKING, Any +from dataclasses import asdict from .tcp_protocol import send_msg, recv_msg from .capture import BufferedCapture from .state import DetectorState +from typing import TYPE_CHECKING if TYPE_CHECKING: from .server import Team1kServer @@ -113,7 +114,7 @@ class TCPClientServer(threading.Thread): "acquiring": self._server.state == DetectorState.ACQUIRING, "power_on": self._server.detector_power.is_on, "tec_on": self._server.tec.is_on, - "bellow_inserted": abs(self._server.bellow_stage.position - self._server.config.bellow_stage.inserted_position_um) < 100, + "bellow_inserted": abs(self._server.bellow_stage.position - self._server.config.peripherals.bellow_stage.inserted_position_um) < 100, }) elif cmd == "capture": @@ -183,12 +184,14 @@ class TCPClientServer(threading.Thread): elif cmd == "register_read": addr = int(msg["address"]) + assert self._server.registers is not None, "Registers not configured" value = self._server.registers.read_register(addr) send_msg(sock, {"ok": True, "value": value}) elif cmd == "register_write": addr = int(msg["address"]) value = int(msg["value"]) + assert self._server.registers is not None, "Registers not configured" self._server.registers.write_register(addr, value) send_msg(sock, {"ok": True}) @@ -225,14 +228,14 @@ class TCPClientServer(threading.Thread): # Convert int keys to strings for JSON send_msg(sock, { "ok": True, - "channels": {str(k): dict(v) for k, v in status.items()}, + "channels": {str(k): asdict(v) for k, v in status.items()}, }) elif cmd == "tec_status": status = self._server.tec.get_status() send_msg(sock, { "ok": True, - "channels": {str(k): dict(v) for k, v in status.items()}, + "channels": {str(k): asdict(v) for k, v in status.items()}, }) elif cmd == "bellow_status": @@ -242,7 +245,7 @@ class TCPClientServer(threading.Thread): "ok": True, "position_um": position, "is_moving": is_moving, - "bellow_inserted": abs(position - self._server.config.bellow_stage.inserted_position_um) < 100, + "bellow_inserted": abs(position - self._server.config.peripherals.bellow_stage.inserted_position_um) < 100, }) else: