diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 660cd6be..130c1cad 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -30,6 +30,7 @@ use pyo3::types::{ PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType, PyTzInfo, }; +use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python}; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; use std::sync::Arc; @@ -1887,7 +1888,7 @@ impl ScannerKind { /// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface. macro_rules! with_scanner { ($scanner:expr, $method:ident($($arg:expr),*)) => { - match $scanner { + match $scanner.as_ref() { ScannerKind::Record(s) => s.$method($($arg),*).await, ScannerKind::Batch(s) => s.$method($($arg),*).await, } @@ -1901,7 +1902,7 @@ macro_rules! with_scanner { /// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches #[pyclass] pub struct LogScanner { - scanner: ScannerKind, + kind: Arc, admin: fcore::client::FlussAdmin, table_info: fcore::metadata::TableInfo, /// The projected Arrow schema to use for empty table creation @@ -1922,7 +1923,7 @@ impl LogScanner { fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe(bucket_id, start_offset)) + with_scanner!(&self.kind, subscribe(bucket_id, start_offset)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1935,7 +1936,7 @@ impl LogScanner { fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets)) + with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1957,7 +1958,7 @@ impl LogScanner { py.detach(|| { TOKIO_RUNTIME.block_on(async { with_scanner!( - &self.scanner, + &self.kind, subscribe_partition(partition_id, bucket_id, start_offset) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1977,7 +1978,7 @@ impl LogScanner { py.detach(|| { TOKIO_RUNTIME.block_on(async { with_scanner!( - &self.scanner, + &self.kind, subscribe_partition_buckets(&partition_bucket_offsets) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1992,7 +1993,7 @@ impl LogScanner { fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, unsubscribe(bucket_id)) + with_scanner!(&self.kind, unsubscribe(bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2006,11 +2007,8 @@ impl LogScanner { fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!( - &self.scanner, - unsubscribe_partition(partition_id, bucket_id) - ) - .map_err(|e| FlussError::from_core_error(&e)) + with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id)) + .map_err(|e| FlussError::from_core_error(&e)) }) }) } @@ -2030,7 +2028,7 @@ impl LogScanner { /// - Returns an empty ScanRecords if no records are available /// - When timeout expires, returns an empty ScanRecords (NOT an error) fn poll(&self, py: Python, timeout_ms: i64) -> PyResult { - let scanner = self.scanner.as_record()?; + let scanner = self.kind.as_record()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2079,7 +2077,7 @@ impl LogScanner { /// - Returns an empty list if no batches are available /// - When timeout expires, returns an empty list (NOT an error) fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2114,7 +2112,7 @@ impl LogScanner { /// - Returns an empty table (with correct schema) if no records are available /// - When timeout expires, returns an empty table (NOT an error) fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2167,13 +2165,16 @@ impl LogScanner { /// Returns: /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { - let scanner = self.scanner.as_batch()?; - let subscribed = scanner.get_subscribed_buckets(); - if subscribed.is_empty() { - return Err(FlussError::new_err( - "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", - )); - } + let subscribed = { + let scanner = self.kind.as_batch()?; + let subs = scanner.get_subscribed_buckets(); + if subs.is_empty() { + return Err(FlussError::new_err( + "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", + )); + } + subs.clone() + }; // 2. Query latest offsets for all subscribed buckets let stopping_offsets = self.query_latest_offsets(py, &subscribed)?; @@ -2199,6 +2200,171 @@ impl LogScanner { Ok(df) } + fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { + let py = slf.py(); + + match slf.kind.as_ref() { + ScannerKind::Record(_) => { + static RECORD_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + let gen_fn = RECORD_ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_scan(scanner, timeout_ms=1000): + while True: + batch = await scanner._async_poll(timeout_ms) + if batch: + for record in batch: + yield record +"# + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals.get_item("_async_scan").unwrap().unwrap().unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + } + ScannerKind::Batch(_) => { + static BATCH_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + let gen_fn = BATCH_ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_batch_scan(scanner, timeout_ms=1000): + while True: + batches = await scanner._async_poll_batches(timeout_ms) + if batches: + for rb in batches: + yield rb +"# + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals + .get_item("_async_batch_scan") + .unwrap() + .unwrap() + .unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + } + } + } + + /// Perform a single bounded poll and return a list of ScanRecord objects. + /// + /// This is the async building block used by `__aiter__` to implement + /// `async for`. Each call does exactly one network poll (bounded by + /// `timeout_ms`), converts any results to Python objects, and returns + /// them as a list. An empty list signals a timeout (no data yet), not + /// end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of ScanRecord objects + fn _async_poll<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } + + let scanner = Arc::clone(&self.kind); + let projected_row_type = self.projected_row_type.clone(); + let timeout = Duration::from_millis(timeout_ms as u64); + + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Record(s) => s, + ScannerKind::Batch(_) => { + return Err(PyTypeError::new_err( + "This internal method only supports record-based scanners. \ + For batch-based scanners, use 'async for' or 'poll_record_batch' instead.", + )); + } + }; + + let scan_records = core_scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // Convert to Python list + Python::attach(|py| { + let mut result: Vec> = Vec::new(); + for (_, records) in scan_records.into_records_by_buckets() { + for core_record in records { + let scan_record = + ScanRecord::from_core(py, &core_record, &projected_row_type)?; + result.push(Py::new(py, scan_record)?); + } + } + Ok(result) + }) + }) + } + + /// Perform a single bounded poll and return a list of RecordBatch objects. + /// + /// This is the async building block used by `__aiter__` (batch mode) to + /// implement `async for`. Each call does exactly one network poll (bounded + /// by `timeout_ms`), converts any results to Python RecordBatch objects, + /// and returns them as a list. An empty list signals a timeout (no data + /// yet), not end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of RecordBatch objects + fn _async_poll_batches<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } + + let scanner = Arc::clone(&self.kind); + let timeout = Duration::from_millis(timeout_ms as u64); + + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Batch(s) => s, + ScannerKind::Record(_) => { + return Err(PyTypeError::new_err( + "This internal method only supports batch-based scanners. \ + For record-based scanners, use 'async for' or 'poll' instead.", + )); + } + }; + + let scan_batches = core_scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // Convert to Python list of RecordBatch objects + Python::attach(|py| { + let mut result: Vec> = Vec::new(); + for scan_batch in scan_batches { + let rb = RecordBatch::from_scan_batch(scan_batch); + result.push(Py::new(py, rb)?); + } + Ok(result) + }) + }) + } + fn __repr__(&self) -> String { format!("LogScanner(table={})", self.table_info.table_path) } @@ -2213,7 +2379,7 @@ impl LogScanner { projected_row_type: fcore::metadata::RowType, ) -> Self { Self { - scanner, + kind: Arc::new(scanner), admin, table_info, projected_schema, @@ -2264,7 +2430,7 @@ impl LogScanner { py: Python, subscribed: &[(fcore::metadata::TableBucket, i64)], ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; let is_partitioned = scanner.is_partitioned(); let table_path = &self.table_info.table_path; @@ -2367,7 +2533,7 @@ impl LogScanner { py: Python, mut stopping_offsets: HashMap, ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; let mut all_batches = Vec::new(); while !stopping_offsets.is_empty() { diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index dd1a4d4f..970a516f 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -729,6 +729,923 @@ async def test_scan_records_indexing_and_slicing(connection, admin): await admin.drop_table(table_path, ignore_if_not_exists=False) +async def test_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on LogScanner.""" + table_path = fluss.TablePath("fluss", "py_test_async_iterator") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + # Write 5 records + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [pa.array(list(range(1, 6)), type=pa.int32()), + pa.array([f"async{i}" for i in range(1, 6)])], + schema=pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets({i: fluss.EARLIEST_OFFSET for i in range(num_buckets)}) + + collected = [] + + # Here is the magical Issue #424 async iterator logic at work: + async def consume_scanner(): + async for record in scanner: + collected.append(record) + if len(collected) == 5: + break + + # We must race the consumption against a timeout so the test doesn't hang if the iterator is broken + await asyncio.wait_for(consume_scanner(), timeout=10.0) + + assert len(collected) == 5, f"Expected 5 records, got {len(collected)}" + + collected.sort(key=lambda r: r.row["id"]) + for i, record in enumerate(collected): + assert record.row["id"] == i + 1 + assert record.row["val"] == f"async{i + 1}" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + `poll()` calls. If the old implementation's tokio::spawn'd task + were still alive, it would hold the Mutex and cause `poll()` to + deadlock or error. + """ + table_path = fluss.TablePath("fluss", "py_test_async_break_leak") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"v{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect only 3 of 10) + collected_async = [] + + async def consume_and_break(): + async for record in scanner: + collected_async.append(record) + if len(collected_async) >= 3: + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert len(collected_async) == 3, ( + f"Expected 3 records from async for, got {len(collected_async)}" + ) + + # Phase 2: sync poll() must still work — proves no leaked task / lock. + # With small data and few buckets, _async_poll may have fetched all + # records in one batch. After break, the un-yielded records from that + # batch are lost. So sync poll may return 0 records — the key assertion + # is that poll() completes without deadlock (returns within timeout). + remaining = scanner.poll(2000) + assert remaining is not None, "poll() should return (not deadlock)" + + # If we got records, verify no duplicates + async_ids = {r.row["id"] for r in collected_async} + sync_ids = {r.row["id"] for r in remaining} + assert async_ids.isdisjoint(sync_ids), ( + f"Duplicate IDs between async and sync: {async_ids & sync_ids}" + ) + + # All IDs must be from the original 1-10 range + all_ids = async_ids | sync_ids + assert all_ids.issubset(set(range(1, 11))), ( + f"Unexpected IDs: {all_ids - set(range(1, 11))}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_multiple_batches(connection, admin): + """Verify async iteration works across multiple network poll cycles. + + _async_poll does a single bounded poll per call. Writing 20 records + to multiple buckets ensures the Python generator must loop through + several _async_poll calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_async_multi_batch") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"multi{i}" for i in range(1, num_records + 1)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected = [] + + async def consume_all(): + async for record in scanner: + collected.append(record) + if len(collected) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(collected) == num_records, ( + f"Expected {num_records} records, got {len(collected)}" + ) + + # Verify all IDs are present (order may vary due to bucketing) + ids = sorted(r.row["id"] for r in collected) + assert ids == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on a batch LogScanner. + + With our __aiter__ dispatch, a batch-based scanner should yield RecordBatch + objects (not ScanRecord). Each yielded item has .batch (PyArrow RecordBatch), + .bucket, .base_offset, .last_offset. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_iter") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 7)), type=pa.int32()), + pa.array([f"bv{i}" for i in range(1, 7)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected_batches = [] + total_rows = 0 + + async def consume_batches(): + nonlocal total_rows + async for rb in batch_scanner: + collected_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 6: + break + + await asyncio.wait_for(consume_batches(), timeout=15.0) + + assert total_rows >= 6, f"Expected >=6 total rows, got {total_rows}" + assert len(collected_batches) > 0 + + # Verify each yielded item is a RecordBatch with expected attributes + for rb in collected_batches: + assert hasattr(rb, "batch"), "RecordBatch should have .batch" + assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" + assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" + assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" + # .batch should be a PyArrow RecordBatch + arrow_batch = rb.batch + assert isinstance(arrow_batch, pa.RecordBatch), ( + f"Expected PyArrow RecordBatch, got {type(arrow_batch).__name__}" + ) + assert arrow_batch.num_columns == 2 + assert set(arrow_batch.schema.names) == {"id", "val"} + + # Verify all 6 IDs are present + all_ids = [] + for rb in collected_batches: + all_ids.extend(rb.batch.column("id").to_pylist()) + assert sorted(all_ids[:6]) == [1, 2, 3, 4, 5, 6] + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of batch `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + poll_record_batch() calls, proving no leaked task or lock. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_break") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"bl{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect just 1 batch) + first_batch = None + + async def consume_and_break(): + nonlocal first_batch + async for rb in batch_scanner: + first_batch = rb + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert first_batch is not None, "Should have received at least 1 batch" + assert first_batch.batch.num_rows > 0 + + # Phase 2: sync poll_record_batch() must still work — proves no leak + remaining = batch_scanner.poll_record_batch(2000) + assert remaining is not None, "poll_record_batch() should return (not deadlock)" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_multiple_batches(connection, admin): + """Verify batch async iteration works across multiple network poll cycles. + + Writing 20 records to 3 buckets ensures the generator must loop through + several _async_poll_batches calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_multi") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"bm{i}" for i in range(1, num_records + 1)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_ids = [] + + async def consume_all(): + async for rb in batch_scanner: + all_ids.extend(rb.batch.column("id").to_pylist()) + if len(all_ids) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(all_ids) >= num_records, ( + f"Expected >={num_records} IDs, got {len(all_ids)}" + ) + assert sorted(all_ids[:num_records]) == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_wrong_scanner_type(connection, admin): + """Verify _async_poll_batches raises TypeError on a record scanner.""" + table_path = fluss.TablePath("fluss", "py_test_apb_wrong_type") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + # Create a RECORD scanner (not batch) + record_scanner = await table.new_scan().create_log_scanner() + record_scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(TypeError): + await record_scanner._async_poll_batches(1000) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_on_batch_scanner_raises_type_error( + connection, admin +): + """Verify _async_poll (record method) raises TypeError on a batch scanner. + + This is the inverse: _async_poll is for records only, _async_poll_batches + is for batches only. Calling the wrong one should raise TypeError. + """ + table_path = fluss.TablePath("fluss", "py_test_apoll_batch_err") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + batch_scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(TypeError): + await batch_scanner._async_poll(1000) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_negative_timeout(connection, admin): + """Verify _async_poll_batches rejects a negative timeout_ms with an error.""" + table_path = fluss.TablePath("fluss", "py_test_apb_neg_timeout") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + batch_scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(Exception, match="non-negative"): + await batch_scanner._async_poll_batches(-1) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_returns_list(connection, admin): + """Verify _async_poll_batches returns a Python list of RecordBatch objects.""" + table_path = fluss.TablePath("fluss", "py_test_apb_returns_list") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Poll until we get a non-empty result + result = None + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + result = await batch_scanner._async_poll_batches(2000) + if result: + break + + assert result is not None, "Expected non-None result from _async_poll_batches" + assert isinstance(result, list), ( + f"Expected list, got {type(result).__name__}" + ) + assert len(result) > 0, "Expected non-empty list" + + # Each element must be a RecordBatch with .batch, .bucket, .base_offset, .last_offset + for rb in result: + assert hasattr(rb, "batch"), "RecordBatch should have .batch" + assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" + assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" + assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" + assert isinstance(rb.batch, pa.RecordBatch) + assert rb.base_offset >= 0 + assert rb.last_offset >= rb.base_offset + + # An empty poll (no new data) should return an empty list, not None + empty_result = await batch_scanner._async_poll_batches(100) + assert isinstance(empty_result, list), ( + f"Empty poll should return list, got {type(empty_result).__name__}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_record_batch_metadata(connection, admin): + """Verify that RecordBatch objects yielded by async for contain correct metadata. + + Each RecordBatch must have: + - .bucket with a valid bucket_id + - .base_offset >= 0 + - .last_offset = base_offset + num_rows - 1 (for non-empty batches) + - .batch.num_rows > 0 + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_meta") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 6)), type=pa.int32()), + pa.array([f"m{i}" for i in range(1, 6)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected_batches = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in batch_scanner: + collected_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 5: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 5 + + for rb in collected_batches: + assert rb.batch.num_rows > 0, "Yielded batch should not be empty" + assert rb.base_offset >= 0, "base_offset should be non-negative" + expected_last = rb.base_offset + rb.batch.num_rows - 1 + assert rb.last_offset == expected_last, ( + f"last_offset should be {expected_last}, got {rb.last_offset}" + ) + assert rb.bucket.bucket_id >= 0, "bucket_id should be non-negative" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_to_pandas(connection, admin): + """Verify end-to-end: async for → RecordBatch → .batch.to_pandas().""" + table_path = fluss.TablePath("fluss", "py_test_batch_async_pandas") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("name", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array(["alice", "bob", "charlie"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("name", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_dfs = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in batch_scanner: + df = rb.batch.to_pandas() + all_dfs.append(df) + total_rows += len(df) + if total_rows >= 3: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 3 + + import pandas as pd + combined = pd.concat(all_dfs, ignore_index=True).sort_values("id").reset_index(drop=True) + assert list(combined.columns) == ["id", "name"] + assert combined["id"].tolist()[:3] == [10, 20, 30] + assert combined["name"].tolist()[:3] == ["alice", "bob", "charlie"] + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_projected_columns(connection, admin): + """Verify batch async for respects column projection.""" + table_path = fluss.TablePath("fluss", "py_test_batch_async_proj") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema( + [ + pa.field("col_a", pa.int32()), + pa.field("col_b", pa.string()), + pa.field("col_c", pa.int32()), + ] + ) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + pa.array([10, 20, 30], type=pa.int32()), + ], + schema=pa.schema( + [ + pa.field("col_a", pa.int32()), + pa.field("col_b", pa.string()), + pa.field("col_c", pa.int32()), + ] + ), + ) + ) + await writer.flush() + + # Project only col_b and col_c + proj_scanner = ( + await table.new_scan() + .project_by_name(["col_b", "col_c"]) + .create_record_batch_log_scanner() + ) + num_buckets = (await admin.get_table_info(table_path)).num_buckets + proj_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_batches = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in proj_scanner: + all_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 3: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 3 + + # Verify projected schema: only col_b and col_c, no col_a + for rb in all_batches: + assert set(rb.batch.schema.names) == {"col_b", "col_c"}, ( + f"Projected schema should only have col_b and col_c, " + f"got {rb.batch.schema.names}" + ) + assert rb.batch.num_columns == 2 + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_sync_methods_after_batch_async_iteration(connection, admin): + """Verify sync poll_record_batch() works after batch async iteration. + + This proves no lock contention between the batch async and sync paths. + """ + table_path = fluss.TablePath("fluss", "py_test_sync_after_batch_async") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 9)), type=pa.int32()), + pa.array([f"sv{i}" for i in range(1, 9)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Step 1: Collect 1 batch via async for then break + first_batch = None + + async def partial_consume(): + nonlocal first_batch + async for rb in batch_scanner: + first_batch = rb + break + + await asyncio.wait_for(partial_consume(), timeout=10.0) + assert first_batch is not None + + # Step 2: Sync poll_record_batch() must work (no deadlock) + sync_batches = batch_scanner.poll_record_batch(2000) + assert sync_batches is not None, "poll_record_batch() should return (not deadlock)" + + # Step 3: Sync poll_arrow() must also work + arrow_table = batch_scanner.poll_arrow(2000) + assert arrow_table is not None, "poll_arrow() should return (not deadlock)" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_negative_timeout(connection, admin): + """Verify _async_poll rejects a negative timeout_ms with an error.""" + table_path = fluss.TablePath("fluss", "py_test_async_poll_neg_timeout") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + scanner = await table.new_scan().create_log_scanner() + scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(Exception, match="non-negative"): + await scanner._async_poll(-1) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_returns_list(connection, admin): + """Verify _async_poll returns a Python list of ScanRecord objects.""" + table_path = fluss.TablePath("fluss", "py_test_async_poll_returns_list") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Poll until we get a non-empty result + result = None + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + result = await scanner._async_poll(2000) + if result: + break + + assert result is not None, "Expected non-None result from _async_poll" + assert isinstance(result, list), ( + f"Expected list, got {type(result).__name__}" + ) + assert len(result) > 0, "Expected non-empty list" + + # Each element must be a ScanRecord with .row, .offset, .timestamp + for record in result: + assert hasattr(record, "row"), "ScanRecord should have .row" + assert hasattr(record, "offset"), "ScanRecord should have .offset" + assert hasattr(record, "timestamp"), ( + "ScanRecord should have .timestamp" + ) + assert "id" in record.row + + # An empty poll (no new data) should return an empty list, not None + empty_result = await scanner._async_poll(100) + assert isinstance(empty_result, list), ( + f"Empty poll should return list, got {type(empty_result).__name__}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_sync_methods_after_async_iteration(connection, admin): + """Verify sync poll() works correctly interleaved with async iteration. + + This proves there is no lock contention between the async and sync + code paths — the removed Mutex would have caused deadlocks here if + the lock were held across the async poll boundary. + """ + table_path = fluss.TablePath( + "fluss", "py_test_sync_after_async" + ) + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 9)), type=pa.int32()), + pa.array([f"s{i}" for i in range(1, 9)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Step 1: Collect 4 records via async for + async_records = [] + + async def partial_consume(): + async for record in scanner: + async_records.append(record) + if len(async_records) >= 4: + break + + await asyncio.wait_for(partial_consume(), timeout=10.0) + assert len(async_records) == 4 + + # Step 2: Collect remaining records via sync poll(). + # With small data, _async_poll may have fetched all records in one + # batch. After break, the un-yielded records are lost. The key + # assertion is that poll() works (no deadlock from a held lock). + sync_records = scanner.poll(2000) + assert sync_records is not None, "poll() should return (not deadlock)" + + # Step 3: Verify no duplicates and all IDs are valid + async_ids = {r.row["id"] for r in async_records} + sync_ids = {r.row["id"] for r in sync_records} + assert async_ids.isdisjoint(sync_ids), ( + f"Duplicate IDs: {async_ids & sync_ids}" + ) + all_ids = async_ids | sync_ids + assert all_ids.issubset(set(range(1, 9))), ( + f"Unexpected IDs: {all_ids - set(range(1, 9))}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -755,3 +1672,4 @@ def _poll_arrow_ids(scanner, expected_count, timeout_s=10): if arrow_table.num_rows > 0: all_ids.extend(arrow_table.column("id").to_pylist()) return all_ids +