diff --git a/src/tickit/adapters/epics.py b/src/tickit/adapters/epics.py index 2846be684..62e9180fd 100644 --- a/src/tickit/adapters/epics.py +++ b/src/tickit/adapters/epics.py @@ -3,7 +3,7 @@ import logging from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, Set +from typing import Any, Awaitable, Callable, Dict, Optional, Set from softioc import asyncio_dispatcher, builder, softioc @@ -17,6 +17,8 @@ #: Ids of all adapters currently registered but not ready. _REGISTERED_ADAPTER_IDS: Set[int] = set() +_REGISTERED_IOC_BACKGROUND_TASKS: Set[Awaitable[None]] = set() + #: Iterator of unique IDs for new adapters _ID_COUNTER: itertools.count = itertools.count() @@ -36,6 +38,10 @@ def register_adapter() -> int: return adapter_id +def register_background_task(task: Awaitable[None]) -> None: + _REGISTERED_IOC_BACKGROUND_TASKS.add(task) + + def notify_adapter_ready(adapter_id: int) -> None: """Notify the builder that a particular adapter has made all the records it needs. @@ -64,6 +70,13 @@ def _build_and_run_ioc() -> None: event_loop = asyncio.get_event_loop() dispatcher = asyncio_dispatcher.AsyncioDispatcher(event_loop) softioc.iocInit(dispatcher) + + async def run_background_tasks() -> None: + if len(_REGISTERED_IOC_BACKGROUND_TASKS) > 0: + await asyncio.wait(_REGISTERED_IOC_BACKGROUND_TASKS) + + dispatcher(run_background_tasks) + # dbl directly prints out all record names, so we have to check # the log level in order to only do it in DEBUG. if LOGGER.level <= logging.DEBUG: @@ -90,9 +103,92 @@ class OutputRecord: class EpicsAdapter: """An adapter interface for the EpicsIo.""" - interrupt_records: Dict[InputRecord, Callable[[], Any]] = {} + interrupt_records: Dict[InputRecord, Callable[[], Any]] interrupt: RaiseInterrupt + def __init__(self) -> None: + self.interrupt_records = {} + + def float_rbv( + self, + name: str, + getter: Callable[[], float], + setter: Callable[[float], None], + rbv_name: Optional[str] = None, + precision: int = 2, + ): + rbv_name = rbv_name or f"{name}_RBV" + builder.aOut( + name, + initial_value=getter(), + on_update=self.interrupting_callback(setter), + PREC=precision, + ) + rbv = builder.aIn( + rbv_name, + initial_value=getter(), + PREC=precision, + ) + self.link_input_on_interrupt(rbv, getter) + + def float_ro( + self, + name: str, + getter: Callable[[], float], + precision: int = 2, + ): + self.link_input_on_interrupt( + builder.aIn(name, PREC=precision), + getter, + ) + + def int_rbv( + self, + name: str, + getter: Callable[[], int], + setter: Callable[[int], None], + rbv_name: Optional[str] = None, + ): + rbv_name = rbv_name or f"{name}_RBV" + builder.mbbOut( + name, + initial_value=getter(), + on_update=self.interrupting_callback(setter), + ) + rbv = builder.mbbIn(rbv_name, initial_value=getter()) + self.link_input_on_interrupt(rbv, getter) + + def bool_rbv( + self, + name: str, + getter: Callable[[], bool], + setter: Callable[[bool], None], + rbv_name: Optional[str] = None, + ): + rbv_name = rbv_name or f"{name}_RBV" + builder.boolOut( + name, + initial_value=getter(), + on_update=self.interrupting_callback(setter), + ) + rbv = builder.boolIn(rbv_name, initial_value=getter()) + self.link_input_on_interrupt(rbv, getter) + + def bool_ro(self, name: str, getter: Callable[[], float]): + self.link_input_on_interrupt( + builder.boolIn(name), + getter, + ) + + def interrupting_callback( + self, action: Callable[[Any], None] + ) -> Callable[[Any], Awaitable[None]]: + async def callback(value: Any) -> None: + action(value) + await self.interrupt() + + return callback + def link_input_on_interrupt( self, record: InputRecord, getter: Callable[[], Any] ) -> None: @@ -115,3 +211,11 @@ def after_update(self) -> None: current_value = getter() record.set(current_value) print(f"Record {record.name} updated to : {current_value}") + + def polling_interrupt(self, interval: float) -> None: + async def polling_task() -> None: + while True: + await asyncio.sleep(interval) + await self.interrupt() + + register_background_task(polling_task())