diff --git a/contract-tests/client_entity.py b/contract-tests/client_entity.py index 6bf957d8..ee594650 100644 --- a/contract-tests/client_entity.py +++ b/contract-tests/client_entity.py @@ -6,6 +6,7 @@ import requests from big_segment_store_fixture import BigSegmentStoreFixture +from flag_change_listener import ListenerRegistry from hook import PostingHook from ldclient import * @@ -158,6 +159,7 @@ def __init__(self, tag, config): config = Config(**opts) self.client = client.LDClient(config, start_wait / 1000.0) + self.listeners = ListenerRegistry(self.client.flag_tracker) def is_initializing(self) -> bool: return self.client.is_initialized() @@ -282,7 +284,26 @@ def fn(payload) -> Result: result = migrator.write(params["key"], Context.from_dict(params["context"]), Stage.from_str(params["defaultStage"]), params["payload"]) return {"result": result.authoritative.value if result.authoritative.is_success() else result.authoritative.error} + def register_flag_change_listener(self, params: dict): + self.listeners.register_flag_change_listener( + listener_id=params['listenerId'], + callback_uri=params['callbackUri'], + ) + + def register_flag_value_change_listener(self, params: dict): + self.listeners.register_flag_value_change_listener( + listener_id=params['listenerId'], + flag_key=params['flagKey'], + context=Context.from_dict(params['context']), + default_value=params['defaultValue'], + callback_uri=params['callbackUri'], + ) + + def unregister_listener(self, params: dict) -> bool: + return self.listeners.unregister(params['listenerId']) + def close(self): + self.listeners.close_all() self.client.close() self.log.info('Test ended') diff --git a/contract-tests/flag_change_listener.py b/contract-tests/flag_change_listener.py new file mode 100644 index 00000000..23b15404 --- /dev/null +++ b/contract-tests/flag_change_listener.py @@ -0,0 +1,90 @@ +import logging +import threading +from typing import Callable, Dict + +import requests + +from ldclient.context import Context +from ldclient.interfaces import FlagChange, FlagTracker, FlagValueChange + +log = logging.getLogger('testservice') + + +class ListenerRegistry: + """Manages all active flag change listener registrations for a single SDK client entity.""" + + def __init__(self, tracker: FlagTracker): + self._tracker = tracker + self._lock = threading.Lock() + # Maps listener_id -> (sdk_listener callable, cleanup function) + self._listeners: Dict[str, Callable] = {} + + def register_flag_change_listener(self, listener_id: str, callback_uri: str): + """Register a general flag change listener that fires on any flag configuration change.""" + def on_flag_change(flag_change: FlagChange): + payload = { + 'listenerId': listener_id, + 'flagKey': flag_change.key, + } + try: + requests.post(callback_uri, json=payload) + except Exception as e: + log.warning('Failed to post flag change notification: %s', e) + + with self._lock: + # If a listener with this ID already exists, unregister the old one first + if listener_id in self._listeners: + self._tracker.remove_listener(self._listeners[listener_id]) + + self._tracker.add_listener(on_flag_change) + self._listeners[listener_id] = on_flag_change + + def register_flag_value_change_listener( + self, + listener_id: str, + flag_key: str, + context: Context, + default_value, + callback_uri: str, + ): + """Register a flag value change listener that fires when the evaluated value changes.""" + def on_value_change(change: FlagValueChange): + payload = { + 'listenerId': listener_id, + 'flagKey': change.key, + 'oldValue': change.old_value, + 'newValue': change.new_value, + } + try: + requests.post(callback_uri, json=payload) + except Exception as e: + log.warning('Failed to post flag value change notification: %s', e) + + # add_flag_value_change_listener returns the underlying listener + # that must be passed to remove_listener to unsubscribe + with self._lock: + if listener_id in self._listeners: + self._tracker.remove_listener(self._listeners[listener_id]) + + underlying_listener = self._tracker.add_flag_value_change_listener(flag_key, context, on_value_change) + self._listeners[listener_id] = underlying_listener + + def unregister(self, listener_id: str) -> bool: + """Unregister a previously registered listener. Returns False if not found.""" + with self._lock: + listener = self._listeners.pop(listener_id, None) + + if listener is None: + return False + + self._tracker.remove_listener(listener) + return True + + def close_all(self): + """Unregister all listeners. Called when the SDK client entity shuts down.""" + with self._lock: + listeners = dict(self._listeners) + self._listeners.clear() + + for listener in listeners.values(): + self._tracker.remove_listener(listener) diff --git a/contract-tests/service.py b/contract-tests/service.py index 699dec07..7b023bcf 100644 --- a/contract-tests/service.py +++ b/contract-tests/service.py @@ -82,6 +82,8 @@ def status(): 'persistent-data-store-redis', 'persistent-data-store-dynamodb', 'persistent-data-store-consul', + 'flag-change-listeners', + 'flag-value-change-listeners', ] } return json.dumps(body), 200, {'Content-type': 'application/json'} @@ -150,6 +152,13 @@ def post_client_command(id): response = client.migration_variation(sub_params) elif command == "migrationOperation": response = client.migration_operation(sub_params) + elif command == "registerFlagChangeListener": + client.register_flag_change_listener(sub_params) + elif command == "registerFlagValueChangeListener": + client.register_flag_value_change_listener(sub_params) + elif command == "unregisterListener": + if not client.unregister_listener(sub_params): + return 'no listener with id "%s"' % sub_params['listenerId'], 400 else: return '', 400 diff --git a/ldclient/testing/impl/datasystem/test_fdv2_datasystem.py b/ldclient/testing/impl/datasystem/test_fdv2_datasystem.py index 2a71f58a..c90bcc83 100644 --- a/ldclient/testing/impl/datasystem/test_fdv2_datasystem.py +++ b/ldclient/testing/impl/datasystem/test_fdv2_datasystem.py @@ -266,14 +266,10 @@ def test_fdv2_falls_back_to_fdv1_on_polling_success_with_header(): changed = Event() changes: List[FlagChange] = [] - count = 0 def listener(flag_change: FlagChange): - nonlocal count - count += 1 changes.append(flag_change) - if count >= 2: - changed.set() + changed.set() set_on_ready = Event() fdv2 = FDv2(Config(sdk_key="dummy"), data_system_config) @@ -282,11 +278,11 @@ def listener(flag_change: FlagChange): assert set_on_ready.wait(1), "Data system did not become ready in time" - # Trigger a flag update in FDv1 + # Update flag in FDv1 data source to verify it's being used td_fdv1.update(td_fdv1.flag("fdv1-fallback-flag").on(False)) - assert changed.wait(1), "Flag change listener was not called in time" + assert changed.wait(2), "Flag change listener was not called in time" - # Verify FDv1 is active + # Verify we got flag changes from FDv1 assert len(changes) > 0 assert any(c.key == "fdv1-fallback-flag" for c in changes)