Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Weaviate Java client <img alt='Weaviate logo' src='https://github.com/weaviate/java-client/blob/v6/logo.png' width='200' align='right' />
# Weaviate Java client <img alt='Weaviate logo' src='https://github.com/weaviate/java-client/blob/main/logo.png' width='200' align='right' />

[![Build Status](https://github.com/weaviate/java-client/actions/workflows/.github/workflows/test.yaml/badge.svg?branch=main)](https://github.com/weaviate/java-client/actions/workflows/.github/workflows/test.yaml)

Expand Down
9 changes: 0 additions & 9 deletions src/it/java/io/weaviate/integration/RbacITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,6 @@ public void test_roles_Lifecycle() throws IOException {
Permission.groups("my-group", GroupType.OIDC, GroupsPermission.Action.READ));
});

requireAtLeast(Weaviate.Version.V132, () -> {
permissions.add(
Permission.aliases("ThingsAlias", myCollection, AliasesPermission.Action.CREATE));
});
requireAtLeast(Weaviate.Version.V133, () -> {
permissions.add(
Permission.groups("my-group", GroupType.OIDC, GroupsPermission.Action.READ));
});

// Act: create role
client.roles.create(nsRole, permissions);

Expand Down
20 changes: 20 additions & 0 deletions src/main/java/io/weaviate/client6/v1/api/Authentication.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ public static Authentication resourceOwnerPassword(String username, String passw
};
}

/**
* Authenticate using Resource Owner Password Credentials authorization grant.
*
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
* @param scopes Client scopes.
*
* @return Authentication provider.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static Authentication resourceOwnerPasswordCredentials(String clientSecret, String username, String password,
List<String> scopes) {
return transport -> {
OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes).withScopes("offline_access");
return TokenProvider.resourceOwnerPasswordCredentials(oidc, clientSecret, username, password);
};
}

/**
* Authenticate using Client Credentials authorization grant.
*
Expand Down
21 changes: 17 additions & 4 deletions src/main/java/io/weaviate/client6/v1/api/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import io.weaviate.client6.v1.internal.BuildInfo;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.Timeout;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.TransportOptions;
Expand All @@ -24,7 +25,8 @@ public record Config(
Map<String, String> headers,
Authentication authentication,
TrustManagerFactory trustManagerFactory,
Timeout timeout) {
Timeout timeout,
Proxy proxy) {

public static Config of(Function<Custom, ObjectBuilder<Config>> fn) {
return fn.apply(new Custom()).build();
Expand All @@ -40,23 +42,24 @@ private Config(Builder<?> builder) {
builder.headers,
builder.authentication,
builder.trustManagerFactory,
builder.timeout);
builder.timeout,
builder.proxy);
}

RestTransportOptions restTransportOptions() {
return restTransportOptions(null);
}

RestTransportOptions restTransportOptions(TokenProvider tokenProvider) {
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout);
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

GrpcChannelOptions grpcTransportOptions() {
return grpcTransportOptions(null);
}

GrpcChannelOptions grpcTransportOptions(TokenProvider tokenProvider) {
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout);
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

private abstract static class Builder<SelfT extends Builder<SelfT>> implements ObjectBuilder<Config> {
Expand All @@ -70,6 +73,7 @@ private abstract static class Builder<SelfT extends Builder<SelfT>> implements O
protected TrustManagerFactory trustManagerFactory;
protected Timeout timeout = new Timeout();
protected Map<String, String> headers = new HashMap<>();
protected Proxy proxy;

/**
* Set URL scheme. Subclasses may increase the visibility of this method to
Expand Down Expand Up @@ -175,6 +179,15 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) {
return (SelfT) this;
}

/**
* Set proxy for all requests.
*/
@SuppressWarnings("unchecked")
public SelfT proxy(Proxy proxy) {
this.proxy = proxy;
return (SelfT) this;
}

/**
* Weaviate will use the URL in this header to call Weaviate Embeddings
* Service if an appropriate vectorizer is configured for collection.
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ public class WeaviateClient implements AutoCloseable {
public final WeaviateClusterClient cluster;

public WeaviateClient(Config config) {
RestTransportOptions restOpt;
RestTransportOptions restOpt = config.restTransportOptions();
GrpcChannelOptions grpcOpt;
if (config.authentication() == null) {
restOpt = config.restTransportOptions();
grpcOpt = config.grpcTransportOptions();
} else {
TokenProvider tokenProvider;
try (final var noAuthRest = new DefaultRestTransport(config.restTransportOptions())) {
try (final var noAuthRest = new DefaultRestTransport(restOpt)) {
tokenProvider = config.authentication().getTokenProvider(noAuthRest);
} catch (Exception e) {
// Generally exceptions are caught in TokenProvider internals.
Expand Down Expand Up @@ -126,6 +125,10 @@ public WeaviateClient(Config config) {
this.config = config;
}

public Config getConfig() {
return config;
}

/**
* Create {@link WeaviateClientAsync} with identical configurations.
* It is a shorthand for:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public CollectionHandle<Map<String, Object>> use(
return use(CollectionDescriptor.ofMap(collectionName), fn);
}

private <PropertiesT> CollectionHandle<PropertiesT> use(CollectionDescriptor<PropertiesT> collection,
public <PropertiesT> CollectionHandle<PropertiesT> use(CollectionDescriptor<PropertiesT> collection,
Function<CollectionHandleDefaults.Builder, ObjectBuilder<CollectionHandleDefaults>> fn) {
return new CollectionHandle<>(restTransport, grpcTransport, collection, CollectionHandleDefaults.of(fn));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,21 @@ public TaskHandle retry(TaskHandle taskHandle) throws InterruptedException {
@Override
public void close() throws IOException {
boolean closedBefore = closed;
closed = true;

// Update the value atomically to make sure shutdownNow
// does not unnecessarily interrupt this thread.
synchronized (this) {
closed = true;
}

// If we'd been interrupted by shutdownNow, closing would've been
// completed exceptionally prior to that. If that's not the case
// but the current thread is interrupted, then we must propagate
// the interrupt. But first, we should dispose of the services.
if (Thread.interrupted() && !closing.isCompletedExceptionally()) {
shutdownExecutors();
Thread.currentThread().interrupt();
}

log.atDebug()
.addKeyValue("closed_before", closedBefore)
Expand Down Expand Up @@ -409,16 +423,26 @@ private void shutdownNow(Exception e) {
send.cancel(true);
}

if (!closed) {
// Since shutdownNow is never triggered by the "main" thread,
// it may be blocked on trying to add to the queue. While batch
// context is active, we own this thread and may interrupt it.
log.atDebug()
.addKeyValue("thread", Thread::currentThread)
.addKeyValue("closed", closed)
.log("Interrupt parent thread");
parent.interrupt();
// Since shutdownNow is never triggered by the "main" thread,
// it may be blocked on trying to add to the queue. While batch
// context is active, we own this thread and may interrupt it.
// We must be able to guarantee that shutdownNow never interrupts
// an in-progress close and we also don't want to potentially block
// the gRPC thread on which shutdownNow may be executing; we use
// the doubly-checked locking pattern to helps us achieve that.
if (closed) {
return;
}
synchronized (this) {
if (!closed) {
log.atDebug()
.addKeyValue("thread", Thread::currentThread)
.addKeyValue("closed", closed)
.log("Interrupt parent thread");
parent.interrupt();
}
}

}

private void shutdownExecutors() {
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/Proxy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.weaviate.client6.v1.internal;

import javax.annotation.Nullable;

public record Proxy(
String scheme,
String host,
int port,
@Nullable String username,
@Nullable String password
) {
public Proxy(String host, int port) {
this("http", host, port, null, null);
}
}
18 changes: 18 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String userna
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Resource Owner Password Credentials authorization grant.
*
* @param oidc OIDC config.
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
*
* @return Internal TokenProvider implementation.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static TokenProvider resourceOwnerPasswordCredentials(OidcConfig oidc, String clientSecret, String username,
String password) {
final var passwordGrant = NimbusTokenProvider.resouceOwnerPasswordCredentials(oidc, clientSecret, username, password);
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Client Credentials authorization grant.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ public abstract class TransportOptions<H> {
protected final H headers;
protected final TrustManagerFactory trustManagerFactory;
protected final Timeout timeout;
protected final Proxy proxy;

protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider,
TrustManagerFactory tmf, Timeout timeout) {
TrustManagerFactory tmf, Timeout timeout, Proxy proxy) {
this.scheme = scheme;
this.host = host;
this.port = port;
this.tokenProvider = tokenProvider;
this.headers = headers;
this.timeout = timeout;
this.trustManagerFactory = tmf;
this.proxy = proxy;
}

public boolean isSecure() {
Expand Down Expand Up @@ -58,6 +60,11 @@ public TrustManagerFactory trustManagerFactory() {
return this.trustManagerFactory;
}

@Nullable
public Proxy proxy() {
return this.proxy;
}

/**
* isWeaviateDomain returns true if the host matches weaviate.io,
* semi.technology, or weaviate.cloud domain.
Expand All @@ -73,4 +80,9 @@ public static boolean isGoogleCloudDomain(String host) {
var lower = host.toLowerCase();
return lower.contains("gcp");
}

@Nullable
public Proxy proxy() {
return this.proxy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import static java.util.Objects.requireNonNull;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.OptionalInt;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
Expand All @@ -12,7 +14,7 @@
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;

import io.grpc.HttpConnectProxiedSocketAddress;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
Expand All @@ -22,12 +24,19 @@
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import io.weaviate.client6.v1.api.WeaviateApiException;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

public final class DefaultGrpcTransport implements GrpcTransport {
/**
* ListenableFuture callbacks are executed
Expand Down Expand Up @@ -92,7 +101,7 @@ public <RequestT, RequestM, ReplyM, ResponseT> CompletableFuture<ResponseT> perf
var method = rpc.methodAsync();
var stub = applyTimeout(futureStub, rpc);
var reply = method.apply(stub, message);
return toCompletableFuture(reply).thenApply(r -> rpc.unmarshal(r));
return toCompletableFuture(reply).thenApply(rpc::unmarshal);
}

/**
Expand Down Expand Up @@ -146,6 +155,27 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions)
channel.sslContext(sslCtx);
}

if (transportOptions.proxy() != null) {
Proxy proxy = transportOptions.proxy();
if ("http".equals(proxy.scheme()) || "https".equals(proxy.scheme())) {
final SocketAddress proxyAddress = new InetSocketAddress(proxy.host(), proxy.port());
channel.proxyDetector(targetAddress -> {
if (targetAddress instanceof InetSocketAddress) {
HttpConnectProxiedSocketAddress.Builder builder = HttpConnectProxiedSocketAddress.newBuilder()
.setProxyAddress(proxyAddress)
.setTargetAddress((InetSocketAddress) targetAddress);

if (proxy.username() != null && proxy.password() != null) {
builder.setUsername(proxy.username());
builder.setPassword(proxy.password());
}
return builder.build();
}
return null;
});
}
}

channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
return channel.build();
}
Expand Down
Loading