Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions arrow-format/FlightSql.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,11 @@ message ActionCreatePreparedStatementResult {
// If the query provided contained parameters, parameter_schema contains the
// schema of the expected parameters. It should be an IPC-encapsulated Schema, as described in Schema.fbs.
bytes parameter_schema = 3;

// When set to true, the query should be executed with CommandPreparedStatementUpdate,
// when set to false, the query should be executed with CommandPreparedStatementQuery.
// If not set, the client can choose how to execute the query.
optional bool is_update = 4;
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();

final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null);
final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null, null);

final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();

final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null);
final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null, null);

final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public ArrowFlightMetaImpl(final AvaticaConnection connection) {
}

/** Construct a signature. */
static Signature newSignature(final String sql, Schema resultSetSchema, Schema parameterSchema) {
static Signature newSignature(
final String sql, Schema resultSetSchema, Schema parameterSchema, Boolean isUpdate) {
List<ColumnMetaData> columnMetaData =
resultSetSchema == null
? new ArrayList<>()
Expand All @@ -62,10 +63,17 @@ static Signature newSignature(final String sql, Schema resultSetSchema, Schema p
parameterSchema == null
? new ArrayList<>()
: ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields());
StatementType statementType =
resultSetSchema == null || resultSetSchema.getFields().isEmpty()
? StatementType.IS_DML
: StatementType.SELECT;
// If the server provided the is_update field, use it to determine the statement type
StatementType statementType;
if (isUpdate != null) {
statementType = isUpdate ? StatementType.IS_DML : StatementType.SELECT;
} else {
// Fall back to the legacy logic: check if the result set schema is empty
statementType =
resultSetSchema == null || resultSetSchema.getFields().isEmpty()
? StatementType.IS_DML
: StatementType.SELECT;
}
return new Signature(
columnMetaData,
sql,
Expand Down Expand Up @@ -178,7 +186,10 @@ private PreparedStatement prepareForHandle(final String query, StatementHandle h
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
handle.signature =
newSignature(
query, preparedStatement.getDataSetSchema(), preparedStatement.getParameterSchema());
query,
preparedStatement.getDataSetSchema(),
preparedStatement.getParameterSchema(),
preparedStatement.isUpdate());
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
return preparedStatement;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,14 @@ public interface PreparedStatement extends AutoCloseable {
*/
Schema getParameterSchema();

/**
* Gets whether this {@link PreparedStatement} is an update statement.
*
* @return {@code true} if this is an update statement, {@code false} if it's a query, or {@code
* null} if the server did not provide this information.
*/
Boolean isUpdate();

void setParameters(VectorSchemaRoot parameters);

@Override
Expand Down Expand Up @@ -456,6 +464,12 @@ public long executeUpdate() {

@Override
public StatementType getType() {
// If the server provided the is_update field, use it to determine the statement type
final Boolean isUpdate = preparedStatement.isUpdate();
if (isUpdate != null) {
return isUpdate ? StatementType.UPDATE : StatementType.SELECT;
}
// Fall back to the legacy logic: check if the result set schema is empty
final Schema schema = preparedStatement.getResultSetSchema();
return schema.getFields().isEmpty() ? StatementType.UPDATE : StatementType.SELECT;
}
Expand All @@ -475,6 +489,11 @@ public void setParameters(VectorSchemaRoot parameters) {
preparedStatement.setParameters(parameters);
}

@Override
public Boolean isUpdate() {
return preparedStatement.isUpdate();
}

@Override
public void close() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,39 @@ public void testSimpleQueryNoParameterBindingWithExecute() throws SQLException {
}
}

@Test
public void testSimpleQueryNoParameterBindingWithExecuteV2() throws SQLException {
final String query = "SELECT * FROM TEST_V2";
final Schema schema =
new Schema(Collections.singletonList(Field.nullable("", Types.MinorType.INT.getType())));
PRODUCER.addSelectQueryV2(
query,
schema,
Collections.singletonList(
listener -> {
try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
root.allocateNew();
((IntVector) root.getVector(0)).setSafe(0, 123);
root.setRowCount(1);
listener.start(root);
listener.putNext();
} finally {
listener.completed();
}
}));
try (final PreparedStatement preparedStatement = connection.prepareStatement(query)) {
boolean isResultSet = preparedStatement.execute();
assertTrue(isResultSet);
final ResultSet resultSet = preparedStatement.getResultSet();
assertTrue(resultSet.next());
assertEquals(123, resultSet.getInt(1));
assertFalse(resultSet.next());
assertFalse(preparedStatement.getMoreResults());
assertEquals(-1, preparedStatement.getUpdateCount());
}
}

@Test
public void testQueryWithParameterBinding() throws SQLException {
final String query = "Fake query with parameters";
Expand Down Expand Up @@ -203,6 +236,20 @@ public void testUpdateQueryWithExecute() throws SQLException {
}
}

@Test
public void testUpdateQueryWithExecuteV2() throws SQLException {
String query = "Fake update with execute V2";
PRODUCER.addUpdateQueryV2(query, /*updatedRows*/ 99);
try (final PreparedStatement stmt = connection.prepareStatement(query)) {
boolean isResultSet = stmt.execute();
assertFalse(isResultSet);
int updated = stmt.getUpdateCount();
assertEquals(99, updated);
assertFalse(stmt.getMoreResults());
assertEquals(-1, stmt.getUpdateCount());
}
}

@Test
public void testUpdateQueryWithParameters() throws SQLException {
String query = "Fake update with parameters";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ public class ArrowFlightStatementExecuteTest {
private static final String SAMPLE_LARGE_UPDATE_QUERY =
"UPDATE this_large_table SET this_large_field = that_large_field FROM this_large_test WHERE this_large_condition";
private static final long SAMPLE_LARGE_UPDATE_COUNT = Long.MAX_VALUE;
private static final String SAMPLE_QUERY_CMD_V2 = "SELECT * FROM this_test_v2";
private static final String SAMPLE_LARGE_UPDATE_QUERY_V2 =
"UPDATE this_large_table_v2 SET this_large_field = that_large_field FROM this_large_test WHERE this_large_condition";
private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer();

@RegisterExtension
Expand Down Expand Up @@ -96,6 +99,30 @@ public static void setUpBeforeClass() {
}));
PRODUCER.addUpdateQuery(SAMPLE_UPDATE_QUERY, SAMPLE_UPDATE_COUNT);
PRODUCER.addUpdateQuery(SAMPLE_LARGE_UPDATE_QUERY, SAMPLE_LARGE_UPDATE_COUNT);

// V2 queries with is_update field set
PRODUCER.addSelectQueryV2(
SAMPLE_QUERY_CMD_V2,
SAMPLE_QUERY_SCHEMA,
Collections.singletonList(
listener -> {
try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
final VectorSchemaRoot root =
VectorSchemaRoot.create(SAMPLE_QUERY_SCHEMA, allocator)) {
final UInt1Vector vector = (UInt1Vector) root.getVector(VECTOR_NAME);
IntStream.range(0, SAMPLE_QUERY_ROWS)
.forEach(index -> vector.setSafe(index, index));
vector.setValueCount(SAMPLE_QUERY_ROWS);
root.setRowCount(SAMPLE_QUERY_ROWS);
listener.start(root);
listener.putNext();
} catch (final Throwable throwable) {
listener.error(throwable);
} finally {
listener.completed();
}
}));
PRODUCER.addUpdateQueryV2(SAMPLE_LARGE_UPDATE_QUERY_V2, SAMPLE_LARGE_UPDATE_COUNT);
}

@BeforeEach
Expand Down Expand Up @@ -168,4 +195,42 @@ public void testUpdateCountShouldStartOnZero() throws SQLException {
is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(0L))));
assertThat(statement.getResultSet(), is(nullValue()));
}

@Test
public void testExecuteShouldRunSelectQueryV2() throws SQLException {
assertThat(statement.execute(SAMPLE_QUERY_CMD_V2), is(true));
final Set<Byte> numbers =
IntStream.range(0, SAMPLE_QUERY_ROWS)
.boxed()
.map(Integer::byteValue)
.collect(Collectors.toCollection(HashSet::new));
try (final ResultSet resultSet = statement.getResultSet()) {
final int columnCount = resultSet.getMetaData().getColumnCount();
assertThat(columnCount, is(1));
int rowCount = 0;
for (; resultSet.next(); rowCount++) {
assertThat(numbers.remove(resultSet.getByte(1)), is(true));
}
assertThat(rowCount, is(equalTo(SAMPLE_QUERY_ROWS)));
}
assertThat(numbers, is(Collections.emptySet()));
assertThat(
(long) statement.getUpdateCount(),
is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(-1L))));
}

@Test
public void testExecuteShouldRunUpdateQueryForLargeUpdateV2() throws SQLException {
assertThat(statement.execute(SAMPLE_LARGE_UPDATE_QUERY_V2), is(false)); // UPDATE query.
final long updateCountSmall = statement.getUpdateCount();
final long updateCountLarge = statement.getLargeUpdateCount();
assertThat(updateCountLarge, is(equalTo(SAMPLE_LARGE_UPDATE_COUNT)));
assertThat(
updateCountSmall,
is(
allOf(
equalTo((long) AvaticaUtils.toSaturatedInt(updateCountLarge)),
not(equalTo(updateCountLarge)))));
assertThat(statement.getResultSet(), is(nullValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public final class MockFlightSqlProducer implements FlightSqlProducer {
private final SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder();
private final Map<String, Schema> parameterSchemas = new HashMap<>();
private final Map<String, List<List<Object>>> expectedParameterValues = new HashMap<>();
private final Map<String, Boolean> isUpdateMap = new HashMap<>();

private final Map<String, Integer> actionTypeCounter = new HashMap<>();

Expand Down Expand Up @@ -176,6 +177,32 @@ public void addUpdateQuery(final String sqlCommand, final long updatedRows) {
});
}

/**
* Registers a new {@link StatementType#SELECT} SQL query with is_update field set.
*
* @param sqlCommand the SQL command under which to register the new query.
* @param schema the schema to use for the query result.
* @param resultProviders the result provider for this query.
*/
public void addSelectQueryV2(
final String sqlCommand,
final Schema schema,
final List<Consumer<ServerStreamListener>> resultProviders) {
addSelectQuery(sqlCommand, schema, resultProviders);
isUpdateMap.put(sqlCommand, false);
}

/**
* Registers a new {@link StatementType#UPDATE} SQL query with is_update field set.
*
* @param sqlCommand the SQL command.
* @param updatedRows the number of rows affected.
*/
public void addUpdateQueryV2(final String sqlCommand, final long updatedRows) {
addUpdateQuery(sqlCommand, updatedRows);
isUpdateMap.put(sqlCommand, true);
}

/**
* Adds a catalog query to the results.
*
Expand Down Expand Up @@ -247,6 +274,12 @@ public void createPreparedStatement(
resultBuilder.setParameterSchema(ByteString.copyFrom(outputStream.toByteArray()));
}

// Set is_update field if present
final Boolean isUpdate = isUpdateMap.get(query);
if (isUpdate != null) {
resultBuilder.setIsUpdate(isUpdate);
}

listener.onNext(new Result(pack(resultBuilder.build()).toByteArray()));
} catch (final Throwable t) {
listener.onError(t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,19 @@ public Schema getParameterSchema() {
return parameterSchema;
}

/**
* Returns whether the server indicated this prepared statement is an update query.
*
* @return true if the server indicated this is an update query, false if the server indicated
* this is a select query, or null if the server did not provide this information.
*/
public Boolean isUpdate() {
if (preparedStatementResult.hasIsUpdate()) {
return preparedStatementResult.getIsUpdate();
}
return null;
}

/** Get the schema of the result set (should be identical to {@link #getResultSetSchema()}). */
public SchemaResult fetchSchema(CallOption... options) {
checkOpen();
Expand Down
Loading