Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions src/fastcs/attributes/attr_r.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
from typing import Any

from fastcs.attributes.attribute import Attribute, AttributeAccessMode
Expand All @@ -10,11 +10,11 @@
from fastcs.datatypes import DataType, DType_T
from fastcs.logging import logger

AttrIOUpdateCallback = Callable[["AttrR[DType_T, Any]"], Awaitable[None]]
AttrIOUpdateCallback = Callable[["AttrR[DType_T, Any]"], Coroutine[None, None, None]]
"""An AttributeIO callback that takes an AttrR and updates its value"""
AttrUpdateCallback = Callable[[], Awaitable[None]]
AttrUpdateCallback = Callable[[], Coroutine[None, None, None]]
"""A callback to be called periodically to update an attribute"""
AttrOnUpdateCallback = Callable[[DType_T], Awaitable[None]]
AttrOnUpdateCallback = Callable[[DType_T], Coroutine[None, None, None]]
"""A callback to be called when the value of the attribute is updated"""


Expand Down Expand Up @@ -132,12 +132,8 @@ def bind_update_callback(self) -> AttrUpdateCallback:
update_callback = self._update_callback

async def update_attribute():
try:
self.log_event("Update attribute", topic=self)
await update_callback(self)
except Exception:
logger.error("Attribute update loop stopped", attribute=self)
raise
self.log_event("Update attribute", topic=self)
await update_callback(self)

return update_attribute

Expand Down
11 changes: 9 additions & 2 deletions src/fastcs/connections/ip_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ async def send_query(self, message: str) -> str:
return response

async def close(self):
if self.__connection is None:
return

async with self._connection as connection:
await connection.close()
self.__connection = None
try:
await connection.close()
except ConnectionResetError:
pass

self.__connection = None
Comment on lines 86 to +92
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure connection state is cleared even when close() fails unexpectedly.

If connection.close() raises anything other than ConnectionResetError, self.__connection is never reset. That can leave a stale connection handle and interfere with reconnect behavior.

Suggested fix
 async def close(self):
-    if self.__connection is None:
+    connection = self.__connection
+    if connection is None:
         return
 
-    async with self._connection as connection:
-        try:
-            await connection.close()
-        except ConnectionResetError:
-            pass
-
-    self.__connection = None
+    try:
+        async with connection:
+            try:
+                await connection.close()
+            except ConnectionResetError:
+                pass
+    finally:
+        self.__connection = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async with self._connection as connection:
await connection.close()
self.__connection = None
try:
await connection.close()
except ConnectionResetError:
pass
self.__connection = None
async with self._connection as connection:
try:
try:
await connection.close()
except ConnectionResetError:
pass
finally:
self.__connection = None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/fastcs/connections/ip_connection.py` around lines 86 - 92, The current
close sequence only clears self.__connection when connection.close() succeeds or
raises ConnectionResetError; move the clearing into a finally block so
self.__connection is set to None regardless of what exception close() raises.
Concretely, wrap the await connection.close() call in try/except/finally (or use
try/finally) inside the async with self._connection context, catch/handle
specific ConnectionResetError if desired, and always assign self.__connection =
None in the finally so the stale handle is never retained; reference the symbols
self._connection, connection.close(), and self.__connection to locate and update
the code.

36 changes: 4 additions & 32 deletions src/fastcs/control_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from IPython.terminal.embed import InteractiveShellEmbed

from fastcs.controllers import BaseController, Controller
from fastcs.controllers import Controller
from fastcs.logging import logger
from fastcs.methods import ScanCallback
from fastcs.tracer import Tracer
from fastcs.transports import ControllerAPI, Transport
from fastcs.transports import Transport

tracer = Tracer()

Expand Down Expand Up @@ -57,15 +57,6 @@ async def _run_initial_coros(self):
async def _start_scan_tasks(self):
self._scan_tasks = {self._loop.create_task(coro()) for coro in self._scan_coros}

for task in self._scan_tasks:
task.add_done_callback(self._scan_done)

def _scan_done(self, task: asyncio.Task):
try:
task.result()
except Exception:
logger.exception("Exception raised in scan task")

def _stop_scan_tasks(self):
for task in self._scan_tasks:
if not task.done():
Expand All @@ -82,9 +73,8 @@ async def serve(self, interactive: bool = True) -> None:
await self._controller.initialise()
self._controller.post_initialise()

self.controller_api = build_controller_api(self._controller)
self._scan_coros, self._initial_coros = (
self.controller_api.get_scan_and_initial_coros()
self.controller_api, self._scan_coros, self._initial_coros = (
self._controller.create_api_and_tasks()
)

context = {
Expand Down Expand Up @@ -168,21 +158,3 @@ async def interactive_shell(

def __del__(self):
self._stop_scan_tasks()


def build_controller_api(controller: Controller) -> ControllerAPI:
return _build_controller_api(controller, [])


def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI:
return ControllerAPI(
path=path,
attributes=controller.attributes,
command_methods=controller.command_methods,
scan_methods=controller.scan_methods,
sub_apis={
name: _build_controller_api(sub_controller, path + [name])
for name, sub_controller in controller.sub_controllers.items()
},
description=controller.description,
)
1 change: 1 addition & 0 deletions src/fastcs/controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_controller import BaseController as BaseController
from .controller import Controller as Controller
from .controller_api import ControllerAPI as ControllerAPI
from .controller_vector import ControllerVector as ControllerVector
16 changes: 15 additions & 1 deletion src/fastcs/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import Counter
from collections.abc import Sequence
from copy import deepcopy
from typing import ( # type: ignore
from typing import (
TypeVar,
_GenericAlias, # type: ignore
get_args,
Expand All @@ -12,6 +12,7 @@
)

from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute
from fastcs.controllers.controller_api import ControllerAPI
from fastcs.logging import logger
from fastcs.methods import Command, Method, Scan, UnboundCommand, UnboundScan
from fastcs.tracer import Tracer
Expand Down Expand Up @@ -388,3 +389,16 @@ def add_scan(self, name: str, scan: Scan):
@property
def scan_methods(self) -> dict[str, Scan]:
return self.__scan_methods

def _build_api(self, path: list[str]) -> ControllerAPI:
return ControllerAPI(
path=path,
attributes=self.attributes,
command_methods=self.command_methods,
scan_methods=self.scan_methods,
sub_apis={
name: sub_controller._build_api(path + [name]) # noqa: SLF001
for name, sub_controller in self.sub_controllers.items()
},
description=self.description,
)
109 changes: 108 additions & 1 deletion src/fastcs/controllers/controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import asyncio
from collections import defaultdict
from collections.abc import Sequence

from fastcs.attributes import AnyAttributeIO
from fastcs.attributes.attr_r import AttrR
from fastcs.attributes.attribute_io_ref import AttributeIORef
from fastcs.controllers.base_controller import BaseController
from fastcs.controllers.controller_api import ControllerAPI
from fastcs.logging import logger
from fastcs.methods import ScanCallback
from fastcs.util import ONCE


class Controller(BaseController):
Expand All @@ -13,6 +21,7 @@ def __init__(
ios: Sequence[AnyAttributeIO] | None = None,
) -> None:
super().__init__(description=description, ios=ios)
self._connected = False

def add_sub_controller(self, name: str, sub_controller: BaseController):
if name.isdigit():
Expand All @@ -23,7 +32,105 @@ def add_sub_controller(self, name: str, sub_controller: BaseController):
return super().add_sub_controller(name, sub_controller)

async def connect(self) -> None:
pass
"""Hook to perform initial connection to device

This should set ``_connected`` to ``True`` if the connection was successful to
enable scan tasks.

"""
self._connected = True

async def reconnect(self):
"""Hook to reconnect to device after an error

This should set ``_connected`` to ``True`` if the connection was successful to
enable scan tasks.

If the connection cannot be re-established it should log an error with the
reason. It should not raise an exception.

"""
self._connected = True

async def disconnect(self) -> None:
"""Hook to tidy up resources before stopping the application"""
pass

def create_api_and_tasks(
self,
) -> tuple[ControllerAPI, list[ScanCallback], list[ScanCallback]]:
"""Create api for transports tasks for FastCS backend

Creates a tuple of
- The `ControllerAPI` for this controller
- Initial coroutines to be run once on startup
- Periodic coroutines to run as background tasks

Returns:
tuple[ControllerAPI, list[ScanCallback], list[ScanCallback]]

"""
controller_api = self._build_api([])

scan_dict: dict[float, list[ScanCallback]] = defaultdict(list)
initial_coros: list[ScanCallback] = []

for api in controller_api.walk_api():
for method in api.scan_methods.values():
if method.period is ONCE:
initial_coros.append(method.fn)
else:
scan_dict[method.period].append(method.fn)

for attribute in api.attributes.values():
match attribute:
case AttrR(_io_ref=AttributeIORef(update_period=update_period)):
if update_period is ONCE:
initial_coros.append(attribute.bind_update_callback())
elif update_period is not None:
scan_dict[update_period].append(
attribute.bind_update_callback()
)

periodic_scan_coros: list[ScanCallback] = []
for period, methods in scan_dict.items():
periodic_scan_coros.append(self._create_periodic_scan_coro(period, methods))

return controller_api, periodic_scan_coros, initial_coros

def _create_periodic_scan_coro(
self, period: float, scans: Sequence[ScanCallback]
) -> ScanCallback:
"""Create a coroutine to run scans at a given period

This returns a coroutine that runs scans at a given period. If an exception is
raised in a callback it is caught and the updates for the controller are
paused, waiting for `_connected` to be set back to true via the `reconnect`
method.

Args:
period: The period to run the scans at
scans: A list of `ScanCallback` to run periodically

Returns:
A wrapper `ScanCallback` that runs all of the callbacks at a given period
"""

async def scan_coro() -> None:
while True:
if not self._connected:
await asyncio.sleep(1)
continue

try:
await asyncio.gather(
asyncio.sleep(period), *[scan() for scan in scans]
)
except Exception:
logger.exception("Exception in scan task", period=period)
self._connected = False

await asyncio.sleep(1) # Wait so this message appears last
logger.error("Pausing scan tasks and waiting for reconnect")

return scan_coro
38 changes: 38 additions & 0 deletions src/fastcs/controllers/controller_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections.abc import Iterator
from dataclasses import dataclass, field

from fastcs.attributes import Attribute
from fastcs.methods import Command, Scan


@dataclass
class ControllerAPI:
"""Attributes, Methods and sub APIs of a `Controller` to expose in a transport"""

path: list[str] = field(default_factory=list)
"""Path within controller tree (empty if this is the root)"""
attributes: dict[str, Attribute] = field(default_factory=dict)
command_methods: dict[str, Command] = field(default_factory=dict)
scan_methods: dict[str, Scan] = field(default_factory=dict)
sub_apis: dict[str, "ControllerAPI"] = field(default_factory=dict)
"""APIs of the sub controllers of the `Controller` this API was built from"""
description: str | None = None

def walk_api(self) -> Iterator["ControllerAPI"]:
"""Walk through all the nested `ControllerAPI` s of this `ControllerAPI`.

Yields the `ControllerAPI` s from a depth-first traversal of the tree,
including self.

"""
yield self
for api in self.sub_apis.values():
yield from api.walk_api()

def __repr__(self):
return (
f"ControllerAPI("
f"path={self.path}, "
f"sub_apis=[{', '.join(self.sub_apis.keys())}]"
f")"
)
6 changes: 6 additions & 0 deletions src/fastcs/demo/controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ async def cancel_all(self) -> None:
async def connect(self) -> None:
await self.connection.connect(self._settings.ip_settings)

async def reconnect(self):
await self.connection.close()
await self.connection.connect(self._settings.ip_settings)

self._connected = True
Comment on lines +98 to +102
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Violates base class contract: reconnect() must not raise exceptions.

Per the base Controller.reconnect() docstring (src/fastcs/controllers/controller.py:43-53), this method "should not raise an exception" and should "log an error with the reason" if reconnection fails. The current implementation allows exceptions from close() and connect() to propagate, which will crash the caller in control_system.py since it has no exception handling around reconnect().

🐛 Proposed fix: wrap in try-except and log errors
     async def reconnect(self):
-        await self.connection.close()
-        await self.connection.connect(self._settings.ip_settings)
-
-        self._connected = True
+        try:
+            await self.connection.close()
+        except Exception:
+            pass  # Ignore close errors, connection may already be closed
+
+        try:
+            await self.connection.connect(self._settings.ip_settings)
+            self._connected = True
+        except Exception:
+            logger.exception(
+                "Failed to reconnect to %s", self._settings.ip_settings
+            )

This also requires importing logger at the top of the file if not already present:

import logging

logger = logging.getLogger(__name__)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/fastcs/demo/controllers.py` around lines 98 - 102, reconnect currently
calls self.connection.close() and self.connection.connect(...) without handling
exceptions, violating the Controller.reconnect() contract; wrap the
close/connect calls in a try/except that catches Exception, log the failure with
a descriptive message and exception info (use a module logger via
logging.getLogger(__name__)), and ensure self._connected is only set True on
successful connect and left False on failure (do not re-raise).


async def close(self) -> None:
await self.connection.close()

Expand Down
1 change: 0 additions & 1 deletion src/fastcs/transports/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .controller_api import ControllerAPI as ControllerAPI
from .transport import Transport as Transport

try:
Expand Down
Loading