From 768266d5792b91eec936792ced801a5274c3c2ae Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sun, 8 Mar 2026 20:51:18 -0700 Subject: [PATCH 1/8] feat: add async 'for' loop support to LogScanner (#424) --- bindings/python/fluss/__init__.pyi | 8 +- bindings/python/src/table.rs | 154 ++++++++++++++++++++++--- bindings/python/test/test_log_table.py | 49 ++++++++ 3 files changed, 188 insertions(+), 23 deletions(-) diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 417ac9b2..2534f638 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -125,7 +125,9 @@ class ScanRecords: def __getitem__(self, index: slice) -> List[ScanRecord]: ... @overload def __getitem__(self, bucket: TableBucket) -> List[ScanRecord]: ... - def __getitem__(self, key: Union[int, slice, TableBucket]) -> Union[ScanRecord, List[ScanRecord]]: ... + def __getitem__( + self, key: Union[int, slice, TableBucket] + ) -> Union[ScanRecord, List[ScanRecord]]: ... def __contains__(self, bucket: TableBucket) -> bool: ... def __iter__(self) -> Iterator[ScanRecord]: ... def __str__(self) -> str: ... @@ -369,7 +371,6 @@ class FlussAdmin: ... def __repr__(self) -> str: ... - class DatabaseDescriptor: """Descriptor for a Fluss database (comment and custom properties).""" @@ -383,7 +384,6 @@ class DatabaseDescriptor: def get_custom_properties(self) -> Dict[str, str]: ... def __repr__(self) -> str: ... - class DatabaseInfo: """Information about a Fluss database.""" @@ -604,7 +604,6 @@ class UpsertWriter: ... def __repr__(self) -> str: ... - class WriteResultHandle: """Handle for a pending write (append/upsert/delete). Ignore for fire-and-forget, or await handle.wait() for ack.""" @@ -613,7 +612,6 @@ class WriteResultHandle: ... def __repr__(self) -> str: ... - class Lookuper: """Lookuper for performing primary key lookups on a Fluss table.""" diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 660cd6be..b68492e1 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -30,6 +30,9 @@ use pyo3::types::{ PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType, PyTzInfo, }; +use pyo3::{ + Bound, IntoPyObjectExt, Py, PyAny, PyClassInitializer, PyErr, PyRef, PyRefMut, PyResult, Python, +}; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; use std::sync::Arc; @@ -1863,6 +1866,13 @@ enum ScannerKind { Batch(fcore::client::RecordBatchLogScanner), } +/// The internal state of the scanner, protected by a Tokio Mutex for async cross-thread sharing +struct ScannerState { + kind: ScannerKind, + /// A buffer to hold records polled from the network before yielding them one-by-one to Python + pending_records: std::collections::VecDeque>, +} + impl ScannerKind { fn as_record(&self) -> PyResult<&fcore::client::LogScanner> { match self { @@ -1901,7 +1911,7 @@ macro_rules! with_scanner { /// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches #[pyclass] pub struct LogScanner { - scanner: ScannerKind, + state: Arc>, admin: fcore::client::FlussAdmin, table_info: fcore::metadata::TableInfo, /// The projected Arrow schema to use for empty table creation @@ -1922,7 +1932,8 @@ impl LogScanner { fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe(bucket_id, start_offset)) + let state = self.state.lock().await; + with_scanner!(&state.kind, subscribe(bucket_id, start_offset)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1935,7 +1946,8 @@ impl LogScanner { fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets)) + let state = self.state.lock().await; + with_scanner!(&state.kind, subscribe_buckets(&bucket_offsets)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1956,8 +1968,9 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { + let state = self.state.lock().await; with_scanner!( - &self.scanner, + &state.kind, subscribe_partition(partition_id, bucket_id, start_offset) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1976,8 +1989,9 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { + let state = self.state.lock().await; with_scanner!( - &self.scanner, + &state.kind, subscribe_partition_buckets(&partition_bucket_offsets) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1992,7 +2006,8 @@ impl LogScanner { fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, unsubscribe(bucket_id)) + let state = self.state.lock().await; + with_scanner!(&state.kind, unsubscribe(bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2006,11 +2021,9 @@ impl LogScanner { fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!( - &self.scanner, - unsubscribe_partition(partition_id, bucket_id) - ) - .map_err(|e| FlussError::from_core_error(&e)) + let state = self.state.lock().await; + with_scanner!(&state.kind, unsubscribe_partition(partition_id, bucket_id)) + .map_err(|e| FlussError::from_core_error(&e)) }) }) } @@ -2030,7 +2043,10 @@ impl LogScanner { /// - Returns an empty ScanRecords if no records are available /// - When timeout expires, returns an empty ScanRecords (NOT an error) fn poll(&self, py: Python, timeout_ms: i64) -> PyResult { - let scanner = self.scanner.as_record()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_record()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2079,7 +2095,10 @@ impl LogScanner { /// - Returns an empty list if no batches are available /// - When timeout expires, returns an empty list (NOT an error) fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2114,7 +2133,10 @@ impl LogScanner { /// - Returns an empty table (with correct schema) if no records are available /// - When timeout expires, returns an empty table (NOT an error) fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2167,7 +2189,10 @@ impl LogScanner { /// Returns: /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; let subscribed = scanner.get_subscribed_buckets(); if subscribed.is_empty() { return Err(FlussError::new_err( @@ -2199,6 +2224,90 @@ impl LogScanner { Ok(df) } + fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { + let py = slf.py(); + let code = pyo3::ffi::c_str!( + r#" +async def _adapter(obj): + while True: + try: + yield await obj.__anext__() + except StopAsyncIteration: + break +"# + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None)?; + let adapter = globals.get_item("_adapter")?.unwrap(); + // Return adapt(self) + adapter.call1((slf.into_bound_py_any(py)?,)) + } + + fn __anext__<'py>(slf: PyRefMut<'py, Self>) -> PyResult>> { + let state_arc = slf.state.clone(); + let projected_row_type = slf.projected_row_type.clone(); + let py = slf.py(); + + let future = future_into_py(py, async move { + let mut state = state_arc.lock().await; + + // 1. If we already have buffered records, pop and return immediately + if let Some(record) = state.pending_records.pop_front() { + return Ok(record.into_any()); + } + + // 2. Buffer is empty, we must poll the network for the next batch + // The underlying kind must be a Record-based scanner. + let scanner = match state.kind.as_record() { + Ok(s) => s, + Err(_) => { + return Err(pyo3::exceptions::PyStopAsyncIteration::new_err( + "Stream Ended", + )); + } + }; + + // Poll with a reasonable internal timeout before unblocking the event loop + let timeout = core::time::Duration::from_millis(5000); + + let mut current_records = scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // If it's a real timeout with zero records, loop or throw StopAsyncIteration? + // Since it's a streaming log, we can yield None or block. Blocking requires a loop in the future. + while current_records.is_empty() { + current_records = scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + } + + // Now we have records. + Python::attach(|py| { + for (_, records) in current_records.into_records_by_buckets() { + for core_record in records { + let scan_record = + ScanRecord::from_core(py, &core_record, &projected_row_type)?; + state.pending_records.push_back(Py::new(py, scan_record)?); + } + } + + // Pop the very first one to return right now + if let Some(record) = state.pending_records.pop_front() { + Ok(record.into_any()) + } else { + Err(pyo3::exceptions::PyStopAsyncIteration::new_err( + "Stream Ended", + )) + } + }) + })?; + + Ok(Some(future)) + } + fn __repr__(&self) -> String { format!("LogScanner(table={})", self.table_info.table_path) } @@ -2213,7 +2322,10 @@ impl LogScanner { projected_row_type: fcore::metadata::RowType, ) -> Self { Self { - scanner, + state: std::sync::Arc::new(tokio::sync::Mutex::new(ScannerState { + kind: scanner, + pending_records: std::collections::VecDeque::new(), + })), admin, table_info, projected_schema, @@ -2264,7 +2376,10 @@ impl LogScanner { py: Python, subscribed: &[(fcore::metadata::TableBucket, i64)], ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; let is_partitioned = scanner.is_partitioned(); let table_path = &self.table_info.table_path; @@ -2367,7 +2482,10 @@ impl LogScanner { py: Python, mut stopping_offsets: HashMap, ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; let mut all_batches = Vec::new(); while !stopping_offsets.is_empty() { diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index dd1a4d4f..2f9588b0 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -729,6 +729,55 @@ async def test_scan_records_indexing_and_slicing(connection, admin): await admin.drop_table(table_path, ignore_if_not_exists=False) +async def test_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on LogScanner.""" + table_path = fluss.TablePath("fluss", "py_test_async_iterator") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + # Write 5 records + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [pa.array(list(range(1, 6)), type=pa.int32()), + pa.array([f"async{i}" for i in range(1, 6)])], + schema=pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets({i: fluss.EARLIEST_OFFSET for i in range(num_buckets)}) + + collected = [] + + # Here is the magical Issue #424 async iterator logic at work: + async def consume_scanner(): + async for record in scanner: + collected.append(record) + if len(collected) == 5: + break + + # We must race the consumption against a timeout so the test doesn't hang if the iterator is broken + await asyncio.wait_for(consume_scanner(), timeout=10.0) + + assert len(collected) == 5, f"Expected 5 records, got {len(collected)}" + + collected.sort(key=lambda r: r.row["id"]) + for i, record in enumerate(collected): + assert record.row["id"] == i + 1 + assert record.row["val"] == f"async{i + 1}" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From 0e01b8b7a453a691f49d5382cb308b71fa83d650 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sun, 8 Mar 2026 21:22:53 -0700 Subject: [PATCH 2/8] chore: revert formatting changes to __init__.pyi --- bindings/python/fluss/__init__.pyi | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 2534f638..417ac9b2 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -125,9 +125,7 @@ class ScanRecords: def __getitem__(self, index: slice) -> List[ScanRecord]: ... @overload def __getitem__(self, bucket: TableBucket) -> List[ScanRecord]: ... - def __getitem__( - self, key: Union[int, slice, TableBucket] - ) -> Union[ScanRecord, List[ScanRecord]]: ... + def __getitem__(self, key: Union[int, slice, TableBucket]) -> Union[ScanRecord, List[ScanRecord]]: ... def __contains__(self, bucket: TableBucket) -> bool: ... def __iter__(self) -> Iterator[ScanRecord]: ... def __str__(self) -> str: ... @@ -371,6 +369,7 @@ class FlussAdmin: ... def __repr__(self) -> str: ... + class DatabaseDescriptor: """Descriptor for a Fluss database (comment and custom properties).""" @@ -384,6 +383,7 @@ class DatabaseDescriptor: def get_custom_properties(self) -> Dict[str, str]: ... def __repr__(self) -> str: ... + class DatabaseInfo: """Information about a Fluss database.""" @@ -604,6 +604,7 @@ class UpsertWriter: ... def __repr__(self) -> str: ... + class WriteResultHandle: """Handle for a pending write (append/upsert/delete). Ignore for fire-and-forget, or await handle.wait() for ack.""" @@ -612,6 +613,7 @@ class WriteResultHandle: ... def __repr__(self) -> str: ... + class Lookuper: """Lookuper for performing primary key lookups on a Fluss table.""" From 3aa067b2fbda018c5b19ce972c4c03aafa01e1a4 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Mon, 9 Mar 2026 07:12:28 -0700 Subject: [PATCH 3/8] fix: remove unused PyClassInitializer and PyErr imports --- bindings/python/src/table.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index b68492e1..10f77463 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -31,7 +31,7 @@ use pyo3::types::{ PyTzInfo, }; use pyo3::{ - Bound, IntoPyObjectExt, Py, PyAny, PyClassInitializer, PyErr, PyRef, PyRefMut, PyResult, Python, + Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python, }; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; From 1065665717ec19494e17a13de1415c8606bf9e70 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Mon, 9 Mar 2026 19:37:43 -0700 Subject: [PATCH 4/8] style: apply cargo fmt --- bindings/python/src/table.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 10f77463..1d66a6e3 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -30,9 +30,7 @@ use pyo3::types::{ PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType, PyTzInfo, }; -use pyo3::{ - Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python, -}; +use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python}; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; use std::sync::Arc; From 195ec7cfc1fbefa50e56ab1c760b345fdea4eaca Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Tue, 10 Mar 2026 10:34:43 -0700 Subject: [PATCH 5/8] refactor: release scanner lock earlier by cloning subscribed buckets within a local scope in `to_arrow` --- bindings/python/src/table.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 1d66a6e3..28ccbc76 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2187,16 +2187,20 @@ impl LogScanner { /// Returns: /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; - let subscribed = scanner.get_subscribed_buckets(); - if subscribed.is_empty() { - return Err(FlussError::new_err( - "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", - )); - } + let subscribed = { + let scanner_ref = unsafe { + &*(&self.state as *const std::sync::Arc>) + }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; + let subs = scanner.get_subscribed_buckets(); + if subs.is_empty() { + return Err(FlussError::new_err( + "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", + )); + } + subs.clone() + }; // 2. Query latest offsets for all subscribed buckets let stopping_offsets = self.query_latest_offsets(py, &subscribed)?; From 4ad2fd86fa710f3cd66f6e840ab2eef02168df6c Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Wed, 11 Mar 2026 16:45:51 -0700 Subject: [PATCH 6/8] refactor: Remove Mutex and utilize __aiter__ with _async_poll(timeout_ms) instead --- bindings/python/src/table.rs | 180 +++++-------- bindings/python/test/test_log_table.py | 341 +++++++++++++++++++++++++ 2 files changed, 412 insertions(+), 109 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 28ccbc76..64c06d30 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -1864,13 +1864,6 @@ enum ScannerKind { Batch(fcore::client::RecordBatchLogScanner), } -/// The internal state of the scanner, protected by a Tokio Mutex for async cross-thread sharing -struct ScannerState { - kind: ScannerKind, - /// A buffer to hold records polled from the network before yielding them one-by-one to Python - pending_records: std::collections::VecDeque>, -} - impl ScannerKind { fn as_record(&self) -> PyResult<&fcore::client::LogScanner> { match self { @@ -1895,7 +1888,7 @@ impl ScannerKind { /// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface. macro_rules! with_scanner { ($scanner:expr, $method:ident($($arg:expr),*)) => { - match $scanner { + match $scanner.as_ref() { ScannerKind::Record(s) => s.$method($($arg),*).await, ScannerKind::Batch(s) => s.$method($($arg),*).await, } @@ -1909,7 +1902,7 @@ macro_rules! with_scanner { /// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches #[pyclass] pub struct LogScanner { - state: Arc>, + kind: Arc, admin: fcore::client::FlussAdmin, table_info: fcore::metadata::TableInfo, /// The projected Arrow schema to use for empty table creation @@ -1930,8 +1923,7 @@ impl LogScanner { fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, subscribe(bucket_id, start_offset)) + with_scanner!(&self.kind, subscribe(bucket_id, start_offset)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1944,8 +1936,7 @@ impl LogScanner { fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, subscribe_buckets(&bucket_offsets)) + with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1966,9 +1957,8 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; with_scanner!( - &state.kind, + &self.kind, subscribe_partition(partition_id, bucket_id, start_offset) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1987,9 +1977,8 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; with_scanner!( - &state.kind, + &self.kind, subscribe_partition_buckets(&partition_bucket_offsets) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -2004,8 +1993,7 @@ impl LogScanner { fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, unsubscribe(bucket_id)) + with_scanner!(&self.kind, unsubscribe(bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2019,8 +2007,7 @@ impl LogScanner { fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, unsubscribe_partition(partition_id, bucket_id)) + with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2041,10 +2028,7 @@ impl LogScanner { /// - Returns an empty ScanRecords if no records are available /// - When timeout expires, returns an empty ScanRecords (NOT an error) fn poll(&self, py: Python, timeout_ms: i64) -> PyResult { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_record()?; + let scanner = self.kind.as_record()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2093,10 +2077,7 @@ impl LogScanner { /// - Returns an empty list if no batches are available /// - When timeout expires, returns an empty list (NOT an error) fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2131,10 +2112,7 @@ impl LogScanner { /// - Returns an empty table (with correct schema) if no records are available /// - When timeout expires, returns an empty table (NOT an error) fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2188,11 +2166,7 @@ impl LogScanner { /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { let subscribed = { - let scanner_ref = unsafe { - &*(&self.state as *const std::sync::Arc>) - }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; let subs = scanner.get_subscribed_buckets(); if subs.is_empty() { return Err(FlussError::new_err( @@ -2227,87 +2201,84 @@ impl LogScanner { } fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { + static ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); let py = slf.py(); - let code = pyo3::ffi::c_str!( - r#" -async def _adapter(obj): + let gen_fn = ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_scan(scanner, timeout_ms=1000): while True: - try: - yield await obj.__anext__() - except StopAsyncIteration: - break + batch = await scanner._async_poll(timeout_ms) + if batch: + for record in batch: + yield record "# - ); - let globals = pyo3::types::PyDict::new(py); - py.run(code, Some(&globals), None)?; - let adapter = globals.get_item("_adapter")?.unwrap(); - // Return adapt(self) - adapter.call1((slf.into_bound_py_any(py)?,)) + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals.get_item("_async_scan").unwrap().unwrap().unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) } - fn __anext__<'py>(slf: PyRefMut<'py, Self>) -> PyResult>> { - let state_arc = slf.state.clone(); - let projected_row_type = slf.projected_row_type.clone(); - let py = slf.py(); - - let future = future_into_py(py, async move { - let mut state = state_arc.lock().await; + /// Perform a single bounded poll and return a list of ScanRecord objects. + /// + /// This is the async building block used by `__aiter__` to implement + /// `async for`. Each call does exactly one network poll (bounded by + /// `timeout_ms`), converts any results to Python objects, and returns + /// them as a list. An empty list signals a timeout (no data yet), not + /// end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of ScanRecord objects + fn _async_poll<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } - // 1. If we already have buffered records, pop and return immediately - if let Some(record) = state.pending_records.pop_front() { - return Ok(record.into_any()); - } + let scanner = Arc::clone(&self.kind); + let projected_row_type = self.projected_row_type.clone(); + let timeout = Duration::from_millis(timeout_ms as u64); - // 2. Buffer is empty, we must poll the network for the next batch - // The underlying kind must be a Record-based scanner. - let scanner = match state.kind.as_record() { - Ok(s) => s, - Err(_) => { - return Err(pyo3::exceptions::PyStopAsyncIteration::new_err( - "Stream Ended", + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Record(s) => s, + ScannerKind::Batch(_) => { + return Err(PyTypeError::new_err( + "Async iteration is only supported for record scanners; \ + use create_log_scanner() instead.", )); } }; - // Poll with a reasonable internal timeout before unblocking the event loop - let timeout = core::time::Duration::from_millis(5000); - - let mut current_records = scanner + let scan_records = core_scanner .poll(timeout) .await .map_err(|e| FlussError::from_core_error(&e))?; - // If it's a real timeout with zero records, loop or throw StopAsyncIteration? - // Since it's a streaming log, we can yield None or block. Blocking requires a loop in the future. - while current_records.is_empty() { - current_records = scanner - .poll(timeout) - .await - .map_err(|e| FlussError::from_core_error(&e))?; - } - - // Now we have records. + // Convert to Python list Python::attach(|py| { - for (_, records) in current_records.into_records_by_buckets() { + let mut result: Vec> = Vec::new(); + for (_, records) in scan_records.into_records_by_buckets() { for core_record in records { let scan_record = ScanRecord::from_core(py, &core_record, &projected_row_type)?; - state.pending_records.push_back(Py::new(py, scan_record)?); + result.push(Py::new(py, scan_record)?); } } - - // Pop the very first one to return right now - if let Some(record) = state.pending_records.pop_front() { - Ok(record.into_any()) - } else { - Err(pyo3::exceptions::PyStopAsyncIteration::new_err( - "Stream Ended", - )) - } + Ok(result) }) - })?; - - Ok(Some(future)) + }) } fn __repr__(&self) -> String { @@ -2324,10 +2295,7 @@ impl LogScanner { projected_row_type: fcore::metadata::RowType, ) -> Self { Self { - state: std::sync::Arc::new(tokio::sync::Mutex::new(ScannerState { - kind: scanner, - pending_records: std::collections::VecDeque::new(), - })), + kind: Arc::new(scanner), admin, table_info, projected_schema, @@ -2378,10 +2346,7 @@ impl LogScanner { py: Python, subscribed: &[(fcore::metadata::TableBucket, i64)], ) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; let is_partitioned = scanner.is_partitioned(); let table_path = &self.table_info.table_path; @@ -2484,10 +2449,7 @@ impl LogScanner { py: Python, mut stopping_offsets: HashMap, ) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; let mut all_batches = Vec::new(); while !stopping_offsets.is_empty() { diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index 2f9588b0..8cf43fb4 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -778,6 +778,347 @@ async def consume_scanner(): await admin.drop_table(table_path, ignore_if_not_exists=False) +async def test_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + `poll()` calls. If the old implementation's tokio::spawn'd task + were still alive, it would hold the Mutex and cause `poll()` to + deadlock or error. + """ + table_path = fluss.TablePath("fluss", "py_test_async_break_leak") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"v{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect only 3 of 10) + collected_async = [] + + async def consume_and_break(): + async for record in scanner: + collected_async.append(record) + if len(collected_async) >= 3: + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert len(collected_async) == 3, ( + f"Expected 3 records from async for, got {len(collected_async)}" + ) + + # Phase 2: sync poll() must still work — proves no leaked task / lock. + # With small data and few buckets, _async_poll may have fetched all + # records in one batch. After break, the un-yielded records from that + # batch are lost. So sync poll may return 0 records — the key assertion + # is that poll() completes without deadlock (returns within timeout). + remaining = scanner.poll(2000) + assert remaining is not None, "poll() should return (not deadlock)" + + # If we got records, verify no duplicates + async_ids = {r.row["id"] for r in collected_async} + sync_ids = {r.row["id"] for r in remaining} + assert async_ids.isdisjoint(sync_ids), ( + f"Duplicate IDs between async and sync: {async_ids & sync_ids}" + ) + + # All IDs must be from the original 1-10 range + all_ids = async_ids | sync_ids + assert all_ids.issubset(set(range(1, 11))), ( + f"Unexpected IDs: {all_ids - set(range(1, 11))}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_multiple_batches(connection, admin): + """Verify async iteration works across multiple network poll cycles. + + _async_poll does a single bounded poll per call. Writing 20 records + to multiple buckets ensures the Python generator must loop through + several _async_poll calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_async_multi_batch") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"multi{i}" for i in range(1, num_records + 1)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected = [] + + async def consume_all(): + async for record in scanner: + collected.append(record) + if len(collected) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(collected) == num_records, ( + f"Expected {num_records} records, got {len(collected)}" + ) + + # Verify all IDs are present (order may vary due to bucketing) + ids = sorted(r.row["id"] for r in collected) + assert ids == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_batch_scanner_raises_type_error( + connection, admin +): + """Verify that using `async for` on a batch scanner raises TypeError.""" + table_path = fluss.TablePath("fluss", "py_test_async_batch_error") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + + # Write some data so there's something to iterate + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["a", "b", "c"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + # Create a BATCH scanner (not a record scanner) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + batch_scanner.subscribe(bucket_id=0, start_offset=0) + + # Attempting async for on a batch scanner must raise TypeError + import pytest + + with pytest.raises(TypeError): + + async def try_iterate(): + async for _ in batch_scanner: + pass + + await asyncio.wait_for(try_iterate(), timeout=5.0) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_negative_timeout(connection, admin): + """Verify _async_poll rejects a negative timeout_ms with an error.""" + table_path = fluss.TablePath("fluss", "py_test_async_poll_neg_timeout") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + scanner = await table.new_scan().create_log_scanner() + scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(Exception, match="non-negative"): + await scanner._async_poll(-1) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_returns_list(connection, admin): + """Verify _async_poll returns a Python list of ScanRecord objects.""" + table_path = fluss.TablePath("fluss", "py_test_async_poll_returns_list") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Poll until we get a non-empty result + result = None + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + result = await scanner._async_poll(2000) + if result: + break + + assert result is not None, "Expected non-None result from _async_poll" + assert isinstance(result, list), ( + f"Expected list, got {type(result).__name__}" + ) + assert len(result) > 0, "Expected non-empty list" + + # Each element must be a ScanRecord with .row, .offset, .timestamp + for record in result: + assert hasattr(record, "row"), "ScanRecord should have .row" + assert hasattr(record, "offset"), "ScanRecord should have .offset" + assert hasattr(record, "timestamp"), ( + "ScanRecord should have .timestamp" + ) + assert "id" in record.row + + # An empty poll (no new data) should return an empty list, not None + empty_result = await scanner._async_poll(100) + assert isinstance(empty_result, list), ( + f"Empty poll should return list, got {type(empty_result).__name__}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_sync_methods_after_async_iteration(connection, admin): + """Verify sync poll() works correctly interleaved with async iteration. + + This proves there is no lock contention between the async and sync + code paths — the removed Mutex would have caused deadlocks here if + the lock were held across the async poll boundary. + """ + table_path = fluss.TablePath( + "fluss", "py_test_sync_after_async" + ) + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 9)), type=pa.int32()), + pa.array([f"s{i}" for i in range(1, 9)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Step 1: Collect 4 records via async for + async_records = [] + + async def partial_consume(): + async for record in scanner: + async_records.append(record) + if len(async_records) >= 4: + break + + await asyncio.wait_for(partial_consume(), timeout=10.0) + assert len(async_records) == 4 + + # Step 2: Collect remaining records via sync poll(). + # With small data, _async_poll may have fetched all records in one + # batch. After break, the un-yielded records are lost. The key + # assertion is that poll() works (no deadlock from a held lock). + sync_records = scanner.poll(2000) + assert sync_records is not None, "poll() should return (not deadlock)" + + # Step 3: Verify no duplicates and all IDs are valid + async_ids = {r.row["id"] for r in async_records} + sync_ids = {r.row["id"] for r in sync_records} + assert async_ids.isdisjoint(sync_ids), ( + f"Duplicate IDs: {async_ids & sync_ids}" + ) + all_ids = async_ids | sync_ids + assert all_ids.issubset(set(range(1, 9))), ( + f"Unexpected IDs: {all_ids - set(range(1, 9))}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From 08eef133d39b0bccf076f5909042423833ff7250 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Thu, 12 Mar 2026 22:31:03 -0700 Subject: [PATCH 7/8] feat: add create_record_batch_log_scanner() --- bindings/python/src/table.rs | 104 ++++- bindings/python/test/test_log_table.py | 556 ++++++++++++++++++++++++- 2 files changed, 636 insertions(+), 24 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 64c06d30..1dddddbd 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2201,11 +2201,14 @@ impl LogScanner { } fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { - static ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); let py = slf.py(); - let gen_fn = ASYNC_GEN_FN.get_or_init(py, || { - let code = pyo3::ffi::c_str!( - r#" + + match slf.kind.as_ref() { + ScannerKind::Record(_) => { + static RECORD_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + let gen_fn = RECORD_ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" async def _async_scan(scanner, timeout_ms=1000): while True: batch = await scanner._async_poll(timeout_ms) @@ -2213,12 +2216,37 @@ async def _async_scan(scanner, timeout_ms=1000): for record in batch: yield record "# - ); - let globals = pyo3::types::PyDict::new(py); - py.run(code, Some(&globals), None).unwrap(); - globals.get_item("_async_scan").unwrap().unwrap().unbind() - }); - gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals.get_item("_async_scan").unwrap().unwrap().unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + } + ScannerKind::Batch(_) => { + static BATCH_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + let gen_fn = BATCH_ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_batch_scan(scanner, timeout_ms=1000): + while True: + batches = await scanner._async_poll_batches(timeout_ms) + if batches: + for rb in batches: + yield rb +"# + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals + .get_item("_async_batch_scan") + .unwrap() + .unwrap() + .unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + } + } } /// Perform a single bounded poll and return a list of ScanRecord objects. @@ -2281,6 +2309,62 @@ async def _async_scan(scanner, timeout_ms=1000): }) } + /// Perform a single bounded poll and return a list of RecordBatch objects. + /// + /// This is the async building block used by `__aiter__` (batch mode) to + /// implement `async for`. Each call does exactly one network poll (bounded + /// by `timeout_ms`), converts any results to Python RecordBatch objects, + /// and returns them as a list. An empty list signals a timeout (no data + /// yet), not end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of RecordBatch objects + fn _async_poll_batches<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } + + let scanner = Arc::clone(&self.kind); + let timeout = Duration::from_millis(timeout_ms as u64); + + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Batch(s) => s, + ScannerKind::Record(_) => { + return Err(PyTypeError::new_err( + "Batch async iteration is only supported for batch scanners; \ + use create_record_batch_log_scanner() instead.", + )); + } + }; + + let scan_batches = core_scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // Convert to Python list of RecordBatch objects + Python::attach(|py| { + let mut result: Vec> = Vec::new(); + for scan_batch in scan_batches { + let rb = RecordBatch::from_scan_batch(scan_batch); + result.push(Py::new(py, rb)?); + } + Ok(result) + }) + }) + } + fn __repr__(&self) -> String { format!("LogScanner(table={})", self.table_info.table_path) } diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index 8cf43fb4..970a516f 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -916,11 +916,14 @@ async def consume_all(): await admin.drop_table(table_path, ignore_if_not_exists=False) -async def test_async_iterator_batch_scanner_raises_type_error( - connection, admin -): - """Verify that using `async for` on a batch scanner raises TypeError.""" - table_path = fluss.TablePath("fluss", "py_test_async_batch_error") +async def test_batch_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on a batch LogScanner. + + With our __aiter__ dispatch, a batch-based scanner should yield RecordBatch + objects (not ScanRecord). Each yielded item has .batch (PyArrow RecordBatch), + .bucket, .base_offset, .last_offset. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_iter") await admin.drop_table(table_path, ignore_if_not_exists=True) schema = fluss.Schema( @@ -929,14 +932,148 @@ async def test_async_iterator_batch_scanner_raises_type_error( await admin.create_table(table_path, fluss.TableDescriptor(schema)) table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 7)), type=pa.int32()), + pa.array([f"bv{i}" for i in range(1, 7)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected_batches = [] + total_rows = 0 + + async def consume_batches(): + nonlocal total_rows + async for rb in batch_scanner: + collected_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 6: + break + + await asyncio.wait_for(consume_batches(), timeout=15.0) + + assert total_rows >= 6, f"Expected >=6 total rows, got {total_rows}" + assert len(collected_batches) > 0 + + # Verify each yielded item is a RecordBatch with expected attributes + for rb in collected_batches: + assert hasattr(rb, "batch"), "RecordBatch should have .batch" + assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" + assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" + assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" + # .batch should be a PyArrow RecordBatch + arrow_batch = rb.batch + assert isinstance(arrow_batch, pa.RecordBatch), ( + f"Expected PyArrow RecordBatch, got {type(arrow_batch).__name__}" + ) + assert arrow_batch.num_columns == 2 + assert set(arrow_batch.schema.names) == {"id", "val"} + + # Verify all 6 IDs are present + all_ids = [] + for rb in collected_batches: + all_ids.extend(rb.batch.column("id").to_pylist()) + assert sorted(all_ids[:6]) == [1, 2, 3, 4, 5, 6] + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of batch `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + poll_record_batch() calls, proving no leaked task or lock. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_break") + await admin.drop_table(table_path, ignore_if_not_exists=True) - # Write some data so there's something to iterate + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) writer = table.new_append().create_writer() writer.write_arrow_batch( pa.RecordBatch.from_arrays( [ - pa.array([1, 2, 3], type=pa.int32()), - pa.array(["a", "b", "c"]), + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"bl{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect just 1 batch) + first_batch = None + + async def consume_and_break(): + nonlocal first_batch + async for rb in batch_scanner: + first_batch = rb + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert first_batch is not None, "Should have received at least 1 batch" + assert first_batch.batch.num_rows > 0 + + # Phase 2: sync poll_record_batch() must still work — proves no leak + remaining = batch_scanner.poll_record_batch(2000) + assert remaining is not None, "poll_record_batch() should return (not deadlock)" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_multiple_batches(connection, admin): + """Verify batch async iteration works across multiple network poll cycles. + + Writing 20 records to 3 buckets ensures the generator must loop through + several _async_poll_batches calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_multi") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"bm{i}" for i in range(1, num_records + 1)]), ], schema=pa.schema( [pa.field("id", pa.int32()), pa.field("val", pa.string())] @@ -945,20 +1082,410 @@ async def test_async_iterator_batch_scanner_raises_type_error( ) await writer.flush() - # Create a BATCH scanner (not a record scanner) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_ids = [] + + async def consume_all(): + async for rb in batch_scanner: + all_ids.extend(rb.batch.column("id").to_pylist()) + if len(all_ids) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(all_ids) >= num_records, ( + f"Expected >={num_records} IDs, got {len(all_ids)}" + ) + assert sorted(all_ids[:num_records]) == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_wrong_scanner_type(connection, admin): + """Verify _async_poll_batches raises TypeError on a record scanner.""" + table_path = fluss.TablePath("fluss", "py_test_apb_wrong_type") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + # Create a RECORD scanner (not batch) + record_scanner = await table.new_scan().create_log_scanner() + record_scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(TypeError): + await record_scanner._async_poll_batches(1000) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_on_batch_scanner_raises_type_error( + connection, admin +): + """Verify _async_poll (record method) raises TypeError on a batch scanner. + + This is the inverse: _async_poll is for records only, _async_poll_batches + is for batches only. Calling the wrong one should raise TypeError. + """ + table_path = fluss.TablePath("fluss", "py_test_apoll_batch_err") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) batch_scanner = await table.new_scan().create_record_batch_log_scanner() batch_scanner.subscribe(bucket_id=0, start_offset=0) - # Attempting async for on a batch scanner must raise TypeError import pytest with pytest.raises(TypeError): + await batch_scanner._async_poll(1000) - async def try_iterate(): - async for _ in batch_scanner: - pass + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_negative_timeout(connection, admin): + """Verify _async_poll_batches rejects a negative timeout_ms with an error.""" + table_path = fluss.TablePath("fluss", "py_test_apb_neg_timeout") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + batch_scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest - await asyncio.wait_for(try_iterate(), timeout=5.0) + with pytest.raises(Exception, match="non-negative"): + await batch_scanner._async_poll_batches(-1) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_returns_list(connection, admin): + """Verify _async_poll_batches returns a Python list of RecordBatch objects.""" + table_path = fluss.TablePath("fluss", "py_test_apb_returns_list") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Poll until we get a non-empty result + result = None + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + result = await batch_scanner._async_poll_batches(2000) + if result: + break + + assert result is not None, "Expected non-None result from _async_poll_batches" + assert isinstance(result, list), ( + f"Expected list, got {type(result).__name__}" + ) + assert len(result) > 0, "Expected non-empty list" + + # Each element must be a RecordBatch with .batch, .bucket, .base_offset, .last_offset + for rb in result: + assert hasattr(rb, "batch"), "RecordBatch should have .batch" + assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" + assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" + assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" + assert isinstance(rb.batch, pa.RecordBatch) + assert rb.base_offset >= 0 + assert rb.last_offset >= rb.base_offset + + # An empty poll (no new data) should return an empty list, not None + empty_result = await batch_scanner._async_poll_batches(100) + assert isinstance(empty_result, list), ( + f"Empty poll should return list, got {type(empty_result).__name__}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_record_batch_metadata(connection, admin): + """Verify that RecordBatch objects yielded by async for contain correct metadata. + + Each RecordBatch must have: + - .bucket with a valid bucket_id + - .base_offset >= 0 + - .last_offset = base_offset + num_rows - 1 (for non-empty batches) + - .batch.num_rows > 0 + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_meta") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 6)), type=pa.int32()), + pa.array([f"m{i}" for i in range(1, 6)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected_batches = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in batch_scanner: + collected_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 5: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 5 + + for rb in collected_batches: + assert rb.batch.num_rows > 0, "Yielded batch should not be empty" + assert rb.base_offset >= 0, "base_offset should be non-negative" + expected_last = rb.base_offset + rb.batch.num_rows - 1 + assert rb.last_offset == expected_last, ( + f"last_offset should be {expected_last}, got {rb.last_offset}" + ) + assert rb.bucket.bucket_id >= 0, "bucket_id should be non-negative" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_to_pandas(connection, admin): + """Verify end-to-end: async for → RecordBatch → .batch.to_pandas().""" + table_path = fluss.TablePath("fluss", "py_test_batch_async_pandas") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("name", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array(["alice", "bob", "charlie"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("name", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_dfs = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in batch_scanner: + df = rb.batch.to_pandas() + all_dfs.append(df) + total_rows += len(df) + if total_rows >= 3: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 3 + + import pandas as pd + combined = pd.concat(all_dfs, ignore_index=True).sort_values("id").reset_index(drop=True) + assert list(combined.columns) == ["id", "name"] + assert combined["id"].tolist()[:3] == [10, 20, 30] + assert combined["name"].tolist()[:3] == ["alice", "bob", "charlie"] + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_projected_columns(connection, admin): + """Verify batch async for respects column projection.""" + table_path = fluss.TablePath("fluss", "py_test_batch_async_proj") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema( + [ + pa.field("col_a", pa.int32()), + pa.field("col_b", pa.string()), + pa.field("col_c", pa.int32()), + ] + ) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + pa.array([10, 20, 30], type=pa.int32()), + ], + schema=pa.schema( + [ + pa.field("col_a", pa.int32()), + pa.field("col_b", pa.string()), + pa.field("col_c", pa.int32()), + ] + ), + ) + ) + await writer.flush() + + # Project only col_b and col_c + proj_scanner = ( + await table.new_scan() + .project_by_name(["col_b", "col_c"]) + .create_record_batch_log_scanner() + ) + num_buckets = (await admin.get_table_info(table_path)).num_buckets + proj_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_batches = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in proj_scanner: + all_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 3: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 3 + + # Verify projected schema: only col_b and col_c, no col_a + for rb in all_batches: + assert set(rb.batch.schema.names) == {"col_b", "col_c"}, ( + f"Projected schema should only have col_b and col_c, " + f"got {rb.batch.schema.names}" + ) + assert rb.batch.num_columns == 2 + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_sync_methods_after_batch_async_iteration(connection, admin): + """Verify sync poll_record_batch() works after batch async iteration. + + This proves no lock contention between the batch async and sync paths. + """ + table_path = fluss.TablePath("fluss", "py_test_sync_after_batch_async") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 9)), type=pa.int32()), + pa.array([f"sv{i}" for i in range(1, 9)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Step 1: Collect 1 batch via async for then break + first_batch = None + + async def partial_consume(): + nonlocal first_batch + async for rb in batch_scanner: + first_batch = rb + break + + await asyncio.wait_for(partial_consume(), timeout=10.0) + assert first_batch is not None + + # Step 2: Sync poll_record_batch() must work (no deadlock) + sync_batches = batch_scanner.poll_record_batch(2000) + assert sync_batches is not None, "poll_record_batch() should return (not deadlock)" + + # Step 3: Sync poll_arrow() must also work + arrow_table = batch_scanner.poll_arrow(2000) + assert arrow_table is not None, "poll_arrow() should return (not deadlock)" await admin.drop_table(table_path, ignore_if_not_exists=False) @@ -1145,3 +1672,4 @@ def _poll_arrow_ids(scanner, expected_count, timeout_s=10): if arrow_table.num_rows > 0: all_ids.extend(arrow_table.column("id").to_pylist()) return all_ids + From 68426a073090a8982f0e440612def37c295a8ba2 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Thu, 12 Mar 2026 22:38:03 -0700 Subject: [PATCH 8/8] chore: update error message for _async_poll and _async_poll_batches so that they match when talking about Record vs Batch --- bindings/python/src/table.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 1dddddbd..130c1cad 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2283,8 +2283,8 @@ async def _async_batch_scan(scanner, timeout_ms=1000): ScannerKind::Record(s) => s, ScannerKind::Batch(_) => { return Err(PyTypeError::new_err( - "Async iteration is only supported for record scanners; \ - use create_log_scanner() instead.", + "This internal method only supports record-based scanners. \ + For batch-based scanners, use 'async for' or 'poll_record_batch' instead.", )); } }; @@ -2342,8 +2342,8 @@ async def _async_batch_scan(scanner, timeout_ms=1000): ScannerKind::Batch(s) => s, ScannerKind::Record(_) => { return Err(PyTypeError::new_err( - "Batch async iteration is only supported for batch scanners; \ - use create_record_batch_log_scanner() instead.", + "This internal method only supports batch-based scanners. \ + For record-based scanners, use 'async for' or 'poll' instead.", )); } };