From e69667ce9484fc68c9d5572ce3e472478143473e Mon Sep 17 00:00:00 2001 From: Yuqi Du Date: Mon, 14 Apr 2025 12:08:08 -0700 Subject: [PATCH 01/22] init --- .../jsonapi/service/provider/ModelUsage.java | 121 ++++++++++++++++++ .../provider/ProviderHttpInterceptor.java | 70 ++++++++++ .../service/provider/ProviderType.java | 11 ++ .../reranking/gateway/RerankingEGWClient.java | 3 +- .../operation/NvidiaRerankingProvider.java | 22 +++- .../operation/RerankingProvider.java | 10 +- src/main/proto/embedding_gateway.proto | 40 +++--- .../reranking/NvidiaRerankingClientTest.java | 9 +- .../reranking/RerankingGatewayClientTest.java | 22 ++++ .../reranking/TestRerankingProvider.java | 13 +- 10 files changed, 287 insertions(+), 34 deletions(-) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java new file mode 100644 index 0000000000..3b3b934ea0 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java @@ -0,0 +1,121 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import io.stargate.embedding.gateway.EmbeddingGateway; + +/** + * This class is to track the usage at the http request level to the embedding or reranking provider + * model service. + */ +public class ModelUsage { + + public final ProviderType providerType; + public final String provider; + public final String model; + + private int requestBytes = 0; + private int responseBytes = 0; + + private int promptTokens = 0; + private int totalTokens = 0; + + public ModelUsage(ProviderType providerType, String provider, String model) { + this.providerType = providerType; + this.provider = provider; + this.model = model; + } + + public ModelUsage( + ProviderType providerType, + String provider, + String model, + int requestBytes, + int responseBytes, + int promptTokens, + int totalTokens) { + this.providerType = providerType; + this.provider = provider; + this.model = model; + this.requestBytes = requestBytes; + this.responseBytes = responseBytes; + this.promptTokens = promptTokens; + this.totalTokens = totalTokens; + } + + /** Create the ModelUsage from the modelUsage of Embedding Gateway gRPC response. */ + public static ModelUsage fromGrpcResponse(EmbeddingGateway.ModelUsage modelUsage) { + return new ModelUsage( + ProviderType.valueOf(modelUsage.getProviderType()), + modelUsage.getProviderName(), + modelUsage.getModelName(), + modelUsage.getRequestBytes(), + modelUsage.getResponseBytes(), + modelUsage.getPromptTokens(), + modelUsage.getTotalTokens()); + } + + /** + * Parse the request and response bytes from the headers of the intercepted response. Headers are + * added in the {@link ProviderHttpInterceptor} registered by specified providerClient. + */ + public ModelUsage parseSentReceivedBytes(jakarta.ws.rs.core.Response interceptedResp) { + if (interceptedResp.getHeaders().get(ProviderHttpInterceptor.SENT_BYTES_HEADER) != null) { + this.requestBytes = + Integer.parseInt( + interceptedResp.getHeaderString(ProviderHttpInterceptor.SENT_BYTES_HEADER)); + } + if (interceptedResp.getHeaders().get(ProviderHttpInterceptor.RECEIVED_BYTES_HEADER) != null) { + this.responseBytes = + Integer.parseInt( + interceptedResp.getHeaderString(ProviderHttpInterceptor.RECEIVED_BYTES_HEADER)); + } + return this; + } + + public ModelUsage setPromptTokens(int promptTokens) { + this.promptTokens = promptTokens; + return this; + } + + public ModelUsage setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + return this; + } + + public int getRequestBytes() { + return requestBytes; + } + + public int getResponseBytes() { + return responseBytes; + } + + public int getPromptTokens() { + return promptTokens; + } + + public int getTotalTokens() { + return totalTokens; + } + + @Override + public String toString() { + return "ModelUsage{" + + "providerType=" + + providerType + + ", provider='" + + provider + + '\'' + + ", model='" + + model + + '\'' + + ", requestBytes=" + + requestBytes + + ", responseBytes=" + + responseBytes + + ", promptTokens=" + + promptTokens + + ", totalTokens=" + + totalTokens + + '}'; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java new file mode 100644 index 0000000000..547eabf7c6 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java @@ -0,0 +1,70 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.ByteStreams; +import com.google.common.io.CountingOutputStream; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientResponseContext; +import jakarta.ws.rs.client.ClientResponseFilter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class is to track the usage at the http request level to the embedding or reranking provider + * model service. + * + *

E.G. When a providerClient registered the interceptor + * as @RegisterProvider(ProviderHttpInterceptor.class), the interceptor will intercept the http + * request and response, then add the sent-bytes and received-bytes to the response headers in the + * response context. + * + *

Note, if provider already returned content-length in the response header, then the interceptor + * will reuse it and won't calculate the response size. + */ +public class ProviderHttpInterceptor implements ClientResponseFilter { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProviderHttpInterceptor.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + // Header name to track the sent_bytes to the provider + public static final String SENT_BYTES_HEADER = "sent-bytes"; + // Header name to track the received_bytes from the provider + public static final String RECEIVED_BYTES_HEADER = "received-bytes"; + + @Override + public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) + throws IOException { + int receivedBytes = 0; + int sentBytes = 0; + + // Parse the request entity stream to measure its size. + if (requestContext.hasEntity()) { + try { + CountingOutputStream cus = new CountingOutputStream(OutputStream.nullOutputStream()); + OBJECT_MAPPER.writeValue(cus, requestContext.getEntity()); + cus.close(); + sentBytes = (int) cus.getCount(); + } catch (Exception e) { + LOGGER.warn("Failed to measure request body size: " + e.getMessage()); + } + } + + // Use the content-length if present, otherwise parse the response entity stream to measure its + // size. + if (responseContext.hasEntity()) { + receivedBytes = responseContext.getLength(); + // if provider does not return content-length in the response header. + if (receivedBytes <= 0) { + // Read the response entity stream to measure its size + InputStream inputStream = responseContext.getEntityStream(); + receivedBytes = (int) ByteStreams.copy(inputStream, OutputStream.nullOutputStream()); + } + } + + responseContext.getHeaders().add(SENT_BYTES_HEADER, String.valueOf(sentBytes)); + responseContext.getHeaders().add(RECEIVED_BYTES_HEADER, String.valueOf(receivedBytes)); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java new file mode 100644 index 0000000000..8dacd2faeb --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java @@ -0,0 +1,11 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +/** + * Enum representing the type of provider. + * + *

Used to differentiate between embedding and reranking providers. + */ +public enum ProviderType { + EMBEDDING_PROVIDER, + RERANKING_PROVIDER +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java index 2beb42afe1..16e6fabaea 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java @@ -8,6 +8,7 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import java.util.*; @@ -115,7 +116,7 @@ public Uni rerank( resp.getRanksList().stream() .map(rank -> new Rank(rank.getIndex(), rank.getScore())) .collect(Collectors.toList()), - new Usage(resp.getUsage().getPromptTokens(), resp.getUsage().getTotalTokens())); + ModelUsage.fromGrpcResponse(resp.getModelUsage())); }); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java index a22ef6173d..3d08878ab7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java @@ -10,6 +10,9 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import io.stargate.sgv2.jsonapi.service.embedding.operation.error.RerankingResponseErrorMessageMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; +import io.stargate.sgv2.jsonapi.service.provider.ProviderType; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import jakarta.ws.rs.HeaderParam; @@ -84,11 +87,12 @@ public NvidiaRerankingProvider( @RegisterRestClient @RegisterProvider(RerankingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) public interface NvidiaRerankingClient { @POST @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni rerank( + Uni rerank( @HeaderParam("Authorization") String accessToken, RerankingRequest request); @ClientExceptionMapper @@ -146,7 +150,7 @@ public Uni rerank( "In order to rerank, please provide the reranking API key."); } - Uni response = + Uni response = applyRetry( nvidiaRerankingClient.rerank( HttpConstants.BEARER_PREFIX_FOR_API_KEY + rerankingCredentials.apiKey().get(), @@ -155,13 +159,19 @@ public Uni rerank( return response .onItem() .transform( - resp -> { + interceptedResp -> { + RerankingResponse providerResp = interceptedResp.readEntity(RerankingResponse.class); List ranks = - resp.rankings().stream() + providerResp.rankings().stream() .map(rank -> new Rank(rank.index(), rank.logit())) .toList(); - Usage usage = new Usage(resp.usage().prompt_tokens(), resp.usage().total_tokens()); - return RerankingBatchResponse.of(batchId, ranks, usage); + + ModelUsage modelUsage = + new ModelUsage(ProviderType.RERANKING_PROVIDER, providerId, modelName) + .setPromptTokens(providerResp.usage().prompt_tokens) + .setTotalTokens(providerResp.usage().total_tokens) + .parseSentReceivedBytes(interceptedResp); + return RerankingBatchResponse.of(batchId, ranks, modelUsage); }); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java index db870d2bca..c22f90a24c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java @@ -4,6 +4,7 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import java.time.Duration; import java.util.ArrayList; @@ -95,16 +96,15 @@ public abstract Uni rerank( int batchId, String query, List passages, RerankingCredentials rerankingCredentials); /** The response of a batch rerank call. */ - public record RerankingBatchResponse(int batchId, List ranks, Usage usage) { - public static RerankingBatchResponse of(int batchId, List rankings, Usage usage) { - return new RerankingBatchResponse(batchId, rankings, usage); + public record RerankingBatchResponse(int batchId, List ranks, ModelUsage modelUsage) { + public static RerankingBatchResponse of( + int batchId, List rankings, ModelUsage modelUsage) { + return new RerankingBatchResponse(batchId, rankings, modelUsage); } } public record Rank(int index, float score) {} - public record Usage(int prompt_tokens, int total_tokens) {} - /** * Applies a retry mechanism with backoff and jitter to the Uni returned by the rerank() method, * which makes an HTTP request to a third-party service. diff --git a/src/main/proto/embedding_gateway.proto b/src/main/proto/embedding_gateway.proto index 518cb79b86..6d47a4281a 100644 --- a/src/main/proto/embedding_gateway.proto +++ b/src/main/proto/embedding_gateway.proto @@ -53,7 +53,7 @@ message ProviderEmbedRequest { // The response message for the embedding gateway gRPC API if successful message EmbeddingResponse { - Usage usage = 1; + ModelUsage modelUsage = 1; repeated FloatEmbedding embeddings = 2; ErrorResponse error = 3; @@ -65,18 +65,6 @@ message EmbeddingResponse { repeated float embedding = 2; } - // The usage statistics for the embedding gateway gRPC API on successful response - message Usage { - string provider_name = 1; - string model_name = 2; - string tenant_id = 3; - int32 prompt_tokens = 4; - int32 total_tokens = 5; - int32 input_bytes = 6; - int32 output_bytes = 7; - int32 call_duration_us = 8; - } - // The error response message for the embedding gateway gRPC API message ErrorResponse { string error_code = 1; @@ -227,7 +215,7 @@ message ProviderRerankingRequest { // The reranking response message for the embedding gateway gRPC API if successful message RerankingResponse { - Usage usage = 1; + ModelUsage modelUsage = 1; repeated Rank ranks = 2; ErrorResponse error = 3; @@ -239,12 +227,6 @@ message RerankingResponse { float score = 2; } - // The usage statistics of reranking for the embedding gateway gRPC API on successful response - message Usage { - int32 prompt_tokens = 1; - int32 total_tokens = 2; - } - message ErrorResponse { string error_code = 1; string error_message = 2; @@ -313,3 +295,21 @@ service RerankingService { rpc Rerank (ProviderRerankingRequest) returns (RerankingResponse) {} rpc GetSupportedRerankingProviders (GetSupportedRerankingProvidersRequest) returns (GetSupportedRerankingProvidersResponse){} } + + + +// Common messages definition shared by both embedding and reranking + +// The usage statistics for the embedding gateway gRPC API on successful response from the provider +message ModelUsage { + string provider_type = 1; + string provider_name = 2; + string model_name = 3; + string tenant_id = 4; + int32 prompt_tokens = 5; + int32 total_tokens = 6; + int32 request_bytes = 7; + int32 response_bytes = 8; + int32 call_duration_us = 9; +} + diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java index fa4187bae7..d61b43f787 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java @@ -10,6 +10,8 @@ import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.service.provider.ProviderType; import io.stargate.sgv2.jsonapi.service.reranking.operation.NvidiaRerankingProvider; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; @@ -42,7 +44,12 @@ void handleValidResponse() { return Uni.createFrom() .item( new RerankingProvider.RerankingBatchResponse( - 1, ranks, new RerankingProvider.Usage(0, 0))); + 1, + ranks, + new ModelUsage( + ProviderType.RERANKING_PROVIDER, + "nvidia", + "llama-3.2-nv-rerankqa-1b-v2"))); }); final RerankingProvider.RerankingBatchResponse response = diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java index f25d9da763..e95d618ba7 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java @@ -14,6 +14,8 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.provider.ProviderType; import io.stargate.sgv2.jsonapi.service.reranking.gateway.RerankingEGWClient; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; @@ -55,6 +57,17 @@ void handleValidResponse() { .build()) .toList(); builder.addAllRanks(ranks); + // mock model usage + builder.setModelUsage( + EmbeddingGateway.ModelUsage.newBuilder() + .setProviderType(ProviderType.RERANKING_PROVIDER.name()) + .setProviderName(ProviderConstants.NVIDIA) + .setModelName("llama-3.2-nv-rerankqa-1b-v2") + .setPromptTokens(10) + .setTotalTokens(20) + .setRequestBytes(100) + .setResponseBytes(200) + .build()); when(rerankService.rerank(any())).thenReturn(Uni.createFrom().item(builder.build())); // Create a RerankEGWClient instance @@ -86,6 +99,15 @@ void handleValidResponse() { assertThat(response.ranks().get(0).score()).isEqualTo(1f); assertThat(response.ranks().get(1).index()).isEqualTo(0); assertThat(response.ranks().get(1).score()).isEqualTo(0.1f); + + assertThat(response.modelUsage()).isNotNull(); + assertThat(response.modelUsage().providerType).isEqualTo(ProviderType.RERANKING_PROVIDER); + assertThat(response.modelUsage().provider).isEqualTo(ProviderConstants.NVIDIA); + assertThat(response.modelUsage().model).isEqualTo("llama-3.2-nv-rerankqa-1b-v2"); + assertThat(response.modelUsage().getPromptTokens()).isEqualTo(10); + assertThat(response.modelUsage().getTotalTokens()).isEqualTo(20); + assertThat(response.modelUsage().getRequestBytes()).isEqualTo(100); + assertThat(response.modelUsage().getResponseBytes()).isEqualTo(200); } @Test diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java index cca8bd2eb1..a9b456e74b 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java @@ -2,6 +2,9 @@ import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.service.provider.ProviderType; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfigImpl; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; @@ -37,6 +40,14 @@ public Uni rerank( ranks.add(new Rank(i, score)); } ranks.sort((o1, o2) -> Float.compare(o2.score(), o1.score())); // Descending order - return Uni.createFrom().item(RerankingBatchResponse.of(batchId, ranks, new Usage(0, 0))); + return Uni.createFrom() + .item( + RerankingBatchResponse.of( + batchId, + ranks, + new ModelUsage( + ProviderType.RERANKING_PROVIDER, + ProviderConstants.NVIDIA, + "nvidia/llama-3.2-nv-rerankqa-1b-v2"))); } } From b50241a32c47555d2b955a96c4f70219e9c86db5 Mon Sep 17 00:00:00 2001 From: Yuqi Du Date: Mon, 14 Apr 2025 12:17:09 -0700 Subject: [PATCH 02/22] java doc --- .../sgv2/jsonapi/service/provider/ModelUsage.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java index 3b3b934ea0..25b5a73645 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java @@ -12,10 +12,18 @@ public class ModelUsage { public final String provider; public final String model; + /** The number of bytes sent in the request. */ private int requestBytes = 0; + + /** The number of bytes received in the response. Use content-length if present */ private int responseBytes = 0; + /** The number of tokens in the prompt, will be set if provider returned in the response. */ private int promptTokens = 0; + + /** + * The total number of tokens in the request, will be set if provider returned in the response. + */ private int totalTokens = 0; public ModelUsage(ProviderType providerType, String provider, String model) { From 1c1f1d877faaaa2a9fa4d84749151e7f25bc38e0 Mon Sep 17 00:00:00 2001 From: Aaron Morton Date: Wed, 11 Jun 2025 11:56:49 +1200 Subject: [PATCH 03/22] WIP - code changes, compiles, tests not verified --- .../command/impl/CreateCollectionCommand.java | 10 +- .../model/command/impl/VectorizeConfig.java | 19 +- .../api/request/EmbeddingCredentials.java | 5 +- .../request/EmbeddingCredentialsResolver.java | 2 +- ...aderBasedEmbeddingCredentialsResolver.java | 3 +- .../jsonapi/api/request/RequestContext.java | 18 +- .../api/request/RerankingCredentials.java | 2 +- .../config/constants/RerankingConstants.java | 2 - .../constants/ServiceDescConstants.java | 2 +- .../service/embedding/DataVectorizer.java | 4 +- .../EmbeddingProviderConfigStore.java | 6 +- ...ertyBasedEmbeddingProviderConfigStore.java | 3 +- .../configuration/ProviderConstants.java | 21 -- .../gateway/EmbeddingGatewayClient.java | 103 ++++--- .../AwsBedrockEmbeddingProvider.java | 239 +++++++++------- .../AzureOpenAIEmbeddingProvider.java | 185 ++++++------ .../operation/CohereEmbeddingProvider.java | 220 ++++++++------- .../operation/EmbeddingProvider.java | 155 ++++++---- .../operation/EmbeddingProviderFactory.java | 85 ++++-- ...HuggingFaceDedicatedEmbeddingProvider.java | 192 +++++++------ .../HuggingFaceEmbeddingProvider.java | 175 +++++++----- .../operation/JinaAIEmbeddingProvider.java | 184 ++++++------ .../operation/MeteredEmbeddingProvider.java | 31 +- .../operation/MistralEmbeddingProvider.java | 193 +++++++------ .../operation/NvidiaEmbeddingProvider.java | 180 ++++++------ .../operation/OpenAIEmbeddingProvider.java | 191 +++++++------ .../operation/UpstageAIEmbeddingProvider.java | 219 +++++++++------ .../operation/VertexAIEmbeddingProvider.java | 249 ++++++++-------- .../operation/VoyageAIEmbeddingProvider.java | 195 +++++++------ .../RerankingResponseErrorMessageMapper.java | 1 + .../test/CustomITEmbeddingProvider.java | 57 +++- .../operation/embeddings/EmbeddingTask.java | 6 +- .../operation/reranking/RerankingTask.java | 3 +- .../service/provider/ModelInputType.java | 34 +++ .../service/provider/ModelProvider.java | 50 ++++ .../jsonapi/service/provider/ModelType.java | 25 ++ .../jsonapi/service/provider/ModelUsage.java | 265 +++++++++++------- .../service/provider/ProviderBase.java | 179 ++++++++++++ .../provider/ProviderHttpInterceptor.java | 54 ++-- .../service/provider/ProviderType.java | 11 - .../reranking/gateway/RerankingEGWClient.java | 74 ++--- .../operation/NvidiaRerankingProvider.java | 177 ++++++------ .../operation/RerankingProvider.java | 196 +++++++++---- .../operation/RerankingProviderFactory.java | 31 +- .../resolver/VectorizeConfigValidator.java | 8 +- src/main/proto/embedding_gateway.proto | 108 +++---- .../resources/embedding-providers-config.yaml | 1 - .../operation/DataVectorizerTest.java | 18 +- .../operation/EmbeddingGatewayClientTest.java | 9 +- .../EmbeddingProviderErrorMessageTest.java | 7 +- .../operation/OpenAiEmbeddingClientTest.java | 11 +- .../operation/TestEmbeddingProvider.java | 42 ++- .../RerankingGatewayClientTest.java | 33 ++- .../NvidiaRerankingClientTest.java | 20 +- .../RerankingProviderTest.java | 5 +- .../TestRerankingProvider.java | 41 ++- 56 files changed, 2653 insertions(+), 1706 deletions(-) delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ProviderConstants.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelProvider.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java rename src/test/java/io/stargate/sgv2/jsonapi/service/reranking/{ => gateway}/RerankingGatewayClientTest.java (82%) rename src/test/java/io/stargate/sgv2/jsonapi/service/reranking/{ => operation}/NvidiaRerankingClientTest.java (71%) rename src/test/java/io/stargate/sgv2/jsonapi/service/reranking/{ => operation}/RerankingProviderTest.java (93%) rename src/test/java/io/stargate/sgv2/jsonapi/service/reranking/{ => operation}/TestRerankingProvider.java (55%) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java index 9ad0427695..dc22ef4ad4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java @@ -5,7 +5,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.CollectionOnlyCommand; import io.stargate.sgv2.jsonapi.api.model.command.CommandName; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; -import io.stargate.sgv2.jsonapi.config.constants.RerankingConstants; +import io.stargate.sgv2.jsonapi.config.constants.ServiceDescConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.schema.collections.DocumentPath; import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules; @@ -276,20 +276,20 @@ public record RerankServiceDesc( description = "Registered reranking service provider", type = SchemaType.STRING, implementation = String.class) - @JsonProperty(RerankingConstants.RerankingService.PROVIDER) + @JsonProperty(ServiceDescConstants.PROVIDER) String provider, @Schema( description = "Registered reranking service model", type = SchemaType.STRING, implementation = String.class) - @JsonProperty(RerankingConstants.RerankingService.MODEL_NAME) + @JsonProperty(ServiceDescConstants.MODEL_NAME) String modelName, @Valid @Nullable @Schema( description = "Authentication config for chosen reranking service", type = SchemaType.OBJECT) - @JsonProperty(RerankingConstants.RerankingService.AUTHENTICATION) + @JsonProperty(ServiceDescConstants.AUTHENTICATION) @JsonInclude(JsonInclude.Include.NON_NULL) Map authentication, @Nullable @@ -297,7 +297,7 @@ public record RerankServiceDesc( description = "Optional parameters that match the messageTemplate provided for the reranking provider", type = SchemaType.OBJECT) - @JsonProperty(RerankingConstants.RerankingService.PARAMETERS) + @JsonProperty(ServiceDescConstants.PARAMETERS) @JsonInclude(JsonInclude.Include.NON_NULL) Map parameters) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java index e4e5888f9d..fd1b631642 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java @@ -4,7 +4,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.embedding.operation.HuggingFaceDedicatedEmbeddingProvider; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.validation.Valid; import jakarta.validation.constraints.*; import java.util.*; @@ -48,24 +49,30 @@ public VectorizeConfig( String modelName, Map authentication, Map parameters) { + if (provider == null) { throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( "'provider' in required property for 'vector.service' Object value"); } + this.provider = provider; + // HuggingfaceDedicated does not need user to specify model explicitly // If user specifies modelName other than endpoint-defined-model, will error out // By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder - if (ProviderConstants.HUGGINGFACE_DEDICATED.equals(provider)) { + if (ModelProvider.HUGGINGFACE_DEDICATED.apiName().equals(provider)) { if (modelName == null) { - modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL; - } else if (!modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) { + modelName = + HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL; + } else if (!modelName.equals( + HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL)) { throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( "'modelName' is not needed for embedding provider %s explicitly, only '%s' is accepted", - ProviderConstants.HUGGINGFACE_DEDICATED, - ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL); + ModelProvider.HUGGINGFACE_DEDICATED, + HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL); } } + this.modelName = modelName; if (authentication != null && !authentication.isEmpty()) { Map updatedAuth = new HashMap<>(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java index 5706f61fce..8b978243e7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java @@ -11,4 +11,7 @@ * @param secretId - Secret Id used for AWS Bedrock embedding service */ public record EmbeddingCredentials( - Optional apiKey, Optional accessId, Optional secretId) {} + String tenantId, + Optional apiKey, + Optional accessId, + Optional secretId) {} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsResolver.java index bd14ba0b15..fc37f4f795 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsResolver.java @@ -5,5 +5,5 @@ /** Functional interface to resolve the embedding api key from the request context. */ @FunctionalInterface public interface EmbeddingCredentialsResolver { - EmbeddingCredentials resolveEmbeddingCredentials(RoutingContext context); + EmbeddingCredentials resolveEmbeddingCredentials(String tenantId, RoutingContext context); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/HeaderBasedEmbeddingCredentialsResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/HeaderBasedEmbeddingCredentialsResolver.java index e9802611b0..50eb7fd596 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/HeaderBasedEmbeddingCredentialsResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/HeaderBasedEmbeddingCredentialsResolver.java @@ -23,12 +23,13 @@ public HeaderBasedEmbeddingCredentialsResolver( Objects.requireNonNull(secretIdHeaderName, "Secret Id header name cannot be null"); } - public EmbeddingCredentials resolveEmbeddingCredentials(RoutingContext context) { + public EmbeddingCredentials resolveEmbeddingCredentials(String tenantId, RoutingContext context) { HttpServerRequest request = context.request(); String headerValue = request.getHeader(this.tokenHeaderName); String accessId = request.getHeader(this.accessIdHeaderName); String secretId = request.getHeader(this.secretIdHeaderName); return new EmbeddingCredentials( + tenantId, Optional.ofNullable(headerValue), Optional.ofNullable(accessId), Optional.ofNullable(secretId)); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java index 6b2274cccf..659852fcd7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java @@ -52,9 +52,14 @@ public RequestContext( Instance tenantResolver, Instance tokenResolver, Instance embeddingCredentialsResolver) { - this.embeddingCredentials = - embeddingCredentialsResolver.get().resolveEmbeddingCredentials(routingContext); + this.tenantId = (tenantResolver.get()).resolve(routingContext, securityContext); + + this.embeddingCredentials = + embeddingCredentialsResolver + .get() + .resolveEmbeddingCredentials(tenantId.orElse(""), routingContext); + this.cassandraToken = (tokenResolver.get()).resolve(routingContext, securityContext); httpHeaders = new HttpHeaderAccess(routingContext.request().headers()); requestId = generateRequestId(); @@ -63,11 +68,14 @@ public RequestContext( HeaderBasedRerankingKeyResolver.resolveRerankingKey(routingContext); this.rerankingCredentials = rerankingApiKeyFromHeader - .map(apiKey -> new RerankingCredentials(Optional.of(apiKey))) + .map(apiKey -> new RerankingCredentials(this.tenantId.get(), Optional.of(apiKey))) .orElse( this.cassandraToken - .map(cassandraToken -> new RerankingCredentials(Optional.of(cassandraToken))) - .orElse(new RerankingCredentials(Optional.empty()))); + .map( + cassandraToken -> + new RerankingCredentials( + this.tenantId.get(), Optional.of(cassandraToken))) + .orElse(new RerankingCredentials(this.tenantId.get(), Optional.empty()))); } private static String generateRequestId() { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java index 874075cb32..a588888900 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java @@ -8,4 +8,4 @@ * cassandra token as the reranking api key. Note, both cassandra token and reranking-api-key could * be absent in Data API request, although it is invalid for authentication. */ -public record RerankingCredentials(Optional apiKey) {} +public record RerankingCredentials(String tenantId, Optional apiKey) {} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java index e34f36bc2f..a166409ab3 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java @@ -6,6 +6,4 @@ interface CollectionRerankingOptions { String ENABLED = "enabled"; String SERVICE = ServiceDescConstants.SERVICE; } - - interface RerankingService extends ServiceDescConstants {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java index a2442489df..5e3536bfd3 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java @@ -1,7 +1,7 @@ package io.stargate.sgv2.jsonapi.config.constants; /** Common service description constants shared between vector and reranking */ -interface ServiceDescConstants { +public interface ServiceDescConstants { String SERVICE = "service"; String PROVIDER = "provider"; String MODEL_NAME = "modelName"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index 113ec69402..a6474fda6a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -175,7 +175,7 @@ public Uni vectorize(String vectorizeContent) { List.of(vectorizeContent), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) - .map(EmbeddingProvider.Response::embeddings); + .map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings); return vectors .onItem() .transform( @@ -301,7 +301,7 @@ private Uni> vectorizeTexts( return embeddingProvider .vectorize(1, textsToVectorize, embeddingCredentials, requestType) - .map(EmbeddingProvider.Response::embeddings) + .map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings) .onItem() .transform( vectorData -> { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java index 884349c6ee..db94a35c4a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java @@ -1,6 +1,7 @@ package io.stargate.sgv2.jsonapi.service.embedding.configuration; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import java.util.Map; import java.util.Optional; @@ -27,8 +28,8 @@ public static ServiceConfig provider( public static ServiceConfig custom(Optional> implementationClass) { return new ServiceConfig( - ProviderConstants.CUSTOM, - ProviderConstants.CUSTOM, + ModelProvider.CUSTOM.apiName(), + ModelProvider.CUSTOM.apiName(), null, implementationClass, null, @@ -56,6 +57,7 @@ record RequestProperties( // `maxBatchSize` is the maximum number of documents to be sent in a single request to be // embedding provider int maxBatchSize) { + public static RequestProperties of( int atMostRetries, int initialBackOffMillis, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java index 347bd56c2c..58976c3cc5 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java @@ -1,6 +1,7 @@ package io.stargate.sgv2.jsonapi.service.embedding.configuration; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import java.util.HashMap; @@ -23,7 +24,7 @@ public void saveConfiguration(Optional tenant, ServiceConfig serviceConf public EmbeddingProviderConfigStore.ServiceConfig getConfiguration( Optional tenant, String serviceName) { // already checked if the service exists and enabled in CreateCollectionCommandResolver - if (serviceName.equals(ProviderConstants.CUSTOM)) { + if (serviceName.equals(ModelProvider.CUSTOM.apiName())) { return ServiceConfig.custom(config.custom().clazz()); } if (config.providers().get(serviceName) == null diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ProviderConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ProviderConstants.java deleted file mode 100644 index e47bc6738a..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ProviderConstants.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding.configuration; - -public final class ProviderConstants { - public static final String OPENAI = "openai"; - public static final String AZURE_OPENAI = "azureOpenAI"; - public static final String HUGGINGFACE = "huggingface"; - public static final String HUGGINGFACE_DEDICATED = "huggingfaceDedicated"; - public static final String HUGGINGFACE_DEDICATED_DEFINED_MODEL = "endpoint-defined-model"; - public static final String VERTEXAI = "vertexai"; - public static final String COHERE = "cohere"; - public static final String NVIDIA = "nvidia"; - public static final String UPSTAGE_AI = "upstageAI"; - public static final String VOYAGE_AI = "voyageAI"; - public static final String JINA_AI = "jinaAI"; - public static final String CUSTOM = "custom"; - public static final String MISTRAL = "mistral"; - public static final String BEDROCK = "bedrock"; - - // Private constructor to prevent instantiation - private ProviderConstants() {} -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java index 86ccfc4b21..5e0c2e9f72 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java @@ -10,7 +10,7 @@ import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; -import java.util.Collections; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -35,7 +35,7 @@ public class EmbeddingGatewayClient extends EmbeddingProvider { private EmbeddingProviderConfigStore.RequestProperties requestProperties; - private String provider; + private ModelProvider modelProvider; private int dimension; @@ -61,7 +61,7 @@ public class EmbeddingGatewayClient extends EmbeddingProvider { */ public EmbeddingGatewayClient( EmbeddingProviderConfigStore.RequestProperties requestProperties, - String provider, + ModelProvider modelProvider, int dimension, Optional tenant, Optional authToken, @@ -71,8 +71,11 @@ public EmbeddingGatewayClient( Map vectorizeServiceParameter, Map authentication, String commandName) { + super( + modelProvider, requestProperties, baseUrl, modelName, dimension, vectorizeServiceParameter); + this.requestProperties = requestProperties; - this.provider = provider; + this.modelProvider = modelProvider; this.dimension = dimension; this.tenant = tenant; this.authToken = authToken; @@ -84,6 +87,12 @@ public EmbeddingGatewayClient( this.commandName = commandName; } + @Override + protected String errorMessageJsonPtr() { + // not used , this is passing through + return ""; + } + /** * Vectorize the given list of texts * @@ -93,48 +102,52 @@ public EmbeddingGatewayClient( * @return */ @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - Map - grpcVectorizeServiceParameter = new HashMap<>(); + + var gatewayRequestParams = + new HashMap< + String, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue>(); + if (vectorizeServiceParameter != null) { vectorizeServiceParameter.forEach( (key, value) -> { if (value instanceof String) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setStrValue((String) value) .build()); else if (value instanceof Integer) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setIntValue((Integer) value) .build()); else if (value instanceof Float) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setFloatValue((Float) value) .build()); else if (value instanceof Boolean) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setBoolValue((Boolean) value) .build()); }); } - EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest embeddingRequest = + + var gatewayEmbedding = EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.newBuilder() .setModelName(modelName) .setDimensions(dimension) .setCommandName(commandName) - .putAllParameters(grpcVectorizeServiceParameter) + .putAllParameters(gatewayRequestParams) .setInputType( embeddingRequestType == EmbeddingRequestType.INDEX ? EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.InputType.INDEX @@ -142,58 +155,59 @@ else if (value instanceof Boolean) .addAllInputs(texts) .build(); - final EmbeddingGateway.ProviderEmbedRequest.ProviderContext.Builder builder = + var contextBuilder = EmbeddingGateway.ProviderEmbedRequest.ProviderContext.newBuilder() - .setProviderName(provider) + .setProviderName(modelProvider.apiName()) .setTenantId(tenant.orElse(DEFAULT_TENANT_ID)); - // Add the value of `Token` in the header - builder.putAuthTokens(DATA_API_TOKEN, authToken.orElse("")); - // Add the value of `x-embedding-api-key` in the header - if (embeddingCredentials.apiKey().isPresent()) { - builder.putAuthTokens(EMBEDDING_API_KEY, embeddingCredentials.apiKey().get()); - } - // Add the value of `x-embedding-access-id` in the header - if (embeddingCredentials.accessId().isPresent()) { - builder.putAuthTokens(EMBEDDING_ACCESS_ID, embeddingCredentials.accessId().get()); - } - // Add the value of `x-embedding-secret-id` in the header - if (embeddingCredentials.secretId().isPresent()) { - builder.putAuthTokens(EMBEDDING_SECRET_ID, embeddingCredentials.secretId().get()); - } + + contextBuilder.putAuthTokens(DATA_API_TOKEN, authToken.orElse("")); + embeddingCredentials + .apiKey() + .ifPresent(v -> contextBuilder.putAuthTokens(EMBEDDING_API_KEY, v)); + embeddingCredentials + .accessId() + .ifPresent(v -> contextBuilder.putAuthTokens(EMBEDDING_ACCESS_ID, v)); + embeddingCredentials + .secretId() + .ifPresent(v -> contextBuilder.putAuthTokens(EMBEDDING_SECRET_ID, v)); + // Add the `authentication` (sync service key) in the createCollection command if (authentication != null) { - builder.putAllAuthTokens(authentication); + contextBuilder.putAllAuthTokens(authentication); } - EmbeddingGateway.ProviderEmbedRequest.ProviderContext providerContext = builder.build(); - EmbeddingGateway.ProviderEmbedRequest providerEmbedRequest = + var gatewayRequest = EmbeddingGateway.ProviderEmbedRequest.newBuilder() - .setEmbeddingRequest(embeddingRequest) - .setProviderContext(providerContext) + .setEmbeddingRequest(gatewayEmbedding) + .setProviderContext(contextBuilder.build()) .build(); + + // TODO: XXX Why is this error handling here not part of the uni pipeline? Uni embeddingResponse; try { - embeddingResponse = embeddingService.embed(providerEmbedRequest); + embeddingResponse = embeddingService.embed(gatewayRequest); } catch (StatusRuntimeException e) { if (e.getStatus().getCode().equals(Status.Code.DEADLINE_EXCEEDED)) { throw ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException(e, e.getMessage()); } throw e; } + return embeddingResponse .onItem() .transform( - resp -> { - if (resp.hasError()) { + gatewayResponse -> { + // TODO : move to V2 error + if (gatewayResponse.hasError()) { throw new JsonApiException( - ErrorCodeV1.valueOf(resp.getError().getErrorCode()), - resp.getError().getErrorMessage()); - } - if (resp.getEmbeddingsList() == null) { - return Response.of(batchId, Collections.emptyList()); + ErrorCodeV1.valueOf(gatewayResponse.getError().getErrorCode()), + gatewayResponse.getError().getErrorMessage()); } + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // but grpc will make sure resp.getEmbeddingsList() is never null + final List vectors = - resp.getEmbeddingsList().stream() + gatewayResponse.getEmbeddingsList().stream() .map( data -> { float[] embedding = new float[data.getEmbeddingCount()]; @@ -203,7 +217,8 @@ else if (value instanceof Boolean) return embedding; }) .toList(); - return Response.of(batchId, vectors); + return new BatchedEmbeddingResponse( + batchId, vectors, createModelUsage(gatewayResponse.getModelUsage())); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java index 444eb30086..b2453b55bd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java @@ -9,29 +9,30 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectReader; import com.fasterxml.jackson.databind.ObjectWriter; +import com.google.common.io.CountingOutputStream; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import jakarta.ws.rs.core.Response; import java.io.IOException; +import java.io.OutputStream; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; /** Provider implementation for AWS Bedrock. To start we support only Titan embedding models. */ public class AwsBedrockEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.BEDROCK; - private static final ObjectWriter ow = new ObjectMapper().writer(); - private static final ObjectReader or = new ObjectMapper().reader(); + private static final ObjectWriter OBJECT_WRITER = new ObjectMapper().writer(); + private static final ObjectReader OBJECT_READER = new ObjectMapper().reader(); public AwsBedrockEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -40,6 +41,7 @@ public AwsBedrockEmbeddingProvider( int dimension, Map vectorizeServiceParameters) { super( + ModelProvider.BEDROCK, requestProperties, baseUrl, modelName, @@ -48,125 +50,164 @@ public AwsBedrockEmbeddingProvider( } @Override - public Uni vectorize( + protected String errorMessageJsonPtr() { + // not used in this provider, has custom error handling + return ""; + } + + @Override + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + + // the config shoudl mean we only do a batch on 1, sanity checking + if (texts.size() != 1) { + throw new IllegalArgumentException( + "AWS Bedrock embedding provider only supports a single text input per request, but received: " + + texts.size()); + } + + // TODO: move to V2 errors if (embeddingCredentials.accessId().isEmpty() && embeddingCredentials.secretId().isEmpty()) { throw ErrorCodeV1.EMBEDDING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( "Both '%s' and '%s' are missing in the header for provider '%s'", EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME, EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME, - providerId); + modelProvider().apiName()); } else if (embeddingCredentials.accessId().isEmpty()) { throw ErrorCodeV1.EMBEDDING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( "'%s' is missing in the header for provider '%s'", - EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME, providerId); + EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME, modelProvider().apiName()); } else if (embeddingCredentials.secretId().isEmpty()) { throw ErrorCodeV1.EMBEDDING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( "'%s' is missing in the header for provider '%s'", - EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME, providerId); + EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME, modelProvider().apiName()); } - AwsBasicCredentials awsCreds = + var awsCreds = AwsBasicCredentials.create( embeddingCredentials.accessId().get(), embeddingCredentials.secretId().get()); - BedrockRuntimeAsyncClient client = + try (var bedrockClient = BedrockRuntimeAsyncClient.builder() .credentialsProvider(StaticCredentialsProvider.create(awsCreds)) .region(Region.of(vectorizeServiceParameters.get("region").toString())) - .build(); - final CompletableFuture invokeModelResponseCompletableFuture = - client.invokeModel( - request -> { - final byte[] inputData; - try { - inputData = ow.writeValueAsBytes(new EmbeddingRequest(texts.get(0), dimension)); - request.body(SdkBytes.fromByteArray(inputData)).modelId(modelName); - } catch (JsonProcessingException e) { - throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); - } - }); - - final CompletableFuture responseCompletableFuture = - invokeModelResponseCompletableFuture.thenApply( - res -> { - try { - EmbeddingResponse response = - or.readValue(res.body().asInputStream(), EmbeddingResponse.class); - List vectors = List.of(response.embedding); - return Response.of(batchId, vectors); - } catch (IOException e) { - throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); - } - }); - - return Uni.createFrom() - .completionStage(responseCompletableFuture) - .onFailure(BedrockRuntimeException.class) - .transform( - error -> { - BedrockRuntimeException bedrockRuntimeException = (BedrockRuntimeException) error; - // Status code == 408 and 504 for timeout - if (bedrockRuntimeException.statusCode() - == jakarta.ws.rs.core.Response.Status.REQUEST_TIMEOUT.getStatusCode() - || bedrockRuntimeException.statusCode() - == jakarta.ws.rs.core.Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { - return ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // Status code == 429 - if (bedrockRuntimeException.statusCode() - == jakarta.ws.rs.core.Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { - return ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // Status code in 4XX other than 429 - if (bedrockRuntimeException.statusCode() > 400 - && bedrockRuntimeException.statusCode() < 500) { - return ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // Status code in 5XX - if (bedrockRuntimeException.statusCode() >= 500) { - return ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // All other errors, Should never happen as all errors are covered above - return ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - }); + .build()) { + + long callStartNano = System.nanoTime(); + + var bytesUsageTracker = new ByteUsageTracker(); + var bedrockFuture = + bedrockClient + .invokeModel( + requestBuilder -> { + try { + var inputData = + OBJECT_WRITER.writeValueAsBytes( + new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension)); + bytesUsageTracker.requestBytes = inputData.length; + requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName()); + } catch (JsonProcessingException e) { + throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); + } + }) + .thenApply( + rawResponse -> { + try { + // aws docs say do not need to close the stream + var inputStream = rawResponse.body().asInputStream(); + var bedrockResponse = + OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + try (var countingOut = + new CountingOutputStream(OutputStream.nullOutputStream())) { + inputStream.transferTo(countingOut); + long responseSize = countingOut.getCount(); + bytesUsageTracker.responseBytes = + responseSize > Integer.MAX_VALUE + ? Integer.MAX_VALUE + : (int) responseSize; + } + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + bedrockResponse.inputTextTokenCount(), + bedrockResponse.inputTextTokenCount(), + bytesUsageTracker.requestBytes, + bytesUsageTracker.responseBytes, + callDurationNano); + + return new BatchedEmbeddingResponse( + batchId, List.of(bedrockResponse.embedding), modelUsage); + + } catch (IOException e) { + throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); + } + }); + + return Uni.createFrom() + .completionStage(bedrockFuture) + .onFailure(BedrockRuntimeException.class) + .transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable)); + } } - private record EmbeddingRequest( - String inputText, @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} + private Throwable mapBedrockException(BedrockRuntimeException bedrockException) { - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(float[] embedding, int inputTextTokenCount) {} + if (bedrockException.statusCode() == Response.Status.REQUEST_TIMEOUT.getStatusCode() + || bedrockException.statusCode() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + if (bedrockException.statusCode() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } + + if (bedrockException.statusCode() > 400 && bedrockException.statusCode() < 500) { + return ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } + + if (bedrockException.statusCode() >= 500) { + return ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } + + // All other errors, Should never happen as all errors are covered above + return ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); } + + private static class ByteUsageTracker { + int requestBytes = 0; + int responseBytes = 0; + } + + /** + * Request structure of the AWS Bedrock REST service. + * + *

.. + */ + public record AwsBedrockEmbeddingRequest( + String inputText, @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} + + /** + * Response structure of the AWS Bedrock REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record AwsBedrockEmbeddingResponse(float[] embedding, int inputTextTokenCount) {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java index 49ddb043e6..ae6c28d917 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java @@ -2,22 +2,21 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -31,8 +30,8 @@ * details of REST API being called. */ public class AzureOpenAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.AZURE_OPENAI; - private final OpenAIEmbeddingProviderClient openAIEmbeddingProviderClient; + + private final AzureOpenAIEmbeddingProviderClient azureClient; public AzureOpenAIEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -42,6 +41,7 @@ public AzureOpenAIEmbeddingProvider( Map vectorizeServiceParameters) { // One special case: legacy "ada-002" model does not accept "dimension" parameter super( + ModelProvider.AZURE_OPENAI, requestProperties, baseUrl, modelName, @@ -49,102 +49,121 @@ public AzureOpenAIEmbeddingProvider( vectorizeServiceParameters); String actualUrl = replaceParameters(baseUrl, vectorizeServiceParameters); - openAIEmbeddingProviderClient = + azureClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(actualUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) - .build(OpenAIEmbeddingProviderClient.class); + .build(AzureOpenAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface OpenAIEmbeddingProviderClient { - @POST - // no path specified, as it is already included in the baseUri - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - // API keys as "api-key", MS Entra as "Authorization: Bearer [token] - @HeaderParam("api-key") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "error": {
-     *     "code": "401",
-     *     "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
-     *   }
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "message" node from the "error" node - JsonNode messageNode = rootNode.at("/error/message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - private record EmbeddingRequest( - String[] input, - String model, - @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * The example response body is: + * + *
+   * {
+   *   "error": {
+   *     "code": "401",
+   *     "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
+   *   }
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/error/message"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - String[] textArray = new String[texts.size()]; - EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension); - // NOTE: NO "Bearer " prefix with API key for Azure OpenAI - Uni response = - applyRetry( - openAIEmbeddingProviderClient.embed(embeddingCredentials.apiKey().get(), request)); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + var azureRequest = + new AzureOpenAIEmbeddingRequest( + texts.toArray(new String[texts.size()]), modelName(), dimension); - return response + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + // NOTE: NO "Bearer " prefix with API key for Azure + var accessToken = embeddingCredentials.apiKey().get(); + + long callStartNano = System.nanoTime(); + return retryHTTPCall(azureClient.embed(accessToken, azureRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var azureResponse = jakartaResponse.readEntity(AzureOpenAIEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (azureResponse.data() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(azureResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(azureResponse.data()) + .map(AzureOpenAIEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + azureResponse.usage().prompt_tokens(), + azureResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Azure Open AI Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface AzureOpenAIEmbeddingProviderClient { + // no path specified, as it is already included in the baseUri + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + // API keys as "api-key", MS Entra as "Authorization: Bearer [token] + @HeaderParam("api-key") String accessToken, AzureOpenAIEmbeddingRequest request); + } + + /** + * Request structure of the Azure Open AI REST service. + * + *

.. + */ + public record AzureOpenAIEmbeddingRequest( + String[] input, + String model, + @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} + + /** + * Response structure of the Azure Open AI REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record AzureOpenAIEmbeddingResponse( + String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java index 193678a583..325ba8a803 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java @@ -1,23 +1,24 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -30,8 +31,8 @@ * of chosen Cohere model. */ public class CohereEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.COHERE; - private final CohereEmbeddingProviderClient cohereEmbeddingProviderClient; + + private final CohereEmbeddingProviderClient cohereClient; public CohereEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -39,124 +40,149 @@ public CohereEmbeddingProvider( String modelName, int dimension, Map vectorizeServiceParameters) { - super(requestProperties, baseUrl, modelName, dimension, vectorizeServiceParameters); - - cohereEmbeddingProviderClient = + super( + ModelProvider.COHERE, + requestProperties, + baseUrl, + modelName, + dimension, + vectorizeServiceParameters); + + cohereClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(CohereEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface CohereEmbeddingProviderClient { - @POST - @Path("/embed") - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message": "invalid api token"
-     * }
-     *
-     * 429 response body:
-     * {
-     *   "data": "string"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Check if the root node contains a "message" field - JsonNode messageNode = rootNode.path("message"); - if (!messageNode.isMissingNode()) { - return messageNode.toString(); - } - // Check if the root node contains a "data" field - JsonNode dataNode = rootNode.path("data"); - if (!dataNode.isMissingNode()) { - return dataNode.toString(); - } - // Return the whole response body if no message or data field is found - return rootNode.toString(); - } + @Override + protected String errorMessageJsonPtr() { + // overriding the function that calls this + return ""; } - private record EmbeddingRequest(String[] texts, String model, String input_type) {} - - // @JsonIgnoreProperties({"id", "texts", "meta", "response_type"}) - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private static class EmbeddingResponse { - - protected EmbeddingResponse() {} - - private List embeddings; + /** + * The example response body is: + * + *
+   * {
+   *   "message": "invalid api token"
+   * }
+   *
+   * 429 response body:
+   * {
+   *   "data": "string"
+   * }
+   */
+  @Override
+  protected String responseErrorMessage(JsonNode rootNode) {
 
-    public List getEmbeddings() {
-      return embeddings;
+    JsonNode messageNode = rootNode.path("message");
+    if (!messageNode.isMissingNode()) {
+      return messageNode.toString();
     }
 
-    public void setEmbeddings(List embeddings) {
-      this.embeddings = embeddings;
+    JsonNode dataNode = rootNode.path("data");
+    if (!dataNode.isMissingNode()) {
+      return dataNode.toString();
     }
-  }
 
-  // Input type to be used for vector search should "search_query"
-  private static final String SEARCH_QUERY = "search_query";
-  private static final String SEARCH_DOCUMENT = "search_document";
+    // Return the whole response body if no message or data field is found
+    return rootNode.toString();
+  }
 
   @Override
-  public Uni vectorize(
+  public Uni vectorize(
       int batchId,
       List texts,
       EmbeddingCredentials embeddingCredentials,
       EmbeddingRequestType embeddingRequestType) {
-    checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey());
 
-    String[] textArray = new String[texts.size()];
-    String input_type =
-        embeddingRequestType == EmbeddingRequestType.INDEX ? SEARCH_DOCUMENT : SEARCH_QUERY;
-    EmbeddingRequest request =
-        new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);
+    checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey());
+
+    // Input type to be used for vector search should "search_query"
+    var input_type =
+        embeddingRequestType == EmbeddingRequestType.INDEX ? "search_document" : "search_query";
+    var cohereRequest =
+        new CohereEmbeddingRequest(
+            texts.toArray(new String[texts.size()]), modelName(), input_type);
 
-    Uni response =
-        applyRetry(
-            cohereEmbeddingProviderClient.embed(
-                HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(),
-                request));
+    // TODO: V2 error
+    // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty.
+    var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get();
 
-    return response
+    long callStartNano = System.nanoTime();
+
+    return retryHTTPCall(cohereClient.embed(accessToken, cohereRequest))
         .onItem()
         .transform(
-            resp -> {
-              if (resp.getEmbeddings() == null) {
-                return Response.of(batchId, Collections.emptyList());
+            jakartaResponse -> {
+              var cohereResponse = jakartaResponse.readEntity(CohereEmbeddingResponse.class);
+              long callDurationNano = System.nanoTime() - callStartNano;
+
+              // aaron - 10 June 2025 - previous code would silently swallow no data returned
+              // and return an empty result. If we made a request we should get a response.
+              if (cohereResponse.embeddings() == null) {
+                throw new IllegalStateException(
+                    "ModelProvider %s returned empty data for model %s"
+                        .formatted(modelProvider(), modelName()));
               }
-              return Response.of(batchId, resp.getEmbeddings());
+
+              var modelUsage =
+                  createModelUsage(
+                      embeddingCredentials.tenantId(),
+                      ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
+                      cohereResponse.meta().billed_units().input_tokens(),
+                      cohereResponse.meta().billed_units().input_tokens(),
+                      jakartaResponse,
+                      callDurationNano);
+              return new BatchedEmbeddingResponse(
+                  batchId, cohereResponse.embeddings().values(), modelUsage);
             });
   }
 
-  @Override
-  public int maxBatchSize() {
-    return requestProperties.maxBatchSize();
+  /**
+   * REST client interface for the Cohere Embedding Service.
+   *
+   * 

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface CohereEmbeddingProviderClient { + @POST + @Path("/embed") + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, CohereEmbeddingRequest request); + } + + /** + * Request structure of the Cohere REST service. + * + *

.. + */ + public record CohereEmbeddingRequest(String[] texts, String model, String input_type) {} + + /** + * Response structure of the Cohere REST service. + * + *

aaron - 9 June 2025, change from class to record, check git if this breaks. + * https://docs.cohere.com/reference/embed#response + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record CohereEmbeddingResponse( + String id, List texts, Embeddings embeddings, Meta meta) { + @JsonIgnoreProperties(ignoreUnknown = true) + public record Embeddings(@JsonProperty("float") List values) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + public record Meta(ApiVersion api_version, BilledUnits billed_units, List warnings) { + @JsonIgnoreProperties(ignoreUnknown = true) + public record ApiVersion(String version, boolean is_experimental) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + public record BilledUnits(int input_tokens) {} + } } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java index 75e571b566..4bbbf559d7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java @@ -2,73 +2,56 @@ import static io.stargate.sgv2.jsonapi.config.constants.HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME; import static io.stargate.sgv2.jsonapi.exception.ErrorCodeV1.EMBEDDING_PROVIDER_API_KEY_MISSING; +import static jakarta.ws.rs.core.Response.Status.Family.CLIENT_ERROR; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ModelType; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.service.provider.ProviderBase; import io.stargate.sgv2.jsonapi.util.recordable.Recordable; +import jakarta.ws.rs.core.Response; import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Interface that accepts a list of texts that needs to be vectorized and returns embeddings based - * of chosen model. - */ -public abstract class EmbeddingProvider { - protected static final Logger logger = LoggerFactory.getLogger(EmbeddingProvider.class); +/** TODO */ +public abstract class EmbeddingProvider extends ProviderBase { + + protected static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingProvider.class); + protected final EmbeddingProviderConfigStore.RequestProperties requestProperties; protected final String baseUrl; - protected final String modelName; protected final int dimension; protected final Map vectorizeServiceParameters; - /** Default constructor */ - protected EmbeddingProvider() { - this(null, null, null, 0, null); - } + protected final Duration initialBackOffDuration; + protected final Duration maxBackOffDuration; - /** Constructs an EmbeddingProvider with the specified configuration. */ protected EmbeddingProvider( + ModelProvider modelProvider, EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, String modelName, int dimension, Map vectorizeServiceParameters) { + super(modelProvider, ModelType.EMBEDDING, modelName); + this.requestProperties = requestProperties; this.baseUrl = baseUrl; - this.modelName = modelName; + this.dimension = dimension; this.vectorizeServiceParameters = vectorizeServiceParameters; - } - /** - * Applies a retry mechanism with backoff and jitter to the Uni returned by the embed() method, - * which makes an HTTP request to a third-party service. - * - * @param The type of the item emitted by the Uni. - * @param uni The Uni to which the retry mechanism should be applied. - * @return A Uni that will retry on the specified failures with the configured backoff and jitter. - */ - protected Uni applyRetry(Uni uni) { - return uni.onFailure( - throwable -> - (throwable.getCause() != null - && throwable.getCause() instanceof JsonApiException jae - && jae.getErrorCode() == ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT) - || throwable instanceof TimeoutException) - .retry() - .withBackOff( - Duration.ofMillis(requestProperties.initialBackOffMillis()), - Duration.ofMillis(requestProperties.maxBackOffMillis())) - .withJitter(requestProperties.jitter()) - .atMost(requestProperties.atMostRetries()); + this.initialBackOffDuration = Duration.ofMillis(requestProperties.initialBackOffMillis()); + this.maxBackOffDuration = Duration.ofMillis(requestProperties.maxBackOffMillis()); } /** @@ -79,7 +62,7 @@ protected Uni applyRetry(Uni uni) { * @param embeddingRequestType Type of request (INDEX or SEARCH) * @return VectorResponse */ - public abstract Uni vectorize( + public abstract Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, @@ -90,7 +73,9 @@ public abstract Uni vectorize( * * @return */ - public abstract int maxBatchSize(); + public int maxBatchSize() { + return requestProperties.maxBatchSize(); + } /** * Helper method that has logic wrt whether OpenAI (azure or regular) accepts {@code "dimensions"} @@ -128,9 +113,9 @@ protected static boolean acceptsTitanAIDimensions(String modelName) { } /** - * Helper method to replace parameters in a messageTemplate string with values from a map: - * placeholders are of form {@code {parameterName}} and matching value to look for in the map is - * String {@code "parameterName"}. + * Replace parameters in a messageTemplate string with values from a map: placeholders are of form + * {@code {parameterName}} and matching value to look for in the map is String {@code + * "parameterName"}. * * @param template Template with placeholders to replace * @param parameters Parameters to replace in the messageTemplate @@ -161,13 +146,82 @@ protected String replaceParameters(String template, Map paramete return baseUrl.toString(); } - /** Helper method to check if the API key is present in the header */ - protected void checkEmbeddingApiKeyHeader(String providerId, Optional apiKey) { + /** Check if the API key is present in the header */ + protected void checkEmbeddingApiKeyHeader(Optional apiKey) { + if (apiKey.isEmpty()) { throw EMBEDDING_PROVIDER_API_KEY_MISSING.toApiException( "header value `%s` is missing for embedding provider: %s", - EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME, providerId); + EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME, modelProvider().apiName()); + } + } + + @Override + protected Duration initialBackOffDuration() { + return initialBackOffDuration; + } + + @Override + protected Duration maxBackOffDuration() { + return maxBackOffDuration; + } + + @Override + protected double jitter() { + return requestProperties.jitter(); + } + + @Override + protected int atMostRetries() { + return requestProperties.atMostRetries(); + } + + @Override + protected boolean decideRetry(Throwable throwable) { + + var retry = + (throwable.getCause() instanceof JsonApiException jae + && jae.getErrorCode() == ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT); + + return retry || super.decideRetry(throwable); + } + + /** Maps an HTTP response to a V1 JsonApiException */ + @Override + protected RuntimeException mapHTTPError(Response jakartaResponse, String errorMessage) { + + if (jakartaResponse.getStatus() == Response.Status.REQUEST_TIMEOUT.getStatusCode() + || jakartaResponse.getStatus() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // Status code == 429 + if (jakartaResponse.getStatus() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // Status code in 4XX other than 429 + if (jakartaResponse.getStatusInfo().getFamily() == CLIENT_ERROR) { + return ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // Status code in 5XX + if (jakartaResponse.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) { + return ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); } + + // All other errors, Should never happen as all errors are covered above + return ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + jakartaResponse, jakartaResponse.getStatus(), errorMessage); } /** @@ -176,18 +230,23 @@ protected void checkEmbeddingApiKeyHeader(String providerId, Optional ap * @param batchId - Sequence number for the batch to order the vectors. * @param embeddings - Embedding vectors for the given text inputs. */ - public record Response(int batchId, List embeddings) implements Recordable { + public record BatchedEmbeddingResponse( + int batchId, List embeddings, ModelUsage modelUsage) implements Recordable { - public static Response of(int batchId, List embeddings) { - return new Response(batchId, embeddings); + public static BatchedEmbeddingResponse empty(int batchId) { + return new BatchedEmbeddingResponse(batchId, List.of(), ModelUsage.EMPTY); } @Override public DataRecorder recordTo(DataRecorder dataRecorder) { - return dataRecorder.append("batchId", batchId).append("embeddings", embeddings); + return dataRecorder + .append("batchId", batchId) + .append("embeddings", embeddings) + .append("modelUsage", modelUsage); } } + // TODO: remove and use the general ModelInputType enum public enum EmbeddingRequestType { /** This is used when vectorizing data in write operation for indexing */ INDEX, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java index 13c0a31297..0d79cb51cb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java @@ -5,8 +5,8 @@ import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import io.stargate.sgv2.jsonapi.service.embedding.gateway.EmbeddingGatewayClient; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; @@ -15,13 +15,15 @@ @ApplicationScoped public class EmbeddingProviderFactory { + @Inject Instance embeddingProviderConfigStore; - @Inject OperationsConfig config; + @Inject OperationsConfig operationsConfig; @GrpcClient("embedding") EmbeddingService embeddingService; + @FunctionalInterface interface ProviderConstructor { EmbeddingProvider create( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -31,22 +33,22 @@ EmbeddingProvider create( Map vectorizeServiceParameter); } - private static final Map providersMap = + private static final Map EMBEDDING_PROVIDER_CTORS = // alphabetic order Map.ofEntries( - Map.entry(ProviderConstants.AZURE_OPENAI, AzureOpenAIEmbeddingProvider::new), - Map.entry(ProviderConstants.COHERE, CohereEmbeddingProvider::new), - Map.entry(ProviderConstants.HUGGINGFACE, HuggingFaceEmbeddingProvider::new), + Map.entry(ModelProvider.AZURE_OPENAI, AzureOpenAIEmbeddingProvider::new), + Map.entry(ModelProvider.BEDROCK, AwsBedrockEmbeddingProvider::new), + Map.entry(ModelProvider.COHERE, CohereEmbeddingProvider::new), + Map.entry(ModelProvider.HUGGINGFACE, HuggingFaceEmbeddingProvider::new), Map.entry( - ProviderConstants.HUGGINGFACE_DEDICATED, HuggingFaceDedicatedEmbeddingProvider::new), - Map.entry(ProviderConstants.JINA_AI, JinaAIEmbeddingProvider::new), - Map.entry(ProviderConstants.MISTRAL, MistralEmbeddingProvider::new), - Map.entry(ProviderConstants.NVIDIA, NvidiaEmbeddingProvider::new), - Map.entry(ProviderConstants.OPENAI, OpenAIEmbeddingProvider::new), - Map.entry(ProviderConstants.UPSTAGE_AI, UpstageAIEmbeddingProvider::new), - Map.entry(ProviderConstants.VERTEXAI, VertexAIEmbeddingProvider::new), - Map.entry(ProviderConstants.VOYAGE_AI, VoyageAIEmbeddingProvider::new), - Map.entry(ProviderConstants.BEDROCK, AwsBedrockEmbeddingProvider::new)); + ModelProvider.HUGGINGFACE_DEDICATED, HuggingFaceDedicatedEmbeddingProvider::new), + Map.entry(ModelProvider.JINA_AI, JinaAIEmbeddingProvider::new), + Map.entry(ModelProvider.MISTRAL, MistralEmbeddingProvider::new), + Map.entry(ModelProvider.NVIDIA, NvidiaEmbeddingProvider::new), + Map.entry(ModelProvider.OPENAI, OpenAIEmbeddingProvider::new), + Map.entry(ModelProvider.UPSTAGE_AI, UpstageAIEmbeddingProvider::new), + Map.entry(ModelProvider.VERTEXAI, VertexAIEmbeddingProvider::new), + Map.entry(ModelProvider.VOYAGE_AI, VoyageAIEmbeddingProvider::new)); public EmbeddingProvider getConfiguration( Optional tenant, @@ -57,13 +59,20 @@ public EmbeddingProvider getConfiguration( Map vectorizeServiceParameters, Map authentication, String commandName) { + if (vectorizeServiceParameters == null) { vectorizeServiceParameters = Map.of(); } + + var modelProvider = + ModelProvider.fromApiName(serviceName) + .orElseThrow( + () -> new IllegalArgumentException("Unknown service provider: " + serviceName)); + return addService( tenant, authToken, - serviceName, + modelProvider, modelName, dimension, vectorizeServiceParameters, @@ -74,22 +83,24 @@ public EmbeddingProvider getConfiguration( private synchronized EmbeddingProvider addService( Optional tenant, Optional authToken, - String serviceName, + ModelProvider modelProvider, String modelName, int dimension, Map vectorizeServiceParameters, Map authentication, String commandName) { - final EmbeddingProviderConfigStore.ServiceConfig configuration = - embeddingProviderConfigStore.get().getConfiguration(tenant, serviceName); - if (config.enableEmbeddingGateway()) { + + final EmbeddingProviderConfigStore.ServiceConfig serviceConfig = + embeddingProviderConfigStore.get().getConfiguration(tenant, modelProvider.apiName()); + + if (operationsConfig.enableEmbeddingGateway()) { return new EmbeddingGatewayClient( - configuration.requestConfiguration(), - configuration.serviceProvider(), + serviceConfig.requestConfiguration(), + modelProvider, dimension, tenant, authToken, - configuration.getBaseUrl(modelName), + serviceConfig.getBaseUrl(modelName), modelName, embeddingService, vectorizeServiceParameters, @@ -97,12 +108,13 @@ private synchronized EmbeddingProvider addService( commandName); } - if (configuration.serviceProvider().equals(ProviderConstants.CUSTOM)) { - Optional> clazz = configuration.implementationClass(); - if (!clazz.isPresent()) { + if (serviceConfig.serviceProvider().equals(ModelProvider.CUSTOM.apiName())) { + Optional> clazz = serviceConfig.implementationClass(); + if (clazz.isEmpty()) { throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( "custom class undefined"); } + try { return (EmbeddingProvider) clazz.get().getConstructor(int.class).newInstance(dimension); } catch (Exception e) { @@ -112,14 +124,25 @@ private synchronized EmbeddingProvider addService( } } - ProviderConstructor ctor = providersMap.get(configuration.serviceProvider()); + // aaron 7 June 2025, the code previously threw this error when the name from the config was not + // found in the code, but this is a serious error that should not happen, it should be more like + // a IllegalState. + var serviceConfigModelProvider = + ModelProvider.fromApiName(serviceConfig.serviceProvider()) + .orElseThrow( + () -> + ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + "unknown service provider '%s'", serviceConfig.serviceProvider())); + + ProviderConstructor ctor = EMBEDDING_PROVIDER_CTORS.get(serviceConfigModelProvider); if (ctor == null) { - throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( - "unknown service provider '%s'", configuration.serviceProvider()); + throw new IllegalStateException( + "ModelProvider does not have a constructor: " + serviceConfigModelProvider.apiName()); } + return ctor.create( - configuration.requestConfiguration(), - configuration.getBaseUrl(modelName), + serviceConfig.requestConfiguration(), + serviceConfig.getBaseUrl(modelName), modelName, dimension, vectorizeServiceParameters); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java index 950976545b..4325c53519 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java @@ -1,20 +1,20 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -23,9 +23,11 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class HuggingFaceDedicatedEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.HUGGINGFACE_DEDICATED; - private final HuggingFaceDedicatedEmbeddingProviderClient - huggingFaceDedicatedEmbeddingProviderClient; + + public static final String HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL = + "endpoint-defined-model"; + + private final HuggingFaceDedicatedEmbeddingProviderClient huggingFaceClient; public HuggingFaceDedicatedEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -33,108 +35,130 @@ public HuggingFaceDedicatedEmbeddingProvider( String modelName, int dimension, Map vectorizeServiceParameters) { - super(requestProperties, baseUrl, modelName, dimension, vectorizeServiceParameters); + super( + ModelProvider.HUGGINGFACE_DEDICATED, + requestProperties, + baseUrl, + modelName, + dimension, + vectorizeServiceParameters); // replace placeholders: endPointName, regionName, cloudName String dedicatedApiUrl = replaceParameters(baseUrl, vectorizeServiceParameters); - huggingFaceDedicatedEmbeddingProviderClient = + huggingFaceClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(dedicatedApiUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(HuggingFaceDedicatedEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface HuggingFaceDedicatedEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message": "Batch size error",
-     *   "type": "validation"
-     * }
-     *
-     * {
-     *   "message": "Model is overloaded",
-     *   "type": "overloaded"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "message" node - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - // huggingfaceDedicated, Test Embeddings Inference, openAI compatible route - // https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/openai_embed - private record EmbeddingRequest(String[] input) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * The example response body is: + * + *
+   * {
+   *   "message": "Batch size error",
+   *   "type": "validation"
+   * }
+   *
+   * {
+   *   "message": "Model is overloaded",
+   *   "type": "overloaded"
+   * }
+   */
+  @Override
+  protected String errorMessageJsonPtr() {
+    return "/message";
   }
 
   @Override
-  public Uni vectorize(
+  public Uni vectorize(
       int batchId,
       List texts,
       EmbeddingCredentials embeddingCredentials,
       EmbeddingRequestType embeddingRequestType) {
-    checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey());
 
-    String[] textArray = new String[texts.size()];
-    EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray));
+    checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey());
 
-    Uni response =
-        applyRetry(
-            huggingFaceDedicatedEmbeddingProviderClient.embed(
-                HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(),
-                request));
+    var huggingFaceRequest =
+        new HuggingFaceDedicatedEmbeddingRequest(texts.toArray(new String[texts.size()]));
 
-    return response
+    // TODO: V2 error
+    // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty.
+    var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get();
+
+    long callStartNano = System.nanoTime();
+    return retryHTTPCall(huggingFaceClient.embed(accessToken, huggingFaceRequest))
         .onItem()
         .transform(
-            resp -> {
-              if (resp.data() == null) {
-                return Response.of(batchId, Collections.emptyList());
+            jakartaResponse -> {
+              var huggingFaceResponse =
+                  jakartaResponse.readEntity(HuggingFaceDedicatedEmbeddingResponse.class);
+              long callDurationNano = System.nanoTime() - callStartNano;
+
+              // aaron - 10 June 2025 - previous code would silently swallow no data returned
+              // and return an empty result. If we made a request we should get a response.
+              if (huggingFaceResponse.data() == null) {
+                throw new IllegalStateException(
+                    "ModelProvider %s returned empty data for model %s"
+                        .formatted(modelProvider(), modelName()));
               }
-              Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
+
+              Arrays.sort(huggingFaceResponse.data(), (a, b) -> a.index() - b.index());
               List vectors =
-                  Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
-              return Response.of(batchId, vectors);
+                  Arrays.stream(huggingFaceResponse.data())
+                      .map(HuggingFaceDedicatedEmbeddingResponse.Data::embedding)
+                      .toList();
+
+              var modelUsage =
+                  createModelUsage(
+                      embeddingCredentials.tenantId(),
+                      ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
+                      huggingFaceResponse.usage().prompt_tokens(),
+                      huggingFaceResponse.usage().total_tokens(),
+                      jakartaResponse,
+                      callDurationNano);
+              return new BatchedEmbeddingResponse(batchId, vectors, modelUsage);
             });
   }
 
-  @Override
-  public int maxBatchSize() {
-    return requestProperties.maxBatchSize();
+  /**
+   * REST client interface for the HuggingFace Dedicated Embedding Service.
+   *
+   * 

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface HuggingFaceDedicatedEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, + HuggingFaceDedicatedEmbeddingRequest request); + } + + /** + * Request structure of the HuggingFace Dedicated REST service. + * + *

huggingfaceDedicated, Test Embeddings Inference, openAI compatible route + * https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/openai_embed + */ + public record HuggingFaceDedicatedEmbeddingRequest(String[] input) {} + + /** + * Response structure of the HuggingFace Dedicated REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record HuggingFaceDedicatedEmbeddingResponse( + String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java index 35cc64b2a6..4927de0209 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java @@ -1,23 +1,20 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; -import jakarta.ws.rs.core.HttpHeaders; -import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.*; import java.net.URI; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -26,8 +23,8 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class HuggingFaceEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.HUGGINGFACE; - private final HuggingFaceEmbeddingProviderClient huggingFaceEmbeddingProviderClient; + + private final HuggingFaceEmbeddingProviderClient huggingFaceClient; public HuggingFaceEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -35,88 +32,122 @@ public HuggingFaceEmbeddingProvider( String modelName, int dimension, Map vectorizeServiceParameters) { - super(requestProperties, baseUrl, modelName, dimension, vectorizeServiceParameters); + super( + ModelProvider.HUGGINGFACE, + requestProperties, + baseUrl, + modelName, + dimension, + vectorizeServiceParameters); - huggingFaceEmbeddingProviderClient = + huggingFaceClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(HuggingFaceEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface HuggingFaceEmbeddingProviderClient { - @POST - @Path("/{modelId}") - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni> embed( - @HeaderParam("Authorization") String accessToken, - @PathParam("modelId") String modelId, - EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extracts the error message from the response body. The example response body is: - * - *

-     * {
-     *   "error": "Authorization header is correct, but the token seems invalid"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body, or null if the message is not - * found. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "error" node - JsonNode errorNode = rootNode.path("error"); - // Return the text of the "message" node, or the whole response body if it is missing - return errorNode.isMissingNode() ? rootNode.toString() : errorNode.toString(); - } - } - - private record EmbeddingRequest(List inputs, Options options) { - public record Options(boolean waitForModel) {} + /** + * The example response body is: + * + *
+   * {
+   *   "error": "Authorization header is correct, but the token seems invalid"
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/error"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - EmbeddingRequest request = new EmbeddingRequest(texts, new EmbeddingRequest.Options(true)); - return applyRetry( - huggingFaceEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - modelName, - request)) + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + var huggingFaceRequest = + new HuggingFaceEmbeddingRequest(texts, new HuggingFaceEmbeddingRequest.Options(true)); + + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); + + long callStartNano = System.nanoTime(); + return retryHTTPCall(huggingFaceClient.embed(accessToken, modelName(), huggingFaceRequest)) .onItem() .transform( - resp -> { - if (resp == null) { - return Response.of(batchId, Collections.emptyList()); - } - return Response.of(batchId, resp); + jakartaResponse -> { + + // NOTE: Boxing happening here, as the response is a JSON array of arrays of floats. + // should return zero legnth list if entity is null or empty. + // TODO: how to deserialise without boxing ? + List vectorsBoxed = jakartaResponse.readEntity(new GenericType<>() {}); + long callDurationNano = System.nanoTime() - callStartNano; + + List vectorsUnboxed = + vectorsBoxed.stream() + .map( + vector -> { + if (vector == null) { + return new float[0]; // Handle null vectors + } + float[] unboxed = new float[vector.length]; + for (int i = 0; i < vector.length; i++) { + unboxed[i] = vector[i]; + } + return unboxed; + }) + .toList(); + + // The hugging face API we are calling does not return usage information, there may be + // a + // newer version of the API that does, but for now we will not return usage + // information. + // https://huggingface.co/blog/getting-started-with-embeddings + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectorsUnboxed, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the HuggingFace Embedding Service. + * + *

.. NOTE: the response is just a JSON array of arrays of floats, e.g.: + * + *

+   *   [[-0.123, 0.456, ...], [-0.789, 0.012, ...], ...]
+   * 
+ */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface HuggingFaceEmbeddingProviderClient { + @POST + @Path("/{modelId}") + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, + @PathParam("modelId") String modelId, + HuggingFaceEmbeddingRequest request); + } + + /** + * Request structure of the HuggingFace REST service. + * + *

.. + */ + public record HuggingFaceEmbeddingRequest(List inputs, Options options) { + public record Options(boolean waitForModel) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java index c45a7d803c..c6b7a33550 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java @@ -2,20 +2,20 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -29,8 +29,8 @@ * called. */ public class JinaAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.JINA_AI; - private final JinaAIEmbeddingProviderClient jinaAIEmbeddingProviderClient; + + private final JinaAIEmbeddingProviderClient jinaClient; public JinaAIEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -39,116 +39,132 @@ public JinaAIEmbeddingProvider( int dimension, Map vectorizeServiceParameters) { super( + ModelProvider.JINA_AI, requestProperties, baseUrl, modelName, acceptsJinaAIDimensions(modelName) ? dimension : 0, vectorizeServiceParameters); - jinaAIEmbeddingProviderClient = + jinaClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(JinaAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface JinaAIEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *    "detail": "ValidationError(model='TextDoc', errors=[{'loc': ('text',), 'msg': 'Single text cannot exceed 8192 tokens. 10454 tokens given.', 'type': 'value_error'}])"
-     * }
-     * 
- * - *
-     *     {"detail":"Failed to authenticate with the provided api key."}
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "detail" node - JsonNode detailNode = rootNode.path("detail"); - return detailNode.isMissingNode() ? rootNode.toString() : detailNode.toString(); - } - } - - // By default, Jina Text Encoding Format is float - private record EmbeddingRequest( - List input, - String model, - @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions, - @JsonInclude(value = JsonInclude.Include.NON_NULL) String task, - @JsonInclude(value = JsonInclude.Include.NON_NULL) Boolean late_chunking) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * Extract the error message from the response body. The example response body is: + * + *
+   * {
+   *    "detail": "ValidationError(model='TextDoc', errors=[{'loc': ('text',), 'msg': 'Single text cannot exceed 8192 tokens. 10454 tokens given.', 'type': 'value_error'}])"
+   * }
+   * 
+ * + *
+   *     {"detail":"Failed to authenticate with the provided api key."}
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/detail"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - EmbeddingRequest request = - new EmbeddingRequest( + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + + var jinaRequest = + new JinaEmbeddingRequest( texts, - modelName, + modelName(), dimension, (String) vectorizeServiceParameters.get("task"), (Boolean) vectorizeServiceParameters.get("late_chunking")); - Uni response = - applyRetry( - jinaAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(jinaClient.embed(accessToken, jinaRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var jinaResponse = jakartaResponse.readEntity(JinaEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (jinaResponse.data() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(jinaResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList(); - return Response.of(batchId, vectors); + Arrays.stream(jinaResponse.data()) + .map(JinaEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + jinaResponse.usage().prompt_tokens(), + jinaResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Jina Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface JinaAIEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, JinaEmbeddingRequest request); + } + + /** + * Request structure of the Voyage REST service. + * + *

By default, Jina Text Encoding Format is float + */ + public record JinaEmbeddingRequest( + List input, + String model, + @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions, + @JsonInclude(value = JsonInclude.Include.NON_NULL) String task, + @JsonInclude(value = JsonInclude.Include.NON_NULL) Boolean late_chunking) {} + + /** + * Response structure of the Jina REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record JinaEmbeddingResponse(String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + public record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + public record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java index 22a284187b..a720c7187e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java @@ -7,6 +7,7 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.api.request.RequestContext; import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonApiMetricsConfig; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -32,6 +33,16 @@ public MeteredEmbeddingProvider( RequestContext dataApiRequestInfo, EmbeddingProvider embeddingProvider, String commandName) { + // aaron 9 June 2025 - we need to remove this "metered" design pattern, for now just pass the + // config through + super( + embeddingProvider.modelProvider(), + embeddingProvider.requestProperties, + embeddingProvider.baseUrl, + embeddingProvider.modelName(), + embeddingProvider.dimension, + embeddingProvider.vectorizeServiceParameters); + this.meterRegistry = meterRegistry; this.jsonApiMetricsConfig = jsonApiMetricsConfig; this.dataApiRequestInfo = dataApiRequestInfo; @@ -39,6 +50,12 @@ public MeteredEmbeddingProvider( this.commandName = commandName; } + @Override + protected String errorMessageJsonPtr() { + // not used we are just passing through + return ""; + } + /** * Vectorizes a list of texts, adding metrics collection for the duration of the vectorization * call and the size of the input texts. @@ -49,11 +66,12 @@ public MeteredEmbeddingProvider( * @return a {@link Uni} that will provide the list of vectorized texts, as arrays of floats. */ @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + // String bytes metrics for vectorize DistributionSummary ds = DistributionSummary.builder(jsonApiMetricsConfig.vectorizeInputBytesMetrics()) @@ -91,11 +109,18 @@ public Uni vectorize( Collections.sort( vectorizedBatches, (a, b) -> Integer.compare(a.batchId(), b.batchId())); List result = new ArrayList<>(); - for (Response vectorizedBatch : vectorizedBatches) { + + ModelUsage aggregatedModelUsage = null; + for (BatchedEmbeddingResponse vectorizedBatch : vectorizedBatches) { + + aggregatedModelUsage = + aggregatedModelUsage == null + ? vectorizedBatch.modelUsage() + : aggregatedModelUsage.merge(vectorizedBatch.modelUsage()); // create the final ordered result result.addAll(vectorizedBatch.embeddings()); } - return Response.of(1, result); + return new BatchedEmbeddingResponse(1, result, aggregatedModelUsage); }) .invoke( () -> diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java index f6af9482ac..6fd5198130 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java @@ -1,20 +1,20 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -28,8 +28,8 @@ * REST API being called. */ public class MistralEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.MISTRAL; - private final MistralEmbeddingProviderClient mistralEmbeddingProviderClient; + + private final MistralEmbeddingProviderClient mistralClient; public MistralEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -37,110 +37,129 @@ public MistralEmbeddingProvider( String modelName, int dimension, Map vectorizeServiceParameters) { - super(requestProperties, baseUrl, modelName, dimension, vectorizeServiceParameters); + super( + ModelProvider.MISTRAL, + requestProperties, + baseUrl, + modelName, + dimension, + vectorizeServiceParameters); - mistralEmbeddingProviderClient = + mistralClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(MistralEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface MistralEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extracts the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message":"Unauthorized",
-     *   "request_id":"1383ed1b472cb85fdfaa9624515d2d0e"
-     * }
-     *
-     * {
-     *   "object":"error",
-     *   "message":"Input is too long. Max length is 8192 got 10970",
-     *   "type":"invalid_request_error",
-     *   "param":null,
-     *   "code":null
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body, or null if the message is not - * found. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); - // Extract the "message" node from the root node - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - private record EmbeddingRequest(List input, String model, String encoding_format) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse( - String id, String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage( - int prompt_tokens, int total_tokens, int completion_tokens, int request_count) {} + /** + * The example response body is: + * + *
+   * {
+   *   "message":"Unauthorized",
+   *   "request_id":"1383ed1b472cb85fdfaa9624515d2d0e"
+   * }
+   *
+   * {
+   *   "object":"error",
+   *   "message":"Input is too long. Max length is 8192 got 10970",
+   *   "type":"invalid_request_error",
+   *   "param":null,
+   *   "code":null
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/message"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - EmbeddingRequest request = new EmbeddingRequest(texts, modelName, "float"); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); - Uni response = - applyRetry( - mistralEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + var mistralRequest = new MistralEmbeddingRequest(texts, modelName(), "float"); + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + + return retryHTTPCall(mistralClient.embed(accessToken, mistralRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var mistralResponse = jakartaResponse.readEntity(MistralEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (mistralResponse.data() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(mistralResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(mistralResponse.data()) + .map(MistralEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + mistralResponse.usage().prompt_tokens(), + mistralResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Mistral Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface MistralEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, MistralEmbeddingRequest request); + } + + /** + * Request structure of the Mistral REST service. + * + *

.. + */ + public record MistralEmbeddingRequest(List input, String model, String encoding_format) {} + + /** + * Response structure of the Mistral REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record MistralEmbeddingResponse( + String id, String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage( + int prompt_tokens, int total_tokens, int completion_tokens, int request_count) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java index 9a96e12d27..36fa1f076c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java @@ -1,22 +1,22 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; +import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -29,8 +29,8 @@ * of chosen Nvidia model. */ public class NvidiaEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.NVIDIA; - private final NvidiaEmbeddingProviderClient nvidiaEmbeddingProviderClient; + + private final NvidiaEmbeddingProviderClient nvidiaClient; public NvidiaEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -38,102 +38,120 @@ public NvidiaEmbeddingProvider( String modelName, int dimension, Map vectorizeServiceParameters) { - super(requestProperties, baseUrl, modelName, dimension, vectorizeServiceParameters); - - nvidiaEmbeddingProviderClient = + super( + ModelProvider.NVIDIA, + requestProperties, + baseUrl, + modelName, + dimension, + vectorizeServiceParameters); + + nvidiaClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(NvidiaEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface NvidiaEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "object": "error",
-     *   "message": "Input length exceeds the maximum token length of the model",
-     *   "detail": {},
-     *   "type": "invalid_request_error"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - private record EmbeddingRequest(String[] input, String model, String input_type) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * The example response body is: + * + *
+   * {
+   *   "object": "error",
+   *   "message": "Input length exceeds the maximum token length of the model",
+   *   "detail": {},
+   *   "type": "invalid_request_error"
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/message"; } - private static final String PASSAGE = "passage"; - private static final String QUERY = "query"; - @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - String[] textArray = new String[texts.size()]; - String input_type = embeddingRequestType == EmbeddingRequestType.INDEX ? PASSAGE : QUERY; - - EmbeddingRequest request = - new EmbeddingRequest(texts.toArray(textArray), modelName, input_type); + var input_type = embeddingRequestType == EmbeddingRequestType.INDEX ? "passage" : "query"; + var nvidiaRequest = + new NvidiaEmbeddingRequest( + texts.toArray(new String[texts.size()]), modelName(), input_type); - Uni response = - applyRetry(nvidiaEmbeddingProviderClient.embed("Bearer ", request)); + // TODO: XXX No token to pass with the nvidia request for now. This will change on main merge + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY; - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(nvidiaClient.embed(accessToken, nvidiaRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var nvidiaResponse = jakartaResponse.readEntity(NvidiaEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (nvidiaResponse.data() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(nvidiaResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(nvidiaResponse.data()) + .map(NvidiaEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + nvidiaResponse.usage().prompt_tokens(), + nvidiaResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the NVidia Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface NvidiaEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, NvidiaEmbeddingRequest request); + } + + /** + * Request structure of the Nidia REST service. + * + *

.. + */ + public record NvidiaEmbeddingRequest(String[] input, String model, String input_type) {} + + /** + * Response structure of the Nvidia REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error + private record NvidiaEmbeddingResponse(Data[] data, String model, Usage usage) { + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java index a5a5be13cd..b1ad0f0601 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java @@ -2,24 +2,23 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -28,8 +27,8 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class OpenAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.OPENAI; - private final OpenAIEmbeddingProviderClient openAIEmbeddingProviderClient; + + private final OpenAIEmbeddingProviderClient openAIClient; public OpenAIEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -39,119 +38,133 @@ public OpenAIEmbeddingProvider( Map vectorizeServiceParameters) { // One special case: legacy "ada-002" model does not accept "dimension" parameter super( + ModelProvider.OPENAI, requestProperties, baseUrl, modelName, acceptsOpenAIDimensions(modelName) ? dimension : 0, vectorizeServiceParameters); - openAIEmbeddingProviderClient = + openAIClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(OpenAIEmbeddingProviderClient.class); } + /** + * The example response body is: + * + *

+   * {
+   *   "error": {
+   *     "message": "You exceeded your current quota, please check your plan and billing details. For
+   *                 more information on this error, read the docs:
+   *                 https://platform.openai.com/docs/guides/error-codes/api-errors.",
+   *     "type": "insufficient_quota",
+   *     "param": null,
+   *     "code": "insufficient_quota"
+   *   }
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/error/message"; + } + + @Override + public Uni vectorize( + int batchId, + List texts, + EmbeddingCredentials embeddingCredentials, + EmbeddingRequestType embeddingRequestType) { + + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + + var openAiRequest = + new OpenAiEmbeddingRequest(texts.toArray(new String[texts.size()]), modelName(), dimension); + var organizationId = (String) vectorizeServiceParameters.get("organizationId"); + var projectId = (String) vectorizeServiceParameters.get("projectId"); + + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); + + final long callStartNano = System.nanoTime(); + return retryHTTPCall(openAIClient.embed(accessToken, organizationId, projectId, openAiRequest)) + .onItem() + .transform( + jakartaResponse -> { + var openAiResponse = jakartaResponse.readEntity(OpenAiEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (openAiResponse.data() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); + } + Arrays.sort(openAiResponse.data(), (a, b) -> a.index() - b.index()); + List vectors = + Arrays.stream(openAiResponse.data()) + .map(OpenAiEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.INPUT_TYPE_UNSPECIFIED, + openAiResponse.usage().prompt_tokens(), + openAiResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); + }); + } + + /** + * REST client interface for the OpenAI Embedding Service. + * + *

.. + */ @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) public interface OpenAIEmbeddingProviderClient { @POST @Path("/embeddings") @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @HeaderParam("OpenAI-Organization") String organizationId, @HeaderParam("OpenAI-Project") String projectId, - EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "error": {
-     *     "message": "You exceeded your current quota, please check your plan and billing details. For
-     *                 more information on this error, read the docs:
-     *                 https://platform.openai.com/docs/guides/error-codes/api-errors.",
-     *     "type": "insufficient_quota",
-     *     "param": null,
-     *     "code": "insufficient_quota"
-     *   }
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "message" node from the "error" node - JsonNode messageNode = rootNode.at("/error/message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.asText(); - } + OpenAiEmbeddingRequest request); } - private record EmbeddingRequest( + /** + * Request structure of the OpenAI REST service. + * + *

.. + */ + public record OpenAiEmbeddingRequest( String[] input, String model, @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { + /** + * Response structure of the OpenAI REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record OpenAiEmbeddingResponse(String object, Data[] data, String model, Usage usage) { @JsonIgnoreProperties(ignoreUnknown = true) private record Data(String object, int index, float[] embedding) {} @JsonIgnoreProperties(ignoreUnknown = true) private record Usage(int prompt_tokens, int total_tokens) {} } - - @Override - public Uni vectorize( - int batchId, - List texts, - EmbeddingCredentials embeddingCredentials, - EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - String[] textArray = new String[texts.size()]; - EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension); - String organizationId = (String) vectorizeServiceParameters.get("organizationId"); - String projectId = (String) vectorizeServiceParameters.get("projectId"); - - Uni response = - applyRetry( - openAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - organizationId, - projectId, - request)); - - return response - .onItem() - .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); - } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); - List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); - }); - } - - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); - } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java index 320524c3cd..b374d4a542 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java @@ -2,7 +2,6 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; @@ -10,15 +9,16 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -27,11 +27,12 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class UpstageAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.UPSTAGE_AI; + private static final String UPSTAGE_MODEL_SUFFIX_QUERY = "-query"; private static final String UPSTAGE_MODEL_SUFFIX_PASSAGE = "-passage"; + private final String modelNamePrefix; - private final UpstageAIEmbeddingProviderClient upstageAIEmbeddingProviderClient; + private final UpstageAIEmbeddingProviderClient upstageClient; public UpstageAIEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -39,100 +40,85 @@ public UpstageAIEmbeddingProvider( String modelNamePrefix, int dimension, Map vectorizeServiceParameters) { - super(requestProperties, baseUrl, modelNamePrefix, dimension, vectorizeServiceParameters); + super( + ModelProvider.UPSTAGE_AI, + requestProperties, + baseUrl, + modelNamePrefix, + dimension, + vectorizeServiceParameters); this.modelNamePrefix = modelNamePrefix; - upstageAIEmbeddingProviderClient = + upstageClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(UpstageAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface UpstageAIEmbeddingProviderClient { - @POST - // no path specified, as it is already included in the baseUri - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extracts the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message": "Unauthorized"
-     * }
-     *
-     * {
-     *   "error": {
-     *     "message": "This model's maximum context length is 4000 tokens. however you requested 10969 tokens. Please reduce your prompt.",
-     *     "type": "invalid_request_error",
-     *     "param": null,
-     *     "code": null
-     *   }
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body, or null if the message is not - * found. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Check if the root node contains a "message" field - JsonNode messageNode = rootNode.path("message"); - if (!messageNode.isMissingNode()) { - return messageNode.asText(); - } - // If the "message" field is not found, check for the nested "error" object - JsonNode errorMessageNode = rootNode.at("/error/message"); - if (!errorMessageNode.isMissingNode()) { - return errorMessageNode.asText(); - } - // Return the whole response body if no message is found - return rootNode.toString(); - } + @Override + protected String errorMessageJsonPtr() { + // overriding the function that calls this + return ""; } - // NOTE: "input" is a single String, not array of Constants! - record EmbeddingRequest(String input, String model) {} + /** + * Extracts the error message from the response body. The example response body is: + * + *
+   * {
+   *   "message": "Unauthorized"
+   * }
+   *
+   * {
+   *   "error": {
+   *     "message": "This model's maximum context length is 4000 tokens. however you requested 10969 tokens. Please reduce your prompt.",
+   *     "type": "invalid_request_error",
+   *     "param": null,
+   *     "code": null
+   *   }
+   * }
+   * 
+ */ + @Override + protected String responseErrorMessage(Response jakartaResponse) { - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - record EmbeddingResponse(Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - record Data(int index, float[] embedding) {} + JsonNode rootNode = jakartaResponse.readEntity(JsonNode.class); - @JsonIgnoreProperties(ignoreUnknown = true) - record Usage(int prompt_tokens, int total_tokens) {} + // Check if the root node contains a "message" field + JsonNode messageNode = rootNode.path("message"); + if (!messageNode.isMissingNode()) { + return messageNode.asText(); + } + + // If the "message" field is not found, check for the nested "error" object + JsonNode errorMessageNode = rootNode.at("/error/message"); + if (!errorMessageNode.isMissingNode()) { + return errorMessageNode.asText(); + } + // Return the whole response body if no message is found + return rootNode.toString(); } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); + + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + // Oddity: Implementation does not support batching, so we only accept "batches" // of 1 String, fail for others if (texts.size() != 1) { + // TODO: This should be IllegalArgumentException + // Temporary fail message: with re-batching will give better information throw ErrorCodeV1.INVALID_VECTORIZE_VALUE_TYPE.toApiException( "UpstageAI only supports vectorization of 1 text at a time, got " + texts.size()); } + // Another oddity: model name used as prefix final String modelName = modelNamePrefix @@ -140,30 +126,83 @@ public Uni vectorize( ? UPSTAGE_MODEL_SUFFIX_QUERY : UPSTAGE_MODEL_SUFFIX_PASSAGE); - EmbeddingRequest request = new EmbeddingRequest(texts.get(0), modelName); + var upstageRequest = new UpstageEmbeddingRequest(texts.getFirst(), modelName); - Uni response = - applyRetry( - upstageAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(upstageClient.embed(accessToken, upstageRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var upstageResponse = jakartaResponse.readEntity(UpstageEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (upstageResponse.data() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + // aaron - 11 june 2025 - prev code would sort upstageResponse.data() BUT per above we + // only support a batch size of 1, so no need to sort. + List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(upstageResponse.data()) + .map(UpstageEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + upstageResponse.usage().prompt_tokens(), + upstageResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Upstage Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface UpstageAIEmbeddingProviderClient { + @POST + // no path specified, as it is already included in the baseUri + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, UpstageEmbeddingRequest request); + } + + /** + * Request structure of the Upstage REST service. + * + *

NOTE: "input" is a single String, not array of Constants! + */ + public record UpstageEmbeddingRequest(String input, String model) {} + + /** + * Response structure of the Upstage REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record UpstageEmbeddingResponse(Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + record Data(int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java index 556d3e0705..84fc377900 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java @@ -3,34 +3,34 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; -import java.util.Collections; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class VertexAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.VERTEXAI; - private final VertexAIEmbeddingProviderClient vertexAIEmbeddingProviderClient; + + private final VertexAIEmbeddingProviderClient vertexClient; public VertexAIEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -38,157 +38,142 @@ public VertexAIEmbeddingProvider( String modelName, int dimension, Map serviceParameters) { - super(requestProperties, baseUrl, modelName, dimension, serviceParameters); + super( + ModelProvider.VERTEXAI, + requestProperties, + baseUrl, + modelName, + dimension, + serviceParameters); String actualUrl = replaceParameters(baseUrl, serviceParameters); - vertexAIEmbeddingProviderClient = + vertexClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(actualUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(VertexAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface VertexAIEmbeddingProviderClient { - @POST - @Path("/{modelId}:predict") - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, - @PathParam("modelId") String modelId, - EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * TODO: Add customized error message extraction logic here.
- * Extract the error message from the response body. The example response body is: - * - *

-     *
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - return rootNode.toString(); - } + @Override + protected String errorMessageJsonPtr() { + // overriding the call that needs this. + return null; } - private record EmbeddingRequest(List instances) { - public record Content(String content) {} + @Override + protected String responseErrorMessage(Response jakartaResponse) { + // aaron 9 june 2025 - this is what it did originally, just get the whole response body + + // Get the whole response body + JsonNode rootNode = jakartaResponse.readEntity(JsonNode.class); + return rootNode.toString(); } - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private static class EmbeddingResponse { - public EmbeddingResponse() {} + @Override + public Uni vectorize( + int batchId, + List texts, + EmbeddingCredentials embeddingCredentials, + EmbeddingRequestType embeddingRequestType) { - private List predictions; + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); - @JsonIgnore private Object metadata; + var vertexRequest = + new VertexEmbeddingRequest( + texts.stream().map(VertexEmbeddingRequest.Content::new).toList()); - public List getPredictions() { - return predictions; - } + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - public void setPredictions(List predictions) { - this.predictions = predictions; - } + long callStartNano = System.nanoTime(); + return retryHTTPCall(vertexClient.embed(accessToken, modelName(), vertexRequest)) + .onItem() + .transform( + jakartaResponse -> { + var vertexResponse = jakartaResponse.readEntity(VertexEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (vertexResponse.predictions() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); + } - public Object getMetadata() { - return metadata; - } + // token usage is for each of the embeddings , need to sum it up + int total_tokens = 0; + List vectors = new ArrayList<>(vertexResponse.predictions().size()); + for (var prediction : vertexResponse.predictions()) { + vectors.add(prediction.embeddings().values); + total_tokens += prediction.embeddings().statistics().token_count; + } - public void setMetadata(Object metadata) { - this.metadata = metadata; - } + // Docs say the token_count in the response is the "Number of tokens of the input + // text." + // so seems safe ot use this as the prompt_tokens and total_tokens + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#response_body + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + total_tokens, + total_tokens, + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); + }); + } - @JsonIgnoreProperties(ignoreUnknown = true) - protected static class Prediction { - public Prediction() {} + /** + * REST client interface for the Vertex Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface VertexAIEmbeddingProviderClient { - private Embeddings embeddings; + @POST + @Path("/{modelId}:predict") + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, + @PathParam("modelId") String modelId, + VertexEmbeddingRequest request); + } - public Embeddings getEmbeddings() { - return embeddings; - } + /** + * Request structure of the Vertex REST service. + * + *

.. + */ + private record VertexEmbeddingRequest(List instances) { + public record Content(String content) {} + } - public void setEmbeddings(Embeddings embeddings) { - this.embeddings = embeddings; - } + /** + * Response structure of the Vertex REST service. + * + *

.. aaron - 10 June 2025 - this used to be a class, moved to be a record for consistency + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record VertexEmbeddingResponse( + List predictions, + // aaron 10 june 2025, could not see metadata in API docs, but it was in the old code. + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#response_body + @JsonIgnore Object metadata) { + @JsonIgnoreProperties(ignoreUnknown = true) + public record Prediction(Embeddings embeddings) { @JsonIgnoreProperties(ignoreUnknown = true) - protected static class Embeddings { - public Embeddings() {} - - private float[] values; - - @JsonIgnore private Object statistics; - - public float[] getValues() { - return values; - } + public record Embeddings(float[] values, Statistics statistics) { - public void setValues(float[] values) { - this.values = values; - } - - public Object getStatistics() { - return statistics; - } - - public void setStatistics(Object statistics) { - this.statistics = statistics; - } + @JsonIgnoreProperties(ignoreUnknown = true) + public record Statistics(boolean truncated, int token_count) {} } } } - - @Override - public Uni vectorize( - int batchId, - List texts, - EmbeddingCredentials embeddingCredentials, - EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - EmbeddingRequest request = - new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList()); - - Uni serviceResponse = - applyRetry( - vertexAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - modelName, - request)); - - return serviceResponse - .onItem() - .transform( - response -> { - if (response.getPredictions() == null) { - return Response.of(batchId, Collections.emptyList()); - } - List vectors = - response.getPredictions().stream() - .map(prediction -> prediction.getEmbeddings().getValues()) - .collect(Collectors.toList()); - return Response.of(batchId, vectors); - }); - } - - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); - } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java index 017dd00133..5c87851e97 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java @@ -2,23 +2,22 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -27,8 +26,9 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class VoyageAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.VOYAGE_AI; - private final VoyageAIEmbeddingProviderClient voyageAIEmbeddingProviderClient; + + private final VoyageAIEmbeddingProviderClient voyageClient; + private final String requestTypeQuery, requestTypeIndex; private final Boolean autoTruncate; @@ -38,111 +38,142 @@ public VoyageAIEmbeddingProvider( String modelName, int dimension, Map serviceParameters) { - super(requestProperties, baseUrl, modelName, dimension, serviceParameters); + super( + ModelProvider.VOYAGE_AI, + requestProperties, + baseUrl, + modelName, + dimension, + serviceParameters); // use configured input_type if available requestTypeQuery = requestProperties.requestTypeQuery().orElse(null); requestTypeIndex = requestProperties.requestTypeIndex().orElse(null); + Object v = (serviceParameters == null) ? null : serviceParameters.get("autoTruncate"); autoTruncate = (v instanceof Boolean) ? (Boolean) v : null; - voyageAIEmbeddingProviderClient = + voyageClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(VoyageAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface VoyageAIEmbeddingProviderClient { - @POST - // no path specified, as it is already included in the baseUri - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {"detail":"You have not yet added your payment method in the billing page and will have reduced rate limits of 3 RPM and 10K TPM.  Please add your payment method in the billing page (https://dash.voyageai.com/billing/payment-methods) to unlock our standard rate limits (https://docs.voyageai.com/docs/rate-limits).  Even with payment methods entered, the free tokens (50M tokens per model) will still apply."}
-     *
-     * {"detail":"Provided API key is invalid."}
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "detail" node - JsonNode detailNode = rootNode.path("detail"); - // Return the text of the "detail" node, or the full response body if it is missing - return detailNode.isMissingNode() ? rootNode.toString() : detailNode.toString(); - } - } - - record EmbeddingRequest( - @JsonInclude(JsonInclude.Include.NON_EMPTY) String input_type, - String[] input, - String model, - @JsonInclude(JsonInclude.Include.NON_NULL) Boolean truncation) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - record EmbeddingResponse(Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - record Data(int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - record Usage(int total_tokens) {} + /** + * Response body with an error will look like below: + * + *
+   * {"detail":"You have not yet added your payment method in the billing page and will have reduced rate limits of 3 RPM and 10K TPM.  Please add your payment method in the billing page (https://dash.voyageai.com/billing/payment-methods) to unlock our standard rate limits (https://docs.voyageai.com/docs/rate-limits).  Even with payment methods entered, the free tokens (50M tokens per model) will still apply."}
+   *
+   * {"detail":"Provided API key is invalid."}
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/detail"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); + + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + + // TODO: remove the requestTypeQuery and requestTypeIndex from config ! + // aaron 8 June 2025 - this looks like the term to sue for query and index is in config, but + // there is + // NOT handling of when this config is not set final String inputType = (embeddingRequestType == EmbeddingRequestType.SEARCH) ? requestTypeQuery : requestTypeIndex; - String[] textArray = new String[texts.size()]; - EmbeddingRequest request = - new EmbeddingRequest(inputType, texts.toArray(textArray), modelName, autoTruncate); - Uni response = - applyRetry( - voyageAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + var voyageRequest = + new VoyageEmbeddingRequest( + inputType, texts.toArray(new String[texts.size()]), modelName(), autoTruncate); + + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(voyageClient.embed(accessToken, voyageRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var voyageResponse = jakartaResponse.readEntity(VoyageEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (voyageResponse.data() == null) { + throw new IllegalStateException( + "ModelProvider %s returned empty data for model %s" + .formatted(modelProvider(), modelName())); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + // TODO: WHY SORT ? + Arrays.sort(voyageResponse.data(), (a, b) -> a.index() - b.index()); + List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(voyageResponse.data()) + .map(VoyageEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + voyageResponse.usage.total_tokens, + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Voyage Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface VoyageAIEmbeddingProviderClient { + @POST + // no path specified, as it is already included in the baseUri + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, VoyageEmbeddingRequest request); + } + + /** + * Request structure of the Voyage REST service. + * + *

.. + */ + public record VoyageEmbeddingRequest( + @JsonInclude(JsonInclude.Include.NON_EMPTY) String input_type, + String[] input, + String model, + @JsonInclude(JsonInclude.Include.NON_NULL) Boolean truncation) {} + + /** + * Response structure of the Voyage REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record VoyageEmbeddingResponse(Data[] data, String model, Usage usage) { + @JsonIgnoreProperties(ignoreUnknown = true) + record Data(int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + record Usage(int total_tokens) { + // Voyage API does not return prompt_tokens + } } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/RerankingResponseErrorMessageMapper.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/RerankingResponseErrorMessageMapper.java index fe0f73b792..2e359ca2a6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/RerankingResponseErrorMessageMapper.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/RerankingResponseErrorMessageMapper.java @@ -17,6 +17,7 @@ public class RerankingResponseErrorMessageMapper { */ public static RuntimeException mapToAPIException( String providerName, Response response, String message) { + // Status code == 408 and 504 for timeout if (response.getStatus() == Response.Status.REQUEST_TIMEOUT.getStatusCode() || response.getStatus() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java index 1dba9e52ff..e54f08065f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java @@ -3,10 +3,11 @@ import io.quarkus.runtime.annotations.RegisterForReflection; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import java.util.*; /** * This is a test implementation of the EmbeddingProvider interface. It is used for @@ -32,9 +33,25 @@ public class CustomITEmbeddingProvider extends EmbeddingProvider { private int dimension; public CustomITEmbeddingProvider(int dimension) { + // aaron 9 June 2025 - refactoring , I think none of the super class is used, so passing dummy + // values + super( + ModelProvider.CUSTOM, + new EmbeddingProviderConfigStore.RequestProperties( + 1, 1, 1, 1, 1, Optional.empty(), Optional.empty(), 1), + "", + "", + 1, + Map.of()); + this.dimension = dimension; } + @Override + protected String errorMessageJsonPtr() { + return ""; + } + static { TEST_DATA_DIMENSION_5.put( "ChatGPT integrated sneakers that talk to you", @@ -68,16 +85,30 @@ public CustomITEmbeddingProvider(int dimension) { } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + List response = new ArrayList<>(texts.size()); - if (texts.size() == 0) return Uni.createFrom().item(Response.of(batchId, response)); - if (!embeddingCredentials.apiKey().isPresent() - || !embeddingCredentials.apiKey().get().equals(TEST_API_KEY)) + if (texts.isEmpty()) { + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom().item(new BatchedEmbeddingResponse(batchId, response, modelUsage)); + } + if (embeddingCredentials.apiKey().isEmpty() + || !embeddingCredentials.apiKey().get().equals(TEST_API_KEY)) { return Uni.createFrom().failure(new RuntimeException("Invalid API Key")); + } + for (String text : texts) { if (dimension == 5) { if (TEST_DATA_DIMENSION_5.containsKey(text)) { @@ -94,7 +125,17 @@ public Uni vectorize( } } } - return Uni.createFrom().item(Response.of(batchId, response)); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom().item(new BatchedEmbeddingResponse(batchId, response, modelUsage)); } @Override diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java index b3a3a5a52e..1d045f9dc9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java @@ -108,14 +108,14 @@ public static class EmbeddingResultSupplier implements BaseTask.UniSupplier embeddingTask; protected final CommandContext commandContext; - protected final BaseTask.UniSupplier supplier; + protected final BaseTask.UniSupplier supplier; protected final List actions; private final List vectorizeTexts; EmbeddingResultSupplier( EmbeddingTask embeddingTask, CommandContext commandContext, - BaseTask.UniSupplier supplier, + BaseTask.UniSupplier supplier, List actions, List vectorizeTexts) { this.embeddingTask = embeddingTask; @@ -178,7 +178,7 @@ private EmbeddingTaskResult(List rawVectors, List embeddingTask, CommandContext commandContext, - EmbeddingProvider.Response providerResponse, + EmbeddingProvider.BatchedEmbeddingResponse providerResponse, List actions) { commandContext diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java index da3931910e..1642edcdb7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java @@ -13,6 +13,7 @@ import io.stargate.sgv2.jsonapi.service.operation.tasks.BaseTask; import io.stargate.sgv2.jsonapi.service.operation.tasks.TaskRetryPolicy; import io.stargate.sgv2.jsonapi.service.projection.DocumentProjector; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.util.PathMatchLocator; import io.stargate.sgv2.jsonapi.util.recordable.Recordable; @@ -266,7 +267,7 @@ public Uni get() { RerankingTaskResult.create( requestTracing, rerankingProvider, - new RerankingProvider.RerankingResponse(List.of()), + new RerankingProvider.RerankingResponse(List.of(), ModelUsage.EMPTY), unrankedDocs, limit)); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java new file mode 100644 index 0000000000..cbd60cb008 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java @@ -0,0 +1,34 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import io.stargate.embedding.gateway.EmbeddingGateway; +import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import java.util.Optional; + +/** + * If the model usage was for indexing data or searching data + * + *

Keeps in parity with the grp proto definition in embedding_gateway.proto + */ +public enum ModelInputType { + INPUT_TYPE_UNSPECIFIED, + INDEX, + SEARCH; + + public static ModelInputType fromEmbeddingRequestType( + EmbeddingProvider.EmbeddingRequestType embeddingRequestType) { + return switch (embeddingRequestType) { + case INDEX -> INDEX; + case SEARCH -> SEARCH; + }; + } + + public static Optional fromEmbeddingGateway( + EmbeddingGateway.ModelUsage.InputType inputType) { + return switch (inputType) { + case INPUT_TYPE_UNSPECIFIED -> Optional.of(INPUT_TYPE_UNSPECIFIED); + case INDEX -> Optional.of(INDEX); + case SEARCH -> Optional.of(SEARCH); + default -> Optional.empty(); + }; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelProvider.java new file mode 100644 index 0000000000..30b20e2f3b --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelProvider.java @@ -0,0 +1,50 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +public enum ModelProvider { + AZURE_OPENAI("azureOpenAI"), + BEDROCK("bedrock"), + COHERE("cohere"), + CUSTOM("custom"), + HUGGINGFACE("huggingface"), + HUGGINGFACE_DEDICATED("huggingfaceDedicated"), + HUGGINGFACE_DEDICATED_DEFINED_MODEL("endpoint-defined-model"), + JINA_AI("jinaAI"), + MISTRAL("mistral"), + NVIDIA("nvidia"), + OPENAI("openai"), + UPSTAGE_AI("upstageAI"), + VERTEXAI("vertexai"), + VOYAGE_AI("voyageAI"); + + private static final Map API_NAME_TO_PROVIDER; + + static { + API_NAME_TO_PROVIDER = new HashMap<>(); + for (ModelProvider provider : ModelProvider.values()) { + API_NAME_TO_PROVIDER.put(provider.apiName(), provider); + } + } + + private final String apiName; + + ModelProvider(String apiName) { + this.apiName = apiName; + } + + public String apiName() { + return apiName; + } + + public static Optional fromApiName(String apiName) { + return Optional.ofNullable(API_NAME_TO_PROVIDER.get(apiName)); + } + + @Override + public String toString() { + return apiName; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java new file mode 100644 index 0000000000..1f4ce1a9d5 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java @@ -0,0 +1,25 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import io.stargate.embedding.gateway.EmbeddingGateway; +import java.util.Optional; + +/** + * If the model usage was for indexing data or searching data + * + *

Keeps in parity with the grp proto definition in embedding_gateway.proto + */ +public enum ModelType { + MODEL_TYPE_UNSPECIFIED, + EMBEDDING, + RERANKING; + + public static Optional fromEmbeddingGateway( + EmbeddingGateway.ModelUsage.ModelType modelType) { + return switch (modelType) { + case MODEL_TYPE_UNSPECIFIED -> Optional.of(MODEL_TYPE_UNSPECIFIED); + case EMBEDDING -> Optional.of(EMBEDDING); + case RERANKING -> Optional.of(RERANKING); + default -> Optional.empty(); + }; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java index 25b5a73645..5d22fda152 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java @@ -1,129 +1,206 @@ package io.stargate.sgv2.jsonapi.service.provider; import io.stargate.embedding.gateway.EmbeddingGateway; +import io.stargate.sgv2.jsonapi.util.recordable.Recordable; +import java.util.Objects; + +public final class ModelUsage implements Recordable { + + public static final ModelUsage EMPTY = new ModelUsage(true); + + private final ModelProvider modelProvider; + private final ModelType modelType; + private final String modelName; + private final String tenantId; + private final ModelInputType inputType; + private final int promptTokens; + private final int totalTokens; + private final int requestBytes; + private final int responseBytes; + private final long durationNanos; + private final int batchCount; + private final boolean isEmpty; -/** - * This class is to track the usage at the http request level to the embedding or reranking provider - * model service. - */ -public class ModelUsage { - - public final ProviderType providerType; - public final String provider; - public final String model; - - /** The number of bytes sent in the request. */ - private int requestBytes = 0; - - /** The number of bytes received in the response. Use content-length if present */ - private int responseBytes = 0; - - /** The number of tokens in the prompt, will be set if provider returned in the response. */ - private int promptTokens = 0; - - /** - * The total number of tokens in the request, will be set if provider returned in the response. - */ - private int totalTokens = 0; + public ModelUsage( + ModelProvider modelProvider, + ModelType modelType, + String modelName, + String tenantId, + ModelInputType inputType, + int promptTokens, + int totalTokens, + int requestBytes, + int responseBytes, + long durationNanos) { + this( + modelProvider, + modelType, + modelName, + tenantId, + inputType, + promptTokens, + totalTokens, + requestBytes, + responseBytes, + durationNanos, + 1, + false); + } - public ModelUsage(ProviderType providerType, String provider, String model) { - this.providerType = providerType; - this.provider = provider; - this.model = model; + private ModelUsage(boolean isEmpty) { + this(null, null, null, null, null, 0, 0, 0, 0, 0L, 0, false); } - public ModelUsage( - ProviderType providerType, - String provider, - String model, + private ModelUsage( + ModelProvider modelProvider, + ModelType modelType, + String modelName, + String tenantId, + ModelInputType inputType, + int promptTokens, + int totalTokens, int requestBytes, int responseBytes, - int promptTokens, - int totalTokens) { - this.providerType = providerType; - this.provider = provider; - this.model = model; - this.requestBytes = requestBytes; - this.responseBytes = responseBytes; + long durationNanos, + int batchCount, + boolean isEmpty) { + this.modelProvider = modelProvider; + this.modelType = modelType; + this.modelName = modelName; + this.tenantId = tenantId; + this.inputType = inputType; this.promptTokens = promptTokens; this.totalTokens = totalTokens; + this.requestBytes = requestBytes; + this.responseBytes = responseBytes; + this.durationNanos = durationNanos; + this.batchCount = batchCount; + this.isEmpty = isEmpty; } - /** Create the ModelUsage from the modelUsage of Embedding Gateway gRPC response. */ - public static ModelUsage fromGrpcResponse(EmbeddingGateway.ModelUsage modelUsage) { + public static ModelUsage fromEmbeddingGateway(EmbeddingGateway.ModelUsage grpcModelUsage) { return new ModelUsage( - ProviderType.valueOf(modelUsage.getProviderType()), - modelUsage.getProviderName(), - modelUsage.getModelName(), - modelUsage.getRequestBytes(), - modelUsage.getResponseBytes(), - modelUsage.getPromptTokens(), - modelUsage.getTotalTokens()); - } - - /** - * Parse the request and response bytes from the headers of the intercepted response. Headers are - * added in the {@link ProviderHttpInterceptor} registered by specified providerClient. - */ - public ModelUsage parseSentReceivedBytes(jakarta.ws.rs.core.Response interceptedResp) { - if (interceptedResp.getHeaders().get(ProviderHttpInterceptor.SENT_BYTES_HEADER) != null) { - this.requestBytes = - Integer.parseInt( - interceptedResp.getHeaderString(ProviderHttpInterceptor.SENT_BYTES_HEADER)); + ModelProvider.fromApiName(grpcModelUsage.getModelProvider()) + .orElseThrow( + () -> + new IllegalArgumentException( + "Unknown Embedding Gateway modelProvider: " + + grpcModelUsage.getModelProvider())), + ModelType.fromEmbeddingGateway(grpcModelUsage.getModelType()) + .orElseThrow( + () -> + new IllegalArgumentException( + "Unknown Embedding Gateway modelType: " + grpcModelUsage.getModelType())), + grpcModelUsage.getModelName(), + grpcModelUsage.getTenantId(), + ModelInputType.fromEmbeddingGateway(grpcModelUsage.getInputType()) + .orElseThrow( + () -> + new IllegalArgumentException( + "Unknown Embedding Gateway modelInputType: " + + grpcModelUsage.getInputType())), + grpcModelUsage.getPromptTokens(), + grpcModelUsage.getTotalTokens(), + grpcModelUsage.getRequestBytes(), + grpcModelUsage.getResponseBytes(), + grpcModelUsage.getCallDurationNanos()); + } + + public ModelUsage merge(ModelUsage other) { + + Objects.requireNonNull(other, "other must not be null"); + if (isEmpty && !other.isEmpty) { + return other; + } + if (other.isEmpty && !isEmpty) { + return this; } - if (interceptedResp.getHeaders().get(ProviderHttpInterceptor.RECEIVED_BYTES_HEADER) != null) { - this.responseBytes = - Integer.parseInt( - interceptedResp.getHeaderString(ProviderHttpInterceptor.RECEIVED_BYTES_HEADER)); + + if (!this.modelProvider.equals(other.modelProvider) + || !this.modelType.equals(other.modelType) + || !this.modelName.equals(other.modelName) + || !this.tenantId.equals(other.tenantId) + || !this.inputType.equals(other.inputType)) { + throw new IllegalArgumentException("Cannot merge ModelUsage with different properties"); } - return this; + + return new ModelUsage( + this.modelProvider, + this.modelType, + this.modelName, + this.tenantId, + this.inputType, + this.promptTokens + other.promptTokens, + this.totalTokens + other.totalTokens, + this.requestBytes + other.requestBytes, + this.responseBytes + other.responseBytes, + this.durationNanos + other.durationNanos, + this.batchCount + other.batchCount, + false); } - public ModelUsage setPromptTokens(int promptTokens) { - this.promptTokens = promptTokens; - return this; + public boolean isEmpty() { + return this.isEmpty; } - public ModelUsage setTotalTokens(int totalTokens) { - this.totalTokens = totalTokens; - return this; + public ModelProvider modelProvider() { + return modelProvider; } - public int getRequestBytes() { - return requestBytes; + public ModelType modelType() { + return modelType; } - public int getResponseBytes() { - return responseBytes; + public String modelName() { + return modelName; + } + + public String tenantId() { + return tenantId; + } + + public ModelInputType inputType() { + return inputType; } - public int getPromptTokens() { + public int promptTokens() { return promptTokens; } - public int getTotalTokens() { + public int totalTokens() { return totalTokens; } + public int requestBytes() { + return requestBytes; + } + + public int responseBytes() { + return responseBytes; + } + + public long durationNanos() { + return durationNanos; + } + + public int batchCount() { + return batchCount; + } + @Override - public String toString() { - return "ModelUsage{" - + "providerType=" - + providerType - + ", provider='" - + provider - + '\'' - + ", model='" - + model - + '\'' - + ", requestBytes=" - + requestBytes - + ", responseBytes=" - + responseBytes - + ", promptTokens=" - + promptTokens - + ", totalTokens=" - + totalTokens - + '}'; + public DataRecorder recordTo(DataRecorder dataRecorder) { + return dataRecorder + .append("modelProvider", modelProvider) + .append("modelType", modelType) + .append("modelName", modelName) + .append("tenantId", tenantId) + .append("inputType", inputType) + .append("promptTokens", promptTokens) + .append("totalTokens", totalTokens) + .append("requestBytes", requestBytes) + .append("responseBytes", responseBytes) + .append("durationNanos", durationNanos) + .append("batchCount", batchCount) + .append("isEmpty", isEmpty); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java new file mode 100644 index 0000000000..5b5ad0f8c6 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java @@ -0,0 +1,179 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import com.fasterxml.jackson.databind.JsonNode; +import io.smallrye.mutiny.Uni; +import io.stargate.embedding.gateway.EmbeddingGateway; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import java.time.Duration; +import java.util.concurrent.TimeoutException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Base for model providers such as embedding and reranking. */ +public abstract class ProviderBase { + protected static final Logger LOGGER = LoggerFactory.getLogger(ProviderBase.class); + + private final ModelProvider modelProvider; + private final ModelType modelType; + private final String modelName; + + protected ProviderBase(ModelProvider modelProvider, ModelType modelType, String modelName) { + this.modelProvider = modelProvider; + this.modelType = modelType; + this.modelName = modelName; + } + + public String modelName() { + return modelName; + } + + public ModelProvider modelProvider() { + return modelProvider; + } + + protected abstract String errorMessageJsonPtr(); + + protected abstract Duration initialBackOffDuration(); + + protected abstract Duration maxBackOffDuration(); + + protected abstract double jitter(); + + protected abstract int atMostRetries(); + + /** + * Applies a retry mechanism with backoff and jitter to the Uni returned by the rerank() method, + * which makes an HTTP request to a third-party service. + * + * @param The type of the item emitted by the Uni. + * @param uni The Uni to which the retry mechanism should be applied. + * @return A Uni that will retry on the specified failures with the configured backoff and jitter. + */ + protected Uni retryHTTPCall(Uni uni) { + + return uni.onItem() + .transform(this::handleHTTPResponse) + .onFailure(this::decideRetry) + .retry() + .withBackOff(initialBackOffDuration(), maxBackOffDuration()) + .withJitter(jitter()) + .atMost(atMostRetries()); + } + + protected boolean decideRetry(Throwable throwable) { + return throwable instanceof TimeoutException; + } + + protected Response handleHTTPResponse(Response jakartaResponse) { + + if (jakartaResponse.getStatus() >= 400) { + var runtimeException = handleHTTPError(jakartaResponse); + if (runtimeException != null) { + throw runtimeException; + } + throw new IllegalStateException( + String.format( + "Unhandled error from model provider, modelProvider: %s, modelName: %s, status: %d, responseBody: %s", + modelProvider(), + modelName(), + jakartaResponse.getStatus(), + jakartaResponse.readEntity(String.class))); + } + return jakartaResponse; + } + + protected RuntimeException handleHTTPError(Response jakartaResponse) { + + var errorMessage = responseErrorMessage(jakartaResponse); + LOGGER.error( + "Error response from model provider, modelProvider: {}, modelName:{}, http.status: {}, error: {}", + modelProvider, + modelName, + jakartaResponse.getStatus(), + errorMessage); + + return mapHTTPError(jakartaResponse, errorMessage); + } + + protected abstract RuntimeException mapHTTPError(Response response, String errorMessage); + + protected String responseErrorMessage(Response jakartaResponse) { + + MediaType contentType = jakartaResponse.getMediaType(); + String raw = jakartaResponse.readEntity(String.class); + + if (contentType == null || !MediaType.APPLICATION_JSON_TYPE.isCompatible(contentType)) { + LOGGER.error( + "Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}", + modelProvider(), + modelName(), + raw); + return raw; + } + + JsonNode rootNode = null; + try { + rootNode = jakartaResponse.readEntity(JsonNode.class); + } catch (Exception e) { + // If we cannot read the response as JsonNode, log the error and return the raw response + LOGGER.error( + "Error parsing error JSON from reranking provider, modelProvider: {}, modelName: {}", + modelProvider, + modelName, + e); + } + + return (rootNode == null) ? raw : responseErrorMessage(rootNode); + } + + protected String responseErrorMessage(JsonNode rootNode) { + + var messageNode = rootNode.at(errorMessageJsonPtr()); + return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); + } + + protected ModelUsage createModelUsage( + String tenantId, + ModelInputType modelInputType, + int promptTokens, + int totalTokens, + int requestBytes, + int responseBytes, + long durationNanos) { + + return new ModelUsage( + modelProvider, + modelType, + modelName, + tenantId, + modelInputType, + promptTokens, + totalTokens, + requestBytes, + responseBytes, + durationNanos); + } + + protected ModelUsage createModelUsage( + String tenantId, + ModelInputType modelInputType, + int promptTokens, + int totalTokens, + Response jakartaResponse, + long durationNanos) { + + return createModelUsage( + tenantId, + modelInputType, + promptTokens, + totalTokens, + ProviderHttpInterceptor.getSentBytes(jakartaResponse), + ProviderHttpInterceptor.getReceivedBytes(jakartaResponse), + durationNanos); + } + + protected ModelUsage createModelUsage(EmbeddingGateway.ModelUsage gatewayModelUsage) { + return ModelUsage.fromEmbeddingGateway(gatewayModelUsage); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java index 547eabf7c6..a1ae54fc27 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java @@ -1,13 +1,12 @@ package io.stargate.sgv2.jsonapi.service.provider; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.io.ByteStreams; import com.google.common.io.CountingOutputStream; import jakarta.ws.rs.client.ClientRequestContext; import jakarta.ws.rs.client.ClientResponseContext; import jakarta.ws.rs.client.ClientResponseFilter; +import jakarta.ws.rs.core.Response; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,26 +28,29 @@ public class ProviderHttpInterceptor implements ClientResponseFilter { private static final Logger LOGGER = LoggerFactory.getLogger(ProviderHttpInterceptor.class); private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - // Header name to track the sent_bytes to the provider - public static final String SENT_BYTES_HEADER = "sent-bytes"; - // Header name to track the received_bytes from the provider - public static final String RECEIVED_BYTES_HEADER = "received-bytes"; + /** Header name to track the sent_bytes to the provider */ + private static final String SENT_BYTES_HEADER = "sent-bytes"; + + /** Header name to track the received_bytes from the provider */ + private static final String RECEIVED_BYTES_HEADER = "received-bytes"; @Override public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) throws IOException { - int receivedBytes = 0; - int sentBytes = 0; + + long receivedBytes = 0; + long sentBytes = 0; // Parse the request entity stream to measure its size. if (requestContext.hasEntity()) { - try { - CountingOutputStream cus = new CountingOutputStream(OutputStream.nullOutputStream()); + try (var cus = new CountingOutputStream(OutputStream.nullOutputStream())) { OBJECT_MAPPER.writeValue(cus, requestContext.getEntity()); - cus.close(); - sentBytes = (int) cus.getCount(); + sentBytes = cus.getCount(); + } catch (Exception e) { - LOGGER.warn("Failed to measure request body size: " + e.getMessage()); + if (LOGGER.isWarnEnabled()) { + LOGGER.warn("Failed to measure request body size.", e); + } } } @@ -56,15 +58,35 @@ public void filter(ClientRequestContext requestContext, ClientResponseContext re // size. if (responseContext.hasEntity()) { receivedBytes = responseContext.getLength(); + // if provider does not return content-length in the response header. if (receivedBytes <= 0) { - // Read the response entity stream to measure its size - InputStream inputStream = responseContext.getEntityStream(); - receivedBytes = (int) ByteStreams.copy(inputStream, OutputStream.nullOutputStream()); + receivedBytes = + responseContext.getEntityStream().transferTo(OutputStream.nullOutputStream()); } } responseContext.getHeaders().add(SENT_BYTES_HEADER, String.valueOf(sentBytes)); responseContext.getHeaders().add(RECEIVED_BYTES_HEADER, String.valueOf(receivedBytes)); } + + public static int getSentBytes(Response jakartaResponse) { + return getHeaderInt(jakartaResponse, SENT_BYTES_HEADER); + } + + public static int getReceivedBytes(Response jakartaResponse) { + return getHeaderInt(jakartaResponse, RECEIVED_BYTES_HEADER); + } + + private static int getHeaderInt(Response jakartaResponse, String headerName) { + var headerString = jakartaResponse.getHeaderString(headerName); + if (headerString != null && !headerString.isBlank()) { + try { + return Integer.parseInt(headerString); + } catch (NumberFormatException e) { + LOGGER.warn("Failed to parse headerName:{}, headerString:{}", headerName, headerString, e); + } + } + return 0; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java deleted file mode 100644 index 8dacd2faeb..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderType.java +++ /dev/null @@ -1,11 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.provider; - -/** - * Enum representing the type of provider. - * - *

Used to differentiate between embedding and reranking providers. - */ -public enum ProviderType { - EMBEDDING_PROVIDER, - RERANKING_PROVIDER -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java index 16e6fabaea..a34ea9e289 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java @@ -8,7 +8,7 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import java.util.*; @@ -19,18 +19,13 @@ public class RerankingEGWClient extends RerankingProvider { private static final String DEFAULT_TENANT_ID = "default"; - /** - * This string acts as key of authTokens map, for passing Data API token to EGW in grpc request. - */ + /** Key of authTokens map, for passing Data API token to EGW in grpc request. */ private static final String DATA_API_TOKEN = "DATA_API_TOKEN"; - /** - * This string acts as key of authTokens map, for passing Reranking API key to EGW in grpc - * request. - */ + /** Key in the authTokens map, for passing Reranking API key to EGW in grpc request. */ private static final String RERANKING_API_KEY = "RERANKING_API_KEY"; - private final String provider; + private final ModelProvider modelProvider; private final Optional tenant; private final Optional authToken; private final String modelName; @@ -42,15 +37,16 @@ public RerankingEGWClient( String baseUrl, RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties requestProperties, - String provider, + ModelProvider modelProvider, Optional tenant, Optional authToken, String modelName, RerankingService rerankingGrpcService, Map authentication, String commandName) { - super(baseUrl, modelName, requestProperties); - this.provider = provider; + super(modelProvider, baseUrl, modelName, requestProperties); + + this.modelProvider = modelProvider; this.tenant = tenant; this.authToken = authToken; this.modelName = modelName; @@ -60,41 +56,43 @@ public RerankingEGWClient( } @Override - public Uni rerank( + protected String errorMessageJsonPtr() { + // not used here, we are just passing through. + return ""; + } + + @Override + public Uni rerank( int batchId, String query, List passages, RerankingCredentials rerankingCredentials) { - // Build the reranking provider request in grpc request - final EmbeddingGateway.ProviderRerankingRequest.RerankingRequest rerankingRequest = + var gatewayReranking = EmbeddingGateway.ProviderRerankingRequest.RerankingRequest.newBuilder() .setModelName(modelName) .setQuery(query) .addAllPassages(passages) + // TODO: Why is the command name passed here ? Can it be removed ? .setCommandName(commandName) .build(); - // Build the reranking provider context in grpc request var contextBuilder = EmbeddingGateway.ProviderRerankingRequest.ProviderContext.newBuilder() - .setProviderName(provider) + .setProviderName(modelProvider.apiName()) .setTenantId(tenant.orElse(DEFAULT_TENANT_ID)) .putAuthTokens(DATA_API_TOKEN, authToken.orElse("")); + rerankingCredentials + .apiKey() + .ifPresent(v -> contextBuilder.putAuthTokens(RERANKING_API_KEY, v)); - if (rerankingCredentials.apiKey().isPresent()) { - contextBuilder.putAuthTokens(RERANKING_API_KEY, rerankingCredentials.apiKey().get()); - } - final EmbeddingGateway.ProviderRerankingRequest.ProviderContext providerContext = - contextBuilder.build(); - - // Built the Grpc request - final EmbeddingGateway.ProviderRerankingRequest grpcRerankingRequest = + var gatewayRequest = EmbeddingGateway.ProviderRerankingRequest.newBuilder() - .setRerankingRequest(rerankingRequest) - .setProviderContext(providerContext) + .setRerankingRequest(gatewayReranking) + .setProviderContext(contextBuilder.build()) .build(); - Uni grpcRerankingResponse; + // TODO: XXX Why is this error handling here not part of the uni pipeline? + Uni gatewayRerankingUni; try { - grpcRerankingResponse = rerankingGrpcService.rerank(grpcRerankingRequest); + gatewayRerankingUni = rerankingGrpcService.rerank(gatewayRequest); } catch (StatusRuntimeException e) { if (e.getStatus().getCode().equals(Status.Code.DEADLINE_EXCEEDED)) { throw ErrorCodeV1.RERANKING_PROVIDER_TIMEOUT.toApiException(e, e.getMessage()); @@ -102,21 +100,23 @@ public Uni rerank( throw e; } - return grpcRerankingResponse + return gatewayRerankingUni .onItem() .transform( - resp -> { - if (resp.hasError()) { + gatewayResponse -> { + if (gatewayResponse.hasError()) { + // TODO : move to V2 error throw new JsonApiException( - ErrorCodeV1.valueOf(resp.getError().getErrorCode()), - resp.getError().getErrorMessage()); + ErrorCodeV1.valueOf(gatewayResponse.getError().getErrorCode()), + gatewayResponse.getError().getErrorMessage()); } - return RerankingBatchResponse.of( + + return new BatchedRerankingResponse( batchId, - resp.getRanksList().stream() + gatewayResponse.getRanksList().stream() .map(rank -> new Rank(rank.getIndex(), rank.getScore())) .collect(Collectors.toList()), - ModelUsage.fromGrpcResponse(resp.getModelUsage())); + createModelUsage(gatewayResponse.getModelUsage())); }); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java index 3d08878ab7..c99e522c13 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java @@ -1,24 +1,21 @@ package io.stargate.sgv2.jsonapi.service.reranking.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.RerankingResponseErrorMessageMapper; -import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; -import io.stargate.sgv2.jsonapi.service.provider.ProviderType; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -62,14 +59,17 @@ */ public class NvidiaRerankingProvider extends RerankingProvider { - private static final String providerId = ProviderConstants.NVIDIA; - private final NvidiaRerankingClient nvidiaRerankingClient; - - // Nvidia Reranking Service supports truncate or error when the passage is too long. - // Data API use NONE as default, means the reranking request will error out if there is a query - // and - // passage pair that exceeds allowed token size 8192 - // https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html#token-limits-truncation + private final NvidiaRerankingClient nvidiaClient; + + /** + * Nvidia Reranking Service supports truncation or error behavior when the passage is too long. + * + *

The Data API uses {@code NONE} as the default, which means the reranking request will error + * out if there is a query and passage pair that exceeds the allowed token size of 8192. + * + *

See: + * https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html#token-limits-truncation + */ private static final String TRUNCATE_PASSAGE = "NONE"; public NvidiaRerankingProvider( @@ -77,14 +77,72 @@ public NvidiaRerankingProvider( String modelName, RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties requestProperties) { - super(baseUrl, modelName, requestProperties); - nvidiaRerankingClient = + super(ModelProvider.NVIDIA, baseUrl, modelName, requestProperties); + + nvidiaClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(NvidiaRerankingClient.class); } + @Override + protected String errorMessageJsonPtr() { + return "/message"; + } + + @Override + protected Uni rerank( + int batchId, String query, List passages, RerankingCredentials rerankingCredentials) { + + // TODO: Move error to v2 + var accessToken = + rerankingCredentials + .apiKey() + .map(apiKey -> HttpConstants.BEARER_PREFIX_FOR_API_KEY + apiKey) + .orElseThrow( + () -> + ErrorCodeV1.RERANKING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( + "In order to rerank, please provide the reranking API key.")); + + var nvidiaRequest = + new NvidiaRerankingRequest( + modelName(), + new NvidiaRerankingRequest.TextWrapper(query), + passages.stream().map(NvidiaRerankingRequest.TextWrapper::new).toList(), + TRUNCATE_PASSAGE); + + final long callStartNano = System.nanoTime(); + return retryHTTPCall(nvidiaClient.rerank(accessToken, nvidiaRequest)) + .onItem() + .transform( + jakartaResponse -> { + var nvidiaResponse = jakartaResponse.readEntity(NvidiaRerankingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // converting from the specific Nvidia response to the generic RerankingBatchResponse + var ranks = + nvidiaResponse.rankings().stream() + .map(rank -> new Rank(rank.index(), rank.logit())) + .toList(); + + var modelUsage = + createModelUsage( + rerankingCredentials.tenantId(), + ModelInputType.INPUT_TYPE_UNSPECIFIED, + nvidiaResponse.usage().prompt_tokens, + nvidiaResponse.usage().total_tokens, + jakartaResponse, + callDurationNano); + return new BatchedRerankingResponse(batchId, ranks, modelUsage); + }); + } + + /** + * REST client interface for the Nvidia Reranking Service. + * + *

.. + */ @RegisterRestClient @RegisterProvider(RerankingProviderResponseValidation.class) @RegisterProvider(ProviderHttpInterceptor.class) @@ -92,86 +150,37 @@ public interface NvidiaRerankingClient { @POST @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni rerank( - @HeaderParam("Authorization") String accessToken, RerankingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return RerankingResponseErrorMessageMapper.mapToAPIException( - providerId, response, errorMessage); - } - - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from reranking provider '{}': {}", providerId, rootNode.toString()); - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } + Uni rerank( + @HeaderParam("Authorization") String accessToken, NvidiaRerankingRequest request); } - /** reranking request to the Nvidia Reranking Service */ - private record RerankingRequest( + /** + * Request structure of the NVIDIA REST service. + * + *

.. + */ + public record NvidiaRerankingRequest( String model, TextWrapper query, List passages, String truncate) { + /** * query and passage string needs to be wrapped in with text key for request to the Nvidia * Reranking Service. E.G. { "text": "which way should i go?" } */ - private record TextWrapper(String text) {} + record TextWrapper(String text) {} } - /** reranking response from the Nvidia reranking Service */ + /** + * Response structure of hte NVIDIA REST service. + * + *

.. + */ @JsonIgnoreProperties(ignoreUnknown = true) - private record RerankingResponse(List rankings, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Ranking(int index, float logit) {} + record NvidiaRerankingResponse(List rankings, NvidiaUsage usage) { @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} - } - - @Override - public Uni rerank( - int batchId, String query, List passages, RerankingCredentials rerankingCredentials) { - - RerankingRequest request = - new RerankingRequest( - modelName, - new RerankingRequest.TextWrapper(query), - passages.stream().map(RerankingRequest.TextWrapper::new).toList(), - TRUNCATE_PASSAGE); + record NvidiaRanking(int index, float logit) {} - if (rerankingCredentials.apiKey().isEmpty()) { - throw ErrorCodeV1.RERANKING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( - "In order to rerank, please provide the reranking API key."); - } - - Uni response = - applyRetry( - nvidiaRerankingClient.rerank( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + rerankingCredentials.apiKey().get(), - request)); - - return response - .onItem() - .transform( - interceptedResp -> { - RerankingResponse providerResp = interceptedResp.readEntity(RerankingResponse.class); - List ranks = - providerResp.rankings().stream() - .map(rank -> new Rank(rank.index(), rank.logit())) - .toList(); - - ModelUsage modelUsage = - new ModelUsage(ProviderType.RERANKING_PROVIDER, providerId, modelName) - .setPromptTokens(providerResp.usage().prompt_tokens) - .setTotalTokens(providerResp.usage().total_tokens) - .parseSentReceivedBytes(interceptedResp); - return RerankingBatchResponse.of(batchId, ranks, modelUsage); - }); + @JsonIgnoreProperties(ignoreUnknown = true) + record NvidiaUsage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java index c22f90a24c..f7c04bdf2e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java @@ -1,62 +1,73 @@ package io.stargate.sgv2.jsonapi.service.reranking.operation; +import static jakarta.ws.rs.core.Response.Status.Family.CLIENT_ERROR; + import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.service.provider.*; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; +import jakarta.ws.rs.core.Response; import java.time.Duration; import java.util.ArrayList; import java.util.Comparator; import java.util.List; -import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class RerankingProvider { - protected static final Logger logger = LoggerFactory.getLogger(RerankingProvider.class); +public abstract class RerankingProvider extends ProviderBase { + + protected static final Logger LOGGER = LoggerFactory.getLogger(RerankingProvider.class); + protected final String baseUrl; - protected final String modelName; protected final RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties requestProperties; + protected final Duration initialBackOffDuration; + + protected final Duration maxBackOffDuration; + protected RerankingProvider( + ModelProvider modelProvider, String baseUrl, String modelName, RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties requestProperties) { + super(modelProvider, ModelType.RERANKING, modelName); + this.baseUrl = baseUrl; - this.modelName = modelName; this.requestProperties = requestProperties; - } - public String modelName() { - return modelName; + this.initialBackOffDuration = Duration.ofMillis(requestProperties.initialBackOffMillis()); + this.maxBackOffDuration = Duration.ofMillis(requestProperties.maxBackOffMillis()); } /** - * Gather the results from all batch reranking calls, adjust the indices, so they refer to the - * original passages list, and return a final RerankingResponse as the original order of the - * passages with the reranking score. + * Reranks the texts, batching as needed, and returns a final RerankingResponse as the original + * order of the passages with the reranking score. * - *

E.G. if the original passages list is ["a", "b", "c", "d", "e"] and the micro batch is 2, - * then API will do 3 batch reranking calls: ["a", "b"], ["c", "d"], ["e"]. 3 response will be - * returned: + *

E.G. if the original passages list is ["a", "b", "c", "d", "e"] and the micro + * batch is 2, then API will do 3 batch reranking calls: ["a", "b"], ["c", "d"], ["e"] + * 3 response will be returned: * *

    - *
  • batch 0: [{index:1, score:x1}, {index:0, score:x2}] - *
  • batch 1: [{index:0, score:x3}, {index:1, score:x4}] - *
  • batch 2: [{index:0, score:x5}] + *
  • batch 0: [{index:1, score:x1}, {index:0, score:x2}] + *
  • batch 1: [{index:0, score:x3}, {index:1, score:x4}] + *
  • batch 2: [{index:0, score:x5}] *
* - * Then this method will adjust the indices and return the final response: [{index:0, score:x1}, - * {index:1, score:x2}, {index:2, score:x3}, {index:3, score:x4}, {index:4, score:x5}] + * Then this method will adjust the indices and return the final response: + * [{index:0, score:x1}, + * {index:1, score:x2}, {index:2, score:x3}, {index:3, score:x4}, {index:4, score:x5}] */ public Uni rerank( String query, List passages, RerankingCredentials rerankingCredentials) { + + // TODO: what to do if passages is empty? List> passageBatches = createPassageBatches(passages); - List> batchRerankings = new ArrayList<>(); + List> batchRerankings = new ArrayList<>(); + for (int batchId = 0; batchId < passageBatches.size(); batchId++) { batchRerankings.add( rerank(batchId, query, passageBatches.get(batchId), rerankingCredentials)); @@ -65,8 +76,87 @@ public Uni rerank( return Uni.join().all(batchRerankings).andFailFast().map(this::aggregateRanks); } + /** + * Subclasses must implement to do the reranking, after the batching is done. + * + *

... + */ + protected abstract Uni rerank( + int batchId, String query, List passages, RerankingCredentials rerankingCredentials); + + @Override + protected Duration initialBackOffDuration() { + return initialBackOffDuration; + } + + @Override + protected Duration maxBackOffDuration() { + return maxBackOffDuration; + } + + @Override + protected double jitter() { + return requestProperties.jitter(); + } + + @Override + protected int atMostRetries() { + return requestProperties.atMostRetries(); + } + + @Override + protected boolean decideRetry(Throwable throwable) { + + var retry = + (throwable.getCause() instanceof JsonApiException jae + && jae.getErrorCode() == ErrorCodeV1.RERANKING_PROVIDER_TIMEOUT); + + return retry || super.decideRetry(throwable); + } + + @Override + protected RuntimeException mapHTTPError(Response jakartaResponse, String errorMessage) { + + // TODO: move to V2 errors + + if (jakartaResponse.getStatus() == Response.Status.REQUEST_TIMEOUT.getStatusCode() + || jakartaResponse.getStatus() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { + + return ErrorCodeV1.RERANKING_PROVIDER_TIMEOUT.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + if (jakartaResponse.getStatus() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { + + return ErrorCodeV1.RERANKING_PROVIDER_RATE_LIMITED.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + if (jakartaResponse.getStatusInfo().getFamily() == CLIENT_ERROR) { + + return ErrorCodeV1.RERANKING_PROVIDER_CLIENT_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + if (jakartaResponse.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) { + + return ErrorCodeV1.RERANKING_PROVIDER_SERVER_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // All other errors, Should never happen as all errors are covered above + return ErrorCodeV1.RERANKING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + /** Create batches of passages to be reranked. */ private List> createPassageBatches(List passages) { + List> batches = new ArrayList<>(); for (int i = 0; i < passages.size(); i += requestProperties.maxBatchSize()) { batches.add( @@ -76,55 +166,47 @@ private List> createPassageBatches(List passages) { } /** Aggregate the ranks from all batched reranking calls. */ - private RerankingResponse aggregateRanks(List batchResponses) { + private RerankingResponse aggregateRanks(List batchResponses) { + List finalRanks = new ArrayList<>(); - for (RerankingBatchResponse batchResponse : batchResponses) { + ModelUsage aggregatedModelUsage = null; + + for (BatchedRerankingResponse batchResponse : batchResponses) { int batchStartIndex = batchResponse.batchId() * requestProperties.maxBatchSize(); + + aggregatedModelUsage = + aggregatedModelUsage == null + ? batchResponse.modelUsage() + : aggregatedModelUsage.merge(batchResponse.modelUsage()); for (Rank rank : batchResponse.ranks()) { finalRanks.add(new Rank(batchStartIndex + rank.index(), rank.score())); } } // This is the original order of all the passages. finalRanks.sort(Comparator.comparingInt(Rank::index)); - return new RerankingResponse(finalRanks); + return new RerankingResponse(finalRanks, aggregatedModelUsage); } - public record RerankingResponse(List ranks) {} - - /** Micro batch rerank method, which will rerank a batch of passages. */ - public abstract Uni rerank( - int batchId, String query, List passages, RerankingCredentials rerankingCredentials); - - /** The response of a batch rerank call. */ - public record RerankingBatchResponse(int batchId, List ranks, ModelUsage modelUsage) { - public static RerankingBatchResponse of( - int batchId, List rankings, ModelUsage modelUsage) { - return new RerankingBatchResponse(batchId, rankings, modelUsage); - } - } + /** + * Unbatched reranking response, returned from the public {@link #rerank(String, List, + * RerankingCredentials)} + * + *

... + */ + public record RerankingResponse(List ranks, ModelUsage modelUsage) {} - public record Rank(int index, float score) {} + /** + * Unbatched reranking response, returned from the protected {@link #rerank(int, String, List, + * RerankingCredentials)} + * + *

... + */ + public record BatchedRerankingResponse(int batchId, List ranks, ModelUsage modelUsage) {} /** - * Applies a retry mechanism with backoff and jitter to the Uni returned by the rerank() method, - * which makes an HTTP request to a third-party service. + * Individual rank and the index of the input passage. * - * @param The type of the item emitted by the Uni. - * @param uni The Uni to which the retry mechanism should be applied. - * @return A Uni that will retry on the specified failures with the configured backoff and jitter. + *

... */ - protected Uni applyRetry(Uni uni) { - return uni.onFailure( - throwable -> - (throwable.getCause() != null - && throwable.getCause() instanceof JsonApiException jae - && jae.getErrorCode() == ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT) - || throwable instanceof TimeoutException) - .retry() - .withBackOff( - Duration.ofMillis(requestProperties.initialBackOffMillis()), - Duration.ofMillis(requestProperties.maxBackOffMillis())) - .withJitter(requestProperties.jitter()) - .atMost(requestProperties.atMostRetries()); - } + public record Rank(int index, float score) {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java index 308521dc56..5f9be2b547 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java @@ -4,6 +4,7 @@ import io.stargate.embedding.gateway.RerankingService; import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.gateway.RerankingEGWClient; import jakarta.enterprise.context.ApplicationScoped; @@ -27,8 +28,8 @@ RerankingProvider create( requestProperties); } - private static final Map RERANKING_PROVIDER_CONSTRUCTOR_MAP = - Map.ofEntries(Map.entry("nvidia", NvidiaRerankingProvider::new)); + private static final Map RERANKING_PROVIDER_CTORS = + Map.ofEntries(Map.entry(ModelProvider.NVIDIA, NvidiaRerankingProvider::new)); public RerankingProvider getConfiguration( Optional tenant, @@ -37,26 +38,34 @@ public RerankingProvider getConfiguration( String modelName, Map authentication, String commandName) { - return addService(tenant, authToken, serviceName, modelName, authentication, commandName); + + var modelProvider = + ModelProvider.fromApiName(serviceName) + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Unknown reranking service provider '%s'", serviceName))); + return addService(tenant, authToken, modelProvider, modelName, authentication, commandName); } private synchronized RerankingProvider addService( Optional tenant, Optional authToken, - String serviceName, + ModelProvider modelProvider, String modelName, Map authentication, String commandName) { - final RerankingProvidersConfig.RerankingProviderConfig configuration = - rerankingConfig.providers().get(serviceName); - RerankingProviderFactory.ProviderConstructor ctor = - RERANKING_PROVIDER_CONSTRUCTOR_MAP.get(serviceName); + + var rerankingProvderConfig = rerankingConfig.providers().get(modelProvider.apiName()); + + RerankingProviderFactory.ProviderConstructor ctor = RERANKING_PROVIDER_CTORS.get(modelProvider); if (ctor == null) { throw ErrorCodeV1.RERANKING_SERVICE_TYPE_UNAVAILABLE.toApiException( - "unknown service provider '%s'", serviceName); + "unknown service provider '%s'", modelProvider.apiName()); } + var modelConfig = - configuration.models().stream() + rerankingProvderConfig.models().stream() .filter(model -> model.name().equals(modelName)) .findFirst() .orElseThrow( @@ -69,7 +78,7 @@ private synchronized RerankingProvider addService( return new RerankingEGWClient( modelConfig.url(), modelConfig.properties(), - serviceName, + modelProvider, tenant, authToken, modelName, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java index 06776f0f2c..80534c9a2a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java @@ -5,7 +5,7 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import java.util.ArrayList; @@ -51,7 +51,7 @@ public VectorizeConfigValidator( */ public Integer validateService(VectorizeConfig userConfig, Integer userVectorDimension) { // Only for internal tests - if (userConfig.provider().equals(ProviderConstants.CUSTOM)) { + if (userConfig.provider().equals(ModelProvider.CUSTOM.apiName())) { return userVectorDimension; } // Check if the service provider exists and is enabled @@ -322,10 +322,10 @@ private Integer validateModelAndDimension( // Find the model configuration by matching the model name // 1. huggingfaceDedicated does not require model, but requires dimension - if (userConfig.provider().equals(ProviderConstants.HUGGINGFACE_DEDICATED)) { + if (userConfig.provider().equals(ModelProvider.HUGGINGFACE_DEDICATED.apiName())) { if (userVectorDimension == null) { throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "'dimension' is needed for provider %s", ProviderConstants.HUGGINGFACE_DEDICATED); + "'dimension' is needed for provider %s", ModelProvider.HUGGINGFACE_DEDICATED.apiName()); } } diff --git a/src/main/proto/embedding_gateway.proto b/src/main/proto/embedding_gateway.proto index 6d47a4281a..a114a5f319 100644 --- a/src/main/proto/embedding_gateway.proto +++ b/src/main/proto/embedding_gateway.proto @@ -5,28 +5,28 @@ option java_package = "io.stargate.embedding.gateway"; package stargate; // The request message that is sent to embedding gateway gRPC API message ProviderEmbedRequest { - ProviderContext provider_context = 1; - EmbeddingRequest embedding_request = 2; + ProviderContext provider_context = 1; + EmbeddingRequest embedding_request = 2; // The provider context message for the embedding gateway gRPC API message ProviderContext { - string provider_name = 1; - string tenant_id = 2; - map auth_tokens = 3; + string provider_name = 1; + string tenant_id = 2; + map auth_tokens = 3; } // The request message for the embedding gateway gRPC API message EmbeddingRequest { // The model name for the embedding request - string model_name = 1; + string model_name = 1; // The dimensions of the embedding, some providers supports multiple dimensions - optional int32 dimensions = 2; + optional int32 dimensions = 2; // The parameter value, used when provided needs user specified parameters - map parameters = 3; + map parameters = 3; // The input type for the embedding request - InputType input_type = 4; + InputType input_type = 4; // The input data that needs to be vectorized - repeated string inputs = 5; + repeated string inputs = 5; // The command contains vectorize string command_name = 6; @@ -34,10 +34,10 @@ message ProviderEmbedRequest { // The parameter value, used when provided needs user specified parameters message ParameterValue { oneof ParameterValueOneOf { - string str_value = 1; - int32 int_value = 2; - float float_value = 3; - bool bool_value = 4; + string str_value = 1; + int32 int_value = 2; + float float_value = 3; + bool bool_value = 4; } } @@ -53,16 +53,16 @@ message ProviderEmbedRequest { // The response message for the embedding gateway gRPC API if successful message EmbeddingResponse { - ModelUsage modelUsage = 1; - repeated FloatEmbedding embeddings = 2; + ModelUsage modelUsage = 1; + repeated FloatEmbedding embeddings = 2; ErrorResponse error = 3; // The embedding response message message FloatEmbedding { // The index of the embedding corresponding to the input - int32 index = 1; + int32 index = 1; // The embedding values - repeated float embedding = 2; + repeated float embedding = 2; } // The error response message for the embedding gateway gRPC API @@ -78,7 +78,7 @@ message GetSupportedProvidersRequest {} // The response message for the get supported providers gRPC API if successful message GetSupportedProvidersResponse { - map supportedProviders = 1; + map supportedProviders = 1; ErrorResponse error = 2; // ProviderConfig message represents configuration for an embedding provider. @@ -193,20 +193,20 @@ service EmbeddingService { // The reranking request message that is sent to embedding gateway gRPC API message ProviderRerankingRequest { - ProviderContext provider_context = 1; - RerankingRequest Reranking_request = 2; + ProviderContext provider_context = 1; + RerankingRequest Reranking_request = 2; message ProviderContext { - string provider_name = 1; - string tenant_id = 2; - map auth_tokens = 3; + string provider_name = 1; + string tenant_id = 2; + map auth_tokens = 3; } message RerankingRequest { // The model name for the reranking request - string model_name = 1; + string model_name = 1; // The query text for the reranking request - string query = 2; + string query = 2; // The passages texts for the reranking request - repeated string passages = 3; + repeated string passages = 3; // The command contains reranking string command_name = 4; } @@ -215,16 +215,16 @@ message ProviderRerankingRequest { // The reranking response message for the embedding gateway gRPC API if successful message RerankingResponse { - ModelUsage modelUsage = 1; - repeated Rank ranks = 2; - ErrorResponse error = 3; + ModelUsage modelUsage = 1; + repeated Rank ranks = 2; + ErrorResponse error = 3; // Reranking result for each passage message Rank { // The rank index of the passage - int32 index = 1; + int32 index = 1; // The rank score value of the passage - float score = 2; + float score = 2; } message ErrorResponse { @@ -238,7 +238,7 @@ message GetSupportedRerankingProvidersRequest {} // The response message for the get supported reranking providers gRPC API if successful message GetSupportedRerankingProvidersResponse { - map supportedProviders = 1; + map supportedProviders = 1; ErrorResponse error = 2; // ProviderConfig message represents configuration for an reranking provider. @@ -296,20 +296,34 @@ service RerankingService { rpc GetSupportedRerankingProviders (GetSupportedRerankingProvidersRequest) returns (GetSupportedRerankingProvidersResponse){} } - - -// Common messages definition shared by both embedding and reranking - -// The usage statistics for the embedding gateway gRPC API on successful response from the provider +// Common structure for all model usage tracking, is included in response messages message ModelUsage { - string provider_type = 1; - string provider_name = 2; - string model_name = 3; - string tenant_id = 4; - int32 prompt_tokens = 5; - int32 total_tokens = 6; - int32 request_bytes = 7; - int32 response_bytes = 8; - int32 call_duration_us = 9; + string model_provider = 1; + ModelType model_type = 2; + string model_name = 3; + string tenant_id = 4; + InputType input_type = 5; + // tokens sent in the request + int32 prompt_tokens = 6; + // total tokens the request will be billed for + int32 total_tokens = 7; + // number of bytes in the outgoing http request sent to the provider + int32 request_bytes = 8; + // number of bytes in the response received from the provider + int32 response_bytes = 9; + int64 call_duration_nanos = 10; + + // If the model usage was for indexing data or searching data + enum InputType { + INPUT_TYPE_UNSPECIFIED = 0; + INDEX = 1; + SEARCH = 2; + } + + enum ModelType { + MODEL_TYPE_UNSPECIFIED = 0; + EMBEDDING = 1; + RERANKING = 2; + } } diff --git a/src/main/resources/embedding-providers-config.yaml b/src/main/resources/embedding-providers-config.yaml index 10e2c366ec..22da96cd29 100644 --- a/src/main/resources/embedding-providers-config.yaml +++ b/src/main/resources/embedding-providers-config.yaml @@ -272,7 +272,6 @@ stargate: required: false default-value: true help: "If set to false, text that exceeds the token limit causes the request to fail. The default value is true." - # OUT OF SCOPE FOR INITIAL PREVIEW cohere: # see https://docs.cohere.com/reference/embed display-name: Cohere diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index 443e75f412..4a3fc9f05f 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -20,6 +20,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorizeDefinition; import io.stargate.sgv2.jsonapi.service.embedding.DataVectorizer; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.schema.EmbeddingSourceModel; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionLexicalConfig; @@ -42,7 +43,7 @@ public class DataVectorizerTest { private TestEmbeddingProvider testEmbeddingProvider = new TestEmbeddingProvider(); private final EmbeddingProvider testService = testEmbeddingProvider; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.empty(), Optional.empty(), Optional.empty()); + new EmbeddingCredentials("test-tenant", Optional.empty(), Optional.empty(), Optional.empty()); private CollectionSchemaObject collectionSettings = null; @@ -196,7 +197,7 @@ public void testWithUnmatchedVectorsNumber() { TestEmbeddingProvider testProvider = new TestEmbeddingProvider() { @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, @@ -205,7 +206,18 @@ public Uni vectorize( texts.forEach(t -> customResponse.add(new float[] {0.5f, 0.5f, 0.5f})); // add additional vector customResponse.add(new float[] {0.5f, 0.5f, 0.5f}); - return Uni.createFrom().item(Response.of(batchId, customResponse)); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom() + .item(new BatchedEmbeddingResponse(batchId, customResponse, modelUsage)); } }; List documents = new ArrayList<>(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java index 6b06059b2b..f2175d2b5a 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java @@ -16,6 +16,7 @@ import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.gateway.EmbeddingGatewayClient; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import java.util.Arrays; import java.util.List; @@ -30,7 +31,7 @@ public class EmbeddingGatewayClientTest { public static final String TESTING_COMMAND_NAME = "test_command"; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.empty(), Optional.empty(), Optional.empty()); + new EmbeddingCredentials("test-tenant", Optional.empty(), Optional.empty(), Optional.empty()); // for [data-api#1088] (NPE for VoyageAI provider) @Test @@ -77,7 +78,7 @@ void handleValidResponse() { new EmbeddingGatewayClient( EmbeddingProviderConfigStore.RequestProperties.of( 5, 5, 5, 5, 0.5, Optional.empty(), Optional.empty(), 2048), - "openai", + ModelProvider.OPENAI, 1536, Optional.of("default"), Optional.of("default"), @@ -88,7 +89,7 @@ void handleValidResponse() { Map.of(), TESTING_COMMAND_NAME); - final EmbeddingProvider.Response response = + final EmbeddingProvider.BatchedEmbeddingResponse response = embeddingGatewayClient .vectorize( 1, @@ -127,7 +128,7 @@ void handleError() { new EmbeddingGatewayClient( EmbeddingProviderConfigStore.RequestProperties.of( 5, 5, 5, 5, 0.5, Optional.empty(), Optional.empty(), 2048), - "openai", + ModelProvider.OPENAI, 1536, Optional.of("default"), Optional.of("default"), diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java index c8d746eeb9..8266981085 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java @@ -23,7 +23,8 @@ public class EmbeddingProviderErrorMessageTest { private static final int DEFAULT_DIMENSIONS = 0; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.of("test"), Optional.empty(), Optional.empty()); + new EmbeddingCredentials( + "test-tenant", Optional.of("test"), Optional.empty(), Optional.empty()); @Inject EmbeddingProvidersConfig config; @@ -139,7 +140,7 @@ public void testRetryError() throws Exception { @Test public void testCorrectHeaderAndBody() { - final EmbeddingProvider.Response result = + final EmbeddingProvider.BatchedEmbeddingResponse result = new NvidiaEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties.of( 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), @@ -247,7 +248,7 @@ public void testNoJsonResponse() { @Test public void testEmptyJsonResponse() { - final EmbeddingProvider.Response result = + final EmbeddingProvider.BatchedEmbeddingResponse result = new NvidiaEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties.of( 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java index 457a1170a9..ee27d1507a 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java @@ -25,13 +25,14 @@ public class OpenAiEmbeddingClientTest { @Inject EmbeddingProvidersConfig config; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.of("test"), Optional.empty(), Optional.empty()); + new EmbeddingCredentials( + "test-tenant", Optional.of("test"), Optional.empty(), Optional.empty()); @Nested class OpenAiEmbeddingTest { @Test public void happyPath() throws Exception { - final EmbeddingProvider.Response response = + final EmbeddingProvider.BatchedEmbeddingResponse response = new OpenAIEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties.of( 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), @@ -49,7 +50,7 @@ public void happyPath() throws Exception { .awaitItem() .getItem(); assertThat(response) - .isInstanceOf(EmbeddingProvider.Response.class) + .isInstanceOf(EmbeddingProvider.BatchedEmbeddingResponse.class) .satisfies( r -> { assertThat(r.embeddings()).isNotNull(); @@ -60,7 +61,7 @@ public void happyPath() throws Exception { @Test public void onlyToken() throws Exception { - final EmbeddingProvider.Response response = + final EmbeddingProvider.BatchedEmbeddingResponse response = new OpenAIEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties.of( 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), @@ -78,7 +79,7 @@ public void onlyToken() throws Exception { .awaitItem() .getItem(); assertThat(response) - .isInstanceOf(EmbeddingProvider.Response.class) + .isInstanceOf(EmbeddingProvider.BatchedEmbeddingResponse.class) .satisfies( r -> { assertThat(r.embeddings()).isNotNull(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java index 4a56cb5587..5682e46469 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java @@ -8,6 +8,9 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorizeDefinition; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.schema.EmbeddingSourceModel; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionLexicalConfig; @@ -16,16 +19,29 @@ import io.stargate.sgv2.jsonapi.service.schema.collections.IdConfig; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Optional; public class TestEmbeddingProvider extends EmbeddingProvider { - private TestConstants testConstants = new TestConstants(); + private final TestConstants TEST_CONSTANTS = new TestConstants(); + + public TestEmbeddingProvider() { + super( + ModelProvider.CUSTOM, + new EmbeddingProviderConfigStore.RequestProperties( + 3, 5, 5000, 5, 0.5, Optional.empty(), Optional.empty(), 100), + "http://mock.com", + "mockModel", + 1024, + Map.of()); + } public CommandContext commandContextWithVectorize() { - return testConstants.collectionContext( + return TEST_CONSTANTS.collectionContext( "testCommand", new CollectionSchemaObject( - testConstants.SCHEMA_OBJECT_NAME, + TEST_CONSTANTS.SCHEMA_OBJECT_NAME, null, IdConfig.defaultIdConfig(), VectorConfig.fromColumnDefinitions( @@ -44,7 +60,13 @@ public CommandContext commandContextWithVectorize() { } @Override - public Uni vectorize( + protected String errorMessageJsonPtr() { + // not used in tests + return ""; + } + + @Override + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, @@ -55,7 +77,17 @@ public Uni vectorize( if (t.equals("return 1s")) response.add(new float[] {1.0f, 1.0f, 1.0f}); else response.add(new float[] {0.25f, 0.25f, 0.25f}); }); - return Uni.createFrom().item(Response.of(batchId, response)); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom().item(new BatchedEmbeddingResponse(batchId, response, modelUsage)); } @Override diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingGatewayClientTest.java similarity index 82% rename from src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java rename to src/test/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingGatewayClientTest.java index e95d618ba7..90b06c9009 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingGatewayClientTest.java @@ -1,4 +1,4 @@ -package io.stargate.sgv2.jsonapi.service.reranking; +package io.stargate.sgv2.jsonapi.service.reranking.gateway; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -14,9 +14,8 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.provider.ProviderType; -import io.stargate.sgv2.jsonapi.service.reranking.gateway.RerankingEGWClient; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ModelType; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import java.util.List; @@ -36,7 +35,7 @@ public class RerankingGatewayClientTest { public static final String TESTING_COMMAND_NAME = "test_command"; private static final RerankingCredentials RERANK_CREDENTIALS = - new RerankingCredentials(Optional.of("mocked reranking api key")); + new RerankingCredentials("test-tenant", Optional.of("mocked reranking api key")); @Test void handleValidResponse() { @@ -60,8 +59,8 @@ void handleValidResponse() { // mock model usage builder.setModelUsage( EmbeddingGateway.ModelUsage.newBuilder() - .setProviderType(ProviderType.RERANKING_PROVIDER.name()) - .setProviderName(ProviderConstants.NVIDIA) + .setModelType(EmbeddingGateway.ModelUsage.ModelType.RERANKING) + .setModelProvider(ModelProvider.NVIDIA.apiName()) .setModelName("llama-3.2-nv-rerankqa-1b-v2") .setPromptTokens(10) .setTotalTokens(20) @@ -75,7 +74,7 @@ void handleValidResponse() { new RerankingEGWClient( "https://xxx", null, - "xxx", + ModelProvider.NVIDIA, Optional.of("default"), Optional.of("default"), "xxx", @@ -83,7 +82,7 @@ void handleValidResponse() { Map.of(), TESTING_COMMAND_NAME); - final RerankingProvider.RerankingBatchResponse response = + final RerankingProvider.BatchedRerankingResponse response = rerankEGWClient .rerank(1, "apple", List.of("orange", "apple"), RERANK_CREDENTIALS) .subscribe() @@ -101,13 +100,13 @@ void handleValidResponse() { assertThat(response.ranks().get(1).score()).isEqualTo(0.1f); assertThat(response.modelUsage()).isNotNull(); - assertThat(response.modelUsage().providerType).isEqualTo(ProviderType.RERANKING_PROVIDER); - assertThat(response.modelUsage().provider).isEqualTo(ProviderConstants.NVIDIA); - assertThat(response.modelUsage().model).isEqualTo("llama-3.2-nv-rerankqa-1b-v2"); - assertThat(response.modelUsage().getPromptTokens()).isEqualTo(10); - assertThat(response.modelUsage().getTotalTokens()).isEqualTo(20); - assertThat(response.modelUsage().getRequestBytes()).isEqualTo(100); - assertThat(response.modelUsage().getResponseBytes()).isEqualTo(200); + assertThat(response.modelUsage().modelType()).isEqualTo(ModelType.RERANKING); + assertThat(response.modelUsage().modelProvider()).isEqualTo(ModelProvider.NVIDIA); + assertThat(response.modelUsage().modelName()).isEqualTo("llama-3.2-nv-rerankqa-1b-v2"); + assertThat(response.modelUsage().promptTokens()).isEqualTo(10); + assertThat(response.modelUsage().totalTokens()).isEqualTo(20); + assertThat(response.modelUsage().requestBytes()).isEqualTo(100); + assertThat(response.modelUsage().responseBytes()).isEqualTo(200); } @Test @@ -130,7 +129,7 @@ void handleError() { new RerankingEGWClient( "https://xxx", null, - "xxx", + ModelProvider.NVIDIA, Optional.of("default"), Optional.of("default"), "xxx", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingClientTest.java similarity index 71% rename from src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java rename to src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingClientTest.java index d61b43f787..e9dbba13fc 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingClientTest.java @@ -1,4 +1,4 @@ -package io.stargate.sgv2.jsonapi.service.reranking; +package io.stargate.sgv2.jsonapi.service.reranking.operation; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.*; @@ -11,9 +11,6 @@ import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; -import io.stargate.sgv2.jsonapi.service.provider.ProviderType; -import io.stargate.sgv2.jsonapi.service.reranking.operation.NvidiaRerankingProvider; -import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import java.util.List; import java.util.Optional; @@ -28,8 +25,8 @@ @TestProfile(NoGlobalResourcesTestProfile.Impl.class) public class NvidiaRerankingClientTest { - private static final RerankingCredentials RERANK_CREDENTIALS = - new RerankingCredentials(Optional.of("mocked data api token")); + private final RerankingCredentials RERANK_CREDENTIALS = + new RerankingCredentials("test-tenant", Optional.of("mocked data api token")); @Test void handleValidResponse() { @@ -42,17 +39,10 @@ void handleValidResponse() { .mapToObj(i -> new RerankingProvider.Rank(i, i == 0 ? 0.1f : 1f)) .toList(); return Uni.createFrom() - .item( - new RerankingProvider.RerankingBatchResponse( - 1, - ranks, - new ModelUsage( - ProviderType.RERANKING_PROVIDER, - "nvidia", - "llama-3.2-nv-rerankqa-1b-v2"))); + .item(new RerankingProvider.BatchedRerankingResponse(1, ranks, ModelUsage.EMPTY)); }); - final RerankingProvider.RerankingBatchResponse response = + final RerankingProvider.BatchedRerankingResponse response = nvidiaRerankingProvider .rerank(1, "apple", List.of("orange", "apple"), RERANK_CREDENTIALS) .subscribe() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingProviderTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderTest.java similarity index 93% rename from src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingProviderTest.java rename to src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderTest.java index e1e4a65573..776cb4dc87 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingProviderTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderTest.java @@ -1,4 +1,4 @@ -package io.stargate.sgv2.jsonapi.service.reranking; +package io.stargate.sgv2.jsonapi.service.reranking.operation; import static org.assertj.core.api.Assertions.assertThat; @@ -6,7 +6,6 @@ import io.quarkus.test.junit.TestProfile; import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; -import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import java.util.List; import java.util.Optional; @@ -18,7 +17,7 @@ public class RerankingProviderTest { private static final RerankingCredentials RERANK_CREDENTIALS = - new RerankingCredentials(Optional.of("mocked reranking api key")); + new RerankingCredentials("test-tenant", Optional.of("mocked reranking api key")); @Test @SuppressWarnings("unchecked") diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/TestRerankingProvider.java similarity index 55% rename from src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java rename to src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/TestRerankingProvider.java index a9b456e74b..c465da0643 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/TestRerankingProvider.java @@ -1,29 +1,28 @@ -package io.stargate.sgv2.jsonapi.service.reranking; +package io.stargate.sgv2.jsonapi.service.reranking.operation; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; -import io.stargate.sgv2.jsonapi.service.provider.ProviderType; -import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfigImpl; -import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import java.util.ArrayList; import java.util.List; /** Mock a test reranking provider that returns ranks based on query and passages */ public class TestRerankingProvider extends RerankingProvider { - protected TestRerankingProvider( - String baseUrl, - String modelName, - RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties) { - super(baseUrl, modelName, requestProperties); - } + // TODO: XXX Remove if not needed + // protected TestRerankingProvider( + // String baseUrl, + // String modelName, + // RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties + // requestProperties) { + // super(baseUrl, modelName, requestProperties); + // } protected TestRerankingProvider(int maxBatchSize) { super( + ModelProvider.CUSTOM, "mockUrl", "mockModel", new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl @@ -31,7 +30,13 @@ protected TestRerankingProvider(int maxBatchSize) { } @Override - public Uni rerank( + protected String errorMessageJsonPtr() { + // not used in tests + return ""; + } + + @Override + public Uni rerank( int batchId, String query, List passages, RerankingCredentials rerankCredentials) { List ranks = new ArrayList<>(passages.size()); for (int i = 0; i < passages.size(); i++) { @@ -40,14 +45,6 @@ public Uni rerank( ranks.add(new Rank(i, score)); } ranks.sort((o1, o2) -> Float.compare(o2.score(), o1.score())); // Descending order - return Uni.createFrom() - .item( - RerankingBatchResponse.of( - batchId, - ranks, - new ModelUsage( - ProviderType.RERANKING_PROVIDER, - ProviderConstants.NVIDIA, - "nvidia/llama-3.2-nv-rerankqa-1b-v2"))); + return Uni.createFrom().item(new BatchedRerankingResponse(batchId, ranks, ModelUsage.EMPTY)); } } From e1db4433d25c9f795034a19d9c5cbd68b5ac22dd Mon Sep 17 00:00:00 2001 From: Aaron Morton Date: Wed, 11 Jun 2025 13:31:04 +1200 Subject: [PATCH 04/22] WIP - basics working on laptop, checking regressions --- .../jsonapi/api/request/RequestContext.java | 6 ++--- .../operation/MeteredEmbeddingProvider.java | 24 ++++++++++++++----- .../operation/OpenAIEmbeddingProvider.java | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java index 659852fcd7..65fc8574f0 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java @@ -68,14 +68,14 @@ public RequestContext( HeaderBasedRerankingKeyResolver.resolveRerankingKey(routingContext); this.rerankingCredentials = rerankingApiKeyFromHeader - .map(apiKey -> new RerankingCredentials(this.tenantId.get(), Optional.of(apiKey))) + .map(apiKey -> new RerankingCredentials(this.tenantId.orElse(""), Optional.of(apiKey))) .orElse( this.cassandraToken .map( cassandraToken -> new RerankingCredentials( - this.tenantId.get(), Optional.of(cassandraToken))) - .orElse(new RerankingCredentials(this.tenantId.get(), Optional.empty()))); + this.tenantId.orElse(""), Optional.of(cassandraToken))) + .orElse(new RerankingCredentials(this.tenantId.orElse(""), Optional.empty()))); } private static String generateRequestId() { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java index a720c7187e..ffb2fd1019 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java @@ -11,7 +11,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; + +import io.stargate.sgv2.jsonapi.util.recordable.PrettyPrintable; import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Provides a metered version of an {@link EmbeddingProvider}, adding metrics collection to the @@ -20,17 +24,20 @@ * input texts. */ public class MeteredEmbeddingProvider extends EmbeddingProvider { + private static final Logger LOGGER = LoggerFactory.getLogger(MeteredEmbeddingProvider.class); + + private static final String UNKNOWN_TENANT_ID = "unknown"; + private final MeterRegistry meterRegistry; private final JsonApiMetricsConfig jsonApiMetricsConfig; - private final RequestContext dataApiRequestInfo; - private static final String UNKNOWN_VALUE = "unknown"; + private final RequestContext requestContext; private final EmbeddingProvider embeddingProvider; private final String commandName; public MeteredEmbeddingProvider( MeterRegistry meterRegistry, JsonApiMetricsConfig jsonApiMetricsConfig, - RequestContext dataApiRequestInfo, + RequestContext requestContext, EmbeddingProvider embeddingProvider, String commandName) { // aaron 9 June 2025 - we need to remove this "metered" design pattern, for now just pass the @@ -45,7 +52,7 @@ public MeteredEmbeddingProvider( this.meterRegistry = meterRegistry; this.jsonApiMetricsConfig = jsonApiMetricsConfig; - this.dataApiRequestInfo = dataApiRequestInfo; + this.requestContext = requestContext; this.embeddingProvider = embeddingProvider; this.commandName = commandName; } @@ -120,7 +127,12 @@ public Uni vectorize( // create the final ordered result result.addAll(vectorizedBatch.embeddings()); } - return new BatchedEmbeddingResponse(1, result, aggregatedModelUsage); + var embeddingResponse = new BatchedEmbeddingResponse(1, result, aggregatedModelUsage); + if (LOGGER.isTraceEnabled()){ + LOGGER.trace( + "Vectorize call completed, aggregatedModelUsage: {}", PrettyPrintable.print(aggregatedModelUsage)); + } + return embeddingResponse; }) .invoke( () -> @@ -144,7 +156,7 @@ public int maxBatchSize() { */ private Tags getCustomTags() { Tag commandTag = Tag.of(jsonApiMetricsConfig.command(), commandName); - Tag tenantTag = Tag.of("tenant", dataApiRequestInfo.getTenantId().orElse(UNKNOWN_VALUE)); + Tag tenantTag = Tag.of("tenant", requestContext.getTenantId().orElse(UNKNOWN_TENANT_ID)); Tag embeddingProviderTag = Tag.of( jsonApiMetricsConfig.embeddingProvider(), embeddingProvider.getClass().getSimpleName()); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java index b1ad0f0601..b234e58c3e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java @@ -115,7 +115,7 @@ public Uni vectorize( var modelUsage = createModelUsage( embeddingCredentials.tenantId(), - ModelInputType.INPUT_TYPE_UNSPECIFIED, + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), openAiResponse.usage().prompt_tokens(), openAiResponse.usage().total_tokens(), jakartaResponse, From 120993c203a29667fea69f256c3d444522ff5a54 Mon Sep 17 00:00:00 2001 From: Aaron Morton Date: Sat, 14 Jun 2025 10:29:38 +1200 Subject: [PATCH 05/22] tmp --- .github/workflows/continuous-integration.yaml | 5 +- CHANGELOG.md | 14 +- docker-compose/.env | 2 +- docker-compose/dse.yaml | 2 +- docker-compose/start_dse69.sh | 2 +- pom.xml | 14 +- src/main/docker/Dockerfile-profiling.jvm | 2 +- src/main/docker/Dockerfile.jvm | 2 +- .../sgv2/jsonapi/StargateJsonApi.java | 18 + .../api/model/command/CollectionCommand.java | 1 + .../jsonapi/api/model/command/Command.java | 3 + .../api/model/command/CommandContext.java | 32 +- .../api/model/command/CommandName.java | 2 + .../jsonapi/api/model/command/Filterable.java | 12 +- .../api/model/command/JsonDefinition.java | 23 +- .../jsonapi/api/model/command/Sortable.java | 20 +- .../api/model/command/VectorSortable.java | 8 +- .../CollectionFilterClauseBuilder.java | 37 + .../builders/CollectionSortClauseBuilder.java | 17 + .../command/builders/FilterClauseBuilder.java | 106 ++- .../SortClauseBuilder.java} | 53 +- .../builders/TableFilterClauseBuilder.java | 5 + .../builders/TableSortClauseBuilder.java | 17 + ...{FilterSpec.java => FilterDefinition.java} | 7 +- .../command/clause/filter/SortDefinition.java | 94 ++ .../filter/ValueComparisonOperator.java | 9 +- .../clause/sort/FindAndRerankSort.java | 20 +- .../FindAndRerankSortClauseDeserializer.java | 31 +- .../model/command/clause/sort/SortClause.java | 17 +- .../command/impl/CountDocumentsCommand.java | 5 +- .../command/impl/CreateTextIndexCommand.java | 71 ++ .../model/command/impl/DeleteManyCommand.java | 6 +- .../model/command/impl/DeleteOneCommand.java | 10 +- .../command/impl/FindAndRerankCommand.java | 44 +- .../api/model/command/impl/FindCommand.java | 8 +- .../impl/FindEmbeddingProvidersCommand.java | 20 +- .../command/impl/FindOneAndDeleteCommand.java | 8 +- .../impl/FindOneAndReplaceCommand.java | 8 +- .../command/impl/FindOneAndUpdateCommand.java | 8 +- .../model/command/impl/FindOneCommand.java | 8 +- .../impl/FindRerankingProvidersCommand.java | 20 +- .../model/command/impl/UpdateManyCommand.java | 4 +- .../model/command/impl/UpdateOneCommand.java | 8 +- .../indexes/TextIndexDefinitionDesc.java | 42 + .../validation/FindOptionsValidation.java | 11 +- .../request/EmbeddingCredentialsSupplier.java | 90 ++ .../jsonapi/api/request/RequestContext.java | 54 +- .../request/RerankingCredentialsResolver.java | 9 - .../jsonapi/api/v1/CollectionResource.java | 62 +- .../sgv2/jsonapi/api/v1/GeneralResource.java | 16 +- .../sgv2/jsonapi/api/v1/KeyspaceResource.java | 15 +- .../api/v1/metrics/JsonApiMetricsConfig.java | 4 - .../jsonapi/api/v1/metrics/MetricsConfig.java | 17 - .../v1/metrics/MicrometerConfiguration.java | 55 -- .../config/CommandLevelLoggingConfig.java | 2 +- .../sgv2/jsonapi/config/DatabaseType.java | 37 + .../sgv2/jsonapi/config/OperationsConfig.java | 52 +- .../config/constants/DocumentConstants.java | 8 + .../config/constants/HttpConstants.java | 3 + .../config/constants/TableDescConstants.java | 9 + .../config/constants/TableDescDefaults.java | 15 + .../sgv2/jsonapi/exception/ErrorCodeV1.java | 5 +- .../jsonapi/exception/SchemaException.java | 5 +- .../mappers/ThrowableToErrorMapper.java | 7 + .../sgv2/jsonapi/metrics/CommandFeature.java | 36 + .../sgv2/jsonapi/metrics/CommandFeatures.java | 113 +++ .../JsonProcessingMetricsReporter.java | 4 +- .../jsonapi/metrics/MetricsConstants.java | 27 + .../MetricsTenantDeactivationConsumer.java | 59 ++ .../metrics/MicrometerConfiguration.java | 144 +++ .../metrics/TenantRequestMetricsFilter.java | 6 +- .../TenantRequestMetricsTagProvider.java | 6 +- .../service/cqldriver/CQLSessionCache.java | 610 +++++++++---- .../service/cqldriver/CqlCredentials.java | 110 ++- .../cqldriver/CqlCredentialsFactory.java | 76 ++ .../cqldriver/CqlSessionCacheSupplier.java | 76 ++ .../service/cqldriver/CqlSessionFactory.java | 169 ++++ .../CustomTaggingMetricIdGenerator.java | 65 -- .../cqldriver/SchemaChangeListener.java | 63 -- .../TenantAwareCqlSessionBuilder.java | 15 +- .../executor/CollectionIndexUsage.java | 28 +- .../executor/CommandQueryExecutor.java | 10 +- .../cqldriver/executor/QueryExecutor.java | 11 +- .../cqldriver/executor/SchemaCache.java | 325 ++++++- .../cqldriver/executor/SchemaObjectName.java | 5 + ...eCache.java => TableBasedSchemaCache.java} | 10 +- .../cqldriver/executor/VectorConfig.java | 10 + .../override/DefaultSubConditionRelation.java | 170 ++++ .../cqldriver/override/ExtendedSelect.java | 90 ++ .../cqldriver/override/LogicalRelation.java | 36 + .../embedding/DataVectorizerService.java | 13 +- .../EmbeddingApiKeyResolverProvider.java | 27 - .../EmbeddingProviderConfigStore.java | 10 +- .../EmbeddingProvidersConfig.java | 18 +- .../EmbeddingProvidersConfigImpl.java | 9 + .../EmbeddingProvidersConfigProducer.java | 1 + ...ertyBasedEmbeddingProviderConfigStore.java | 1 + .../AwsBedrockEmbeddingProvider.java | 15 +- .../AzureOpenAIEmbeddingProvider.java | 25 +- .../operation/CohereEmbeddingProvider.java | 12 +- .../operation/EmbeddingProvider.java | 43 +- .../operation/EmbeddingProviderFactory.java | 55 +- ...HuggingFaceDedicatedEmbeddingProvider.java | 12 +- .../HuggingFaceEmbeddingProvider.java | 12 +- .../operation/JinaAIEmbeddingProvider.java | 14 +- .../operation/MeteredEmbeddingProvider.java | 13 +- .../operation/MistralEmbeddingProvider.java | 12 +- .../operation/NvidiaEmbeddingProvider.java | 12 +- .../operation/OpenAIEmbeddingProvider.java | 14 +- .../operation/UpstageAIEmbeddingProvider.java | 14 +- .../operation/VertexAIEmbeddingProvider.java | 18 +- .../operation/VoyageAIEmbeddingProvider.java | 19 +- .../test/CustomITEmbeddingProvider.java | 32 +- .../service/operation/GenericOperation.java | 1 + .../jsonapi/service/operation/ReadDBTask.java | 6 +- .../service/operation/ReadDBTaskPage.java | 5 +- .../builder/BuiltConditionPredicate.java | 3 +- .../collections/CollectionReadOperation.java | 2 +- .../CreateCollectionOperation.java | 150 ++- .../operation/embeddings/EmbeddingTask.java | 5 +- .../FindEmbeddingProvidersOperation.java | 173 ++-- .../collection/MatchCollectionFilter.java | 46 + .../filters/table/InTableFilter.java | 9 +- .../filters/table/NativeTypeTableFilter.java | 9 +- .../operation/query/DBLogicalExpression.java | 17 +- .../service/operation/query/TableFilter.java | 14 +- .../FindRerankingProvidersOperation.java | 95 +- .../IntermediateCollectionReadTask.java | 20 +- .../operation/reranking/RerankingMetrics.java | 233 +++++ .../operation/reranking/RerankingQuery.java | 2 +- .../operation/reranking/RerankingTask.java | 31 +- .../tables/CreateIndexDBTaskBuilder.java | 14 + .../operation/tables/TableProjection.java | 5 +- .../tables/TableSimilarityFunction.java | 7 +- .../operation/tables/TableWhereCQLClause.java | 122 ++- .../service/operation/tasks/DBTask.java | 1 + .../service/processor/CommandProcessor.java | 202 ++-- .../processor/HybridFieldExpander.java | 62 +- .../processor/MeteredCommandProcessor.java | 274 +++--- .../service/provider/ApiModelSupport.java | 43 + ...ddingAndRerankingConfigSourceProvider.java | 39 +- .../service/provider/ModelSupport.java | 33 - .../jsonapi/service/provider/ModelType.java | 2 +- .../service/provider/ProviderBase.java | 40 +- .../RerankingProviderConfigProducer.java | 19 +- .../RerankingProvidersConfig.java | 6 +- .../RerankingProvidersConfigImpl.java | 4 +- .../operation/RerankingProvider.java | 5 +- .../service/resolver/CommandResolver.java | 5 +- .../CreateCollectionCommandResolver.java | 25 +- .../CreateTextIndexCommandResolver.java | 101 ++ .../resolver/DeleteOneCommandResolver.java | 4 +- .../FindAndRerankOperationBuilder.java | 25 +- .../FindCollectionsCommandResolver.java | 9 +- .../service/resolver/FindCommandResolver.java | 2 +- ...FindEmbeddingProvidersCommandResolver.java | 2 +- .../FindOneAndDeleteCommandResolver.java | 6 +- .../FindOneAndReplaceCommandResolver.java | 6 +- .../FindOneAndUpdateCommandResolver.java | 6 +- .../resolver/FindOneCommandResolver.java | 6 +- .../resolver/ListTablesCommandResolver.java | 5 +- .../resolver/TableReadDBOperationBuilder.java | 10 +- .../resolver/UpdateOneCommandResolver.java | 4 +- .../resolver/VectorizeConfigValidator.java | 80 +- .../resolver/matcher/CaptureGroups.java | 2 +- .../matcher/CollectionFilterResolver.java | 39 +- .../sort/TableCqlSortClauseResolver.java | 2 +- .../sort/TableMemorySortClauseResolver.java | 6 +- .../collections/CollectionLexicalConfig.java | 126 ++- .../collections/CollectionRerankDef.java | 97 +- .../collections/CollectionSchemaObject.java | 22 +- .../collections/CollectionSettingsReader.java | 2 +- .../CollectionSettingsV0Reader.java | 30 +- .../CollectionSettingsV1Reader.java | 9 +- .../schema/tables/ApiColumnDefContainer.java | 5 - .../service/schema/tables/ApiIndexType.java | 24 +- .../service/schema/tables/ApiTextIndex.java | 207 +++++ .../service/schema/tables/ApiVectorIndex.java | 2 - .../schema/tables/IndexFactoryFromCql.java | 3 +- .../tables/IndexFactoryFromIndexDesc.java | 8 +- .../collections/DocumentShredder.java | 2 +- .../jsonapi/util/recordable/Recordable.java | 4 +- src/main/proto/embedding_gateway.proto | 17 +- src/main/resources/application.conf | 4 +- .../resources/embedding-providers-config.yaml | 5 +- src/main/resources/errors.yaml | 49 +- .../test-embedding-providers-config.yaml | 479 ++++++++++ .../test-reranking-providers-config.yaml | 4 +- .../stargate/sgv2/jsonapi/TestConstants.java | 17 +- .../ObjectMapperConfigurationTest.java | 21 +- .../builders/FilterClauseBuilderTest.java | 448 +++++---- .../SortClauseBuilderTest.java} | 76 +- ...ndAndRerankSortClauseDeserializerTest.java | 90 +- .../impl/FindOneAndDeleteCommandTest.java | 15 +- .../impl/FindOneAndReplaceCommandTest.java | 4 +- .../impl/HybridLimitsDeserializerTest.java | 26 +- .../command/impl/UpdateOneCommandTest.java | 24 +- .../EmbeddingCredentialsSupplierTest.java | 184 ++++ .../AbstractKeyspaceIntegrationTestBase.java | 45 +- .../v1/CollectionResourceIntegrationTest.java | 34 +- ...nBackwardCompatibilityIntegrationTest.java | 205 +++++ .../v1/CreateCollectionIntegrationTest.java | 678 +++----------- ...llectionTooManyIndexesIntegrationTest.java | 18 +- ...ollectionTooManyTablesIntegrationTest.java | 19 +- ...eCollectionWithLexicalIntegrationTest.java | 129 ++- ...ollectionWithRerankingIntegrationTest.java | 171 ++-- .../api/v1/CreateKeyspaceIntegrationTest.java | 109 +-- .../api/v1/DropKeyspaceIntegrationTest.java | 178 ++-- ...EstimatedDocumentCountIntegrationTest.java | 114 +-- ...indAndRerankCollectionIntegrationTest.java | 135 +-- ...CollectionWithLexicalIntegrationTest.java} | 177 +++- ...FindCollectionWithSortIntegrationTest.java | 392 +++----- .../v1/FindCollectionsIntegrationTest.java | 163 +--- ...FindEmbeddingProvidersIntegrationTest.java | 125 ++- .../jsonapi/api/v1/FindIntegrationTest.java | 28 +- .../api/v1/FindKeyspacesIntegrationTest.java | 24 +- .../v1/FindOneAndDeleteIntegrationTest.java | 227 ++--- .../v1/FindOneAndReplaceIntegrationTest.java | 470 +++------- .../v1/FindOneAndUpdateIntegrationTest.java | 11 +- ...indOneAndUpdateNoIndexIntegrationTest.java | 67 +- .../api/v1/FindOneIntegrationTest.java | 20 +- .../FindOneWithProjectionIntegrationTest.java | 121 +-- ...FindRerankingProvidersIntegrationTest.java | 132 +-- .../v1/GeneralResourceIntegrationTest.java | 23 +- .../api/v1/InAndNinIntegrationTest.java | 724 ++++++--------- .../api/v1/IndexingConfigIntegrationTest.java | 309 ++----- .../v1/InsertInCollectionIntegrationTest.java | 17 +- ...ertLexicalInCollectionIntegrationTest.java | 49 + .../v1/KeyspaceResourceIntegrationTest.java | 22 +- .../api/v1/LwtRetryIntegrationTest.java | 84 +- .../api/v1/PaginationIntegrationTest.java | 89 +- .../api/v1/RangeReadIntegrationTest.java | 330 ++++--- .../api/v1/UpdateManyIntegrationTest.java | 503 +++------- .../api/v1/VectorSearchIntegrationTest.java | 864 ++++++------------ .../v1/VectorizeSearchIntegrationTest.java | 436 +++------ .../AbstractTableIntegrationTestBase.java | 2 +- .../v1/tables/AlterTableIntegrationTest.java | 29 +- .../CreateTableIndexIntegrationTest.java | 191 +++- .../v1/tables/CreateTableIntegrationTest.java | 60 ++ .../tables/InsertOneTableIntegrationTest.java | 10 +- .../v1/tables/ListIndexesIntegrationTest.java | 378 +++++--- .../tables/LogicalFilterIntegrationTest.java | 166 ++++ .../api/v1/util/DataApiResponseValidator.java | 25 +- .../v1/util/DataApiTableCommandSender.java | 4 + .../sgv2/jsonapi/fixtures/TestTextUtil.java | 23 + .../tables/TableLogicalRelationTest.java | 199 ++++ .../testdata/LogicalExpressionTestData.java | 93 +- .../testdata/TableWhereCQLClauseTestData.java | 128 +++ .../jsonapi/fixtures/testdata/TestData.java | 4 + .../testdata/WhereAnalyzerTestData.java | 12 +- ...etricsTenantDeactivationConsumerTests.java | 172 ++++ .../metrics/MicrometerConfigurationTests.java | 258 ++++++ .../service/cql/builder/QueryBuilderTest.java | 17 +- .../cqldriver/CqlCredentialsFactoryTests.java | 196 ++++ .../CqlSessionCacheSupplierTests.java | 46 + .../cqldriver/CqlSessionCacheTest.java | 320 ------- .../cqldriver/CqlSessionCacheTests.java | 655 +++++++++++++ .../cqldriver/CqlSessionCacheTimingTest.java | 109 --- .../cqldriver/CqlSessionFactoryTests.java | 159 ++++ .../cqldriver/InvalidCqlCredentialsTest.java | 120 --- .../TenantAwareCqlSessionBuilderTest.java | 43 - .../executor/NamespaceCacheTest.java | 10 +- .../cqldriver/executor/SchemaCacheTests.java | 257 ++++++ .../operation/DataVectorizerTest.java | 3 +- .../operation/EmbeddingGatewayClientTest.java | 10 +- .../EmbeddingProviderErrorMessageTest.java | 40 +- .../operation/OpenAiEmbeddingClientTest.java | 32 +- .../operation/TestEmbeddingProvider.java | 46 +- .../CreateCollectionOperationTest.java | 20 +- .../FindCollectionOperationTest.java | 4 +- .../collections/OperationTestBase.java | 2 +- .../reranking/RerankingQueryTests.java | 5 +- .../tables/SelectWhereAnalyzerTest.java | 4 +- .../tables/WriteableTableRowBuilderTest.java | 4 +- .../processor/HybridFieldExpanderTest.java | 112 ++- .../jsonapi/testresource/DseTestResource.java | 54 +- .../testresource/StargateTestResource.java | 4 +- src/test/resources/application.yaml | 5 + src/test/resources/dse.yaml | 2 +- 279 files changed, 12085 insertions(+), 7234 deletions(-) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionSortClauseBuilder.java rename src/main/java/io/stargate/sgv2/jsonapi/api/model/command/{deserializers/SortClauseDeserializer.java => builders/SortClauseBuilder.java} (84%) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableSortClauseBuilder.java rename src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/{FilterSpec.java => FilterDefinition.java} (91%) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/SortDefinition.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateTextIndexCommand.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/indexes/TextIndexDefinitionDesc.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentialsResolver.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MicrometerConfiguration.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/config/DatabaseType.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeature.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeatures.java rename src/main/java/io/stargate/sgv2/jsonapi/{api/v1 => }/metrics/JsonProcessingMetricsReporter.java (94%) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsConstants.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsTenantDeactivationConsumer.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/metrics/MicrometerConfiguration.java rename src/main/java/io/stargate/sgv2/jsonapi/{api/v1 => }/metrics/TenantRequestMetricsFilter.java (95%) rename src/main/java/io/stargate/sgv2/jsonapi/{api/v1 => }/metrics/TenantRequestMetricsTagProvider.java (92%) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentialsFactory.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheSupplier.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionFactory.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CustomTaggingMetricIdGenerator.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/SchemaChangeListener.java rename src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/{NamespaceCache.java => TableBasedSchemaCache.java} (93%) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/DefaultSubConditionRelation.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/ExtendedSelect.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/LogicalRelation.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/embedding/EmbeddingApiKeyResolverProvider.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/collection/MatchCollectionFilter.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingMetrics.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ApiModelSupport.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelSupport.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTextIndexCommandResolver.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiTextIndex.java create mode 100644 src/main/resources/test-embedding-providers-config.yaml rename src/test/java/io/stargate/sgv2/jsonapi/api/model/command/{deserializers/SortClauseDeserializerTest.java => builders/SortClauseBuilderTest.java} (80%) create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplierTest.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionBackwardCompatibilityIntegrationTest.java rename src/test/java/io/stargate/sgv2/jsonapi/api/v1/{FindCollectionWithLexicalSortIntegrationTest.java => FindCollectionWithLexicalIntegrationTest.java} (70%) create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/LogicalFilterIntegrationTest.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/fixtures/TestTextUtil.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/fixtures/tables/TableLogicalRelationTest.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/TableWhereCQLClauseTestData.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/metrics/MetricsTenantDeactivationConsumerTests.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/metrics/MicrometerConfigurationTests.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentialsFactoryTests.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheSupplierTests.java delete mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheTest.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheTests.java delete mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheTimingTest.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionFactoryTests.java delete mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/InvalidCqlCredentialsTest.java delete mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/TenantAwareCqlSessionBuilderTest.java create mode 100644 src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaCacheTests.java diff --git a/.github/workflows/continuous-integration.yaml b/.github/workflows/continuous-integration.yaml index adf8f0e21c..9ce8ea3d12 100644 --- a/.github/workflows/continuous-integration.yaml +++ b/.github/workflows/continuous-integration.yaml @@ -176,6 +176,7 @@ jobs: # run the int tests - name: Integration Test - # -DDEFAULT_RERANKING_CONFIG_RESOURCE_OVERRIDE=test-reranking-providers-config.yaml is to override the reranking config to customized one + # -DRERANKING_CONFIG_RESOURCE=test-reranking-providers-config.yaml is to override the reranking config to customized one + # -DEMBEDDING_CONFIG_RESOURCE=test-embedding-providers-config.yaml is to override the embedding config to customized one run: | - ./mvnw -B -ntp clean verify -DskipUnitTests -DDEFAULT_RERANKING_CONFIG_RESOURCE_OVERRIDE=test-reranking-providers-config.yaml -Dquarkus.container-image.build=true -Dquarkus.container-image.tag=${{ github.sha }} -Drun-create-index-parallel=true ${{ matrix.profile }} \ No newline at end of file + ./mvnw -B -ntp clean verify -DskipUnitTests -DRERANKING_CONFIG_RESOURCE=test-reranking-providers-config.yaml -DEMBEDDING_CONFIG_RESOURCE=test-embedding-providers-config.yaml -Dquarkus.container-image.build=true -Dquarkus.container-image.tag=${{ github.sha }} -Drun-create-index-parallel=true ${{ matrix.profile }} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f44dd5786..aee0ed8db3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,8 @@ # Changelog -## [Unreleased](https://github.com/stargate/data-api/tree/HEAD) +## [v1.0.25](https://github.com/stargate/data-api/tree/v1.0.25) (2025-04-18) -[Full Changelog](https://github.com/stargate/data-api/compare/v1.0.24...HEAD) +[Full Changelog](https://github.com/stargate/data-api/compare/v1.0.24...v1.0.25) **Fixed bugs:** @@ -10,8 +10,18 @@ **Closed issues:** +- Flaky IT: `insertDifferentVectorizeProviders` - Error message varies with multiple providers [\#2016](https://github.com/stargate/data-api/issues/2016) +- Reranking and Lexical feature enabled in production caused backwards compatibility issue for createCollection. [\#2013](https://github.com/stargate/data-api/issues/2013) - \(Tables\) Error inserting empty association lists as a map [\#1998](https://github.com/stargate/data-api/issues/1998) +**Merged pull requests:** + +- Fix \#2016: Fix flaky integration test [\#2017](https://github.com/stargate/data-api/pull/2017) ([Hazel-Datastax](https://github.com/Hazel-Datastax)) +- Add IT for \#2013: Add integration tests to ensure backward compatibility on createCollection [\#2015](https://github.com/stargate/data-api/pull/2015) ([Hazel-Datastax](https://github.com/Hazel-Datastax)) +- Fixes \#2013: avoid fail for createCollection wrt legacy collections \(pre-lexical/pre-rerank\) [\#2014](https://github.com/stargate/data-api/pull/2014) ([tatu-at-datastax](https://github.com/tatu-at-datastax)) +- Add simple IT to check that stemming/stop-word config works for $lexical Collection [\#2009](https://github.com/stargate/data-api/pull/2009) ([tatu-at-datastax](https://github.com/tatu-at-datastax)) +- Bumping version for next data-api release [\#2007](https://github.com/stargate/data-api/pull/2007) ([github-actions[bot]](https://github.com/apps/github-actions)) + ## [v1.0.24](https://github.com/stargate/data-api/tree/v1.0.24) (2025-04-10) [Full Changelog](https://github.com/stargate/data-api/compare/v1.0.23...v1.0.24) diff --git a/docker-compose/.env b/docker-compose/.env index dafca8bf9e..1c574c6349 100644 --- a/docker-compose/.env +++ b/docker-compose/.env @@ -3,7 +3,7 @@ LOGLEVEL=INFO REQUESTLOG=false DATAAPITAG=v1 DATAAPIIMAGE=stargateio/jsonapi -DSETAG=6.9.8 +DSETAG=6.9.9 DSEIMAGE="cr.dtsx.io/datastax/dse-server" HCDTAG=1.2.1-early-preview #HCDIMAGE="cr.dtsx.io/datastax/hcd" diff --git a/docker-compose/dse.yaml b/docker-compose/dse.yaml index 29193dcdd0..96cc90a737 100644 --- a/docker-compose/dse.yaml +++ b/docker-compose/dse.yaml @@ -1,4 +1,4 @@ -# DSE Config Version: 6.9.8 +# DSE Config Version: 6.9.9 # Memory limit for DSE In-Memory tables as a fraction of system memory. When not set, # the default is 0.2 (20% of system memory). diff --git a/docker-compose/start_dse69.sh b/docker-compose/start_dse69.sh index f2384c9dd4..6796402e4e 100755 --- a/docker-compose/start_dse69.sh +++ b/docker-compose/start_dse69.sh @@ -13,7 +13,7 @@ LOGLEVEL=INFO DATAAPITAG="v1" DATAAPIIMAGE="stargateio/data-api" -DSETAG="6.9.8" +DSETAG="6.9.9" DSEIMAGE="cr.dtsx.io/datastax/dse-server" DSEONLY="false" DSENODES=1 diff --git a/pom.xml b/pom.xml index c3615e4ed0..3308e0a8d2 100644 --- a/pom.xml +++ b/pom.xml @@ -3,9 +3,9 @@ 4.0.0 io.stargate sgv2-jsonapi - 1.0.25-SNAPSHOT + 1.0.26-SNAPSHOT - 2.1.0-BETA-23 + 2.1.0-BETA-25 2.18.3 @@ -27,7 +27,7 @@ datastax/dse-server - 6.9.8 + 6.9.9 ${cassandra.version} dse-${stargate.int-test.cassandra.image-tag}-cluster true @@ -239,6 +239,12 @@ 5.0.0 + + com.github.javafaker + javafaker + 1.0.2 + test + io.quarkus quarkus-junit5 @@ -422,7 +428,7 @@ datastax/dse-server - 6.9.8 + 6.9.9 dse-${stargate.int-test.cassandra.image-tag}-cluster false true diff --git a/src/main/docker/Dockerfile-profiling.jvm b/src/main/docker/Dockerfile-profiling.jvm index 1ecd2046fb..4ba859afad 100644 --- a/src/main/docker/Dockerfile-profiling.jvm +++ b/src/main/docker/Dockerfile-profiling.jvm @@ -3,7 +3,7 @@ # based on less minimal UBI8 OpenJDK 17 image # ### -FROM registry.access.redhat.com/ubi8/openjdk-21:1.21 +FROM registry.access.redhat.com/ubi8/openjdk-21:1.21-2 ENV LANGUAGE='en_US:en' diff --git a/src/main/docker/Dockerfile.jvm b/src/main/docker/Dockerfile.jvm index d0ec377c9a..e00efb9b57 100644 --- a/src/main/docker/Dockerfile.jvm +++ b/src/main/docker/Dockerfile.jvm @@ -79,7 +79,7 @@ ### # see https://catalog.redhat.com/software/containers/ubi8/openjdk-21-runtime/653fd184292263c0a2f14d69?gs&q=openjdk%2021%20ubi -FROM registry.access.redhat.com/ubi8/openjdk-21-runtime:1.21 +FROM registry.access.redhat.com/ubi8/openjdk-21-runtime:1.21-2 ENV LANGUAGE='en_US:en' diff --git a/src/main/java/io/stargate/sgv2/jsonapi/StargateJsonApi.java b/src/main/java/io/stargate/sgv2/jsonapi/StargateJsonApi.java index 81005889a2..a7ee65cf4b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/StargateJsonApi.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/StargateJsonApi.java @@ -902,6 +902,24 @@ } } """), + @ExampleObject( + name = "createTextIndex", + summary = "`createTextIndex` for text columns, in tables api", + value = + """ + { + "createTextIndex": { + "name": "lexical_idx", + "definition": { + "column": "keywords", + "analyzer": "english" + }, + "options" : { + "ifNotExists" : true + } + } + } + """), @ExampleObject( name = "createVectorIndex", summary = "`createVectorIndex` for vector columns, in tables api", diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CollectionCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CollectionCommand.java index 0acc4c5003..449a2588fd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CollectionCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CollectionCommand.java @@ -24,6 +24,7 @@ // We have only collection resource that is used for API Tables @JsonSubTypes.Type(value = AlterTableCommand.class), @JsonSubTypes.Type(value = CreateIndexCommand.class), + @JsonSubTypes.Type(value = CreateTextIndexCommand.class), @JsonSubTypes.Type(value = CreateVectorIndexCommand.class), @JsonSubTypes.Type(value = ListIndexesCommand.class), }) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Command.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Command.java index 642820f07d..05a206d205 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Command.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Command.java @@ -2,6 +2,7 @@ import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.stargate.sgv2.jsonapi.metrics.CommandFeatures; import io.stargate.sgv2.jsonapi.service.resolver.CommandResolver; /** @@ -35,4 +36,6 @@ public interface Command { * publicCommandName -> createCollection */ CommandName commandName(); + + default void addCommandFeatures(CommandFeatures commandFeatures) {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandContext.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandContext.java index 110a254662..556403a477 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandContext.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandContext.java @@ -1,14 +1,16 @@ package io.stargate.sgv2.jsonapi.api.model.command; import com.google.common.base.Preconditions; +import io.micrometer.core.instrument.MeterRegistry; import io.stargate.sgv2.jsonapi.api.model.command.impl.FindAndRerankCommand; import io.stargate.sgv2.jsonapi.api.model.command.tracing.DefaultRequestTracing; import io.stargate.sgv2.jsonapi.api.model.command.tracing.RequestTracing; import io.stargate.sgv2.jsonapi.api.request.RequestContext; -import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.config.feature.ApiFeature; import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures; import io.stargate.sgv2.jsonapi.config.feature.FeaturesConfig; +import io.stargate.sgv2.jsonapi.metrics.CommandFeatures; +import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; @@ -41,6 +43,7 @@ public class CommandContext { private final CommandConfig commandConfig; private final EmbeddingProviderFactory embeddingProviderFactory; private final RerankingProviderFactory rerankingProviderFactory; + private final MeterRegistry meterRegistry; // Request specific private final SchemaT schemaObject; @@ -53,6 +56,9 @@ public class CommandContext { // see accessors private FindAndRerankCommand.HybridLimits hybridLimits; + // used to track the features used in the command + private final CommandFeatures commandFeatures; + // created on demand or set via builder, otherwise we need to read from config too early when // running tests, See the {@link Builder#withApiFeatures} // access via {@link CommandContext#apiFeatures()} @@ -68,7 +74,8 @@ private CommandContext( CommandConfig commandConfig, ApiFeatures apiFeatures, EmbeddingProviderFactory embeddingProviderFactory, - RerankingProviderFactory rerankingProviderFactory) { + RerankingProviderFactory rerankingProviderFactory, + MeterRegistry meterRegistry) { this.schemaObject = schemaObject; this.embeddingProvider = embeddingProvider; @@ -82,6 +89,7 @@ private CommandContext( this.rerankingProviderFactory = rerankingProviderFactory; this.apiFeatures = apiFeatures; + this.meterRegistry = meterRegistry; var anyTracing = apiFeatures().isFeatureEnabled(ApiFeature.REQUEST_TRACING) @@ -94,6 +102,8 @@ private CommandContext( requestContext.getTenantId().orElse(""), apiFeatures().isFeatureEnabled(ApiFeature.REQUEST_TRACING_FULL)) : RequestTracing.NO_OP; + + this.commandFeatures = CommandFeatures.create(); } /** See doc comments for {@link CommandContext} */ @@ -156,6 +166,10 @@ public ApiFeatures apiFeatures() { return apiFeatures; } + public CommandFeatures commandFeatures() { + return commandFeatures; + } + public JsonProcessingMetricsReporter jsonProcessingMetricsReporter() { return jsonProcessingMetricsReporter; } @@ -172,6 +186,10 @@ public EmbeddingProviderFactory embeddingProviderFactory() { return embeddingProviderFactory; } + public MeterRegistry meterRegistry() { + return meterRegistry; + } + public boolean isCollectionContext() { return schemaObject().type() == CollectionSchemaObject.TYPE; } @@ -221,6 +239,7 @@ public static class BuilderSupplier { private CommandConfig commandConfig; private EmbeddingProviderFactory embeddingProviderFactory; private RerankingProviderFactory rerankingProviderFactory; + private MeterRegistry meterRegistry; BuilderSupplier() {} @@ -252,6 +271,11 @@ public BuilderSupplier withRerankingProviderFactory( return this; } + public BuilderSupplier withMeterRegistry(MeterRegistry meterRegistry) { + this.meterRegistry = meterRegistry; + return this; + } + public Builder getBuilder(SchemaT schemaObject) { Objects.requireNonNull( @@ -260,6 +284,7 @@ public Builder getBuilder(SchemaT schema Objects.requireNonNull(commandConfig, "commandConfig must not be null"); Objects.requireNonNull(embeddingProviderFactory, "embeddingProviderFactory must not be null"); Objects.requireNonNull(rerankingProviderFactory, "rerankingProviderFactory must not be null"); + Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); // SchemaObject is passed here so the generics gets locked here, makes call chaining easier Objects.requireNonNull(schemaObject, "schemaObject must not be null"); @@ -326,7 +351,8 @@ public CommandContext build() { commandConfig, apiFeatures, embeddingProviderFactory, - rerankingProviderFactory); + rerankingProviderFactory, + meterRegistry); } } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandName.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandName.java index a7134cb9b2..685682d1dc 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandName.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandName.java @@ -19,6 +19,7 @@ public enum CommandName { COUNT_DOCUMENTS(Names.COUNT_DOCUMENTS, CommandType.DML, CommandTarget.COLLECTION), CREATE_COLLECTION(Names.CREATE_COLLECTION, CommandType.DDL, CommandTarget.KEYSPACE), CREATE_INDEX(Names.CREATE_INDEX, CommandType.DDL, CommandTarget.TABLE), + CREATE_TEXT_INDEX(Names.CREATE_TEXT_INDEX, CommandType.DDL, CommandTarget.TABLE), CREATE_VECTOR_INDEX(Names.CREATE_VECTOR_INDEX, CommandType.DDL, CommandTarget.TABLE), CREATE_KEYSPACE(Names.CREATE_KEYSPACE, CommandType.DDL, CommandTarget.DATABASE), CREATE_NAMESPACE(Names.CREATE_NAMESPACE, CommandType.DDL, CommandTarget.DATABASE), @@ -107,6 +108,7 @@ public interface Names { String COUNT_DOCUMENTS = "countDocuments"; String CREATE_COLLECTION = "createCollection"; String CREATE_INDEX = "createIndex"; + String CREATE_TEXT_INDEX = "createTextIndex"; String CREATE_VECTOR_INDEX = "createVectorIndex"; String CREATE_KEYSPACE = "createKeyspace"; String CREATE_NAMESPACE = "createNamespace"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Filterable.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Filterable.java index e0c43ef343..c219007dfb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Filterable.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Filterable.java @@ -1,17 +1,17 @@ package io.stargate.sgv2.jsonapi.api.model.command; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterClause; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; /* - * All the commands that needs FilterClause will have to implement this. + * All the commands that accept {@code FilterClause} will have to implement this interface. */ public interface Filterable { - /** Accessor for the filter specification in its intermediate for */ - FilterSpec filterSpec(); + /** Accessor for the filter definition in its intermediate JSON form */ + FilterDefinition filterDefinition(); default FilterClause filterClause(CommandContext ctx) { - FilterSpec spec = filterSpec(); - return (spec == null) ? FilterClause.empty() : spec.toFilterClause(ctx); + FilterDefinition def = filterDefinition(); + return (def == null) ? FilterClause.empty() : def.build(ctx); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/JsonDefinition.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/JsonDefinition.java index f63f60a151..ee40ee64c8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/JsonDefinition.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/JsonDefinition.java @@ -1,17 +1,20 @@ package io.stargate.sgv2.jsonapi.api.model.command; import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.annotations.VisibleForTesting; import java.util.Objects; /** - * Value type enclosing JSON-decoded representation of something to resolve into fully definition, + * Value type enclosing JSON-encoded representation of something to build into full model object, * such as filter or sort clause. Needed to let Quarkus/Jackson decode textual JSON into * intermediate form ({@link JsonNode}) from which actual deserialization and validation can be * deferred until we have context we need * - * @see io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec + * @see io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition + * @see io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition + * @param Type of the value object to be deserialized from the JSON value */ -public abstract class JsonDefinition { +public abstract class JsonDefinition { /** The wrapped JSON value */ private final JsonNode json; @@ -20,9 +23,19 @@ protected JsonDefinition(JsonNode json) { } /** - * @return JSON value that specifies the object + * Method for lazy deserialization of the JSON value into a {@link T} instance, usually a {@code + * Clause} of some type. + * + * @param ctx Context passed to the builder + * @return Fully processed {@link T} instance based on this definition. */ - protected JsonNode json() { + public abstract T build(CommandContext ctx); + + /** + * @return JSON value that contains the definition of value object to build + */ + @VisibleForTesting + public JsonNode json() { return json; } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Sortable.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Sortable.java index bd5657ee30..a949cd4663 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Sortable.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/Sortable.java @@ -1,10 +1,26 @@ package io.stargate.sgv2.jsonapi.api.model.command; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; /* - * All the commands that needs SortClause will have to implement this. + * All the commands that accept {@code SortClause} will have to implement this interface. + * Will delegate most of the work to {@link SortDefinition} which in turn will delegate + * to {@link SortClauseBuilder}. */ public interface Sortable { - SortClause sortClause(); + /** Accessor for the Sort definition in its intermediate JSON form */ + SortDefinition sortDefinition(); + + /** + * Convenience accessor for fully processed SortClause: will convert the intermediate JSON value + * to a {@link SortClause} instance, including all validation. + * + * @param ctx Processing context for the command; used mostly to get the schema object for the + * builder + */ + default SortClause sortClause(CommandContext ctx) { + SortDefinition def = sortDefinition(); + return (def == null) ? SortClause.empty() : def.build(ctx); + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorSortable.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorSortable.java index b9e04c24cc..471c0d3c64 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorSortable.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorSortable.java @@ -1,5 +1,6 @@ package io.stargate.sgv2.jsonapi.api.model.command; +import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import java.util.Optional; @@ -20,10 +21,11 @@ default Optional includeSortVector() { * * @return the vector sort expression if it exists. */ - default Optional vectorSortExpression() { - if (sortClause() != null && sortClause().sortExpressions() != null) { + default Optional vectorSortExpression(CommandContext ctx) { + SortClause sortClause = sortClause(ctx); + if (sortClause.sortExpressions() != null) { var vectorSorts = - sortClause().sortExpressions().stream() + sortClause.sortExpressions().stream() .filter(expression -> expression.vector() != null) .toList(); if (vectorSorts.size() > 1) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionFilterClauseBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionFilterClauseBuilder.java index f09092d78b..266d24731d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionFilterClauseBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionFilterClauseBuilder.java @@ -11,6 +11,8 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.projection.IndexingProjector; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; +import io.stargate.sgv2.jsonapi.service.schema.collections.DocumentPath; +import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules; import java.util.List; import java.util.Map; @@ -30,6 +32,41 @@ protected FilterClause validateAndBuild(LogicalExpression rootExpr) { return new FilterClause(validateWithSchema(rootExpr)); } + @Override + protected String validateFilterClausePath(String path) { + if (!NamingRules.FIELD.apply(path)) { + if (path.isEmpty()) { + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "filter clause path cannot be empty String"); + } + // 3 special fields with $ prefix, skip here + switch (path) { + case DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD, + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD -> { + return path; + } + case DocumentConstants.Fields.LEXICAL_CONTENT_FIELD -> { + if (!schema.lexicalConfig().enabled()) { + throw ErrorCodeV1.LEXICAL_NOT_ENABLED_FOR_COLLECTION.toApiException( + "Lexical search is not enabled for collection '%s'", schema.name()); + } + return path; + } + } + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "filter clause path ('%s') cannot start with `$`", path); + } + + try { + path = DocumentPath.verifyEncodedPath(path); + } catch (IllegalArgumentException e) { + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "filter clause path ('%s') is not a valid path: %s", path, e.getMessage()); + } + + return path; + } + private LogicalExpression validateWithSchema(LogicalExpression rootExpr) { IndexingProjector indexingProjector = schema.indexingProjector(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionSortClauseBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionSortClauseBuilder.java new file mode 100644 index 0000000000..833616b7b4 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/CollectionSortClauseBuilder.java @@ -0,0 +1,17 @@ +package io.stargate.sgv2.jsonapi.api.model.command.builders; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; + +/** {@link SortClauseBuilder} to use with Collections. */ +public class CollectionSortClauseBuilder extends SortClauseBuilder { + public CollectionSortClauseBuilder(CollectionSchemaObject collection) { + super(collection); + } + + @Override + public SortClause buildAndValidate(ObjectNode sortNode) { + return defaultBuildAndValidate(sortNode); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilder.java index ec13967a74..e2184c4095 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilder.java @@ -11,8 +11,6 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; -import io.stargate.sgv2.jsonapi.service.schema.collections.DocumentPath; -import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules; import io.stargate.sgv2.jsonapi.service.shredding.collections.DocumentId; import io.stargate.sgv2.jsonapi.service.shredding.collections.JsonExtensionType; import io.stargate.sgv2.jsonapi.util.JsonUtil; @@ -20,7 +18,7 @@ import java.util.*; /** - * Object for converting {@link JsonNode} (from {@link FilterSpec}) into {@link FilterClause}. + * Object for converting {@link JsonNode} (from {@link FilterDefinition}) into {@link FilterClause}. * Process will validate structure of the JSON, and also validate values of the filter operations. * *

TIDY: this class has a lot of string constants for filter operations that we have defined as @@ -112,7 +110,7 @@ private void populateExpression(LogicalExpression logicalExpression, JsonNode no } } else { throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( - "Cannot filter on '%s' field using operator '$eq': only '$exists' is supported", + "Cannot filter on '%s' field using operator $eq: only $exists is supported", DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD); } } @@ -136,7 +134,7 @@ private void populateExpression( case DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD, DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD -> throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( - "Cannot filter on '%s' field using operator '$eq': only '$exists' is supported", + "Cannot filter on '%s' field using operator $eq: only $exists is supported", entry.getKey()); default -> throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( @@ -147,7 +145,18 @@ private void populateExpression( populateExpression(innerLogicalExpression, next); } logicalExpression.addLogicalExpression(innerLogicalExpression); - } else { + } else { // neither Array nor Object, simple implicit "$eq" comparison + switch (entry.getKey()) { + case DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD, + DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD -> + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "Cannot filter on '%s' field using operator $eq: only $exists is supported", + entry.getKey()); + case DocumentConstants.Fields.LEXICAL_CONTENT_FIELD -> + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "Cannot filter on '%s' field using operator $eq: only $match is supported", + entry.getKey()); + } // the key should match pattern String key = validateFilterClausePath(entry.getKey()); logicalExpression.addComparisonExpressions( @@ -199,39 +208,59 @@ private List createComparisonExpressionList( return comparisonExpressionList; } - // Before validating Filter path, check for special cases: - // ($vector/$vectorize and $exist operator) - String entryKey = entry.getKey(); - if ((entryKey.equals(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD) - && updateField.getKey().equals("$exists")) - || (entryKey.equals(DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD) - && updateField.getKey().equals("$exists"))) { - ; // fine, special cases - } else { - entryKey = validateFilterClausePath(entryKey); + String entryKey = validateFilterClausePath(entry.getKey()); + + // First things first: $lexical field can only be used with $match operator + if (entryKey.equals(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD) + && (operator != ValueComparisonOperator.MATCH)) { + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "Cannot filter on '%s' field using operator %s: only $match is supported", + entry.getKey(), operator.getOperator()); } + JsonNode value = updateField.getValue(); Object valueObject = jsonNodeValue(entryKey, value); if (operator == ValueComparisonOperator.GT || operator == ValueComparisonOperator.GTE || operator == ValueComparisonOperator.LT || operator == ValueComparisonOperator.LTE) { - // Note, added 'valueObject instanceof String || valueObject instanceof Boolean', this is to - // unblock some table filter against non-numeric column - // e.g. {"event_date": {"$gt": "2024-09-24"}}, {"is_active": {"$gt": true}}, - // {"name":{"$gt":"Tim"}} - // Also, for collection path, this will allow comparison filter against collection maps - // query_bool_values and query_text_values + // Comparator GT/GTE/LT/LTE can apply to following value types: + // For Tables, Data/String/Boolean/BigDecimal + // For Collections, Data/String/Boolean/BigDecimal and + // DocumentID(Date/String/Boolean/BigDecimal) + // E.G. + // {"birthday": {"$gt": {"$date": 1672531200000}}}, {"name": {"$gt": "Tim"}} + // {"is_active": {"$gt": true}}, {"age": {"$gt": 123}} + // {"_id": {"$gt": {"$date": 1672531200000}}}, {"_id": {"$gt": "Tim"}} + // {"_id": {"$gt": true}}, {"_id": {"$gt": 123}} if (!(valueObject instanceof Date || valueObject instanceof String || valueObject instanceof Boolean || valueObject instanceof BigDecimal - || (valueObject instanceof DocumentId && (value.isObject() || value.isNumber())))) { + || (valueObject instanceof DocumentId + && (value.isObject() + || value.isTextual() + || value.isBoolean() + || value.isNumber())))) { throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( "%s operator must have `DATE` or `NUMBER` or `TEXT` or `BOOLEAN` value", operator.getOperator()); } + } else if (operator == ValueComparisonOperator.MATCH) { + // $match operator can only be used with String value and for Collections + // only on $lexical field + if (!entryKey.equals(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD)) { + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "%s operator can only be used with the '%s' field, not '%s'", + operator.getOperator(), DocumentConstants.Fields.LEXICAL_CONTENT_FIELD, entryKey); + } + if (!(valueObject instanceof String)) { + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "%s operator must have `String` value, was `%s`", + operator.getOperator(), JsonUtil.nodeTypeAsString(value)); + } } + ComparisonExpression expression = new ComparisonExpression(entryKey, new ArrayList<>(), null); expression.add(operator, valueObject); comparisonExpressionList.add(expression); @@ -538,27 +567,12 @@ private void flip(LogicalExpression logicalExpression) { } } - private String validateFilterClausePath(String path) { - if (!NamingRules.FIELD.apply(path)) { - if (path.isEmpty()) { - throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( - "filter clause path cannot be empty String"); - } - if (path.equals(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD)) { - throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( - "Cannot filter on lexical content field '%s': only 'sort' clause supported", path); - } - throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( - "filter clause path ('%s') cannot start with `$`", path); - } - - try { - path = DocumentPath.verifyEncodedPath(path); - } catch (IllegalArgumentException e) { - throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( - "filter clause path ('%s') is not a valid path. " + e.getMessage(), path); - } - - return path; - } + /** + * Method called to enforce the filter clause path to be valid. This method is called for each + * path. + * + * @param path Path to be validated + * @return Path after validation - currently not changed + */ + protected abstract String validateFilterClausePath(String path); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/SortClauseDeserializer.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/SortClauseBuilder.java similarity index 84% rename from src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/SortClauseDeserializer.java rename to src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/SortClauseBuilder.java index 2c0a4b290b..293c6934e1 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/SortClauseDeserializer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/SortClauseBuilder.java @@ -1,55 +1,70 @@ -package io.stargate.sgv2.jsonapi.api.model.command.deserializers; +package io.stargate.sgv2.jsonapi.api.model.command.builders; import static io.stargate.sgv2.jsonapi.util.JsonUtil.arrayNodeToVector; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.EJSONWrapper; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; +import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.DocumentPath; import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules; import io.stargate.sgv2.jsonapi.util.JsonUtil; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; -/** {@link StdDeserializer} for the {@link SortClause}. */ -public class SortClauseDeserializer extends StdDeserializer { +/** + * Object for converting {@link JsonNode} (from {@link SortDefinition}) into {@link SortClause}. + * Process will validate structure of the JSON, and also validate values of the sort expressions. + */ +public abstract class SortClauseBuilder { + protected final T schema; - /** No-arg constructor explicitly needed. */ - protected SortClauseDeserializer() { - super(SortClause.class); + protected SortClauseBuilder(T schema) { + this.schema = Objects.requireNonNull(schema); } - /** {@inheritDoc} */ - @Override - public SortClause deserialize(JsonParser parser, DeserializationContext ctxt) throws IOException { - JsonNode node = ctxt.readTree(parser); + public static SortClauseBuilder builderFor(SchemaObject schema) { + return switch (schema) { + case CollectionSchemaObject collection -> new CollectionSortClauseBuilder(collection); + case TableSchemaObject table -> new TableSortClauseBuilder(table); + default -> + throw new UnsupportedOperationException( + String.format( + "Unsupported schema object class for `SortClauseBuilder`: %s", + schema.getClass())); + }; + } - // if missing or null, return null back + public SortClause build(JsonNode node) { + // if missing or null, return "empty" sort clause if (node.isMissingNode() || node.isNull()) { - // TODO should we return empty sort clause instead? - return null; + return SortClause.empty(); } // otherwise, if it's not object throw exception - if (!node.isObject()) { + if (!(node instanceof ObjectNode sortNode)) { throw ErrorCodeV1.INVALID_SORT_CLAUSE.toApiException( "Sort clause must be submitted as json object"); } - ObjectNode sortNode = (ObjectNode) node; + return buildAndValidate(sortNode); + } + + protected abstract SortClause buildAndValidate(ObjectNode sortNode); + protected SortClause defaultBuildAndValidate(ObjectNode sortNode) { // safe to iterate, we know it's an Object Iterator> fieldIter = sortNode.fields(); int totalFields = sortNode.size(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableFilterClauseBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableFilterClauseBuilder.java index aa78298b2c..9bdbea4655 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableFilterClauseBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableFilterClauseBuilder.java @@ -19,4 +19,9 @@ protected boolean isDocId(String path) { protected FilterClause validateAndBuild(LogicalExpression implicitAnd) { return new FilterClause(implicitAnd); } + + @Override + protected String validateFilterClausePath(String path) { + return path; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableSortClauseBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableSortClauseBuilder.java new file mode 100644 index 0000000000..ba2e94c32c --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/builders/TableSortClauseBuilder.java @@ -0,0 +1,17 @@ +package io.stargate.sgv2.jsonapi.api.model.command.builders; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; + +/** {@link SortClauseBuilder} to use with Tables. */ +public class TableSortClauseBuilder extends SortClauseBuilder { + public TableSortClauseBuilder(TableSchemaObject table) { + super(table); + } + + @Override + public SortClause buildAndValidate(ObjectNode sortNode) { + return defaultBuildAndValidate(sortNode); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/FilterSpec.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/FilterDefinition.java similarity index 91% rename from src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/FilterSpec.java rename to src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/FilterDefinition.java index 3afecd2a91..3ddd4e3b5a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/FilterSpec.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/FilterDefinition.java @@ -20,7 +20,7 @@ """ {"name": "Aaron", "country": {"$eq": "NZ"}, "age": {"$gt": 40}} """) -public class FilterSpec extends JsonDefinition { +public class FilterDefinition extends JsonDefinition { /** * Lazily deserialized {@link FilterClause} from the JSON value. We need this due to existing * reliance on specific stateful instances of {@link FilterClause}. @@ -31,7 +31,7 @@ public class FilterSpec extends JsonDefinition { * To deserialize the whole JSON value, need to ensure DELEGATING mode (instead of PROPERTIES). */ @JsonCreator(mode = JsonCreator.Mode.DELEGATING) - public FilterSpec(JsonNode json) { + public FilterDefinition(JsonNode json) { super(json); } @@ -42,7 +42,8 @@ public FilterSpec(JsonNode json) { * @param ctx The command context to resolve the filter clause. * @return The resolved filter clause. */ - public FilterClause toFilterClause(CommandContext ctx) { + @Override + public FilterClause build(CommandContext ctx) { if (filterClause == null) { filterClause = FilterClauseBuilder.builderFor(ctx.schemaObject()) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/SortDefinition.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/SortDefinition.java new file mode 100644 index 0000000000..267129639d --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/SortDefinition.java @@ -0,0 +1,94 @@ +package io.stargate.sgv2.jsonapi.api.model.command.clause.filter; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; +import io.stargate.sgv2.jsonapi.api.model.command.JsonDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.builders.SortClauseBuilder; +import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; +import java.util.List; +import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; +import org.eclipse.microprofile.openapi.annotations.media.Schema; + +/** + * Intermediate lightly-processed container for JSON that specifies a {@link SortClause}, and allows + * for lazy deserialization into a {@link SortClause}. + */ +@Schema( + type = SchemaType.OBJECT, + implementation = SortClause.class, + example = + """ + {"user.age" : -1, "user.name" : 1} + """) +public class SortDefinition extends JsonDefinition { + /** + * Lazily deserialized {@link SortClause} from the JSON value. We need this due to existing + * reliance on specific stateful instances of {@link SortClause}. + */ + private SortClause sortClause; + + /** + * To deserialize the whole JSON value, need to ensure DELEGATING mode (instead of PROPERTIES). + */ + @JsonCreator(mode = JsonCreator.Mode.DELEGATING) + public SortDefinition(JsonNode json) { + super(json); + } + + private SortDefinition(SortClause sortClause) { + // We do not really provide actual JSON value, but need to pass something to the parent + super(JsonNodeFactory.instance.nullNode()); + this.sortClause = sortClause; + } + + /** + * Alternate constructor used to "wrap" already constructed {@link SortClause} into a {@link + * SortDefinition} instance. Used by Find-and-Rerank functionality to pass already resolved {@link + * SortClause}s. + * + * @param sortClause Actual sort clause to be wrapped. + */ + public static SortDefinition wrap(SortClause sortClause) { + return new SortDefinition(sortClause); + } + + /** + * Convert the JSON value to a {@link SortClause} instance and cache it, so further calls will + * return the same instance. + */ + public SortClause build(CommandContext ctx) { + if (sortClause == null) { + sortClause = SortClauseBuilder.builderFor(ctx.schemaObject()).build(json()); + } + return sortClause; + } + + /** + * Helper method unfortunately needed to check if the sort clause is a vector search clause, + * called from context where we do not have access to the {@link SortClause} yet. + * + * @return True if the Sort clause definition contains a vector search clause. + */ + public boolean hasVsearchClause() { + // We will either be wrapping the clause or have a json value: + if (sortClause != null) { + return sortClause.hasVsearchClause(); + } + return json().has(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD); + } + + /** + * Helper method needed for logging purposes, where caller does not have access to the {@link + * CommandContext}. + */ + public List getSortExpressionPaths() { + // We will either be wrapping the clause or have a json value: + if (sortClause != null) { + return sortClause.sortExpressions().stream().map(expr -> expr.path()).toList(); + } + return json().properties().stream().map(entry -> entry.getKey()).toList(); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/ValueComparisonOperator.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/ValueComparisonOperator.java index dd1e154899..6d648ebe5c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/ValueComparisonOperator.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/filter/ValueComparisonOperator.java @@ -1,5 +1,7 @@ package io.stargate.sgv2.jsonapi.api.model.command.clause.filter; +import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; + /** * List of value operator that can be used in Filter clause Have commented the unsupported * operators, will add it as we support them @@ -12,7 +14,8 @@ public enum ValueComparisonOperator implements FilterOperator { GT("$gt"), GTE("$gte"), LT("$lt"), - LTE("$lte"); + LTE("$lte"), + MATCH("$match"); private String operator; @@ -44,6 +47,10 @@ public FilterOperator invert() { return GTE; case LTE: return GT; + case MATCH: + // No way to do "not matches" (not supported by database) + throw ErrorCodeV1.INVALID_FILTER_EXPRESSION.toApiException( + "cannot use $not to invert $match operator"); } return this; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSort.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSort.java index 3decd72e82..1c16549506 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSort.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSort.java @@ -1,6 +1,7 @@ package io.stargate.sgv2.jsonapi.api.model.command.clause.sort; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import io.stargate.sgv2.jsonapi.metrics.CommandFeatures; import io.stargate.sgv2.jsonapi.util.recordable.Recordable; import java.util.*; import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; @@ -26,17 +27,20 @@ {"$sort" : {"$hybrid" : {"$vectorize" : "vectorize sort query" , "$lexical": "lexical sort" }}} {"$sort" : {"$hybrid" : {"$vector" : [1,2,3] , "$lexical": "lexical sort" }}} """) -public record FindAndRerankSort(String vectorizeSort, String lexicalSort, float[] vectorSort) +public record FindAndRerankSort( + String vectorizeSort, String lexicalSort, float[] vectorSort, CommandFeatures commandFeatures) implements Recordable { - static final FindAndRerankSort NO_ARG_SORT = new FindAndRerankSort(null, null, null); + public static final FindAndRerankSort NO_ARG_SORT = + new FindAndRerankSort(null, null, null, CommandFeatures.EMPTY); @Override public DataRecorder recordTo(DataRecorder dataRecorder) { return dataRecorder .append("vectorizeSort", vectorizeSort) .append("lexicalSort", lexicalSort) - .append("vectorSort", Arrays.toString(vectorSort)); + .append("vectorSort", Arrays.toString(vectorSort)) + .append("commandFeatures", commandFeatures); } /** @@ -47,14 +51,16 @@ public DataRecorder recordTo(DataRecorder dataRecorder) { */ @Override public boolean equals(Object obj) { - return Objects.equals(vectorizeSort, ((FindAndRerankSort) obj).vectorizeSort) - && Objects.equals(lexicalSort, ((FindAndRerankSort) obj).lexicalSort) - && Arrays.equals(vectorSort, ((FindAndRerankSort) obj).vectorSort); + return (obj instanceof FindAndRerankSort other) + && Objects.equals(vectorizeSort, other.vectorizeSort) + && Objects.equals(lexicalSort, other.lexicalSort) + && Arrays.equals(vectorSort, other.vectorSort) + && Objects.equals(commandFeatures, other.commandFeatures); } /** Override to do a value equality hash on the vector */ @Override public int hashCode() { - return Objects.hash(vectorizeSort, lexicalSort, Arrays.hashCode(vectorSort)); + return Objects.hash(vectorizeSort, lexicalSort, Arrays.hashCode(vectorSort), commandFeatures); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializer.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializer.java index 04efab2bb5..6767f47f98 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializer.java @@ -13,6 +13,8 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.TextNode; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.EJSONWrapper; +import io.stargate.sgv2.jsonapi.metrics.CommandFeature; +import io.stargate.sgv2.jsonapi.metrics.CommandFeatures; import io.stargate.sgv2.jsonapi.util.JsonFieldMatcher; import java.io.IOException; import java.util.*; @@ -29,14 +31,14 @@ public class FindAndRerankSortClauseDeserializer extends StdDeserializer MATCH_HYBRID_FIELD = new JsonFieldMatcher<>(JsonNode.class, List.of(HYBRID_FIELD), List.of()); // user is specifying the sorts individually // can be - // { "sort" : { "$hybrid" : { "vectorize" : "i like cheese", "$lexical" : "cows"} } } + // { "sort" : { "$hybrid" : { "$vectorize" : "i like cheese", "$lexical" : "cows"} } } // { "sort" : { "$hybrid" : { "$vector" : [1, 2, 3], "$lexical" : "cows"} } } // they could also set something to null to skip that read, that is handled by the resolver // later. @@ -58,7 +60,7 @@ public FindAndRerankSort deserialize( return switch (deserializationContext.readTree(jsonParser)) { case NullNode ignored -> // this is {"sort" : null} FindAndRerankSort.NO_ARG_SORT; - case ObjectNode objectNode -> deserialise(jsonParser, objectNode); + case ObjectNode objectNode -> deserialize(jsonParser, objectNode); default -> throw new JsonMappingException( jsonParser, "sort clause must be an object or null", jsonParser.currentLocation()); @@ -66,13 +68,13 @@ public FindAndRerankSort deserialize( } /** - * Deserialise the sort clause from an object node. + * Deserialize the sort clause from an object node. * * @param sort The sort clause as an object node, no check for nulls. * @return The {@link FindAndRerankSort} that reflects the request without validation, e.g. * checking that a vectorize or vector sort is provided but not both. */ - private static FindAndRerankSort deserialise(JsonParser jsonParser, ObjectNode sort) + private static FindAndRerankSort deserialize(JsonParser jsonParser, ObjectNode sort) throws JsonMappingException { // { "sort" : { } } @@ -83,9 +85,12 @@ private static FindAndRerankSort deserialise(JsonParser jsonParser, ObjectNode s var hybridMatch = MATCH_HYBRID_FIELD.matchAndThrow(sort, jsonParser, ERROR_CONTEXT); return switch (hybridMatch.matched().get(HYBRID_FIELD)) { - case TextNode textNode -> // using the same text for vectorize and for lexical, no vector - new FindAndRerankSort( - normalizedText(textNode.asText()), normalizedText(textNode.asText()), null); + case TextNode textNode -> { + // using the same text for vectorize and for lexical, no vector + var normalizedText = normalizedText(textNode.asText().trim()); + yield new FindAndRerankSort( + normalizedText, normalizedText, null, CommandFeatures.of(CommandFeature.HYBRID)); + } case ObjectNode objectNode -> deserializeHybridObject(jsonParser, objectNode); case JsonNode node -> throw JsonFieldMatcher.errorForWrongType( @@ -97,6 +102,7 @@ private static FindAndRerankSort deserializeHybridObject( JsonParser jsonParser, ObjectNode hybridObject) throws JsonMappingException { var sortMatch = MATCH_SORT_FIELDS.matchAndThrow(hybridObject, jsonParser, ERROR_CONTEXT); + CommandFeatures commandFeatures = CommandFeatures.of(CommandFeature.HYBRID); var vectorizeText = switch (sortMatch.matched().get(VECTOR_EMBEDDING_TEXT_FIELD)) { @@ -109,10 +115,12 @@ private static FindAndRerankSort deserializeHybridObject( case NullNode ignored -> { // explict setting to null is allowed // { "sort" : { "$hybrid" : { "$vectorize" : null, + commandFeatures.addFeature(CommandFeature.VECTORIZE); yield null; } case TextNode textNode -> { // { "sort" : { "$hybrid" : { "$vectorize" : "I like cheese", + commandFeatures.addFeature(CommandFeature.VECTORIZE); yield normalizedText(textNode.asText().trim()); } case JsonNode node -> @@ -135,10 +143,12 @@ private static FindAndRerankSort deserializeHybridObject( case NullNode ignored -> { // explict setting to null is allowed // { "sort" : { "$hybrid" : { "$lexical" : null, + commandFeatures.addFeature(CommandFeature.LEXICAL); yield null; } case TextNode textNode -> { // { "sort" : { "$hybrid" : { "$lexical" : "cheese", + commandFeatures.addFeature(CommandFeature.LEXICAL); yield normalizedText(textNode.asText().trim()); } case JsonNode node -> @@ -161,10 +171,12 @@ private static FindAndRerankSort deserializeHybridObject( case NullNode ignored -> { // explict setting to null is allowed // { "sort" : { "$hybrid" : { "$vector" : null, + commandFeatures.addFeature(CommandFeature.VECTOR); yield null; } case ArrayNode arrayNode -> { // { "sort" : { "$hybrid" : { "vector" : [1,2,3], + commandFeatures.addFeature(CommandFeature.VECTOR); yield arrayNodeToVector(arrayNode); } case ObjectNode objectNode -> { @@ -184,6 +196,7 @@ private static FindAndRerankSort deserializeHybridObject( ArrayNode.class, ObjectNode.class); } + commandFeatures.addFeature(CommandFeature.VECTOR); yield ejson.getVectorValueForBinary(); } case JsonNode node -> @@ -197,7 +210,7 @@ private static FindAndRerankSort deserializeHybridObject( ObjectNode.class); }; - return new FindAndRerankSort(vectorizeText, lexicalText, vector); + return new FindAndRerankSort(vectorizeText, lexicalText, vector, commandFeatures); } /** diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/SortClause.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/SortClause.java index e026ec4ca6..9c18c2b8bd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/SortClause.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/SortClause.java @@ -1,32 +1,23 @@ package io.stargate.sgv2.jsonapi.api.model.command.clause.sort; import com.datastax.oss.driver.api.core.CqlIdentifier; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import io.stargate.sgv2.jsonapi.api.model.command.deserializers.SortClauseDeserializer; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.projection.IndexingProjector; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import jakarta.validation.Valid; +import java.util.Collections; import java.util.List; -import java.util.Map; -import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; -import org.eclipse.microprofile.openapi.annotations.media.Schema; /** * Internal model for the sort clause that can be used in the commands. * * @param sortExpressions Ordered list of sort expressions. */ -@JsonDeserialize(using = SortClauseDeserializer.class) -@Schema( - type = SchemaType.OBJECT, - implementation = Map.class, - example = - """ - {"user.age" : -1, "user.name" : 1} - """) public record SortClause(@Valid List sortExpressions) { + public static SortClause empty() { + return new SortClause(Collections.emptyList()); + } public boolean isEmpty() { return sortExpressions == null || sortExpressions.isEmpty(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CountDocumentsCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CountDocumentsCommand.java index f4f1204419..76a72fa511 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CountDocumentsCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CountDocumentsCommand.java @@ -6,7 +6,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.Filterable; import io.stargate.sgv2.jsonapi.api.model.command.NoOptionsCommand; import io.stargate.sgv2.jsonapi.api.model.command.ReadCommand; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; import jakarta.validation.Valid; import org.eclipse.microprofile.openapi.annotations.media.Schema; @@ -14,7 +14,8 @@ description = "Command that returns count of documents in a collection based on the collection.") @JsonTypeName(CommandName.Names.COUNT_DOCUMENTS) -public record CountDocumentsCommand(@Valid @JsonProperty("filter") FilterSpec filterSpec) +public record CountDocumentsCommand( + @Valid @JsonProperty("filter") FilterDefinition filterDefinition) implements ReadCommand, NoOptionsCommand, Filterable { /** {@inheritDoc} */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateTextIndexCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateTextIndexCommand.java new file mode 100644 index 0000000000..ff0e0d4170 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateTextIndexCommand.java @@ -0,0 +1,71 @@ +package io.stargate.sgv2.jsonapi.api.model.command.impl; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.annotation.JsonTypeName; +import io.stargate.sgv2.jsonapi.api.model.command.CollectionCommand; +import io.stargate.sgv2.jsonapi.api.model.command.CommandName; +import io.stargate.sgv2.jsonapi.api.model.command.IndexCreationCommand; +import io.stargate.sgv2.jsonapi.api.model.command.table.definition.indexes.TextIndexDefinitionDesc; +import io.stargate.sgv2.jsonapi.config.constants.TableDescConstants; +import io.stargate.sgv2.jsonapi.config.constants.TableDescDefaults; +import io.stargate.sgv2.jsonapi.service.schema.tables.ApiIndexType; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotNull; +import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; +import org.eclipse.microprofile.openapi.annotations.media.Schema; + +@Schema( + description = + "Creates an index on a text column that can be used for lexical filtering and sorting.") +@JsonTypeName(CommandName.Names.CREATE_TEXT_INDEX) +@JsonPropertyOrder({ + TableDescConstants.IndexDesc.NAME, + TableDescConstants.IndexDesc.DEFINITION, + TableDescConstants.IndexDesc.INDEX_TYPE, + TableDescConstants.IndexDesc.OPTIONS +}) +public record CreateTextIndexCommand( + @Schema(description = "Name of the Index to create.") + @JsonProperty(TableDescConstants.IndexDesc.NAME) + String name, + @NotNull + @Schema(description = "Definition of the index to create.", type = SchemaType.OBJECT) + @JsonProperty(TableDescConstants.IndexDesc.DEFINITION) + TextIndexDefinitionDesc definition, + @JsonInclude(JsonInclude.Include.NON_NULL) + @Nullable + @Schema( + description = + "Optional type of the index to create. The only supported value is '" + + ApiIndexType.Constants.TEXT + + "'.", + type = SchemaType.STRING, + defaultValue = ApiIndexType.Constants.TEXT) + @JsonProperty(TableDescConstants.IndexDesc.INDEX_TYPE) + String indexType, + @Nullable + @JsonInclude(JsonInclude.Include.NON_NULL) + @Schema(description = "Options for the command.", type = SchemaType.OBJECT) + @JsonProperty(TableDescConstants.IndexDesc.OPTIONS) + CommandOptions options) + implements CollectionCommand, IndexCreationCommand { + + /** Options for the command */ + public record CommandOptions( + @Nullable + @Schema( + description = "True to ignore if index with the same name already exists.", + defaultValue = + TableDescDefaults.CreateTextIndexOptionsDefaults.Constants.IF_NOT_EXISTS, + type = SchemaType.BOOLEAN, + implementation = Boolean.class) + Boolean ifNotExists) {} + + /** {@inheritDoc} */ + @Override + public CommandName commandName() { + return CommandName.CREATE_TEXT_INDEX; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteManyCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteManyCommand.java index 983034d54d..3754002872 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteManyCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteManyCommand.java @@ -4,14 +4,14 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import io.stargate.sgv2.jsonapi.api.model.command.*; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterClause; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; import jakarta.validation.Valid; import org.eclipse.microprofile.openapi.annotations.media.Schema; /** * Representation of the deleteMany API {@link Command}. * - * @param filterSpec {@link FilterClause} used to identify documents. + * @param filterDefinition {@link FilterClause} used to identify documents. */ @Schema( description = @@ -23,7 +23,7 @@ public record DeleteManyCommand( implementation = FilterClause.class) @Valid @JsonProperty("filter") - FilterSpec filterSpec) + FilterDefinition filterDefinition) implements ModifyCommand, NoOptionsCommand, Filterable { /** {@inheritDoc} */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteOneCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteOneCommand.java index ff82563012..57bca06994 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteOneCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/DeleteOneCommand.java @@ -4,8 +4,8 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import io.stargate.sgv2.jsonapi.api.model.command.*; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterClause; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; import org.eclipse.microprofile.openapi.annotations.media.Schema; @@ -13,7 +13,7 @@ /** * Representation of the deleteOne API {@link Command}. * - * @param filterSpec {@link FilterClause} used to identify a document. + * @param filterDefinition {@link FilterClause} used to identify a document. */ @Schema(description = "Command that finds a single document and deletes it from a collection") @JsonTypeName(CommandName.Names.DELETE_ONE) @@ -24,8 +24,8 @@ public record DeleteOneCommand( implementation = FilterClause.class) @Valid @JsonProperty("filter") - FilterSpec filterSpec, - @Valid @JsonProperty("sort") SortClause sortClause) + FilterDefinition filterDefinition, + @Valid @JsonProperty("sort") SortDefinition sortDefinition) implements ModifyCommand, NoOptionsCommand, Filterable, Sortable { /** {@inheritDoc} */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindAndRerankCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindAndRerankCommand.java index fa9d673e30..7e39f42f4e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindAndRerankCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindAndRerankCommand.java @@ -13,9 +13,11 @@ import com.fasterxml.jackson.databind.node.NumericNode; import com.fasterxml.jackson.databind.node.ObjectNode; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.FindAndRerankSort; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; +import io.stargate.sgv2.jsonapi.metrics.CommandFeature; +import io.stargate.sgv2.jsonapi.metrics.CommandFeatures; import io.stargate.sgv2.jsonapi.util.JsonFieldMatcher; import io.stargate.sgv2.jsonapi.util.recordable.Recordable; import jakarta.validation.Valid; @@ -31,11 +33,14 @@ "Finds documents using using vector and lexical sorting, then reranks the results.") @JsonTypeName(CommandName.Names.FIND_AND_RERANK) public record FindAndRerankCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, @JsonProperty("projection") JsonNode projectionDefinition, @Valid @JsonProperty("sort") FindAndRerankSort sortClause, @Valid @Nullable Options options) implements ReadCommand, Filterable, Projectable, Windowable { + public FindAndRerankCommand { + sortClause = (sortClause == null) ? FindAndRerankSort.NO_ARG_SORT : sortClause; + } // NOTE: is not VectorSortable because it has its own sort clause. @@ -45,6 +50,16 @@ public CommandName commandName() { return CommandName.FIND_AND_RERANK; } + @Override + public void addCommandFeatures(CommandFeatures commandFeatures) { + if (sortClause != null) { + commandFeatures.addAll(sortClause.commandFeatures()); + } + if (options != null && options.hybridLimits() != null) { + commandFeatures.addAll(options.hybridLimits().commandFeatures()); + } + } + public record Options( @Positive(message = "limit should be greater than `0`") @Schema( @@ -94,13 +109,17 @@ public record Options( public record HybridLimits( @JsonProperty(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD) int vectorLimit, /** ---- */ - @JsonProperty(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD) int lexicalLimit) + @JsonProperty(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD) int lexicalLimit, + CommandFeatures commandFeatures) implements Recordable { - public static final HybridLimits DEFAULT = new HybridLimits(50, 50); + public static final HybridLimits DEFAULT = new HybridLimits(50, 50, CommandFeatures.EMPTY); @Override public DataRecorder recordTo(DataRecorder dataRecorder) { - return dataRecorder.append("vectorLimit", vectorLimit).append("lexicalLimit", lexicalLimit); + return dataRecorder + .append("vectorLimit", vectorLimit) + .append("lexicalLimit", lexicalLimit) + .append("commandFeatures", commandFeatures); } } @@ -128,8 +147,8 @@ public HybridLimits deserialize( JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { return switch (deserializationContext.readTree(jsonParser)) { - case NumericNode number -> deserialise(jsonParser, number); - case ObjectNode object -> deserialise(jsonParser, object); + case NumericNode number -> deserialize(jsonParser, number); + case ObjectNode object -> deserialize(jsonParser, object); case JsonNode node -> throw new JsonMappingException( jsonParser, @@ -138,15 +157,16 @@ public HybridLimits deserialize( }; } - private HybridLimits deserialise(JsonParser jsonParser, NumericNode limitsNumber) + private HybridLimits deserialize(JsonParser jsonParser, NumericNode limitsNumber) throws JsonMappingException { return new HybridLimits( normaliseLimit(jsonParser, limitsNumber, VECTOR_EMBEDDING_FIELD), - normaliseLimit(jsonParser, limitsNumber, LEXICAL_CONTENT_FIELD)); + normaliseLimit(jsonParser, limitsNumber, LEXICAL_CONTENT_FIELD), + CommandFeatures.of(CommandFeature.HYBRID_LIMITS_NUMBER)); } - private HybridLimits deserialise(JsonParser jsonParser, ObjectNode limitsObject) + private HybridLimits deserialize(JsonParser jsonParser, ObjectNode limitsObject) throws JsonMappingException { var limitMatch = MATCH_LIMIT_FIELDS.matchAndThrow(limitsObject, jsonParser, ERROR_CONTEXT); @@ -155,7 +175,9 @@ private HybridLimits deserialise(JsonParser jsonParser, ObjectNode limitsObject) normaliseLimit( jsonParser, limitMatch.matched().get(VECTOR_EMBEDDING_FIELD), VECTOR_EMBEDDING_FIELD), normaliseLimit( - jsonParser, limitMatch.matched().get(LEXICAL_CONTENT_FIELD), LEXICAL_CONTENT_FIELD)); + jsonParser, limitMatch.matched().get(LEXICAL_CONTENT_FIELD), LEXICAL_CONTENT_FIELD), + CommandFeatures.of( + CommandFeature.HYBRID_LIMITS_VECTOR, CommandFeature.HYBRID_LIMITS_LEXICAL)); } private int normaliseLimit(JsonParser jsonParser, NumericNode limitNode, String fieldName) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindCommand.java index a36332a258..0b940dc649 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindCommand.java @@ -5,8 +5,8 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.databind.JsonNode; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import io.stargate.sgv2.jsonapi.api.model.command.validation.CheckFindOption; import jakarta.validation.Valid; import jakarta.validation.constraints.Positive; @@ -20,9 +20,9 @@ @JsonTypeName(CommandName.Names.FIND) @CheckFindOption public record FindCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, @JsonProperty("projection") JsonNode projectionDefinition, - @Valid @JsonProperty("sort") SortClause sortClause, + @Valid @JsonProperty("sort") SortDefinition sortDefinition, @Valid @Nullable Options options) implements ReadCommand, Filterable, Projectable, Sortable, Windowable, VectorSortable { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindEmbeddingProvidersCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindEmbeddingProvidersCommand.java index 3aafc1c602..c6bfd1931b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindEmbeddingProvidersCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindEmbeddingProvidersCommand.java @@ -3,12 +3,28 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import io.stargate.sgv2.jsonapi.api.model.command.CommandName; import io.stargate.sgv2.jsonapi.api.model.command.GeneralCommand; -import io.stargate.sgv2.jsonapi.api.model.command.NoOptionsCommand; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Pattern; +import javax.annotation.Nullable; +import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; import org.eclipse.microprofile.openapi.annotations.media.Schema; @Schema(description = "Lists the available Embedding Providers for this database.") @JsonTypeName(CommandName.Names.FIND_EMBEDDING_PROVIDERS) -public record FindEmbeddingProvidersCommand() implements GeneralCommand, NoOptionsCommand { +public record FindEmbeddingProvidersCommand( + @Valid @Nullable @Schema(type = SchemaType.OBJECT, implementation = Options.class) + Options options) + implements GeneralCommand { + + public record Options( + @Nullable + @Schema( + description = + "Filter models to include required support status. If omitted the entire Options, only SUPPORTED models are returned, which can be used when creating a new Collection or Table. Available values are SUPPORTED, DEPRECATED, and END_OF_LIFE (case-insensitive). Set to null or an empty string to return all models.", + type = SchemaType.STRING, + implementation = String.class) + @Pattern(regexp = "(?i)^(SUPPORTED|DEPRECATED|END_OF_LIFE)?$") + String filterModelStatus) {} /** {@inheritDoc} */ @Override diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommand.java index 064fc76cfd..66f023d23d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommand.java @@ -4,8 +4,8 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.databind.JsonNode; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import jakarta.validation.Valid; import org.eclipse.microprofile.openapi.annotations.media.Schema; @@ -14,8 +14,8 @@ "Command that finds a single JSON document from a collection and deletes it. The deleted document is returned") @JsonTypeName(CommandName.Names.FIND_ONE_AND_DELETE) public record FindOneAndDeleteCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, - @Valid @JsonProperty("sort") SortClause sortClause, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, + @Valid @JsonProperty("sort") SortDefinition sortDefinition, @JsonProperty("projection") JsonNode projectionDefinition) implements ModifyCommand, Filterable, Projectable, Sortable { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommand.java index 73e09bd8ee..e73e672cb9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommand.java @@ -5,8 +5,8 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Pattern; @@ -18,8 +18,8 @@ "Command that finds a single JSON document from a collection and replaces it with the replacement document.") @JsonTypeName(CommandName.Names.FIND_ONE_AND_REPLACE) public record FindOneAndReplaceCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, - @Valid @JsonProperty("sort") SortClause sortClause, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, + @Valid @JsonProperty("sort") SortDefinition sortDefinition, @JsonProperty("projection") JsonNode projectionDefinition, @NotNull @Valid @JsonProperty("replacement") ObjectNode replacementDocument, @Valid @Nullable Options options) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndUpdateCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndUpdateCommand.java index eb25ea86f8..c7c009b6ce 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndUpdateCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndUpdateCommand.java @@ -4,8 +4,8 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.databind.JsonNode; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; @@ -18,9 +18,9 @@ "Command that finds a single JSON document from a collection and updates the value provided in the update clause.") @JsonTypeName(CommandName.Names.FIND_ONE_AND_UPDATE) public record FindOneAndUpdateCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, @JsonProperty("projection") JsonNode projectionDefinition, - @Valid @JsonProperty("sort") SortClause sortClause, + @Valid @JsonProperty("sort") SortDefinition sortDefinition, @NotNull @Valid @JsonProperty("update") UpdateClause updateClause, @Valid @Nullable Options options) implements ReadCommand, Filterable, Projectable, Sortable, Updatable { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneCommand.java index 8efc5e1af6..a7b521d371 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneCommand.java @@ -4,8 +4,8 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.databind.JsonNode; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import jakarta.validation.Valid; import java.util.Optional; import javax.annotation.Nullable; @@ -15,9 +15,9 @@ @Schema(description = "Command that finds a single JSON document from a collection.") @JsonTypeName(CommandName.Names.FIND_ONE) public record FindOneCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, @JsonProperty("projection") JsonNode projectionDefinition, - @Valid @JsonProperty("sort") SortClause sortClause, + @Valid @JsonProperty("sort") SortDefinition sortDefinition, @Valid @Nullable Options options) implements ReadCommand, Filterable, Projectable, Sortable, Windowable, VectorSortable { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindRerankingProvidersCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindRerankingProvidersCommand.java index 5389bd450c..855273e59f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindRerankingProvidersCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindRerankingProvidersCommand.java @@ -3,9 +3,8 @@ import com.fasterxml.jackson.annotation.JsonTypeName; import io.stargate.sgv2.jsonapi.api.model.command.CommandName; import io.stargate.sgv2.jsonapi.api.model.command.GeneralCommand; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; import jakarta.validation.Valid; -import java.util.EnumSet; +import jakarta.validation.constraints.Pattern; import javax.annotation.Nullable; import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; import org.eclipse.microprofile.openapi.annotations.media.Schema; @@ -17,16 +16,15 @@ public record FindRerankingProvidersCommand( Options options) implements GeneralCommand { - /** - * By default, if includeModelStatus is not provided, only model in supported status will be - * returned. - */ public record Options( - @Schema( - description = "Use the option to include models as in target support status.", - type = SchemaType.OBJECT, - implementation = ModelSupport.SupportStatus.class) - EnumSet includeModelStatus) {} + @Nullable + @Schema( + description = + "Filter models to include required support status. If omitted the entire Options, only SUPPORTED models are returned, which can be used when creating a new Collection or Table. Available values are SUPPORTED, DEPRECATED, and END_OF_LIFE (case-insensitive). Set to null or an empty string to return all models.", + type = SchemaType.STRING, + implementation = String.class) + @Pattern(regexp = "(?i)^(SUPPORTED|DEPRECATED|END_OF_LIFE)?$") + String filterModelStatus) {} /** {@inheritDoc} */ @Override diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateManyCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateManyCommand.java index 4be2277c6d..17a0462526 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateManyCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateManyCommand.java @@ -7,7 +7,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.Filterable; import io.stargate.sgv2.jsonapi.api.model.command.ReadCommand; import io.stargate.sgv2.jsonapi.api.model.command.Updatable; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; @@ -20,7 +20,7 @@ "Command that finds documents from a collection and updates it with the values provided in the update clause.") @JsonTypeName(CommandName.Names.UPDATE_MANY) public record UpdateManyCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, @NotNull @Valid @JsonProperty("update") UpdateClause updateClause, @Nullable Options options) implements ReadCommand, Filterable, Updatable { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommand.java index 4c89a5d16f..905240e178 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommand.java @@ -3,8 +3,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause; import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; @@ -16,9 +16,9 @@ "Command that finds a single JSON document from a collection and updates the value provided in the update clause.") @JsonTypeName(CommandName.Names.UPDATE_ONE) public record UpdateOneCommand( - @Valid @JsonProperty("filter") FilterSpec filterSpec, + @Valid @JsonProperty("filter") FilterDefinition filterDefinition, @NotNull @Valid @JsonProperty("update") UpdateClause updateClause, - @Valid @JsonProperty("sort") SortClause sortClause, + @Valid @JsonProperty("sort") SortDefinition sortDefinition, @Nullable Options options) implements ReadCommand, Filterable, Sortable, Updatable { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/indexes/TextIndexDefinitionDesc.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/indexes/TextIndexDefinitionDesc.java new file mode 100644 index 0000000000..f140f32b22 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/table/definition/indexes/TextIndexDefinitionDesc.java @@ -0,0 +1,42 @@ +package io.stargate.sgv2.jsonapi.api.model.command.table.definition.indexes; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.databind.JsonNode; +import io.stargate.sgv2.jsonapi.config.constants.TableDescConstants; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotNull; +import org.eclipse.microprofile.openapi.annotations.enums.SchemaType; +import org.eclipse.microprofile.openapi.annotations.media.Schema; + +@JsonPropertyOrder({ + TableDescConstants.IndexDefinitionDesc.COLUMN, + TableDescConstants.IndexDefinitionDesc.OPTIONS +}) +public record TextIndexDefinitionDesc( + @NotNull + @Schema(description = "Required name of the column to index.", required = true) + @JsonProperty(TableDescConstants.IndexDefinitionDesc.COLUMN) + String column, + @JsonInclude(JsonInclude.Include.NON_NULL) + @Nullable + @Schema(description = "Indexing options.", type = SchemaType.OBJECT) + @JsonProperty(TableDescConstants.IndexDefinitionDesc.OPTIONS) + TextIndexDescOptions options) + implements IndexDefinitionDesc { + + /** Options for the vector index */ + @JsonPropertyOrder({TableDescConstants.TextIndexDefinitionDescOptions.ANALYZER}) + public record TextIndexDescOptions( + @Nullable + @Schema( + description = + """ +"Optional definition of the analyzer to use for the text index: either String (named analyzer like "english") + or Object specifying analyzer details. + If not specified, the default named analyzer (\"standard\") will be used. +""") + @JsonProperty(TableDescConstants.TextIndexDefinitionDescOptions.ANALYZER) + JsonNode analyzer) {} +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/validation/FindOptionsValidation.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/validation/FindOptionsValidation.java index b8e82755ba..e4b8ef7fab 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/validation/FindOptionsValidation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/validation/FindOptionsValidation.java @@ -22,8 +22,9 @@ public boolean isValid(FindCommand value, ConstraintValidatorContext context) { final FindCommand.Options options = value.options(); if (options == null) return true; + var sortSpec = value.sortDefinition(); context.disableDefaultConstraintViolation(); - if (options.skip() != null && value.sortClause() == null) { + if (options.skip() != null && sortSpec == null) { context .buildConstraintViolationWithTemplate("skip options should be used with sort clause") .addPropertyNode("options.skip") @@ -31,9 +32,7 @@ public boolean isValid(FindCommand value, ConstraintValidatorContext context) { return false; } - if (options.skip() != null - && value.sortClause() != null - && value.sortClause().hasVsearchClause()) { + if (options.skip() != null && sortSpec != null && sortSpec.hasVsearchClause()) { context .buildConstraintViolationWithTemplate( "skip options should not be used with vector search") @@ -42,8 +41,8 @@ public boolean isValid(FindCommand value, ConstraintValidatorContext context) { return false; } - if (value.sortClause() != null - && value.sortClause().hasVsearchClause() + if (sortSpec != null + && sortSpec.hasVsearchClause() && options.limit() != null && options.limit() > config.get().maxVectorSearchLimit()) { context diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java new file mode 100644 index 0000000000..a05460c296 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java @@ -0,0 +1,90 @@ +package io.stargate.sgv2.jsonapi.api.request; + +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +import java.util.Map; +import java.util.Optional; + +/** + * A supplier for creating {@link EmbeddingCredentials} based on the current request context, + * collection authentication configuration, and embedding provider configuration. + * + *

This class centralizes the logic for determining which credentials to use for embedding + * service calls. + */ +public class EmbeddingCredentialsSupplier { + private final String authTokenHeaderName; + private final String embeddingApiKeyHeaderName; + private final String embeddingAccessIdHeaderName; + private final String embeddingsecretIdHeaderName; + private Map authConfigFromCollection; + + public EmbeddingCredentialsSupplier( + String authTokenHeaderName, + String embeddingApiKeyHeaderName, + String embeddingAccessIdHeaderName, + String embeddingsecretIdHeaderName) { + this.authTokenHeaderName = authTokenHeaderName; + this.embeddingApiKeyHeaderName = embeddingApiKeyHeaderName; + this.embeddingAccessIdHeaderName = embeddingAccessIdHeaderName; + this.embeddingsecretIdHeaderName = embeddingsecretIdHeaderName; + } + + /** Sets the authentication configuration defined at the createCollection command. */ + public void withAuthConfigFromCollection(Map authConfigFromCollection) { + this.authConfigFromCollection = authConfigFromCollection; + } + + /** + * Creates an {@link EmbeddingCredentials} instance based on the current request context and + * provider configuration. + * + * @param requestContext The current request context containing HTTP headers. + * @param providerConfig The configuration for the embedding provider. + * @return An instance of {@link EmbeddingCredentials} with the appropriate credentials. + */ + public EmbeddingCredentials create( + RequestContext requestContext, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + + var embeddingApi = requestContext.getHttpHeaders().getHeader(this.embeddingApiKeyHeaderName); + var accessId = requestContext.getHttpHeaders().getHeader(this.embeddingAccessIdHeaderName); + var secretId = requestContext.getHttpHeaders().getHeader(this.embeddingsecretIdHeaderName); + + // If these three conditions are met, we use the auth token as the embeddingApiKey to pass + // through to Embedding Providers: + + // 1: User did not provide x-embedding-api-key + boolean isEmbeddingApiKeyMissing = (embeddingApi == null); + + // 2: Provider has authTokenPassThroughForNoneAuth set to true + boolean isAuthTokenPassThroughEnabled = + providerConfig != null && providerConfig.authTokenPassThroughForNoneAuth(); + + // 3: Provider supports NONE auth it's enabled in the config + boolean providerSupportsNoneAuth = false; + if (providerConfig != null) { + var noneAuthConfig = + providerConfig + .supportedAuthentications() + .get(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE); + providerSupportsNoneAuth = noneAuthConfig != null && noneAuthConfig.enabled(); + } + + // 4: Collection supports NONE auth - no "authentication" in "options.vector.service" + boolean collectionSupportsNoneAuth = (authConfigFromCollection == null); + + if (isEmbeddingApiKeyMissing + && isAuthTokenPassThroughEnabled + && providerSupportsNoneAuth + && collectionSupportsNoneAuth) { + var authToken = requestContext.getHttpHeaders().getHeader(this.authTokenHeaderName); + return new EmbeddingCredentials( + Optional.ofNullable(authToken), Optional.empty(), Optional.empty()); + } + + return new EmbeddingCredentials( + Optional.ofNullable(embeddingApi), + Optional.ofNullable(accessId), + Optional.ofNullable(secretId)); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java index 65fc8574f0..b59458de5d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java @@ -4,11 +4,14 @@ import com.fasterxml.uuid.NoArgGenerator; import io.stargate.sgv2.jsonapi.api.request.tenant.DataApiTenantResolver; import io.stargate.sgv2.jsonapi.api.request.token.DataApiTokenResolver; +import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.vertx.ext.web.RoutingContext; import jakarta.enterprise.context.RequestScoped; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; +import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.SecurityContext; +import java.util.List; import java.util.Optional; /** @@ -25,11 +28,13 @@ public class RequestContext { private final Optional tenantId; private final Optional cassandraToken; - private final EmbeddingCredentials embeddingCredentials; + private final EmbeddingCredentialsSupplier embeddingCredentialsSupplier; private final RerankingCredentials rerankingCredentials; private final HttpHeaderAccess httpHeaders; private final String requestId; + private final String userAgent; + /** * Constructor that will be useful in the offline library mode, where only the tenant will be set * and accessed. @@ -38,11 +43,12 @@ public class RequestContext { */ public RequestContext(Optional tenantId) { this.tenantId = tenantId; - this.cassandraToken = Optional.empty(); - this.embeddingCredentials = null; - this.rerankingCredentials = null; + cassandraToken = Optional.empty(); + embeddingCredentialsSupplier = null; + rerankingCredentials = null; httpHeaders = null; requestId = generateRequestId(); + userAgent = null; } @Inject @@ -51,6 +57,7 @@ public RequestContext( SecurityContext securityContext, Instance tenantResolver, Instance tokenResolver, +<<<<<<< HEAD Instance embeddingCredentialsResolver) { this.tenantId = (tenantResolver.get()).resolve(routingContext, securityContext); @@ -61,12 +68,27 @@ public RequestContext( .resolveEmbeddingCredentials(tenantId.orElse(""), routingContext); this.cassandraToken = (tokenResolver.get()).resolve(routingContext, securityContext); +======= + HttpConstants httpConstants) { + + tenantId = tenantResolver.get().resolve(routingContext, securityContext); + cassandraToken = tokenResolver.get().resolve(routingContext, securityContext); +>>>>>>> main httpHeaders = new HttpHeaderAccess(routingContext.request().headers()); requestId = generateRequestId(); + userAgent = httpHeaders.getHeader(HttpHeaders.USER_AGENT); + + embeddingCredentialsSupplier = + new EmbeddingCredentialsSupplier( + httpConstants.authToken(), + httpConstants.embeddingApiKey(), + httpConstants.embeddingAccessId(), + httpConstants.embeddingSecretId()); + // if x-reranking-api-key is present, then use it, else use cassandraToken Optional rerankingApiKeyFromHeader = HeaderBasedRerankingKeyResolver.resolveRerankingKey(routingContext); - this.rerankingCredentials = + rerankingCredentials = rerankingApiKeyFromHeader .map(apiKey -> new RerankingCredentials(this.tenantId.orElse(""), Optional.of(apiKey))) .orElse( @@ -87,19 +109,23 @@ public String getRequestId() { } public Optional getTenantId() { - return this.tenantId; + return tenantId; } public Optional getCassandraToken() { - return this.cassandraToken; + return cassandraToken; + } + + public Optional getUserAgent() { + return Optional.ofNullable(userAgent); } - public EmbeddingCredentials getEmbeddingCredentials() { - return this.embeddingCredentials; + public EmbeddingCredentialsSupplier getEmbeddingCredentialsSupplier() { + return embeddingCredentialsSupplier; } public RerankingCredentials getRerankingCredentials() { - return this.rerankingCredentials; + return rerankingCredentials; } public HttpHeaderAccess getHttpHeaders() { @@ -117,6 +143,14 @@ public HttpHeaderAccess(io.vertx.core.MultiMap headers) { this.headers = headers; } + public String getHeader(String headerName) { + return headers.get(headerName); + } + + public List getHeaders(String headerName) { + return headers.getAll(headerName); + } + /** * Accessor for getting value of given header, as {@code Boolean} if (and only if!) value is one * of "true" or "false". Access by name is (and has to be) case-insensitive as per HTTP diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentialsResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentialsResolver.java deleted file mode 100644 index a60d6ff250..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentialsResolver.java +++ /dev/null @@ -1,9 +0,0 @@ -package io.stargate.sgv2.jsonapi.api.request; - -import io.vertx.ext.web.RoutingContext; - -/** Functional interface to resolve the reranking api key from the request context. */ -@FunctionalInterface -public interface RerankingCredentialsResolver { - RerankingCredentials resolveRerankingCredentials(RoutingContext context); -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java index 891457a29c..8e1cfb48a4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java @@ -2,12 +2,14 @@ import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD; +import io.micrometer.core.instrument.MeterRegistry; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.ConfigPreLoader; import io.stargate.sgv2.jsonapi.api.model.command.*; import io.stargate.sgv2.jsonapi.api.model.command.impl.AlterTableCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CountDocumentsCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateIndexCommand; +import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateTextIndexCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateVectorIndexCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.DeleteManyCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.DeleteOneCommand; @@ -23,7 +25,6 @@ import io.stargate.sgv2.jsonapi.api.model.command.impl.UpdateManyCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.UpdateOneCommand; import io.stargate.sgv2.jsonapi.api.request.RequestContext; -import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.config.constants.OpenApiConstants; import io.stargate.sgv2.jsonapi.config.feature.ApiFeature; import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures; @@ -31,7 +32,8 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableCommandResultSupplier; -import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; +import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; +import io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionCacheSupplier; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; @@ -71,26 +73,23 @@ public class CollectionResource { public static final String BASE_PATH = GeneralResource.BASE_PATH + "/{keyspace}/{collection}"; - private final MeteredCommandProcessor meteredCommandProcessor; - - @Inject private SchemaCache schemaCache; - - private EmbeddingProviderFactory embeddingProviderFactory; - - @Inject private RequestContext requestContext; - - // need to keep for a little because we have to check the schema type before making the command + // need to keep for a little because we have to check the schema type before making the command // context // TODO remove apiFeatureConfig as a property after cleanup for how we get schema from cache @Inject private FeaturesConfig apiFeatureConfig; + @Inject private RequestContext requestContext; + @Inject private SchemaCache schemaCache; private final CommandContext.BuilderSupplier contextBuilderSupplier; + private final EmbeddingProviderFactory embeddingProviderFactory; + private final MeteredCommandProcessor meteredCommandProcessor; @Inject public CollectionResource( MeteredCommandProcessor meteredCommandProcessor, + MeterRegistry meterRegistry, JsonProcessingMetricsReporter jsonProcessingMetricsReporter, - CQLSessionCache cqlSessionCache, + CqlSessionCacheSupplier sessionCacheSupplier, EmbeddingProviderFactory embeddingProviderFactory, RerankingProviderFactory rerankingProviderFactory) { this.embeddingProviderFactory = embeddingProviderFactory; @@ -99,10 +98,11 @@ public CollectionResource( contextBuilderSupplier = CommandContext.builderSupplier() .withJsonProcessingMetricsReporter(jsonProcessingMetricsReporter) - .withCqlSessionCache(cqlSessionCache) + .withCqlSessionCache(sessionCacheSupplier.get()) .withCommandConfig(ConfigPreLoader.getPreLoadOrEmpty()) .withEmbeddingProviderFactory(embeddingProviderFactory) - .withRerankingProviderFactory(rerankingProviderFactory); + .withRerankingProviderFactory(rerankingProviderFactory) + .withMeterRegistry(meterRegistry); } @Operation( @@ -136,6 +136,7 @@ public CollectionResource( // Table Only commands AlterTableCommand.class, CreateIndexCommand.class, + CreateTextIndexCommand.class, CreateVectorIndexCommand.class, ListIndexesCommand.class }), @@ -160,6 +161,7 @@ public CollectionResource( @ExampleObject(ref = "alterTableAddVectorize"), @ExampleObject(ref = "alterTableDropVectorize"), @ExampleObject(ref = "createIndex"), + @ExampleObject(ref = "createTextIndex"), @ExampleObject(ref = "createVectorIndex"), @ExampleObject(ref = "listIndexes"), @ExampleObject(ref = "insertOneTables"), @@ -202,7 +204,6 @@ public Uni> postCommand( return schemaCache .getSchemaObject( requestContext, - requestContext.getTenantId(), keyspace, collection, CommandType.DDL.equals(command.commandName().getCommandType())) @@ -249,18 +250,25 @@ public Uni> postCommand( .getFirstVectorColumnWithVectorizeDefinition() .orElse(null); } - EmbeddingProvider embeddingProvider = - (vectorColDef == null || vectorColDef.vectorizeDefinition() == null) - ? null - : embeddingProviderFactory.getConfiguration( - requestContext.getTenantId(), - requestContext.getCassandraToken(), - vectorColDef.vectorizeDefinition().provider(), - vectorColDef.vectorizeDefinition().modelName(), - vectorColDef.vectorSize(), - vectorColDef.vectorizeDefinition().parameters(), - vectorColDef.vectorizeDefinition().authentication(), - command.getClass().getSimpleName()); + + EmbeddingProvider embeddingProvider = null; + + if (vectorColDef != null && vectorColDef.vectorizeDefinition() != null) { + embeddingProvider = + embeddingProviderFactory.getConfiguration( + requestContext.getTenantId(), + requestContext.getCassandraToken(), + vectorColDef.vectorizeDefinition().provider(), + vectorColDef.vectorizeDefinition().modelName(), + vectorColDef.vectorSize(), + vectorColDef.vectorizeDefinition().parameters(), + vectorColDef.vectorizeDefinition().authentication(), + command.getClass().getSimpleName()); + requestContext + .getEmbeddingCredentialsSupplier() + .withAuthConfigFromCollection( + vectorColDef.vectorizeDefinition().authentication()); + } var commandContext = contextBuilderSupplier diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResource.java index 6bca69714b..033360a601 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResource.java @@ -1,5 +1,6 @@ package io.stargate.sgv2.jsonapi.api.v1; +import io.micrometer.core.instrument.MeterRegistry; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.ConfigPreLoader; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; @@ -7,9 +8,9 @@ import io.stargate.sgv2.jsonapi.api.model.command.GeneralCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateKeyspaceCommand; import io.stargate.sgv2.jsonapi.api.request.RequestContext; -import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.config.constants.OpenApiConstants; -import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; +import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; +import io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionCacheSupplier; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.DatabaseSchemaObject; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory; import io.stargate.sgv2.jsonapi.service.processor.MeteredCommandProcessor; @@ -41,17 +42,17 @@ public class GeneralResource { public static final String BASE_PATH = "/v1"; - private final MeteredCommandProcessor meteredCommandProcessor; - @Inject private RequestContext requestContext; private final CommandContext.BuilderSupplier contextBuilderSupplier; + private final MeteredCommandProcessor meteredCommandProcessor; @Inject public GeneralResource( MeteredCommandProcessor meteredCommandProcessor, + MeterRegistry meterRegistry, JsonProcessingMetricsReporter jsonProcessingMetricsReporter, - CQLSessionCache cqlSessionCache, + CqlSessionCacheSupplier sessionCacheSupplier, EmbeddingProviderFactory embeddingProviderFactory, RerankingProviderFactory rerankingProviderFactory) { this.meteredCommandProcessor = meteredCommandProcessor; @@ -60,10 +61,11 @@ public GeneralResource( CommandContext.builderSupplier() // old code did not set jsonProcessingMetricsReporter - Aaron Feb 10 .withJsonProcessingMetricsReporter(jsonProcessingMetricsReporter) - .withCqlSessionCache(cqlSessionCache) + .withCqlSessionCache(sessionCacheSupplier.get()) .withCommandConfig(ConfigPreLoader.getPreLoadOrEmpty()) .withEmbeddingProviderFactory(embeddingProviderFactory) - .withRerankingProviderFactory(rerankingProviderFactory); + .withRerankingProviderFactory(rerankingProviderFactory) + .withMeterRegistry(meterRegistry); } // TODO: add example for findEmbeddingProviders diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResource.java index dd867b37c9..6369a7b7e6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResource.java @@ -1,5 +1,6 @@ package io.stargate.sgv2.jsonapi.api.v1; +import io.micrometer.core.instrument.MeterRegistry; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.ConfigPreLoader; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; @@ -14,12 +15,12 @@ import io.stargate.sgv2.jsonapi.api.model.command.impl.FindCollectionsCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.ListTablesCommand; import io.stargate.sgv2.jsonapi.api.request.RequestContext; -import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.config.constants.OpenApiConstants; import io.stargate.sgv2.jsonapi.config.feature.ApiFeature; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableCommandResultSupplier; -import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; +import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; +import io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionCacheSupplier; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory; import io.stargate.sgv2.jsonapi.service.processor.MeteredCommandProcessor; @@ -55,17 +56,18 @@ public class KeyspaceResource { public static final String BASE_PATH = GeneralResource.BASE_PATH + "/{keyspace}"; - private final MeteredCommandProcessor meteredCommandProcessor; @Inject private RequestContext requestContext; private final CommandContext.BuilderSupplier contextBuilderSupplier; + private final MeteredCommandProcessor meteredCommandProcessor; @Inject public KeyspaceResource( MeteredCommandProcessor meteredCommandProcessor, + MeterRegistry meterRegistry, JsonProcessingMetricsReporter jsonProcessingMetricsReporter, - CQLSessionCache cqlSessionCache, + CqlSessionCacheSupplier sessionCacheSupplier, EmbeddingProviderFactory embeddingProviderFactory, RerankingProviderFactory rerankingProviderFactory) { this.meteredCommandProcessor = meteredCommandProcessor; @@ -74,10 +76,11 @@ public KeyspaceResource( CommandContext.builderSupplier() // old code did not pass a jsonProcessingMetricsReporter not sure why - Aaron Feb 10 .withJsonProcessingMetricsReporter(jsonProcessingMetricsReporter) - .withCqlSessionCache(cqlSessionCache) + .withCqlSessionCache(sessionCacheSupplier.get()) .withCommandConfig(ConfigPreLoader.getPreLoadOrEmpty()) .withEmbeddingProviderFactory(embeddingProviderFactory) - .withRerankingProviderFactory(rerankingProviderFactory); + .withRerankingProviderFactory(rerankingProviderFactory) + .withMeterRegistry(meterRegistry); } @Operation( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/JsonApiMetricsConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/JsonApiMetricsConfig.java index 96ca29bb89..fb77e99834 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/JsonApiMetricsConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/JsonApiMetricsConfig.java @@ -70,10 +70,6 @@ public interface JsonApiMetricsConfig { @WithDefault("command.processor.process") String metricsName(); - @NotBlank - @WithDefault("vectorize.call.duration") - String vectorizeCallDurationMetrics(); - @NotBlank @WithDefault("vectorize.input.bytes") String vectorizeInputBytesMetrics(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MetricsConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MetricsConfig.java index 6c91134cb4..c6cfb900ad 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MetricsConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MetricsConfig.java @@ -1,20 +1,3 @@ -/* - * Copyright The Stargate Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - package io.stargate.sgv2.jsonapi.api.v1.metrics; import io.smallrye.config.ConfigMapping; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MicrometerConfiguration.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MicrometerConfiguration.java deleted file mode 100644 index c258682c59..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/MicrometerConfiguration.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright The Stargate Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package io.stargate.sgv2.jsonapi.api.v1.metrics; - -import io.micrometer.core.instrument.Tag; -import io.micrometer.core.instrument.Tags; -import io.micrometer.core.instrument.config.MeterFilter; -import jakarta.enterprise.inject.Produces; -import jakarta.inject.Singleton; -import java.util.Collection; -import java.util.Map; -import java.util.stream.Collectors; - -/** Configuration of all {@link MeterFilter}s used. */ -public class MicrometerConfiguration { - - /** - * @return Produces meter filter that takes care of the global tags e.g. module tag sgv2-jsonapi - * as MeterFilter commonTag - */ - @Produces - @Singleton - public MeterFilter globalTagsMeterFilter(MetricsConfig config) { - Map globalTags = config.globalTags(); - - // if we have no global tags, use empty - if (null == globalTags || globalTags.isEmpty()) { - return new MeterFilter() {}; - } - - // transform to tags - Collection tags = - globalTags.entrySet().stream() - .map(e -> Tag.of(e.getKey(), e.getValue())) - .collect(Collectors.toList()); - - // return all - return MeterFilter.commonTags(Tags.of(tags)); - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/CommandLevelLoggingConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/config/CommandLevelLoggingConfig.java index 3232187161..9657739ad4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/CommandLevelLoggingConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/CommandLevelLoggingConfig.java @@ -30,7 +30,7 @@ public interface CommandLevelLoggingConfig { /** * @return If request info logging is enabled. */ - @WithDefault("false") + @WithDefault("true") boolean enabled(); /** diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/DatabaseType.java b/src/main/java/io/stargate/sgv2/jsonapi/config/DatabaseType.java new file mode 100644 index 0000000000..adb57f9d8c --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/DatabaseType.java @@ -0,0 +1,37 @@ +package io.stargate.sgv2.jsonapi.config; + +import io.stargate.sgv2.jsonapi.service.cqldriver.CqlCredentialsFactory; +import java.util.Objects; +import org.eclipse.microprofile.config.spi.Converter; + +/** + * The back end database the API is running against. + * + *

How we manage credentials is a bit different for each database type, see {@link + * CqlCredentialsFactory}. + */ +public enum DatabaseType { + ASTRA, + CASSANDRA, + OFFLINE_WRITER; + + /** + * Constants should only be used where we need a string constant for defaults etc, use the enum + * normally. + */ + public interface Constants { + String ASTRA = "ASTRA"; + String CASSANDRA = "CASSANDRA"; + String OFFLINE_WRITER = "OFFLINE_WRITER"; + } + + /** Used by {@link OperationsConfig.DatabaseConfig#type} */ + public static class DatabaseTypeConverter implements Converter { + @Override + public DatabaseType convert(String value) + throws IllegalArgumentException, NullPointerException { + Objects.requireNonNull(value, "value must not be null"); + return DatabaseType.valueOf(value.toUpperCase()); + } + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/OperationsConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/config/OperationsConfig.java index b2dd40b577..65b88e6d5d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/OperationsConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/OperationsConfig.java @@ -17,8 +17,6 @@ package io.stargate.sgv2.jsonapi.config; -import static io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache.CASSANDRA; - import com.datastax.oss.driver.api.core.ConsistencyLevel; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithConverter; @@ -145,6 +143,15 @@ public interface OperationsConfig { @WithDefault("true") boolean tooManyIndexesRollbackEnabled(); + /** + * Optional string that is the case-insensitive user agent string that will be used to identify if + * a request is from an SLA checker. Requests from SLA checkers may be treated differently for + * features such as caching sessions.(Empty string is treated as unset / null) + */ + @WithDefault("") + @Nullable + Optional slaUserAgent(); + /** * @return Defines the default page size for count operation, having separate from * `defaultPageSize` config because count will read more keys per page, defaults to 100 @@ -178,30 +185,41 @@ interface LwtConfig { interface DatabaseConfig { - /** Database type can be cassandra or astra. */ - @WithDefault(CASSANDRA) - String type(); + /** + * The type of backend DB to connect to, this drives decisions like using the cassandraEndPoints + */ + @WithDefault(DatabaseType.Constants.CASSANDRA) + @WithConverter(DatabaseType.DatabaseTypeConverter.class) + DatabaseType type(); - /** Username when connecting to cassandra database (when type is cassandra) */ + /** + * Username when connecting to cassandra database (when type is {@link DatabaseType#CASSANDRA}) + * and fixedToken is used + */ @Nullable @WithDefault("cassandra") String userName(); - /** Password when connecting to cassandra database (when type is cassandra) */ + /** + * Password when connecting to cassandra database (when type is {@link DatabaseType#CASSANDRA}) + * and fixedToken is used + */ @Nullable @WithDefault("cassandra") String password(); - /** Fixed Token used for Integration Test authentication */ + /** + * Fixed Token used for Integration Test authentication. When set, all tokens must match this + * value and the userName and password from this config are always used for the db credentials + */ Optional fixedToken(); - /** Cassandra contact points (when type is cassandra) */ + /** Cassandra contact points (when type is {@link DatabaseType#CASSANDRA}) */ @Nullable @WithDefault("127.0.0.1") List cassandraEndPoints(); - /** Cassandra contact points (when type is cassandra) */ - @Nullable + /** Cassandra port (when type is {@link DatabaseType#CASSANDRA}) */ @WithDefault("9042") int cassandraPort(); @@ -210,15 +228,22 @@ interface DatabaseConfig { @WithDefault("datacenter1") String localDatacenter(); - /** Time to live for CQLSession in cache in seconds. */ + /** Time to live for CQLSession in cache in seconds, that are not from the slaUserAgent. */ @WithDefault("300") long sessionCacheTtlSeconds(); + /** + * Time to live for CQLSession created because of a request from with the SLA user-agent in + * cache in seconds. + */ + @WithDefault("10") + long slaSessionCacheTtlSeconds(); + /** Maximum number of CQLSessions in cache. */ @WithDefault("50") int sessionCacheMaxSize(); - /** DDL query retry wait in illis. */ + /** DDL query retry wait in millis. */ @WithDefault("1000") int ddlRetryDelayMillis(); @@ -303,7 +328,6 @@ interface ConsistencyConfig { interface OfflineModeConfig { - /** Database type can be cassandra or astra. */ @WithDefault("1000") int maxDocumentInsertCount(); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java index f6a00abe5b..80f5e2d1d5 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/DocumentConstants.java @@ -44,9 +44,17 @@ interface Columns { */ String DATA_CONTAINS_COLUMN_NAME = "array_contains"; + String QUERY_BOOLEAN_MAP_COLUMN_NAME = "query_bool_values"; + + String QUERY_DOUBLE_MAP_COLUMN_NAME = "query_dbl_values"; + + String QUERY_NULL_MAP_COLUMN_NAME = "query_null_values"; + /** Text map support _id $ne and _id $nin on both atomic value and array element */ String QUERY_TEXT_MAP_COLUMN_NAME = "query_text_values"; + String QUERY_TIMESTAMP_MAP_COLUMN_NAME = "query_timestamp_values"; + /** Physical table column name that stores the vector field. */ String VECTOR_SEARCH_INDEX_COLUMN_NAME = "query_vector_value"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/HttpConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/HttpConstants.java index 928f5a3e4c..7928087311 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/HttpConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/HttpConstants.java @@ -27,6 +27,9 @@ public interface HttpConstants { /** Bearer prefix for the API key. */ String BEARER_PREFIX_FOR_API_KEY = "Bearer "; + @WithDefault(AUTHENTICATION_TOKEN_HEADER_NAME) + String authToken(); + /** * @return Embedding service header name for token. */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescConstants.java index 9a4546e073..85312a1658 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescConstants.java @@ -46,6 +46,15 @@ interface RegularIndexDefinitionDescOptions { String NORMALIZE = "normalize"; } + interface TextIndexDefinitionDescOptions { + String ANALYZER = "analyzer"; + } + + /** Options for the creating text index via CQL. */ + interface TextIndexCQLOptions { + String OPTION_ANALYZER = "index_analyzer"; + } + interface VectorIndexDefinitionDescOptions { String SOURCE_MODEL = "source_model"; String SIMILARITY_FUNCTION = "similarity_function"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescDefaults.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescDefaults.java index 203882fa18..e8fd5f1b48 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescDefaults.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/TableDescDefaults.java @@ -18,6 +18,21 @@ interface Constants { } } + /** + * Defaults for {@link + * io.stargate.sgv2.jsonapi.api.model.command.impl.CreateTextIndexCommand.CommandOptions}. + */ + interface CreateTextIndexOptionsDefaults { + boolean IF_NOT_EXISTS = false; + + String DEFAULT_NAMED_ANALYZER = "standard"; + + // For use in @Schema decorators + interface Constants { + String IF_NOT_EXISTS = "false"; + } + } + /** * Defaults for {@link * io.stargate.sgv2.jsonapi.api.model.command.impl.CreateVectorIndexCommand.CreateVectorIndexCommandOptions}. diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java index 22c4d50851..6cf7c6df54 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCodeV1.java @@ -155,7 +155,6 @@ public enum ErrorCodeV1 { VECTOR_SEARCH_TOO_BIG_VALUE("Vector embedding property '$vector' length too big"), VECTOR_SIZE_MISMATCH("Length of vector parameter different from declared '$vector' dimension"), - VECTORIZE_MODEL_DEPRECATED("Vectorize model is deprecated"), VECTORIZE_FEATURE_NOT_AVAILABLE("Vectorize feature is not available in the environment"), VECTORIZE_SERVICE_NOT_REGISTERED("Vectorize service name provided is not registered : "), VECTORIZE_SERVICE_TYPE_UNAVAILABLE("Vectorize service unavailable : "), @@ -167,9 +166,11 @@ public enum ErrorCodeV1 { LEXICAL_NOT_AVAILABLE_FOR_DATABASE("Lexical search is not available on this database"), LEXICAL_NOT_ENABLED_FOR_COLLECTION("Lexical search is not enabled for the collection"), + LEXICAL_CONTENT_TOO_BIG( + "Lexical content is too big, please use a smaller value for the $lexical field"), HYBRID_FIELD_CONFLICT( - "Conflict between '$hybrid' field and '$vector' and/or '$vectorize' field(s): can only use one or the other(s)"), + "The '$hybrid' field cannot be used with '$lexical', '$vector', or '$vectorize'."), HYBRID_FIELD_UNSUPPORTED_VALUE_TYPE("Unsupported JSON value type for '$hybrid' field"), HYBRID_FIELD_UNKNOWN_SUBFIELDS("Unrecognized sub-field(s) for '$hybrid' Object"), HYBRID_FIELD_UNSUPPORTED_SUBFIELD_VALUE_TYPE( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java index ca11409130..b33553a836 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/SchemaException.java @@ -28,6 +28,8 @@ public enum Code implements ErrorCode { CANNOT_DROP_VECTORIZE_FROM_UNKNOWN_COLUMNS, CANNOT_VECTORIZE_NON_VECTOR_COLUMNS, CANNOT_VECTORIZE_UNKNOWN_COLUMNS, + DEPRECATED_AI_MODEL, + END_OF_LIFE_AI_MODEL, INVALID_FORMAT_FOR_INDEX_CREATION_COLUMN, MISSING_ALTER_TABLE_OPERATIONS, MISSING_DIMENSION_IN_VECTOR_COLUMN, @@ -41,15 +43,16 @@ public enum Code implements ErrorCode { UNKNOWN_VECTOR_METRIC, UNKNOWN_VECTOR_SOURCE_MODEL, UNSUPPORTED_DATA_TYPE_TABLE_CREATION, - UNSUPPORTED_PROVIDER_MODEL, UNSUPPORTED_INDEXING_FOR_DATA_TYPES, UNSUPPORTED_INDEXING_FOR_FROZEN_COLUMN, UNSUPPORTED_INDEX_TYPE, + UNSUPPORTED_JSON_TYPE_FOR_TEXT_INDEX, UNSUPPORTED_LIST_DEFINITION, UNSUPPORTED_MAP_DEFINITION, UNSUPPORTED_SCHEMA_NAME, UNSUPPORTED_SET_DEFINITION, UNSUPPORTED_TEXT_ANALYSIS_FOR_DATA_TYPES, + UNSUPPORTED_TEXT_INDEX_FOR_DATA_TYPES, UNSUPPORTED_VECTOR_DIMENSION, UNSUPPORTED_VECTOR_INDEX_FOR_DATA_TYPES, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java index 980b96a61a..41686137c6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java @@ -193,6 +193,13 @@ private static CommandResult.Error handleQueryValidationException( .toApiException() .getCommandResultError(message, Response.Status.OK); } + // [data-api#2068]: Need to convert Lexical-value-too-big failure to something more meaningful + if (message.contains( + "analyzed size for column query_lexical_value exceeds the cumulative limit for index")) { + return ErrorCodeV1.LEXICAL_CONTENT_TOO_BIG + .toApiException() + .getCommandResultError(Response.Status.OK); + } return ErrorCodeV1.INVALID_QUERY .toApiException() .getCommandResultError(message, Response.Status.OK); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeature.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeature.java new file mode 100644 index 0000000000..67ca8d94f2 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeature.java @@ -0,0 +1,36 @@ +package io.stargate.sgv2.jsonapi.metrics; + +import java.util.Objects; + +/** + * Represents distinct features that can be used in the general command. Each feature has an + * associated tag name, which can be used for metrics. + */ +public enum CommandFeature { + /** The usage of $hybrid with String in the command */ + HYBRID("feature.hybrid.string"), + /** The usage of $lexical in the command */ + LEXICAL("feature.lexical"), + /** The usage of $vector in the command */ + VECTOR("feature.vector"), + /** The usage of $vectorize in the command */ + VECTORIZE("feature.vectorize"), + + /** The usage of `hybridLimits` with Number in the command */ + HYBRID_LIMITS_NUMBER("feature.hybrid.limits.number"), + /** The usage of `hybridLimits` Object with $vector in the command */ + HYBRID_LIMITS_VECTOR("feature.hybrid.limits.vector"), + /** The usage of `hybridLimits` Object with $lexical in the command */ + HYBRID_LIMITS_LEXICAL("feature.hybrid.limits.lexical"), + ; + + private final String tagName; + + CommandFeature(String tagName) { + this.tagName = Objects.requireNonNull(tagName); + } + + public String getTagName() { + return tagName; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeatures.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeatures.java new file mode 100644 index 0000000000..05b7c4f251 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/CommandFeatures.java @@ -0,0 +1,113 @@ +package io.stargate.sgv2.jsonapi.metrics; + +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import java.util.EnumSet; +import java.util.Objects; + +/** + * Represents a collection of {@link CommandFeature}s used in a command. This class is mutable and + * designed to be used within a {@link io.stargate.sgv2.jsonapi.api.model.command.CommandContext} to + * accumulate features during command processing. Mutation is controlled via specific add methods. + * It uses an {@link EnumSet} internally for efficient storage and operations on commandFeatures. + */ +public final class CommandFeatures { + private final EnumSet commandFeatures; + + /** A instance representing no commandFeatures in use. */ + public static final CommandFeatures EMPTY = + new CommandFeatures(EnumSet.noneOf(CommandFeature.class)); + + /** Private constructor, use factory methods 'of' or 'create' */ + private CommandFeatures(EnumSet commandFeatures) { + this.commandFeatures = commandFeatures; + } + + /** + * Creates a new, mutable {@code CommandFeatures} instance containing no features. + * + * @return A new, empty, mutable {@code CommandFeatures} instance. + */ + public static CommandFeatures create() { + return new CommandFeatures(EnumSet.noneOf(CommandFeature.class)); + } + + /** + * Creates a {@code CommandFeatures} instance from an array of {@link CommandFeature}s. The + * returned instance will be mutable. + * + * @param initialFeatures The initial features to include. If null or empty, an empty instance is + * returned. + * @return A new {@code CommandFeatures} instance containing the specified features. + */ + public static CommandFeatures of(CommandFeature... initialFeatures) { + if (initialFeatures == null || initialFeatures.length == 0) { + return create(); + } + return new CommandFeatures(EnumSet.of(initialFeatures[0], initialFeatures)); + } + + /** Adds the specified feature to this instance. */ + public void addFeature(CommandFeature commandFeature) { + Objects.requireNonNull(commandFeature, "CommandFeature cannot be null"); + commandFeatures.add(commandFeature); + } + + /** + * Adds all features from another {@code CommandFeatures} instance to this instance. Mutates the + * current object. + * + * @param other The other {@code CommandFeatures} instance whose features should be added. If null + * or empty, this instance remains unchanged. + */ + public void addAll(CommandFeatures other) { + if (other != null && !other.isEmpty()) { + commandFeatures.addAll(other.commandFeatures); + } + } + + /** + * Checks if this instance contains any features. + * + * @return {@code true} if no features are present, {@code false} otherwise. + */ + public boolean isEmpty() { + return commandFeatures.isEmpty(); + } + + /** + * Generates Micrometer Tags representing the features in this instance. Each feature is + * represented as a tag with its name and a value of {@code true}. + * + * @return A {@link Tags} object containing the tags for each feature. + */ + public Tags getTags() { + return Tags.of( + commandFeatures.stream() + .map(f -> Tag.of(f.getTagName(), String.valueOf(true))) + .toArray(Tag[]::new)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + CommandFeatures that = (CommandFeatures) obj; + return Objects.equals(commandFeatures, that.commandFeatures); + } + + @Override + public int hashCode() { + return commandFeatures.hashCode(); + } + + @Override + public String toString() { + // CommandFeatures[features…] + return "CommandFeatures" + commandFeatures.toString(); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/JsonProcessingMetricsReporter.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/JsonProcessingMetricsReporter.java similarity index 94% rename from src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/JsonProcessingMetricsReporter.java rename to src/main/java/io/stargate/sgv2/jsonapi/metrics/JsonProcessingMetricsReporter.java index 0e6fe36f4d..6ed7af74c8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/JsonProcessingMetricsReporter.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/JsonProcessingMetricsReporter.java @@ -1,10 +1,12 @@ -package io.stargate.sgv2.jsonapi.api.v1.metrics; +package io.stargate.sgv2.jsonapi.metrics; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; import io.stargate.sgv2.jsonapi.api.request.RequestContext; +import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonApiMetricsConfig; +import io.stargate.sgv2.jsonapi.api.v1.metrics.MetricsConfig; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsConstants.java new file mode 100644 index 0000000000..4e9e3d6ffc --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsConstants.java @@ -0,0 +1,27 @@ +package io.stargate.sgv2.jsonapi.metrics; + +/** Defines constants for metric names and tag keys used in the Data API. */ +public interface MetricsConstants { + /** Default value used for tags when the actual value is unknown or unavailable. */ + String UNKNOWN_VALUE = "unknown"; + + /** Defines common tag keys used across various metrics. */ + interface MetricTags { + String KEYSPACE_TAG = "keyspace"; + String RERANKING_PROVIDER_TAG = "reranking.provider"; + String RERANKING_MODEL_TAG = "reranking.model"; + String SESSION_TAG = "session"; + String TENANT_TAG = "tenant"; + String TABLE_TAG = "table"; + } + + /** Defines metric names that used in the DataAPI */ + interface MetricNames { + String HTTP_SERVER_REQUESTS = "http.server.requests"; + String RERANK_ALL_CALL_DURATION_METRIC = "rerank.all.call.duration"; + String RERANK_ALL_PASSAGE_COUNT_METRIC = "rerank.all.passage.count"; + String RERANK_TENANT_CALL_DURATION_METRIC = "rerank.tenant.call.duration"; + String RERANK_TENANT_PASSAGE_COUNT_METRIC = "rerank.tenant.passage.count"; + String VECTORIZE_CALL_DURATION_METRIC = "vectorize.call.duration"; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsTenantDeactivationConsumer.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsTenantDeactivationConsumer.java new file mode 100644 index 0000000000..c7c457c0d1 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/MetricsTenantDeactivationConsumer.java @@ -0,0 +1,59 @@ +package io.stargate.sgv2.jsonapi.metrics; + +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricTags.SESSION_TAG; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricTags.TENANT_TAG; + +import com.github.benmanes.caffeine.cache.RemovalCause; +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.MeterRegistry; +import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; +import java.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link CQLSessionCache.DeactivatedTenantConsumer} responsible for removing tenant-specific + * metrics from the {@link MeterRegistry} when a tenant's session is evicted from the {@link + * CQLSessionCache}. + */ +public class MetricsTenantDeactivationConsumer + implements CQLSessionCache.DeactivatedTenantConsumer { + private static final Logger LOGGER = + LoggerFactory.getLogger(MetricsTenantDeactivationConsumer.class); + private final MeterRegistry meterRegistry; + + public MetricsTenantDeactivationConsumer(MeterRegistry meterRegistry) { + this.meterRegistry = Objects.requireNonNull(meterRegistry, "MeterRegistry cannot be null"); + } + + /** + * Called by {@link CQLSessionCache} when a tenant's session is removed. This method iterates + * through all registered meters in the {@link MeterRegistry} and removes any that are tagged with + * the specified {@code tenantId} using either the {@link MetricsConstants.MetricTags#TENANT_TAG} + * or {@link MetricsConstants.MetricTags#SESSION_TAG} key. + * + * @param tenantId The ID of the tenant whose session was deactivated. This value will be used to + * find metrics with a matching tag. + * @param cause The reason for the removal from the cache. + */ + @Override + public void accept(String tenantId, RemovalCause cause) { + if (tenantId == null) { + LOGGER.warn("Received null tenantId for deactivation"); + return; + } + + for (Meter meter : meterRegistry.getMeters()) { + // Check TENANT_TAG first, if not found, check SESSION_TAG + if (Objects.equals(meter.getId().getTag(TENANT_TAG), tenantId) + || Objects.equals(meter.getId().getTag(SESSION_TAG), tenantId)) { + if (meterRegistry.remove(meter.getId()) == null) { + LOGGER.debug( + "Attempted to remove metric with ID {} for tenant {} but it was not found in the registry during the removal phase.", + meter.getId(), + tenantId); + } + } + } + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/metrics/MicrometerConfiguration.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/MicrometerConfiguration.java new file mode 100644 index 0000000000..b72e48e355 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/MicrometerConfiguration.java @@ -0,0 +1,144 @@ +package io.stargate.sgv2.jsonapi.metrics; + +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricTags.SESSION_TAG; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricTags.TENANT_TAG; + +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.config.MeterFilter; +import io.micrometer.core.instrument.distribution.DistributionStatisticConfig; +import io.smallrye.config.SmallRyeConfig; +import io.stargate.sgv2.jsonapi.api.v1.metrics.MetricsConfig; +import jakarta.enterprise.inject.Produces; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.eclipse.microprofile.config.ConfigProvider; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Centralized configuration of Micrometer {@link MeterFilter}s. + * + *

This class provides CDI producer methods for {@link MeterFilter} beans that: + * + *

    + *
  • Apply global tags (e.g., {@code module=sgv2-jsonapi}) to all metrics based on configuration + * provided by {@link MetricsConfig}, see {@link #globalTagsMeterFilter()}. + *
  • Configure distribution statistics percentiles for timer metrics such as HTTP server see + * {@link #configureDistributionStatistics()}. + *
+ */ +public class MicrometerConfiguration { + private static final Logger LOGGER = LoggerFactory.getLogger(MicrometerConfiguration.class); + + /** + * Produces a meter filter that applies configured global tags (e.g., {@code module=sgv2-jsonapi}) + * to all metrics. Reads tags from {@link MetricsConfig#globalTags()}. + * + * @return A {@link MeterFilter} for applying common tags, or an no-op filter if no global tags + * are configured. + */ + @Produces + @SuppressWarnings("unused") + public MeterFilter globalTagsMeterFilter() { + MetricsConfig metricsConfig = + ConfigProvider.getConfig() + .unwrap(SmallRyeConfig.class) + .getConfigMapping(MetricsConfig.class); + + Map globalTags = metricsConfig.globalTags(); + LOGGER.info("Configuring metrics with common global tags: {}", globalTags); + + // if we have no global tags, use empty (no-op filter) + if (null == globalTags || globalTags.isEmpty()) { + return new MeterFilter() {}; + } + + var tags = + globalTags.entrySet().stream() + .map(e -> Tag.of(e.getKey(), e.getValue())) + .collect(Collectors.toList()); + + // Notes from PR 2003: This producer method globalTagsMeterFilter is typically called only once + // during the application startup phase by the Micrometer extension when it's collecting all the + // MeterFilter beans. It's not called repeatedly during the application's runtime for every + // metric being recorded. So no need the cache. + return MeterFilter.commonTags(Tags.of(tags)); + } + + /** + * Produces a meter filter to configure distribution statistics for timer metrics such as HTTP + * server requests, vectorization duration, and reranking calls. + * + *

Tests in {@link MicrometerConfigurationTests} show what is expected to output for the + * different metric types. + * + *

For all distribution metrics, we supress the full histogram buckets to reduce overhead. And + * then we configure the percentiles based on the metric type: + * + *

    + *
  • Per tenant metrics (see {@link IsPerTenantPredicate}) have fewer percentiles because + * there will be many more metrics of these types. + *
  • Non per tenant metrics have more percentiles because they are less numerous. + *
+ * + * This is applied to all metrics, including the driver metrics. + * + * @return A {@link MeterFilter} for configuring distribution statistics. + */ + @Produces + @SuppressWarnings("unused") + public MeterFilter configureDistributionStatistics() { + + final double[] allTenantLatencyPercentiles = {0.5, 0.90, 0.95, 0.98, 0.99}; + final double[] perTenantLatencyPercentiles = {0.5, 0.98, 0.99}; + final Predicate isPerTenantPredicate = new IsPerTenantPredicate(); + + return new MeterFilter() { + @Override + public DistributionStatisticConfig configure( + Meter.Id id, DistributionStatisticConfig config) { + + var builder = DistributionStatisticConfig.builder(); + if (isPerTenantPredicate.test(id)) { + builder = builder.percentiles(perTenantLatencyPercentiles); + } else { + builder = builder.percentiles(allTenantLatencyPercentiles); + } + + // make sure we do not publish the histogram buckets for all distribution metrics + // that can be 70 lines long for a single metric, and we don't need them because we have + // calc'd + // the percentiles + builder = builder.percentilesHistogram(false); + + return builder.build().merge(config); + } + }; + } + + static class IsPerTenantPredicate implements Predicate { + + @Override + public boolean test(Meter.Id id) { + + // if the Metric has a "tenant" or "session" tag we assume it is a per-tenant metric + // the API code will use tenant, the Driver uses session + // getTags() iterates over the tags, but there will never be too many for it to be a problem + // and sanity check they are not blank strings + + var tenantTag = id.getTag(TENANT_TAG); + if (tenantTag != null && !tenantTag.isBlank()) { + return true; + } + + var sessionTag = id.getTag(SESSION_TAG); + if (sessionTag != null && !sessionTag.isBlank()) { + return true; + } + return false; + } + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/TenantRequestMetricsFilter.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/TenantRequestMetricsFilter.java similarity index 95% rename from src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/TenantRequestMetricsFilter.java rename to src/main/java/io/stargate/sgv2/jsonapi/metrics/TenantRequestMetricsFilter.java index d5c974a1f4..9178b7b2b1 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/TenantRequestMetricsFilter.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/TenantRequestMetricsFilter.java @@ -15,16 +15,18 @@ * */ -package io.stargate.sgv2.jsonapi.api.v1.metrics; +package io.stargate.sgv2.jsonapi.metrics; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; import io.stargate.sgv2.jsonapi.api.request.RequestContext; +import io.stargate.sgv2.jsonapi.api.v1.metrics.MetricsConfig; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.container.ContainerResponseContext; +import jakarta.ws.rs.core.HttpHeaders; import java.util.regex.Pattern; import org.jboss.resteasy.reactive.server.ServerResponseFilter; @@ -113,7 +115,7 @@ public void record( } private String getUserAgentValue(ContainerRequestContext requestContext) { - String headerString = requestContext.getHeaderString("user-agent"); + String headerString = requestContext.getHeaderString(HttpHeaders.USER_AGENT); if (null != headerString && !headerString.isBlank()) { String[] split = USER_AGENT_SPLIT.split(headerString); if (split.length > 0) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/TenantRequestMetricsTagProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/metrics/TenantRequestMetricsTagProvider.java similarity index 92% rename from src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/TenantRequestMetricsTagProvider.java rename to src/main/java/io/stargate/sgv2/jsonapi/metrics/TenantRequestMetricsTagProvider.java index 677e2adf2b..34fe27deb6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/metrics/TenantRequestMetricsTagProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/metrics/TenantRequestMetricsTagProvider.java @@ -1,12 +1,14 @@ -package io.stargate.sgv2.jsonapi.api.v1.metrics; +package io.stargate.sgv2.jsonapi.metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; import io.quarkus.micrometer.runtime.HttpServerMetricsTagsContributor; import io.stargate.sgv2.jsonapi.api.request.RequestContext; +import io.stargate.sgv2.jsonapi.api.v1.metrics.MetricsConfig; import io.vertx.core.http.HttpServerRequest; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; +import jakarta.ws.rs.core.HttpHeaders; import java.util.regex.Pattern; /** Tags provider for http request metrics. It provides tenant id and user agent as tags. */ @@ -64,7 +66,7 @@ public Tags contribute(Context context) { } private String getUserAgentValue(HttpServerRequest request) { - String headerString = request.getHeader("user-agent"); + String headerString = request.getHeader(HttpHeaders.USER_AGENT); if (null != headerString && !headerString.isBlank()) { String[] split = USER_AGENT_SPLIT.split(headerString); if (split.length > 0) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java index 33e9d197fd..c7fd46a01b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java @@ -1,280 +1,494 @@ package io.stargate.sgv2.jsonapi.service.cqldriver; +import static io.stargate.sgv2.jsonapi.util.ClassUtils.classSimpleName; + import com.datastax.oss.driver.api.core.CqlSession; -import com.datastax.oss.driver.api.core.config.DefaultDriverOption; -import com.datastax.oss.driver.api.core.config.DriverConfigLoader; -import com.github.benmanes.caffeine.cache.Caffeine; -import com.github.benmanes.caffeine.cache.LoadingCache; -import com.github.benmanes.caffeine.cache.RemovalListener; +import com.github.benmanes.caffeine.cache.*; +import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.binder.cache.CaffeineCacheMetrics; -import io.quarkus.security.UnauthorizedException; -import io.stargate.sgv2.jsonapi.JsonApiStartUp; import io.stargate.sgv2.jsonapi.api.request.RequestContext; -import io.stargate.sgv2.jsonapi.config.OperationsConfig; -import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache; -import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec; -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.inject.Inject; -import java.net.InetSocketAddress; +import io.stargate.sgv2.jsonapi.config.DatabaseType; import java.time.Duration; +import java.util.List; import java.util.Objects; -import org.eclipse.microprofile.config.inject.ConfigProperty; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * CQL session cache to reuse the session for the same tenant and token. The cache is configured to - * expire after CACHE_TTL_SECONDS of inactivity and to have a maximum size of - * CACHE_TTL_SECONDS sessions. + * A cache for managing and reusing {@link CqlSession} instances based on tenant and authentication + * credentials. + * + *

Sessions are cached based on the tenantId and authentication token. So that a single tenant + * may have multiple sessions, but a single session is used for the same tenant and auth token. + * + *

Create instances using the {@link CqlSessionCacheSupplier} class. + * + *

Call {@link #getSession(RequestContext)} and overloads to get a session for the current + * request context. + * + *

The {@link DeactivatedTenantConsumer} interface will be called when a session is removed from + * the cache, so that schema cache and metrics can be updated to remove the tenant. NOTE: this is + * called when the session expires, but a single tenant may have multiple sessions (based on key + * above), so it is not a guarantee that the tenant is not active with another set of credentials. + * If you take action to remove a deactivated tenant, there should be a path for the tenant to be + * reactivated. + * + *

NOTE: There is no method to get the size of the cache because it is not a reliable + * measure, it's only an estimate. We can assume the size feature works. For testing use {@link + * #peekSession(String, String, String)} */ -@ApplicationScoped public class CQLSessionCache { - private static final Logger LOGGER = LoggerFactory.getLogger(JsonApiStartUp.class); - - /** Configuration for the JSON API operations. */ - private final OperationsConfig operationsConfig; + private static final Logger LOGGER = LoggerFactory.getLogger(CQLSessionCache.class); /** * Default tenant to be used when the backend is OSS cassandra and when no tenant is passed in the * request */ - private static final String DEFAULT_TENANT = "default_tenant"; - - /** CQLSession cache. */ - private final LoadingCache sessionCache; - - /** SchemaCache, used for evict collectionSetting cache and namespace cache. */ - @Inject private SchemaCache schemaCache; - - /** Database type Astra */ - public static final String ASTRA = "astra"; - - /** Database type OSS cassandra */ - public static final String CASSANDRA = "cassandra"; + public static final String DEFAULT_TENANT = "default_tenant"; - /** Persistence type SSTable Writer */ - public static final String OFFLINE_WRITER = "offline_writer"; + private final DatabaseType databaseType; + private final Duration cacheTTL; + private final String slaUserAgent; + private final Duration slaUserTTL; - @ConfigProperty(name = "quarkus.application.name") - String APPLICATION_NAME; + private final LoadingCache sessionCache; + private final CqlCredentialsFactory credentialsFactory; + private final SessionFactory sessionFactory; - @Inject - public CQLSessionCache(OperationsConfig operationsConfig, MeterRegistry meterRegistry) { + private List deactivatedTenantConsumers; - LOGGER.info("Initializing CQLSessionCache"); - this.operationsConfig = operationsConfig; - - LoadingCache loadingCache = - Caffeine.newBuilder() - .expireAfterAccess( - Duration.ofSeconds(operationsConfig.databaseConfig().sessionCacheTtlSeconds())) - .maximumSize(operationsConfig.databaseConfig().sessionCacheMaxSize()) - // removal listener is invoked after the entry has been removed from the cache. So the - // idea is that we no longer return this session for any lookup as a first step, then - // close the session in the background asynchronously which is a graceful closing of - // channels i.e. any in-flight query will be completed before the session is getting - // closed. - .removalListener( - (RemovalListener) - (sessionCacheKey, session, cause) -> { - if (sessionCacheKey != null) { - if (LOGGER.isTraceEnabled()) { - LOGGER.trace( - "Removing session for tenant : {}", sessionCacheKey.tenantId()); - } - if (this.schemaCache != null && session != null) { - // When a sessionCache entry expires - // Evict all corresponding entire NamespaceCaches for the tenant - // This is to ensure there is no offset for sessionCache and schemaCache - schemaCache.evictNamespaceCacheEntriesForTenant( - sessionCacheKey.tenantId(), session.getMetadata().getKeyspaces()); - } - } - if (session != null) { - session.close(); - } - }) - .recordStats() - .build(this::getNewSession); - this.sessionCache = - CaffeineCacheMetrics.monitor(meterRegistry, loadingCache, "cql_sessions_cache"); - LOGGER.info( - "CQLSessionCache initialized with ttl of {} seconds and max size of {}", - operationsConfig.databaseConfig().sessionCacheTtlSeconds(), - operationsConfig.databaseConfig().sessionCacheMaxSize()); + /** + * Constructs a new instance of the {@link CQLSessionCache}. + * + *

Use this overload in production code, see other for detailed description of the parameters. + */ + public CQLSessionCache( + DatabaseType databaseType, + Duration cacheTTL, + long cacheMaxSize, + String slaUserAgent, + Duration slaUserTTL, + CqlCredentialsFactory credentialsFactory, + SessionFactory sessionFactory, + MeterRegistry meterRegistry, + List deactivatedTenantConsumer) { + this( + databaseType, + cacheTTL, + cacheMaxSize, + slaUserAgent, + slaUserTTL, + credentialsFactory, + sessionFactory, + meterRegistry, + deactivatedTenantConsumer, + false, + null); } /** - * Loader for new CQLSession. + * Constructs a new instance of the {@link CQLSessionCache}. * - * @return CQLSession - * @throws RuntimeException if database type is not supported + *

Use this ctor for testing only. + * + * @param databaseType The type of database being used, + * @param cacheTTL The time-to-live (TTL) duration for cache entries. + * @param cacheMaxSize The maximum size of the cache. + * @param credentialsFactory A factory for creating {@link CqlCredentials} based on authentication + * tokens. + * @param sessionFactory A factory for creating new {@link CqlSession} instances when needed. + * @param meterRegistry The {@link MeterRegistry} for monitoring cache metrics. + * @param deactivatedTenantConsumer A list of consumers to handle tenant deactivation events. + * @param asyncTaskOnCaller If true, asynchronous tasks (e.g., callbacks) will run on the caller + * thread. This is intended for testing purposes only. DO NOT USE in production. + * @param cacheTicker If non-null, this is the ticker used by the cache to decide when to expire + * entries. If null, the default ticker is used. DO NOT USE in production. */ - private CqlSession getNewSession(SessionCacheKey cacheKey) { - - // TODO: WHY IS DriverConfigLoader USED ? - DriverConfigLoader loader = - DriverConfigLoader.programmaticBuilder() - .withString(DefaultDriverOption.SESSION_NAME, cacheKey.tenantId) - .build(); - - var databaseConfig = operationsConfig.databaseConfig(); - if (LOGGER.isTraceEnabled()) { - LOGGER.trace( - "Creating new session tenantId={} and databaseType={}", - cacheKey.tenantId(), - databaseConfig.type()); + CQLSessionCache( + DatabaseType databaseType, + Duration cacheTTL, + long cacheMaxSize, + String slaUserAgent, + Duration slaUserTTL, + CqlCredentialsFactory credentialsFactory, + SessionFactory sessionFactory, + MeterRegistry meterRegistry, + List deactivatedTenantConsumer, + boolean asyncTaskOnCaller, + Ticker cacheTicker) { + + this.databaseType = Objects.requireNonNull(databaseType, "databaseType must not be null"); + this.cacheTTL = Objects.requireNonNull(cacheTTL, "cacheTTL must not be null"); + // we use case-insensitive compare + this.slaUserAgent = slaUserAgent == null || slaUserAgent.isBlank() ? null : slaUserAgent; + if (slaUserAgent != null) { + this.slaUserTTL = + Objects.requireNonNull(slaUserTTL, "slaUserTTL must not be null is slaUserAgent is set"); + } else { + this.slaUserTTL = null; } - // there is a lot of common setup regardless of the database type + this.credentialsFactory = + Objects.requireNonNull(credentialsFactory, "credentialsFactory must not be null"); + this.sessionFactory = Objects.requireNonNull(sessionFactory, "sessionFactory must not be null"); + this.deactivatedTenantConsumers = + deactivatedTenantConsumer == null ? List.of() : List.copyOf(deactivatedTenantConsumer); + + LOGGER.info( + "Initializing CQLSessionCache with cacheTTL={}, cacheMaxSize={}, databaseType={}, slaUserAgent={}, slaUserTTL={}, deactivatedTenantConsumers.count={}", + cacheTTL, + cacheMaxSize, + databaseType, + slaUserAgent, + slaUserTTL, + deactivatedTenantConsumers.size()); + var builder = - new TenantAwareCqlSessionBuilder(cacheKey.tenantId()) - .withLocalDatacenter(operationsConfig.databaseConfig().localDatacenter()) - .withClassLoader(Thread.currentThread().getContextClassLoader()) - .withConfigLoader(loader) - .addSchemaChangeListener(new SchemaChangeListener(schemaCache, cacheKey.tenantId)) - .withApplicationName(APPLICATION_NAME); - cacheKey.credentials().addToSessionBuilder(builder); - - if (databaseConfig.type().equals(CASSANDRA)) { - var seeds = - Objects.requireNonNull(operationsConfig.databaseConfig().cassandraEndPoints()).stream() - .map( - host -> - new InetSocketAddress( - host, operationsConfig.databaseConfig().cassandraPort())) - .toList(); - builder.addContactPoints(seeds); + Caffeine.newBuilder() + .expireAfter(new SessionExpiry()) + .maximumSize(cacheMaxSize) + .removalListener(this::onKeyRemoved) + .recordStats(); + + if (asyncTaskOnCaller) { + LOGGER.warn( + "CQLSessionCache CONFIGURED TO RUN ASYNC TASKS SUCH AS CALLBACKS ON THE CALLER THREAD, DO NOT USE IN PRODUCTION."); + builder = builder.executor(Runnable::run); } + if (cacheTicker != null) { + LOGGER.warn("CQLSessionCache CONFIGURED TO USE A CUSTOM TICKER, DO NOT USE IN PRODUCTION."); + builder = builder.ticker(cacheTicker); + } + LoadingCache loadingCache = + builder.build(this::onLoadSession); - // Add optimized CqlVector codec (see [data-api#1775]) - builder = builder.addTypeCodecs(SubtypeOnlyFloatVectorToArrayCodec.instance()); - - // aaron - this used to have an if / else that threw an exception if the database type was not - // known but we test that when creating the credentials for the cache key so no need to do it - // here. - return builder.build(); + this.sessionCache = + CaffeineCacheMetrics.monitor(meterRegistry, loadingCache, "cql_sessions_cache"); } /** - * Get CQLSession from cache. + * Gets or creates a {@link CqlSession} for the provided request context * - * @return CQLSession + * @param requestContext {@link RequestContext} to get the session for. + * @return {@link CqlSession} for this tenant and credentials. */ - public CqlSession getSession(RequestContext dataApiRequestInfo) { + public CqlSession getSession(RequestContext requestContext) { + Objects.requireNonNull(requestContext, "requestContext must not be null"); // Validation happens when creating the credentials and session key return getSession( - dataApiRequestInfo.getTenantId().orElse(""), - dataApiRequestInfo.getCassandraToken().orElse("")); + requestContext.getTenantId().orElse(""), + requestContext.getCassandraToken().orElse(""), + requestContext.getUserAgent().orElse(null)); } - public CqlSession getSession(String tenantId, String authToken) { - String fixedToken = getFixedToken(); - if (fixedToken != null && !authToken.equals(fixedToken)) { - throw new UnauthorizedException(ErrorCodeV1.UNAUTHENTICATED_REQUEST.getMessage()); - } + /** + * Retrieves or creates a {@link CqlSession} for the specified tenant and authentication token. + * + *

If the database type is OFFLINE_WRITER, this method will attempt to retrieve the session + * from the cache without creating a new session if it is not present. For other database types, a + * new session will be created if it is not already cached. + * + * @param tenantId the identifier for the tenant + * @param authToken the authentication token for accessing the session + * @param userAgent Nullable user agent, if matching the configured SLA checker user agent then + * the session will use the TTL for the SLA user. + * @return a {@link CqlSession} associated with the given tenantId and authToken + */ + public CqlSession getSession(String tenantId, String authToken, String userAgent) { - var cacheKey = getSessionCacheKey(tenantId, authToken); + var cacheKey = createCacheKey(tenantId, authToken, userAgent); // TODO: why is this different for OFFLINE ? - if (OFFLINE_WRITER.equals(operationsConfig.databaseConfig().type())) { - return sessionCache.getIfPresent(cacheKey); - } - return sessionCache.get(cacheKey); + var holder = + switch (databaseType) { + case OFFLINE_WRITER -> sessionCache.getIfPresent(cacheKey); + default -> sessionCache.get(cacheKey); + }; + return holder == null ? null : holder.session(); } /** - * Default token which will be used by the integration tests. If this property is set, then the - * token from the request will be compared with this to perform authentication. + * For testing, peek into the cache to see if a session is present for the given tenantId, + * authToken, and userAgent. */ - private String getFixedToken() { - return operationsConfig.databaseConfig().fixedToken().orElse(null); + @VisibleForTesting + protected Optional peekSession(String tenantId, String authToken, String userAgent) { + var cacheKey = createCacheKey(tenantId, authToken, userAgent); + return Optional.ofNullable(sessionCache.getIfPresent(cacheKey)) + .map(SessionValueHolder::session); } - /** - * Build key for CQLSession cache from tenant and token if the database type is AstraDB or from - * tenant, username and password if the database type is OSS cassandra (also, if token is present - * in the request, that will be given priority for the cache key). - * - * @return key for CQLSession cache - */ - private SessionCacheKey getSessionCacheKey(String tenantId, String authToken) { - var databaseConfig = operationsConfig.databaseConfig(); - - // NOTE: this has changed, will create the UsernamePasswordCredentials from the token if that is - // the token - var credentials = - CqlCredentials.create( - getFixedToken(), authToken, databaseConfig.userName(), databaseConfig.password()); - - // Only the OFFLINE_WRITER allows anonymous access, because it is not connecting to an actual - // database - if (credentials.isAnonymous() - && !OFFLINE_WRITER.equals(operationsConfig.databaseConfig().type())) { - throw ErrorCodeV1.SERVER_INTERNAL_ERROR.toApiException( - "Missing/Invalid authentication credentials provided for type: %s", - operationsConfig.databaseConfig().type()); + /** Process a key being removed from the cache for any reason. */ + private void onKeyRemoved( + SessionCacheKey cacheKey, SessionValueHolder sessionHolder, RemovalCause cause) { + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("onKeyRemoved for sessionHolder={}, cause={}", sessionHolder, cause); + } + + deactivatedTenantConsumers.forEach( + consumer -> { + try { + consumer.accept(cacheKey.tenantId(), cause); + } catch (Exception e) { + LOGGER.warn( + "Error calling deactivated tenant consumer: sessionHolder={}, cause={}, consumer.class={}", + sessionHolder, + cause, + classSimpleName(consumer.getClass()), + e); + } + }); + + // we need to manually close the session, the cache will not close it for us. + if (sessionHolder != null) { + // This will be running on a cache tread, any error will not make it to the user + // So we log it and swallow + try { + sessionHolder.session.close(); + } catch (Exception e) { + LOGGER.error("Error closing CQLSession sessionHolder={}", sessionHolder, e); + } + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Closed CQL Session sessionHolder={}", sessionHolder); + } + + } else if (LOGGER.isWarnEnabled()) { + LOGGER.warn("CQL Session was null when removing from cache, cacheKey={}", cacheKey); + } + } + + /** Callback to create a new session when one is needed for the cache */ + private SessionValueHolder onLoadSession(SessionCacheKey cacheKey) { + + // factory will do some logging + var holder = + new SessionValueHolder( + sessionFactory.apply(cacheKey.tenantId(), cacheKey.credentials()), cacheKey); + + if (LOGGER.isDebugEnabled()) { + // so we get the identity hash code of the session holder + LOGGER.debug("Loaded CQLSession into cache, holder={}", holder); + } + return holder; + } + + /** Builds the cache key to use for the supplied tenant and authentication token. */ + private SessionCacheKey createCacheKey(String tenantId, String authToken, String userAgent) { + + var credentials = credentialsFactory.apply(authToken); + if (credentials == null) { + // sanity check + throw new IllegalStateException("credentialsFactory returned null"); } - return switch (operationsConfig.databaseConfig().type()) { + // userAgent arg can be null + // slaUserAgent forced to lower case in the ctor + var keyTTL = + slaUserAgent == null || !slaUserAgent.equalsIgnoreCase(userAgent) ? cacheTTL : slaUserTTL; + + return switch (databaseType) { case CASSANDRA, OFFLINE_WRITER -> new SessionCacheKey( - tenantId == null || tenantId.isBlank() ? DEFAULT_TENANT : tenantId, credentials); - case ASTRA -> new SessionCacheKey(tenantId, credentials); - default -> - throw new IllegalStateException( - "Unknown databaseConfig().type(): " + operationsConfig.databaseConfig().type()); + tenantId == null || tenantId.isBlank() ? DEFAULT_TENANT : tenantId, + credentials, + keyTTL, + userAgent); + case ASTRA -> new SessionCacheKey(tenantId, credentials, keyTTL, userAgent); }; } /** - * Get cache size. + * Invalidate all entries and cleanup, for testing when items are invalidated. * - * @return cache size + *

Note: Removal cause will be {@link RemovalCause#EXPLICIT} for items. */ - public long cacheSize() { + @VisibleForTesting + void clearCache() { + LOGGER.info("Manually clearing CQLSession cache"); + sessionCache.invalidateAll(); sessionCache.cleanUp(); - return sessionCache.estimatedSize(); } /** - * Remove CQLSession from cache. - * - * @param cacheKey key for CQLSession cache + * Clean up the cache, for testing when items are invalidated. The cache will try to be lazy, so + * things like evictions may not happen exactly at the TTL, this is a way to force it. */ - public void removeSession(SessionCacheKey cacheKey) { - sessionCache.invalidate(cacheKey); + @VisibleForTesting + void cleanUp() { sessionCache.cleanUp(); - LOGGER.trace("Session removed for tenant : {}", cacheKey.tenantId()); + } + + /** Key for CQLSession cache. */ + static class SessionCacheKey { + + private final String tenantId; + private final CqlCredentials credentials; + private final Duration ttl; + // user agent only added for logging and debugging + @Nullable private final String userAgent; + + /** + * Creates a new instance of {@link SessionCacheKey}. + * + * @param tenantId The identifier for the tenant. + * @param credentials The credentials used for authentication. + * @param ttl The time-to-live (TTL) duration for the cache entry. Note: This is NOT used in the + * value quality of the cache key, it is set so dynamic TTL can be used per key. + * @param userAgent Optional user agent for the request, not used in the equality of the key + * just for logging. + */ + SessionCacheKey( + String tenantId, CqlCredentials credentials, Duration ttl, @Nullable String userAgent) { + if (tenantId == null || tenantId.isBlank()) { + tenantId = ""; + } + this.tenantId = tenantId; + this.credentials = Objects.requireNonNull(credentials, "credentials must not be null"); + this.ttl = Objects.requireNonNull(ttl, "ttl must not be null"); + this.userAgent = userAgent; + } + + public String tenantId() { + return tenantId; + } + + public CqlCredentials credentials() { + return credentials; + } + + public Duration ttl() { + return ttl; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SessionCacheKey)) { + return false; + } + SessionCacheKey that = (SessionCacheKey) o; + return tenantId.equals(that.tenantId) && credentials.equals(that.credentials); + } + + @Override + public int hashCode() { + return Objects.hash(tenantId, credentials); + } + + @Override + public String toString() { + return new StringBuilder("SessionCacheKey{") + .append("tenantId='") + .append(tenantId) + .append('\'') + .append(", credentials=") + .append(credentials) // creds should make sure they dont include sensitive info + .append(", ttl=") + .append(ttl) + .append(", userAgent='") + .append(userAgent) + .append('}') + .toString(); + } } /** - * Put CQLSession in cache. - * - * @param sessionCacheKey key for CQLSession cache - * @param cqlSession CQLSession instance + * Holder for the value added to the cache to make it very clear what key was used when it was + * added so we can get the TTL used when it was loaded. Used for dynamic TTL in the Expiry class */ - public void putSession(SessionCacheKey sessionCacheKey, CqlSession cqlSession) { - sessionCache.put(sessionCacheKey, cqlSession); + record SessionValueHolder(CqlSession session, SessionCacheKey loadingKey) { + + SessionValueHolder { + Objects.requireNonNull(session, "session must not be null"); + Objects.requireNonNull(loadingKey, "loadingKey must not be null"); + } + + /** + * Note that the cache can decide when it wants to actually remove an expired key, so we may be + * closing a session for a tenant at the same time we are opening one. The {@link #toString()} + * includes the identity hash code to help with debugging. + */ + @Override + public String toString() { + return new StringBuilder("SessionValueHolder{") + .append("identityHashCode=") + .append(System.identityHashCode(this)) + .append(", loadingKey=") + .append(loadingKey) + .append('}') + .toString(); + } } /** - * Key for CQLSession cache. + * Dynamic cache TTL for the session cache. * - *

+ *

We use the maximum TTL between either the key that was used the load the session, or the + * current key being used to access it. The TTL is set when the key is created based on the user + * agent coming in. So if a SLA user agent adds it, then a non SLA uses it the non SLA user agent + * TTL will be used. * - * @param tenantId optional tenantId, if null converted to empty string - * @param credentials Required, credentials for the session + *

The laster user who access the session will set the TTL for the session if their TTL is + * higher. */ - public record SessionCacheKey(String tenantId, CqlCredentials credentials) { + static class SessionExpiry implements Expiry { - public SessionCacheKey { - if (tenantId == null) { - tenantId = ""; + private static final Logger LOGGER = LoggerFactory.getLogger(SessionExpiry.class); + + @Override + public long expireAfterCreate(SessionCacheKey key, SessionValueHolder value, long currentTime) { + return value.loadingKey().ttl().toNanos(); + } + + @Override + public long expireAfterUpdate( + SessionCacheKey key, SessionValueHolder value, long currentTime, long currentDuration) { + return currentDuration; + } + + @Override + public long expireAfterRead( + SessionCacheKey key, SessionValueHolder value, long currentTime, long currentDuration) { + long accessTTL = key.ttl().toNanos(); + long loadTTL = value.loadingKey().ttl().toNanos(); + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "expireAfterRead() - key.tenant={}, key.ttl={}, key.identityHashCode={}, value.loadingKey.ttl={}, value.loadingKey.identityHashCode={}", + key.tenantId(), + key.ttl(), + System.identityHashCode(key), + value.loadingKey.ttl(), + System.identityHashCode(value.loadingKey)); } - Objects.requireNonNull(credentials, "credentials must not be null"); + return Math.max(accessTTL, loadTTL); } } + + /** Callback when a tenant is deactivated. */ + @FunctionalInterface + public interface DeactivatedTenantConsumer extends BiConsumer { + void accept(String tenantId, RemovalCause cause); + } + + /** Called to create credentials used with the session and session cache key. */ + @FunctionalInterface + public interface CredentialsFactory extends Function { + CqlCredentials apply(String authToken); + } + + /** Called to create a new session when one is needed. */ + @FunctionalInterface + public interface SessionFactory extends BiFunction { + CqlSession apply(String tenantId, CqlCredentials credentials); + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentials.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentials.java index c5a4f13f6d..08646aec2f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentials.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentials.java @@ -2,67 +2,38 @@ import com.datastax.oss.driver.api.core.CqlSessionBuilder; import io.quarkus.security.UnauthorizedException; -import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import java.util.Base64; +import java.util.Objects; /** * Interface for what it means to have credentials for the CQL driver * - *

Create instances using the {@link #create(String, String, String, String)} factory method, - * this will return the correct implementation based on the provided tokens. + *

Create instances using the {@link CqlCredentialsFactory} class. * *

NOTE: Implementations should be immutable, and support comparison and hashing because * they are used as part of the Session cache key. The initial ones use records for these reasons. */ public interface CqlCredentials { - String USERNAME_PASSWORD_PREFIX = "Cassandra:"; - /** - * Factory method to create the correct CqlCredentials based on the provided tokens. - * - * @param fixedToken the "fixed token" from configuration, e.g. from - * operationsConfig.databaseConfig().fixedToken() is passed in to make testing easier. - * @param authToken the token provided in the request, e.g. from the Authorization / Token header - * @param fallbackUsername the username to use if the fixedToken is set, this is from config - * usually - * @param fallbackPassword the password to use if the fixedToken is set, this is from config - * usually - * @return the correct CqlCredentials implementation based on the provided tokens + * Prefix for an auth token that is a username and password. See {@link + * UsernamePasswordCredentials#fromToken(String)} for the format. */ - static CqlCredentials create( - String fixedToken, String authToken, String fallbackUsername, String fallbackPassword) { - - // This used to be in CqlSessionCache.getSession(), the fixedToken config is used in testing and - // the API - // checks the provided authToken is the same as the configured fixedToken. - if (fixedToken != null && !fixedToken.equals(authToken)) { - throw new UnauthorizedException(ErrorCodeV1.UNAUTHENTICATED_REQUEST.getMessage()); - } - - // Also from CqlSessionCache.getNewSession(), if the fixedToken is set, then we always use the - // configured / fallback username and password - if (fixedToken != null) { - return new UsernamePasswordCredentials(fallbackUsername, fallbackPassword); - } - - return switch (authToken) { - case null -> new AnonymousCredentials(); - case "" -> new AnonymousCredentials(); - case String t when t.startsWith(USERNAME_PASSWORD_PREFIX) -> - UsernamePasswordCredentials.fromToken(t); - default -> new TokenCredentials(authToken); - }; - } + String USERNAME_PASSWORD_TOKEN_PREFIX = "Cassandra:"; /** If the credentials are anonymous, i.e. there is no auth token or username/password. */ default boolean isAnonymous() { return false; } - /** Add the credentials to the provided CqlSessionBuilder so it can login appropriately. */ - void addToSessionBuilder(CqlSessionBuilder builder); + /** Add the credentials to the provided CqlSessionBuilder so it can log in appropriately. */ + CqlSessionBuilder addToSessionBuilder(CqlSessionBuilder builder); + /** + * Credentials when the user has not provided an auth token. + * + *

-- + */ record AnonymousCredentials() implements CqlCredentials { @Override @@ -71,8 +42,14 @@ public boolean isAnonymous() { } @Override - public void addToSessionBuilder(CqlSessionBuilder builder) { - // Do nothing + public CqlSessionBuilder addToSessionBuilder(CqlSessionBuilder builder) { + // do nothing, there is no auth token + return builder; + } + + @Override + public String toString() { + return "AnonymousCredentials{isAnonymous=true}"; } } @@ -83,9 +60,6 @@ public void addToSessionBuilder(CqlSessionBuilder builder) { */ record TokenCredentials(String token) implements CqlCredentials { - /** CQL username to be used when using the auth token as the credentials */ - private static final String USERNAME_TOKEN = "token"; - public TokenCredentials { if (token == null || token.isBlank()) { throw new IllegalArgumentException("token must not be null or blank"); @@ -93,8 +67,21 @@ record TokenCredentials(String token) implements CqlCredentials { } @Override - public void addToSessionBuilder(CqlSessionBuilder builder) { - builder.withAuthCredentials(USERNAME_TOKEN, token); + public CqlSessionBuilder addToSessionBuilder(CqlSessionBuilder builder) { + return builder.withAuthCredentials("token", token); + } + + @Override + public String toString() { + // Don't log the full token, just the first 4 chars + return new StringBuilder("TokenCredentials{") + .append("token='") + .append(token.substring(0, 4)) + .append("...'") + .append(", isAnonymous=") + .append(isAnonymous()) + .append('}') + .toString(); } } @@ -110,20 +97,29 @@ public void addToSessionBuilder(CqlSessionBuilder builder) { record UsernamePasswordCredentials(String userName, String password) implements CqlCredentials { public UsernamePasswordCredentials { - if (userName == null || userName.isBlank()) { - throw new IllegalArgumentException("userName must not be null or blank"); - } - if (password == null || password.isBlank()) { - throw new IllegalArgumentException("password must not be null or blank"); - } + // allow empty string, up to DB to validate + Objects.requireNonNull(userName, "userName must not be null"); + Objects.requireNonNull(password, "password must not be null"); + } + + @Override + public CqlSessionBuilder addToSessionBuilder(CqlSessionBuilder builder) { + return builder.withAuthCredentials(userName, password); } @Override - public void addToSessionBuilder(CqlSessionBuilder builder) { - builder.withAuthCredentials(userName, password); + public String toString() { + // Don't include any username or password + return new StringBuilder("UsernamePasswordCredentials{") + .append("userName='REDACTED'") + .append(", password='REDACTED'") + .append(", isAnonymous=") + .append(isAnonymous()) + .append('}') + .toString(); } - public static UsernamePasswordCredentials fromToken(String encodedCredentials) { + static UsernamePasswordCredentials fromToken(String encodedCredentials) { String[] parts = encodedCredentials.split(":"); if (parts.length != 3) { throw new UnauthorizedException( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentialsFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentialsFactory.java new file mode 100644 index 0000000000..ab6caac1f9 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlCredentialsFactory.java @@ -0,0 +1,76 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver; + +import io.quarkus.security.UnauthorizedException; +import io.stargate.sgv2.jsonapi.config.DatabaseType; +import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Factory to create the CqlCredentials based on the provided tokens. */ +public class CqlCredentialsFactory implements CQLSessionCache.CredentialsFactory { + private static final Logger LOGGER = LoggerFactory.getLogger(CqlCredentialsFactory.class); + + private final DatabaseType databaseType; + private final String fixedToken; + private final String fixedUserName; + private final String fixedPassword; + + /** + * Constructor for the CqlCredentialsFactory. + * + * @param databaseType the type of database, used to determine if anonymous access is allowed. + * @param fixedToken the "fixed token" from configuration, e.g. from + * operationsConfig.databaseConfig().fixedToken() is passed in to make + * testing easier. When not null, this token is used to validate the authToken provided in the + * request. + * @param fixedUserName the username to use if the fixedToken is set, this is from config usually + * @param fixedPassword the password to use if the fixedToken is set, this is from config usually + */ + public CqlCredentialsFactory( + DatabaseType databaseType, String fixedToken, String fixedUserName, String fixedPassword) { + + this.databaseType = Objects.requireNonNull(databaseType, "databaseType must not be null"); + this.fixedToken = fixedToken; + this.fixedUserName = fixedUserName; + this.fixedPassword = fixedPassword; + + if (fixedToken != null && LOGGER.isWarnEnabled()) { + LOGGER.warn("Fixed token is set, all tokens will be validated against this token."); + } + } + + /** Create the CqlCredentials based on the provided authToken. */ + @Override + public CqlCredentials apply(String authToken) { + + if (fixedToken != null) { + // The fixedToken config is used for testing and for the API to verify that + // the provided authToken matches the configured fixedToken. + // (Previously part of CqlSessionCache.getSession()) + if (!fixedToken.equals(authToken)) { + throw new UnauthorizedException(ErrorCodeV1.UNAUTHENTICATED_REQUEST.getMessage()); + } + // If a fixedToken is configured, always use the fallback username and password. + // (Logic originally from CqlSessionCache.getNewSession()) + return new CqlCredentials.UsernamePasswordCredentials(fixedUserName, fixedPassword); + } + + var credentials = + switch (authToken) { + case null -> new CqlCredentials.AnonymousCredentials(); + case "" -> new CqlCredentials.AnonymousCredentials(); + case String t when t.startsWith(CqlCredentials.USERNAME_PASSWORD_TOKEN_PREFIX) -> + CqlCredentials.UsernamePasswordCredentials.fromToken(t); + default -> new CqlCredentials.TokenCredentials(authToken); + }; + + // Only the OFFLINE_WRITER allows anonymous access, because it is not connecting to an actual + // database + if (credentials.isAnonymous() && databaseType != DatabaseType.OFFLINE_WRITER) { + throw ErrorCodeV1.SERVER_INTERNAL_ERROR.toApiException( + "Missing/Invalid authentication credentials provided for type: %s", databaseType); + } + return credentials; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheSupplier.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheSupplier.java new file mode 100644 index 0000000000..077273479c --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionCacheSupplier.java @@ -0,0 +1,76 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver; + +import io.micrometer.core.instrument.MeterRegistry; +import io.stargate.sgv2.jsonapi.config.OperationsConfig; +import io.stargate.sgv2.jsonapi.metrics.MetricsTenantDeactivationConsumer; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; +import java.time.Duration; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; +import org.eclipse.microprofile.config.inject.ConfigProperty; + +/** + * Factory for creating a singleton {@link CQLSessionCache} instance that is configured via CDI. + * + *

We use this factory so the cache itself is not a CDI bean. For one so it does not have all + * that extra overhead, and because there construction is a bit more complicated. + */ +@ApplicationScoped +public class CqlSessionCacheSupplier implements Supplier { + + private final CQLSessionCache singleton; + + @Inject + public CqlSessionCacheSupplier( + @ConfigProperty(name = "quarkus.application.name") String applicationName, + OperationsConfig operationsConfig, + MeterRegistry meterRegistry, + SchemaCache schemaCache) { + + Objects.requireNonNull(applicationName, "applicationName must not be null"); + Objects.requireNonNull(operationsConfig, "operationsConfig must not be null"); + Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + Objects.requireNonNull(schemaCache, "schemaCache must not be null"); + + var dbConfig = operationsConfig.databaseConfig(); + + var credentialsFactory = + new CqlCredentialsFactory( + dbConfig.type(), + dbConfig.fixedToken().orElse(null), + dbConfig.userName(), + dbConfig.password()); + + var sessionFactory = + new CqlSessionFactory( + applicationName, + dbConfig.type(), + dbConfig.localDatacenter(), + dbConfig.cassandraEndPoints(), + dbConfig.cassandraPort(), + List.of(schemaCache.getSchemaChangeListener())); + + singleton = + new CQLSessionCache( + dbConfig.type(), + Duration.ofSeconds(dbConfig.sessionCacheTtlSeconds()), + dbConfig.sessionCacheMaxSize(), + operationsConfig.slaUserAgent().orElse(null), + Duration.ofSeconds(dbConfig.slaSessionCacheTtlSeconds()), + credentialsFactory, + sessionFactory, + meterRegistry, + List.of( + schemaCache.getDeactivatedTenantConsumer(), + new MetricsTenantDeactivationConsumer(meterRegistry))); + } + + /** Gets the singleton instance of the {@link CQLSessionCache}. */ + @Override + public CQLSessionCache get() { + return singleton; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionFactory.java new file mode 100644 index 0000000000..9baa701918 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CqlSessionFactory.java @@ -0,0 +1,169 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver; + +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverConfigLoader; +import com.datastax.oss.driver.api.core.metadata.schema.SchemaChangeListener; +import com.google.common.annotations.VisibleForTesting; +import io.stargate.sgv2.jsonapi.config.DatabaseType; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec; +import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Factory to create {@link CqlSession} instances, normally used with the {@link CQLSessionCache} + * via the {@link CQLSessionCache.SessionFactory} interface. + * + *

Abstracted out to make it easier to test the session cache and creating the session. + */ +public class CqlSessionFactory implements CQLSessionCache.SessionFactory { + + private static final Logger LOGGER = LoggerFactory.getLogger(CqlSessionFactory.class); + + private final String applicationName; + + private final DatabaseType databaseType; + private final String localDatacenter; + private final Collection contactPoints; + private final List schemaChangeListeners; + private final Supplier sessionBuilderSupplier; + + /** + * Constructor for the CqlSessionFactory, normally this overload is used for non-testing code. + * + * @param applicationName the name of the application, set on the CQL session + * @param databaseType the type of database, controls contact points and other settings + * @param localDatacenter the local datacenter for the client connection. + * @param cassandraEndPoints the Cassandra endpoints, only used when the database type is + * CASSANDRA + * @param cassandraPort the Cassandra port, only used when the database type is CASSANDRA + * @param schemaChangeListeners the schema change listeners, these are added to the session to + * listen for schema changes from it. + */ + CqlSessionFactory( + String applicationName, + DatabaseType databaseType, + String localDatacenter, + List cassandraEndPoints, + Integer cassandraPort, + List schemaChangeListeners) { + this( + applicationName, + databaseType, + localDatacenter, + cassandraEndPoints, + cassandraPort, + schemaChangeListeners, + TenantAwareCqlSessionBuilder::new); + } + + /** + * Constructor for the CqlSessionFactory, this overload is for testing so the SessionBuilder can + * be mocked. + * + * @param applicationName the name of the application, set on the CQL session + * @param databaseType the type of database, controls contact points and other settings + * @param localDatacenter the local datacenter for the client connection. + * @param cassandraEndPoints the Cassandra endpoints, only used when the database type is + * CASSANDRA + * @param cassandraPort the Cassandra port, only used when the database type is CASSANDRA + * @param schemaChangeListeners the schema change listeners, these are added to the session to + * listen for schema changes from it. + * @param sessionBuilderSupplier a supplier for creating CqlSessionBuilder instances, so that + * testing can mock the builder for session creation. In prod code use the ctor without this. + */ + @VisibleForTesting + CqlSessionFactory( + String applicationName, + DatabaseType databaseType, + String localDatacenter, + List cassandraEndPoints, + Integer cassandraPort, + List schemaChangeListeners, + Supplier sessionBuilderSupplier) { + + this.applicationName = + Objects.requireNonNull(applicationName, "applicationName must not be null"); + if (applicationName.isBlank()) { + throw new IllegalArgumentException("applicationName must not be blank"); + } + this.databaseType = Objects.requireNonNull(databaseType, "databaseType must not be null"); + this.localDatacenter = + Objects.requireNonNull(localDatacenter, "localDatacenter must not be null"); + + this.schemaChangeListeners = + schemaChangeListeners == null ? List.of() : List.copyOf(schemaChangeListeners); + this.sessionBuilderSupplier = + Objects.requireNonNull(sessionBuilderSupplier, "sessionBuilderSupplier must not be null"); + + // these never change, and we do not have them in astra, so we can cache + if (databaseType == DatabaseType.CASSANDRA) { + Objects.requireNonNull(cassandraEndPoints, "cassandraEndPoints must not be null"); + if (cassandraEndPoints.isEmpty()) { + throw new IllegalArgumentException( + "Database type is %s but cassandraEndPoints is empty.".formatted(databaseType)); + } + contactPoints = + cassandraEndPoints.stream() + .map(host -> new InetSocketAddress(host, cassandraPort)) + .toList(); + } else { + contactPoints = List.of(); + } + } + + @Override + public CqlSession apply(String tenantId, CqlCredentials credentials) { + Objects.requireNonNull(credentials, "credentials must not be null"); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Creating CQL Session tenantId={}, credentials={}", tenantId, credentials); + } + + // the driver TypedDriverOption is only used with DriverConfigLoader.fromMap() + // The ConfigLoader is held by the session and closed when the session closes, do not close it + // here. + // Setting the session name to the tenantId, this is used by the driver to identify the session, + // used in logging and metrics + var configLoader = + DriverConfigLoader.programmaticBuilder() + .withString(DefaultDriverOption.SESSION_NAME, tenantId == null ? "" : tenantId) + .build(); + + var builder = + sessionBuilderSupplier + .get() + .withLocalDatacenter(localDatacenter) + .withClassLoader(Thread.currentThread().getContextClassLoader()) // TODO: EXPLAIN + .withConfigLoader(configLoader) + .withApplicationName(applicationName); + + if (builder instanceof TenantAwareCqlSessionBuilder tenantAwareBuilder) { + tenantAwareBuilder.withTenantId(tenantId); + } + + for (var listener : schemaChangeListeners) { + builder = builder.addSchemaChangeListener(listener); + } + builder = credentials.addToSessionBuilder(builder); + + // for astra it will default to 127.0.0.1 which is routed to the astra proxy + if (databaseType == DatabaseType.CASSANDRA) { + builder = builder.addContactPoints(contactPoints); + } + + // Add optimized CqlVector codec (see [data-api#1775]) + builder = builder.addTypeCodecs(SubtypeOnlyFloatVectorToArrayCodec.instance()); + + // aaron - this used to have an if / else that threw an exception if the database type was not + // known but we test that when creating the credentials for the cache key so no need to do it + // here. + return builder.build(); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CustomTaggingMetricIdGenerator.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CustomTaggingMetricIdGenerator.java deleted file mode 100644 index 6b75cc2d55..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CustomTaggingMetricIdGenerator.java +++ /dev/null @@ -1,65 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.cqldriver; - -import com.datastax.oss.driver.api.core.config.DefaultDriverOption; -import com.datastax.oss.driver.api.core.context.DriverContext; -import com.datastax.oss.driver.api.core.metadata.Node; -import com.datastax.oss.driver.api.core.metrics.NodeMetric; -import com.datastax.oss.driver.api.core.metrics.SessionMetric; -import com.datastax.oss.driver.internal.core.metrics.DefaultMetricId; -import com.datastax.oss.driver.internal.core.metrics.MetricId; -import com.datastax.oss.driver.internal.core.metrics.MetricIdGenerator; -import com.google.common.collect.ImmutableMap; -import edu.umd.cs.findbugs.annotations.NonNull; -import java.util.Objects; - -/** - * Customized Tagging Metric ID Generator for Driver Metric - * - *

Session metric identifiers contain a name starting with "session." and ending with the metric - * path, a tag with the key "session" and the value of the current session name, a tag with the key - * 'tenant' and the value of current tenantId. - * - *

Node metric identifiers contain a name starting with "nodes." and ending with the metric path, - * and 3 tags: a tag with the key "session" and the value of the current session name, a tag with - * the key "node" and the value of the current node endpoint, a tag with the key 'tenant' and the - * value of current tenantId. - */ -public class CustomTaggingMetricIdGenerator implements MetricIdGenerator { - - private final String sessionName; - private final String sessionPrefix; - private final String nodePrefix; - private final String tenantId; - - @SuppressWarnings("unused") - public CustomTaggingMetricIdGenerator(DriverContext context) { - sessionName = context.getSessionName(); - String prefix = - Objects.requireNonNull( - context - .getConfig() - .getDefaultProfile() - .getString(DefaultDriverOption.METRICS_ID_GENERATOR_PREFIX, "")); - sessionPrefix = prefix.isEmpty() ? "session." : prefix + ".session."; - nodePrefix = prefix.isEmpty() ? "nodes." : prefix + ".nodes."; - tenantId = - ((TenantAwareCqlSessionBuilder.TenantAwareDriverContext) context) - .getStartupOptions() - .get(TenantAwareCqlSessionBuilder.TENANT_ID_PROPERTY_KEY); - } - - @NonNull - @Override - public MetricId sessionMetricId(@NonNull SessionMetric metric) { - return new DefaultMetricId( - sessionPrefix + metric.getPath(), ImmutableMap.of("tenant", tenantId)); - } - - @NonNull - @Override - public MetricId nodeMetricId(@NonNull Node node, @NonNull NodeMetric metric) { - return new DefaultMetricId( - nodePrefix + metric.getPath(), - ImmutableMap.of("node", node.getEndPoint().toString(), "tenant", tenantId)); - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/SchemaChangeListener.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/SchemaChangeListener.java deleted file mode 100644 index dc75391d8a..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/SchemaChangeListener.java +++ /dev/null @@ -1,63 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.cqldriver; - -import com.datastax.oss.driver.api.core.metadata.schema.KeyspaceMetadata; -import com.datastax.oss.driver.api.core.metadata.schema.SchemaChangeListenerBase; -import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; -import edu.umd.cs.findbugs.annotations.NonNull; -import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache; -import java.util.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class SchemaChangeListener extends SchemaChangeListenerBase { - - private static final Logger LOGGER = - LoggerFactory.getLogger(SchemaChangeListener.class.getName()); - private final SchemaCache schemaCache; - - private final String tenantId; - - public SchemaChangeListener(SchemaCache schemaCache, String tenantId) { - this.schemaCache = schemaCache; - this.tenantId = tenantId; - } - - /** - * Add tableDropped event listener for every cqlSession, drop the corresponding collectionSetting - * cache entry to avoid operations using outdated CollectionSetting This should work for both CQL - * Table drop and Data API deleteCollection - */ - public void onTableDropped(TableMetadata table) { - schemaCache.evictCollectionSettingCacheEntry( - Optional.ofNullable(tenantId), - table.getKeyspace().asInternal(), - table.getName().asInternal()); - } - - /** - * Add keyspaceDropped event listener for every cqlSession, drop the corresponding namespaceCache - * entry This should work for both CQL keyspace drop and Data API dropNamespace - */ - @Override - public void onKeyspaceDropped(@NonNull KeyspaceMetadata keyspace) { - schemaCache.evictNamespaceCacheEntriesForTenant(tenantId, keyspace.getName().asInternal()); - } - - /** When table is created, drop the corresponding collectionSetting cache entry if existed */ - @Override - public void onTableCreated(@NonNull TableMetadata table) { - schemaCache.evictCollectionSettingCacheEntry( - Optional.ofNullable(tenantId), - table.getKeyspace().asInternal(), - table.getName().asInternal()); - } - - @Override - public void onTableUpdated(@NonNull TableMetadata current, @NonNull TableMetadata previous) { - // Evict from the cache because things like indexes can change for CQL Tables - schemaCache.evictCollectionSettingCacheEntry( - Optional.ofNullable(tenantId), - current.getKeyspace().asInternal(), - current.getName().asInternal()); - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/TenantAwareCqlSessionBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/TenantAwareCqlSessionBuilder.java index 49dfe53d49..7e220dbb15 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/TenantAwareCqlSessionBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/TenantAwareCqlSessionBuilder.java @@ -12,7 +12,10 @@ /** * This is an extension of the {@link CqlSessionBuilder} that allows to pass a tenant ID to the * CQLSession via TenantAwareDriverContext which is an extension of the {@link DefaultDriverContext} - * that adds the tenant ID to the startup options. + * that adds the tenant ID to the startup options. The tenant ID is critical for the cql session and + * it has to be passed and cannot be removed. + * + *

It's linked to issue #2119 */ public class TenantAwareCqlSessionBuilder extends CqlSessionBuilder { /** @@ -22,18 +25,14 @@ public class TenantAwareCqlSessionBuilder extends CqlSessionBuilder { public static final String TENANT_ID_PROPERTY_KEY = "TENANT_ID"; /** Tenant ID that will be passed to the CQLSession via TenantAwareDriverContext */ - private final String tenantId; + private String tenantId; - /** - * Constructor that takes the tenant ID as a parameter - * - * @param tenantId tenant id or database id - */ - public TenantAwareCqlSessionBuilder(String tenantId) { + public TenantAwareCqlSessionBuilder withTenantId(String tenantId) { if (tenantId == null || tenantId.isEmpty()) { throw ErrorCodeV1.SERVER_INTERNAL_ERROR.toApiException("Tenant ID cannot be null or empty"); } this.tenantId = tenantId; + return this; } /** diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CollectionIndexUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CollectionIndexUsage.java index 843dd513e2..231f16d84f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CollectionIndexUsage.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CollectionIndexUsage.java @@ -3,6 +3,7 @@ import com.google.common.base.Preconditions; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; /** * This class is used to track the usage of indexes in a query. It is used to generate metrics for @@ -19,7 +20,8 @@ public class CollectionIndexUsage implements IndexUsage { textIndexTag, timestampIndexTag, nullIndexTag, - vectorIndexTag; + vectorIndexTag, + lexicalIndexTag; /** * This method is used to generate the tags for the index usage @@ -32,13 +34,23 @@ public Tags getTags() { Tag.of("key", String.valueOf(primaryKeyTag)), Tag.of("exist_keys", String.valueOf(existKeysIndexTag)), Tag.of("array_size", String.valueOf(arraySizeIndexTag)), - Tag.of("array_contains", String.valueOf(arrayContainsTag)), - Tag.of("query_bool_values", String.valueOf(booleanIndexTag)), - Tag.of("query_dbl_values", String.valueOf(numberIndexTag)), - Tag.of("query_text_values", String.valueOf(textIndexTag)), - Tag.of("query_timestamp_values", String.valueOf(timestampIndexTag)), - Tag.of("query_null_values", String.valueOf(nullIndexTag)), - Tag.of("query_vector_value", String.valueOf(vectorIndexTag))); + Tag.of( + DocumentConstants.Columns.DATA_CONTAINS_COLUMN_NAME, String.valueOf(arrayContainsTag)), + Tag.of( + DocumentConstants.Columns.QUERY_BOOLEAN_MAP_COLUMN_NAME, + String.valueOf(booleanIndexTag)), + Tag.of( + DocumentConstants.Columns.QUERY_DOUBLE_MAP_COLUMN_NAME, String.valueOf(numberIndexTag)), + Tag.of(DocumentConstants.Columns.QUERY_NULL_MAP_COLUMN_NAME, String.valueOf(nullIndexTag)), + Tag.of(DocumentConstants.Columns.QUERY_TEXT_MAP_COLUMN_NAME, String.valueOf(textIndexTag)), + Tag.of( + DocumentConstants.Columns.QUERY_TIMESTAMP_MAP_COLUMN_NAME, + String.valueOf(timestampIndexTag)), + Tag.of( + DocumentConstants.Columns.VECTOR_SEARCH_INDEX_COLUMN_NAME, + String.valueOf(vectorIndexTag)), + Tag.of( + DocumentConstants.Columns.LEXICAL_INDEX_COLUMN_NAME, String.valueOf(lexicalIndexTag))); } /** diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CommandQueryExecutor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CommandQueryExecutor.java index 3254925c34..808f5b3284 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CommandQueryExecutor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/CommandQueryExecutor.java @@ -74,7 +74,6 @@ public CommandQueryExecutor( public Uni executeRead(SimpleStatement statement) { Objects.requireNonNull(statement, "statement must not be null"); - statement = withExecutionProfile(statement, QueryType.READ); return executeAndWrap(statement); } @@ -158,7 +157,9 @@ public Uni executeCreateSchema(SimpleStatement statement) { private CqlSession session() { return cqlSessionCache.getSession( - dbRequestContext.tenantId().orElse(""), dbRequestContext.authToken().orElse("")); + dbRequestContext.tenantId().orElse(""), + dbRequestContext.authToken().orElse(""), + dbRequestContext.userAgent().orElse(null)); } private String getExecutionProfile(QueryType queryType) { @@ -182,5 +183,8 @@ public Uni executeAndWrap(SimpleStatement statement) { // Aaron - Feb 3 - temp rename while factoring full RequestContext public record DBRequestContext( - Optional tenantId, Optional authToken, boolean tracingEnabled) {} + Optional tenantId, + Optional authToken, + Optional userAgent, + boolean tracingEnabled) {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/QueryExecutor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/QueryExecutor.java index b35b9b81f0..bb1564b74a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/QueryExecutor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/QueryExecutor.java @@ -16,8 +16,6 @@ import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.inject.Inject; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Base64; @@ -27,7 +25,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -@ApplicationScoped +/** + * This is legacy class from the first versions of the API, this class is now created in {@link + * io.stargate.sgv2.jsonapi.service.operation.Operation#execute(RequestContext, QueryExecutor)} for + * backwards compatibility. From there is passed to the operation and used to execute. + * + *

It is no longer a bean and should not be injected. + */ public class QueryExecutor { private static final Logger logger = LoggerFactory.getLogger(QueryExecutor.class); private final OperationsConfig operationsConfig; @@ -37,7 +41,6 @@ public class QueryExecutor { private final RequestTracing requestTracing; - @Inject public QueryExecutor(CQLSessionCache cqlSessionCache, OperationsConfig operationsConfig) { this(cqlSessionCache, operationsConfig, RequestTracing.NO_OP); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaCache.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaCache.java index 6b1c99d780..7debb5aa26 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaCache.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaCache.java @@ -1,93 +1,318 @@ package io.stargate.sgv2.jsonapi.service.cqldriver.executor; -import static io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache.CASSANDRA; +import static io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache.DEFAULT_TENANT; -import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.metadata.schema.KeyspaceMetadata; +import com.datastax.oss.driver.api.core.metadata.schema.SchemaChangeListener; +import com.datastax.oss.driver.api.core.metadata.schema.SchemaChangeListenerBase; +import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.datastax.oss.driver.api.core.session.Session; import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.LoadingCache; +import com.github.benmanes.caffeine.cache.RemovalCause; +import com.google.common.annotations.VisibleForTesting; +import edu.umd.cs.findbugs.annotations.NonNull; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RequestContext; +import io.stargate.sgv2.jsonapi.config.DatabaseType; import io.stargate.sgv2.jsonapi.config.OperationsConfig; +import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; +import io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionCacheSupplier; +import io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionFactory; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; -import java.util.Map; +import java.util.Objects; import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -/** Caches the vector enabled status for all the namespace in schema */ +/** + * Top level entry for caching the keyspaces and tables from the backend db + * + *

IMPORTANT: use {@link #getSchemaChangeListener()} and {@link #getDeactivatedTenantConsumer()} + * to get callbacks to evict the cache when the schema changes or a tenant is deactivated. This + * should be handled in {@link CqlSessionCacheSupplier} + * + *

TODO: There should be a single level cache of keyspace,table not two levels, it will be easier + * to size and manage https://github.com/stargate/data-api/issues/2070 + */ @ApplicationScoped public class SchemaCache { + private static final Logger LOGGER = LoggerFactory.getLogger(SchemaCache.class); - @Inject QueryExecutor queryExecutor; + private final CqlSessionCacheSupplier sessionCacheSupplier; + private final DatabaseType databaseType; + private final ObjectMapper objectMapper; + private final TableCacheFactory tableCacheFactory; + private final OperationsConfig operationsConfig; - @Inject ObjectMapper objectMapper; + /** caching the keyspaces we know about which then have all the tables / collections under them */ + private final LoadingCache keyspaceCache; - @Inject OperationsConfig operationsConfig; + @Inject + public SchemaCache( + CqlSessionCacheSupplier sessionCacheSupplier, + ObjectMapper objectMapper, + OperationsConfig operationsConfig) { + this(sessionCacheSupplier, objectMapper, operationsConfig, TableBasedSchemaCache::new); + } + + /** + * NOTE: must not use the sessionCacheSupplier in the ctor or because it will create a circular + * calls, because the sessionCacheSupplier calls schema cache to get listeners + */ + @VisibleForTesting + protected SchemaCache( + CqlSessionCacheSupplier sessionCacheSupplier, + ObjectMapper objectMapper, + OperationsConfig operationsConfig, + TableCacheFactory tableCacheFactory) { + + this.sessionCacheSupplier = + Objects.requireNonNull(sessionCacheSupplier, "sessionCacheSupplier must not be null"); + this.objectMapper = Objects.requireNonNull(objectMapper, "objectMapper must not be null"); + this.operationsConfig = operationsConfig; + this.databaseType = + Objects.requireNonNull(operationsConfig, "operationsConfig must not be null") + .databaseConfig() + .type(); + this.tableCacheFactory = + Objects.requireNonNull(tableCacheFactory, "tableCacheFactory must not be null"); + + // TODO: The size of the cache should be in configuration. + int cacheSize = 1000; + keyspaceCache = Caffeine.newBuilder().maximumSize(cacheSize).build(this::onLoad); + + LOGGER.info("SchemaCache created with max size {}", cacheSize); + } - // TODO: The size of the cache should be in configuration. - // TODO: set the cache loader when creating the cache - private final Cache schemaCache = - Caffeine.newBuilder().maximumSize(1000).build(); + /** + * Gets a listener to use with the {@link CqlSessionFactory} to remove the schema cache entries + * when the DB sends schema change events. + */ + public SchemaChangeListener getSchemaChangeListener() { + return new SchemaCacheSchemaChangeListener(this); + } + + /** + * Gets a consumer to use with the {@link CQLSessionCache} to remove the schema cache entries when + * a tenant is deactivated. + */ + public CQLSessionCache.DeactivatedTenantConsumer getDeactivatedTenantConsumer() { + return new SchemaCacheDeactivatedTenantConsumer(this); + } + /** Gets or loads the schema object for the given namespace and collection or table name. */ public Uni getSchemaObject( - RequestContext dataApiRequestInfo, - Optional tenant, + RequestContext requestContext, String namespace, String collectionName, boolean forceRefresh) { - // TODO: refactor, this has duplicate code, the only special handling the OSS has is the tenant - // check + Objects.requireNonNull(namespace, "namespace must not be null"); + + // Aaron 1 may 2025 - Based on existing logic I thought we should have a non null tenantId in + // all cases other than + // when using cassandra as a backend. + var resolvedTenantId = + (databaseType == DatabaseType.CASSANDRA) + ? requestContext.getTenantId().orElse(DEFAULT_TENANT) + : requestContext.getTenantId().orElseThrow(); - if (CASSANDRA.equals(operationsConfig.databaseConfig().type())) { - // default_tenant is for oss run - // TODO: move the string to a constant or config, why does this still check the tenant if this - // is for OSS ? - final NamespaceCache namespaceCache = - schemaCache.get( - new CacheKey(Optional.of(tenant.orElse("default_tenant")), namespace), - this::addNamespaceCache); - return namespaceCache.getSchemaObject(dataApiRequestInfo, collectionName, forceRefresh); + var tableBasedSchemaCache = + keyspaceCache.get(new KeyspaceCacheKey(resolvedTenantId, namespace)); + Objects.requireNonNull( + tableBasedSchemaCache, "keyspaceCache must not return null tableBasedSchemaCache"); + return tableBasedSchemaCache.getSchemaObject(requestContext, collectionName, forceRefresh); + } + + private TableBasedSchemaCache onLoad(SchemaCache.KeyspaceCacheKey key) { + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("onLoad() - tenantId: {}, keyspace: {}", key.tenantId(), key.keyspace()); } - final NamespaceCache namespaceCache = - schemaCache.get(new CacheKey(tenant, namespace), this::addNamespaceCache); - return namespaceCache.getSchemaObject(dataApiRequestInfo, collectionName, forceRefresh); + // Cannot get a session from the sessionCacheSupplier in the constructor because + // it will create a circular call. So need to wait until now to create the QueryExecutor + // this is OK, only happens when the table is not in the cache + var queryExecutor = new QueryExecutor(sessionCacheSupplier.get(), operationsConfig); + return tableCacheFactory.create(key.keyspace(), queryExecutor, objectMapper); + } + + /** For testing only - peek to see if the schema object is in the cache without loading it. */ + @VisibleForTesting + Optional peekSchemaObject(String tenantId, String keyspaceName, String tableName) { + + var tableBasedSchemaCache = + keyspaceCache.getIfPresent(new KeyspaceCacheKey(tenantId, keyspaceName)); + if (tableBasedSchemaCache != null) { + return tableBasedSchemaCache.peekSchemaObject(tableName); + } + return Optional.empty(); } + ; + + /** Removes the table from the cache if present. */ + void evictTable(String tenantId, String keyspace, String tableName) { + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "evictTable() - tenantId: {}, keyspace: {}, tableName: {}", + tenantId, + keyspace, + tableName); + } - /** Evict collectionSetting Cache entry when there is a drop table event */ - public void evictCollectionSettingCacheEntry( - Optional tenant, String namespace, String collectionName) { - final NamespaceCache namespaceCache = schemaCache.getIfPresent(new CacheKey(tenant, namespace)); - if (namespaceCache != null) { - namespaceCache.evictCollectionSettingCacheEntry(collectionName); + var tableBasedSchemaCache = + keyspaceCache.getIfPresent(new KeyspaceCacheKey(tenantId, keyspace)); + if (tableBasedSchemaCache != null) { + tableBasedSchemaCache.evictCollectionSettingCacheEntry(tableName); } } - private NamespaceCache addNamespaceCache(CacheKey cacheKey) { - return new NamespaceCache(cacheKey.namespace(), queryExecutor, objectMapper); + /** Removes all keyspaces and table entries for the given tenantId from the cache. */ + void evictAllKeyspaces(String tenantId) { + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("evictAllKeyspaces() - tenantId: {}", tenantId); + } + + keyspaceCache.asMap().keySet().removeIf(key -> key.tenantId().equals(tenantId)); + } + + /** Removes the keyspace from the cache if present. */ + void evictKeyspace(String tenant, String keyspace) { + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("evictKeyspace() - tenantId: {}, keyspace: {}", tenant, keyspace); + } + + keyspaceCache.invalidate(new KeyspaceCacheKey(tenant, keyspace)); + } + + /** Key for the Keyspace cache, we rely on the record hash and equals */ + record KeyspaceCacheKey(String tenantId, String keyspace) { + + KeyspaceCacheKey { + Objects.requireNonNull(tenantId, "tenantId must not be null"); + Objects.requireNonNull(keyspace, "namespace must not be null"); + } } /** - * When a sessionCache entry expires, evict all corresponding entire NamespaceCaches for the - * tenant This is to ensure there is no offset for sessionCache and schemaCache + * SchemaChangeListener for the schema cache, this is used to evict the cache entries when we get + * messages from the DB that the schema has changed. + * + *

NOTE: This relies on the sessionName being set correctly which should be in {@link + * io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionFactory} + * + *

A new schema change listener should be created for each CQL {@link Session} when it is + * created because the listener will first listen for {@link SchemaChangeListener#onSessionReady} + * and get the tenantID from the session name via {@link Session#getName}. + * + *

If the tenantId is not set, null or blank, we log at ERROR rather than throw because the + * callback methods are called on driver async threads and exceptions there are unlikely to be + * passed back in the request response. + * + *

This could be non-static inner, but static to make testing easier so we can pass in the + * cache it is working with. */ - public void evictNamespaceCacheEntriesForTenant( - String tenant, Map keyspaces) { - for (Map.Entry cqlIdentifierKeyspaceMetadataEntry : - keyspaces.entrySet()) { - schemaCache.invalidate( - new CacheKey( - Optional.ofNullable(tenant), - cqlIdentifierKeyspaceMetadataEntry.getKey().asInternal())); + static class SchemaCacheSchemaChangeListener extends SchemaChangeListenerBase { + + private static final Logger LOGGER = + LoggerFactory.getLogger(SchemaCacheSchemaChangeListener.class); + + private final SchemaCache schemaCache; + + private String tenantId = null; + + public SchemaCacheSchemaChangeListener(SchemaCache schemaCache) { + this.schemaCache = Objects.requireNonNull(schemaCache, "schemaCache must not be null"); + } + + private boolean hasTenantId(String context) { + if (tenantId == null || tenantId.isBlank()) { + LOGGER.error( + "SchemaCacheSchemaChangeListener tenantId is null or blank when expected to be set - {}", + context); + return false; + } + return true; + } + + private void evictTable(String context, TableMetadata tableMetadata) { + if (hasTenantId(context)) { + schemaCache.evictTable( + tenantId, + tableMetadata.getKeyspace().asInternal(), + tableMetadata.getName().asInternal()); + } + } + + @Override + public void onSessionReady(@NonNull Session session) { + // This is called when the session is ready, we can get the tenantId from the session name + // and set it in the listener so we can use it in the other methods. + tenantId = session.getName(); + hasTenantId("onSessionReady called but sessionName() is null or blank"); + } + + /** + * When a table is dropped, evict from cache to reduce the size and avoid stale if it is + * re-created + */ + @Override + public void onTableDropped(@NonNull TableMetadata table) { + evictTable("onTableDropped", table); + } + + /** When a table is created, evict from cache to avoid stale if it was re-created */ + @Override + public void onTableCreated(@NonNull TableMetadata table) { + evictTable("onTableCreated", table); + } + + /** When a table is updated, evict from cache to avoid stale entries */ + @Override + public void onTableUpdated(@NonNull TableMetadata current, @NonNull TableMetadata previous) { + // table name can never change + evictTable("onTableUpdated", current); + } + + /** When keyspace dropped, we dont need any more of the tables in the cache */ + @Override + public void onKeyspaceDropped(@NonNull KeyspaceMetadata keyspace) { + if (hasTenantId("onKeyspaceDropped")) { + schemaCache.evictKeyspace(tenantId, keyspace.getName().asInternal()); + } } } - /** Evict corresponding namespaceCache When there is a keyspace drop event */ - public void evictNamespaceCacheEntriesForTenant(String tenant, String keyspace) { - schemaCache.invalidate(new CacheKey(Optional.ofNullable(tenant), keyspace)); + /** + * Listener for use with the {@link CQLSessionCache} to remove the schema cache entries when a + * tenant is deactivated. + */ + private static class SchemaCacheDeactivatedTenantConsumer + implements CQLSessionCache.DeactivatedTenantConsumer { + + private final SchemaCache schemaCache; + + public SchemaCacheDeactivatedTenantConsumer(SchemaCache schemaCache) { + this.schemaCache = Objects.requireNonNull(schemaCache, "schemaCache must not be null"); + } + + @Override + public void accept(String tenantId, RemovalCause cause) { + // the sessions are keyed on the tenantID and the credentials, and one session can work with + // multiple keyspaces. So we need to evict all the keyspaces for the tenantId + schemaCache.evictAllKeyspaces(tenantId); + } } - record CacheKey(Optional tenant, String namespace) {} + /** Function to create a new TableBasedSchemaCache, so we can mock when testing */ + @FunctionalInterface + public interface TableCacheFactory { + TableBasedSchemaCache create( + String namespace, QueryExecutor queryExecutor, ObjectMapper objectMapper); + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObjectName.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObjectName.java index edec0bb0a5..00de4f348f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObjectName.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/SchemaObjectName.java @@ -34,6 +34,11 @@ public void addToMDC() { MDC.put("collection", table); } + public void removeFromMDC() { + MDC.remove("namespace"); + MDC.remove("collection"); + } + @Override public DataRecorder recordTo(DataRecorder dataRecorder) { return dataRecorder.append("keyspace", keyspace).append("table", table); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/TableBasedSchemaCache.java similarity index 93% rename from src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java rename to src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/TableBasedSchemaCache.java index 93a8dfe9e6..8cdf5da210 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/TableBasedSchemaCache.java @@ -10,12 +10,13 @@ import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionTableMatcher; import java.time.Duration; +import java.util.Optional; /** Caches the vector enabled status for the namespace */ // TODO: what is the vector status of a namespace ? vectors are per collection // TODO: clarify the name of this class, it is a cache of the collections/ tables not a cache of // namespaces ?? -public class NamespaceCache { +public class TableBasedSchemaCache { public final String namespace; @@ -33,7 +34,8 @@ public class NamespaceCache { .maximumSize(CACHE_MAX_SIZE) .build(); - public NamespaceCache(String namespace, QueryExecutor queryExecutor, ObjectMapper objectMapper) { + public TableBasedSchemaCache( + String namespace, QueryExecutor queryExecutor, ObjectMapper objectMapper) { this.namespace = namespace; this.queryExecutor = queryExecutor; this.objectMapper = objectMapper; @@ -84,6 +86,10 @@ protected Uni getSchemaObject( } } + Optional peekSchemaObject(String tableName) { + return Optional.ofNullable(schemaObjectCache.getIfPresent(tableName)); + } + private Uni loadSchemaObject(RequestContext requestContext, String collectionName) { return queryExecutor diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java index 7cb44440d3..e2f416c854 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/VectorConfig.java @@ -167,6 +167,16 @@ public int hashCode() { return Objects.hash(columnVectorDefinitions, vectorEnabled); } + @Override + public String toString() { + return "VectorConfig[" + + "vectorEnabled=" + + vectorEnabled + + ", columnVectorDefinitions=" + + columnVectorDefinitions + + ']'; + } + @Override public DataRecorder recordTo(DataRecorder dataRecorder) { return dataRecorder diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/DefaultSubConditionRelation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/DefaultSubConditionRelation.java new file mode 100644 index 0000000000..1770f5ae99 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/DefaultSubConditionRelation.java @@ -0,0 +1,170 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.override; + +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.cql.SimpleStatementBuilder; +import com.datastax.oss.driver.api.querybuilder.BuildableQuery; +import com.datastax.oss.driver.api.querybuilder.CqlSnippet; +import com.datastax.oss.driver.api.querybuilder.relation.OngoingWhereClause; +import com.datastax.oss.driver.api.querybuilder.relation.Relation; +import edu.umd.cs.findbugs.annotations.CheckReturnValue; +import edu.umd.cs.findbugs.annotations.NonNull; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * This class is to add AND/OR ability to driver where clauses, and it is a temporary override + * solution, should be removed once the driver supports it natively. Note, although AND/OR has been + * implemented in DataStax Astra, it is not yet in Apache Cassandra. So the community Java driver + * does not support this as of now. See Ticket CASSJAVA-47. + */ +public class DefaultSubConditionRelation + implements OngoingWhereClause, BuildableQuery, Relation { + + private final List relations; + private final boolean isSubCondition; + + /** Construct sub-condition relation with empty WHERE clause. */ + public DefaultSubConditionRelation(boolean isSubCondition) { + this.relations = new ArrayList<>(); + this.isSubCondition = isSubCondition; + } + + @NonNull + @Override + public DefaultSubConditionRelation where(@NonNull Relation relation) { + relations.add(relation); + return this; + } + + @NonNull + @Override + public DefaultSubConditionRelation where(@NonNull Iterable additionalRelations) { + for (Relation relation : additionalRelations) { + relations.add(relation); + } + return this; + } + + @NonNull + public DefaultSubConditionRelation withRelations(@NonNull List newRelations) { + relations.addAll(newRelations); + return this; + } + + @NonNull + @Override + public String asCql() { + StringBuilder builder = new StringBuilder(); + + if (isSubCondition) { + builder.append("("); + } + appendWhereClause(builder, relations, isSubCondition); + if (isSubCondition) { + builder.append(")"); + } + + return builder.toString(); + } + + public static void appendWhereClause( + StringBuilder builder, List relations, boolean isSubCondition) { + boolean first = true; + for (int i = 0; i < relations.size(); ++i) { + CqlSnippet snippet = relations.get(i); + if (first && !isSubCondition) { + builder.append(" WHERE "); + } + first = false; + + snippet.appendTo(builder); + + boolean logicalOperatorAdded = false; + LogicalRelation logicalRelation = lookAheadNextRelation(relations, i, LogicalRelation.class); + if (logicalRelation != null) { + builder.append(" "); + logicalRelation.appendTo(builder); + builder.append(" "); + logicalOperatorAdded = true; + ++i; + } + if (!logicalOperatorAdded && i + 1 < relations.size()) { + builder.append(" AND "); + } + } + } + + private static T lookAheadNextRelation( + List relations, int position, Class clazz) { + if (position + 1 >= relations.size()) { + return null; + } + Relation relation = relations.get(position + 1); + if (relation.getClass().isAssignableFrom(clazz)) { + return (T) relation; + } + return null; + } + + @NonNull + @Override + public SimpleStatement build() { + return builder().build(); + } + + @NonNull + @Override + public SimpleStatement build(@NonNull Object... values) { + return builder().addPositionalValues(values).build(); + } + + @NonNull + @Override + public SimpleStatement build(@NonNull Map namedValues) { + SimpleStatementBuilder builder = builder(); + for (Map.Entry entry : namedValues.entrySet()) { + builder.addNamedValue(entry.getKey(), entry.getValue()); + } + return builder.build(); + } + + @Override + public String toString() { + return asCql(); + } + + @Override + public void appendTo(@NonNull StringBuilder builder) { + builder.append(asCql()); + } + + @Override + public boolean isIdempotent() { + for (Relation relation : relations) { + if (!relation.isIdempotent()) { + return false; + } + } + return true; + } + + /** Adds conjunction clause. Next relation is logically joined with AND. */ + public OngoingWhereClause and() { + return where(LogicalRelation.AND); + } + + /** Adds alternative clause. Next relation is logically joined with OR. */ + @NonNull + @CheckReturnValue + public OngoingWhereClause or() { + return where(LogicalRelation.OR); + } + + /** Creates new sub-condition in the WHERE clause, surrounded by parenthesis. */ + @NonNull + public static DefaultSubConditionRelation subCondition() { + return new DefaultSubConditionRelation(true); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/ExtendedSelect.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/ExtendedSelect.java new file mode 100644 index 0000000000..b993118b58 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/ExtendedSelect.java @@ -0,0 +1,90 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.override; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder; +import com.datastax.oss.driver.api.querybuilder.BindMarker; +import com.datastax.oss.driver.api.querybuilder.select.SelectFrom; +import com.datastax.oss.driver.internal.querybuilder.CqlHelper; +import com.datastax.oss.driver.internal.querybuilder.select.DefaultSelect; +import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; +import java.util.Map; + +/** + * This class is to add AND/OR ability to driver DefaultSelect, and it is a temporary override + * solution, should be removed once the driver supports it natively. Note, although AND/OR has been + * implemented in DataStax Astra, it is not yet in Apache Cassandra. So the community Java driver + * does not support this as of now. See Ticket CASSJAVA-47. + */ +public class ExtendedSelect extends DefaultSelect { + + public ExtendedSelect(@Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier table) { + super(keyspace, table); + } + + @NonNull + public static SelectFrom selectFrom( + @Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier table) { + return new ExtendedSelect(keyspace, table); + } + + @NonNull + @Override + public String asCql() { + StringBuilder builder = new StringBuilder(); + + builder.append("SELECT"); + if (isJson()) { + builder.append(" JSON"); + } + if (isDistinct()) { + builder.append(" DISTINCT"); + } + + CqlHelper.append(getSelectors(), builder, " ", ",", null); + + builder.append(" FROM "); + CqlHelper.qualify(getKeyspace(), getTable(), builder); + + // Note, this is the specific override to apply AND/OR to the where clause + DefaultSubConditionRelation.appendWhereClause(builder, getRelations(), false); + + CqlHelper.append(getGroupByClauses(), builder, " GROUP BY ", ",", null); + + boolean first = true; + for (Map.Entry entry : getOrderings().entrySet()) { + if (first) { + builder.append(" ORDER BY "); + first = false; + } else { + builder.append(","); + } + builder.append(entry.getKey().asCql(true)).append(" ").append(entry.getValue().name()); + } + + if (getLimit() != null) { + builder.append(" LIMIT "); + if (getLimit() instanceof BindMarker) { + ((BindMarker) getLimit()).appendTo(builder); + } else { + builder.append(getLimit()); + } + } + + if (getPerPartitionLimit() != null) { + builder.append(" PER PARTITION LIMIT "); + if (getPerPartitionLimit() instanceof BindMarker) { + ((BindMarker) getPerPartitionLimit()).appendTo(builder); + } else { + builder.append(getPerPartitionLimit()); + } + } + + if (allowsFiltering()) { + builder.append(" ALLOW FILTERING"); + } + + return builder.toString(); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/LogicalRelation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/LogicalRelation.java new file mode 100644 index 0000000000..627a296b68 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/override/LogicalRelation.java @@ -0,0 +1,36 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.override; + +import com.datastax.oss.driver.api.querybuilder.relation.Relation; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import edu.umd.cs.findbugs.annotations.NonNull; +import net.jcip.annotations.Immutable; + +/** + * This class is to add AND/OR relation support to java driver, and it is a temporary override + * solution, should be removed once the driver supports it natively. Note, although AND/OR has been + * implemented in DataStax Astra, it is not yet in Apache Cassandra. So the community Java driver + * does not support this as of now. See Ticket CASSJAVA-47. + */ +@Immutable +public class LogicalRelation implements Relation { + public static final LogicalRelation AND = new LogicalRelation("AND"); + public static final LogicalRelation OR = new LogicalRelation("OR"); + + private final String operator; + + public LogicalRelation(@NonNull String operator) { + Preconditions.checkNotNull(operator); + this.operator = operator; + } + + @Override + public void appendTo(@NonNull StringBuilder builder) { + builder.append(operator); + } + + @Override + public boolean isIdempotent() { + return true; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java index 524a6d8e21..33e3bdd551 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java @@ -86,7 +86,12 @@ public DataVectorizer constructDataVectorizer( return new DataVectorizer( embeddingProvider, objectMapper.getNodeFactory(), - commandContext.requestContext().getEmbeddingCredentials(), + commandContext + .requestContext() + .getEmbeddingCredentialsSupplier() + .create( + commandContext.requestContext(), + embeddingProvider == null ? null : embeddingProvider.getProviderConfig()), commandContext.schemaObject()); } @@ -94,7 +99,7 @@ private Uni vectorizeSortClause( DataVectorizer dataVectorizer, CommandContext commandContext, Command command) { if (command instanceof Sortable sortable) { - return dataVectorizer.vectorize(sortable.sortClause()); + return dataVectorizer.vectorize(sortable.sortClause(commandContext)); } return Uni.createFrom().item(true); } @@ -246,13 +251,13 @@ List tasksForVectorizeColumns( private List tasksForSort( Sortable command, CommandContext commandContext) { - var sortClause = command.sortClause(); + var sortClause = command.sortClause(commandContext); // because this is coming off the command may be null or empty if (sortClause == null || sortClause.isEmpty()) { return List.of(); } - var vectorizeSorts = command.sortClause().tableVectorizeSorts(); + var vectorizeSorts = sortClause.tableVectorizeSorts(); if (vectorizeSorts.isEmpty()) { return List.of(); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/EmbeddingApiKeyResolverProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/EmbeddingApiKeyResolverProvider.java deleted file mode 100644 index cbc3c9b33a..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/EmbeddingApiKeyResolverProvider.java +++ /dev/null @@ -1,27 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding; - -import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentialsResolver; -import io.stargate.sgv2.jsonapi.api.request.HeaderBasedEmbeddingCredentialsResolver; -import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.inject.Inject; -import jakarta.inject.Singleton; -import jakarta.ws.rs.Produces; - -/** - * Simple CDI producer for the {@link EmbeddingCredentialsResolver} to be used in the embedding - * service - */ -@Singleton -public class EmbeddingApiKeyResolverProvider { - @Inject HttpConstants httpConstants; - - @Produces - @ApplicationScoped - EmbeddingCredentialsResolver headerTokenResolver() { - return new HeaderBasedEmbeddingCredentialsResolver( - httpConstants.embeddingApiKey(), - httpConstants.embeddingAccessId(), - httpConstants.embeddingSecretId()); - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java index db94a35c4a..da7bd13bc4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java @@ -38,9 +38,13 @@ public static ServiceConfig custom(Optional> implementationClass) { public String getBaseUrl(String modelName) { if (modelUrlOverrides != null && modelUrlOverrides.get(modelName) == null) { - throw ErrorCodeV1.VECTORIZE_MODEL_DEPRECATED.toApiException( - "Model %s is deprecated, supported models for provider '%s' are %s", - modelName, serviceName, modelUrlOverrides.keySet()); + // modelUrlOverride is a work-around for self-hosted nvidia models with different url. + // This is bad, initial design should have url in model level instead of provider level. + // As best practice, when we deprecate or EOL a model: + // we must mark the status in the configuration, + // instead of removing the whole configuration entry. + throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + "unknown model '%s' for service provider '%s'", modelName, serviceProvider); } return modelUrlOverrides != null ? modelUrlOverrides.get(modelName).orElse(baseUrl) : baseUrl; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java index 766efc8ec1..ee038a5e77 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java @@ -5,6 +5,7 @@ import io.smallrye.config.WithConverter; import io.smallrye.config.WithDefault; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import jakarta.annotation.Nullable; import jakarta.inject.Inject; import java.util.List; @@ -30,11 +31,14 @@ interface EmbeddingProviderConfig { @JsonProperty Optional url(); + /** a boolean to flag if the Astra token should be passed through to the provider. */ + @JsonProperty + @WithDefault("false") + boolean authTokenPassThroughForNoneAuth(); + /** * A map of supported authentications. HEADER, SHARED_SECRET and NONE are the only techniques * the DataAPI supports (i.e. the key of map can only be HEADER, SHARED_SECRET or NONE). - * - * @return */ @JsonProperty Map supportedAuthentications(); @@ -105,11 +109,17 @@ interface ModelConfig { @JsonProperty String name(); + /** + * apiModelSupport marks the support status of the model and optional message for the + * deprecation, EOL etc. By default, apiModelSupport will be mapped to SUPPORTED and empty + * message if it is not configured in the config source. + */ + @JsonProperty + ApiModelSupport apiModelSupport(); + /** * vectorDimension is not null if the model supports a single dimension value. It will be null * if the model supports different dimensions. A parameter called vectorDimension is included. - * - * @return */ @Nullable @JsonProperty diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java index c8fb6e84b9..601d5efeea 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java @@ -1,6 +1,7 @@ package io.stargate.sgv2.jsonapi.service.embedding.configuration; import io.stargate.embedding.gateway.EmbeddingGateway; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import java.util.*; import java.util.stream.Collectors; @@ -12,6 +13,7 @@ public record EmbeddingProviderConfigImpl( String displayName, boolean enabled, Optional url, + boolean authTokenPassThroughForNoneAuth, Map supportedAuthentications, List parameters, RequestProperties properties, @@ -26,6 +28,7 @@ public record TokenConfigImpl(String accepted, String forwarded) implements Toke public record ModelConfigImpl( String name, + ApiModelSupport apiModelSupport, Optional vectorDimension, List parameters, Map properties, @@ -37,6 +40,12 @@ public ModelConfigImpl( List modelParameterList) { this( grpcModelConfig.getName(), + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.valueOf( + grpcModelConfig.getApiModelSupport().getStatus()), + grpcModelConfig.getApiModelSupport().hasMessage() + ? Optional.of(grpcModelConfig.getApiModelSupport().getMessage()) + : Optional.empty()), grpcModelConfig.hasVectorDimension() ? Optional.of(grpcModelConfig.getVectorDimension()) : Optional.empty(), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java index fb7dfa5e66..b0c078c586 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java @@ -144,6 +144,7 @@ private EmbeddingProvidersConfig grpcResponseToConfig( grpcProviderConfig.hasUrl() ? Optional.of(grpcProviderConfig.getUrl()) : Optional.empty(), + grpcProviderConfig.getAuthTokenPassThroughForNoneAuth(), supportedAuthenticationsMap, providerParameterList, providerRequestProperties, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java index 58976c3cc5..4d2d31ff1e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java @@ -31,6 +31,7 @@ public EmbeddingProviderConfigStore.ServiceConfig getConfiguration( || !config.providers().get(serviceName).enabled()) { throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException(serviceName); } + final var properties = config.providers().get(serviceName).properties(); Map> modelwiseServiceUrlOverrides = Objects.requireNonNull(config.providers().get(serviceName).models()).stream() diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java index b2453b55bd..54b0956385 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java @@ -14,6 +14,7 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.ws.rs.core.Response; @@ -35,17 +36,17 @@ public class AwsBedrockEmbeddingProvider extends EmbeddingProvider { private static final ObjectReader OBJECT_READER = new ObjectMapper().reader(); public AwsBedrockEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.BEDROCK, - requestProperties, + providerConfig, baseUrl, - modelName, - acceptsTitanAIDimensions(modelName) ? dimension : 0, + modelConfig, + acceptsTitanAIDimensions(modelConfig.name()) ? dimension : 0, vectorizeServiceParameters); } @@ -62,7 +63,9 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // the config shoudl mean we only do a batch on 1, sanity checking + checkEOLModelUsage(); + + // the config should mean we only do a batch on 1, sanity checking if (texts.size() != 1) { throw new IllegalArgumentException( "AWS Bedrock embedding provider only supports a single text input per request, but received: " diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java index ae6c28d917..11a5e164ce 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java @@ -5,8 +5,8 @@ import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -15,14 +15,15 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; +import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; + import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; -import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; -import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; /** * Implementation of client that talks to Azure-deployed OpenAI embedding provider. See vectorizeServiceParameters) { // One special case: legacy "ada-002" model does not accept "dimension" parameter super( ModelProvider.AZURE_OPENAI, - requestProperties, + providerConfig, baseUrl, - modelName, - acceptsOpenAIDimensions(modelName) ? dimension : 0, + modelConfig, + acceptsOpenAIDimensions(modelConfig.name()) ? dimension : 0, vectorizeServiceParameters); String actualUrl = replaceParameters(baseUrl, vectorizeServiceParameters); azureClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(actualUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(AzureOpenAIEmbeddingProviderClient.class); } @@ -80,6 +81,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); var azureRequest = new AzureOpenAIEmbeddingRequest( @@ -132,7 +134,8 @@ public Uni vectorize( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) @RegisterProvider(ProviderHttpInterceptor.class) - public interface AzureOpenAIEmbeddingProviderClient { + public interface + AzureOpenAIEmbeddingProviderClient { // no path specified, as it is already included in the baseUri @POST @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java index 325ba8a803..19ab3737f3 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java @@ -9,6 +9,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -35,23 +36,23 @@ public class CohereEmbeddingProvider extends EmbeddingProvider { private final CohereEmbeddingProviderClient cohereClient; public CohereEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.COHERE, - requestProperties, + providerConfig, baseUrl, - modelName, + modelConfig, dimension, vectorizeServiceParameters); cohereClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(CohereEmbeddingProviderClient.class); } @@ -98,6 +99,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); // Input type to be used for vector search should "search_query" diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java index 4bbbf559d7..5373df4b13 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java @@ -8,11 +8,8 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; -import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; -import io.stargate.sgv2.jsonapi.service.provider.ModelType; -import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; -import io.stargate.sgv2.jsonapi.service.provider.ProviderBase; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +import io.stargate.sgv2.jsonapi.service.provider.*; import io.stargate.sgv2.jsonapi.util.recordable.Recordable; import jakarta.ws.rs.core.Response; import java.time.Duration; @@ -27,31 +24,44 @@ public abstract class EmbeddingProvider extends ProviderBase { protected static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingProvider.class); - protected final EmbeddingProviderConfigStore.RequestProperties requestProperties; + protected final EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig; protected final String baseUrl; + protected final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig; protected final int dimension; protected final Map vectorizeServiceParameters; + protected final Duration initialBackOffDuration; protected final Duration maxBackOffDuration; protected EmbeddingProvider( ModelProvider modelProvider, - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { - super(modelProvider, ModelType.EMBEDDING, modelName); + super(modelProvider, ModelType.EMBEDDING); - this.requestProperties = requestProperties; + this.providerConfig = providerConfig; this.baseUrl = baseUrl; - + this.modelConfig = modelConfig; this.dimension = dimension; this.vectorizeServiceParameters = vectorizeServiceParameters; - this.initialBackOffDuration = Duration.ofMillis(requestProperties.initialBackOffMillis()); - this.maxBackOffDuration = Duration.ofMillis(requestProperties.maxBackOffMillis()); + + this.initialBackOffDuration = Duration.ofMillis(providerConfig.properties().initialBackOffMillis()); + this.maxBackOffDuration = Duration.ofMillis(providerConfig.properties().maxBackOffMillis()); + } + + @Override + public String modelName() { + return modelConfig.name(); + } + + @Override + public ApiModelSupport modelSupport() { + return modelConfig.apiModelSupport(); } /** @@ -74,7 +84,7 @@ public abstract Uni vectorize( * @return */ public int maxBatchSize() { - return requestProperties.maxBatchSize(); + return providerConfig.properties().maxBatchSize(); } /** @@ -156,6 +166,7 @@ protected void checkEmbeddingApiKeyHeader(Optional apiKey) { } } + @Override protected Duration initialBackOffDuration() { return initialBackOffDuration; @@ -168,12 +179,12 @@ protected Duration maxBackOffDuration() { @Override protected double jitter() { - return requestProperties.jitter(); + return providerConfig.properties().jitter(); } @Override protected int atMostRetries() { - return requestProperties.atMostRetries(); + return providerConfig.properties().atMostRetries(); } @Override diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java index 0d79cb51cb..dd7d8267aa 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java @@ -5,6 +5,11 @@ import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; +<<<<<<< HEAD +======= +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +>>>>>>> main import io.stargate.sgv2.jsonapi.service.embedding.gateway.EmbeddingGatewayClient; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.enterprise.context.ApplicationScoped; @@ -17,8 +22,13 @@ public class EmbeddingProviderFactory { @Inject Instance embeddingProviderConfigStore; +<<<<<<< HEAD @Inject OperationsConfig operationsConfig; +======= + @Inject EmbeddingProvidersConfig embeddingProvidersConfig; + @Inject OperationsConfig config; +>>>>>>> main @GrpcClient("embedding") EmbeddingService embeddingService; @@ -28,9 +38,10 @@ interface ProviderConstructor { EmbeddingProvider create( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, int dimension, - Map vectorizeServiceParameter); + Map vectorizeServiceParameter, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig); } private static final Map EMBEDDING_PROVIDER_CTORS = @@ -90,10 +101,17 @@ private synchronized EmbeddingProvider addService( Map authentication, String commandName) { +<<<<<<< HEAD final EmbeddingProviderConfigStore.ServiceConfig serviceConfig = embeddingProviderConfigStore.get().getConfiguration(tenant, modelProvider.apiName()); if (operationsConfig.enableEmbeddingGateway()) { +======= + final EmbeddingProviderConfigStore.ServiceConfig configuration = + embeddingProviderConfigStore.get().getConfiguration(tenant, serviceName); + + if (config.enableEmbeddingGateway()) { +>>>>>>> main return new EmbeddingGatewayClient( serviceConfig.requestConfiguration(), modelProvider, @@ -108,9 +126,16 @@ private synchronized EmbeddingProvider addService( commandName); } +<<<<<<< HEAD if (serviceConfig.serviceProvider().equals(ModelProvider.CUSTOM.apiName())) { Optional> clazz = serviceConfig.implementationClass(); if (clazz.isEmpty()) { +======= + // CUSTOM is for test only + if (configuration.serviceProvider().equals(ProviderConstants.CUSTOM)) { + Optional> clazz = configuration.implementationClass(); + if (!clazz.isPresent()) { +>>>>>>> main throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( "custom class undefined"); } @@ -140,11 +165,35 @@ private synchronized EmbeddingProvider addService( "ModelProvider does not have a constructor: " + serviceConfigModelProvider.apiName()); } +<<<<<<< HEAD return ctor.create( serviceConfig.requestConfiguration(), serviceConfig.getBaseUrl(modelName), modelName, +======= + // Get the provider, then get the model. + var providerConfig = embeddingProvidersConfig.providers().get(configuration.serviceProvider()); + if (providerConfig == null) { + throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + "unknown service provider '%s'", configuration.serviceProvider()); + } + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model = + embeddingProvidersConfig.providers().get(configuration.serviceProvider()).models().stream() + .filter(m -> m.name().equals(modelName)) + .findFirst() + .orElseThrow( + () -> + ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + "unknown model '%s' for service provider '%s'", + modelName, configuration.serviceProvider())); + + return ctor.create( + configuration.requestConfiguration(), + configuration.getBaseUrl(modelName), + model, +>>>>>>> main dimension, - vectorizeServiceParameters); + vectorizeServiceParameters, + providerConfig); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java index 4325c53519..193754e9fe 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java @@ -7,6 +7,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -30,16 +31,16 @@ public class HuggingFaceDedicatedEmbeddingProvider extends EmbeddingProvider { private final HuggingFaceDedicatedEmbeddingProviderClient huggingFaceClient; public HuggingFaceDedicatedEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.HUGGINGFACE_DEDICATED, - requestProperties, + providerConfig, baseUrl, - modelName, + modelConfig, dimension, vectorizeServiceParameters); @@ -48,7 +49,7 @@ public HuggingFaceDedicatedEmbeddingProvider( huggingFaceClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(dedicatedApiUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(HuggingFaceDedicatedEmbeddingProviderClient.class); } @@ -78,6 +79,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); var huggingFaceRequest = diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java index 4927de0209..c42733df3a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java @@ -6,6 +6,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -27,23 +28,23 @@ public class HuggingFaceEmbeddingProvider extends EmbeddingProvider { private final HuggingFaceEmbeddingProviderClient huggingFaceClient; public HuggingFaceEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.HUGGINGFACE, - requestProperties, + providerConfig, baseUrl, - modelName, + modelConfig, dimension, vectorizeServiceParameters); huggingFaceClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(HuggingFaceEmbeddingProviderClient.class); } @@ -68,6 +69,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); var huggingFaceRequest = new HuggingFaceEmbeddingRequest(texts, new HuggingFaceEmbeddingRequest.Options(true)); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java index c6b7a33550..bf1102130a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java @@ -8,6 +8,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -33,23 +34,23 @@ public class JinaAIEmbeddingProvider extends EmbeddingProvider { private final JinaAIEmbeddingProviderClient jinaClient; public JinaAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.JINA_AI, - requestProperties, + providerConfig, baseUrl, - modelName, - acceptsJinaAIDimensions(modelName) ? dimension : 0, + modelConfig, + acceptsJinaAIDimensions(modelConfig.name()) ? dimension : 0, vectorizeServiceParameters); jinaClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(JinaAIEmbeddingProviderClient.class); } @@ -78,6 +79,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); var jinaRequest = diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java index ffb2fd1019..b1bc13ef07 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java @@ -1,5 +1,8 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricTags.TENANT_TAG; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.UNKNOWN_VALUE; + import com.google.common.collect.Lists; import io.micrometer.core.instrument.*; import io.smallrye.mutiny.Multi; @@ -7,7 +10,11 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.api.request.RequestContext; import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonApiMetricsConfig; +<<<<<<< HEAD import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +======= +import io.stargate.sgv2.jsonapi.metrics.MetricsConstants; +>>>>>>> main import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -138,7 +145,7 @@ public Uni vectorize( () -> sample.stop( meterRegistry.timer( - jsonApiMetricsConfig.vectorizeCallDurationMetrics(), tags))); + MetricsConstants.MetricNames.VECTORIZE_CALL_DURATION_METRIC, tags))); } @Override @@ -156,7 +163,11 @@ public int maxBatchSize() { */ private Tags getCustomTags() { Tag commandTag = Tag.of(jsonApiMetricsConfig.command(), commandName); +<<<<<<< HEAD Tag tenantTag = Tag.of("tenant", requestContext.getTenantId().orElse(UNKNOWN_TENANT_ID)); +======= + Tag tenantTag = Tag.of(TENANT_TAG, requestContext.getTenantId().orElse(UNKNOWN_VALUE)); +>>>>>>> main Tag embeddingProviderTag = Tag.of( jsonApiMetricsConfig.embeddingProvider(), embeddingProvider.getClass().getSimpleName()); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java index 6fd5198130..c5e71f18fe 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java @@ -7,6 +7,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -32,23 +33,23 @@ public class MistralEmbeddingProvider extends EmbeddingProvider { private final MistralEmbeddingProviderClient mistralClient; public MistralEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.MISTRAL, - requestProperties, + providerConfig, baseUrl, - modelName, + modelConfig, dimension, vectorizeServiceParameters); mistralClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(MistralEmbeddingProviderClient.class); } @@ -82,6 +83,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); var mistralRequest = new MistralEmbeddingRequest(texts, modelName(), "float"); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java index 36fa1f076c..2e9e377aed 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java @@ -7,6 +7,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -33,23 +34,23 @@ public class NvidiaEmbeddingProvider extends EmbeddingProvider { private final NvidiaEmbeddingProviderClient nvidiaClient; public NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.NVIDIA, - requestProperties, + providerConfig, baseUrl, - modelName, + modelConfig, dimension, vectorizeServiceParameters); nvidiaClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(NvidiaEmbeddingProviderClient.class); } @@ -77,6 +78,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); var input_type = embeddingRequestType == EmbeddingRequestType.INDEX ? "passage" : "query"; var nvidiaRequest = new NvidiaEmbeddingRequest( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java index b234e58c3e..28fa557361 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java @@ -8,6 +8,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -31,24 +32,24 @@ public class OpenAIEmbeddingProvider extends EmbeddingProvider { private final OpenAIEmbeddingProviderClient openAIClient; public OpenAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { // One special case: legacy "ada-002" model does not accept "dimension" parameter super( ModelProvider.OPENAI, - requestProperties, + providerConfig, baseUrl, - modelName, - acceptsOpenAIDimensions(modelName) ? dimension : 0, + modelConfig, + acceptsOpenAIDimensions(modelConfig.name()) ? dimension : 0, vectorizeServiceParameters); openAIClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(OpenAIEmbeddingProviderClient.class); } @@ -80,6 +81,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); var openAiRequest = diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java index b374d4a542..2bd65a5e58 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java @@ -9,6 +9,7 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -35,24 +36,24 @@ public class UpstageAIEmbeddingProvider extends EmbeddingProvider { private final UpstageAIEmbeddingProviderClient upstageClient; public UpstageAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelNamePrefix, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, Map vectorizeServiceParameters) { super( ModelProvider.UPSTAGE_AI, - requestProperties, + providerConfig, baseUrl, - modelNamePrefix, + modelConfig, dimension, vectorizeServiceParameters); - this.modelNamePrefix = modelNamePrefix; + this.modelNamePrefix = modelConfig.name(); upstageClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(UpstageAIEmbeddingProviderClient.class); } @@ -107,6 +108,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); // Oddity: Implementation does not support batching, so we only accept "batches" diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java index 84fc377900..c7873a794c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java @@ -9,6 +9,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -33,24 +34,24 @@ public class VertexAIEmbeddingProvider extends EmbeddingProvider { private final VertexAIEmbeddingProviderClient vertexClient; public VertexAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, - Map serviceParameters) { + Map vectorizeServiceParameters) { super( ModelProvider.VERTEXAI, - requestProperties, + providerConfig, baseUrl, - modelName, + modelConfig, dimension, - serviceParameters); + vectorizeServiceParameters); - String actualUrl = replaceParameters(baseUrl, serviceParameters); + String actualUrl = replaceParameters(baseUrl, vectorizeServiceParameters); vertexClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(actualUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(providerConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(VertexAIEmbeddingProviderClient.class); } @@ -76,6 +77,7 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + checkEOLModelUsage(); checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); var vertexRequest = diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java index 5c87851e97..c6934af226 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java @@ -8,6 +8,7 @@ import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; @@ -33,22 +34,22 @@ public class VoyageAIEmbeddingProvider extends EmbeddingProvider { private final Boolean autoTruncate; public VoyageAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, String baseUrl, - String modelName, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, int dimension, - Map serviceParameters) { + Map vectorizeServiceParameters) { super( - ModelProvider.VOYAGE_AI, - requestProperties, + ModelProvider.VERTEXAI, + providerConfig, baseUrl, - modelName, + modelConfig, dimension, - serviceParameters); + vectorizeServiceParameters); // use configured input_type if available - requestTypeQuery = requestProperties.requestTypeQuery().orElse(null); - requestTypeIndex = requestProperties.requestTypeIndex().orElse(null); + requestTypeQuery = providerConfig.properties().requestTypeQuery().orElse(null); + requestTypeIndex = providerConfig.properties().requestTypeIndex().orElse(null); Object v = (serviceParameters == null) ? null : serviceParameters.get("autoTruncate"); autoTruncate = (v instanceof Boolean) ? (Boolean) v : null; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java index e54f08065f..af4a5633f6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java @@ -3,10 +3,16 @@ import io.quarkus.runtime.annotations.RegisterForReflection; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; +<<<<<<< HEAD import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +======= +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfigImpl; +import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +>>>>>>> main import java.util.*; /** @@ -30,9 +36,10 @@ public class CustomITEmbeddingProvider extends EmbeddingProvider { public static HashMap TEST_DATA_DIMENSION_5 = new HashMap<>(); public static HashMap TEST_DATA_DIMENSION_6 = new HashMap<>(); - private int dimension; + private final int dimension; public CustomITEmbeddingProvider(int dimension) { +<<<<<<< HEAD // aaron 9 June 2025 - refactoring , I think none of the super class is used, so passing dummy // values super( @@ -44,6 +51,23 @@ public CustomITEmbeddingProvider(int dimension) { 1, Map.of()); +======= + // construct the test modelConfig + super( + null, + null, + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl( + "testModel", + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), + Optional.of(dimension), + List.of(), + Map.of(), + Optional.empty()), + dimension, + Map.of(), + null); +>>>>>>> main this.dimension = dimension; } @@ -91,6 +115,12 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { +<<<<<<< HEAD +======= + // Check if using an EOF model + checkEOLModelUsage(); + +>>>>>>> main List response = new ArrayList<>(texts.size()); if (texts.isEmpty()) { var modelUsage = diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/GenericOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/GenericOperation.java index eeb6aa4143..652fdb7666 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/GenericOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/GenericOperation.java @@ -97,6 +97,7 @@ protected Multi startMulti(CommandContext commandContext) { new CommandQueryExecutor.DBRequestContext( commandContext.requestContext().getTenantId(), commandContext.requestContext().getCassandraToken(), + commandContext.requestContext().getUserAgent(), commandContext.requestTracing().enabled()), CommandQueryExecutor.QueryTarget.TABLE); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTask.java index b7c3fa9cbd..31079bb7f1 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTask.java @@ -15,6 +15,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.table.definition.ColumnsDescContainer; import io.stargate.sgv2.jsonapi.exception.WarningException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*; +import io.stargate.sgv2.jsonapi.service.cqldriver.override.ExtendedSelect; import io.stargate.sgv2.jsonapi.service.operation.query.*; import io.stargate.sgv2.jsonapi.service.operation.tasks.DBTask; import io.stargate.sgv2.jsonapi.service.operation.tasks.TaskRetryPolicy; @@ -171,7 +172,10 @@ protected SimpleStatement buildReadStatement() { List positionalValues = new ArrayList<>(); - var selectFrom = selectFrom(schemaObject.keyspaceName(), schemaObject.tableName()); + // Note, use ExtendedSelect to support AND/OR in where clause, see details in + // ExtendedSelect.java. + var selectFrom = + ExtendedSelect.selectFrom(schemaObject.keyspaceName(), schemaObject.tableName()); var select = applySelect(selectFrom, positionalValues); // these are options that go on the query builder, such as limit or allow filtering var bindableQuery = applyOptions(select); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTaskPage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTaskPage.java index ac90d3b0ba..fb42881cc7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTaskPage.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/ReadDBTaskPage.java @@ -108,11 +108,12 @@ private Accumulator sortVector(float[] sortVector) { return this; } - public Accumulator mayReturnVector(CmdT command) { + public Accumulator mayReturnVector( + CommandContext commandContext, CmdT command) { var includeVector = command.includeSortVector().orElse(false); if (includeVector) { var requestedVector = - command.vectorSortExpression().map(SortExpression::vector).orElse(null); + command.vectorSortExpression(commandContext).map(SortExpression::vector).orElse(null); if (requestedVector != null) { this.includeSortVector = true; this.sortVector = requestedVector; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/builder/BuiltConditionPredicate.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/builder/BuiltConditionPredicate.java index b283cf256a..d4b48352c1 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/builder/BuiltConditionPredicate.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/builder/BuiltConditionPredicate.java @@ -10,7 +10,8 @@ public enum BuiltConditionPredicate { IN("IN"), CONTAINS("CONTAINS"), NOT_CONTAINS("NOT CONTAINS"), - CONTAINS_KEY("CONTAINS KEY"); + CONTAINS_KEY("CONTAINS KEY"), + TEXT_SEARCH(":"); public final String cql; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CollectionReadOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CollectionReadOperation.java index 8c16e55f35..5577ee4167 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CollectionReadOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CollectionReadOperation.java @@ -12,9 +12,9 @@ import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RequestContext; -import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.projection.DocumentProjector; import io.stargate.sgv2.jsonapi.service.shredding.collections.DocumentId; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java index 97330c525a..570fe4138c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/collections/CreateCollectionOperation.java @@ -22,6 +22,7 @@ import io.stargate.sgv2.jsonapi.service.schema.EmbeddingSourceModel; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionLexicalConfig; +import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionRerankDef; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionTableMatcher; import java.time.Duration; @@ -45,7 +46,8 @@ public record CreateCollectionOperation( boolean tooManyIndexesRollbackEnabled, // if true, deny all indexing option is set and no indexes will be created boolean indexingDenyAll, - CollectionLexicalConfig lexicalConfig) + CollectionLexicalConfig lexicalConfig, + CollectionRerankDef rerankDef) implements Operation { private static final Logger logger = LoggerFactory.getLogger(CreateCollectionOperation.class); @@ -65,7 +67,8 @@ public static CreateCollectionOperation withVectorSearch( int ddlDelayMillis, boolean tooManyIndexesRollbackEnabled, boolean indexingDenyAll, - CollectionLexicalConfig lexicalConfig) { + CollectionLexicalConfig lexicalConfig, + CollectionRerankDef rerankDef) { return new CreateCollectionOperation( commandContext, dbLimitsConfig, @@ -80,7 +83,8 @@ public static CreateCollectionOperation withVectorSearch( ddlDelayMillis, tooManyIndexesRollbackEnabled, indexingDenyAll, - Objects.requireNonNull(lexicalConfig)); + Objects.requireNonNull(lexicalConfig), + Objects.requireNonNull(rerankDef)); } public static CreateCollectionOperation withoutVectorSearch( @@ -93,7 +97,8 @@ public static CreateCollectionOperation withoutVectorSearch( int ddlDelayMillis, boolean tooManyIndexesRollbackEnabled, boolean indexingDenyAll, - CollectionLexicalConfig lexicalConfig) { + CollectionLexicalConfig lexicalConfig, + CollectionRerankDef rerankDef) { return new CreateCollectionOperation( commandContext, dbLimitsConfig, @@ -108,18 +113,19 @@ public static CreateCollectionOperation withoutVectorSearch( ddlDelayMillis, tooManyIndexesRollbackEnabled, indexingDenyAll, - Objects.requireNonNull(lexicalConfig)); + Objects.requireNonNull(lexicalConfig), + Objects.requireNonNull(rerankDef)); } @Override public Uni> execute( RequestContext dataApiRequestInfo, QueryExecutor queryExecutor) { logger.info( - "Executing CreateCollectionOperation for {}.{} with property {}", + "Executing CreateCollectionOperation for {}.{} with definition: {}", commandContext.schemaObject().name().keyspace(), name, comment); - // validate Data API collection limit guardrail and get tableMetadata + // validate Data API collection limit guard rail and get tableMetadata Map allKeyspaces = cqlSessionCache.getSession(dataApiRequestInfo).getMetadata().getKeyspaces(); KeyspaceMetadata currKeyspace = @@ -136,10 +142,10 @@ public Uni> execute( // if table doesn't exist, continue to create collection if (tableMetadata == null) { - return executeCollectionCreation(dataApiRequestInfo, queryExecutor, false); + return executeCollectionCreation(dataApiRequestInfo, queryExecutor, lexicalConfig(), false); } - // if table exists, compare existedCollectionSettings and newCollectionSettings - CollectionSchemaObject existedCollectionSettings = + // if table exists, compare existingCollectionSettings and newCollectionSettings + CollectionSchemaObject existingCollectionSettings = CollectionSchemaObject.getCollectionSettings(tableMetadata, objectMapper); // Use the fromNameOrDefault() so if not specified it will default @@ -162,11 +168,60 @@ public Uni> execute( embeddingSourceModel, comment, objectMapper); - // if table exists we have two choices: + // If Collection exists we have a choice: // (1) trying to create with same options -> ok, proceed // (2) trying to create with different options -> error out - if (existedCollectionSettings.equals(newCollectionSettings)) { - return executeCollectionCreation(dataApiRequestInfo, queryExecutor, true); + // but before deciding (2), we need to consider one specific backwards-compatibility + // case: that of existing pre-lexical/pre-reranking collection, being re-created + // without definitions for lexical/pre-ranking. Although it would create a new + // Collection with both enabled, it should NOT fail if attempted on an existing + // Collection with pre-lexical/pre-reranking settings but silently succeed. + + boolean settingsAreEqual = existingCollectionSettings.equals(newCollectionSettings); + + if (!settingsAreEqual) { + final var oldLexical = existingCollectionSettings.lexicalConfig(); + final var newLexical = lexicalConfig(); + final var oldReranking = existingCollectionSettings.rerankingConfig(); + final var newReranking = rerankDef(); + + // So: for backwards compatibility reasons we may need to override settings if + // (and only if) the collection was created before lexical and reranking. + // In addition, we need to check that new lexical settings are for defaults + // (difficult to check the same for reranking; for now assume that if lexical + // is default, reranking is also default). + if (oldLexical == CollectionLexicalConfig.configForPreLexical() + && newLexical == CollectionLexicalConfig.configForDefault() + && oldReranking == CollectionRerankDef.configForPreRerankingCollection() + && newReranking == CollectionRerankDef.configForDefault()) { + var originalNewSettings = newCollectionSettings; + newCollectionSettings = + newCollectionSettings.withLexicalAndRerankOverrides( + oldLexical, existingCollectionSettings.rerankingConfig()); + // and now re-check if settings are the same + settingsAreEqual = existingCollectionSettings.equals(newCollectionSettings); + logger.info( + "CreateCollectionOperation for {}.{} with existing legacy lexical/reranking settings, new settings differ. Tried to unify, result: {}" + + " Old settings: {}, New settings: {}", + commandContext.schemaObject().name().keyspace(), + name, + settingsAreEqual, + existingCollectionSettings, + originalNewSettings); + } else { + logger.info( + "CreateCollectionOperation for {}.{} with different settings (but not old legacy lexical/reranking settings), cannot unify." + + " Old settings: {}, New settings: {}", + commandContext.schemaObject().name().keyspace(), + name, + existingCollectionSettings, + newCollectionSettings); + } + } + + if (settingsAreEqual) { + return executeCollectionCreation( + dataApiRequestInfo, queryExecutor, newCollectionSettings.lexicalConfig(), true); } return Uni.createFrom() .failure( @@ -179,15 +234,25 @@ public Uni> execute( * * @param dataApiRequestInfo DBRequestContext * @param queryExecutor QueryExecutor instance + * @param lexicalConfig Lexical configuration for the collection * @param collectionExisted boolean that says if collection existed before * @return Uni> */ private Uni> executeCollectionCreation( - RequestContext dataApiRequestInfo, QueryExecutor queryExecutor, boolean collectionExisted) { + RequestContext dataApiRequestInfo, + QueryExecutor queryExecutor, + CollectionLexicalConfig lexicalConfig, + boolean collectionExisted) { final Uni execute = queryExecutor.executeCreateSchemaChange( dataApiRequestInfo, - getCreateTable(commandContext.schemaObject().name().keyspace(), name)); + getCreateTable( + commandContext.schemaObject().name().keyspace(), + name, + vectorSearch, + vectorSize, + comment, + lexicalConfig)); final Uni indexResult = execute .onItem() @@ -201,6 +266,7 @@ private Uni> executeCollectionCreation( getIndexStatements( commandContext.schemaObject().name().keyspace(), name, + lexicalConfig, collectionExisted); Multi indexResultMulti; /* @@ -402,9 +468,15 @@ TableMetadata findTableAndValidateLimits( return null; } - public SimpleStatement getCreateTable(String keyspace, String table) { + public static SimpleStatement getCreateTable( + String keyspace, + String table, + boolean vectorSearch, + int vectorSize, + String comment, + CollectionLexicalConfig lexicalConfig) { // The keyspace and table name are quoted to make it case-sensitive - final String lexicalField = lexicalConfig().enabled() ? " query_lexical_value text, " : ""; + final String lexicalField = lexicalConfig.enabled() ? " query_lexical_value text, " : ""; if (vectorSearch) { String createTableWithVector = "CREATE TABLE IF NOT EXISTS \"%s\".\"%s\" (" @@ -428,27 +500,26 @@ public SimpleStatement getCreateTable(String keyspace, String table) { createTableWithVector = createTableWithVector + " WITH comment = '" + comment + "'"; } return SimpleStatement.newInstance(String.format(createTableWithVector, keyspace, table)); - } else { - String createTable = - "CREATE TABLE IF NOT EXISTS \"%s\".\"%s\" (" - + " key tuple," - + " tx_id timeuuid, " - + " doc_json text," - + " exist_keys set," - + " array_size map," - + " array_contains set," - + " query_bool_values map," - + " query_dbl_values map," - + " query_text_values map, " - + " query_timestamp_values map, " - + " query_null_values set, " - + lexicalField - + " PRIMARY KEY (key))"; - if (comment != null) { - createTable = createTable + " WITH comment = '" + comment + "'"; - } - return SimpleStatement.newInstance(String.format(createTable, keyspace, table)); } + String createTable = + "CREATE TABLE IF NOT EXISTS \"%s\".\"%s\" (" + + " key tuple," + + " tx_id timeuuid, " + + " doc_json text," + + " exist_keys set," + + " array_size map," + + " array_contains set," + + " query_bool_values map," + + " query_dbl_values map," + + " query_text_values map, " + + " query_timestamp_values map, " + + " query_null_values set, " + + lexicalField + + " PRIMARY KEY (key))"; + if (comment != null) { + createTable = createTable + " WITH comment = '" + comment + "'"; + } + return SimpleStatement.newInstance(String.format(createTable, keyspace, table)); } /* @@ -456,7 +527,10 @@ public SimpleStatement getCreateTable(String keyspace, String table) { * For a new table they are run without IF NOT EXISTS. */ public List getIndexStatements( - String keyspace, String table, boolean collectionExisted) { + String keyspace, + String table, + CollectionLexicalConfig lexicalConfig, + boolean collectionExisted) { List statements = new ArrayList<>(10); String appender = collectionExisted ? "CREATE CUSTOM INDEX IF NOT EXISTS" : "CREATE CUSTOM INDEX"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java index 1d045f9dc9..1d5ca4f19f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java @@ -74,7 +74,10 @@ protected EmbeddingTask.EmbeddingResultSupplier buildResultSupplier( embeddingProvider.vectorize( 1, // always use 1, microbatching happens in the provider. vectorizeTexts, - commandContext.requestContext().getEmbeddingCredentials(), + commandContext + .requestContext() + .getEmbeddingCredentialsSupplier() + .create(commandContext.requestContext(), embeddingProvider.getProviderConfig()), requestType), embeddingActions, vectorizeTexts); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/FindEmbeddingProvidersOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/FindEmbeddingProvidersOperation.java index bf563da32f..4b763f03d2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/FindEmbeddingProvidersOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/FindEmbeddingProvidersOperation.java @@ -3,12 +3,15 @@ import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.model.command.CommandResult; import io.stargate.sgv2.jsonapi.api.model.command.CommandStatus; +import io.stargate.sgv2.jsonapi.api.model.command.impl.FindEmbeddingProvidersCommand; import io.stargate.sgv2.jsonapi.api.model.command.tracing.RequestTracing; import io.stargate.sgv2.jsonapi.api.request.RequestContext; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.operation.Operation; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import java.util.*; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -16,8 +19,8 @@ * Operation that list all available and enabled vector providers into the {@link * CommandStatus#EXISTING_EMBEDDING_PROVIDERS} command status. */ -public record FindEmbeddingProvidersOperation(EmbeddingProvidersConfig config) - implements Operation { +public record FindEmbeddingProvidersOperation( + FindEmbeddingProvidersCommand command, EmbeddingProvidersConfig config) implements Operation { @Override public Uni> execute( RequestContext dataApiRequestInfo, QueryExecutor queryExecutor) { @@ -30,7 +33,9 @@ public Uni> execute( .collect( Collectors.toMap( Map.Entry::getKey, - entry -> EmbeddingProviderResponse.provider(entry.getValue()))); + entry -> + EmbeddingProviderResponse.toResponse( + entry.getValue(), getSupportStatusPredicate()))); return new Result(embeddingProviders); }); } @@ -49,7 +54,7 @@ public CommandResult get() { } /** - * A simplified representation of a vector provider's configuration for API responses. Excludes + * A simplified representation of n embedding provider's configuration for API responses. Excludes * internal properties (retry, timeout etc.) to focus on data relevant to clients, including URL, * authentication methods, and model customization parameters. * @@ -67,56 +72,91 @@ private record EmbeddingProviderResponse( supportedAuthentication, List parameters, List models) { - private static EmbeddingProviderResponse provider( - EmbeddingProvidersConfig.EmbeddingProviderConfig embeddingProviderConfig) { - ArrayList modelsRemoveProperties = new ArrayList<>(); - for (EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model : - embeddingProviderConfig.models()) { - ModelConfigResponse returnModel = - ModelConfigResponse.returnModelConfigResponse( - model.name(), model.vectorDimension(), model.parameters()); - modelsRemoveProperties.add(returnModel); + + /** + * Constructs an {@link EmbeddingProviderResponse} from the original provider config. It will + * exclude the internal properties (retry, timeout etc.). + * + * @param sourceEmbeddingProviderConfig, the original provider config with all properties. + * @param statusPredicate predicate to filter models based on their support status. + */ + private static EmbeddingProviderResponse toResponse( + EmbeddingProvidersConfig.EmbeddingProviderConfig sourceEmbeddingProviderConfig, + Predicate statusPredicate) { + + // if the providerConfig.models is null or empty, return an EmbeddingProviderResponse with + // empty models. + if (sourceEmbeddingProviderConfig.models() == null + || sourceEmbeddingProviderConfig.models().isEmpty()) { + return new EmbeddingProviderResponse( + sourceEmbeddingProviderConfig.displayName(), + sourceEmbeddingProviderConfig.url(), + sourceEmbeddingProviderConfig.supportedAuthentications(), + sourceEmbeddingProviderConfig.parameters(), + Collections.emptyList()); } + + // include models that with apiModelSupport status that user asked for + var modelsFilteredWithStatus = + sourceEmbeddingProviderConfig.models().stream() + .filter(modelConfig -> statusPredicate.test(modelConfig.apiModelSupport().status())) + .toList(); + + // convert each modelConfig to ModelConfigResponse with internal properties excluded + var modelConfigResponses = + modelsFilteredWithStatus.stream() + .map(ModelConfigResponse::toResponse) + .sorted(Comparator.comparing(ModelConfigResponse::name)) + .toList(); + return new EmbeddingProviderResponse( - embeddingProviderConfig.displayName(), - embeddingProviderConfig.url(), - embeddingProviderConfig.supportedAuthentications(), - embeddingProviderConfig.parameters(), - modelsRemoveProperties); + sourceEmbeddingProviderConfig.displayName(), + sourceEmbeddingProviderConfig.url(), + sourceEmbeddingProviderConfig.supportedAuthentications(), + sourceEmbeddingProviderConfig.parameters(), + modelConfigResponses); } } /** - * Configuration details for a model offered by a vector provider, tailored for external clients. - * Includes the model name and parameters for customization, excluding internal properties (retry, - * timeout etc.). + * Model configuration with internal properties excluded. Only includes the model name and + * parameters for customization, excluding internal properties (retry, timeout etc.). * * @param name Identifier name of the model. + * @param apiModelSupport Support status of the model. * @param vectorDimension vector dimension of the model. * @param parameters Parameters for customizing the model. */ private record ModelConfigResponse( - String name, Optional vectorDimension, List parameters) { - private static ModelConfigResponse returnModelConfigResponse( - String name, - Optional vectorDimension, - List parameters) { - // reconstruct each parameter for lowercase parameter type - ArrayList parametersResponse = new ArrayList<>(); - for (EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig parameter : - parameters) { - ParameterConfigResponse returnParameter = - ParameterConfigResponse.returnParameterConfigResponse( - parameter.name(), - parameter.type().toString(), - parameter.required(), - parameter.defaultValue(), - parameter.validation(), - parameter.help()); - parametersResponse.add(returnParameter); + String name, + ApiModelSupport apiModelSupport, + Optional vectorDimension, + List parameters) { + + private static ModelConfigResponse toResponse( + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig sourceModelConfig) { + + // if the sourceModelConfig.parameters is null or empty, return a ModelConfigResponse with + // empty parameters. + if (sourceModelConfig.parameters() == null || sourceModelConfig.parameters().isEmpty()) { + return new ModelConfigResponse( + sourceModelConfig.name(), + sourceModelConfig.apiModelSupport(), + sourceModelConfig.vectorDimension(), + Collections.emptyList()); } - return new ModelConfigResponse(name, vectorDimension, parametersResponse); + // reconstruct each parameter for lowercase parameter type + List parametersResponse = + sourceModelConfig.parameters().stream() + .map(ParameterConfigResponse::toResponse) + .collect(Collectors.toList()); + + return new ModelConfigResponse( + sourceModelConfig.name(), + sourceModelConfig.apiModelSupport(), + sourceModelConfig.vectorDimension(), + parametersResponse); } } @@ -140,20 +180,45 @@ private record ParameterConfigResponse( Optional defaultValue, Map> validation, Optional help) { - private static ParameterConfigResponse returnParameterConfigResponse( - String name, - String type, - boolean required, - Optional defaultValue, - Map> - validation, - Optional help) { - Map> validationMap = new HashMap<>(); - for (Map.Entry> - entry : validation.entrySet()) { - validationMap.put(entry.getKey().toString(), entry.getValue()); - } - return new ParameterConfigResponse(name, type, required, defaultValue, validationMap, help); + + private static ParameterConfigResponse toResponse( + EmbeddingProvidersConfig.EmbeddingProviderConfig.ParameterConfig sourceParameterConfig) { + Map> validationMap = + sourceParameterConfig.validation().entrySet().stream() + .collect(Collectors.toMap(entry -> entry.getKey().toString(), Map.Entry::getValue)); + + return new ParameterConfigResponse( + sourceParameterConfig.name(), + sourceParameterConfig.type().name(), + sourceParameterConfig.required(), + sourceParameterConfig.defaultValue(), + validationMap, + sourceParameterConfig.help()); } } + + /** + * With {@link FindEmbeddingProvidersCommand.Options#filterModelStatus()}, there are the rules to + * filter the models: + * + *
    + *
  • If not provided, only SUPPORTED models will be returned. + *
  • If provided with "" empty string or null, all SUPPORTED, DEPRECATED, END_OF_LIFE model + * will be returned. + *
  • If provided with specified SUPPORTED or DEPRECATED or END_OF_LIFE, only models matched + * the status will be returned. + *
+ */ + private Predicate getSupportStatusPredicate() { + if (command.options() == null) { + return status -> status == ApiModelSupport.SupportStatus.SUPPORTED; + } + + if (command.options().filterModelStatus() == null + || command.options().filterModelStatus().isBlank()) { + return status -> true; // accept all + } + + return status -> status.name().equalsIgnoreCase(command.options().filterModelStatus()); + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/collection/MatchCollectionFilter.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/collection/MatchCollectionFilter.java new file mode 100644 index 0000000000..710caef278 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/collection/MatchCollectionFilter.java @@ -0,0 +1,46 @@ +package io.stargate.sgv2.jsonapi.service.operation.filters.collection; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; +import io.stargate.sgv2.jsonapi.service.operation.builder.BuiltCondition; +import io.stargate.sgv2.jsonapi.service.operation.builder.BuiltConditionPredicate; +import io.stargate.sgv2.jsonapi.service.operation.builder.ConditionLHS; +import io.stargate.sgv2.jsonapi.service.operation.builder.JsonTerm; +import java.util.Objects; +import java.util.Optional; + +/** Filter for logical "$lexical" field in Documents. */ +public class MatchCollectionFilter extends CollectionFilter { + private final String value; + + public MatchCollectionFilter(String path, String value) { + super(path); + this.value = Objects.requireNonNull(value, "value must not be null"); + this.collectionIndexUsage.lexicalIndexTag = true; + } + + @Override + public BuiltCondition get() { + return BuiltCondition.of( + ConditionLHS.column(DocumentConstants.Columns.LEXICAL_INDEX_COLUMN_NAME), + BuiltConditionPredicate.TEXT_SEARCH, + new JsonTerm(value)); + } + + protected Optional jsonNodeForNewDocument(JsonNodeFactory nodeFactory) { + return Optional.of(toJsonNode(nodeFactory, value)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || !(o instanceof MatchCollectionFilter other)) return false; + return Objects.equals(value, other.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/InTableFilter.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/InTableFilter.java index 93c86d526f..70b1a7e51b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/InTableFilter.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/InTableFilter.java @@ -6,7 +6,6 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.querybuilder.relation.ColumnRelationBuilder; -import com.datastax.oss.driver.api.querybuilder.relation.OngoingWhereClause; import com.datastax.oss.driver.api.querybuilder.relation.Relation; import com.datastax.oss.driver.api.querybuilder.term.Term; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.ValueComparisonOperator; @@ -52,10 +51,7 @@ public InTableFilter(Operator operator, String path, List arrayValue) { } @Override - public > StmtT apply( - TableSchemaObject tableSchemaObject, - StmtT ongoingWhereClause, - List positionalValues) { + public Relation apply(TableSchemaObject tableSchemaObject, List positionalValues) { List bindMarkers = new ArrayList<>(); @@ -112,8 +108,7 @@ public > StmtT apply( } } - return ongoingWhereClause.where( - applyInOperator(Relation.column(getPathAsCqlIdentifier()), bindMarkers)); + return applyInOperator(Relation.column(getPathAsCqlIdentifier()), bindMarkers); } /** diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/NativeTypeTableFilter.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/NativeTypeTableFilter.java index 54b75e3438..7cd5cafc0d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/NativeTypeTableFilter.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/NativeTypeTableFilter.java @@ -3,7 +3,6 @@ import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.bindMarker; import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.*; -import com.datastax.oss.driver.api.querybuilder.relation.OngoingWhereClause; import com.datastax.oss.driver.api.querybuilder.relation.Relation; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.ValueComparisonOperator; import io.stargate.sgv2.jsonapi.exception.DocumentException; @@ -104,10 +103,7 @@ public BuiltCondition get() { } @Override - public > StmtT apply( - TableSchemaObject tableSchemaObject, - StmtT ongoingWhereClause, - List positionalValues) { + public Relation apply(TableSchemaObject tableSchemaObject, List positionalValues) { try { var codec = @@ -149,8 +145,7 @@ public > StmtT apply( })); } - return ongoingWhereClause.where( - Relation.column(getPathAsCqlIdentifier()).build(operator.predicate.cql, bindMarker())); + return Relation.column(getPathAsCqlIdentifier()).build(operator.predicate.cql, bindMarker()); } public Recordable.DataRecorder recordTo(Recordable.DataRecorder dataRecorder) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/DBLogicalExpression.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/DBLogicalExpression.java index 2da5c70800..2156106506 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/DBLogicalExpression.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/DBLogicalExpression.java @@ -65,16 +65,23 @@ public void visitAllFilters(Class filterClass, Consumer analyzeRule) { } /** - * Add a sub dbLogicalExpression as subExpression to current caller dbLogicalExpression - * - * @param DBLogicalExpression subExpression - * @return subExpression + * Add a sub dbLogicalExpression as subExpression to current caller dbLogicalExpression. Return + * the passing sub dbLogicalExpression. */ - public DBLogicalExpression addSubExpression(DBLogicalExpression subExpression) { + public DBLogicalExpression addSubExpressionReturnSub(DBLogicalExpression subExpression) { subExpressions.add(Objects.requireNonNull(subExpression, "subExpressions cannot be null")); return subExpression; } + /** + * Add a sub dbLogicalExpression as subExpression to current caller dbLogicalExpression. Return + * the current caller dbLogicalExpression. + */ + public DBLogicalExpression addSubExpressionReturnCurrent(DBLogicalExpression subExpression) { + subExpressions.add(Objects.requireNonNull(subExpression, "subExpressions cannot be null")); + return this; + } + /** * Add a dbFilter to current caller dbLogicalExpression. This new DBFilter will be added in the * dbFilter List, it will be in the relation context of this dbLogicalExpression. diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/TableFilter.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/TableFilter.java index 56bffaf07c..24775a8fb4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/TableFilter.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/query/TableFilter.java @@ -3,7 +3,7 @@ import static io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil.cqlIdentifierFromUserInput; import com.datastax.oss.driver.api.core.CqlIdentifier; -import com.datastax.oss.driver.api.querybuilder.relation.OngoingWhereClause; +import com.datastax.oss.driver.api.querybuilder.relation.Relation; import com.datastax.oss.driver.api.querybuilder.select.Select; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.IndexUsage; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; @@ -67,19 +67,15 @@ public boolean filterIsSlice() { * java driver. * * @param tableSchemaObject The table the filter is being applied to. - * @param ongoingWhereClause The class from the Java Driver that implements the {@link - * OngoingWhereClause} that is used to build the WHERE in a CQL clause. This is the type of - * the statement the where is being added to such {@link Select} or {@link - * com.datastax.oss.driver.api.querybuilder.update.Update} * @param positionalValues Mutable array of values that are used when the {@link * com.datastax.oss.driver.api.querybuilder.QueryBuilder#bindMarker()} method is used, the * values are added to the select statement using {@link Select#build(Object...)} - * @return The {@link Select} to use to continue building the query. NOTE: the query builder is a - * fluent builder that returns immutable that are used in a chain, see the + * @return The {@link Relation} to use to continue building the query. NOTE: the query builder is + * a fluent builder that returns immutable that are used in a chain, see the * https://docs.datastax.com/en/developer/java-driver/4.3/manual/query_builder/index.html */ - public abstract > StmtT apply( - TableSchemaObject tableSchemaObject, StmtT ongoingWhereClause, List positionalValues); + public abstract Relation apply( + TableSchemaObject tableSchemaObject, List positionalValues); /** Default "NO OP" implementation to keep the code above cleaner. */ private static final FilterBehaviour UNSUPPORTED_BEHAVIOUR = diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/FindRerankingProvidersOperation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/FindRerankingProvidersOperation.java index ca9609c283..2f6d3fcaf7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/FindRerankingProvidersOperation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/FindRerankingProvidersOperation.java @@ -9,10 +9,11 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.DatabaseSchemaObject; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor; import io.stargate.sgv2.jsonapi.service.operation.Operation; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfigImpl; import java.util.*; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -33,22 +34,12 @@ public Uni> execute( Collectors.toMap( Map.Entry::getKey, entry -> - RerankingProviderResponse.provider( - entry.getValue(), getSupportStatuses()))); + RerankingProviderResponse.toResponse( + entry.getValue(), getSupportStatusPredicate()))); return new Result(rerankingProviders); }); } - // By default, if includeModelStatus is not provided in command option, only model in supported - // status will be listed. - private Set getSupportStatuses() { - var includeModelStatus = EnumSet.of(ModelSupport.SupportStatus.SUPPORTED); - if (command.options() != null && command.options().includeModelStatus() != null) { - includeModelStatus = command.options().includeModelStatus(); - } - return includeModelStatus; - } - // simple result wrapper private record Result(Map rerankingProviders) implements Supplier { @@ -70,32 +61,82 @@ private record RerankingProviderResponse( RerankingProvidersConfig.RerankingProviderConfig.AuthenticationConfig> supportedAuthentication, List models) { - private static RerankingProviderResponse provider( - RerankingProvidersConfig.RerankingProviderConfig rerankingProviderConfig, - Set includeModelStatus) { + + /** + * Constructs an {@link RerankingProviderResponse} from the original provider config. + * + * @param sourceRerankingProviderConfig, the original provider config with all properties. + * @param statusPredicate predicate to filter models based on their support status. + */ + private static RerankingProviderResponse toResponse( + RerankingProvidersConfig.RerankingProviderConfig sourceRerankingProviderConfig, + Predicate statusPredicate) { + + // if the providerConfig.models is null or empty, return an EmbeddingProviderResponse with + // empty models. + if (sourceRerankingProviderConfig.models() == null + || sourceRerankingProviderConfig.models().isEmpty()) { + return new RerankingProviderResponse( + sourceRerankingProviderConfig.isDefault(), + sourceRerankingProviderConfig.displayName(), + sourceRerankingProviderConfig.supportedAuthentications(), + Collections.emptyList()); + } + + // include models that with apiModelSupport status that user asked for. + // also exclude internal model properties. + var filteredModels = filteredModels(sourceRerankingProviderConfig.models(), statusPredicate); return new RerankingProviderResponse( - rerankingProviderConfig.isDefault(), - rerankingProviderConfig.displayName(), - rerankingProviderConfig.supportedAuthentications(), - filterModels(rerankingProviderConfig.models(), includeModelStatus)); + sourceRerankingProviderConfig.isDefault(), + sourceRerankingProviderConfig.displayName(), + sourceRerankingProviderConfig.supportedAuthentications(), + filteredModels); } - // exclude internal model properties from findRerankingProviders command - // exclude models that are not in the provided statuses - private static List filterModels( - List models, - Set includeModelStatus) { + /** + * Returns models matched by given model supportStatus Predicate, and exclude internal model + * properties from command response. + */ + private static List + filteredModels( + List models, + Predicate statusPredicate) { return models.stream() - .filter(model -> includeModelStatus.contains(model.modelSupport().status())) + .filter(modelConfig -> statusPredicate.test(modelConfig.apiModelSupport().status())) .map( model -> new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl( - model.name(), model.modelSupport(), model.isDefault(), model.url(), null)) + model.name(), model.apiModelSupport(), model.isDefault(), model.url(), null)) .sorted( Comparator.comparing( RerankingProvidersConfig.RerankingProviderConfig.ModelConfig::name)) .collect(Collectors.toList()); } } + + /** + * With {@link FindRerankingProvidersCommand.Options#filterModelStatus()}, there are the rules to + * filter the models: + * + *
    + *
  • If not provided, only SUPPORTED models will be returned. + *
  • If provided with "" empty string or null, all SUPPORTED, DEPRECATED, END_OF_LIFE model + * will be returned. + *
  • If provided with specified SUPPORTED or DEPRECATED or END_OF_LIFE, only models matched + * the status will be returned. + *
+ */ + private Predicate getSupportStatusPredicate() { + if (command.options() == null) { + return status -> status == ApiModelSupport.SupportStatus.SUPPORTED; + } + + if (command.options().filterModelStatus() == null + || command.options().filterModelStatus().isBlank()) { + return status -> true; // accept all + } + + return status -> status.name().equalsIgnoreCase(command.options().filterModelStatus()); + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/IntermediateCollectionReadTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/IntermediateCollectionReadTask.java index c4fa9fbfa3..bc713832aa 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/IntermediateCollectionReadTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/IntermediateCollectionReadTask.java @@ -14,21 +14,16 @@ import io.stargate.sgv2.jsonapi.service.resolver.FindCommandResolver; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.util.recordable.Recordable; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.Supplier; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class IntermediateCollectionReadTask extends BaseTask< CollectionSchemaObject, IntermediateCollectionReadTask.IntermediateReadResultSupplier, IntermediateCollectionReadTask.IntermediateReadResults> { - - private static final Logger LOGGER = - LoggerFactory.getLogger(IntermediateCollectionReadTask.class); - private final FindCommandResolver findCommandResolver; private final FindCommand findCommand; private final DeferredVectorize deferredVectorize; @@ -65,15 +60,13 @@ public static EmbeddingTaskBuilder commandContext) { - // If we have a deferred vectroize, we should use it to update the sort clause on the find + // If we have a deferred vectorize, we should use it to update the sort clause on the find // command if (deferredVectorize != null) { - findCommand.sortClause().sortExpressions().clear(); + var sortClause = findCommand.sortClause(commandContext); + sortClause.sortExpressions().clear(); // will throw if the deferred value is not complete - findCommand - .sortClause() - .sortExpressions() - .add(SortExpression.vsearch(deferredVectorize.getVector())); + sortClause.sortExpressions().add(SortExpression.vsearch(deferredVectorize.getVector())); } Operation findOperation = @@ -102,11 +95,12 @@ protected RuntimeException maybeHandleException( @Override public DataRecorder recordTo(DataRecorder dataRecorder) { + final var sortDef = findCommand.sortDefinition(); return super.recordTo(dataRecorder) .append("deferredVectorize isNull", deferredVectorize == null) .append( "sortClause.sortExpression.paths", - findCommand.sortClause().sortExpressions().stream().map(SortExpression::path).toList()); + (sortDef == null) ? Collections.emptyList() : sortDef.getSortExpressionPaths()); } // ================================================================================================= diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingMetrics.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingMetrics.java new file mode 100644 index 0000000000..faf15992f9 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingMetrics.java @@ -0,0 +1,233 @@ +package io.stargate.sgv2.jsonapi.service.operation.reranking; + +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricNames.*; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricTags.*; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.UNKNOWN_VALUE; +import static io.stargate.sgv2.jsonapi.util.ClassUtils.classSimpleName; + +import io.micrometer.core.instrument.*; +import io.micrometer.core.instrument.Timer; +import io.stargate.sgv2.jsonapi.api.request.RequestContext; +import io.stargate.sgv2.jsonapi.metrics.MetricsConstants; +import io.stargate.sgv2.jsonapi.metrics.MicrometerConfiguration; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject; +import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; +import java.util.*; +import java.util.concurrent.TimeUnit; + +/** + * Records metrics related to reranking operations performed within a specific request context. + * + *

This class separates metrics into two categories based on tagging: + * + *

    + *
  • Tenant-Specific Metrics: Tracked per tenant and table (e.g., {@code + * rerank.tenant.passage.count}, {@code rerank.tenant.call.duration}). Uses {@code tenant} and + * {@code table} tags derived from the {@link RequestContext} and {@link SchemaObject}. + *
  • Overall Metrics: Aggregated across all tenants, but dimensioned by the reranking + * provider and model name (e.g., {@code rerank.all.passage.count}, {@code + * rerank.all.call.duration}). Uses {@code provider} and {@code model} tags derived from the + * {@link RerankingProvider}. + *
+ * + * The duration timers ({@code *.call.duration}) measure the asynchronous execution phase of the + * reranking call, starting after the initial synchronous setup within the provider and ending when + * the asynchronous operation completes (successfully or with failure). + * + *

Note: Configuration of percentiles and histogram settings for timers is handled externally via + * {@link io.micrometer.core.instrument.config.MeterFilter} beans (see {@link + * MicrometerConfiguration}). + */ +public class RerankingMetrics { + private final MeterRegistry meterRegistry; + private final RerankingProvider rerankingProvider; + private final RequestContext requestContext; + private final SchemaObject schemaObject; + + /** + * Constructs a new RerankingMetrics instance for a specific request context. + * + * @param meterRegistry The MeterRegistry to use for recording metrics. + * @param rerankingProvider The RerankingProvider used in the operation, for tagging metrics by + * provider/model. + * @param requestContext The RequestContext for the current request, used for tenant tagging. + * @param schemaObject The SchemaObject representing the target table, used for table tagging. + */ + public RerankingMetrics( + MeterRegistry meterRegistry, + RerankingProvider rerankingProvider, + RequestContext requestContext, + SchemaObject schemaObject) { + this.meterRegistry = Objects.requireNonNull(meterRegistry, "meterRegistry cannot be null"); + this.rerankingProvider = + Objects.requireNonNull(rerankingProvider, "rerankingProvider cannot be null"); + this.requestContext = Objects.requireNonNull(requestContext, "requestContext cannot be null"); + this.schemaObject = Objects.requireNonNull(schemaObject, "schemaObject cannot be null"); + } + + /** + * Records the number of passages being reranked, updating both the tenant-specific metric and the + * overall metric. + * + *

This involves recording the count against two distinct metrics: + * + *

    + *
  • Tenant-specific: {@value MetricsConstants.MetricNames#RERANK_TENANT_PASSAGE_COUNT_METRIC} + * with tags {@value MetricsConstants.MetricTags#TENANT_TAG} and {@value + * MetricsConstants.MetricTags#TABLE_TAG}. + *
  • Overall: {@value MetricsConstants.MetricNames#RERANK_ALL_PASSAGE_COUNT_METRIC} with tags + * {@value MetricsConstants.MetricTags#RERANKING_PROVIDER_TAG} and {@value + * MetricsConstants.MetricTags#RERANKING_MODEL_TAG}. + *
+ * + * @param passageCount The number of passages. + */ + public void recordPassageCount(int passageCount) { + // Record the passage count for the specific tenant and table + Tags tenantTags = + new RerankingTagsBuilder() + .withTenant(requestContext.getTenantId().orElse(UNKNOWN_VALUE)) + .withKeyspace(schemaObject.name().keyspace()) + .withTable(schemaObject.name().table()) + .build(); + meterRegistry.summary(RERANK_TENANT_PASSAGE_COUNT_METRIC, tenantTags).record(passageCount); + + // Record the passage count for all tenants, tagged by provider and model + Tags allTags = + new RerankingTagsBuilder() + .withProvider(classSimpleName(rerankingProvider.getClass())) + .withModel(rerankingProvider.modelName()) + .build(); + meterRegistry.summary(RERANK_ALL_PASSAGE_COUNT_METRIC, allTags).record(passageCount); + } + + /** + * Starts a timer sample to measure the duration of the asynchronous reranking network call phase. + * + *

The returned sample should be passed to {@link #recordCallLatency} upon completion of the + * asynchronous operation. + * + * @return A {@link Timer.Sample} instance representing the start time. + */ + public Timer.Sample startCallLatency() { + return Timer.start(meterRegistry); + } + + /** + * Stops the timer sample once, calculating the duration, and then records that exact duration + * against both the tenant-specific and the overall reranking call duration metrics. + * + *

This ensures the identical duration value is recorded for: + * + *

    + *
  • Tenant-specific: {@value MetricsConstants.MetricNames#RERANK_TENANT_CALL_DURATION_METRIC} + * with tags {@value MetricsConstants.MetricTags#TENANT_TAG} and {@value + * MetricsConstants.MetricTags#TABLE_TAG}. + *
  • Overall: {@value MetricsConstants.MetricNames#RERANK_ALL_CALL_DURATION_METRIC} with tags + * {@value MetricsConstants.MetricTags#RERANKING_PROVIDER_TAG} and {@value + * MetricsConstants.MetricTags#RERANKING_MODEL_TAG}. + *
+ * + * @param sample The {@link Timer.Sample} started by {@link #startCallLatency()}. Must not be + * null. + */ + public void recordCallLatency(Timer.Sample sample) { + Objects.requireNonNull(sample, "sample cannot be null"); + + // --- Tenant-Specific Timer --- + // Build tags for the tenant timer + Tags tenantTags = + new RerankingTagsBuilder() + .withTenant(requestContext.getTenantId().orElse(UNKNOWN_VALUE)) + .withKeyspace(schemaObject.name().keyspace()) + .withTable(schemaObject.name().table()) + .build(); + // Get the tenant timer instance + Timer tenantTimer = meterRegistry.timer(RERANK_TENANT_CALL_DURATION_METRIC, tenantTags); + // Stop the sample against the tenant timer. This records the duration AND returns it. + long durationNanos = sample.stop(tenantTimer); + + // --- Overall Timer --- + // Build tags for the overall timer + Tags allTags = + new RerankingTagsBuilder() + .withProvider(classSimpleName(rerankingProvider.getClass())) + .withModel(rerankingProvider.modelName()) + .build(); + // Get the overall timer instance + Timer allTimer = meterRegistry.timer(RERANK_ALL_CALL_DURATION_METRIC, allTags); + // Manually record the exact same duration (obtained above) to the overall timer. + allTimer.record(durationNanos, TimeUnit.NANOSECONDS); + } + + // --- Static Inner Tag Builder Class --- + + /** + * Builder for creating {@link Tags} specific to reranking metrics. + * + *

Allows flexible combination of common reranking-related tags derived from the context + * provided to the outer {@link RerankingMetrics} instance. + */ + public static class RerankingTagsBuilder { + // Use a Map to store tags and check for duplicates + private final Map tagsMap; + + public RerankingTagsBuilder() { + this.tagsMap = new HashMap<>(); + } + + public RerankingTagsBuilder withTenant(String tenantId) { + putOrThrow(TENANT_TAG, tenantId); + return this; + } + + public RerankingTagsBuilder withKeyspace(String keyspace) { + putOrThrow(KEYSPACE_TAG, keyspace); + return this; + } + + public RerankingTagsBuilder withTable(String table) { + putOrThrow(TABLE_TAG, table); + return this; + } + + public RerankingTagsBuilder withProvider(String provider) { + // TODO: It's not easy to get the provider name as it is described in the config. So we use + // the class name to indicate the provider. Need to replace the class name in the future. + putOrThrow(RERANKING_PROVIDER_TAG, provider); + return this; + } + + public RerankingTagsBuilder withModel(String modelName) { + putOrThrow(RERANKING_MODEL_TAG, modelName); + return this; + } + + /** Helper method to check and put, throwing an exception on duplicate key */ + private void putOrThrow(String key, String value) { + // If the key already exists, that means the caller is trying to set the same tag multiple + // times. Throw IllegalStateException since this is only related to our internal + // implementation. + if (tagsMap.put(key, value) != null) { + throw new IllegalStateException( + String.format( + "Tag key '%s' cannot be set multiple times. Check metrics tags related functions.", + key)); // Use dynamic callerMethodName + } + } + + /** + * Builds the final {@link Tags} object from the added tags. + * + * @return A new {@link Tags} instance containing all tags added via the 'withX' methods. + */ + public Tags build() { + // Convert the map entries back to Tag objects for Micrometer + List tagList = new ArrayList<>(tagsMap.size()); + for (Map.Entry entry : tagsMap.entrySet()) { + tagList.add(Tag.of(entry.getKey(), entry.getValue())); + } + return Tags.of(tagList); + } + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingQuery.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingQuery.java index 0a256a4413..cc9bad25d9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingQuery.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingQuery.java @@ -56,7 +56,7 @@ public static RerankingQuery create(FindAndRerankCommand command) { } var vectorizeQuery = command.sortClause().vectorizeSort(); - // will never be blank, but double checking for safety + // will never be blank, but double-checking for safety if (vectorizeQuery != null && !vectorizeQuery.isBlank()) { return new RerankingQuery(vectorizeQuery, Source.VECTORIZE); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java index 1642edcdb7..6692aca8b5 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java @@ -106,14 +106,21 @@ protected RerankingResultSupplier buildResultSupplier(CommandContext co dedupResult.droppedDocuments(), dedupResult.deduplicatedDocuments.size())); + var rerankMetrics = + new RerankingMetrics( + commandContext.meterRegistry(), + rerankingProvider, + commandContext.requestContext(), + commandContext.schemaObject()); + return new RerankingResultSupplier( commandContext.requestTracing(), rerankingProvider, commandContext.requestContext().getRerankingCredentials(), query, - passageLocator, dedupResult.deduplicatedDocuments(), - limit); + limit, + rerankMetrics); } @Override @@ -214,21 +221,23 @@ public static class RerankingResultSupplier implements UniSupplier unrankedDocs; private final int limit; + private final RerankingMetrics rerankingMetrics; RerankingResultSupplier( RequestTracing requestTracing, RerankingProvider rerankingProvider, RerankingCredentials credentials, RerankingQuery query, - PathMatchLocator passageLocator, List unrankedDocs, - int limit) { + int limit, + RerankingMetrics rerankingMetrics) { this.requestTracing = requestTracing; this.rerankingProvider = rerankingProvider; this.credentials = credentials; this.query = query; this.unrankedDocs = unrankedDocs; this.limit = limit; + this.rerankingMetrics = rerankingMetrics; } @Override @@ -286,10 +295,20 @@ public Uni get() { "limit", limit, "passages", passages)))); + rerankingMetrics.recordPassageCount(passages.size()); + + // Start the timer + var sample = rerankingMetrics.startCallLatency(); + return rerankingProvider .rerank(query.query(), passages, credentials) - .onItem() - .transform( + // Use .eventually() to execute the provided Runnable when the Uni terminates, + // either successfully (onItem) or with a failure (onFailure). + // This is preferred over .onItemOrFailure() when the side-effect action is identical + // for both outcomes and doesn't need access to the item or the failure reason. + .eventually( + () -> rerankingMetrics.recordCallLatency(sample)) // Stop timer regardless of outcome + .map( rerankingResponse -> RerankingTaskResult.create( requestTracing, rerankingProvider, rerankingResponse, unrankedDocs, limit)); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexDBTaskBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexDBTaskBuilder.java index 1520e2f26c..c1f5a60a9e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexDBTaskBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/CreateIndexDBTaskBuilder.java @@ -6,6 +6,7 @@ import io.stargate.sgv2.jsonapi.service.operation.query.CQLOptions; import io.stargate.sgv2.jsonapi.service.operation.tasks.TaskBuilder; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiRegularIndex; +import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTextIndex; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiVectorIndex; import java.util.Objects; @@ -58,6 +59,19 @@ public CreateIndexDBTask build(ApiRegularIndex apiRegularIndex) { buildCqlOptions()); } + public CreateIndexDBTask build(ApiTextIndex apiTextIndex) { + Objects.requireNonNull(apiTextIndex, "apiTextIndex cannot be null"); + checkBuildPreconditions(); + + return new CreateIndexDBTask( + nextPosition(), + schemaObject, + schemaRetryPolicy, + getExceptionHandlerFactory(), + apiTextIndex, + buildCqlOptions()); + } + public CreateIndexDBTask build(ApiVectorIndex apiVectorIndex) { Objects.requireNonNull(apiVectorIndex, "apiVectorIndex cannot be null"); checkBuildPreconditions(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableProjection.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableProjection.java index cf176fe293..3b0f2bd543 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableProjection.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableProjection.java @@ -71,8 +71,9 @@ private TableProjection( * schema. */ public static TableProjection fromDefinition( - ObjectMapper objectMapper, CmdT command, TableSchemaObject table) { + CommandContext ctx, ObjectMapper objectMapper, CmdT command) { + TableSchemaObject table = ctx.schemaObject(); Map columnsByName = new HashMap<>(); // TODO: This can also be cached as part of TableSchemaObject than resolving it for every query. table @@ -120,7 +121,7 @@ public static TableProjection fromDefinition( table, columns, readApiColumns.toColumnsDesc(), - TableSimilarityFunction.from(command, table)); + TableSimilarityFunction.from(ctx, command)); } @Override diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableSimilarityFunction.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableSimilarityFunction.java index bc2ca4e458..971b14b724 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableSimilarityFunction.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableSimilarityFunction.java @@ -6,6 +6,7 @@ import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.querybuilder.select.Select; import com.datastax.oss.driver.api.querybuilder.select.Selector; +import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.Projectable; import io.stargate.sgv2.jsonapi.api.model.command.VectorSortable; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; @@ -22,14 +23,14 @@ public interface TableSimilarityFunction extends Function { String SIMILARITY_SCORE_ALIAS = "similarityScore" + System.currentTimeMillis(); static TableSimilarityFunction from( - CmdT command, TableSchemaObject table) { - + CommandContext ctx, CmdT command) { + final TableSchemaObject table = ctx.schemaObject(); if (!(command instanceof VectorSortable)) { return NO_OP; } var vectorSortable = (VectorSortable) command; - var sortExpressionOptional = vectorSortable.vectorSortExpression(); + var sortExpressionOptional = vectorSortable.vectorSortExpression(ctx); if (sortExpressionOptional.isEmpty()) { // nothing to sort on, so nothing to return even if they asked for the similarity score return NO_OP; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableWhereCQLClause.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableWhereCQLClause.java index 4e82caba6a..c218104b04 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableWhereCQLClause.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tables/TableWhereCQLClause.java @@ -2,17 +2,20 @@ import com.datastax.oss.driver.api.querybuilder.delete.Delete; import com.datastax.oss.driver.api.querybuilder.relation.OngoingWhereClause; +import com.datastax.oss.driver.api.querybuilder.relation.Relation; import com.datastax.oss.driver.api.querybuilder.select.Select; import com.datastax.oss.driver.api.querybuilder.update.Update; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.LogicalExpression; import io.stargate.sgv2.jsonapi.exception.WithWarnings; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; +import io.stargate.sgv2.jsonapi.service.cqldriver.override.DefaultSubConditionRelation; import io.stargate.sgv2.jsonapi.service.operation.query.*; +import java.util.ArrayList; import java.util.List; import java.util.Objects; /** - * Builds the WHERE clause in a CQL statment when using the Java Driver Query Builder. + * Builds the WHERE clause in a CQL statement when using the Java Driver Query Builder. * *

TODO: this accepts the {@link LogicalExpression} to build the statement, we want to stop * handing that down to the operations but keeping for now for POC work. @@ -20,9 +23,9 @@ *

NOTE: Using a class so the ctor can be made private to force use fo the static factories that * solve the generic typing needed for the {@link OngoingWhereClause}. * - * @param The type of Query Builder stament that the where clause is being added to, use the - * static factory methods like {@link #forSelect(TableSchemaObject, DBLogicalExpression)} to get - * the correct type. + * @param The type of Query Builder statement that the where clause is being added to, use the + * static factory methods like {@link #forSelect(TableSchemaObject, WithWarnings)} to get the + * correct type. */ public class TableWhereCQLClause> implements WhereCQLClause { @@ -39,11 +42,9 @@ private TableWhereCQLClause( /** * Build an instance to add the where clause to a {@link Select}. * - *

- * - * @param table - * @param dbLogicalExpression - * @return + * @param table the target table schema. + * @param dbLogicalExpression the DB LogicalExpression that contains all the filters and + * logicalOperator */ public static WithWarnings> forSelect( TableSchemaObject table, WithWarnings dbLogicalExpression) { @@ -54,11 +55,9 @@ public static WithWarnings> forSelect( /** * Build an instance to add the where clause to a {@link Update}. * - *

- * - * @param table - * @param dbLogicalExpression - * @return + * @param table the target table schema. + * @param dbLogicalExpression the DB LogicalExpression that contains all the filters and + * logicalOperator */ public static WithWarnings> forUpdate( TableSchemaObject table, WithWarnings dbLogicalExpression) { @@ -69,34 +68,88 @@ public static WithWarnings> forUpdate( /** * Build an instance to add the where clause to a {@link Delete}. * - *

- * - * @param table - * @param dbLogicalExpression - * @return + * @param table the target table schema. + * @param dbLogicalExpression the DB LogicalExpression that contains all the filters and + * logicalOperator */ public static TableWhereCQLClause forDelete( TableSchemaObject table, DBLogicalExpression dbLogicalExpression) { return new TableWhereCQLClause<>(table, dbLogicalExpression); } + /** + * Apply the {@link TableWhereCQLClause} to the {@link OngoingWhereClause}. It will recursively + * apply all the filters and logical relations to build the {@link Relation} and feed to the + * {@link OngoingWhereClause}. + * + * @param tOngoingWhereClause the {@link OngoingWhereClause} to apply the filters and logical + * relations to. + * @param objects the positional values to append to the {@link Relation}. + * @return the {@link OngoingWhereClause} with the filters and logical relations applied. + */ @Override - public DBLogicalExpression getLogicalExpression() { - return dbLogicalExpression; + public T apply(T tOngoingWhereClause, List objects) { + + // If there is no filter in the entire logicalExpression tree + // Just return the ongoingWhereClause without apply anything. + // This could happen when user provides no filter or empty filter in the request. + if (dbLogicalExpression.isEmpty()) { + return tOngoingWhereClause; + } + + return tOngoingWhereClause.where(applyLogicalRelation(dbLogicalExpression, objects)); } - @Override - public T apply(T tOngoingWhereClause, List objects) { - // TODO BUG: this probably breaks order for nested expressions, for now enough to get this - // tested - var tableFilters = - dbLogicalExpression.filters().stream().map(dbFilter -> (TableFilter) dbFilter).toList(); - - // Add the where clause operations - for (TableFilter tableFilter : tableFilters) { - tOngoingWhereClause = tableFilter.apply(tableSchemaObject, tOngoingWhereClause, objects); + /** + * Method to recursively resolve the DBLogicalExpression into regular relation {@link Relation} or + * AND/OR relation {@link DefaultSubConditionRelation} that Driver QueryBuilder expects. + * + * @param currentLogicalExpression currentLogicalExpression. + * @param objects positionalValues to append in order. + * @return {@link Relation} conjunct relation from DBLogicalExpression. + */ + private Relation applyLogicalRelation( + DBLogicalExpression currentLogicalExpression, List objects) { + + // create the default relation to represent the current level of AND/OR, E.G. + // implicit and: {"name": "John"} -> WHERE (name=?) + // implicit and with explicit or: {"$or": [{"name": "John"}, {"age": 30}]} -> WHERE ((name=? OR + // age=?)) + // Ideally, we don't need the parenthesis in the root level + // but this is how the driver override works current, see DefaultSubConditionRelation.java + // We can not remove the root level parenthesis currently to build the logical relation + var relationWhere = DefaultSubConditionRelation.subCondition(); + + List relations = new ArrayList<>(); + // First, add all simple filters + currentLogicalExpression + .filters() + .forEach(filter -> relations.add(((TableFilter) filter).apply(tableSchemaObject, objects))); + + // Then, recursively build relations from sub_levels + currentLogicalExpression + .subExpressions() + .forEach(subExpr -> relations.add(applyLogicalRelation(subExpr, objects))); + + // if current logical operator is AND, construct and() relation + if (currentLogicalExpression.operator() == DBLogicalExpression.DBLogicalOperator.AND + && !relations.isEmpty()) { + relationWhere = relationWhere.where(relations.getFirst()); + for (int i = 1; i < relations.size(); i++) { + relationWhere = relationWhere.and().where(relations.get(i)); + } } - return tOngoingWhereClause; + + // if current logical operator is AND, construct and() relation + if (currentLogicalExpression.operator() == DBLogicalExpression.DBLogicalOperator.OR + && !relations.isEmpty()) { + relationWhere = relationWhere.where(relations.getFirst()); + for (int i = 1; i < relations.size(); i++) { + relationWhere = relationWhere.or().where(relations.get(i)); + } + } + + return relationWhere; } @Override @@ -123,4 +176,9 @@ public boolean selectsSinglePartition(TableSchemaObject tableSchemaObject) { } return true; } + + @Override + public DBLogicalExpression getLogicalExpression() { + return dbLogicalExpression; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tasks/DBTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tasks/DBTask.java index 359f5e44a7..500a37099d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tasks/DBTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/tasks/DBTask.java @@ -179,6 +179,7 @@ protected CommandQueryExecutor getCommandQueryExecutor(CommandContext c new CommandQueryExecutor.DBRequestContext( commandContext.requestContext().getTenantId(), commandContext.requestContext().getCassandraToken(), + commandContext.requestContext().getUserAgent(), commandContext.requestTracing().enabled()), CommandQueryExecutor.QueryTarget.TABLE); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/CommandProcessor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/CommandProcessor.java index 71b4a1257c..2251e9e376 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/CommandProcessor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/CommandProcessor.java @@ -24,13 +24,9 @@ import org.slf4j.LoggerFactory; /** - * Processes valid document {@link Command} to read, write, schema change, etc. This is a single - * entry to run a Command without worrying how to. - * - *

Called from the API layer which deals with public JSON data formats, the command layer - * translates from the JSON models to internal and back (shredding and de-shredding). - * - *

May provide a thread or resource boundary from calling API layer. + * Processes a {@link Command} by taking it through a series of transformations: expansion, + * vectorization, resolution to an {@link Operation}, and finally execution. It also handles error + * recovery and result post-processing. */ @ApplicationScoped public class CommandProcessor { @@ -49,20 +45,18 @@ public CommandProcessor( } /** - * Processes a single command in a given command context. + * Processes a single command through the full pipeline. * - * @param commandContext {@link CommandContext} - * @param command {@link Command} - * @return Uni emitting the result of the command execution. + * @param commandContext The context for the command execution. + * @param command The command to be processed. * @param Type of the command. - * @param Type of the schema object command operates on. + * @param Type of the schema object the command operates on. + * @return A {@link Uni} emitting the {@link CommandResult} of the command execution. */ public Uni processCommand( CommandContext commandContext, CommandT command) { - var debugMode = commandContext.config().get(DebugModeConfig.class).enabled(); - var errorObjectV2 = commandContext.config().get(OperationsConfig.class).extendError(); - + // Initial tracing before the reactive pipeline starts commandContext .requestTracing() .maybeTrace( @@ -74,84 +68,124 @@ public Uni { - // start by resolving the command, get resolver - return commandResolverService - .resolverForCommand(vectorizedCommand) - - // resolver can be null, not handled in CommandResolverService for now - .flatMap( - resolver -> { - // if we have resolver, resolve operation - Operation operation = - resolver.resolveCommand(commandContext, vectorizedCommand); - return Uni.createFrom().item(operation); - }); + + // Step 1: Expand any hybrid fields in the command (synchronous) and record the command + // features + .invoke( + cmd -> { + HybridFieldExpander.expandHybridField(commandContext, cmd); + cmd.addCommandFeatures(commandContext.commandFeatures()); }) - // execute the operation + // Step 2: Vectorize relevant parts of the command (asynchronous) + .flatMap(cmd -> dataVectorizerService.vectorize(commandContext, cmd)) + + // Step 3: Resolve the vectorized command to a runnable Operation (asynchronous) + .flatMap(cmd -> resolveCommandToOperation(commandContext, cmd)) + + // Step 4: Execute the operation (asynchronous) .flatMap(operation -> operation.execute(commandContext)) - // handle failures here + // Step 5: Handle any failures from the preceding steps .onFailure() - .recoverWithItem( - t -> - switch (t) { - case APIException apiException -> { - // new error object V2 - var errorBuilder = - new APIExceptionCommandErrorBuilder(debugMode, errorObjectV2); - - // yet more mucking about with suppliers everywhere :( - yield (Supplier) - () -> - CommandResult.statusOnlyBuilder( - errorObjectV2, debugMode, commandContext.requestTracing()) - .addCommandResultError( - errorBuilder.buildLegacyCommandResultError(apiException)) - .build(); - } - case JsonApiException jsonApiException -> - // old error objects, old comment below - // Note: JsonApiException means that JSON API itself handled the situation - // (created, or wrapped the exception) -- should not be logged (have already - // been logged if necessary) - jsonApiException; - default -> { - // Old error handling below, to be replaced eventually (aaron aug 28 2024) - // But other exception types are unexpected, so log for now - logger.warn( - "Command '{}' failed with exception", - command.getClass().getSimpleName(), - t); - yield new ThrowableCommandResultSupplier(t); - } - }) - - // if we have a non-null item - // call supplier get to map to the command result + .recoverWithItem(throwable -> handleProcessingFailure(commandContext, command, throwable)) + + // Step 6: Transform the successful or recovered item (Supplier) into + // CommandResult .onItem() .ifNotNull() .transform(Supplier::get) - // add possible warning for using a deprecated command - .map( - commandResult -> { - if (command instanceof DeprecatedCommand deprecatedCommand) { - // for the warnings we always want V2 errors and do not want / need debug ? - var errorV2 = - new APIExceptionCommandErrorBuilder(false, true) - .buildCommandErrorV2(deprecatedCommand.getDeprecationWarning()); - commandResult.addWarning(errorV2); - } - return commandResult; + + // Step 7: Perform any final post-processing on the CommandResult (e.g., add warnings) + .map(commandResult -> postProcessCommandResult(command, commandResult)); + } + + /** + * Resolves a {@link Command} to its corresponding {@link Operation}. + * + * @param commandContext The command context. + * @param commandToResolve The command to resolve. + * @param Type of the schema object. + * @return A {@link Uni} emitting the resolved {@link Operation}. + */ + private Uni> resolveCommandToOperation( + CommandContext commandContext, Command commandToResolve) { + return commandResolverService + // Find resolver for command, it handles the case where resolver is null + .resolverForCommand(commandToResolve) + .flatMap( + resolver -> { + // Now the resolver is found, resolve the command to an operation. + // This resolution step itself is synchronous. + Operation operation = + resolver.resolveCommand(commandContext, commandToResolve); + return Uni.createFrom().item(operation); }); } + + /** + * Handles failures that occur during the command processing pipeline. It attempts to convert + * known exceptions (APIException, JsonApiException) into a {@link CommandResult} supplier, and + * logs other unexpected exceptions. + * + * @param commandContext The command context. + * @param originalCommand The initial command that was being processed. + * @param throwable The failure. + * @return A {@link Supplier} of {@link CommandResult} representing the error. + */ + private Supplier handleProcessingFailure( + CommandContext commandContext, Command originalCommand, Throwable throwable) { + var debugMode = commandContext.config().get(DebugModeConfig.class).enabled(); + var errorObjectV2 = commandContext.config().get(OperationsConfig.class).extendError(); + + return switch (throwable) { + case APIException apiException -> { + // new error object V2 + var errorBuilder = new APIExceptionCommandErrorBuilder(debugMode, errorObjectV2); + yield () -> + CommandResult.statusOnlyBuilder( + errorObjectV2, debugMode, commandContext.requestTracing()) + .addCommandResultError(errorBuilder.buildLegacyCommandResultError(apiException)) + .build(); + } + case JsonApiException jsonApiException -> + // old error objects, old comment below + // Note: JsonApiException means that JSON API itself handled the situation + // (created, or wrapped the exception) -- should not be logged (have already + // been logged if necessary) + jsonApiException; + default -> { + // Old error handling below, to be replaced eventually (aaron aug 28 2024) + // But other exception types are unexpected, so log for now + logger.warn( + "Command '{}' failed with exception", + originalCommand.getClass().getSimpleName(), + throwable); + yield new ThrowableCommandResultSupplier(throwable); + } + }; + } + + /** + * Performs post-processing on the {@link CommandResult}, such as adding warnings for deprecated + * commands. + * + * @param originalCommand The initial command that was processed. + * @param commandResult The result of the command execution. + * @return The potentially modified {@link CommandResult}. + */ + private CommandResult postProcessCommandResult( + Command originalCommand, CommandResult commandResult) { + if (originalCommand instanceof DeprecatedCommand deprecatedCommand) { + // (aaron) for the warnings we always want V2 errors and do not want / need debug ? + var errorV2 = + new APIExceptionCommandErrorBuilder(false, true) + .buildCommandErrorV2(deprecatedCommand.getDeprecationWarning()); + commandResult.addWarning(errorV2); + } + return commandResult; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/HybridFieldExpander.java b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/HybridFieldExpander.java index 9843551031..4f21595909 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/HybridFieldExpander.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/HybridFieldExpander.java @@ -10,6 +10,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.impl.InsertOneCommand; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.metrics.CommandFeature; import io.stargate.sgv2.jsonapi.util.JsonUtil; import java.util.Iterator; @@ -27,12 +28,12 @@ public static void expandHybridField(CommandContext context, Command command) { if (context.isCollectionContext()) { // and just for Insert commands, in particular switch (command) { - case InsertOneCommand cmd -> expandHybridField(0, 1, cmd.document()); + case InsertOneCommand cmd -> expandHybridField(context, 0, 1, cmd.document()); case InsertManyCommand cmd -> { var docs = cmd.documents(); if (docs != null) { for (int i = 0, len = docs.size(); i < len; ++i) { - expandHybridField(i, len, docs.get(i)); + expandHybridField(context, i, len, docs.get(i)); } } } @@ -42,24 +43,46 @@ public static void expandHybridField(CommandContext context, Command command) { } // protected for testing purposes - protected static void expandHybridField(int docIndex, int docCount, JsonNode docNode) { + protected static void expandHybridField( + CommandContext context, int docIndex, int docCount, JsonNode docNode) { final JsonNode hybridField; - if ((docNode instanceof ObjectNode doc) - && (hybridField = doc.remove(DocumentConstants.Fields.HYBRID_FIELD)) != null) { - switch (hybridField) { - case NullNode ignored -> addLexicalAndVectorize(doc, hybridField, hybridField); - case TextNode ignored -> addLexicalAndVectorize(doc, hybridField, hybridField); - case ObjectNode ob -> addFromObject(doc, ob, docIndex, docCount); - default -> - throw ErrorCodeV1.HYBRID_FIELD_UNSUPPORTED_VALUE_TYPE.toApiException( - "expected String, Object or `null` but received %s (Document %d of %d)", - JsonUtil.nodeTypeAsString(hybridField), docIndex + 1, docCount); + if (docNode instanceof ObjectNode doc) { + // Check for $hybrid field + if ((hybridField = doc.remove(DocumentConstants.Fields.HYBRID_FIELD)) != null) { + context.commandFeatures().addFeature(CommandFeature.HYBRID); + switch (hybridField) { + // this is {"$hybrid" : null} + case NullNode ignored -> addLexicalAndVectorize(doc, hybridField, hybridField); + case TextNode ignored -> addLexicalAndVectorize(doc, hybridField, hybridField); + case ObjectNode ob -> addFromObject(context, doc, ob, docIndex, docCount); + default -> + throw ErrorCodeV1.HYBRID_FIELD_UNSUPPORTED_VALUE_TYPE.toApiException( + "expected String, Object or `null` but received %s (Document %d of %d)", + JsonUtil.nodeTypeAsString(hybridField), docIndex + 1, docCount); + } + } else { + // No $hybrid field, check other fields and add feature usage to CommandContext + if (doc.has(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD) + && !doc.get(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD).isNull()) { + context.commandFeatures().addFeature(CommandFeature.VECTOR); + } + // `$vectorize` and `$vector` can't be used together - the check will be done later (in + // DataVectorizer) + if (doc.has(DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD) + && !doc.get(DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD).isNull()) { + context.commandFeatures().addFeature(CommandFeature.VECTORIZE); + } + if (doc.has(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD) + && !doc.get(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD).isNull()) { + context.commandFeatures().addFeature(CommandFeature.LEXICAL); + } } } } - private static void addFromObject(ObjectNode doc, ObjectNode hybrid, int docIndex, int docCount) { + private static void addFromObject( + CommandContext context, ObjectNode doc, ObjectNode hybrid, int docIndex, int docCount) { JsonNode lexical = hybrid.remove(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD); JsonNode vectorize = hybrid.remove(DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD); @@ -76,9 +99,15 @@ private static void addFromObject(ObjectNode doc, ObjectNode hybrid, int docInde lexical = validateSubFieldType( lexical, DocumentConstants.Fields.LEXICAL_CONTENT_FIELD, docIndex, docCount); + if (!lexical.isNull()) { + context.commandFeatures().addFeature(CommandFeature.LEXICAL); + } vectorize = validateSubFieldType( vectorize, DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, docIndex, docCount); + if (!vectorize.isNull()) { + context.commandFeatures().addFeature(CommandFeature.VECTORIZE); + } addLexicalAndVectorize(doc, lexical, vectorize); } @@ -97,12 +126,13 @@ private static JsonNode validateSubFieldType( } private static void addLexicalAndVectorize(ObjectNode doc, JsonNode lexical, JsonNode vectorize) { - // Important: verify we had no conflict with existing $lexical or $vectorize fields + // Important: verify we had no conflict with existing $lexical or $vector or $vectorize fields // (that is, values from $hybrid would not overwrite existing values) var oldLexical = doc.replace(DocumentConstants.Fields.LEXICAL_CONTENT_FIELD, lexical); + var oldVector = doc.get(DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD); var oldVectorize = doc.replace(DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD, vectorize); - if ((oldLexical != null) || (oldVectorize != null)) { + if ((oldLexical != null) || (oldVector != null) || (oldVectorize != null)) { throw ErrorCodeV1.HYBRID_FIELD_CONFLICT.toApiException(); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java index e55d7ff675..75baf1673d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/processor/MeteredCommandProcessor.java @@ -1,16 +1,15 @@ package io.stargate.sgv2.jsonapi.service.processor; +import static io.stargate.sgv2.jsonapi.config.constants.ErrorObjectV2Constants.MetricTags.ERROR_CODE; +import static io.stargate.sgv2.jsonapi.config.constants.ErrorObjectV2Constants.MetricTags.EXCEPTION_CLASS; import static io.stargate.sgv2.jsonapi.config.constants.LoggingConstants.*; import com.fasterxml.jackson.core.JacksonException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; import io.micrometer.core.instrument.*; -import io.micrometer.core.instrument.config.MeterFilter; -import io.micrometer.core.instrument.distribution.DistributionStatisticConfig; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.model.command.impl.*; import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonApiMetricsConfig; import io.stargate.sgv2.jsonapi.api.v1.metrics.MetricsConfig; @@ -18,48 +17,47 @@ import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject; import jakarta.enterprise.context.ApplicationScoped; -import jakarta.enterprise.inject.Produces; import jakarta.inject.Inject; -import jakarta.inject.Singleton; import java.util.Collections; -import java.util.List; import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.MDC; +/** + * A processor that wraps the core {@link CommandProcessor} to add metrics and command-level logging + * capabilities. + * + *

It measures the execution time of commands, collects various tags (like command name, tenant, + * errors, vector usage), and logs command details based on configuration. + */ @ApplicationScoped public class MeteredCommandProcessor { private static final Logger logger = LoggerFactory.getLogger(MeteredCommandProcessor.class); + // ObjectWriter for serializing CommandLog instances. private static final ObjectWriter OBJECT_WRITER = new ObjectMapper().writer(); + // Constants for tag values private static final String UNKNOWN_VALUE = "unknown"; - private static final String NA = "NA"; + // Core processor dependency private final CommandProcessor commandProcessor; + // Metrics and configuration dependencies private final MeterRegistry meterRegistry; - private final JsonApiMetricsConfig jsonApiMetricsConfig; - private final MetricsConfig.TenantRequestCounterConfig tenantConfig; + private final CommandLevelLoggingConfig commandLevelLoggingConfig; - /** The tag for error being true, created only once. */ + // Pre-computed common tags for efficiency private final Tag errorTrue; - - /** The tag for error being false, created only once. */ private final Tag errorFalse; - - /** The tag for tenant being unknown, created only once. */ private final Tag tenantUnknown; - private final Tag defaultErrorCode; - private final Tag defaultErrorClass; - private final CommandLevelLoggingConfig commandLevelLoggingConfig; @Inject public MeteredCommandProcessor( @@ -71,51 +69,90 @@ public MeteredCommandProcessor( this.commandProcessor = commandProcessor; this.meterRegistry = meterRegistry; this.jsonApiMetricsConfig = jsonApiMetricsConfig; - tenantConfig = metricsConfig.tenantRequestCounter(); + this.tenantConfig = metricsConfig.tenantRequestCounter(); + this.commandLevelLoggingConfig = commandLevelLoggingConfig; + + // Pre-compute common tags for efficiency errorTrue = Tag.of(tenantConfig.errorTag(), "true"); errorFalse = Tag.of(tenantConfig.errorTag(), "false"); tenantUnknown = Tag.of(tenantConfig.tenantTag(), UNKNOWN_VALUE); defaultErrorCode = Tag.of(jsonApiMetricsConfig.errorCode(), NA); defaultErrorClass = Tag.of(jsonApiMetricsConfig.errorClass(), NA); - this.commandLevelLoggingConfig = commandLevelLoggingConfig; } /** - * Processes a single command in a given command context. + * Processes a single command, adding metrics and logging around the core execution. * - * @param commandContext {@link CommandContext} - * @param command {@link Command} + * @param commandContext {@link CommandContext} The context for the command execution, containing + * schema, request info, etc. + * @param command The {@link Command} to be processed. * @param Type of the command. - * @return Uni emitting the result of the command execution. + * @param Type of the schema object. + * @return A Uni emitting the {@link CommandResult} of the command execution. */ public Uni processCommand( CommandContext commandContext, CommandT command) { Timer.Sample sample = Timer.start(meterRegistry); + + // Set up logging context (MDC) // use MDC to populate logs as needed(namespace,collection,tenantId) commandContext.schemaObject().name().addToMDC(); - MDC.put("tenantId", commandContext.requestContext().getTenantId().orElse(UNKNOWN_VALUE)); - // start by resolving the command, get resolver - return commandProcessor - .processCommand(commandContext, command) + + // --- Defer Command Processing (from PR2076) --- + // We wrap the call to `commandProcessor.processCommand` in `Uni.createFrom().deferred()` + // for two main reasons: + // 1. Defensive Programming for Synchronous Failures: + // Ensures that if `commandProcessor.processCommand` itself (or code executed synchronously + // within it before its own reactive chain fully forms and handles errors) throws a + // synchronous exception, this MeteredCommandProcessor can still catch it in its + // `.onFailure()` block for consistent logging (no metrics currently). This acts as a safety + // net. + // 2. Lazy Execution (Benefit of `deferred`): + // The `commandProcessor.processCommand` method, which kicks off potentially significant + // work and returns a Uni, will only be invoked when this resulting Uni is actually + // subscribed to. + return Uni.createFrom() + .deferred(() -> commandProcessor.processCommand(commandContext, command)) .onItem() .invoke( + // Success path handling result -> { + // --- Metrics Recording --- Tags tags = getCustomTags(commandContext, command, result); - // add metrics sample.stop(meterRegistry.timer(jsonApiMetricsConfig.metricsName(), tags)); + // --- Command Level Logging (Success) --- if (isCommandLevelLoggingEnabled(commandContext, result, false)) { logger.info(buildCommandLog(commandContext, command, result)); } }) .onFailure() .invoke( + // Failure path handling. + // This block will only be executed if the Uni returned by + // commandProcessor.processCommand terminates with a failure signal that was not caught + // and recovered by the .recoverWithItem() block inside commandProcessor.processCommand. + // That is, this should not happen unless there are unexpected errors before/in the + // .recoverWithItem() or there are some framework errors throwable -> { + // TODO: Metrics timer (`sample.stop()`) is not called here by design? + + // --- Command Level Logging if (isCommandLevelLoggingEnabled(commandContext, null, true)) { - logger.error(buildCommandLog(commandContext, command, null), throwable); + logger.error( + "Command processing failed. Details: {}", + buildCommandLog(commandContext, command, null), + throwable); } + }) + .eventually( + () -> { + // Cleanup MDC after processing completes (success or failure) to prevent data from + // leaking into the next request handled by the same thread. + commandContext.schemaObject().name().removeFromMDC(); + MDC.remove("tenantId"); }); } @@ -147,54 +184,67 @@ private String buildCommandLog( } /** - * Get outgoing documents count in a command result. + * Counts outgoing documents from a {@link CommandResult}. * - * @param result - * @return + * @param result The command result. + * @return Document count as a String, or "NA" if not applicable or result is null. */ private String getOutgoingDocumentsCount(CommandResult result) { - if (result == null) { - return "NA"; + if (result == null || result.data() == null) { + return NA; } + // Check specific ResponseData types that contain documents if (result.data() instanceof ResponseData.MultiResponseData || result.data() instanceof ResponseData.SingleResponseData) { return String.valueOf(result.data().getResponseDocuments().size()); } - return "NA"; + // Other types don't have outgoing docs in this sense + return NA; } /** - * Get incoming documents count in a command. + * Counts incoming documents for relevant commands (InsertOne, InsertMany). * - * @param command - * @return + * @param command The command being executed. + * @return Document count as a String, or "NA" if not applicable. */ private String getIncomingDocumentsCount(Command command) { - if (command instanceof InsertManyCommand) { - return String.valueOf(((InsertManyCommand) command).documents().size()); - } else if (command instanceof InsertOneCommand) { - return String.valueOf(((InsertOneCommand) command).document() != null ? 1 : 0); - } - return "NA"; + return switch (command) { + case InsertManyCommand insertManyCmd -> + String.valueOf(insertManyCmd.documents() != null ? insertManyCmd.documents().size() : 0); + case InsertOneCommand insertOneCmd -> String.valueOf(insertOneCmd.document() != null ? 1 : 0); + default -> + // Command types without relevant incoming documents + NA; + }; } /** - * @param commandResult - command result - * @param isFailure - Is from the failure flow - * @return true if command level logging is allowed, false otherwise + * Checks if command-level logging should be performed based on configuration and result status. + * + * @param commandContext Command context (used for tenant filtering). + * @param commandResult The result (used for error filtering), can be null for failure path. + * @param isFailure Indicates if this check is being done during the failure handling path. + * @return {@code true} if logging should proceed, {@code false} otherwise. */ private boolean isCommandLevelLoggingEnabled( CommandContext commandContext, CommandResult commandResult, boolean isFailure) { + // Globally disabled? if (!commandLevelLoggingConfig.enabled()) { return false; } + + // Check tenant filter (if configured) Set allowedTenants = commandLevelLoggingConfig.enabledTenants().orElse(Collections.singleton(ALL_TENANTS)); if (!allowedTenants.contains(ALL_TENANTS) && !allowedTenants.contains( commandContext.requestContext().getTenantId().orElse(UNKNOWN_VALUE))) { + // Logging disabled for this tenant return false; } + + // Disabled if no errors in command if (!isFailure && commandLevelLoggingConfig.onlyResultsWithErrors() && (commandResult == null @@ -202,109 +252,109 @@ private boolean isCommandLevelLoggingEnabled( || commandResult.errors().isEmpty())) { return false; } + // return true in all other cases return true; } /** - * Generate custom tags based on the command and result. + * Generates metric tags based on the command, context, and result. * - * @param command - request command - * @param result - response command result - * @return + * @param commandContext The command context. + * @param command The executed command. + * @param result The result of the command execution (contains data and errors). + * @param Type of the schema object. + * @return A set of Micrometer {@link Tags}. */ private Tags getCustomTags( CommandContext commandContext, Command command, CommandResult result) { + // --- Basic Tags --- + // Identify the command being executed and the tenant associated with the request Tag commandTag = Tag.of(jsonApiMetricsConfig.command(), command.getClass().getSimpleName()); String tenant = commandContext.requestContext().getTenantId().orElse(UNKNOWN_VALUE); Tag tenantTag = Tag.of(tenantConfig.tenantTag(), tenant); + + // --- Error Tags --- + // Determine if the command resulted in an error and capture details Tag errorTag = errorFalse; Tag errorClassTag = defaultErrorClass; Tag errorCodeTag = defaultErrorCode; - // if error is present, add error tags else use defaults - if (null != result.errors() && !result.errors().isEmpty()) { + if (result != null && null != result.errors() && !result.errors().isEmpty()) { errorTag = errorTrue; - String errorClass = - (String) - result - .errors() - .get(0) - .fieldsForMetricsTag() - .getOrDefault("exceptionClass", UNKNOWN_VALUE); + // Extract details from the first error object's metric fields. + // TODO: Assumption use the first error is representative for metrics? + var metricFields = result.errors().getFirst().fieldsForMetricsTag(); + + // Safely extract error class and code, defaulting to UNKNOWN_VALUE + String errorClass = (String) metricFields.getOrDefault(EXCEPTION_CLASS, UNKNOWN_VALUE); errorClassTag = Tag.of(jsonApiMetricsConfig.errorClass(), errorClass); - String errorCode = - (String) - result.errors().get(0).fieldsForMetricsTag().getOrDefault("errorCode", UNKNOWN_VALUE); + String errorCode = (String) metricFields.getOrDefault(ERROR_CODE, UNKNOWN_VALUE); errorCodeTag = Tag.of(jsonApiMetricsConfig.errorCode(), errorCode); } + // --- Schema Feature Tags --- + // Indicate if the collection/table has vector search enabled in its schema Tag vectorEnabled = - commandContext.schemaObject().vectorConfig().vectorEnabled() - ? Tag.of(jsonApiMetricsConfig.vectorEnabled(), "true") - : Tag.of(jsonApiMetricsConfig.vectorEnabled(), "false"); + Tag.of( + jsonApiMetricsConfig.vectorEnabled(), + Boolean.toString(commandContext.schemaObject().vectorConfig().vectorEnabled())); + + // --- Sort Type Tag --- + // Determine the type of sorting used (if any), primarily for FindCommand. + // NOTE: This logic might need refinement or replacement when CommandFeatures is fully + // integrated, especially for FindAndRerankCommand. JsonApiMetricsConfig.SortType sortType = getVectorTypeTag(commandContext, command); Tag sortTypeTag = Tag.of(jsonApiMetricsConfig.sortType(), sortType.name()); - Tags tags = - Tags.of( - commandTag, - tenantTag, - errorTag, - errorClassTag, - errorCodeTag, - vectorEnabled, - sortTypeTag); - return tags; + + // --- Command Feature Usage Tags --- + Tags commandFeatureTags = commandContext.commandFeatures().getTags(); + + // --- Combine All Tags --- + return commandFeatureTags.and( + commandTag, tenantTag, errorTag, errorClassTag, errorCodeTag, vectorEnabled, sortTypeTag); } + /** + * Determines the {@link JsonApiMetricsConfig.SortType} of sorting used in the command based on + * the command's sort clause. Primarily intended for vector-based sorting. + * + * @param commandContext The command context. + * @param command The command being executed. + * @return The type of sorting used (if any). + */ private JsonApiMetricsConfig.SortType getVectorTypeTag( CommandContext commandContext, Command command) { + // Get the count of filter conditions applied int filterCount = 0; if (command instanceof Filterable filterable) { filterCount = filterable.filterClause(commandContext).size(); } + + // Check if the command supports sorting and has a sort clause defined if (command instanceof Sortable sc - && sc.sortClause() != null - && !sc.sortClause().sortExpressions().isEmpty()) { - if (sc.sortClause() != null) { - List sortClause = sc.sortClause().sortExpressions(); - if (sortClause.size() == 1 - && (DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD.equals(sortClause.get(0).path()) - || DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD.equals( - sortClause.get(0).path()))) { + && !sc.sortClause(commandContext).sortExpressions().isEmpty()) { + + var sortExpressions = sc.sortClause(commandContext).sortExpressions(); + + // Check if the only sort expression is for vector similarity ($vector or $vectorize) + if (sortExpressions.size() == 1) { + String sortPath = sortExpressions.getFirst().path(); + if (DocumentConstants.Fields.VECTOR_EMBEDDING_FIELD.equals(sortPath) // $vector + || DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD.equals(sortPath)) // $vectorize + { + // It's a pure vector similarity sort and no filters were applied if (filterCount == 0) { return JsonApiMetricsConfig.SortType.SIMILARITY_SORT; - } else { - return JsonApiMetricsConfig.SortType.SIMILARITY_SORT_WITH_FILTERS; } - } else { - return JsonApiMetricsConfig.SortType.SORT_BY_FIELD; + // Filters were applied alongside the vector sort + return JsonApiMetricsConfig.SortType.SIMILARITY_SORT_WITH_FILTERS; } } + // If more than one sort expression, or the single one isn't $vector/$vectorize + return JsonApiMetricsConfig.SortType.SORT_BY_FIELD; } - return JsonApiMetricsConfig.SortType.NONE; - } - /** Enable histogram buckets for a specific timer */ - private static final String HISTOGRAM_METRICS_NAME = "http.server.requests"; - - @Produces - @Singleton - public MeterFilter enableHistogram() { - return new MeterFilter() { - @Override - public DistributionStatisticConfig configure( - Meter.Id id, DistributionStatisticConfig config) { - if (id.getName().startsWith(HISTOGRAM_METRICS_NAME) - || id.getName().startsWith(jsonApiMetricsConfig.vectorizeCallDurationMetrics())) { - - return DistributionStatisticConfig.builder() - .percentiles(0.5, 0.90, 0.95, 0.99) // median and 95th percentile, not aggregable - .percentilesHistogram(true) // histogram buckets (e.g. prometheus histogram_quantile) - .build() - .merge(config); - } - return config; - } - }; + // Default if no sorting is detected or applicable + return JsonApiMetricsConfig.SortType.NONE; } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ApiModelSupport.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ApiModelSupport.java new file mode 100644 index 0000000000..85110d3a44 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ApiModelSupport.java @@ -0,0 +1,43 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.smallrye.config.WithDefault; +import java.util.Optional; + +/** + * By default, model is supported and has no message. So if api-model-support is not configured in + * the config source, it will be supported by default. + * + *

If the model is deprecated or EOF, it will be marked in the config source and been mapped. + * + *

If message is not configured in config source, it will be Optional.empty(). + */ +public interface ApiModelSupport { + @JsonProperty + @WithDefault("SUPPORTED") + SupportStatus status(); + + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_EMPTY) + Optional message(); + + /** Enumeration of support status for an embedding or reranking model. */ + enum SupportStatus { + /** The model is supported and can be used when creating new Collections and Tables. */ + SUPPORTED, + /** + * The model is deprecated and may be removed in future versions. Data API supports read and + * write on DEPRECATED model, createCollection and CreateTable are forbidden. + */ + DEPRECATED, + /** + * The model is no longer supported and should not be used. Data API does not support read, + * write, createCollection, createTable for END_OF_LIFE model. + */ + END_OF_LIFE + } + + record ApiModelSupportImpl(ApiModelSupport.SupportStatus status, Optional message) + implements ApiModelSupport {} +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/EmbeddingAndRerankingConfigSourceProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/EmbeddingAndRerankingConfigSourceProvider.java index 5e03673e0a..97c96a931d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/EmbeddingAndRerankingConfigSourceProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/EmbeddingAndRerankingConfigSourceProvider.java @@ -23,8 +23,9 @@ * EmbeddingAndRerankingConfigSourceProvider#RERANKING_CONFIG_FILE_PATH} variable set, Data * API loads provider config from specified file location. E.G. Astra Data API. *

  • With system property {@link - * EmbeddingAndRerankingConfigSourceProvider#DEFAULT_RERANKING_CONFIG_RESOURCE_OVERRIDE} set, - * it will override the provider config resource. E.G. Data API integration test. + * EmbeddingAndRerankingConfigSourceProvider#SYSTEM_PROPERTY_EMBEDDING_CONFIG_RESOURCE}, + * {@link EmbeddingAndRerankingConfigSourceProvider#SYSTEM_PROPERTY_RERANKING_CONFIG_RESOURCE} + * set, it will override the provider config resource. E.G. Data API integration test. *
  • With none set, Data API loads default provider config from resource folder. E.G. Local * development. * @@ -44,10 +45,15 @@ public class EmbeddingAndRerankingConfigSourceProvider implements ConfigSourcePr private static final String DEFAULT_EMBEDDING_CONFIG_RESOURCE = "embedding-providers-config.yaml"; // Default reranking config resource. private static final String DEFAULT_RERANKING_CONFIG_RESOURCE = "reranking-providers-config.yaml"; - // System property name to override reranking config resource. Could be set by integration test - // resource. - private static final String DEFAULT_RERANKING_CONFIG_RESOURCE_OVERRIDE = - "DEFAULT_RERANKING_CONFIG_RESOURCE_OVERRIDE"; + + // System property name to override embedding config resource. + // Is usually set by integration test resource. + private static final String SYSTEM_PROPERTY_EMBEDDING_CONFIG_RESOURCE = + "EMBEDDING_CONFIG_RESOURCE"; + // System property name to override reranking config resource. + // Is usually set by integration test resource. + private static final String SYSTEM_PROPERTY_RERANKING_CONFIG_RESOURCE = + "RERANKING_CONFIG_RESOURCE"; @Override public Iterable getConfigSources(ClassLoader forClassLoader) { @@ -69,14 +75,23 @@ public Iterable getConfigSources(ClassLoader forClassLoader) { *
  • With env variable {@link * EmbeddingAndRerankingConfigSourceProvider#EMBEDDING_CONFIG_FILE_PATH} set, Data API loads * provider config from specified file location. E.G. Data API astra deployment. + *
  • With system property {@link + * EmbeddingAndRerankingConfigSourceProvider#SYSTEM_PROPERTY_EMBEDDING_CONFIG_RESOURCE} set, + * it indicated Data API is running for integration tests, then override the default config + * resource. *
  • If the env is not set, use the default config from the resources folder. * */ private ConfigSource getEmbeddingConfigSources(ClassLoader forClassLoader) throws IOException { String filePathFromEnv = System.getenv(EMBEDDING_CONFIG_FILE_PATH); + String resourceOverride = System.getProperty(SYSTEM_PROPERTY_EMBEDDING_CONFIG_RESOURCE); + if (filePathFromEnv != null) { LOGGER.info("Loading embedding config from file path: {}", filePathFromEnv); return loadConfigSourceFromFile(filePathFromEnv); + } else if (resourceOverride != null && !resourceOverride.isBlank()) { + LOGGER.info("Loading embedding config from override resource: {}", resourceOverride); + return loadConfigSourceFromResource(resourceOverride, forClassLoader); } else { LOGGER.info( "Loading embedding config from default resource file : {}", @@ -93,23 +108,23 @@ private ConfigSource getEmbeddingConfigSources(ClassLoader forClassLoader) throw * EmbeddingAndRerankingConfigSourceProvider#RERANKING_CONFIG_FILE_PATH} set, Data API loads * provider config from specified file location. E.G. Data API astra deployment. *
  • With system property {@link - * EmbeddingAndRerankingConfigSourceProvider#DEFAULT_RERANKING_CONFIG_RESOURCE_OVERRIDE} - * set, it indicated Data API is running for integration tests, then override the default - * config resource. + * EmbeddingAndRerankingConfigSourceProvider#SYSTEM_PROPERTY_RERANKING_CONFIG_RESOURCE} set, + * it indicated Data API is running for integration tests, then override the default config + * resource. *
  • If none is set, use the default config from the resources folder. E.G. Local development * mode. * */ private ConfigSource getRerankingConfigSources(ClassLoader forClassLoader) throws IOException { String filePathFromEnv = System.getenv(RERANKING_CONFIG_FILE_PATH); - String resourceOverride = System.getProperty(DEFAULT_RERANKING_CONFIG_RESOURCE_OVERRIDE); + String resourceOverride = System.getProperty(SYSTEM_PROPERTY_RERANKING_CONFIG_RESOURCE); if (filePathFromEnv != null) { LOGGER.info("Loading reranking config from file path: {}", filePathFromEnv); return loadConfigSourceFromFile(filePathFromEnv); - } else if (resourceOverride != null) { + } else if (resourceOverride != null && !resourceOverride.isBlank()) { LOGGER.info("Loading reranking config from override resource: {}", resourceOverride); - return loadConfigSourceFromResource("test-reranking-providers-config.yaml", forClassLoader); + return loadConfigSourceFromResource(resourceOverride, forClassLoader); } else { LOGGER.info( "Loading reranking config from default resource: {}", DEFAULT_RERANKING_CONFIG_RESOURCE); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelSupport.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelSupport.java deleted file mode 100644 index 0845146c6d..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelSupport.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.provider; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import io.smallrye.config.WithDefault; -import java.util.Optional; - -/** - * By default, model is supported and has no message. So if model-support is not configured in the - * config source, it will be supported by default. - * - *

    If the model is deprecated or EOF, it will be marked in the config source and been mapped. - * - *

    If message is not configured in config source, it will be Optional.empty(). - */ -public interface ModelSupport { - @JsonProperty - @WithDefault("SUPPORTED") - SupportStatus status(); - - @JsonProperty - @JsonInclude(JsonInclude.Include.NON_EMPTY) - Optional message(); - - enum SupportStatus { - SUPPORTED, - DEPRECATED, - END_OF_LIFE - } - - record ModelSupportImpl(ModelSupport.SupportStatus status, Optional message) - implements ModelSupport {} -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java index 1f4ce1a9d5..fc9b348d40 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java @@ -6,7 +6,7 @@ /** * If the model usage was for indexing data or searching data * - *

    Keeps in parity with the grp proto definition in embedding_gateway.proto + *

    Keeps in parity with the grpc proto definition in embedding_gateway.proto */ public enum ModelType { MODEL_TYPE_UNSPECIFIED, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java index 5b5ad0f8c6..a9aa69110b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java @@ -3,9 +3,11 @@ import com.fasterxml.jackson.databind.JsonNode; import io.smallrye.mutiny.Uni; import io.stargate.embedding.gateway.EmbeddingGateway; +import io.stargate.sgv2.jsonapi.exception.SchemaException; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import java.time.Duration; +import java.util.Map; import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -16,17 +18,17 @@ public abstract class ProviderBase { private final ModelProvider modelProvider; private final ModelType modelType; - private final String modelName; - protected ProviderBase(ModelProvider modelProvider, ModelType modelType, String modelName) { + // TODO: the Embedding and Rerank code *does not* share model configs, but they can & should do, + // so we we cannot pass the model into the base until we refactor the code. + protected ProviderBase(ModelProvider modelProvider, ModelType modelType) { this.modelProvider = modelProvider; this.modelType = modelType; - this.modelName = modelName; } - public String modelName() { - return modelName; - } + public abstract String modelName(); + + public abstract ApiModelSupport modelSupport(); public ModelProvider modelProvider() { return modelProvider; @@ -133,6 +135,32 @@ protected String responseErrorMessage(JsonNode rootNode) { return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); } + /** + * Checks if the vectorization will use an END_OF_LIFE model and throws an exception if it is. + * + *

    As part of embedding model deprecation ability, any read and write with vectorization in an + * END_OF_LIFE model will throw an exception. + * + *

    Note, SUPPORTED and DEPRECATED models are still allowed to be used in read and write. + * + *

    This method should be called before any vectorization operation. + */ + protected void checkEOLModelUsage() { + + if (modelSupport().status() == ApiModelSupport.SupportStatus.END_OF_LIFE) { + throw SchemaException.Code.END_OF_LIFE_AI_MODEL.get( + Map.of( + "model", + modelName(), + "modelStatus", + modelSupport().status().name(), + "message", + modelSupport() + .message() + .orElse("The model is no longer supported (reached its end-of-life)."))); + } + } + protected ModelUsage createModelUsage( String tenantId, ModelInputType modelInputType, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProviderConfigProducer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProviderConfigProducer.java index 1e3695d7ef..fae60bcccb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProviderConfigProducer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProviderConfigProducer.java @@ -1,7 +1,5 @@ package io.stargate.sgv2.jsonapi.service.reranking.configuration; -import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.*; - import io.quarkus.grpc.GrpcClient; import io.quarkus.runtime.Startup; import io.stargate.embedding.gateway.EmbeddingGateway; @@ -9,7 +7,8 @@ import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.ServerException; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionRerankDef; import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.inject.Produces; import java.time.temporal.ChronoUnit; @@ -71,6 +70,11 @@ RerankingProvidersConfig produce( } validateRerankingProvidersConfig(rerankingProvidersConfig); + + // Initialize the default reranking provider and model in DefaultRerankingProviderDef as + // Singleton. + CollectionRerankDef.initializeDefaultRerankDef(rerankingProvidersConfig); + return rerankingProvidersConfig; } @@ -192,10 +196,11 @@ private RerankingProvidersConfig.RerankingProviderConfig createRerankingProvider model -> new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl( model.getName(), - new ModelSupport.ModelSupportImpl( - ModelSupport.SupportStatus.valueOf(model.getModelSupport().getStatus()), - model.getModelSupport().hasMessage() - ? Optional.of(model.getModelSupport().getMessage()) + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.valueOf( + model.getApiModelSupport().getStatus()), + model.getApiModelSupport().hasMessage() + ? Optional.of(model.getApiModelSupport().getMessage()) : Optional.empty()), model.getIsDefault(), model.getUrl(), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfig.java index b6bfeffb6f..4978debeff 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfig.java @@ -2,7 +2,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.smallrye.config.WithDefault; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionRerankDef; import java.util.List; import java.util.Map; @@ -56,11 +56,11 @@ interface ModelConfig { String name(); /** - * modelSupport marks the support status of the model and optional message for the + * apiModelSupport marks the support status of the model and optional message for the * deprecation, EOL etc. */ @JsonProperty - ModelSupport modelSupport(); + ApiModelSupport apiModelSupport(); @JsonProperty @WithDefault("false") diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfigImpl.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfigImpl.java index e93fadded8..cc5e0897f9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfigImpl.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/configuration/RerankingProvidersConfigImpl.java @@ -1,6 +1,6 @@ package io.stargate.sgv2.jsonapi.service.reranking.configuration; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import java.util.List; import java.util.Map; @@ -23,7 +23,7 @@ public record TokenConfigImpl(String accepted, String forwarded) implements Toke public record ModelConfigImpl( String name, - ModelSupport modelSupport, + ApiModelSupport apiModelSupport, boolean isDefault, String url, RequestProperties properties) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java index f7c04bdf2e..19247ccdc9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java @@ -6,6 +6,7 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.provider.*; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import jakarta.ws.rs.core.Response; @@ -31,9 +32,7 @@ public abstract class RerankingProvider extends ProviderBase { protected RerankingProvider( ModelProvider modelProvider, String baseUrl, - String modelName, - RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties) { + RerankingProvidersConfig.RerankingProviderConfig.ModelConfig model) { super(modelProvider, ModelType.RERANKING, modelName); this.baseUrl = baseUrl; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolver.java index bf6ce047c1..afe748a2c1 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CommandResolver.java @@ -2,6 +2,8 @@ import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.errFmtJoin; import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.errVars; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.MetricTags.TENANT_TAG; +import static io.stargate.sgv2.jsonapi.metrics.MetricsConstants.UNKNOWN_VALUE; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; @@ -186,9 +188,6 @@ default Operation resolveDatabaseCommand( command.getClass().getSimpleName(), ctx.schemaObject().name()); } - static final String UNKNOWN_VALUE = "unknown"; - static final String TENANT_TAG = "tenant"; - /** * Call to track metrics for the index usage, this method is called after the command is resolved * and we know the filters we want to run. diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java index 34ce51158e..9e212c537b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateCollectionCommandResolver.java @@ -12,7 +12,6 @@ import io.stargate.sgv2.jsonapi.config.feature.ApiFeature; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject; import io.stargate.sgv2.jsonapi.service.operation.Operation; import io.stargate.sgv2.jsonapi.service.operation.collections.CreateCollectionOperation; @@ -29,7 +28,6 @@ public class CreateCollectionCommandResolver implements CommandResolver { private final ObjectMapper objectMapper; - private final CQLSessionCache cqlSessionCache; private final DocumentLimitsConfig documentLimitsConfig; private final DatabaseLimitsConfig dbLimitsConfig; private final OperationsConfig operationsConfig; @@ -39,14 +37,12 @@ public class CreateCollectionCommandResolver implements CommandResolver getCommandClass() { return CreateCollectionCommand.class; @@ -76,7 +68,7 @@ public Operation resolveKeyspaceCommand( if (options == null) { final CollectionLexicalConfig lexicalConfig = lexicalAvailableForDB - ? CollectionLexicalConfig.configForEnabledStandard() + ? CollectionLexicalConfig.configForDefault() : CollectionLexicalConfig.configForDisabled(); final CollectionRerankDef rerankDef = CollectionRerankDef.configForNewCollections( @@ -85,14 +77,15 @@ public Operation resolveKeyspaceCommand( ctx, dbLimitsConfig, objectMapper, - cqlSessionCache, + ctx.cqlSessionCache(), name, generateComment( objectMapper, false, false, name, null, null, null, lexicalConfig, rerankDef), operationsConfig.databaseConfig().ddlDelayMillis(), operationsConfig.tooManyIndexesRollbackEnabled(), false, - lexicalConfig); + lexicalConfig, + rerankDef); } boolean hasIndexing = options.indexing() != null; @@ -137,7 +130,7 @@ public Operation resolveKeyspaceCommand( ctx, dbLimitsConfig, objectMapper, - cqlSessionCache, + ctx.cqlSessionCache(), name, vector.dimension(), vector.metric(), @@ -146,19 +139,21 @@ public Operation resolveKeyspaceCommand( operationsConfig.databaseConfig().ddlDelayMillis(), operationsConfig.tooManyIndexesRollbackEnabled(), indexingDenyAll, - lexicalConfig); + lexicalConfig, + rerankDef); } else { return CreateCollectionOperation.withoutVectorSearch( ctx, dbLimitsConfig, objectMapper, - cqlSessionCache, + ctx.cqlSessionCache(), name, comment, operationsConfig.databaseConfig().ddlDelayMillis(), operationsConfig.tooManyIndexesRollbackEnabled(), indexingDenyAll, - lexicalConfig); + lexicalConfig, + rerankDef); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTextIndexCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTextIndexCommandResolver.java new file mode 100644 index 0000000000..2728e1cdad --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/CreateTextIndexCommandResolver.java @@ -0,0 +1,101 @@ +package io.stargate.sgv2.jsonapi.service.resolver; + +import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.errFmtJoin; +import static io.stargate.sgv2.jsonapi.util.ApiOptionUtils.getOrDefault; + +import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; +import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateTextIndexCommand; +import io.stargate.sgv2.jsonapi.config.OperationsConfig; +import io.stargate.sgv2.jsonapi.config.constants.TableDescDefaults; +import io.stargate.sgv2.jsonapi.exception.SchemaException; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.DefaultDriverExceptionHandler; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; +import io.stargate.sgv2.jsonapi.service.operation.Operation; +import io.stargate.sgv2.jsonapi.service.operation.SchemaDBTask; +import io.stargate.sgv2.jsonapi.service.operation.SchemaDBTaskPage; +import io.stargate.sgv2.jsonapi.service.operation.tables.CreateIndexDBTask; +import io.stargate.sgv2.jsonapi.service.operation.tables.CreateIndexDBTaskBuilder; +import io.stargate.sgv2.jsonapi.service.operation.tables.CreateIndexExceptionHandler; +import io.stargate.sgv2.jsonapi.service.operation.tasks.TaskGroup; +import io.stargate.sgv2.jsonapi.service.operation.tasks.TaskOperation; +import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules; +import io.stargate.sgv2.jsonapi.service.schema.tables.ApiIndexType; +import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTextIndex; +import jakarta.enterprise.context.ApplicationScoped; +import java.time.Duration; +import java.util.Map; + +/** Resolver for the {@link CreateTextIndexCommand}. */ +@ApplicationScoped +public class CreateTextIndexCommandResolver implements CommandResolver { + + @Override + public Class getCommandClass() { + return CreateTextIndexCommand.class; + } + + @Override + public Operation resolveTableCommand( + CommandContext commandContext, CreateTextIndexCommand command) { + + final var indexName = validateSchemaName(command.name(), NamingRules.INDEX); + + var indexType = + command.indexType() == null + ? ApiIndexType.TEXT + : ApiIndexType.fromApiName(command.indexType()); + + if (indexType == null) { + throw SchemaException.Code.UNKNOWN_INDEX_TYPE.get( + Map.of( + "knownTypes", + errFmtJoin(ApiIndexType.values(), ApiIndexType::apiName), + "unknownType", + command.indexType())); + } + + if (indexType != ApiIndexType.TEXT) { + throw SchemaException.Code.UNSUPPORTED_INDEX_TYPE.get( + Map.of( + "supportedTypes", + ApiIndexType.TEXT.apiName(), + "unsupportedType", + command.indexType())); + } + + // TODO: we need a centralised way of creating retry attempt. + CreateIndexDBTaskBuilder taskBuilder = + CreateIndexDBTask.builder(commandContext.schemaObject()) + .withIfNotExists( + getOrDefault( + command.options(), + CreateTextIndexCommand.CommandOptions::ifNotExists, + TableDescDefaults.CreateTextIndexOptionsDefaults.IF_NOT_EXISTS)) + .withSchemaRetryPolicy( + new SchemaDBTask.SchemaRetryPolicy( + commandContext + .config() + .get(OperationsConfig.class) + .databaseConfig() + .ddlRetries(), + Duration.ofMillis( + commandContext + .config() + .get(OperationsConfig.class) + .databaseConfig() + .ddlRetryDelayMillis()))); + + // this will throw APIException if the index is not supported + ApiTextIndex apiIndex = + ApiTextIndex.FROM_DESC_FACTORY.create( + commandContext.schemaObject(), indexName, command.definition()); + taskBuilder.withExceptionHandlerFactory( + DefaultDriverExceptionHandler.Factory.withIdentifier( + CreateIndexExceptionHandler::new, apiIndex.indexName())); + + var taskGroup = new TaskGroup<>(taskBuilder.build(apiIndex)); + + return new TaskOperation<>( + taskGroup, SchemaDBTaskPage.accumulator(CreateIndexDBTask.class, commandContext)); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/DeleteOneCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/DeleteOneCommandResolver.java index 3b19b52c84..353588b17c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/DeleteOneCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/DeleteOneCommandResolver.java @@ -66,7 +66,7 @@ public Operation resolveTableCommand( CommandContext commandContext, DeleteOneCommand command) { // Sort clause is not supported for table deleteOne command. - if (command.sortClause() != null && !command.sortClause().isEmpty()) { + if (!command.sortClause(commandContext).isEmpty()) { throw SortException.Code.UNSUPPORTED_SORT_FOR_TABLE_DELETE_COMMAND.get( errVars(commandContext.schemaObject(), map -> {})); } @@ -105,7 +105,7 @@ private FindCollectionOperation getFindOperation( var dbLogicalExpression = collectionFilterResolver.resolve(commandContext, command).target(); - final SortClause sortClause = command.sortClause(); + final SortClause sortClause = command.sortClause(commandContext); if (sortClause != null) { sortClause.validate(commandContext.schemaObject()); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java index 13ba30e09c..e75e3c2c89 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java @@ -7,6 +7,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import io.stargate.sgv2.jsonapi.api.model.command.*; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.model.command.impl.FindAndRerankCommand; @@ -23,7 +24,7 @@ import io.stargate.sgv2.jsonapi.service.operation.embeddings.EmbeddingTaskGroupBuilder; import io.stargate.sgv2.jsonapi.service.operation.reranking.*; import io.stargate.sgv2.jsonapi.service.operation.tasks.*; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.service.shredding.Deferrable; @@ -173,15 +174,19 @@ private void checkSupported() { var modelConfig = rerankingProvidersConfig.filterByRerankServiceDef( commandContext.schemaObject().rerankingConfig().rerankServiceDef()); - if (modelConfig.modelSupport().status() == ModelSupport.SupportStatus.END_OF_LIFE) { - throw SchemaException.Code.UNSUPPORTED_PROVIDER_MODEL.get( + // Validate if the model is END_OF_LIFE + if (modelConfig.apiModelSupport().status() == ApiModelSupport.SupportStatus.END_OF_LIFE) { + throw SchemaException.Code.END_OF_LIFE_AI_MODEL.get( Map.of( "model", modelConfig.name(), "modelStatus", - modelConfig.modelSupport().status().name(), + modelConfig.apiModelSupport().status().name(), "message", - modelConfig.modelSupport().message().orElse("The model is not supported."))); + modelConfig + .apiModelSupport() + .message() + .orElse("The model is no longer supported (reached its end-of-life)."))); } } @@ -294,7 +299,10 @@ private IntermediateCollectionReadTask buildBm25Read(DeferredCommandResultAction var bm25SortClause = new SortClause(List.of(SortExpression.bm25Search(bm25SortTerm))); var bm25ReadCommand = new FindCommand( - command.filterSpec(), INCLUDE_ALL_PROJECTION, bm25SortClause, buildFindOptions(false)); + command.filterDefinition(), + INCLUDE_ALL_PROJECTION, + SortDefinition.wrap(bm25SortClause), + buildFindOptions(false)); return new IntermediateCollectionReadTask( 0, @@ -339,7 +347,10 @@ private IntermediateCollectionReadTask buildBm25Read(DeferredCommandResultAction // The intermediate task will set the sort when we give it the deferred vectorize var vectorReadCommand = new FindCommand( - command.filterSpec(), INCLUDE_ALL_PROJECTION, sortClause, buildFindOptions(true)); + command.filterDefinition(), + INCLUDE_ALL_PROJECTION, + SortDefinition.wrap(sortClause), + buildFindOptions(true)); var readTask = new IntermediateCollectionReadTask( 1, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCollectionsCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCollectionsCommandResolver.java index f0635911a3..dc54ef1b66 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCollectionsCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCollectionsCommandResolver.java @@ -3,7 +3,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.impl.FindCollectionsCommand; -import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject; import io.stargate.sgv2.jsonapi.service.operation.Operation; import io.stargate.sgv2.jsonapi.service.operation.collections.FindCollectionsCollectionOperation; @@ -14,13 +13,10 @@ @ApplicationScoped public class FindCollectionsCommandResolver implements CommandResolver { private final ObjectMapper objectMapper; - private final CQLSessionCache cqlSessionCache; @Inject - public FindCollectionsCommandResolver( - ObjectMapper objectMapper, CQLSessionCache cqlSessionCache) { + public FindCollectionsCommandResolver(ObjectMapper objectMapper) { this.objectMapper = objectMapper; - this.cqlSessionCache = cqlSessionCache; } /** {@inheritDoc} */ @@ -35,6 +31,7 @@ public Operation resolveKeyspaceCommand( CommandContext ctx, FindCollectionsCommand command) { boolean explain = command.options() != null ? command.options().explain() : false; - return new FindCollectionsCollectionOperation(explain, objectMapper, cqlSessionCache, ctx); + return new FindCollectionsCollectionOperation( + explain, objectMapper, ctx.cqlSessionCache(), ctx); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCommandResolver.java index f143b4c5c4..8496325d78 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindCommandResolver.java @@ -95,7 +95,7 @@ public Operation resolveCollectionCommand( includeSortVector = options.includeSortVector(); } - final SortClause sortClause = command.sortClause(); + final SortClause sortClause = command.sortClause(commandContext); // collection always uses in memory sorting, so we don't support page state with sort clause // empty sort clause and empty page state are treated as no sort clause and no page state diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindEmbeddingProvidersCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindEmbeddingProvidersCommandResolver.java index 43799a8794..791549bbac 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindEmbeddingProvidersCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindEmbeddingProvidersCommandResolver.java @@ -32,6 +32,6 @@ public Operation resolveDatabaseCommand( if (!operationsConfig.vectorizeEnabled()) { throw ErrorCodeV1.VECTORIZE_FEATURE_NOT_AVAILABLE.toApiException(); } - return new FindEmbeddingProvidersOperation(embeddingProvidersConfig); + return new FindEmbeddingProvidersOperation(command, embeddingProvidersConfig); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndDeleteCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndDeleteCommandResolver.java index c4fe46a53b..d4aeb39687 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndDeleteCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndDeleteCommandResolver.java @@ -68,10 +68,8 @@ private FindCollectionOperation getFindOperation( CommandContext commandContext, FindOneAndDeleteCommand command) { var dbLogicalExpression = collectionFilterResolver.resolve(commandContext, command).target(); - final SortClause sortClause = command.sortClause(); - if (sortClause != null) { - sortClause.validate(commandContext.schemaObject()); - } + final SortClause sortClause = command.sortClause(commandContext); + sortClause.validate(commandContext.schemaObject()); float[] vector = SortClauseUtil.resolveVsearch(sortClause); var indexUsage = commandContext.schemaObject().newCollectionIndexUsage(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndReplaceCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndReplaceCommandResolver.java index 362d478e5b..3976a93337 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndReplaceCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndReplaceCommandResolver.java @@ -102,10 +102,8 @@ private FindCollectionOperation getFindOperation( var dbLogicalExpression = collectionFilterResolver.resolve(commandContext, command).target(); - final SortClause sortClause = command.sortClause(); - if (sortClause != null) { - sortClause.validate(commandContext.schemaObject()); - } + final SortClause sortClause = command.sortClause(commandContext); + sortClause.validate(commandContext.schemaObject()); float[] vector = SortClauseUtil.resolveVsearch(sortClause); var indexUsage = commandContext.schemaObject().newCollectionIndexUsage(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndUpdateCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndUpdateCommandResolver.java index 81885d7d8a..91a8fb6ba0 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndUpdateCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneAndUpdateCommandResolver.java @@ -95,10 +95,8 @@ private FindCollectionOperation getFindOperation( var dbLogicalExpression = collectionFilterResolver.resolve(commandContext, command).target(); - final SortClause sortClause = command.sortClause(); - if (sortClause != null) { - sortClause.validate(commandContext.schemaObject()); - } + final SortClause sortClause = command.sortClause(commandContext); + sortClause.validate(commandContext.schemaObject()); float[] vector = SortClauseUtil.resolveVsearch(sortClause); var indexUsage = commandContext.schemaObject().newCollectionIndexUsage(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneCommandResolver.java index 23292f8e5c..0ea66cc571 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindOneCommandResolver.java @@ -68,10 +68,8 @@ public Operation resolveCollectionCommand( final DBLogicalExpression dbLogicalExpression = collectionFilterResolver.resolve(commandContext, command).target(); - final SortClause sortClause = command.sortClause(); - if (sortClause != null) { - sortClause.validate(commandContext.schemaObject()); - } + final SortClause sortClause = command.sortClause(commandContext); + sortClause.validate(commandContext.schemaObject()); float[] vector = SortClauseUtil.resolveVsearch(sortClause); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/ListTablesCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/ListTablesCommandResolver.java index a1df4abfd5..f1a70dd1cc 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/ListTablesCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/ListTablesCommandResolver.java @@ -4,7 +4,6 @@ import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; import io.stargate.sgv2.jsonapi.api.model.command.CommandStatus; import io.stargate.sgv2.jsonapi.api.model.command.impl.ListTablesCommand; -import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.KeyspaceSchemaObject; import io.stargate.sgv2.jsonapi.service.operation.*; import io.stargate.sgv2.jsonapi.service.operation.tables.KeyspaceDriverExceptionHandler; @@ -17,12 +16,10 @@ @ApplicationScoped public class ListTablesCommandResolver implements CommandResolver { private final ObjectMapper objectMapper; - private final CQLSessionCache cqlSessionCache; @Inject - public ListTablesCommandResolver(ObjectMapper objectMapper, CQLSessionCache cqlSessionCache) { + public ListTablesCommandResolver(ObjectMapper objectMapper) { this.objectMapper = objectMapper; - this.cqlSessionCache = cqlSessionCache; } /** {@inheritDoc} */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/TableReadDBOperationBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/TableReadDBOperationBuilder.java index 7be0112604..7b1553ba18 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/TableReadDBOperationBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/TableReadDBOperationBuilder.java @@ -18,8 +18,6 @@ import io.stargate.sgv2.jsonapi.service.resolver.sort.TableMemorySortClauseResolver; import java.util.List; import java.util.Objects; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Encapsulates resolving a read command into a {@link Operation}, which includes building the tasks @@ -30,9 +28,6 @@ */ class TableReadDBOperationBuilder< CmdT extends ReadCommand & Filterable & Projectable & Sortable & Windowable & VectorSortable> { - - private static final Logger LOGGER = LoggerFactory.getLogger(TableReadDBOperationBuilder.class); - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private final CommandContext commandContext; @@ -112,8 +107,7 @@ public Operation build() { // the columns the user wants // NOTE: the TableProjection is doing double duty as the select and the operation projection - var projection = - TableProjection.fromDefinition(OBJECT_MAPPER, command, commandContext.schemaObject()); + var projection = TableProjection.fromDefinition(commandContext, OBJECT_MAPPER, command); taskBuilder.withSelect(WithWarnings.of(projection)); taskBuilder.withProjection(projection); @@ -132,7 +126,7 @@ public Operation build() { taskGroup, ReadDBTaskPage.accumulator(commandContext) .singleResponse(singleResponse) - .mayReturnVector(command), + .mayReturnVector(commandContext, command), List.of(orderByWithWarnings.target())); return EmbeddingOperationFactory.createOperation(commandContext, tasksAndDeferrables); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/UpdateOneCommandResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/UpdateOneCommandResolver.java index 06d813c67c..3c9668e6c4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/UpdateOneCommandResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/UpdateOneCommandResolver.java @@ -76,7 +76,7 @@ public Operation resolveTableCommand( CommandContext commandContext, UpdateOneCommand command) { // Sort clause is not supported for table updateOne command. - if (command.sortClause() != null && !command.sortClause().isEmpty()) { + if (!command.sortClause(commandContext).isEmpty()) { throw SortException.Code.UNSUPPORTED_SORT_FOR_TABLE_UPDATE_COMMAND.get( errVars(commandContext.schemaObject(), map -> {})); } @@ -140,7 +140,7 @@ private FindCollectionOperation getFindOperation( var dbLogicalExpression = collectionFilterResolver.resolve(commandContext, command).target(); - final SortClause sortClause = command.sortClause(); + final SortClause sortClause = command.sortClause(commandContext); if (sortClause != null) { sortClause.validate(commandContext.schemaObject()); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java index 80534c9a2a..4116bd5e03 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java @@ -4,8 +4,14 @@ import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +<<<<<<< HEAD import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +======= +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +>>>>>>> main import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import java.util.ArrayList; @@ -68,6 +74,11 @@ public Integer validateService(VectorizeConfig userConfig, Integer userVectorDim Integer vectorDimension = validateModelAndDimension(userConfig, providerConfig, userVectorDimension); + // Model must be SUPPORTED to do the schema operation. + // Note, validateService method is only triggered for createCollection/createTable/alterTable. + // So we can add validation here to cut off DEPRECATED/END_OF_LIFE model for schema creation. + checkModelSupportForSchemaCreation(userConfig); + // Validate user-provided parameters against internal expectations validateUserParameters(userConfig, providerConfig); @@ -330,19 +341,7 @@ private Integer validateModelAndDimension( } // 2. other providers do require model - if (userConfig.modelName() == null) { - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "'modelName' is needed for provider %s", userConfig.provider()); - } - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model = - providerConfig.models().stream() - .filter(m -> m.name().equals(userConfig.modelName())) - .findFirst() - .orElseThrow( - () -> - ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "Model name '%s' for provider '%s' is not supported", - userConfig.modelName(), userConfig.provider())); + var model = getModelConfig(userConfig, providerConfig); // Handle models with a fixed vector dimension if (model.vectorDimension().isPresent() && model.vectorDimension().get() != 0) { @@ -365,6 +364,24 @@ private Integer validateModelAndDimension( .orElse(userVectorDimension); // should not go here } + /** Retrieves the model configuration for the specified provider and model. */ + private EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig getModelConfig( + VectorizeConfig userConfig, EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + + if (userConfig.modelName() == null) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "'modelName' is needed for provider %s", userConfig.provider()); + } + return providerConfig.models().stream() + .filter(m -> m.name().equals(userConfig.modelName())) + .findFirst() + .orElseThrow( + () -> + ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Model name '%s' for provider '%s' is not supported", + userConfig.modelName(), userConfig.provider())); + } + /** * Validates the user-provided vector dimension against the dimension parameter's validation * constraints. @@ -411,4 +428,41 @@ private Integer validateRangeDimension( } return userVectorDimension; } + + /** + * Validates the model support for the vectorization service. This method checks if the model is + * SUPPORTED and throw corresponding error if the model's DEPRECATED or END_OF_LIFE. + * + *

    Note, this validation will only happen for schema change, E.G. + * createCollection/createTable/alterTable. When loading schemaObject from existing + * collection/table from DB, this validation will not be called, since the model is already set + * and we don't want to cut off the normal usage for existing collection/table. + */ + private void checkModelSupportForSchemaCreation(VectorizeConfig service) { + + // 1. Check if the service provider exists and is enabled + var providerConfig = getAndValidateProviderConfig(service); + + // 2. other providers do require model + var model = getModelConfig(service, providerConfig); + + // 3. validate model support + if (model.apiModelSupport().status() != ApiModelSupport.SupportStatus.SUPPORTED) { + var errorCode = + model.apiModelSupport().status() == ApiModelSupport.SupportStatus.DEPRECATED + ? SchemaException.Code.DEPRECATED_AI_MODEL + : SchemaException.Code.END_OF_LIFE_AI_MODEL; + throw errorCode.get( + Map.of( + "model", + model.name(), + "modelStatus", + model.apiModelSupport().status().name(), + "message", + model + .apiModelSupport() + .message() + .orElse("The model is %s.".formatted(model.apiModelSupport().status().name())))); + } + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CaptureGroups.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CaptureGroups.java index 86194a89b1..39961ddf0a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CaptureGroups.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CaptureGroups.java @@ -103,7 +103,7 @@ public void consumeAll( // represents the logical // relation in the API query, so we need to create a logical expression for each CaptureGroups DBLogicalExpression subDBLogicalExpression = - currentDbLogicalExpression.addSubExpression( + currentDbLogicalExpression.addSubExpressionReturnSub( new DBLogicalExpression(captureGroups.getLogicalOperator())); captureGroups.consumeAll(subDBLogicalExpression, consumer); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CollectionFilterResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CollectionFilterResolver.java index 92a8c5609b..f0ac6c5e1b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CollectionFilterResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/matcher/CollectionFilterResolver.java @@ -43,6 +43,8 @@ public class CollectionFilterResolver private static final Object SIZE_GROUP = new Object(); private static final Object ARRAY_EQUALS = new Object(); private static final Object SUB_DOC_EQUALS = new Object(); + // For $match on $lexical + private static final Object MATCH_GROUP = new Object(); public CollectionFilterResolver(OperationsConfig operationsConfig) { super(operationsConfig); @@ -156,7 +158,11 @@ protected FilterMatchRules buildMatchRules() { .compareValues( "*", EnumSet.of(ValueComparisonOperator.EQ, ValueComparisonOperator.NE), - JsonType.SUB_DOC); + JsonType.SUB_DOC) + .capture(MATCH_GROUP) + .compareValues( + // Should be "$lexical" but validated elsewhere + "*", EnumSet.of(ValueComparisonOperator.MATCH), JsonType.STRING); return matchRules; } @@ -272,6 +278,23 @@ public static DBLogicalExpression findDynamic( idRangeGroup.consumeAllCaptures( expression -> { final DocumentId value = (DocumentId) expression.value(); + // E.G. {"_id":{"$gt":true}} + if (value.value() instanceof Boolean) { + dbLogicalExpression.addFilter( + new BoolCollectionFilter( + DocumentConstants.Fields.DOC_ID, + getMapFilterBaseOperator(expression.operator()), + (Boolean) value.value())); + } + // E.G. {"_id":{"$gt":"apple"}} + if (value.value() instanceof String) { + dbLogicalExpression.addFilter( + new TextCollectionFilter( + DocumentConstants.Fields.DOC_ID, + getMapFilterBaseOperator(expression.operator()), + (String) value.value())); + } + // E.G. {"_id":{"$gt":123}} if (value.value() instanceof BigDecimal bdv) { dbLogicalExpression.addFilter( new NumberCollectionFilter( @@ -279,6 +302,7 @@ public static DBLogicalExpression findDynamic( getMapFilterBaseOperator(expression.operator()), bdv)); } + // E.G. {"_id":{"$gt":{"$date":1672531200000}}} if (value.value() instanceof Map) { dbLogicalExpression.addFilter( new DateCollectionFilter( @@ -481,6 +505,19 @@ public static DBLogicalExpression findDynamic( : MapCollectionFilter.Operator.MAP_NOT_EQUALS)); }); }); + + captureGroups + .getGroupIfPresent(MATCH_GROUP) + .ifPresent( + captureGroup -> { + CaptureGroup matchGroup = (CaptureGroup) captureGroup; + matchGroup.consumeAllCaptures( + expression -> { + dbLogicalExpression.addFilter( + new MatchCollectionFilter( + expression.path(), (String) expression.value())); + }); + }); }; currentCaptureGroups.consumeAll(currentDbLogicalExpression, consumer); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableCqlSortClauseResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableCqlSortClauseResolver.java index 23fce014c2..3e95da2c74 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableCqlSortClauseResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableCqlSortClauseResolver.java @@ -63,7 +63,7 @@ public WithWarnings resolve( Objects.requireNonNull(commandContext, "commandContext is required"); Objects.requireNonNull(command, "command is required"); - var sortClause = command.sortClause(); + var sortClause = command.sortClause(commandContext); if (sortClause == null || sortClause.isEmpty()) { LOGGER.debug("Sort clause is null or empty, no CQL ORDER BY needed."); return WithWarnings.of(OrderByCqlClause.NO_OP); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableMemorySortClauseResolver.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableMemorySortClauseResolver.java index 52394b7865..bc74002de2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableMemorySortClauseResolver.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/sort/TableMemorySortClauseResolver.java @@ -60,14 +60,14 @@ public WithWarnings resolve( if (orderByCqlClause.fullyCoversCommand()) { // Cql Order by is enough to handle the sort clause, no need for a row sorter // this will also cover where there is no sorting. - LOGGER.debug("No in memory sort needed, using CQL order by"); + LOGGER.debug("No in-memory sort needed, using CQL order by"); return WithWarnings.of(RowSorter.NO_OP); } // Just a sanity check, - var sortClause = command.sortClause(); + var sortClause = command.sortClause(commandContext); if (sortClause == null || sortClause.isEmpty()) { - LOGGER.debug("No in memory sort needed, sort clause was null or empty"); + LOGGER.debug("No in-memory sort needed, sort clause was null or empty"); return WithWarnings.of(RowSorter.NO_OP); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionLexicalConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionLexicalConfig.java index 6de40e9f3a..cc41b0132a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionLexicalConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionLexicalConfig.java @@ -8,7 +8,12 @@ import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateCollectionCommand; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.util.JsonUtil; +import java.util.Arrays; +import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.TreeSet; +import java.util.stream.Collectors; /** Validated configuration Object for Lexical (BM-25) indexing configuration for Collections. */ public record CollectionLexicalConfig( @@ -20,6 +25,16 @@ public record CollectionLexicalConfig( private static final JsonNode DEFAULT_NAMED_ANALYZER_NODE = JsonNodeFactory.instance.textNode(DEFAULT_NAMED_ANALYZER); + private static final CollectionLexicalConfig DEFAULT_CONFIG = + new CollectionLexicalConfig(true, DEFAULT_NAMED_ANALYZER_NODE); + + private static final CollectionLexicalConfig MISSING_CONFIG = + new CollectionLexicalConfig(false, null); + + // TreeSet just to retain alphabetic order for error message + private static final Set VALID_ANALYZER_FIELDS = + new TreeSet<>(Arrays.asList("charFilters", "filters", "tokenizer")); + /** * Constructs a lexical configuration with the specified enabled state and analyzer definition. * @@ -48,7 +63,7 @@ public CollectionLexicalConfig(boolean enabled, JsonNode analyzerDefinition) { || (analyzerDefinition.isObject() && analyzerDefinition.isEmpty()); if (!isAcceptableWhenDisabled) { throw new IllegalArgumentException( - "Analyzer definition should be omitted, JSON null, or an empty JSON object {} if if lexical is disabled."); + "Analyzer definition should be omitted, JSON null, or an empty JSON object {} if lexical is disabled."); } } this.analyzerDefinition = null; @@ -67,7 +82,7 @@ public static CollectionLexicalConfig validateAndConstruct( CreateCollectionCommand.Options.LexicalConfigDefinition lexicalConfig) { // Case 1: No lexical body provided - use defaults if available, otherwise disable if (lexicalConfig == null) { - return lexicalAvailableForDB ? configForEnabledStandard() : configForDisabled(); + return lexicalAvailableForDB ? configForDefault() : configForDisabled(); } // Case 2: Validate 'enabled' flag is present @@ -77,24 +92,24 @@ public static CollectionLexicalConfig validateAndConstruct( "'enabled' is required property for 'lexical' Object value"); } + // Following cases mean "analyzer" is not defined: + // 1. No JSON value + // 2. JSON value itself is null (`null`) + // 3. JSON value is an empty object (`{}`) + JsonNode analyzerDef = lexicalConfig.analyzerDef(); + final boolean analyzerNotDefined = + (analyzerDef == null) + || analyzerDef.isNull() + || (analyzerDef.isObject() && analyzerDef.isEmpty()); + // Case 3: Lexical is disabled - ensure analyzer is absent, JSON null, or empty object {} if (!enabled) { - if (lexicalConfig.analyzerDef() != null) { - // Define the acceptable states when lexical is disabled: - // 1. The JSON value itself is null (`null`) - // 2. The JSON value is an empty object (`{}`) - boolean isAcceptableWhenDisabled = - lexicalConfig.analyzerDef().isNull() - || (lexicalConfig.analyzerDef().isObject() - && lexicalConfig.analyzerDef().isEmpty()); - - if (!isAcceptableWhenDisabled) { - String nodeType = JsonUtil.nodeTypeAsString(lexicalConfig.analyzerDef()); - throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "'lexical' is disabled, but 'lexical.analyzer' property was provided with an unexpected type: %s. " - + "When 'lexical' is disabled, 'lexical.analyzer' must either be omitted, JSON null, or an empty JSON object {}.", - nodeType); - } + if (!analyzerNotDefined) { + String nodeType = JsonUtil.nodeTypeAsString(analyzerDef); + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "'lexical' is disabled, but 'lexical.analyzer' property was provided with an unexpected type: %s. " + + "When 'lexical' is disabled, 'lexical.analyzer' must either be omitted or be JSON null, or an empty Object {}.", + nodeType); } return configForDisabled(); } @@ -105,38 +120,79 @@ public static CollectionLexicalConfig validateAndConstruct( } // Case 5: Enabled and analyzer provided - validate and use - JsonNode analyzer = lexicalConfig.analyzerDef(); // Case 5a: missing/null/Empty Object - use default analyzer - if (analyzer == null || analyzer.isNull() || (analyzer.isObject() && analyzer.isEmpty())) { - analyzer = mapper.getNodeFactory().textNode(CollectionLexicalConfig.DEFAULT_NAMED_ANALYZER); - } else if (analyzer.isTextual()) { + if (analyzerNotDefined) { + analyzerDef = + mapper.getNodeFactory().textNode(CollectionLexicalConfig.DEFAULT_NAMED_ANALYZER); + } else if (analyzerDef.isTextual()) { // Case 5b: JSON String - use as-is -- Could/should we try to validate analyzer name? ; - } else if (analyzer.isObject()) { - // Case 5c: JSON Object - use as-is -- TODO? validate analyzer wrt required fields? - ; + } else if (analyzerDef.isObject()) { + // Case 5c: JSON Object - use as-is but first do light validation + Set foundNames = + analyzerDef.properties().stream().map(Map.Entry::getKey).collect(Collectors.toSet()); + // First: check for any invalid (misspelled etc) fields + foundNames.removeAll(VALID_ANALYZER_FIELDS); + if (!foundNames.isEmpty()) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "Invalid field%s for 'lexical.analyzer'. Valid fields are: %s, found: %s", + (foundNames.size() == 1 ? "" : "s"), VALID_ANALYZER_FIELDS, new TreeSet<>(foundNames)); + } + // Second: check basic data types for allowed fields + for (Map.Entry entry : analyzerDef.properties()) { + JsonNode fieldValue = entry.getValue(); + // Nulls ok for all + if (fieldValue.isNull()) { + continue; + } + String expectedType; + boolean valueOk = + switch (entry.getKey()) { + case "tokenizer" -> { + expectedType = "Object"; + yield fieldValue.isObject(); + } + default -> { + expectedType = "Array"; + yield fieldValue.isArray(); + } + }; + if (!valueOk) { + throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( + "'%s' property of 'lexical.analyzer' must be JSON %s, is: %s", + entry.getKey(), expectedType, JsonUtil.nodeTypeAsString(fieldValue)); + } + } } else { // Otherwise, invalid definition throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( "'analyzer' property of 'lexical' must be either JSON Object or String, is: %s", - JsonUtil.nodeTypeAsString(analyzer)); + JsonUtil.nodeTypeAsString(analyzerDef)); } - return new CollectionLexicalConfig(true, analyzer); + return new CollectionLexicalConfig(true, analyzerDef); } /** - * Accessor for an instance to use for "default enabled" cases, using "standard" analyzer - * configuration: typically used for new collections where lexical search is available. + * Accessor for an instance to use for "lexical disabled" Collections (but not for ones pre-dating + * lexical search feature). */ - public static CollectionLexicalConfig configForEnabledStandard() { - return new CollectionLexicalConfig(true, DEFAULT_NAMED_ANALYZER_NODE); + public static CollectionLexicalConfig configForDisabled() { + return new CollectionLexicalConfig(false, null); } /** - * Accessor for an instance to use for "lexical disabled" cases: either for existing collections - * without lexical config, or envi + * Accessor for a singleton instance used to represent case of default lexical configuration for + * newly created Collections that do not specify lexical configuration. */ - public static CollectionLexicalConfig configForDisabled() { - return new CollectionLexicalConfig(false, null); + public static CollectionLexicalConfig configForDefault() { + return DEFAULT_CONFIG; + } + + /** + * Accessor for a singleton instance used to represent case of missing lexical configuration for + * legacy Collections created before lexical search was available. + */ + public static CollectionLexicalConfig configForPreLexical() { + return MISSING_CONFIG; } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionRerankDef.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionRerankDef.java index 7b2c0200e6..80999f7158 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionRerankDef.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionRerankDef.java @@ -10,7 +10,8 @@ import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.exception.SchemaException; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProviderConfigProducer; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import java.util.*; import org.slf4j.Logger; @@ -39,9 +40,25 @@ public class CollectionRerankDef { /** * Singleton instance for disabled reranking configuration. It can be used for disabled reranking - * collections, existing pre-reranking collections, and missing collections. + * collections and missing collections. */ - public static final CollectionRerankDef DISABLED = new CollectionRerankDef(false, null); + private static final CollectionRerankDef DISABLED = new CollectionRerankDef(false, null); + + /** + * Singleton instance for disabled reranking configuration. It is to be used for existing + * pre-reranking collections. + */ + private static final CollectionRerankDef MISSING = new CollectionRerankDef(false, null); + + /** + * Singleton instance for default reranking configuration. It is used for newly created + * collections with default reranking settings. + * + *

    NOTE: this is initialized during startup (via call to {@link #initializeDefaultRerankDef} by + * {@link RerankingProviderConfigProducer}) and cannot unfortunately be made final: this because + * initialization requires access to other configuration loaded during start up. + */ + private static CollectionRerankDef DEFAULT; private static final Logger LOGGER = LoggerFactory.getLogger(CollectionRerankDef.class); @@ -103,16 +120,12 @@ public RerankServiceDef rerankServiceDef() { } /** - * Creates default reranking configuration for new collections. + * Get default reranking configuration for new collections. * *

    When a collection is created without explicit reranking settings, this method provides a * default configuration based on the reranking providers' configuration. It looks for the * provider marked as default and its default model. * - *

    If no default provider is configured in the yaml, reranking will be disabled for new - * collections. Similarly, if the default provider doesn't have a default model, reranking will be - * disabled. - * * @param isRerankingEnabledForAPI * @param rerankingProvidersConfig The configuration for all available reranking providers * @return A default-configured CollectionRerankDef @@ -124,30 +137,44 @@ public static CollectionRerankDef configForNewCollections( if (!isRerankingEnabledForAPI) { return DISABLED; } + if (DEFAULT == null) { + // DEFAULT has been set during the application startup. + throw new IllegalStateException("No default reranking definition found"); + } + return DEFAULT; + } + + /** + * Initializes the DEFAULT reranking definition as Singleton during the application startup. See + * {@link RerankingProviderConfigProducer} as caller and how the configuration is validated to + * promise a default provider and model. + */ + public static void initializeDefaultRerankDef(RerankingProvidersConfig rerankingProvidersConfig) { // Find the provider marked as default var defaultProviderEntry = rerankingProvidersConfig.providers().entrySet().stream() .filter(entry -> entry.getValue().isDefault()) .findFirst(); - // If no default provider exists, disable reranking + // There must be a default provider, otherwise it's a config bug. + // It is validated in RerankingProviderConfigProducer.class during startup. if (defaultProviderEntry.isEmpty()) { - LOGGER.debug("No default reranking provider found, disabling reranking for new collections"); - return DISABLED; + throw new IllegalStateException("No default reranking provider found"); } // Extract provider information String defaultProviderName = defaultProviderEntry.get().getKey(); var defaultProviderConfig = defaultProviderEntry.get().getValue(); - // Find the model marked as default for this provider + // Find the model marked as default for this provider. // The default provider must have a default model that has SUPPORTED status, otherwise it's - // config bug + // config bug, It is validated in RerankingProviderConfigProducer.class during startup. var defaultModel = defaultProviderConfig.models().stream() .filter(RerankingProvidersConfig.RerankingProviderConfig.ModelConfig::isDefault) .filter( modelConfig -> - modelConfig.modelSupport().status() == ModelSupport.SupportStatus.SUPPORTED) + modelConfig.apiModelSupport().status() + == ApiModelSupport.SupportStatus.SUPPORTED) .findFirst() .orElseThrow( () -> @@ -179,20 +206,37 @@ public static CollectionRerankDef configForNewCollections( null // No parameters for default configuration ); - return new CollectionRerankDef(true, defaultRerankingService); + LOGGER.info( + "InitializeDefaultRerankDef during application startup, default reranking configuration initialized with provider '%s' and model '%s'" + .formatted(defaultProviderName, defaultModel.name())); + DEFAULT = new CollectionRerankDef(true, defaultRerankingService); + } + + public static CollectionRerankDef configForDisabled() { + return DISABLED; } /** - * Factory method for creating a configuration for existing collections that predate reranking - * support. + * Accessor for getting a configuration for existing collections that predate reranking support. * *

    Used for collections created before reranking functionality was available. These collections * need to have reranking explicitly disabled for backward compatibility. * - * @return A singleton CollectionRerankDef instance with reranking disabled + * @return A singleton CollectionRerankDef instance ({@link #MISSING}) with reranking disabled */ public static CollectionRerankDef configForPreRerankingCollection() { - return DISABLED; + return MISSING; + } + + /** + * Accessor for a singleton instance used to represent case of default reranking configuration for + * newly created Collections that do not specify reranking configuration. + * + * @return A singleton CollectionRerankDef instance ({@link #DEFAULT}) initialized during + * application startup. + */ + public static CollectionRerankDef configForDefault() { + return DEFAULT; } /** @@ -377,15 +421,22 @@ private static String validateModel( } var model = rerankModel.get(); - if (model.modelSupport().status() != ModelSupport.SupportStatus.SUPPORTED) { - throw SchemaException.Code.UNSUPPORTED_PROVIDER_MODEL.get( + if (model.apiModelSupport().status() != ApiModelSupport.SupportStatus.SUPPORTED) { + var errorCode = + model.apiModelSupport().status() == ApiModelSupport.SupportStatus.DEPRECATED + ? SchemaException.Code.DEPRECATED_AI_MODEL + : SchemaException.Code.END_OF_LIFE_AI_MODEL; + throw errorCode.get( Map.of( "model", model.name(), "modelStatus", - model.modelSupport().status().name(), + model.apiModelSupport().status().name(), "message", - model.modelSupport().message().orElse("The model is not supported."))); + model + .apiModelSupport() + .message() + .orElse("The model is %s.".formatted(model.apiModelSupport().status().name())))); } return modelName; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java index 78a5d43067..30268e073d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSchemaObject.java @@ -46,7 +46,7 @@ public final class CollectionSchemaObject extends TableBasedSchemaObject { VectorConfig.NOT_ENABLED_CONFIG, null, CollectionLexicalConfig.configForDisabled(), - CollectionRerankDef.DISABLED); + CollectionRerankDef.configForDisabled()); private final IdConfig idConfig; private final VectorConfig vectorConfig; @@ -109,6 +109,22 @@ public CollectionSchemaObject withIdType(CollectionIdType idType) { rerankDef); } + /** + * Method for constructing a new CollectionSchemaObject with overrides for Lexical and Rerank + * settings. + */ + public CollectionSchemaObject withLexicalAndRerankOverrides( + CollectionLexicalConfig lexicalOverride, CollectionRerankDef rerankOverride) { + return new CollectionSchemaObject( + name(), + tableMetadata, + idConfig, + vectorConfig, + indexingConfig, + lexicalOverride, + rerankOverride); + } + @Override public VectorConfig vectorConfig() { return vectorConfig; @@ -272,7 +288,7 @@ private static CollectionSchemaObject createCollectionSettings( if (comment == null || comment.isBlank()) { // If no "comment", must assume Legacy (no Lexical) config - CollectionLexicalConfig lexicalConfig = CollectionLexicalConfig.configForDisabled(); + CollectionLexicalConfig lexicalConfig = CollectionLexicalConfig.configForPreLexical(); // If no "comment", must assume Legacy (no Reranking) config CollectionRerankDef rerankingConfig = CollectionRerankDef.configForPreRerankingCollection(); if (vectorEnabled) { @@ -464,7 +480,7 @@ public int hashCode() { @Override public String toString() { - return "CollectionSettings[" + return "CollectionSchemaObject[" + "name=" + name + ", " diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsReader.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsReader.java index 6b89a42980..29a552b82d 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsReader.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsReader.java @@ -12,8 +12,8 @@ *

    CommandSettingsDeserializer works for convert the table comment jsonNode into * collectionSettings */ +@Deprecated // not really used effectively, to be removed public interface CollectionSettingsReader { - // TODO: this interface is not used well, see the V0 implementation CollectionSchemaObject readCollectionSettings( JsonNode jsonNode, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java index b1f58640b6..5dbb613938 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV0Reader.java @@ -2,7 +2,6 @@ import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.config.constants.TableCommentConstants; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; @@ -12,15 +11,16 @@ import java.util.List; /** - * schema_version 0 is before we introduce schema_version into the C* table comment of data api - * collection at this version, table comment only works for indexing options sample: + * schema_version 0 is before we introduced schema_version into the C* table comment of Data API + * collection at this version, table comment only works for indexing options. Sample: + * + *

      * {"indexing":{"deny":["address"]}}
    + * 
    * *

    Note, all collection created in this schema version 0, should have UUID as idType */ -public class CollectionSettingsV0Reader implements CollectionSettingsReader { - - // TODO: Why have function with the same name as the interface method ? +public class CollectionSettingsV0Reader { public CollectionSchemaObject readCollectionSettings( JsonNode commentConfigNode, String keyspaceName, @@ -55,24 +55,8 @@ public CollectionSchemaObject readCollectionSettings( vectorConfig, indexingConfig, // Legacy config, must assume legacy lexical config (disabled) - CollectionLexicalConfig.configForDisabled(), + CollectionLexicalConfig.configForPreLexical(), // Legacy config, must assume legacy reranking config (disabled) CollectionRerankDef.configForPreRerankingCollection()); } - - /** - * schema v0 is obsolete(supported though for backwards compatibility, hard to implement - * readCollectionSettings method based on interface method signature - */ - @Override - public CollectionSchemaObject readCollectionSettings( - JsonNode jsonNode, - String keyspaceName, - String collectionName, - TableMetadata tableMetadata, - ObjectMapper objectMapper) { - // TODO: this is really confusing, why does this implement the interface and not implement the - // one method ? - return null; - } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java index bc7d756516..bb135102ea 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/collections/CollectionSettingsV1Reader.java @@ -15,8 +15,7 @@ * "parameters":{"projectId":"test project"}}} }, "lexical":{"enabled":true,"analyzer":"standard"}, * "rerank":{"enabled":true,"provider":"nvidia","modelName":"nvidia/llama-3.2-nv-rerankqa-1b-v2"}, } */ -public class CollectionSettingsV1Reader implements CollectionSettingsReader { - @Override +public class CollectionSettingsV1Reader { public CollectionSchemaObject readCollectionSettings( JsonNode collectionNode, String keyspaceName, @@ -52,15 +51,15 @@ public CollectionSchemaObject readCollectionSettings( CollectionLexicalConfig lexicalConfig; JsonNode lexicalNode = collectionOptionsNode.path(TableCommentConstants.COLLECTION_LEXICAL_CONFIG_KEY); - if (lexicalNode == null) { - lexicalConfig = CollectionLexicalConfig.configForDisabled(); + if (lexicalNode.isMissingNode()) { + lexicalConfig = CollectionLexicalConfig.configForPreLexical(); } else { boolean enabled = lexicalNode.path("enabled").asBoolean(false); JsonNode analyzerNode = lexicalNode.get("analyzer"); lexicalConfig = new CollectionLexicalConfig(enabled, analyzerNode); } - CollectionRerankDef rerankingConfig = null; + CollectionRerankDef rerankingConfig; JsonNode rerankingNode = collectionOptionsNode.path(TableCommentConstants.COLLECTION_RERANKING_CONFIG_KEY); if (rerankingNode.isMissingNode()) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiColumnDefContainer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiColumnDefContainer.java index ffb124dfa3..d8a506002f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiColumnDefContainer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiColumnDefContainer.java @@ -15,15 +15,10 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** A {@link ApiColumnDefContainer} that maintains the order of the columns as they were added. */ public class ApiColumnDefContainer extends LinkedHashMap implements Recordable { - - private static final Logger LOGGER = LoggerFactory.getLogger(ApiColumnDefContainer.class); - private static final ApiColumnDefContainer IMMUTABLE_EMPTY = new ApiColumnDefContainer(0).toUnmodifiable(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiIndexType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiIndexType.java index 824b31ab5a..d12d2b1e2c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiIndexType.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiIndexType.java @@ -1,6 +1,7 @@ package io.stargate.sgv2.jsonapi.service.schema.tables; import com.datastax.oss.driver.api.core.metadata.schema.IndexMetadata; +import io.stargate.sgv2.jsonapi.config.constants.TableDescConstants; import io.stargate.sgv2.jsonapi.exception.checked.UnsupportedCqlIndexException; import java.util.HashMap; import java.util.Map; @@ -75,27 +76,38 @@ static ApiIndexType fromCql( ApiColumnDef apiColumnDef, CQLSAIIndex.IndexTarget indexTarget, IndexMetadata indexMetadata) throws UnsupportedCqlIndexException { - // TODO: this needs to be updated to detect an analyzed index as a text index + final ApiDataType columnType = apiColumnDef.type(); + + // Let's start with Text (aka Lexical, or Analyzed) indexes: only for TEXT or ASCII columns + // and with a text analyzer defined in the index options. + switch (columnType.typeName()) { + case ASCII, TEXT -> { + String analyzerDef = + indexMetadata.getOptions().get(TableDescConstants.TextIndexCQLOptions.OPTION_ANALYZER); + if (analyzerDef != null && !analyzerDef.isBlank()) { + return TEXT; + } + } + } // If there is no function on the indexTarget, and the column is a scalar, then it is a regular // index on an int, text, etc - if (indexTarget.indexFunction() == null && apiColumnDef.type().isPrimitive()) { + if (indexTarget.indexFunction() == null && columnType.isPrimitive()) { return REGULAR; } // if the target column is a vector, it can only be a vector index, we will let building the // index check the options. // NOTE: check this before the container check, as a vector type is a container type - if (indexTarget.indexFunction() == null - && apiColumnDef.type().typeName() == ApiTypeName.VECTOR) { + if (indexTarget.indexFunction() == null && columnType.typeName() == ApiTypeName.VECTOR) { return VECTOR; } // if the target column is a collection, it can only be a collection index, // collection indexes must have a function if (indexTarget.indexFunction() != null - && apiColumnDef.type().isContainer() - && apiColumnDef.type().typeName() != ApiTypeName.VECTOR) { + && columnType.isContainer() + && columnType.typeName() != ApiTypeName.VECTOR) { return REGULAR; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiTextIndex.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiTextIndex.java new file mode 100644 index 0000000000..c82d7d32ca --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiTextIndex.java @@ -0,0 +1,207 @@ +package io.stargate.sgv2.jsonapi.service.schema.tables; + +import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.errFmt; +import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.errFmtApiColumnDef; +import static io.stargate.sgv2.jsonapi.exception.ErrorFormatters.errVars; +import static io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil.cqlIdentifierToJsonKey; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.metadata.schema.IndexMetadata; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import io.stargate.sgv2.jsonapi.api.model.command.table.IndexDesc; +import io.stargate.sgv2.jsonapi.api.model.command.table.definition.indexes.RegularIndexDefinitionDesc; +import io.stargate.sgv2.jsonapi.api.model.command.table.definition.indexes.TextIndexDefinitionDesc; +import io.stargate.sgv2.jsonapi.config.constants.TableDescConstants; +import io.stargate.sgv2.jsonapi.config.constants.TableDescDefaults; +import io.stargate.sgv2.jsonapi.exception.SchemaException; +import io.stargate.sgv2.jsonapi.exception.checked.UnsupportedCqlIndexException; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; +import io.stargate.sgv2.jsonapi.util.JsonUtil; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** An index of type {@link ApiIndexType#TEXT} on text column */ +public class ApiTextIndex extends ApiSupportedIndex { + public static final IndexFactoryFromIndexDesc + FROM_DESC_FACTORY = new UserDescFactory(); + public static final IndexFactoryFromCql FROM_CQL_FACTORY = new CqlTypeFactory(); + + private final JsonNode analyzer; + + private ApiTextIndex( + CqlIdentifier indexName, + CqlIdentifier targetColumn, + Map options, + JsonNode analyzer) { + super(ApiIndexType.TEXT, indexName, targetColumn, options, null); + + this.analyzer = Objects.requireNonNull(analyzer, "analyzer must not be null"); + } + + @Override + public IndexDesc indexDesc() { + + final var definitionOptions = new TextIndexDefinitionDesc.TextIndexDescOptions(analyzer); + final var definition = + new TextIndexDefinitionDesc(cqlIdentifierToJsonKey(targetColumn), definitionOptions); + + return new IndexDesc<>() { + @Override + public String name() { + return cqlIdentifierToJsonKey(indexName); + } + + @Override + public String indexType() { + return indexType.apiName(); + } + + @Override + public TextIndexDefinitionDesc definition() { + return definition; + } + }; + } + + /** + * Factory to create a new {@link ApiTextIndex} using {@link RegularIndexDefinitionDesc} from the + * user request. + */ + private static class UserDescFactory + extends IndexFactoryFromIndexDesc { + @Override + public ApiTextIndex create( + TableSchemaObject tableSchemaObject, String indexName, TextIndexDefinitionDesc indexDesc) { + + Objects.requireNonNull(tableSchemaObject, "tableSchemaObject must not be null"); + Objects.requireNonNull(indexDesc, "indexDesc must not be null"); + + // for now, we are relying on the validation of the request deserializer that these values are + // specified userNameToIdentifier will throw an exception if the values are not specified + var indexIdentifier = userNameToIdentifier(indexName, "indexName"); + var targetIdentifier = userNameToIdentifier(indexDesc.column(), "targetColumn"); + + var apiColumnDef = checkIndexColumnExists(tableSchemaObject, targetIdentifier); + + // we could check if there is an existing index but that is a race condition, we will need to + // catch it if it fails - the resolver needs to set up a custom error mapper + + // Text indexes can only be on text columns + if (apiColumnDef.type().typeName() != ApiTypeName.TEXT) { + throw SchemaException.Code.UNSUPPORTED_TEXT_INDEX_FOR_DATA_TYPES.get( + errVars( + tableSchemaObject, + map -> { + map.put( + "allColumns", + errFmtApiColumnDef(tableSchemaObject.apiTableDef().allColumns())); + map.put("unsupportedColumns", errFmt(apiColumnDef)); + })); + } + + Map indexOptions = new HashMap<>(); + + // The analyzer is optional, if not specified default settings will be used + // (named analyzer "standard" in CQL). + // But further, if it is a StringNode, needs to be as a "raw" String, not a JSON String; + // but if ObjectNode, it must be encoded as JSON object. + + JsonNode analyzerDef = indexDesc.options() == null ? null : indexDesc.options().analyzer(); + if (analyzerDef == null) { + analyzerDef = + JsonNodeFactory.instance.textNode( + TableDescDefaults.CreateTextIndexOptionsDefaults.DEFAULT_NAMED_ANALYZER); + } else { + // validate that the analyzer is either a String or an Object + if (!analyzerDef.isTextual() && !analyzerDef.isObject()) { + final String unsupportedType = JsonUtil.nodeTypeAsString(analyzerDef); + throw SchemaException.Code.UNSUPPORTED_JSON_TYPE_FOR_TEXT_INDEX.get( + errVars( + tableSchemaObject, + map -> { + map.put("unsupportedType", unsupportedType); + })); + } + } + indexOptions.put( + TableDescConstants.TextIndexCQLOptions.OPTION_ANALYZER, + analyzerDef.isTextual() ? analyzerDef.textValue() : analyzerDef.toString()); + + return new ApiTextIndex(indexIdentifier, targetIdentifier, indexOptions, analyzerDef); + } + } + + /** + * Factory to create a new {@link ApiTextIndex} using the {@link IndexMetadata} from the driver. + */ + private static class CqlTypeFactory extends IndexFactoryFromCql { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + @Override + protected ApiIndexDef create( + ApiColumnDef apiColumnDef, CQLSAIIndex.IndexTarget indexTarget, IndexMetadata indexMetadata) + throws UnsupportedCqlIndexException { + + // this is a sanity check, the base will have worked this, but we should verify it here + var apiIndexType = ApiIndexType.fromCql(apiColumnDef, indexTarget, indexMetadata); + if (apiIndexType != ApiIndexType.TEXT) { + throw new IllegalStateException( + "ApiTextIndex factory only supports %s indexes, apiIndexType: %s" + .formatted(ApiIndexType.TEXT, apiIndexType)); + } + + // also, we must not have an index function + if (indexTarget.indexFunction() != null) { + throw new IllegalStateException( + "ApiTextIndex factory must not have index function, indexMetadata.name: " + + indexMetadata.getName()); + } + + String analyzerDefFromCql = + indexMetadata.getOptions().get(TableDescConstants.TextIndexCQLOptions.OPTION_ANALYZER); + + // Heuristics: 3 choices: + // 1. JSON Object (as a String) -- JSON decode + // 2. String (as a String) -- use as-is + // 3. null or empty -- failure case (should not happen, but handle explicitly) + + JsonNode analyzerDef; + + if (analyzerDefFromCql == null || analyzerDefFromCql.isBlank()) { + // should never happen, but just in case + throw new IllegalStateException( + "ApiTextIndex definition broken (indexMetadata.name: " + + indexMetadata.getName() + + "), missing '" + + TableDescConstants.TextIndexCQLOptions.OPTION_ANALYZER + + "' JSON; options = " + + indexMetadata.getOptions()); + } else if (analyzerDefFromCql.trim().startsWith("{")) { + try { + analyzerDef = OBJECT_MAPPER.readTree(analyzerDefFromCql); + } catch (IOException e) { + throw new IllegalStateException( + "ApiTextIndex definition broken (indexMetadata.name: " + + indexMetadata.getName() + + "), invalid JSON -- " + + analyzerDefFromCql + + " -- error: " + + e.getMessage()); + } + } else { + // just a string, use as is + analyzerDef = OBJECT_MAPPER.getNodeFactory().textNode(analyzerDefFromCql); + } + + return new ApiTextIndex( + indexMetadata.getName(), + indexTarget.targetColumn(), + indexMetadata.getOptions(), + analyzerDef); + } + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiVectorIndex.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiVectorIndex.java index 0fe72e55e6..e07713cb85 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiVectorIndex.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/ApiVectorIndex.java @@ -21,8 +21,6 @@ /** An index of type {@link ApiIndexType#VECTOR} on vector column */ public class ApiVectorIndex extends ApiSupportedIndex { - private static final Logger LOGGER = LoggerFactory.getLogger(ApiVectorIndex.class); - public static final IndexFactoryFromIndexDesc FROM_DESC_FACTORY = new UserDescFactory(); public static final IndexFactoryFromCql FROM_CQL_FACTORY = new CqlTypeFactory(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromCql.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromCql.java index 50ba980b3a..460f8d512f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromCql.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromCql.java @@ -19,7 +19,7 @@ public abstract class IndexFactoryFromCql extends FactoryFromCql { private static final Logger LOGGER = LoggerFactory.getLogger(IndexFactoryFromCql.class); /** - * Analyses the IndexMetadat to call the correct factory to create an {@link ApiIndexDef} + * Analyses the IndexMetadata to call the correct factory to create an {@link ApiIndexDef} * * @param allColumns Container of all columns on the table the index is from. * @param indexMetadata The index metadata from the driver @@ -52,6 +52,7 @@ public static ApiIndexDef create(ApiColumnDefContainer allColumns, IndexMetadata return switch (apiIndexType) { case REGULAR -> ApiRegularIndex.FROM_CQL_FACTORY.create(apiColumnDef, indexTarget, indexMetadata); + case TEXT -> ApiTextIndex.FROM_CQL_FACTORY.create(apiColumnDef, indexTarget, indexMetadata); case VECTOR -> ApiVectorIndex.FROM_CQL_FACTORY.create(apiColumnDef, indexTarget, indexMetadata); default -> diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromIndexDesc.java b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromIndexDesc.java index df1e5e2574..a5be2baeca 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromIndexDesc.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/schema/tables/IndexFactoryFromIndexDesc.java @@ -7,8 +7,6 @@ import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.exception.checked.UnsupportedUserIndexException; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Base for Factories that can create a {@link ApiIndexDef} subclass from the user description in a @@ -20,13 +18,11 @@ public abstract class IndexFactoryFromIndexDesc< ApiT extends ApiIndexDef, DescT extends IndexDefinitionDesc> extends FactoryFromDesc { - private static final Logger LOGGER = LoggerFactory.getLogger(IndexFactoryFromIndexDesc.class); - /** * Called to create an index from the user description in a command * - * @param apiTableDef - * @param name + * @param tableSchemaObject + * @param indexName * @param indexDesc * @return * @throws UnsupportedUserIndexException The factory should throw this if the user description diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/shredding/collections/DocumentShredder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/shredding/collections/DocumentShredder.java index 6bd75db973..777e88bb05 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/shredding/collections/DocumentShredder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/shredding/collections/DocumentShredder.java @@ -11,10 +11,10 @@ import com.fasterxml.uuid.Generators; import com.fasterxml.uuid.NoArgGenerator; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; -import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.config.DocumentLimitsConfig; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.service.projection.IndexingProjector; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionIdType; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/util/recordable/Recordable.java b/src/main/java/io/stargate/sgv2/jsonapi/util/recordable/Recordable.java index 5ee86787cb..c51ec429a8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/util/recordable/Recordable.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/util/recordable/Recordable.java @@ -22,7 +22,7 @@ public interface Recordable { /** - * Called for the implementer to record it's data to the {@link DataRecorder}, values should be + * Called for the implementer to record its data to the {@link DataRecorder}, values should be * appended using {@link DataRecorder#append(String, Object)}. Values that implement {@link * Recordable} will be added as a sub object. * @@ -54,7 +54,7 @@ default DataRecorder recordToSubRecorder(DataRecorder dataRecorder) { * .maybeTrace("Parsed JSON Documents", Recordable.copyOf(parsedDocuments)); * * - * @param recordables The collection of Recordable objects to wrap. + * @param values The collection of Recordable objects to wrap. * @return A Recordable object that will record the collection as an array of objects. */ static Recordable copyOf(Collection values) { diff --git a/src/main/proto/embedding_gateway.proto b/src/main/proto/embedding_gateway.proto index a114a5f319..4381c94933 100644 --- a/src/main/proto/embedding_gateway.proto +++ b/src/main/proto/embedding_gateway.proto @@ -91,6 +91,7 @@ message GetSupportedProvidersResponse { repeated ParameterConfig parameters = 5; RequestProperties properties = 6; repeated ModelConfig models = 7; + bool authTokenPassThroughForNoneAuth = 8; message AuthenticationConfig{ @@ -147,13 +148,14 @@ message GetSupportedProvidersResponse { int32 max_batch_size = 9; } - // ModelConfig message represents configuration for a specific model. + // ModelConfig message represents configuration for a specific embedding model. message ModelConfig { string name = 1; optional int32 vector_dimension = 2; repeated ParameterConfig parameters = 3; map properties = 4; optional string service_url_override = 5; + ApiModelSupport apiModelSupport = 6; } } @@ -262,16 +264,11 @@ message GetSupportedRerankingProvidersResponse { // ModelConfig message represents configuration for a specific reranking model. message ModelConfig { string name = 1; - ModelSupport modelSupport = 2; + ApiModelSupport apiModelSupport = 2; bool isDefault = 3; string url = 4; RequestProperties properties = 5; - message ModelSupport{ - string status = 1; - optional string message = 2; - } - message RequestProperties { int32 at_most_retries = 1; int32 initial_back_off_millis = 2; @@ -290,6 +287,12 @@ message GetSupportedRerankingProvidersResponse { } } +// The ApiModelSupport message represents the support status of the embedding/reranking model. +message ApiModelSupport{ + string status = 1; + optional string message = 2; +} + // The embedding gateway gPRC API to reranking service RerankingService { rpc Rerank (ProviderRerankingRequest) returns (RerankingResponse) {} diff --git a/src/main/resources/application.conf b/src/main/resources/application.conf index 6826cc80a0..d34493007f 100644 --- a/src/main/resources/application.conf +++ b/src/main/resources/application.conf @@ -19,7 +19,9 @@ datastax-java-driver { } advanced.metrics { id-generator{ - class = io.stargate.sgv2.jsonapi.service.cqldriver.CustomTaggingMetricIdGenerator + # This will add 'session' tag to session and node level metrics with the session name + # that we set to the tenant name. (Also adds a node tag for node level metrics) + class = TaggingMetricIdGenerator } factory.class = MicrometerMetricsFactory session { diff --git a/src/main/resources/embedding-providers-config.yaml b/src/main/resources/embedding-providers-config.yaml index 22da96cd29..ebf7a02b1e 100644 --- a/src/main/resources/embedding-providers-config.yaml +++ b/src/main/resources/embedding-providers-config.yaml @@ -157,7 +157,7 @@ stargate: # see https://huggingface.co/blog/getting-started-with-embeddings display-name: Hugging Face - Serverless enabled: true - url: https://api-inference.huggingface.co/pipeline/feature-extraction/ + url: https://router.huggingface.co/hf-inference/models/{modelId}/pipeline/feature-extraction supported-authentications: NONE: enabled: false @@ -301,7 +301,8 @@ stargate: # see https://docs.api.nvidia.com/nim/reference/nvidia-embedding-2b-infer display-name: Nvidia enabled: true - url: https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings + url: https://us-west-2.api-dev.ai.datastax.com/nvidia/v1/embeddings/nv-embed-qa-v4 + auth-token-pass-through-for-none-auth: true supported-authentications: NONE: enabled: true diff --git a/src/main/resources/errors.yaml b/src/main/resources/errors.yaml index 1df65e06bf..ec403901d4 100644 --- a/src/main/resources/errors.yaml +++ b/src/main/resources/errors.yaml @@ -859,14 +859,24 @@ request-errors: Resend the command without dropping the indexed columns. - scope: SCHEMA - code: UNSUPPORTED_PROVIDER_MODEL - title: The model is not supported by Data API + code: DEPRECATED_AI_MODEL + title: Cannot use deprecated model. body: |- - The model ${model} is at ${modelStatus} status. + The command attempted to create or alter a Collection or Table to use a AI Model that has been marked as deprecated. Deprecated models are only supported by the API for existing use and will later be removed. + + The model is: ${model}. It is at ${modelStatus} status. ${message} - DEPRECATED model cannot be used for creation of new Collections or Tables, but can be used for existing Collections or Tables until model status is changed to END_OF_LIFE. - END_OF_LIFE model cannot be used for creation of new Collections or Tables, and can not be used for existing Collections or Tables. + Resend the command using supported model. + + - scope: SCHEMA + code: END_OF_LIFE_AI_MODEL + title: Cannot use end of life model. + body: |- + The command attempted to use an AI Model that has been marked as end of life. End of life models cannot be used in any way. Collections or Tables that use the model must be recreated as data such as embeddings is not transferrable. + + The model is: ${model}. It is at ${modelStatus} status. + ${message} Resend the command using supported model. @@ -877,7 +887,7 @@ request-errors: The command attempted to vectorize columns that are not in the table schema. The table ${keyspace}.${table} defines the columns: ${allColumns}. - The command attempted to drop the unknown columns: ${unknownColumns}. + The command attempted to vectorize the unknown columns: ${unknownColumns}. Resend the command using only columns defined in the table schema. @@ -1082,6 +1092,17 @@ request-errors: Resend the command using a supported index type. + - scope: SCHEMA + code: UNSUPPORTED_JSON_TYPE_FOR_TEXT_INDEX + title: JSON value type is not supported for creating text index + body: |- + The command attempted to create a text index using an unsupported JSON value type. + + The supported JSON value types are: String, Object. + The command used the unsupported JSON value type: ${unsupportedType}. + + Resend the command using a supported JSON value type. + - scope: SCHEMA code: UNSUPPORTED_TEXT_ANALYSIS_FOR_DATA_TYPES title: Analysed text index not supported by data types @@ -1099,11 +1120,25 @@ request-errors: Resend the command using columns that support text analysis. + - scope: SCHEMA + code: UNSUPPORTED_TEXT_INDEX_FOR_DATA_TYPES + title: Text index not supported by data types + body: |- + The command attempted to create a text index on a column that is not a `text` or `ascii` type. + + Text indexes can only be created on columns of type `text` or `ascii`. + + The table ${keyspace}.${table} defines the columns: ${allColumns}. + The command attempted to text index the unsupported columns: ${unsupportedColumns}. + + Resend the command using columns of `text` type. + + - scope: SCHEMA code: UNSUPPORTED_VECTOR_INDEX_FOR_DATA_TYPES title: Vector index not supported by data types body: |- - The command attempted to create an vector index on a column that is not a `vector` type. + The command attempted to create a vector index on a column that is not a `vector` type. Vector indexes can only be created on columns of type `vector`, regular indexes can only be created on primitive data types such as `text` and `int` using the createIndex command. diff --git a/src/main/resources/test-embedding-providers-config.yaml b/src/main/resources/test-embedding-providers-config.yaml new file mode 100644 index 0000000000..6b32693bde --- /dev/null +++ b/src/main/resources/test-embedding-providers-config.yaml @@ -0,0 +1,479 @@ +# custom properties for enabling vectorize method +stargate: + jsonapi: + embedding: + providers: + openai: + #see https://platform.openai.com/docs/api-reference/embeddings/create + display-name: OpenAI + enabled: true + url: https://api.openai.com/v1/ + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + parameters: + - name: organizationId + type: string + required: false + help: "Organization ID will be passed as an OpenAI organization" + display-name: "Organization ID" + hint: "Add an (optional) organization ID" + - name: projectId + type: string + required: false + help: "Project ID will be passed as an OpenAI project header" + display-name: "Project ID" + hint: "Add an (optional) project ID" + properties: + max-batch-size: 2048 + models: + - name: text-embedding-3-small + parameters: + - name: vectorDimension + type: number + required: true + default-value: 1536 + validation: + numeric-range: [2, 1536] + help: "Vector dimension to use in the database and when calling OpenAI." + - name: text-embedding-3-large + parameters: + - name: vectorDimension + type: number + required: true + default-value: 3072 + validation: + numeric-range: [256, 3072] + help: "Vector dimension to use in the database and when calling OpenAI." + - name: text-embedding-ada-002 + vector-dimension: 1536 + azureOpenAI: + # see https://learn.microsoft.com/en-us/azure/ai-services/openai/reference + # see https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models + display-name: Azure OpenAI + enabled: true + url: https://{resourceName}.openai.azure.com/openai/deployments/{deploymentId}/embeddings?api-version=2024-02-01 + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: api-key + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: api-key + parameters: + - name: "resourceName" + type: string + required: true + help: "Azure OpenAI Service name" + display-name: "Resource name" + - name: "deploymentId" + type: string + required: true + help: "Deployment name" + display-name: "Deployment ID" + properties: + max-input-length: 16 + max-batch-size: 2048 + models: + - name: text-embedding-3-small + parameters: + - name: vectorDimension + type: number + required: true + default-value: 1536 + validation: + numeric-range: [2, 1536] + help: "Vector dimension to use in the database and when calling Azure OpenAI." + - name: text-embedding-3-large + parameters: + - name: vectorDimension + type: number + required: true + # https://github.com/stargate/data-api/issues/1241: Docs claim 3072 is max, + # but using values above 1536 does not seem to work. So at least default + # to what seems like a legal value (but leave max higher in case issue is fixed). + default-value: 1536 + validation: + numeric-range: [256, 3072] + help: "Vector dimension to use in the database and when calling Azure OpenAI." + - name: text-embedding-ada-002 + vector-dimension: 1536 + bedrock: + display-name: Amazon Bedrock + enabled: true + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-access-id + forwarded: aws_access_key_id + - accepted: x-embedding-secret-id + forwarded: aws_secret_access_key + SHARED_SECRET: + enabled: false + tokens: + - accepted: accessId + forwarded: aws_access_key_id + - accepted: secretKey + forwarded: aws_secret_access_key + parameters: + - name: "region" + type: string + required: true + help: "AWS region where the model is hosted." + properties: + max-batch-size: 1 + models: + - name: amazon.titan-embed-text-v1 + vector-dimension: 1536 + - name: amazon.titan-embed-text-v2:0 + parameters: + - name: vectorDimension + type: number + required: false + default-value: 1024 + validation: + options: [256, 512, 1024] + help: "Vector dimension to use in the database and when calling Amazon Bedrock Titan V2 embedding model." + huggingface: + # see https://huggingface.co/blog/getting-started-with-embeddings + display-name: Hugging Face - Serverless + enabled: true + url: https://api-inference.huggingface.co/pipeline/feature-extraction/ + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + properties: + max-batch-size: 32 + models: + - name: sentence-transformers/all-MiniLM-L6-v2 + vector-dimension: 384 + - name: intfloat/multilingual-e5-large + vector-dimension: 1024 + - name: intfloat/multilingual-e5-large-instruct + vector-dimension: 1024 + - name: BAAI/bge-small-en-v1.5 + vector-dimension: 384 + - name: BAAI/bge-base-en-v1.5 + vector-dimension: 768 + - name: BAAI/bge-large-en-v1.5 + vector-dimension: 1024 + huggingfaceDedicated: + # see https://huggingface.co/docs/inference-endpoints/en/supported_tasks#sentence-embeddings + display-name: Hugging Face - Dedicated + enabled: true + url: https://{endpointName}.{regionName}.{cloudName}.endpoints.huggingface.cloud/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + properties: + max-batch-size: 32 + models: + - name: endpoint-defined-model + parameters: + - name: vectorDimension + type: number + required: true + validation: + numeric-range: [2, 3072] + help: "Vector dimension to use in the database, should be the same as the model used by the endpoint." + parameters: + - name: "endpointName" + type: string + required: true + help: "Add the first part of the dedicated endpoint URL" + display-name: "Endpoint name" + hint: "Add endpoint name" + - name: "regionName" + type: string + required: true + help: "Add the second part of the dedicated endpoint URL" + display-name: "Region name" + hint: "Add region name" + - name: "cloudName" + type: string + required: true + help: "Add the third part of the dedicated endpoint URL" + display-name: "Cloud provider where the dedicated endpoint is deployed" + hint: "Add cloud name" + # OUT OF SCOPE FOR INITIAL PREVIEW + vertexai: + # see https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#get_text_embeddings_for_a_snippet_of_text + display-name: Google Vertex AI + enabled: false + url: "https://us-central1-aiplatform.googleapis.com/v1/projects/{projectId}/locations/us-central1/publishers/google/models" + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + parameters: + - name: projectId + type: string + required: true + help: "The Google Cloud Project ID to use when calling" + properties: + task-type-store: RETRIEVAL_DOCUMENT # see https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#api_changes_to_models_released_on_or_after_august_2023 + task-type-read: QUESTION_ANSWERING + max-input-length: 5 + max-batch-size: 32 + models: + - name: textembedding-gecko@003 + vector-dimension: 768 + parameters: + - name: "autoTruncate" + type: boolean + required: false + default-value: true + help: "If set to false, text that exceeds the token limit causes the request to fail. The default value is true." + # OUT OF SCOPE FOR INITIAL PREVIEW + cohere: + # see https://docs.cohere.com/reference/embed + display-name: Cohere + enabled: false + url: https://api.cohere.ai/v1/ + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + properties: + max-batch-size: 32 + models: + - name: embed-english-v3.0 + vector-dimension: 1024 + - name: embed-english-v2.0 + vector-dimension: 4096 + nvidia: + # see https://docs.api.nvidia.com/nim/reference/nvidia-embedding-2b-infer + display-name: Nvidia + enabled: true + url: https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings + supported-authentications: + NONE: + enabled: true + properties: + max-batch-size: 8 + models: + - name: NV-Embed-QA + vector-dimension: 1024 + properties: + max-tokens: 512 + - name: a-deprecated-nvidia-embedding-model + vector-dimension: 1024 + api-model-support: + status: DEPRECATED + message: This model has been deprecated, it will be removed in a future release. It is not supported for new Collections or Tables. + - name: a-EOL-nvidia-embedding-model + vector-dimension: 1024 + api-model-support: + status: END_OF_LIFE + message: This model is at END_OF_LIFE status, it is not supported. + + jinaAI: + #see https://jina.ai/embeddings/#apiform + display-name: Jina AI + enabled: true + url: https://api.jina.ai/v1/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + properties: + initial-back-off-millis: 1000 + max-back-off-millis: 1000 + max-batch-size: 32 + models: + - name: jina-embeddings-v2-base-en + vector-dimension: 768 + - name: jina-embeddings-v2-base-de + vector-dimension: 768 + - name: jina-embeddings-v2-base-es + vector-dimension: 768 + - name: jina-embeddings-v2-base-code + vector-dimension: 768 + - name: jina-embeddings-v2-base-zh + vector-dimension: 768 + - name: jina-embeddings-v3 + # https://jina.ai/news/jina-embeddings-v3-a-frontier-multilingual-embedding-model/ + parameters: + - name: vectorDimension + type: number + required: true + default-value: 1024 + validation: + numeric-range: [1, 1024] + help: "Vector dimension to use in the database and when calling Jina AI." + - name: task + type: string + required: false + default-value: text-matching + help: "Select the downstream task for which the embeddings will be used. The model will return the optimized embeddings for that task. Available options are: retrieval.passage, retrieval.query, separation, classification, text-matching. For more information, please refer to the Jina AI documentation: https://jina.ai/news/jina-embeddings-v3-a-frontier-multilingual-embedding-model/." + - name: late_chunking + type: boolean + required: false + default-value: false + help: "Apply the late chunking technique to leverage the model's long-context capabilities for generating contextual chunk embeddings. For more information, please refer to the Jina AI documentation: https://jina.ai/news/jina-embeddings-v3-a-frontier-multilingual-embedding-model/." + voyageAI: + # see https://docs.voyageai.com/reference/embeddings-api + # see https://docs.voyageai.com/docs/embeddings + display-name: Voyage AI + enabled: true + url: https://api.voyageai.com/v1/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + parameters: + - name: "autoTruncate" + type: BOOLEAN + required: false + default-value: true + help: "Whether to truncate the input texts to fit within the context length. Defaults to true." + properties: + max-input-length: 128 + task-type-store: document + task-type-read: query + max-batch-size: 32 + models: + - name: voyage-large-2-instruct + vector-dimension: 1024 + - name: voyage-law-2 + vector-dimension: 1024 + - name: voyage-code-2 + vector-dimension: 1536 + - name: voyage-large-2 + vector-dimension: 1536 + - name: voyage-2 + vector-dimension: 1024 + - name: voyage-finance-2 + vector-dimension: 1024 + - name: voyage-multilingual-2 + vector-dimension: 1024 + mistral: + # see https://docs.mistral.ai/api/#operation/createEmbedding + display-name: Mistral AI + enabled: true + url: https://api.mistral.ai/v1/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + parameters: + properties: + max-batch-size: 32 + models: + - name: mistral-embed + vector-dimension: 1024 + + # NOTE: UpstageAI has one model for storing and a diff one for reading: this is different + # from everyone else. For now handling this requires explicit handling by actual + # embedding client implementation: model name here is a prefix for the actual model name. + # In addition, implementation only supports 1-entry vectorization. + upstageAI: + # see https://developers.upstage.ai/docs/apis/embeddings + display-name: Upstage + enabled: true + url: https://api.upstage.ai/v1/solar/embeddings + supported-authentications: + NONE: + enabled: false + HEADER: + enabled: true + tokens: + - accepted: x-embedding-api-key + forwarded: Authorization + SHARED_SECRET: + enabled: false + tokens: + - accepted: providerKey + forwarded: Authorization + parameters: + properties: + max-batch-size: 1 + models: + # NOTE: this is where weirdness exists; model name is prefix on which + # either "-query" or "-passage" is appended to get the actual model name + - name: solar-embedding-1-large + vector-dimension: 4096 diff --git a/src/main/resources/test-reranking-providers-config.yaml b/src/main/resources/test-reranking-providers-config.yaml index 2cfd0fb711..7f1350470a 100644 --- a/src/main/resources/test-reranking-providers-config.yaml +++ b/src/main/resources/test-reranking-providers-config.yaml @@ -16,14 +16,14 @@ stargate: properties: max-batch-size: 10 - name: nvidia/a-random-deprecated-model - model-support: + api-model-support: status: DEPRECATED message: This model has been deprecated, it will be removed in a future release. It is not supported for new Collections or Tables. url: https://us-west-2.api-dev.ai.datastax.com/nvidia/v1/ranking properties: max-batch-size: 10 - name: nvidia/a-random-EOL-model - model-support: + api-model-support: status: END_OF_LIFE message: This model is at END_OF_LIFE status, it is not supported. url: https://us-west-2.api-dev.ai.datastax.com/nvidia/v1/ranking diff --git a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java index 497dbc5ed0..d0536c24e9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java @@ -1,12 +1,15 @@ package io.stargate.sgv2.jsonapi; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import io.micrometer.core.instrument.MeterRegistry; import io.stargate.sgv2.jsonapi.api.model.command.CommandConfig; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; +import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentialsSupplier; import io.stargate.sgv2.jsonapi.api.request.RequestContext; -import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; +import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; @@ -57,7 +60,7 @@ public TestConstants() { IdConfig.defaultIdConfig(), VectorConfig.NOT_ENABLED_CONFIG, null, - CollectionLexicalConfig.configForEnabledStandard(), + CollectionLexicalConfig.configForDefault(), // Use default reranking config - hardcode the value to avoid reading config new CollectionRerankDef( true, @@ -107,6 +110,11 @@ public CommandContext collectionContext( JsonProcessingMetricsReporter metricsReporter, EmbeddingProvider embeddingProvider) { + var requestContext = mock(RequestContext.class); + when(requestContext.getEmbeddingCredentialsSupplier()) + .thenReturn(mock(EmbeddingCredentialsSupplier.class)); + when(requestContext.getTenantId()).thenReturn(Optional.of("test-tenant")); + return CommandContext.builderSupplier() .withJsonProcessingMetricsReporter( metricsReporter == null ? mock(JsonProcessingMetricsReporter.class) : metricsReporter) @@ -114,10 +122,11 @@ public CommandContext collectionContext( .withCommandConfig(new CommandConfig()) .withEmbeddingProviderFactory(mock(EmbeddingProviderFactory.class)) .withRerankingProviderFactory(mock(RerankingProviderFactory.class)) + .withMeterRegistry(mock(MeterRegistry.class)) .getBuilder(schema) .withEmbeddingProvider(embeddingProvider) .withCommandName(commandName) - .withRequestContext(new RequestContext(Optional.of("test-tenant"))) + .withRequestContext(requestContext) .build(); } @@ -138,6 +147,7 @@ public CommandContext keyspaceContext( .withCommandConfig(new CommandConfig()) .withEmbeddingProviderFactory(mock(EmbeddingProviderFactory.class)) .withRerankingProviderFactory(mock(RerankingProviderFactory.class)) + .withMeterRegistry(mock(MeterRegistry.class)) .getBuilder(schema) .withCommandName(commandName) .withRequestContext(new RequestContext(Optional.of("test-tenant"))) @@ -151,6 +161,7 @@ public CommandContext databaseContext() { .withCommandConfig(new CommandConfig()) .withEmbeddingProviderFactory(mock(EmbeddingProviderFactory.class)) .withRerankingProviderFactory(mock(RerankingProviderFactory.class)) + .withMeterRegistry(mock(MeterRegistry.class)) .getBuilder(DATABASE_SCHEMA_OBJECT) .withCommandName(TEST_COMMAND_NAME) .withRequestContext(new RequestContext(Optional.of("test-tenant"))) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java index dfdcc0b93d..d125042e99 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/configuration/ObjectMapperConfigurationTest.java @@ -15,7 +15,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.GeneralCommand; import io.stargate.sgv2.jsonapi.api.model.command.KeyspaceCommand; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterClause; -import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterSpec; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.FilterDefinition; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.JsonLiteral; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.JsonType; import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.ValueComparisonOperation; @@ -58,7 +58,8 @@ class ObjectMapperConfigurationTest { @Inject ObjectMapper objectMapper; @Inject DocumentLimitsConfig documentLimitsConfig; - private TestConstants testConstants = new TestConstants(); + + private final TestConstants testConstants = new TestConstants(); @Nested class UnmatchedOperationCommandHandlerTest { @@ -176,7 +177,7 @@ public void happyPath() throws Exception { .isInstanceOfSatisfying( FindOneCommand.class, findOne -> { - SortClause sortClause = findOne.sortClause(); + SortClause sortClause = findOne.sortClause(testConstants.collectionContext()); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()) .contains( @@ -220,7 +221,11 @@ public void sortClauseOptional() throws Exception { assertThat(result) .isInstanceOfSatisfying( FindOneCommand.class, - findOne -> Assertions.assertThat(findOne.sortClause()).isNull()); + findOne -> { + SortClause sc = findOne.sortClause(testConstants.collectionContext()); + Assertions.assertThat(sc).isNotNull(); + Assertions.assertThat(sc.isEmpty()).isTrue(); + }); } @Test @@ -238,7 +243,7 @@ public void filterClauseOptional() throws Exception { assertThat(result) .isInstanceOfSatisfying( FindOneCommand.class, - findOne -> Assertions.assertThat(findOne.filterSpec()).isNull()); + findOne -> Assertions.assertThat(findOne.filterDefinition()).isNull()); } // Only "empty" Options allowed, nothing else @@ -934,7 +939,7 @@ public void happyPath() throws Exception { .isInstanceOfSatisfying( FindOneAndUpdateCommand.class, findOneAndUpdateCommand -> { - FilterSpec filterSpec = findOneAndUpdateCommand.filterSpec(); + FilterDefinition filterSpec = findOneAndUpdateCommand.filterDefinition(); assertThat(filterSpec).isNotNull(); final UpdateClause updateClause = findOneAndUpdateCommand.updateClause(); assertThat(updateClause).isNotNull(); @@ -963,7 +968,7 @@ public void findOneAndUpdateWithOptions() throws Exception { .isInstanceOfSatisfying( FindOneAndUpdateCommand.class, findOneAndUpdateCommand -> { - FilterSpec filterSpec = findOneAndUpdateCommand.filterSpec(); + FilterDefinition filterSpec = findOneAndUpdateCommand.filterDefinition(); assertThat(filterSpec).isNotNull(); final UpdateClause updateClause = findOneAndUpdateCommand.updateClause(); assertThat(updateClause).isNotNull(); @@ -1038,7 +1043,7 @@ public void happyPath() throws Exception { .isInstanceOfSatisfying( CountDocumentsCommand.class, countCommand -> { - FilterSpec filterSpec = countCommand.filterSpec(); + FilterDefinition filterSpec = countCommand.filterDefinition(); assertThat(filterSpec).isNotNull(); }); } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilderTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilderTest.java index 260cf40c99..516ff4579e 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilderTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/builders/FilterClauseBuilderTest.java @@ -35,7 +35,7 @@ public class FilterClauseBuilderTest { private TestConstants testConstants = new TestConstants(); @Nested - class Deserialize { + class BuildWithRegularOperators { @Test public void happyPath() throws Exception { @@ -44,7 +44,7 @@ public void happyPath() throws Exception { {"username": "aaron"} """; - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); final ComparisonExpression expectedResult = new ComparisonExpression( "username", @@ -86,7 +86,7 @@ private static Stream provideRangeQueries() { @MethodSource("provideRangeQueries") public void testRangeComparisonOperator( String json, ValueComparisonOperator operator, String column) throws Exception { - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause).isNotNull(); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); @@ -107,7 +107,7 @@ public void testRangeComparisonOperator( public void mustHandleNull() throws Exception { String json = "null"; - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.isEmpty()).isTrue(); } @@ -119,7 +119,7 @@ public void mustHandleEmpty() throws Exception { {} """; - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(0); } @@ -137,7 +137,7 @@ public void mustHandleString() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral("aaron", JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -161,7 +161,7 @@ public void mustHandleNumber() throws Exception { ValueComparisonOperator.EQ, new JsonLiteral(BigDecimal.valueOf(40), JsonType.NUMBER))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -184,7 +184,7 @@ public void mustHandleBoolean() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral(true, JsonType.BOOLEAN))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -208,7 +208,7 @@ public void mustHandleDate() throws Exception { ValueComparisonOperator.EQ, new JsonLiteral(new Date(1672531200000L), JsonType.DATE))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -232,7 +232,7 @@ public void mustHandleDateAndOr() throws Exception { ValueComparisonOperator.EQ, new JsonLiteral(new Date(1672531200000L), JsonType.DATE))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(1); assertThat(filterClause.logicalExpression().getTotalComparisonExpressionCount()).isEqualTo(1); assertThat( @@ -262,7 +262,7 @@ public void mustHandleDateAsEpoch() throws Exception { {"dateType": {"$date": "2023-01-01"}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -280,7 +280,7 @@ public void mustHandleDateAsEpochAndOr() throws Exception { { "$or" : [{"dateType": {"$date": "2023-01-01"}}]} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -295,8 +295,8 @@ public void mustHandleDateAsEpochAndOr() throws Exception { public void mustHandleAll() throws Exception { String json = """ - {"allPath" : {"$all": ["a", "b"]}} - """; + {"allPath" : {"$all": ["a", "b"]}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "allPath", @@ -305,7 +305,7 @@ public void mustHandleAll() throws Exception { ArrayComparisonOperator.ALL, new JsonLiteral(List.of("a", "b"), JsonType.ARRAY))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -327,7 +327,7 @@ public void mustHandleAllAndOr() throws Exception { } ] } - """; + """; final ComparisonExpression expectedResult1 = new ComparisonExpression( "allPath", @@ -343,7 +343,7 @@ public void mustHandleAllAndOr() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral("testAge", JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(1); assertThat(filterClause.logicalExpression().getTotalComparisonExpressionCount()).isEqualTo(2); assertThat( @@ -390,7 +390,7 @@ public void mustHandleAllNonArray() throws Exception { """ {"allPath" : {"$all": "abc"}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -405,7 +405,7 @@ public void mustHandleAllNonEmptyArray() throws Exception { """ {"allPath" : {"$all": []}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -418,8 +418,8 @@ public void mustHandleAllNonEmptyArray() throws Exception { public void mustHandleSize() throws Exception { String json = """ - {"sizePath" : {"$size": 2}} - """; + {"sizePath" : {"$size": 2}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "sizePath", @@ -428,7 +428,7 @@ public void mustHandleSize() throws Exception { ArrayComparisonOperator.SIZE, new JsonLiteral(new BigDecimal(2), JsonType.NUMBER))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -442,8 +442,8 @@ public void mustHandleSize() throws Exception { public void mustHandleIntegerWithTrailingZeroSize() throws Exception { String json = """ - {"sizePath" : {"$size": 0.0}} - """; + {"sizePath" : {"$size": 0.0}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "sizePath", @@ -452,7 +452,7 @@ public void mustHandleIntegerWithTrailingZeroSize() throws Exception { ArrayComparisonOperator.SIZE, new JsonLiteral(new BigDecimal(0), JsonType.NUMBER))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -463,8 +463,8 @@ public void mustHandleIntegerWithTrailingZeroSize() throws Exception { String json1 = """ - {"sizePath" : {"$size": 5.0}} - """; + {"sizePath" : {"$size": 5.0}} + """; final ComparisonExpression expectedResult1 = new ComparisonExpression( "sizePath", @@ -473,7 +473,7 @@ public void mustHandleIntegerWithTrailingZeroSize() throws Exception { ArrayComparisonOperator.SIZE, new JsonLiteral(new BigDecimal(5), JsonType.NUMBER))), null); - FilterClause filterClause1 = readFilterClause(json); + FilterClause filterClause1 = readCollectionFilterClause(json); assertThat(filterClause1.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause1.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -487,9 +487,9 @@ public void mustHandleIntegerWithTrailingZeroSize() throws Exception { public void mustHandleSizeNonNumber() throws Exception { String json = """ - {"sizePath" : {"$size": "2"}} - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + {"sizePath" : {"$size": "2"}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -503,9 +503,9 @@ public void mustHandleSizeNonNumber() throws Exception { public void mustHandleSizeNonInteger() throws Exception { String json = """ - {"sizePath" : {"$size": "1.1"}} - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + {"sizePath" : {"$size": "1.1"}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -515,9 +515,9 @@ public void mustHandleSizeNonInteger() throws Exception { String json1 = """ - {"sizePath" : {"$size": "5.4"}} - """; - Throwable throwable1 = catchThrowable(() -> readFilterClause(json1)); + {"sizePath" : {"$size": "5.4"}} + """; + Throwable throwable1 = catchThrowable(() -> readCollectionFilterClause(json1)); assertThat(throwable1) .isInstanceOf(JsonApiException.class) .satisfies( @@ -530,9 +530,9 @@ public void mustHandleSizeNonInteger() throws Exception { public void mustHandleSizeNegative() throws Exception { String json = """ - {"sizePath" : {"$size": -2}} - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + {"sizePath" : {"$size": -2}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -545,8 +545,8 @@ public void mustHandleSizeNegative() throws Exception { public void mustHandleSubDocEq() throws Exception { String json = """ - {"sub_doc" : {"col": 2}} - """; + {"sub_doc" : {"col": 2}} + """; Map value = new LinkedHashMap<>(); value.put("col", new BigDecimal(2)); final ComparisonExpression expectedResult = @@ -556,7 +556,7 @@ public void mustHandleSubDocEq() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral(value, JsonType.SUB_DOC))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -570,8 +570,8 @@ ValueComparisonOperator.EQ, new JsonLiteral(value, JsonType.SUB_DOC))), public void mustHandleArrayNe() throws Exception { String json = """ - {"col" : {"$ne": ["1","2"]}} - """; + {"col" : {"$ne": ["1","2"]}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "col", @@ -580,7 +580,7 @@ public void mustHandleArrayNe() throws Exception { ValueComparisonOperator.NE, new JsonLiteral(List.of("1", "2"), JsonType.ARRAY))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -594,8 +594,8 @@ public void mustHandleArrayNe() throws Exception { public void mustHandleArrayEq() throws Exception { String json = """ - {"col" : {"$eq": ["3","4"]}} - """; + {"col" : {"$eq": ["3","4"]}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "col", @@ -604,7 +604,7 @@ public void mustHandleArrayEq() throws Exception { ValueComparisonOperator.EQ, new JsonLiteral(List.of("3", "4"), JsonType.ARRAY))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -618,8 +618,8 @@ public void mustHandleArrayEq() throws Exception { public void mustHandleSubDocNe() throws Exception { String json = """ - {"sub_doc" : {"$ne" : {"col": 2}}} - """; + {"sub_doc" : {"$ne" : {"col": 2}}} + """; Map value = new LinkedHashMap<>(); value.put("col", new BigDecimal(2)); final ComparisonExpression expectedResult = @@ -629,7 +629,7 @@ public void mustHandleSubDocNe() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.NE, new JsonLiteral(value, JsonType.SUB_DOC))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -666,7 +666,7 @@ ValueComparisonOperator.NE, new JsonLiteral("Tim", JsonType.STRING))), ElementComparisonOperator.EXISTS, new JsonLiteral(true, JsonType.BOOLEAN))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(2); assertThat( @@ -690,8 +690,8 @@ ElementComparisonOperator.EXISTS, new JsonLiteral(true, JsonType.BOOLEAN))), public void mustHandleIdFieldIn() throws Exception { String json = """ - {"_id" : {"$in": ["2", "3"]}} - """; + {"_id" : {"$in": ["2", "3"]}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "_id", @@ -702,7 +702,7 @@ public void mustHandleIdFieldIn() throws Exception { List.of(DocumentId.fromString("2"), DocumentId.fromString("3")), JsonType.ARRAY))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -716,8 +716,8 @@ public void mustHandleIdFieldIn() throws Exception { public void mustHandleNonIdFieldIn() throws Exception { String json = """ - {"name" : {"$in": ["name1", "name2"]}} - """; + {"name" : {"$in": ["name1", "name2"]}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "name", @@ -726,7 +726,7 @@ public void mustHandleNonIdFieldIn() throws Exception { ValueComparisonOperator.IN, new JsonLiteral(List.of("name1", "name2"), JsonType.ARRAY))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -740,15 +740,15 @@ public void mustHandleNonIdFieldIn() throws Exception { public void mustHandleNonIdFieldInAndOr() throws Exception { String json = """ - { - "$and": [ - {"name" : {"$in": ["name1", "name2"]}}, - { - "age": "testAge" - } - ] - } - """; + { + "$and": [ + {"name" : {"$in": ["name1", "name2"]}}, + { + "age": "testAge" + } + ] + } + """; final ComparisonExpression expectedResult1 = new ComparisonExpression( "name", @@ -764,7 +764,7 @@ public void mustHandleNonIdFieldInAndOr() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral("testAge", JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(1); assertThat(filterClause.logicalExpression().getTotalComparisonExpressionCount()).isEqualTo(2); assertThat( @@ -810,13 +810,13 @@ public void simpleOr() throws Exception { String json = """ - { - "$or":[ - {"name" : "testName"}, - {"age" : "testAge"} - ] - } - """; + { + "$or":[ + {"name" : "testName"}, + {"age" : "testAge"} + ] + } + """; final ComparisonExpression expectedResult1 = new ComparisonExpression( "name", @@ -831,7 +831,7 @@ ValueComparisonOperator.EQ, new JsonLiteral("testName", JsonType.STRING))), new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral("testAge", JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(1); assertThat(filterClause.logicalExpression().logicalExpressions.get(0).comparisonExpressions) .hasSize(2); @@ -860,13 +860,13 @@ public void simpleAnd() throws Exception { String json = """ - { - "$and":[ - {"name" : "testName"}, - {"age" : "testAge"} - ] - } - """; + { + "$and":[ + {"name" : "testName"}, + {"age" : "testAge"} + ] + } + """; final ComparisonExpression expectedResult1 = new ComparisonExpression( "name", @@ -881,7 +881,7 @@ ValueComparisonOperator.EQ, new JsonLiteral("testName", JsonType.STRING))), new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral("testAge", JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(1); assertThat(filterClause.logicalExpression().logicalExpressions.get(0).comparisonExpressions) .hasSize(2); @@ -910,27 +910,27 @@ public void nestedOrAnd() throws Exception { String json = """ - { - "$and": [ - { - "name": "testName" - }, - { - "age": "testAge" - }, - { - "$or": [ - { - "address": "testAddress" - }, - { - "height": "testHeight" - } - ] - } - ] - } - """; + { + "$and": [ + { + "name": "testName" + }, + { + "age": "testAge" + }, + { + "$or": [ + { + "address": "testAddress" + }, + { + "height": "testHeight" + } + ] + } + ] + } + """; final ComparisonExpression expectedResult1 = new ComparisonExpression( "name", @@ -959,7 +959,7 @@ ValueComparisonOperator.EQ, new JsonLiteral("testAddress", JsonType.STRING))), new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral("testHeight", JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(1); assertThat(filterClause.logicalExpression().getTotalComparisonExpressionCount()).isEqualTo(4); assertThat(filterClause.logicalExpression().logicalExpressions.get(0).comparisonExpressions) @@ -1020,8 +1020,8 @@ ValueComparisonOperator.EQ, new JsonLiteral("testHeight", JsonType.STRING))), public void mustHandleInArrayNonEmpty() throws Exception { String json = """ - {"_id" : {"$in": []}} - """; + {"_id" : {"$in": []}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "_id", @@ -1029,7 +1029,7 @@ public void mustHandleInArrayNonEmpty() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.IN, new JsonLiteral(List.of(), JsonType.ARRAY))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -1043,8 +1043,8 @@ ValueComparisonOperator.IN, new JsonLiteral(List.of(), JsonType.ARRAY))), public void mustHandleNinArrayNonEmpty() throws Exception { String json = """ - {"_id" : {"$nin": []}} - """; + {"_id" : {"$nin": []}} + """; final ComparisonExpression expectedResult = new ComparisonExpression( "_id", @@ -1052,7 +1052,7 @@ public void mustHandleNinArrayNonEmpty() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.NIN, new JsonLiteral(List.of(), JsonType.ARRAY))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -1066,9 +1066,9 @@ ValueComparisonOperator.NIN, new JsonLiteral(List.of(), JsonType.ARRAY))), public void mustHandleInArrayOnly() throws Exception { String json = """ - {"_id" : {"$in": "aaa"}} - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + {"_id" : {"$in": "aaa"}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1081,9 +1081,9 @@ public void mustHandleInArrayOnly() throws Exception { public void mustHandleNinArrayOnly() throws Exception { String json = """ - {"_id" : {"$nin": "random"}} - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + {"_id" : {"$nin": "random"}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1096,16 +1096,16 @@ public void mustHandleNinArrayOnly() throws Exception { public void mustHandleNinArrayOnlyAnd() throws Exception { String json = """ - { - "$and": [ - {"age" : {"$nin": "aaa"}}, - { - "name": "testName" + { + "$and": [ + {"age" : {"$nin": "aaa"}}, + { + "name": "testName" + } + ] } - ] - } - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1119,9 +1119,9 @@ public void mustHandleInArrayWithBigArray() throws Exception { // String array with 100 unique numbers String json = """ - {"_id" : {"$in": ["0","1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","32","33","34","35","36","37","38","39","40","41","42","43","44","45","46","47","48","49","50","51","52","53","54","55","56","57","58","59","60","61","62","63","64","65","66","67","68","69","70","71","72","73","74","75","76","77","78","79","80","81","82","83","84","85","86","87","88","89","90","91","92","93","94","95","96","97","98","99","100"]}} - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + {"_id" : {"$in": ["0","1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","32","33","34","35","36","37","38","39","40","41","42","43","44","45","46","47","48","49","50","51","52","53","54","55","56","57","58","59","60","61","62","63","64","65","66","67","68","69","70","71","72","73","74","75","76","77","78","79","80","81","82","83","84","85","86","87","88","89","90","91","92","93","94","95","96","97","98","99","100"]}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1139,9 +1139,9 @@ public void mustHandleNinArrayWithBigArray() throws Exception { // String array with 100 unique numbers String json = """ - {"_id" : {"$nin": ["0","1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","32","33","34","35","36","37","38","39","40","41","42","43","44","45","46","47","48","49","50","51","52","53","54","55","56","57","58","59","60","61","62","63","64","65","66","67","68","69","70","71","72","73","74","75","76","77","78","79","80","81","82","83","84","85","86","87","88","89","90","91","92","93","94","95","96","97","98","99","100"]}} - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + {"_id" : {"$nin": ["0","1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","32","33","34","35","36","37","38","39","40","41","42","43","44","45","46","47","48","49","50","51","52","53","54","55","56","57","58","59","60","61","62","63","64","65","66","67","68","69","70","71","72","73","74","75","76","77","78","79","80","81","82","83","84","85","86","87","88","89","90","91","92","93","94","95","96","97","98","99","100"]}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1158,22 +1158,22 @@ public void mustHandleNinArrayWithBigArray() throws Exception { public void multipleIdFilterAndOr() throws Exception { String json = """ - { - "_id": "testID1", - "$or": [ - { - "name": "testName" - }, - { - "age": "testAge" - }, - { - "_id": "testID2" - } - ] - } - """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + { + "_id": "testID1", + "$or": [ + { + "name": "testName" + }, + { + "age": "testAge" + }, + { + "_id": "testID2" + } + ] + } + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1189,7 +1189,7 @@ public void invalidPathName() throws Exception { """ {"$gt" : {"test" : 5}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) @@ -1204,10 +1204,10 @@ public void invalidPathName() throws Exception { public void valid$vectorPathName() throws Exception { String json = """ - {"$vector" : {"$exists": true}} - """; + {"$vector" : {"$exists": true}} + """; - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat(filterClause.logicalExpression().comparisonExpressions.get(0).getPath()) @@ -1230,7 +1230,7 @@ public void invalidPathName() throws Exception { {"$exists" : {"$vector": true}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) @@ -1246,7 +1246,7 @@ public void invalidPathNameWithValidOperator() { """ {"$exists" : {"$exists": true}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) @@ -1259,7 +1259,115 @@ public void invalidPathNameWithValidOperator() { } @Nested - class DeserializeWithJsonExtensions { + class BuildWithMatchOperator { + @Test + public void mustHandleMatchOperator() throws Exception { + String json = + """ + {"$lexical": {"$match": "search text"}} + """; + final ComparisonExpression expectedResult = + new ComparisonExpression( + "$lexical", + List.of( + new ValueComparisonOperation( + ValueComparisonOperator.MATCH, + new JsonLiteral("search text", JsonType.STRING))), + null); + FilterClause filterClause = readCollectionFilterClause(json); + assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); + assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); + assertThat( + filterClause.logicalExpression().comparisonExpressions.get(0).getFilterOperations()) + .isEqualTo(expectedResult.getFilterOperations()); + assertThat(filterClause.logicalExpression().comparisonExpressions.get(0).getPath()) + .isEqualTo(expectedResult.getPath()); + } + + @Test + public void mustFailOnMatchWithNonLexicalField() { + String json = + """ + {"content": {"$match": "search text"}} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); + assertThat(throwable) + .isInstanceOf(JsonApiException.class) + .satisfies( + t -> { + assertThat(t.getMessage()) + .contains( + "$match operator can only be used with the '$lexical' field, not 'content'"); + }); + } + + @ParameterizedTest + @MethodSource("matchNonStringArgs") + public void mustFailOnMatchWithNonString(String actualType, String jsonSnippet) { + String json = + """ + {"$lexical": {"$match": %s}} + """ + .formatted(jsonSnippet); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); + assertThat(throwable) + .isInstanceOf(JsonApiException.class) + .satisfies( + t -> { + assertThat(t.getMessage()) + .contains( + "$match operator must have `String` value, was `%s`".formatted(actualType)); + }); + } + + private static Stream matchNonStringArgs() { + return Stream.of( + Arguments.of("Array", "[\"text1\", \"text2\"]"), + Arguments.of("Boolean", "true"), + Arguments.of("Null", "null"), + Arguments.of("Number", "42"), + Arguments.of("Object", "{\"key\": \"value\"}")); + } + + // Verify explicit "$eq" not allowed for $lexical + @Test + public void mustFailOnLexicalWithExplicitEq() { + String json = + """ + {"$lexical": { "$eq": "search text"} } + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); + assertThat(throwable) + .isInstanceOf(JsonApiException.class) + .satisfies( + t -> { + assertThat(t.getMessage()) + .contains( + "Cannot filter on '$lexical' field using operator $eq: only $match is supported"); + }); + } + + // Verify short-cut for "$eq" not allowed for $lexical + @Test + public void mustFailOnLexicalWithImplicitEq() { + String json = + """ + {"$lexical": "search text"} + """; + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); + assertThat(throwable) + .isInstanceOf(JsonApiException.class) + .satisfies( + t -> { + assertThat(t.getMessage()) + .contains( + "Cannot filter on '$lexical' field using operator $eq: only $match is supported"); + }); + } + } + + @Nested + class BuildWithJsonExtensions { @Test public void mustHandleObjectIdAsId() throws Exception { final String OBJECT_ID = "5f3e3d1e1e6e6f1e6e6e6f1e"; @@ -1280,7 +1388,7 @@ public void mustHandleObjectIdAsId() throws Exception { objectMapper.getNodeFactory().textNode(OBJECT_ID)), JsonType.DOCUMENT_ID))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat(filterClause.logicalExpression().comparisonExpressions.get(0).getPath()) @@ -1305,7 +1413,7 @@ public void mustHandleObjectIdAsRegularField() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral(OBJECT_ID, JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -1334,7 +1442,7 @@ public void mustHandleUUIDAsId() throws Exception { JsonExtensionType.UUID, objectMapper.getNodeFactory().textNode(UUID)), JsonType.DOCUMENT_ID))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat(filterClause.logicalExpression().comparisonExpressions.get(0).getPath()) @@ -1359,7 +1467,7 @@ public void mustHandleUUIDAsRegularField() throws Exception { new ValueComparisonOperation( ValueComparisonOperator.EQ, new JsonLiteral(UUID, JsonType.STRING))), null); - FilterClause filterClause = readFilterClause(json); + FilterClause filterClause = readCollectionFilterClause(json); assertThat(filterClause.logicalExpression().logicalExpressions).hasSize(0); assertThat(filterClause.logicalExpression().comparisonExpressions).hasSize(1); assertThat( @@ -1376,7 +1484,7 @@ public void mustFailOnBadUUIDAsId() throws Exception { {"_id": {"$uuid": "abc"}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1394,7 +1502,7 @@ public void mustFailOnBadObjectIdAsId() throws Exception { {"_id": {"$objectId": "xyz"}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1412,7 +1520,7 @@ public void mustFailOnUnknownOperatorAsId() throws Exception { {"_id": {"$GUID": "abc"}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1428,7 +1536,7 @@ public void mustFailOnBadUUIDAsField() throws Exception { {"field": {"$uuid": "abc"}} """; - Throwable throwable = catchThrowable(() -> readFilterClause(json)); + Throwable throwable = catchThrowable(() -> readCollectionFilterClause(json)); assertThat(throwable) .isInstanceOf(JsonApiException.class) .satisfies( @@ -1440,8 +1548,12 @@ public void mustFailOnBadUUIDAsField() throws Exception { } } - FilterClause readFilterClause(String json) throws IOException { - return FilterClauseBuilder.builderFor(testConstants.COLLECTION_SCHEMA_OBJECT) - .build(operationsConfig, objectMapper.readTree(json)); + FilterClause readCollectionFilterClause(String json) { + try { + return FilterClauseBuilder.builderFor(testConstants.COLLECTION_SCHEMA_OBJECT) + .build(operationsConfig, objectMapper.readTree(json)); + } catch (IOException e) { + throw new RuntimeException(e); + } } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/SortClauseDeserializerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/builders/SortClauseBuilderTest.java similarity index 80% rename from src/test/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/SortClauseDeserializerTest.java rename to src/test/java/io/stargate/sgv2/jsonapi/api/model/command/builders/SortClauseBuilderTest.java index 18b131fec4..42b0288549 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/deserializers/SortClauseDeserializerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/builders/SortClauseBuilderTest.java @@ -1,24 +1,30 @@ -package io.stargate.sgv2.jsonapi.api.model.command.deserializers; +package io.stargate.sgv2.jsonapi.api.model.command.builders; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.catchThrowable; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkus.test.junit.QuarkusTest; import io.quarkus.test.junit.TestProfile; +import io.stargate.sgv2.jsonapi.TestConstants; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import io.stargate.sgv2.jsonapi.util.Base64Util; import io.stargate.sgv2.jsonapi.util.CqlVectorUtil; import jakarta.inject.Inject; +import java.io.IOException; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @QuarkusTest @TestProfile(NoGlobalResourcesTestProfile.Impl.class) -class SortClauseDeserializerTest { +class SortClauseBuilderTest { + // Needed to create the collection context to pass to the builder + private final TestConstants testConstants = new TestConstants(); @Inject ObjectMapper objectMapper; @@ -35,7 +41,7 @@ public void happyPath() throws Exception { } """; - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()) @@ -54,7 +60,7 @@ public void happyPathWithUnusualChars() throws Exception { } """; - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()) @@ -75,7 +81,7 @@ public void happyPathWithEscapedChars() throws Exception { } """; - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()) @@ -97,7 +103,7 @@ public void happyPathVectorSearch() throws Exception { } """; - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()).hasSize(1); @@ -119,7 +125,7 @@ public void vectorSearchBinaryObject() throws Exception { """ .formatted(vectorString); - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()).hasSize(1); @@ -141,7 +147,7 @@ public void binaryVectorSearchTableColumn() throws Exception { """ .formatted(vectorString); - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()).hasSize(1); @@ -159,7 +165,7 @@ public void vectorSearchEmpty() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()).contains("$vector value can't be empty"); @@ -174,7 +180,7 @@ public void vectorSearchNonArray() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()).contains("$vector value needs to be array of numbers"); @@ -189,7 +195,7 @@ public void vectorSearchNonArrayObject() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()) @@ -206,7 +212,7 @@ public void vectorSearchInvalidData() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()).contains("$vector value needs to be array of numbers"); @@ -222,7 +228,7 @@ public void vectorSearchInvalidSortClause() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()) @@ -238,7 +244,7 @@ public void happyPathVectorizeSearch() throws Exception { } """; - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()).hasSize(1); @@ -255,7 +261,7 @@ public void vectorizeSearchNonText() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()) @@ -271,7 +277,7 @@ public void vectorizeSearchObject() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()) @@ -288,7 +294,7 @@ public void vectorizeSearchBlank() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()) @@ -305,7 +311,7 @@ public void vectorizeSearchWithOtherSort() { } """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable.getMessage()) @@ -319,7 +325,7 @@ public void mustTrimPath() throws Exception { {"some.path " : 1} """; - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); assertThat(sortClause).isNotNull(); assertThat(sortClause.sortExpressions()) @@ -331,9 +337,11 @@ public void mustTrimPath() throws Exception { public void mustHandleNull() throws Exception { String json = "null"; - SortClause sortClause = objectMapper.readValue(json, SortClause.class); + SortClause sortClause = deserializeSortClause(json); - assertThat(sortClause).isNull(); + // Note: we will always create non-null sort clause + assertThat(sortClause).isNotNull(); + assertThat(sortClause.isEmpty()).isTrue(); } @Test @@ -343,7 +351,7 @@ public void mustBeObject() { ["primitive"] """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); } @@ -355,7 +363,7 @@ public void mustNotContainBlankString() { {" " : 1} """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); } @@ -367,7 +375,7 @@ public void mustNotContainEmptyString() { {"": 1} """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); } @@ -378,7 +386,7 @@ public void invalidPathNameOperator() { """ {"$gt": 1} """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable) @@ -392,11 +400,10 @@ public void invalidPathNameHybridWithNumber() { Throwable t = catchThrowable( () -> - objectMapper.readValue( + deserializeSortClause( """ {"$hybrid": 1} - """, - SortClause.class)); + """)); assertThat(t).isInstanceOf(JsonApiException.class); assertThat(t) @@ -409,11 +416,10 @@ public void invalidPathNameHybridWithString() { Throwable t = catchThrowable( () -> - objectMapper.readValue( + deserializeSortClause( """ {"$hybrid": "tokens are tasty"} - """, - SortClause.class)); + """)); assertThat(t).isInstanceOf(JsonApiException.class); assertThat(t) @@ -426,7 +432,7 @@ public void invalidEscapeUsage() { """ {"a&b": 1} """; - Throwable throwable = catchThrowable(() -> objectMapper.readValue(json, SortClause.class)); + Throwable throwable = catchThrowable(() -> deserializeSortClause(json)); assertThat(throwable).isInstanceOf(JsonApiException.class); assertThat(throwable) @@ -434,4 +440,10 @@ public void invalidEscapeUsage() { "Invalid sort clause path: sort clause path ('a&b') is not a valid path."); } } + + private SortClause deserializeSortClause(String json) throws IOException { + final JsonNode node = objectMapper.readTree(json); + CollectionSchemaObject schema = testConstants.collectionContext().schemaObject(); + return SortClauseBuilder.builderFor(schema).build(node); + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializerTest.java index c9f97c1189..d9612c878a 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/clause/sort/FindAndRerankSortClauseDeserializerTest.java @@ -1,5 +1,6 @@ package io.stargate.sgv2.jsonapi.api.model.command.clause.sort; +import static io.stargate.sgv2.jsonapi.metrics.CommandFeature.*; import static io.stargate.sgv2.jsonapi.util.Base64Util.encodeAsMimeBase64; import static io.stargate.sgv2.jsonapi.util.CqlVectorUtil.floatsToBytes; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -11,6 +12,7 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.ObjectMapper; +import io.stargate.sgv2.jsonapi.metrics.CommandFeatures; import io.stargate.sgv2.jsonapi.util.recordable.PrettyPrintable; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; @@ -44,15 +46,30 @@ void setUp() { @Test public void testEqualsAndHash() { var value1 = - new FindAndRerankSort("vectorize sort", "lexical sort", new float[] {1.1f, 2.2f, 3.3f}); + new FindAndRerankSort( + "vectorize sort", + "lexical sort", + new float[] {1.1f, 2.2f, 3.3f}, + CommandFeatures.EMPTY); var diffVectorize = - new FindAndRerankSort("vectorize sort 2", "lexical sort", new float[] {1.1f, 2.2f, 3.3f}); + new FindAndRerankSort( + "vectorize sort 2", + "lexical sort", + new float[] {1.1f, 2.2f, 3.3f}, + CommandFeatures.EMPTY); var diffLexical = - new FindAndRerankSort("vectorize sort", "lexical sort 2", new float[] {1.1f, 2.2f, 3.3f}); + new FindAndRerankSort( + "vectorize sort", + "lexical sort 2", + new float[] {1.1f, 2.2f, 3.3f}, + CommandFeatures.EMPTY); var diffVector = new FindAndRerankSort( - "vectorize sort", "lexical sort", new float[] {1.1f, 2.2f, 3.3f, 4.4f}); + "vectorize sort", + "lexical sort", + new float[] {1.1f, 2.2f, 3.3f, 4.4f}, + CommandFeatures.EMPTY); assertThat(value1).as("Object equals self").isEqualTo(value1); assertThat(value1).as("different vectorize sort").isNotEqualTo(diffVectorize); @@ -110,12 +127,15 @@ private static Stream validSortsTestCases() { { "$hybrid" : "same for hybrid and lexical" } """, new FindAndRerankSort( - "same for hybrid and lexical", "same for hybrid and lexical", null)), + "same for hybrid and lexical", + "same for hybrid and lexical", + null, + CommandFeatures.of(HYBRID))), Arguments.of( """ { "$hybrid" : "" } """, - new FindAndRerankSort(null, null, null)), + new FindAndRerankSort(null, null, null, CommandFeatures.of(HYBRID))), // ---- // maximum fields, resolver works out the valid combinations Arguments.of( @@ -123,75 +143,107 @@ private static Stream validSortsTestCases() { { "$hybrid" : { "$vectorize" : "vectorize sort", "$lexical" : "lexical sort", "$vector" : [1.1, 2.2, 3.3]} } """, new FindAndRerankSort( - "vectorize sort", "lexical sort", new float[] {1.1f, 2.2f, 3.3f})), + "vectorize sort", + "lexical sort", + new float[] {1.1f, 2.2f, 3.3f}, + CommandFeatures.of(HYBRID, LEXICAL, VECTOR, VECTORIZE))), Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize sort", "$lexical" : "lexical sort"} } """, - new FindAndRerankSort("vectorize sort", "lexical sort", null)), + new FindAndRerankSort( + "vectorize sort", + "lexical sort", + null, + CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE))), // ---- // $lexical variations Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize sort", "$lexical" : null} } """, - new FindAndRerankSort("vectorize sort", null, null)), + new FindAndRerankSort( + "vectorize sort", null, null, CommandFeatures.of(HYBRID, VECTORIZE, LEXICAL))), Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize sort", "$lexical" : ""} } """, - new FindAndRerankSort("vectorize sort", null, null)), + new FindAndRerankSort( + "vectorize sort", null, null, CommandFeatures.of(HYBRID, VECTORIZE, LEXICAL))), Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize sort"} } """, - new FindAndRerankSort("vectorize sort", null, null)), + new FindAndRerankSort( + "vectorize sort", null, null, CommandFeatures.of(HYBRID, VECTORIZE))), // ---- // $vectorize variations Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize sort", "$lexical" : "lexical sort"} } """, - new FindAndRerankSort("vectorize sort", "lexical sort", null)), + new FindAndRerankSort( + "vectorize sort", + "lexical sort", + null, + CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE))), Arguments.of( """ { "$hybrid" : { "$vectorize" : null, "$lexical" : "lexical sort"} } """, - new FindAndRerankSort(null, "lexical sort", null)), + new FindAndRerankSort( + null, "lexical sort", null, CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE))), Arguments.of( """ { "$hybrid" : { "$vectorize" : "", "$lexical" : "lexical sort"} } """, - new FindAndRerankSort(null, "lexical sort", null)), + new FindAndRerankSort( + null, "lexical sort", null, CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE))), Arguments.of( """ { "$hybrid" : {"$lexical" : "lexical sort"} } """, - new FindAndRerankSort(null, "lexical sort", null)), + new FindAndRerankSort(null, "lexical sort", null, CommandFeatures.of(HYBRID, LEXICAL))), // ---- // $vector variations Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize", "$lexical" : "lexical", "$vector" : null} } """, - new FindAndRerankSort("vectorize", "lexical", null)), + new FindAndRerankSort( + "vectorize", + "lexical", + null, + CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE, VECTOR))), Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize", "$lexical" : "lexical", "$vector" : [0.1, 0.2, 0.3]} } """, - new FindAndRerankSort("vectorize", "lexical", new float[] {0.1f, 0.2f, 0.3f})), + new FindAndRerankSort( + "vectorize", + "lexical", + new float[] {0.1f, 0.2f, 0.3f}, + CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE, VECTOR))), Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize", "$lexical" : "lexical", "$vector" : {"$binary": "%s"}} } """ .formatted(emptyVectorBase64), - new FindAndRerankSort("vectorize", "lexical", emptyVector)), + new FindAndRerankSort( + "vectorize", + "lexical", + emptyVector, + CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE, VECTOR))), Arguments.of( """ { "$hybrid" : { "$vectorize" : "vectorize", "$lexical" : "lexical", "$vector" : {"$binary": "%s"}} } """ .formatted(vectorBase64), - new FindAndRerankSort("vectorize", "lexical", vector))); + new FindAndRerankSort( + "vectorize", + "lexical", + vector, + CommandFeatures.of(HYBRID, LEXICAL, VECTORIZE, VECTOR)))); } @ParameterizedTest diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommandTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommandTest.java index 0c1fb0dbca..c81bc9c9bc 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommandTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndDeleteCommandTest.java @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkus.test.junit.QuarkusTest; import io.quarkus.test.junit.TestProfile; +import io.stargate.sgv2.jsonapi.TestConstants; import io.stargate.sgv2.jsonapi.api.model.command.Command; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; @@ -21,6 +22,8 @@ public class FindOneAndDeleteCommandTest { @Inject Validator validator; + private final TestConstants testConstants = new TestConstants(); + @Nested class Validation { @Test @@ -40,7 +43,7 @@ public void happyPath() throws Exception { .isInstanceOfSatisfying( FindOneAndDeleteCommand.class, findOneAndDeleteCommand -> { - assertThat(findOneAndDeleteCommand.filterSpec()).isNotNull(); + assertThat(findOneAndDeleteCommand.filterDefinition()).isNotNull(); }); } @@ -62,8 +65,9 @@ public void withSort() throws Exception { .isInstanceOfSatisfying( FindOneAndDeleteCommand.class, findOneAndDeleteCommand -> { - assertThat(findOneAndDeleteCommand.filterSpec()).isNotNull(); - final SortClause sortClause = findOneAndDeleteCommand.sortClause(); + assertThat(findOneAndDeleteCommand.filterDefinition()).isNotNull(); + final SortClause sortClause = + findOneAndDeleteCommand.sortClause(testConstants.collectionContext()); assertThat(sortClause).isNotNull(); assertThat(sortClause) .satisfies( @@ -92,8 +96,9 @@ public void sortAndProject() throws Exception { .isInstanceOfSatisfying( FindOneAndDeleteCommand.class, findOneAndDeleteCommand -> { - assertThat(findOneAndDeleteCommand.filterSpec()).isNotNull(); - final SortClause sortClause = findOneAndDeleteCommand.sortClause(); + assertThat(findOneAndDeleteCommand.filterDefinition()).isNotNull(); + final SortClause sortClause = + findOneAndDeleteCommand.sortClause(testConstants.collectionContext()); assertThat(sortClause).isNotNull(); assertThat(sortClause) .satisfies( diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommandTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommandTest.java index 72ddd7eed6..67d4b5187a 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommandTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/FindOneAndReplaceCommandTest.java @@ -43,7 +43,7 @@ public void happyPath() throws Exception { .isInstanceOfSatisfying( FindOneAndReplaceCommand.class, findOneAndReplaceCommand -> { - assertThat(findOneAndReplaceCommand.filterSpec()).isNotNull(); + assertThat(findOneAndReplaceCommand.filterDefinition()).isNotNull(); final JsonNode replacementDocument = findOneAndReplaceCommand.replacementDocument(); assertThat(replacementDocument).isNotNull(); final FindOneAndReplaceCommand.Options options = findOneAndReplaceCommand.options(); @@ -71,7 +71,7 @@ public void withSortAndOptions() throws Exception { .isInstanceOfSatisfying( FindOneAndReplaceCommand.class, findOneAndReplaceCommand -> { - assertThat(findOneAndReplaceCommand.filterSpec()).isNotNull(); + assertThat(findOneAndReplaceCommand.filterDefinition()).isNotNull(); final JsonNode replacementDocument = findOneAndReplaceCommand.replacementDocument(); assertThat(replacementDocument).isNotNull(); final FindOneAndReplaceCommand.Options options = findOneAndReplaceCommand.options(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/HybridLimitsDeserializerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/HybridLimitsDeserializerTest.java index 258d609fe1..8598443426 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/HybridLimitsDeserializerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/HybridLimitsDeserializerTest.java @@ -9,6 +9,8 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.ObjectMapper; +import io.stargate.sgv2.jsonapi.metrics.CommandFeature; +import io.stargate.sgv2.jsonapi.metrics.CommandFeatures; import io.stargate.sgv2.jsonapi.util.recordable.PrettyPrintable; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; @@ -40,10 +42,10 @@ void setUp() { @Test public void testEqualsAndHash() { - var value1 = new FindAndRerankCommand.HybridLimits(10, 10); + var value1 = new FindAndRerankCommand.HybridLimits(10, 10, CommandFeatures.EMPTY); - var diffVector = new FindAndRerankCommand.HybridLimits(20, 10); - var diffLexical = new FindAndRerankCommand.HybridLimits(10, 20); + var diffVector = new FindAndRerankCommand.HybridLimits(20, 10, CommandFeatures.EMPTY); + var diffLexical = new FindAndRerankCommand.HybridLimits(10, 20, CommandFeatures.EMPTY); assertThat(value1).as("Object equals self").isEqualTo(value1); assertThat(value1).as("different vector limit").isNotEqualTo(diffVector); @@ -81,24 +83,34 @@ private static Stream validLimitsTestCases() { """ 99 """, - new FindAndRerankCommand.HybridLimits(99, 99)), + new FindAndRerankCommand.HybridLimits( + 99, 99, CommandFeatures.of(CommandFeature.HYBRID_LIMITS_NUMBER))), Arguments.of( """ 0 """, - new FindAndRerankCommand.HybridLimits(0, 0)), + new FindAndRerankCommand.HybridLimits( + 0, 0, CommandFeatures.of(CommandFeature.HYBRID_LIMITS_NUMBER))), // ---- // all must tbe provided for the object form Arguments.of( """ { "$vector" : 99, "$lexical" : 99} """, - new FindAndRerankCommand.HybridLimits(99, 99)), + new FindAndRerankCommand.HybridLimits( + 99, + 99, + CommandFeatures.of( + CommandFeature.HYBRID_LIMITS_VECTOR, CommandFeature.HYBRID_LIMITS_LEXICAL))), Arguments.of( """ { "$vector" : 9, "$lexical" : 99} """, - new FindAndRerankCommand.HybridLimits(9, 99))); + new FindAndRerankCommand.HybridLimits( + 9, + 99, + CommandFeatures.of( + CommandFeature.HYBRID_LIMITS_VECTOR, CommandFeature.HYBRID_LIMITS_LEXICAL)))); } @ParameterizedTest diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommandTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommandTest.java index ed920c3248..bd2e6689fd 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommandTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/model/command/impl/UpdateOneCommandTest.java @@ -2,16 +2,17 @@ import static org.assertj.core.api.Assertions.assertThat; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkus.test.junit.QuarkusTest; import io.quarkus.test.junit.TestProfile; import io.stargate.sgv2.jsonapi.api.model.command.Command; -import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import jakarta.inject.Inject; import jakarta.validation.ConstraintViolation; import jakarta.validation.Validator; +import java.io.IOException; import java.util.Set; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -46,17 +47,14 @@ public void happyPath() throws Exception { .isInstanceOfSatisfying( UpdateOneCommand.class, updateOneCommand -> { - assertThat(updateOneCommand.filterSpec()).isNotNull(); + assertThat(updateOneCommand.filterDefinition()).isNotNull(); final UpdateClause updateClause = updateOneCommand.updateClause(); assertThat(updateClause).isNotNull(); assertThat(updateClause.buildOperations()).hasSize(1); - final SortClause sortClause = updateOneCommand.sortClause(); - assertThat(sortClause).isNotNull(); - assertThat(sortClause.sortExpressions()).hasSize(1); - assertThat(sortClause.sortExpressions().get(0).path()).isEqualTo("username"); - assertThat(sortClause.sortExpressions().get(0).ascending()).isTrue(); - final UpdateOneCommand.Options options = updateOneCommand.options(); - assertThat(options).isNotNull(); + assertThat(updateOneCommand.sortDefinition()).isNotNull(); + assertThat(updateOneCommand.sortDefinition().json()) + .isEqualTo(readTree("{\"username\" : 1}")); + assertThat(updateOneCommand.options()).isNotNull(); }); } @@ -80,4 +78,12 @@ public void noUpdateClause() throws Exception { .contains("must not be null"); } } + + private JsonNode readTree(String json) { + try { + return objectMapper.readTree(json); + } catch (IOException e) { + throw new RuntimeException("Failed to parse JSON", e); + } + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplierTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplierTest.java new file mode 100644 index 0000000000..1ad90e4b31 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplierTest.java @@ -0,0 +1,184 @@ +package io.stargate.sgv2.jsonapi.api.request; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; +import jakarta.inject.Inject; +import java.util.Collections; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@QuarkusTest +@TestProfile(NoGlobalResourcesTestProfile.Impl.class) +public class EmbeddingCredentialsSupplierTest { + + @Inject HttpConstants httpConstants; + + RequestContext requestContext; + RequestContext.HttpHeaderAccess httpHeaderAccess; + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig; + EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig noneAuthConfig; + EmbeddingCredentialsSupplier supplier; + + @BeforeEach + void setUp() { + requestContext = mock(RequestContext.class); + httpHeaderAccess = mock(RequestContext.HttpHeaderAccess.class); + providerConfig = mock(EmbeddingProvidersConfig.EmbeddingProviderConfig.class); + noneAuthConfig = + mock(EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationConfig.class); + + supplier = + new EmbeddingCredentialsSupplier( + httpConstants.authToken(), + httpConstants.embeddingApiKey(), + httpConstants.embeddingAccessId(), + httpConstants.embeddingSecretId()); + + when(requestContext.getHttpHeaders()).thenReturn(httpHeaderAccess); + } + + @Test + public void shouldPassThroughAuthTokenWhenAllConditionsMet() { + // not providing "x-embedding-api-key" and use auth token + when(httpHeaderAccess.getHeader(httpConstants.embeddingApiKey())).thenReturn(null); + when(httpHeaderAccess.getHeader(httpConstants.authToken())).thenReturn("astra-auth-token"); + // Provider config is available, Provider supports NONE auth and it's enabled + when(providerConfig.supportedAuthentications()) + .thenReturn( + Collections.singletonMap( + EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE, + noneAuthConfig)); + when(noneAuthConfig.enabled()).thenReturn(true); + // Provider has authTokenPassThroughForNoneAuth set to true + when(providerConfig.authTokenPassThroughForNoneAuth()).thenReturn(true); + // no collection auth config + supplier.withAuthConfigFromCollection(null); + + // Act + EmbeddingCredentials credentials = supplier.create(requestContext, providerConfig); + + // Assert + assertThat(credentials.apiKey()).contains("astra-auth-token"); + assertThat(credentials.accessId()).isEmpty(); + assertThat(credentials.secretId()).isEmpty(); + } + + @Test + void shouldUseExplicitEmbeddingApiKeyWhenProvided() { + // Arrange + when(httpHeaderAccess.getHeader(httpConstants.embeddingApiKey())) + .thenReturn("embedding-api-key"); + when(httpHeaderAccess.getHeader(httpConstants.embeddingAccessId())).thenReturn("access-id"); + when(httpHeaderAccess.getHeader(httpConstants.embeddingSecretId())).thenReturn("secret-id"); + + // Act + EmbeddingCredentials credentials = supplier.create(requestContext, providerConfig); + + // Assert + assertThat(credentials.apiKey()).contains("embedding-api-key"); + assertThat(credentials.accessId()).contains("access-id"); + assertThat(credentials.secretId()).contains("secret-id"); + } + + @Test + void shouldUseExplicitEmbeddingApiKeyIfPassThroughIsFalse() { + // Arrange + // User explicitly sends the header, but its value is null (or header not present) + when(httpHeaderAccess.getHeader(httpConstants.embeddingApiKey())).thenReturn(null); + when(httpHeaderAccess.getHeader(httpConstants.embeddingAccessId())).thenReturn("access-id"); + when(httpHeaderAccess.getHeader(httpConstants.embeddingSecretId())).thenReturn("secret-id"); + + // Set passThrough to false, make one condition fail + when(providerConfig.authTokenPassThroughForNoneAuth()).thenReturn(false); + + // Act + EmbeddingCredentials credentials = supplier.create(requestContext, providerConfig); + + // Assert + assertThat(credentials.apiKey()).isEmpty(); + assertThat(credentials.accessId()).contains("access-id"); + assertThat(credentials.secretId()).contains("secret-id"); + } + + @Test + void shouldNotPassThroughIfNoneAuthNotSupportedByProvider() { + // Provide auth token + when(httpHeaderAccess.getHeader(httpConstants.authToken())).thenReturn("astra-auth-token"); + when(httpHeaderAccess.getHeader(httpConstants.embeddingApiKey())).thenReturn(null); + // No NONE auth + when(providerConfig.supportedAuthentications()).thenReturn(Collections.emptyMap()); + when(providerConfig.authTokenPassThroughForNoneAuth()).thenReturn(true); + + // Act + EmbeddingCredentials credentials = supplier.create(requestContext, providerConfig); + + // Assert - not replaced with auth token + assertThat(credentials.apiKey()).isEmpty(); + } + + @Test + void shouldNotPassThroughIfNoneAuthSupportedButNotEnabled() { + // Provide auth token + when(httpHeaderAccess.getHeader(httpConstants.authToken())).thenReturn("astra-auth-token"); + when(httpHeaderAccess.getHeader(httpConstants.embeddingApiKey())).thenReturn(null); + // NONE auth present but disabled + when(providerConfig.supportedAuthentications()) + .thenReturn( + Collections.singletonMap( + EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE, + noneAuthConfig)); + when(noneAuthConfig.enabled()).thenReturn(false); + when(providerConfig.authTokenPassThroughForNoneAuth()).thenReturn(true); + + // Act + EmbeddingCredentials credentials = supplier.create(requestContext, providerConfig); + + // Assert - not replaced with auth token + assertThat(credentials.apiKey()).isEmpty(); + } + + @Test + void shouldNotPassThroughIfCollectionHasItsOwnAuthConfig() { + // Provide auth token + when(httpHeaderAccess.getHeader(httpConstants.authToken())).thenReturn("astra-auth-token"); + when(httpHeaderAccess.getHeader(httpConstants.embeddingApiKey())).thenReturn(null); + // Provider config is available, Provider supports NONE auth and it's enabled + when(providerConfig.supportedAuthentications()) + .thenReturn( + Collections.singletonMap( + EmbeddingProvidersConfig.EmbeddingProviderConfig.AuthenticationType.NONE, + noneAuthConfig)); + when(noneAuthConfig.enabled()).thenReturn(true); + // Provider has authTokenPassThroughForNoneAuth set to true + when(providerConfig.authTokenPassThroughForNoneAuth()).thenReturn(true); + // Collection has some auth config + supplier.withAuthConfigFromCollection(Map.of("providerKey", "shared_creds.providerKey")); + + // Act + EmbeddingCredentials credentials = supplier.create(requestContext, providerConfig); + + // Assert - not replaced with auth token + assertThat(credentials.apiKey()).isEmpty(); + } + + @Test + void shouldNotPassThroughIfProviderConfigIsNull() { + // Arrange + when(httpHeaderAccess.getHeader(httpConstants.embeddingApiKey())).thenReturn(null); + when(httpHeaderAccess.getHeader(httpConstants.authToken())).thenReturn("stargate-auth-token"); + + // Act: Pass null for providerConfig + EmbeddingCredentials credentials = supplier.create(requestContext, null); + + // Assert - not replaced with auth token + assertThat(credentials.apiKey()).isEmpty(); // Falls back to standard credential resolution + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/AbstractKeyspaceIntegrationTestBase.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/AbstractKeyspaceIntegrationTestBase.java index b9e4e0cb0b..33bbde0683 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/AbstractKeyspaceIntegrationTestBase.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/AbstractKeyspaceIntegrationTestBase.java @@ -6,8 +6,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; +import static org.junit.jupiter.api.Assertions.fail; import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.fasterxml.jackson.core.Base64Variants; import io.restassured.RestAssured; @@ -16,8 +18,6 @@ import io.restassured.specification.RequestSpecification; import io.stargate.sgv2.jsonapi.api.v1.util.IntegrationTestUtils; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; -import io.stargate.sgv2.jsonapi.service.cqldriver.TenantAwareCqlSessionBuilder; import io.stargate.sgv2.jsonapi.service.embedding.operation.test.CustomITEmbeddingProvider; import io.stargate.sgv2.jsonapi.testresource.StargateTestResource; import io.stargate.sgv2.jsonapi.util.Base64Util; @@ -234,15 +234,26 @@ public static void checkShouldAbsentMetrics(String commandName) { public static void checkDriverMetricsTenantId() { String metrics = given().when().get("/metrics").then().statusCode(200).extract().asString(); + // Example line + // session_cql_requests_seconds{module="sgv2-jsonapi",session="default_tenant",quantile="0.5",} + // 0.238944256 + Optional sessionLevelDriverMetricTenantId = metrics .lines() .filter( line -> - line.startsWith("session_cql_requests_seconds_bucket") - && line.contains("tenant")) + line.startsWith("session_cql_requests_seconds") && line.contains("session=")) .findFirst(); - assertThat(sessionLevelDriverMetricTenantId.isPresent()).isTrue(); + if (!sessionLevelDriverMetricTenantId.isPresent()) { + List lines = metrics.lines().toList(); + long buckets = + lines.stream().filter(line -> line.startsWith("session_cql_requests_seconds")).count(); + fail( + String.format( + "No tenant id found in any of 'session_cql_requests_seconds' entries (%d buckets; %d log lines)", + buckets, lines.size())); + } } public static void checkVectorMetrics(String commandName, String sortType) { @@ -280,11 +291,6 @@ public static void checkIndexUsageMetrics(String commandName, boolean vector) { assertThat(countMetrics.size()).isGreaterThan(0); } - /** Utility method for reducing boilerplate code for sending JSON commands */ - protected RequestSpecification givenHeadersAndJson(String json) { - return given().headers(getHeaders()).contentType(ContentType.JSON).body(json); - } - protected String generateBase64EncodedBinaryVector(float[] vector) { { final byte[] byteArray = CqlVectorUtil.floatsToBytes(vector); @@ -327,10 +333,11 @@ private CqlSession createDriverSession() { } else { dc = "datacenter1"; } - var builder = new TenantAwareCqlSessionBuilder("IntegrationTest").withLocalDatacenter(dc); - builder - .addContactPoint(new InetSocketAddress("localhost", port)) - .withAuthCredentials(CQLSessionCache.CASSANDRA, CQLSessionCache.CASSANDRA); + var builder = + new CqlSessionBuilder() + .withLocalDatacenter(dc) + .addContactPoint(new InetSocketAddress("localhost", port)) + .withAuthCredentials("cassandra", "cassandra"); // default admin password :) return builder.build(); } @@ -339,6 +346,16 @@ protected boolean isLexicalAvailableForDB() { return !"true".equals(System.getProperty("testing.db.lexical-disabled")); } + /** Utility method for reducing boilerplate code for sending JSON commands */ + protected RequestSpecification givenHeaders() { + return given().headers(getHeaders()).contentType(ContentType.JSON); + } + + /** Utility method for reducing boilerplate code for sending JSON commands */ + protected RequestSpecification givenHeadersAndJson(String json) { + return givenHeaders().body(json); + } + /** Utility method for reducing boilerplate code for sending JSON commands */ protected ValidatableResponse givenHeadersPostJsonThen(String json) { return givenHeadersAndJson(json).when().post(KeyspaceResource.BASE_PATH, keyspaceName).then(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResourceIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResourceIntegrationTest.java index d9a434c75c..f74ab898e4 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResourceIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResourceIntegrationTest.java @@ -21,7 +21,7 @@ class ClientErrors { @Test public void tokenMissing() { - given() + given() // NOTE: not passing headers, on purpose .contentType(ContentType.JSON) .body("{}") .when() @@ -37,10 +37,7 @@ public void tokenMissing() { @Test public void malformedBody() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("wrong") + givenHeadersAndJson("wrong") .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -60,18 +57,13 @@ public void malformedBody() { @Test public void unknownCommand() { - String json = - """ + givenHeadersAndJson( + """ { "unknownCommand": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -88,19 +80,14 @@ public void unknownCommand() { @Test public void unknownCommandField() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "unknown": "value" } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -118,9 +105,8 @@ public void unknownCommandField() { @Test public void emptyBody() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) + // Note: no body specified + givenHeaders() .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionBackwardCompatibilityIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionBackwardCompatibilityIntegrationTest.java new file mode 100644 index 0000000000..860cc4ff4b --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionBackwardCompatibilityIntegrationTest.java @@ -0,0 +1,205 @@ +package io.stargate.sgv2.jsonapi.api.v1; + +import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsDDLSuccess; +import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; +import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import io.quarkus.test.common.WithTestResource; +import io.quarkus.test.junit.QuarkusIntegrationTest; +import io.stargate.sgv2.jsonapi.testresource.DseTestResource; +import org.junit.jupiter.api.*; + +@QuarkusIntegrationTest +@WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) +@TestClassOrder(ClassOrderer.OrderAnnotation.class) +public class CreateCollectionBackwardCompatibilityIntegrationTest + extends AbstractKeyspaceIntegrationTestBase { + + @Nested + @TestMethodOrder(MethodOrderer.OrderAnnotation.class) + class CreateCollectionWithLexicalRerankBackwardCompatibility { + private static final String PRE_LEXICAL_RERANK_COLLECTION_NAME = + "pre_lexical_rerank_collection"; + + @Test + @Order(1) + public final void createPreLexicalRerankCollection() { + // NOTE(2025/04/17): Using raw CQL here to precisely simulate the schema state before + // lexical/rerank options were introduced in collection comments. It would be better to use + // non-test code to generate this, but it's embedded in the CreateCollectionOperation. Need to + // change in the future + String collectionWithoutLexicalRerank = + """ + CREATE TABLE IF NOT EXISTS "%s"."%s" ( + key frozen> PRIMARY KEY, + array_contains set, + array_size map, + doc_json text, + exist_keys set, + query_bool_values map, + query_dbl_values map, + query_null_values set, + query_text_values map, + query_timestamp_values map, + query_vector_value vector, + tx_id timeuuid + ) WITH comment = '{"collection":{"name":"%s","schema_version":1,"options":{"defaultId":{"type":""}}}}'; + """; + executeCqlStatement( + SimpleStatement.newInstance( + collectionWithoutLexicalRerank.formatted( + keyspaceName, + PRE_LEXICAL_RERANK_COLLECTION_NAME, + PRE_LEXICAL_RERANK_COLLECTION_NAME))); + + // create indexes for the collection + String[] createIndexCqls = { + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_array_contains ON \"%s\".\"%s\" (values(array_contains)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME), + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_array_size ON \"%s\".\"%s\" (entries(array_size)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME), + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_exists_keys ON \"%s\".\"%s\" (values(exist_keys)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME), + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_bool_values ON \"%s\".\"%s\" (entries(query_bool_values)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME), + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_dbl_values ON \"%s\".\"%s\" (entries(query_dbl_values)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME), + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_null_values ON \"%s\".\"%s\" (values(query_null_values)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME), + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_text_values ON \"%s\".\"%s\" (entries(query_text_values)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME), + String.format( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_timestamp_values ON \"%s\".\"%s\" (entries(query_timestamp_values)) USING 'StorageAttachedIndex';", + PRE_LEXICAL_RERANK_COLLECTION_NAME, keyspaceName, PRE_LEXICAL_RERANK_COLLECTION_NAME) + }; + for (String indexCql : createIndexCqls) { + assertThat(executeCqlStatement(SimpleStatement.newInstance(indexCql))).isTrue(); + } + + // verify the collection using FindCollection + givenHeadersPostJsonThenOkNoErrors( + """ + { + "findCollections": { + "options" : { + "explain": true + } + } + } + """) + .body("$", responseIsDDLSuccess()) + .body("status.collections", hasSize(1)) + .body( + "status.collections[0]", + jsonEquals( + """ + { + "name": "%s", + "options": { + "lexical": { + "enabled": false + }, + "rerank": { + "enabled": false + } + } + } + """ + .formatted(PRE_LEXICAL_RERANK_COLLECTION_NAME))); + } + + @Test + @Order(2) + public final void createCollectionWithoutLexicalRerankUsingAPI() { + // Can only test if we have BM25 support by backend, otherwise skip the test + Assumptions.assumeTrue(isLexicalAvailableForDB()); + + // verify the preexisting collection(generated by the above CQL) using FindCollection + givenHeadersPostJsonThenOkNoErrors( + """ + { + "findCollections": { + "options" : { + "explain": true + } + } + } + """) + .body("$", responseIsDDLSuccess()) + .body("status.collections", hasSize(1)) + .body( + "status.collections[0]", + jsonEquals( + """ + { + "name": "%s", + "options": { + "lexical": { + "enabled": false + }, + "rerank": { + "enabled": false + } + } + } + """ + .formatted(PRE_LEXICAL_RERANK_COLLECTION_NAME))); + + // create the same collection using API - should not have + // EXISTING_COLLECTION_DIFFERENT_SETTINGS error + givenHeadersPostJsonThenOkNoErrors( + """ + { + "createCollection": { + "name": "%s" + } + } + """ + .formatted(PRE_LEXICAL_RERANK_COLLECTION_NAME)) + .body("$", responseIsStatusOnly()) + .body("status.ok", is(1)); + + // verify the collection using FindCollection again + givenHeadersPostJsonThenOkNoErrors( + """ + { + "findCollections": { + "options" : { + "explain": true + } + } + } + """) + .body("$", responseIsDDLSuccess()) + .body("status.collections", hasSize(1)) + .body( + "status.collections[0]", + jsonEquals( + """ + { + "name": "%s", + "options": { + "lexical": { + "enabled": false + }, + "rerank": { + "enabled": false + } + } + } + """ + .formatted(PRE_LEXICAL_RERANK_COLLECTION_NAME))); + } + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java index f580697325..78afd91e95 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionIntegrationTest.java @@ -1,15 +1,18 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; +import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; +import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; @QuarkusIntegrationTest @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) @@ -73,24 +76,15 @@ class CreateCollection { @Test public void happyPath() { final String collectionName = "col" + RandomStringUtils.randomNumeric(16); - String json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name": "%s" } } """ - .formatted(collectionName); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + .formatted(collectionName)) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); deleteCollection(collectionName); @@ -98,45 +92,27 @@ public void happyPath() { @Test public void caseSensitive() { - String json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name": "%s" } } """ - .formatted("testcollection"); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + .formatted("testcollection")) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); - json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name": "%s" } } """ - .formatted("testCollection"); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + .formatted("testCollection")) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); deleteCollection("testcollection"); @@ -146,37 +122,17 @@ public void caseSensitive() { @Test public void duplicateNonVectorCollectionName() { // create a non vector collection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createNonVectorCollectionJson) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createNonVectorCollectionJson) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // recreate the same non vector collection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createNonVectorCollectionJson) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createNonVectorCollectionJson) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // create a vector collection with the same name - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createVectorCollection) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() + givenHeadersPostJsonThenOk(createVectorCollection) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("EXISTING_COLLECTION_DIFFERENT_SETTINGS")) @@ -191,35 +147,15 @@ public void duplicateNonVectorCollectionName() { @Test public void duplicateVectorCollectionName() { // create a vector collection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createVectorCollection) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createVectorCollection) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // recreate the same vector collection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createVectorCollection) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createVectorCollection) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // create a non vector collection with the same name - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createNonVectorCollectionJson) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() + givenHeadersPostJsonThenOk(createNonVectorCollectionJson) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("EXISTING_COLLECTION_DIFFERENT_SETTINGS")) @@ -234,25 +170,11 @@ public void duplicateVectorCollectionName() { @Test public void duplicateVectorCollectionNameWithDiffSetting() { // create a vector collection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createVectorCollection) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createVectorCollection) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // create another vector collection with the same name but different size setting - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createVectorCollectionWithOtherSizeSettings) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createVectorCollectionWithOtherSizeSettings) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("EXISTING_COLLECTION_DIFFERENT_SETTINGS")) @@ -262,14 +184,7 @@ public void duplicateVectorCollectionNameWithDiffSetting() { "trying to create Collection ('simple_collection') with different settings")); // create another vector collection with the same name but different function setting - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createVectorCollectionWithOtherFunctionSettings) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createVectorCollectionWithOtherFunctionSettings) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("EXISTING_COLLECTION_DIFFERENT_SETTINGS")) @@ -302,26 +217,12 @@ public void happyCreateCollectionWithIndexingAllow() { """; // create vector collection with indexing allow option - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionRequest) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionRequest) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // Also: should be idempotent so try creating again - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionRequest) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionRequest) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -349,26 +250,12 @@ public void happyCreateCollectionWithIndexingDeny() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionRequest) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionRequest) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // Also: should be idempotent so try creating again - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionRequest) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionRequest) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -378,10 +265,7 @@ public void happyCreateCollectionWithIndexingDeny() { // Test to ensure single "*" accepted for "allow" or "deny" but not both @Test public void createCollectionWithIndexingStar() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -394,19 +278,12 @@ public void createCollectionWithIndexingStar() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); deleteCollection("simple_collection_indexing_allow_star"); // create vector collection with indexing deny option - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -419,19 +296,12 @@ public void createCollectionWithIndexingStar() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); deleteCollection("simple_collection_indexing_deny_star"); // And then check that we can't use both - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -445,10 +315,6 @@ public void createCollectionWithIndexingStar() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -464,10 +330,7 @@ class CreateCollectionFail { @Test public void failCreateCollectionWithIndexHavingDuplicates() { // create vector collection with error indexing option - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -480,10 +343,6 @@ public void failCreateCollectionWithIndexHavingDuplicates() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -495,10 +354,7 @@ public void failCreateCollectionWithIndexHavingDuplicates() { @Test public void failCreateCollectionWithIndexHavingAllowAndDeny() { // create vector collection with error indexing option - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -512,10 +368,6 @@ public void failCreateCollectionWithIndexHavingAllowAndDeny() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -529,10 +381,7 @@ public void failCreateCollectionWithIndexHavingAllowAndDeny() { @Test public void failWithInvalidNameInIndexingDeny() { // create a vector collection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( // Dollars not allowed in regular field names (can only start operators) """ { @@ -546,10 +395,6 @@ public void failWithInvalidNameInIndexingDeny() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -608,10 +453,7 @@ public void failWithInvalidEscapeCharacterInIndexingDeny() { @Test public void failWithInvalidMainLevelOption() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -622,10 +464,6 @@ public void failWithInvalidMainLevelOption() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -638,10 +476,7 @@ public void failWithInvalidMainLevelOption() { @Test public void failWithInvalidIdConfigOption() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -654,10 +489,6 @@ public void failWithInvalidIdConfigOption() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -670,10 +501,7 @@ public void failWithInvalidIdConfigOption() { @Test public void failWithInvalidIndexingConfigOption() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -686,10 +514,6 @@ public void failWithInvalidIndexingConfigOption() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -702,10 +526,7 @@ public void failWithInvalidIndexingConfigOption() { @Test public void failWithInvalidVectorConfigOption() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -718,10 +539,6 @@ public void failWithInvalidVectorConfigOption() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -762,26 +579,12 @@ public void happyEmbeddingService() { """; // create vector collection with vector service - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionRequest) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionRequest) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // Also: should be idempotent so try creating again - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionRequest) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionRequest) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -791,10 +594,7 @@ public void happyEmbeddingService() { @Test public void failProviderNotSupport() { // create a collection with embedding service provider not support - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -818,10 +618,6 @@ public void failProviderNotSupport() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -834,10 +630,7 @@ public void failProviderNotSupport() { @Test public void failUnsupportedModel() { // create a collection with unsupported model name - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -858,11 +651,7 @@ public void failUnsupportedModel() { } } } - """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) .body( "errors[0].message", @@ -873,6 +662,48 @@ public void failUnsupportedModel() { } } + private static Stream deprecatedEmbeddingModelSource() { + return Stream.of( + Arguments.of( + "DEPRECATED", + "a-deprecated-nvidia-embedding-model", + SchemaException.Code.DEPRECATED_AI_MODEL), + Arguments.of( + "END_OF_LIFE", + "a-EOL-nvidia-embedding-model", + SchemaException.Code.END_OF_LIFE_AI_MODEL)); + } + + @ParameterizedTest + @MethodSource("deprecatedEmbeddingModelSource") + public void failDeprecatedEOLEmbedModel( + String status, String modelName, SchemaException.Code errorCode) { + givenHeadersPostJsonThenOk( + """ + + { + "createCollection": { + "name": "bad_nvidia_model", + "options": { + "vector": { + "dimension": 1024, + "service": { + "provider": "nvidia", + "modelName": "%s" + } + } + } + } + } + """ + .formatted(modelName)) + .body("$", responseIsError()) + .body( + "errors[0].message", + containsString("The model is: %s. It is at %s status".formatted(modelName, status))) + .body("errors[0].errorCode", is(errorCode.name())); + } + @Nested @Order(4) class CreateCollectionWithEmbeddingServiceTestDimension { @@ -912,54 +743,26 @@ public void happyFixDimensionAutoPopulate() { } } } - """; + """; // create vector collection with vector service and no dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithoutDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithoutDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // Also: should be idempotent when try creating with correct dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); deleteCollection("collection_with_vector_service"); // create vector collection with vector service and correct dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // Also: should be idempotent when try creating with no dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithoutDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithoutDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -969,10 +772,7 @@ public void happyFixDimensionAutoPopulate() { @Test public void failNoServiceProviderAndNoDimension() { // create a collection with no dimension and service - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -985,10 +785,6 @@ public void failNoServiceProviderAndNoDimension() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -1001,10 +797,7 @@ public void failNoServiceProviderAndNoDimension() { @Test public void failFixDimensionUnmatchedVectorDimension() { // create a collection with unmatched vector dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1022,10 +815,6 @@ public void failFixDimensionUnmatchedVectorDimension() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -1073,52 +862,24 @@ public void happyRangeDimensionAutoPopulate() { } """; // create vector collection with vector service and no dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithoutDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithoutDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // Also: should be idempotent when try creating with correct dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithDefaultDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithDefaultDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); deleteCollection("collection_with_vector_service"); // create vector collection with vector service and correct dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithDefaultDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithDefaultDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // Also: should be idempotent when try creating with no dimension - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollectionWithoutDimension) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(createCollectionWithoutDimension) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -1127,10 +888,7 @@ public void happyRangeDimensionAutoPopulate() { @Test public void happyRangeDimensionInRange() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1147,11 +905,7 @@ public void happyRangeDimensionInRange() { } } } - """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -1161,10 +915,7 @@ public void happyRangeDimensionInRange() { @Test public void failRangeDimensionNotInRange() { // create a collection with a dimension lower than the min - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1181,11 +932,7 @@ public void failRangeDimensionNotInRange() { } } } - """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) .body( "errors[0].message", @@ -1195,10 +942,7 @@ public void failRangeDimensionNotInRange() { .body("errors[0].exceptionClass", is("JsonApiException")); // create a collection with a dimension higher than the min - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1216,10 +960,6 @@ public void failRangeDimensionNotInRange() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -1236,10 +976,7 @@ class CreateCollectionWithEmbeddingServiceTestAuth { @Test public void happyWithNoneAuth() { // create a collection without providing authentication - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1257,10 +994,6 @@ public void happyWithNoneAuth() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -1269,10 +1002,7 @@ public void happyWithNoneAuth() { @Test public void failNotExistAuthKey() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1293,10 +1023,6 @@ public void failNotExistAuthKey() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -1308,10 +1034,7 @@ public void failNotExistAuthKey() { @Test public void failNoneAndHeaderDisabled() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1329,10 +1052,6 @@ public void failNoneAndHeaderDisabled() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -1344,10 +1063,7 @@ public void failNoneAndHeaderDisabled() { @Test public void failInvalidAuthKey() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1368,10 +1084,6 @@ public void failInvalidAuthKey() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", @@ -1383,10 +1095,7 @@ public void failInvalidAuthKey() { @Test public void happyValidAuthKey() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1404,10 +1113,6 @@ public void happyValidAuthKey() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -1416,10 +1121,7 @@ public void happyValidAuthKey() { @Test public void happyProviderKeyFormat() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1440,19 +1142,12 @@ public void happyProviderKeyFormat() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); deleteCollection("collection_with_vector_service"); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1473,10 +1168,6 @@ public void happyProviderKeyFormat() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -1490,10 +1181,7 @@ class CreateCollectionWithEmbeddingServiceTestParameters { @Test public void failWithMissingRequiredProviderParameters() { // create a collection without providing required parameters - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1511,10 +1199,6 @@ public void failWithMissingRequiredProviderParameters() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -1527,10 +1211,7 @@ public void failWithMissingRequiredProviderParameters() { @Test public void failWithUnrecognizedProviderParameters() { // create a collection with unrecognized parameters - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1551,10 +1232,6 @@ public void failWithUnrecognizedProviderParameters() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -1566,10 +1243,7 @@ public void failWithUnrecognizedProviderParameters() { @Test public void failWithUnexpectedProviderParameters() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1590,10 +1264,6 @@ public void failWithUnexpectedProviderParameters() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -1606,10 +1276,7 @@ public void failWithUnexpectedProviderParameters() { @Test public void failWithWrongProviderParameterType() { // create a collection with wrong parameter type - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1630,11 +1297,7 @@ public void failWithWrongProviderParameterType() { } } } - """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -1646,10 +1309,7 @@ public void failWithWrongProviderParameterType() { @Test public void failWithMissingModelParameters() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1670,10 +1330,6 @@ public void failWithMissingModelParameters() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -1686,10 +1342,7 @@ public void failWithMissingModelParameters() { @Test public void failWithUnexpectedModelParameters() { // create a collection with unrecognized parameters - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1712,10 +1365,6 @@ public void failWithUnexpectedModelParameters() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -1728,10 +1377,7 @@ public void failWithUnexpectedModelParameters() { @Test public void failWithWrongModelParameterType() { // create a collection with wrong parameter type - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1753,10 +1399,6 @@ public void failWithWrongModelParameterType() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) @@ -1773,10 +1415,7 @@ class CreateCollectionWithSourceModel { @Test public void happyWithSourceModelAndMetrics() { // create a collection with source model and metric - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1795,18 +1434,11 @@ public void happyWithSourceModelAndMetrics() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // verify the collection using FindCollection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "findCollections": { @@ -1816,10 +1448,6 @@ public void happyWithSourceModelAndMetrics() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(1)) .body("status.collections[0].options.vector.metric", is("cosine")) @@ -1831,10 +1459,7 @@ public void happyWithSourceModelAndMetrics() { @Test public void happyWithSourceModelOnly() { // create a collection with source model - metric will be auto-populated to 'dot_product' - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1852,18 +1477,11 @@ public void happyWithSourceModelOnly() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // verify the collection using FindCollection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "findCollections": { @@ -1873,10 +1491,6 @@ public void happyWithSourceModelOnly() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(1)) .body("status.collections[0].options.vector.metric", is("dot_product")) @@ -1888,10 +1502,7 @@ public void happyWithSourceModelOnly() { @Test public void happyWithMetricOnly() { // create a collection with metric - source model will be auto-populated to 'other' - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1909,18 +1520,11 @@ public void happyWithMetricOnly() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // verify the collection using FindCollection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "findCollections": { @@ -1930,10 +1534,6 @@ public void happyWithMetricOnly() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(1)) .body("status.collections[0].options.vector.metric", is("cosine")) @@ -1946,10 +1546,7 @@ public void happyWithMetricOnly() { public void happyNoSourceModelAndMetric() { // create a collection without sourceModel and metric - source model will be auto-populated to // 'other' and metric to 'cosine' - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -1966,18 +1563,11 @@ public void happyNoSourceModelAndMetric() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // verify the collection using FindCollection - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "findCollections": { @@ -1987,10 +1577,6 @@ public void happyNoSourceModelAndMetric() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(1)) .body("status.collections[0].options.vector.metric", is("cosine")) @@ -2001,10 +1587,7 @@ public void happyNoSourceModelAndMetric() { @Test public void failWithInvalidSourceModel() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -2022,10 +1605,6 @@ public void failWithInvalidSourceModel() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) @@ -2037,10 +1616,7 @@ public void failWithInvalidSourceModel() { @Test public void failWithInvalidSourceModelObject() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -2058,10 +1634,6 @@ public void failWithInvalidSourceModelObject() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyIndexesIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyIndexesIntegrationTest.java index 9b723d70d6..d3deb4d4f4 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyIndexesIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyIndexesIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsDDLSuccess; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static org.hamcrest.Matchers.is; @@ -8,7 +7,6 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.ClassOrderer; import org.junit.jupiter.api.Test; @@ -60,10 +58,7 @@ public void enforceMaxIndexes() { // First create max collections, should work fine for (int i = 1; i <= COLLECTIONS_TO_CREATE; ++i) { String json = createTemplate.formatted(i); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson(json) .when() .post(KeyspaceResource.BASE_PATH, NS) .then() @@ -72,11 +67,7 @@ public void enforceMaxIndexes() { .body("status.ok", is(1)); } // And then failure - String json = createTemplate.formatted(99); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson(createTemplate.formatted(99)) .when() .post(KeyspaceResource.BASE_PATH, NS) .then() @@ -91,10 +82,7 @@ public void enforceMaxIndexes() { // But then verify that re-creating an existing one should still succeed // (if using same settings) - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createTemplate.formatted(1)) + givenHeadersAndJson(createTemplate.formatted(1)) .when() .post(KeyspaceResource.BASE_PATH, NS) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyTablesIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyTablesIntegrationTest.java index 93f71e1113..bc6cff6d9f 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyTablesIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionTooManyTablesIntegrationTest.java @@ -1,13 +1,11 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsDDLSuccess; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static org.hamcrest.Matchers.is; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.ClassOrderer; import org.junit.jupiter.api.Test; @@ -57,11 +55,7 @@ public void enforceMaxCollections() { // First create maximum number of collections for (int i = 1; i <= COLLECTIONS_TO_CREATE; ++i) { - String json = createTemplate.formatted(i); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson(createTemplate.formatted(i)) .when() .post(KeyspaceResource.BASE_PATH, NS) .then() @@ -70,11 +64,7 @@ public void enforceMaxCollections() { .body("status.ok", is(1)); } // And then failure - String json = createTemplate.formatted(99); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson(createTemplate.formatted(99)) .when() .post(KeyspaceResource.BASE_PATH, NS) .then() @@ -92,10 +82,7 @@ public void enforceMaxCollections() { // But then verify that re-creating an existing one should still succeed // (if using same settings) - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createTemplate.formatted(1)) + givenHeadersAndJson(createTemplate.formatted(1)) .when() .post(KeyspaceResource.BASE_PATH, NS) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithLexicalIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithLexicalIntegrationTest.java index a1f200d92c..d717dec9aa 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithLexicalIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithLexicalIntegrationTest.java @@ -103,6 +103,35 @@ void createLexicalSimpleEnabledCustom() { deleteCollection(collectionName); } + @Test + void createLexicalAdvancedCustom() { + Assumptions.assumeTrue(isLexicalAvailableForDB()); + + final String collectionName = "coll_lexical_advanced_" + RandomStringUtils.randomNumeric(16); + String json = + createRequestWithLexical( + collectionName, + """ + { + "enabled": true, + "analyzer": { + "tokenizer" : {"name" : "standard"}, + "filters": [ + { "name": "lowercase" }, + { "name": "stop" }, + { "name": "porterstem" }, + { "name": "asciifolding" } + ] + } + } + """); + + givenHeadersPostJsonThenOkNoErrors(json) + .body("$", responseIsDDLSuccess()) + .body("status.ok", is(1)); + deleteCollection(collectionName); + } + @Test void createLexicalSimpleDisabled() { // Fine regardless of whether Lexical available for DB or not @@ -224,7 +253,7 @@ void failCreateLexicalWithDisabledAndAnalyzerObject() { .body( "errors[0].message", containsString( - "When 'lexical' is disabled, 'lexical.analyzer' must either be omitted, JSON null, or an empty JSON object {}.")); + "When 'lexical' is disabled, 'lexical.analyzer' must either be omitted or be JSON null, or")); } @Test @@ -276,11 +305,11 @@ void failCreateLexicalWrongJsonType() { createRequestWithLexical( collectionName, """ - { - "enabled": true, - "analyzer": [ 1, 2, 3 ] - } - """); + { + "enabled": true, + "analyzer": [ 1, 2, 3 ] + } + """); if (isLexicalAvailableForDB()) { givenHeadersPostJsonThenOk(json) @@ -298,6 +327,94 @@ void failCreateLexicalWrongJsonType() { .body("errors[0].errorCode", is("LEXICAL_NOT_AVAILABLE_FOR_DATABASE")); } } + + // [data-api#2011] + @Test + void failCreateLexicalMisspelledTokenizer() { + Assumptions.assumeTrue(isLexicalAvailableForDB()); + + final String collectionName = "coll_lexical_" + RandomStringUtils.randomNumeric(16); + String json = + createRequestWithLexical( + collectionName, + """ + { + "enabled": true, + "analyzer": { + "tokeniser": {"name": "standard", "args": {}}, + "filters": [ + {"name": "lowercase"}, + {"name": "stop"}, + {"name": "porterstem"}, + {"name": "asciifolding"} + ], + "extra": 123 + } + } + """); + + givenHeadersPostJsonThenOk(json) + .body("$", responseIsError()) + .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) + .body( + "errors[0].message", + containsString( + "Invalid fields for 'lexical.analyzer'. Valid fields are: [charFilters, filters, tokenizer], found: [extra, tokeniser]")); + } + + // [data-api#2011] + @Test + void failCreateLexicalNonObjectForTokenizer() { + Assumptions.assumeTrue(isLexicalAvailableForDB()); + + final String collectionName = "coll_lexical_" + RandomStringUtils.randomNumeric(16); + String json = + createRequestWithLexical( + collectionName, + """ + { + "enabled": true, + "analyzer": { + "tokenizer": false + } + } + """); + + givenHeadersPostJsonThenOk(json) + .body("$", responseIsError()) + .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) + .body( + "errors[0].message", + containsString( + "'tokenizer' property of 'lexical.analyzer' must be JSON Object, is: Boolean")); + } + + // [data-api#2011] + @Test + void failCreateLexicalNonArrayForFilters() { + Assumptions.assumeTrue(isLexicalAvailableForDB()); + + final String collectionName = "coll_lexical_" + RandomStringUtils.randomNumeric(16); + String json = + createRequestWithLexical( + collectionName, + """ + { + "enabled": true, + "analyzer": { + "filters": { } + } + } + """); + + givenHeadersPostJsonThenOk(json) + .body("$", responseIsError()) + .body("errors[0].errorCode", is("INVALID_CREATE_COLLECTION_OPTIONS")) + .body( + "errors[0].message", + containsString( + "'filters' property of 'lexical.analyzer' must be JSON Array, is: Object")); + } } private String createRequestWithLexical(String collectionName, String lexicalDef) { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithRerankingIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithRerankingIntegrationTest.java index be779789de..00f568b829 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithRerankingIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateCollectionWithRerankingIntegrationTest.java @@ -7,9 +7,14 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; +import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; +import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; @QuarkusIntegrationTest @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) @@ -274,14 +279,14 @@ void failCreateRerankingWithDisabledAndModel() { createRequestWithReranking( collectionName, """ - { - "enabled": false, - "service": { - "provider": "nvidia", - "modelName": "nvidia/llama-3.2-nv-rerankqa-1b-v2" - } - } - """); + { + "enabled": false, + "service": { + "provider": "nvidia", + "modelName": "nvidia/llama-3.2-nv-rerankqa-1b-v2" + } + } + """); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) @@ -313,11 +318,11 @@ void failMissingServiceProvider() { createRequestWithReranking( collectionName, """ - { - "enabled": true, - "service": {} - } - """); + { + "enabled": true, + "service": {} + } + """); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) @@ -335,13 +340,13 @@ void failUnknownServiceProvider() { createRequestWithReranking( collectionName, """ - { - "enabled": true, - "service": { - "provider": "unknown" + { + "enabled": true, + "service": { + "provider": "unknown" + } } - } - """); + """); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) @@ -359,13 +364,13 @@ void failMissingServiceModel() { createRequestWithReranking( collectionName, """ - { - "enabled": true, - "service": { - "provider": "nvidia" + { + "enabled": true, + "service": { + "provider": "nvidia" + } } - } - """); + """); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) @@ -383,14 +388,14 @@ void failUnknownServiceModel() { createRequestWithReranking( collectionName, """ - { - "enabled": true, - "service": { - "provider": "nvidia", - "modelName": "unknown" - } + { + "enabled": true, + "service": { + "provider": "nvidia", + "modelName": "unknown" } - """); + } + """); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) @@ -408,17 +413,17 @@ void failUnsupportedAuthentication() { createRequestWithReranking( collectionName, """ - { - "enabled": true, - "service": { - "provider": "nvidia", - "modelName": "nvidia/llama-3.2-nv-rerankqa-1b-v2", - "authentication": { - "providerKey" : "myKey" - } - } + { + "enabled": true, + "service": { + "provider": "nvidia", + "modelName": "nvidia/llama-3.2-nv-rerankqa-1b-v2", + "authentication": { + "providerKey" : "myKey" + } } - """); + } + """); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) @@ -436,17 +441,17 @@ void failUnsupportedParameters() { createRequestWithReranking( collectionName, """ - { - "enabled": true, - "service": { - "provider": "nvidia", - "modelName": "nvidia/llama-3.2-nv-rerankqa-1b-v2", - "parameters": { - "test": "test" - } - } + { + "enabled": true, + "service": { + "provider": "nvidia", + "modelName": "nvidia/llama-3.2-nv-rerankqa-1b-v2", + "parameters": { + "test": "test" + } } - """); + } + """); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) @@ -457,53 +462,43 @@ void failUnsupportedParameters() { "Reranking provider 'nvidia' currently doesn't support any parameters. No parameters should be provided.")); } - @Test - void failDeprecatedModel() { - final String collectionName = "coll_Reranking_" + RandomStringUtils.randomNumeric(16); - String json = - createRequestWithReranking( - collectionName, - """ - { - "enabled": true, - "service": { - "provider": "nvidia", - "modelName": "nvidia/a-random-deprecated-model" - } - } - """); - - givenHeadersPostJsonThenOk(json) - .body("$", responseIsError()) - .body("errors[0].errorCode", is("UNSUPPORTED_PROVIDER_MODEL")) - .body( - "errors[0].message", - containsString( - "The model nvidia/a-random-deprecated-model is at DEPRECATED status.")); + private static Stream deprecatedRerankingModelSource() { + return Stream.of( + Arguments.of( + "DEPRECATED", + "nvidia/a-random-deprecated-model", + SchemaException.Code.DEPRECATED_AI_MODEL), + Arguments.of( + "END_OF_LIFE", + "nvidia/a-random-EOL-model", + SchemaException.Code.END_OF_LIFE_AI_MODEL)); } - @Test - void failEOLModel() { + @ParameterizedTest + @MethodSource("deprecatedRerankingModelSource") + public void failDeprecatedEOLRerankModel( + String status, String modelName, SchemaException.Code errorCode) { final String collectionName = "coll_Reranking_" + RandomStringUtils.randomNumeric(16); String json = createRequestWithReranking( collectionName, - """ - { - "enabled": true, - "service": { - "provider": "nvidia", - "modelName": "nvidia/a-random-EOL-model" - } - } - """); + """ + { + "enabled": true, + "service": { + "provider": "nvidia", + "modelName": "%s" + } + } + """ + .formatted(modelName)); givenHeadersPostJsonThenOk(json) .body("$", responseIsError()) - .body("errors[0].errorCode", is("UNSUPPORTED_PROVIDER_MODEL")) .body( "errors[0].message", - containsString("The model nvidia/a-random-EOL-model is at END_OF_LIFE status.")); + containsString("The model is: %s. It is at %s status".formatted(modelName, status))) + .body("errors[0].errorCode", is(errorCode.name())); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateKeyspaceIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateKeyspaceIntegrationTest.java index fe52a32e84..834867b850 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateKeyspaceIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/CreateKeyspaceIntegrationTest.java @@ -1,13 +1,11 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; import io.restassured.RestAssured; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.config.constants.ErrorObjectV2Constants; import io.stargate.sgv2.jsonapi.exception.ErrorFamily; import io.stargate.sgv2.jsonapi.exception.RequestException; @@ -36,20 +34,15 @@ public static void enableLog() { @AfterEach public void deleteKeyspace() { - String json = - """ + givenHeadersAndJson( + """ { "dropKeyspace": { "name": "%s" } } """ - .formatted(DB_NAME); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(DB_NAME)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -64,20 +57,15 @@ class CreateKeyspace { @Test public final void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "createKeyspace": { "name": "%s" } } """ - .formatted(DB_NAME); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(DB_NAME)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -88,20 +76,15 @@ public final void happyPath() { @Test public final void alreadyExists() { - String json = - """ + givenHeadersAndJson( + """ { "createKeyspace": { "name": "%s" } } """ - .formatted(keyspaceName); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(keyspaceName)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -112,8 +95,8 @@ public final void alreadyExists() { @Test public final void withReplicationFactor() { - String json = - """ + givenHeadersAndJson( + """ { "createKeyspace": { "name": "%s", @@ -126,12 +109,7 @@ public final void withReplicationFactor() { } } """ - .formatted(DB_NAME); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(DB_NAME)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -142,18 +120,13 @@ public final void withReplicationFactor() { @Test public void invalidCommand() { - String json = - """ + givenHeadersAndJson( + """ { "createKeyspace": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -174,20 +147,15 @@ class DeprecatedCreateNamespace { @Test public final void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "createNamespace": { "name": "%s" } } """ - .formatted(DB_NAME); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(DB_NAME)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -217,20 +185,15 @@ public final void happyPath() { @Test public final void alreadyExists() { - String json = - """ + givenHeadersAndJson( + """ { "createNamespace": { "name": "%s" } } """ - .formatted(keyspaceName); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(keyspaceName)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -259,8 +222,8 @@ public final void alreadyExists() { @Test public final void withReplicationFactor() { - String json = - """ + givenHeadersAndJson( + """ { "createNamespace": { "name": "%s", @@ -273,12 +236,7 @@ public final void withReplicationFactor() { } } """ - .formatted(DB_NAME); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(DB_NAME)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -308,18 +266,13 @@ public final void withReplicationFactor() { @Test public void invalidCommand() { - String json = - """ - { - "createNamespace": { - } - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson( + """ + { + "createNamespace": { + } + } + """) .when() .post(GeneralResource.BASE_PATH) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/DropKeyspaceIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/DropKeyspaceIntegrationTest.java index eeccc001d0..2d30cdc35c 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/DropKeyspaceIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/DropKeyspaceIntegrationTest.java @@ -1,12 +1,10 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.config.constants.ErrorObjectV2Constants; import io.stargate.sgv2.jsonapi.exception.ErrorFamily; import io.stargate.sgv2.jsonapi.exception.RequestException; @@ -30,20 +28,15 @@ class DropKeyspace { @Test public final void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "dropKeyspace": { "name": "%s" } } """ - .formatted(keyspaceName); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(keyspaceName)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -52,18 +45,13 @@ public final void happyPath() { .body("status.ok", is(1)); // ensure it's dropped - json = - """ + givenHeadersAndJson( + """ { "findKeyspaces": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -77,39 +65,30 @@ public final void withExistingCollection() { String keyspace = "k%s".formatted(RandomStringUtils.randomAlphanumeric(8)).toLowerCase(); String collection = "c%s".formatted(RandomStringUtils.randomAlphanumeric(8)).toLowerCase(); - String createKeyspace = - """ + givenHeadersAndJson( + """ { "createKeyspace": { "name": "%s" } } """ - .formatted(keyspace); - String createCollection = - """ - { - "createCollection": { - "name": "%s" - } - } - """ - .formatted(collection); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createKeyspace) + .formatted(keyspace)) .when() .post(GeneralResource.BASE_PATH) .then() .statusCode(200) .body("$", responseIsStatusOnly()) .body("status.ok", is(1)); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollection) + givenHeadersAndJson( + """ + { + "createCollection": { + "name": "%s" + } + } + """ + .formatted(collection)) .when() .post(KeyspaceResource.BASE_PATH, keyspace) .then() @@ -117,20 +96,15 @@ public final void withExistingCollection() { .body("$", responseIsStatusOnly()) .body("status.ok", is(1)); - String json = - """ + givenHeadersAndJson( + """ { "dropKeyspace": { "name": "%s" } } """ - .formatted(keyspace); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(keyspace)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -139,18 +113,13 @@ public final void withExistingCollection() { .body("status.ok", is(1)); // ensure it's dropped - json = - """ + givenHeadersAndJson( + """ { "findKeyspaces": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -161,19 +130,14 @@ public final void withExistingCollection() { @Test public final void notExisting() { - String json = - """ + givenHeadersAndJson( + """ { "dropKeyspace": { "name": "whatever_not_there" } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -189,20 +153,15 @@ class DeprecatedDropNamespace { @Test public final void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "dropNamespace": { "name": "%s" } } """ - .formatted(keyspaceName); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(keyspaceName)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -234,18 +193,13 @@ public final void happyPath() { containsString("The new command to use is: dropKeyspace.")); // ensure it's dropped - json = - """ + givenHeadersAndJson( + """ { "findKeyspaces": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -259,39 +213,30 @@ public final void withExistingCollection() { String keyspace = "k%s".formatted(RandomStringUtils.randomAlphanumeric(8)).toLowerCase(); String collection = "c%s".formatted(RandomStringUtils.randomAlphanumeric(8)).toLowerCase(); - String createKeyspace = - """ + givenHeadersAndJson( + """ { "createKeyspace": { "name": "%s" } } """ - .formatted(keyspace); - String createCollection = - """ - { - "createCollection": { - "name": "%s" - } - } - """ - .formatted(collection); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createKeyspace) + .formatted(keyspace)) .when() .post(GeneralResource.BASE_PATH) .then() .statusCode(200) .body("$", responseIsStatusOnly()) .body("status.ok", is(1)); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(createCollection) + givenHeadersAndJson( + """ + { + "createCollection": { + "name": "%s" + } + } + """ + .formatted(collection)) .when() .post(KeyspaceResource.BASE_PATH, keyspace) .then() @@ -299,20 +244,15 @@ public final void withExistingCollection() { .body("$", responseIsStatusOnly()) .body("status.ok", is(1)); - String json = - """ + givenHeadersAndJson( + """ { "dropNamespace": { "name": "%s" } } """ - .formatted(keyspace); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(keyspace)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -339,18 +279,13 @@ public final void withExistingCollection() { containsString("The new command to use is: dropKeyspace.")); ; // ensure it's dropped - json = - """ + givenHeadersAndJson( + """ { "findKeyspaces": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -361,19 +296,14 @@ public final void withExistingCollection() { @Test public final void notExisting() { - String json = - """ + givenHeadersAndJson( + """ { "dropNamespace": { "name": "whatever_not_there" } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/EstimatedDocumentCountIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/EstimatedDocumentCountIntegrationTest.java index 77b082f71a..920cc5b90e 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/EstimatedDocumentCountIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/EstimatedDocumentCountIntegrationTest.java @@ -1,14 +1,14 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsWriteSuccess; +import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; +import static org.junit.jupiter.api.Assertions.fail; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.*; import org.slf4j.Logger; @@ -23,14 +23,20 @@ public class EstimatedDocumentCountIntegrationTest extends AbstractCollectionInt private static final Logger LOG = LoggerFactory.getLogger(EstimatedDocumentCountIntegrationTest.class); - public static final int MAX_ITERATIONS = 100; + private static final int MAX_ITERATIONS = 200; + + private static final int DOCS_PER_ITERATION = 4; @Nested @TestMethodOrder(MethodOrderer.OrderAnnotation.class) @Order(1) class Count { - public static final int TIME_TO_SETTLE = 75; + /** + * Time to wait for the estimated document count to settle after a truncate or insertMany (based + * on... observed time needed?) + */ + public static final int TIME_TO_SETTLE_SECS = 75; public static final String JSON_ESTIMATED_COUNT = """ @@ -77,10 +83,6 @@ class Count { { "username": "user4", "indexedObject" : { "0": "value_0", "1": "value_1" } - }, - { - "username": "user5", - "sub_doc" : { "a": 5, "b": { "c": "v1", "d": false } } } ], "options" : { @@ -99,18 +101,19 @@ public void insertDocuments() throws InterruptedException { int tries = 1; while (tries <= MAX_ITERATIONS) { insertMany(); + Thread.sleep(10L); // get count results every N iterations - if (tries % 500 == 0) { + if (tries % 10 == 0) { int estimatedCount = getEstimatedCount(); int actualCount = getActualCount(); - LOG.info( + LOG.warn( "Iteration: " + tries + ", Docs inserted: " - + tries * 5 + + tries * DOCS_PER_ITERATION + ", Actual count: " + actualCount + ", Estimated count: " @@ -125,11 +128,23 @@ public void insertDocuments() throws InterruptedException { } LOG.info( - "Stopping insertion after non-zero estimated count, now waiting " - + TIME_TO_SETTLE + "Stopping insertion after non-zero estimated count, now waiting up to " + + TIME_TO_SETTLE_SECS + " seconds for count to settle"); - Thread.sleep(TIME_TO_SETTLE * 1000); - LOG.info("Final estimated count: " + getEstimatedCount()); + + for (int i = 0; i < TIME_TO_SETTLE_SECS; ++i) { + int estimatedCount = getEstimatedCount(); + if (estimatedCount > 0) { + LOG.info("Final estimated count: " + estimatedCount + " -- test passes"); + return; + } + if (i % 10 == 0) { + LOG.info("Estimated count is still zero, waiting..."); + } + Thread.sleep(1000L); + } + + fail("Estimated count is zero after " + TIME_TO_SETTLE_SECS + " seconds of wait."); } /** @@ -139,6 +154,8 @@ public void insertDocuments() throws InterruptedException { @Test @Order(2) public void truncate() throws InterruptedException { + int estimatedCount = getEstimatedCount(); + LOG.info("Estimated count before truncate: " + estimatedCount); String jsonTruncate = """ @@ -147,48 +164,30 @@ public void truncate() throws InterruptedException { } } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(jsonTruncate) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(jsonTruncate) .body("$", responseIsStatusOnly()) .body("status.deletedCount", is(-1)) .body("status.moreData", is(nullValue())); LOG.info( - "Truncated collection, waiting for estimated count to settle for " - + TIME_TO_SETTLE + "Truncated collection, waiting for estimated count to settle for up to " + + TIME_TO_SETTLE_SECS + " seconds"); - Thread.sleep(TIME_TO_SETTLE * 1000); - LOG.info("Final estimated count after truncate: " + getEstimatedCount()); - - // ensure estimated doc count is zero - // does not find the documents - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(JSON_ESTIMATED_COUNT) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) - .body("$", responseIsStatusOnly()) - .body("status.count", is(0)); + + for (int i = 0; i < TIME_TO_SETTLE_SECS; ++i) { + if (estimatedCount < 1) { + break; + } + if (i % 10 == 0) { + LOG.info("Estimated count still above (" + estimatedCount + "), waiting..."); + } + Thread.sleep(1000); + } + assertThat(estimatedCount).isLessThan(1); } private int getActualCount() { - return given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(JSON_ACTUAL_COUNT) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + return givenHeadersPostJsonThenOk(JSON_ACTUAL_COUNT) .body("$", responseIsStatusOnly()) .extract() .response() @@ -197,14 +196,7 @@ private int getActualCount() { } private int getEstimatedCount() { - return given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(JSON_ESTIMATED_COUNT) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + return givenHeadersPostJsonThenOk(JSON_ESTIMATED_COUNT) .body("$", responseIsStatusOnly()) .extract() .response() @@ -213,15 +205,7 @@ private int getEstimatedCount() { } private void insertMany() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(INSERT_MANY) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) - .body("$", responseIsWriteSuccess()); + givenHeadersPostJsonThenOk(INSERT_MANY).body("$", responseIsWriteSuccess()); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindAndRerankCollectionIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindAndRerankCollectionIntegrationTest.java index 5ec34e000c..b22587a087 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindAndRerankCollectionIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindAndRerankCollectionIntegrationTest.java @@ -1,16 +1,15 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; -import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsDDLSuccess; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; +import io.stargate.sgv2.jsonapi.exception.RequestException; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -23,7 +22,7 @@ @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) public class FindAndRerankCollectionIntegrationTest extends AbstractCollectionIntegrationTestBase { - // used to cleanup the collection from a previous test, if non null + // used to cleanup the collection from a previous test, if non-null private String cleanupCollectionName = null; @BeforeAll @@ -34,56 +33,12 @@ public final void createDefaultCollection() { @AfterEach public void cleanup() { if (cleanupCollectionName != null) { - var json = - """ - {"dropCollection": {"name": "%s"}} - """ - .formatted(cleanupCollectionName); - given() - .port(getTestPort()) - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) - .body("$", responseIsDDLSuccess()); + String toDelete = cleanupCollectionName; + cleanupCollectionName = null; // clear out just in case + deleteCollection(toDelete); } } - public void errorOnNotEnabled( - String collectionName, String collectionSpec, String errorCode, String errorMessageContains) { - - var createCollection = collectionSpec.formatted(collectionName); - createComplexCollection(createCollection); - - var rerank = - """ - {"findAndRerank": { - "filter": {}, - "projection": {}, - "sort": { - "$hybrid": "hybrid sort" - }, - "options": { - "limit" : 10, - "hybridLimits" : 10, - "includeScores": true, - "includeSortVector": false - } - } - } - """; - - givenHeadersPostJsonThen(keyspaceName, collectionName, rerank) - .body("$", responseIsError()) - .body("errors[0].errorCode", is(errorCode)) - .body( - "errors[0].message", - containsString(errorMessageContains.formatted(keyspaceName, collectionName))); - } - @Test public void failOnVectorDisabled() { errorOnNotEnabled( @@ -99,7 +54,7 @@ public void failOnVectorDisabled() { } @Test - public void voidFailOnVectorizeDisabled() { + void failOnVectorizeDisabled() { errorOnNotEnabled( "vectorize_not_enabled", """ @@ -118,7 +73,7 @@ public void voidFailOnVectorizeDisabled() { } @Test - public void voidFailOnLexicalDisabled() { + void failOnLexicalDisabled() { errorOnNotEnabled( "lexical_not_enabled", """ @@ -142,4 +97,78 @@ public void voidFailOnLexicalDisabled() { "LEXICAL_NOT_ENABLED_FOR_COLLECTION", "Lexical search is not enabled for collection"); } + + // https://github.com/stargate/data-api/issues/2057 + @Test + void failOnEmptyRequest() { + // Must not fail for "no lexical available", so skip on DSE + Assumptions.assumeTrue(isLexicalAvailableForDB()); + + String collectionName = "find_rerank_empty_request"; + createCollectionWithCleanup( + collectionName, + """ + { + "name" : "%s", + "options": { + "vector": { + "metric": "cosine", + "dimension": 1024, + "service": { + "provider": "openai", + "modelName": "text-embedding-3-small" + } + }, + "lexical": { + "enabled": true, + "analyzer": "standard" + } + } + } + """); + + givenHeadersPostJsonThen(keyspaceName, collectionName, "{\"findAndRerank\": { } }") + .body("$", responseIsError()) + .body("errors[0].errorCode", is(RequestException.Code.MISSING_RERANK_QUERY_TEXT.name())) + .body( + "errors[0].message", + containsString( + "findAndRerank command is missing the text to use as the query with the reranking")); + } + + private void errorOnNotEnabled( + String collectionName, String collectionSpec, String errorCode, String errorMessageContains) { + createCollectionWithCleanup(collectionName, collectionSpec); + + var rerank = + """ + {"findAndRerank": { + "filter": {}, + "projection": {}, + "sort": { + "$hybrid": "hybrid sort" + }, + "options": { + "limit" : 10, + "hybridLimits" : 10, + "includeScores": true, + "includeSortVector": false + } + } + } + """; + + givenHeadersPostJsonThen(keyspaceName, collectionName, rerank) + .body("$", responseIsError()) + .body("errors[0].errorCode", is(errorCode)) + .body( + "errors[0].message", + containsString(errorMessageContains.formatted(keyspaceName, collectionName))); + } + + private void createCollectionWithCleanup(String collectionName, String collectionSpec) { + createComplexCollection(collectionSpec.formatted(collectionName)); + // save the collection name for cleanup, but only after successful creation + cleanupCollectionName = collectionName; + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalSortIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalIntegrationTest.java similarity index 70% rename from src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalSortIntegrationTest.java rename to src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalIntegrationTest.java index 8ddaca72b4..30fa4b0a39 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalSortIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalIntegrationTest.java @@ -1,5 +1,6 @@ package io.stargate.sgv2.jsonapi.api.v1; +import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsFindSuccess; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; @@ -32,7 +33,7 @@ @QuarkusIntegrationTest @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) @TestClassOrder(ClassOrderer.OrderAnnotation.class) -public class FindCollectionWithLexicalSortIntegrationTest +public class FindCollectionWithLexicalIntegrationTest extends AbstractCollectionIntegrationTestBase { static final String COLLECTION_WITH_LEXICAL = "coll_lexical_sort_" + RandomStringUtils.randomNumeric(16); @@ -40,11 +41,11 @@ public class FindCollectionWithLexicalSortIntegrationTest static final String COLLECTION_WITHOUT_LEXICAL = "coll_no_lexical_sort_" + RandomStringUtils.randomNumeric(16); - static final String DOC1_JSON = lexicalDoc(1, "monkey banana", "value1"); - static final String DOC2_JSON = lexicalDoc(2, "monkey", "value2"); - static final String DOC3_JSON = lexicalDoc(3, "biking fun", "value3"); - static final String DOC4_JSON = lexicalDoc(4, "banana bread with butter", "value4"); - static final String DOC5_JSON = lexicalDoc(5, "fun", "value5"); + static final String DOC1_JSON = lexicalDoc(1, "monkey banana", "value1", "top"); + static final String DOC2_JSON = lexicalDoc(2, "monkey", "value2", "top"); + static final String DOC3_JSON = lexicalDoc(3, "biking fun", "value3", "middle"); + static final String DOC4_JSON = lexicalDoc(4, "banana bread with butter", "value4", "bottom"); + static final String DOC5_JSON = lexicalDoc(5, "fun", "value5", "bottom"); @DisabledIfSystemProperty(named = TEST_PROP_LEXICAL_DISABLED, matches = "true") @Nested @@ -113,6 +114,50 @@ void findManyWithLexicalSort() { .body("data.documents[0]._id", is("lexical-1")) .body("data.documents[1]._id", is("lexical-4")); } + + @Test + void findManyWithOnlyLexicalFilter() { + givenHeadersPostJsonThenOkNoErrors( + keyspaceName, + COLLECTION_WITH_LEXICAL, + """ + { + "find": { + "filter" : { + "$lexical": { + "$match": "biking" + } + } + } + } + """) + .body("$", responseIsFindSuccess()) + .body("data.documents", hasSize(1)) + .body("data.documents[0]._id", is("lexical-3")); + } + + @Test + void findManyWithLexicalAndOtherFilter() { + // Lexical brings 2, tag 2; intersection is 1 + givenHeadersPostJsonThenOkNoErrors( + keyspaceName, + COLLECTION_WITH_LEXICAL, + """ + { + "find": { + "filter" : { + "$and": [ + { "$lexical": { "$match": "banana" } }, + { "tag": "bottom" } + ] + } + } + } + """) + .body("$", responseIsFindSuccess()) + .body("data.documents", hasSize(1)) + .body("data.documents[0]._id", is("lexical-4")); + } } @DisabledIfSystemProperty(named = TEST_PROP_LEXICAL_DISABLED, matches = "true") @@ -154,6 +199,24 @@ void findOneWithLexicalSortMonkeyBananas() { // Needs to get "lexical-1" with "monkey banana" .body("data.document", jsonEquals(DOC1_JSON)); } + + @Test + void findOneWithOnlyLexicalFilter() { + givenHeadersPostJsonThenOkNoErrors( + keyspaceName, + COLLECTION_WITH_LEXICAL, + """ + { + "findOne": { + "projection": {"$lexical": 1 }, + "filter" : {"$lexical": {"$match": "bread butter" } } + } + } + """) + .body("$", responseIsFindSuccess()) + // Needs to get "lexical-4" + .body("data.document", jsonEquals(DOC4_JSON)); + } } @DisabledIfSystemProperty(named = TEST_PROP_LEXICAL_DISABLED, matches = "true") @@ -161,7 +224,7 @@ void findOneWithLexicalSortMonkeyBananas() { @Order(3) class FailingCasesFindMany { @Test - void failIfLexicalDisabledForCollection() { + void failSortIfLexicalDisabledForCollection() { givenHeadersPostJsonThenOk( keyspaceName, COLLECTION_WITHOUT_LEXICAL, @@ -177,6 +240,23 @@ void failIfLexicalDisabledForCollection() { .body("errors[0].message", containsString("Lexical search is not enabled")); } + @Test + void failFilterIfLexicalDisabledForCollection() { + givenHeadersPostJsonThenOk( + keyspaceName, + COLLECTION_WITHOUT_LEXICAL, + """ + { + "find": { + "filter" : {"$lexical": {"$match": "banana" } } + } + } + """) + .body("errors", hasSize(1)) + .body("errors[0].errorCode", is("LEXICAL_NOT_ENABLED_FOR_COLLECTION")) + .body("errors[0].message", containsString("Lexical search is not enabled")); + } + @Test void failForBadLexicalSortValueType() { givenHeadersPostJsonThenOk( @@ -196,6 +276,26 @@ void failForBadLexicalSortValueType() { containsString("if sorting by '$lexical' value must be String, not Number")); } + @Test + void failForBadLexicalFilterValueType() { + givenHeadersPostJsonThenOk( + keyspaceName, + COLLECTION_WITH_LEXICAL, + """ + { + "find": { + "filter" : {"$lexical": {"$match": [ 1, 2, 3 ] } } + } + } + """) + .body("errors", hasSize(1)) + .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) + .body( + "errors[0].message", + containsString( + "Invalid filter expression: $match operator must have `String` value, was `Array`")); + } + @Test void failForLexicalSortWithOtherExpressions() { givenHeadersPostJsonThenOk( @@ -217,6 +317,48 @@ void failForLexicalSortWithOtherExpressions() { "errors[0].message", containsString("if sorting by '$lexical' no other sort expressions allowed")); } + + // No way to do "$not" with "$match" (not supported by DBs) + @Test + void failForLexicalFilterWithNot() { + givenHeadersPostJsonThenOk( + keyspaceName, + COLLECTION_WITH_LEXICAL, + """ + { + "find": { + "filter" : {"$not": {"$lexical": {"$match": "banana" } }}} + } + } + """) + .body("errors", hasSize(1)) + .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) + .body( + "errors[0].message", + containsString( + "Invalid filter expression: cannot use $not to invert $match operator")); + } + + // Can only use $match with $lexical, not $eq, $ne, etc. + @Test + public void failForEqFilteringOnLexical() { + for (String filter : + new String[] { + "{\"$lexical\": \"quick brown fox\"}", "{\"$lexical\": {\"$eq\": \"quick brown fox\"}}" + }) { + givenHeadersPostJsonThenOk( + keyspaceName, + COLLECTION_WITH_LEXICAL, + "{ \"findOne\": { \"filter\" : %s}}".formatted(filter)) + .body("$", responseIsError()) + .body("errors", hasSize(1)) + .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) + .body( + "errors[0].message", + containsString( + "Cannot filter on '$lexical' field using operator $eq: only $match is supported")); + } + } } @DisabledIfSystemProperty(named = TEST_PROP_LEXICAL_DISABLED, matches = "true") @@ -225,7 +367,7 @@ void failForLexicalSortWithOtherExpressions() { class HappyCasesFindOneAndUpdate { @Test void findOneAndUpdateWithSort() { - final String expectedAfterChange = lexicalDoc(1, "monkey banana", "value1-updated"); + final String expectedAfterChange = lexicalDoc(1, "monkey banana", "value1-updated", "top"); givenHeadersPostJsonThenOkNoErrors( keyspaceName, COLLECTION_WITH_LEXICAL, @@ -265,7 +407,7 @@ void findOneAndUpdateWithSort() { class HappyCasesUpdateOne { @Test void updateOneWithSort() { - final String expectedAfterChange = lexicalDoc(1, "monkey banana", "value1-updated-2"); + final String expectedAfterChange = lexicalDoc(1, "monkey banana", "value1-updated-2", "top"); givenHeadersPostJsonThenOkNoErrors( keyspaceName, COLLECTION_WITH_LEXICAL, @@ -302,7 +444,7 @@ void updateOneWithSort() { class HappyCasesFindOneAndReplace { @Test void findOneAndReplaceWithSort() { - final String expectedAfterChange = lexicalDoc(1, "monkey banana", "value1-replaced"); + final String expectedAfterChange = lexicalDoc(1, "monkey banana", "value1-replaced", "top"); givenHeadersPostJsonThenOkNoErrors( keyspaceName, COLLECTION_WITH_LEXICAL, @@ -364,7 +506,7 @@ void findOneAndDeleteWithSort() { """ { "find": { - "projection": {"_id": 1, "value": 0 } + "projection": {"_id": 1, "value": 0, "tag": 0 } } } """) @@ -415,21 +557,22 @@ void deleteOneWithSort() { .body( "data.documents", containsInAnyOrder( - Map.of("_id", "lexical-1"), - Map.of("_id", "lexical-4"), - Map.of("_id", "lexical-5"))); + Map.of("_id", "lexical-1", "tag", "top"), + Map.of("_id", "lexical-4", "tag", "bottom"), + Map.of("_id", "lexical-5", "tag", "bottom"))); } } - static String lexicalDoc(int id, String keywords, String value) { + static String lexicalDoc(int id, String keywords, String value, String tag) { return """ { "_id": "lexical-%d", "$lexical": "%s", - "value": "%s" + "value": "%s", + "tag": "%s" } """ - .formatted(id, keywords, value); + .formatted(id, keywords, value, tag); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithSortIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithSortIntegrationTest.java index 2c6cef8a83..fcf0b3923c 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithSortIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithSortIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsFindSuccess; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; @@ -15,7 +14,6 @@ import com.fasterxml.uuid.Generators; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import io.stargate.sgv2.jsonapi.util.JsonNodeComparator; import java.util.ArrayList; @@ -49,30 +47,23 @@ public void setUp() { @Test public void sortByTextAndNullValue() throws Exception { sortByUserName(testDatas, true); - String json = - """ - { - "find": { - "sort" : {"username" : 1} - } - } - """; JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(20); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 20; i++) { arrayNode.add( objectMapper.readTree( objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(testDatas.get(i)))); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + } + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"username" : 1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(20)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -81,16 +72,6 @@ public void sortByTextAndNullValue() throws Exception { @Test public void sortWithSkipLimit() throws Exception { sortByUserName(testDatas, true); - String json = - """ - { - "find": { - "sort" : {"username" : 1}, - "options" : {"skip": 10, "limit" : 10} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(10); for (int i = 0; i < 20; i++) { @@ -103,14 +84,15 @@ public void sortWithSkipLimit() throws Exception { } } - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"username" : 1}, + "options" : {"skip": 10, "limit" : 10} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(10)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -119,32 +101,24 @@ public void sortWithSkipLimit() throws Exception { @Test public void sortDescendingTextValue() throws Exception { sortByUserName(testDatas, false); - String json = - """ - { - "find": { - "sort" : {"username" : -1} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(20); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 20; i++) { arrayNode.add( objectMapper.readTree( objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(testDatas.get(i)))); + } - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"username" : -1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(20)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -153,32 +127,23 @@ public void sortDescendingTextValue() throws Exception { @Test public void sortBooleanValueAndMissing() throws Exception { sortByActiveUser(testDatas, true); - String json = - """ - { - "find": { - "sort" : {"activeUser" : 1} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(20); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 20; i++) { arrayNode.add( objectMapper.readTree( objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(testDatas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + } + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"activeUser" : 1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(20)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -187,32 +152,23 @@ public void sortBooleanValueAndMissing() throws Exception { @Test public void sortBooleanValueAndMissingDescending() throws Exception { sortByActiveUser(testDatas, false); - String json = - """ - { - "find": { - "sort" : {"activeUser" : -1} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(20); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 20; i++) { arrayNode.add( objectMapper.readTree( objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(testDatas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + } + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"activeUser" : -1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(20)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -221,32 +177,23 @@ public void sortBooleanValueAndMissingDescending() throws Exception { @Test public void sortNumericField() throws Exception { sortByUserId(testDatas, true); - String json = - """ - { - "find": { - "sort" : {"userId" : 1} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(20); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 20; i++) { arrayNode.add( objectMapper.readTree( objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(testDatas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + } + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"userId" : 1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(20)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -255,32 +202,23 @@ public void sortNumericField() throws Exception { @Test public void sortNumericFieldDescending() throws Exception { sortByUserId(testDatas, false); - String json = - """ - { - "find": { - "sort" : {"userId" : -1} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(20); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 20; i++) { arrayNode.add( objectMapper.readTree( objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(testDatas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + } + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"userId" : -1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(20)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -293,31 +231,23 @@ public void sortNumericFieldAndFilter() throws Exception { .filter(obj -> (obj instanceof TestData o) && o.activeUser()) .collect(Collectors.toList()); sortByUserId(datas, true); - String json = - """ - { - "find": { - "filter" : {"activeUser" : true}, - "sort" : {"userId" : 1} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(datas.size()); - for (int i = 0; i < datas.size(); i++) + for (int i = 0; i < datas.size(); i++) { arrayNode.add( objectMapper.readTree( objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(datas.get(i)))); + } - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : {"activeUser" : true}, + "sort" : {"userId" : 1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(Math.min(20, datas.size()))) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -326,32 +256,23 @@ public void sortNumericFieldAndFilter() throws Exception { @Test public void sortMultiColumns() throws Exception { sortByUserNameUserId(testDatas, true, true); - String json = - """ - { - "find": { - "sort" : {"username" : 1, "userId" : 1} - } - } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); final ArrayNode arrayNode = nodefactory.arrayNode(20); - for (int i = 0; i < 20; i++) + for (int i = 0; i < 20; i++) { arrayNode.add( objectMapper.readTree( objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(testDatas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + } + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "sort" : {"username" : 1, "userId" : 1} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(20)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -364,30 +285,22 @@ public void sortMultiColumnsMixedOrder() throws Exception { .filter(obj -> (obj instanceof TestData o) && o.activeUser()) .collect(Collectors.toList()); sortByUserNameUserId(datas, true, false); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(datas.size()); + for (int i = 0; i < datas.size(); i++) { + arrayNode.add( + objectMapper.readTree( + objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(datas.get(i)))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"activeUser" : true}, "sort" : {"username" : 1, "userId" : -1} } } - """; - - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(datas.size()); - for (int i = 0; i < datas.size(); i++) - arrayNode.add( - objectMapper.readTree( - objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(datas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(Math.min(20, datas.size()))) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -400,30 +313,22 @@ public void sortByDate() throws Exception { .filter(obj -> (obj instanceof TestData o) && o.activeUser()) .collect(Collectors.toList()); sortByDate(datas, true); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(datas.size()); + for (int i = 0; i < datas.size(); i++) { + arrayNode.add( + objectMapper.readTree( + objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(datas.get(i)))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"activeUser" : true}, "sort" : {"dateValue" : 1} } } - """; - - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(datas.size()); - for (int i = 0; i < datas.size(); i++) - arrayNode.add( - objectMapper.readTree( - objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(datas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(Math.min(20, datas.size()))) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -436,30 +341,22 @@ public void sortByDateDescending() throws Exception { .filter(obj -> (obj instanceof TestData o) && o.activeUser()) .collect(Collectors.toList()); sortByDate(datas, false); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(datas.size()); + for (int i = 0; i < datas.size(); i++) { + arrayNode.add( + objectMapper.readTree( + objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(datas.get(i)))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"activeUser" : true}, "sort" : {"dateValue" : -1} } } - """; - - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(datas.size()); - for (int i = 0; i < datas.size(); i++) - arrayNode.add( - objectMapper.readTree( - objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(datas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(Math.min(20, datas.size()))) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -475,9 +372,17 @@ public void sortByUUID() throws Exception { sortByUUID(datas, true); // Create a sublist of the first 20 elements List first20Datas = new ArrayList<>(datas.subList(0, Math.min(20, datas.size()))); - - String json = - """ + JsonNodeFactory nodeFactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodeFactory.arrayNode(first20Datas.size()); + for (int i = 0; i < first20Datas.size(); i++) { + arrayNode.add( + objectMapper.readTree( + objectMapper + .writerWithDefaultPrettyPrinter() + .writeValueAsString(first20Datas.get(i)))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter":{ @@ -486,24 +391,7 @@ public void sortByUUID() throws Exception { "sort" : {"uuid" : 1} } } - """; - - JsonNodeFactory nodeFactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodeFactory.arrayNode(first20Datas.size()); - for (int i = 0; i < first20Datas.size(); i++) - arrayNode.add( - objectMapper.readTree( - objectMapper - .writerWithDefaultPrettyPrinter() - .writeValueAsString(first20Datas.get(i)))); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(Math.min(20, first20Datas.size()))) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -641,22 +529,16 @@ public void setUp() { @Test public void sortFailDueToTooMany() { - String json = - """ + givenHeadersPostJsonThenOk( + keyspaceName, + biggerCollectionName, + """ { "find": { "sort" : {"username" : 1} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, biggerCollectionName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) .body("errors", hasSize(1)) .body("errors[0].exceptionClass", is("JsonApiException")) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java index ddeeb62e81..d920b8c760 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionsIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsDDLSuccess; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; @@ -13,7 +12,6 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.Assumptions; @@ -42,23 +40,13 @@ class FindCollections { * this default keyspace. */ public void checkNamespaceHasNoCollections() { - // then find - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findCollections": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(0)); } @@ -67,10 +55,7 @@ public void checkNamespaceHasNoCollections() { @Order(2) public void happyPath() { // create first - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "createCollection": { @@ -78,27 +63,16 @@ public void happyPath() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // then find - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findCollections": { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(greaterThanOrEqualTo(1))) .body("status.collections", hasItem("collection1")); @@ -110,8 +84,8 @@ public void happyPathWithExplain() { // To create Collection with Lexical, it must be available for the database Assumptions.assumeTrue(isLexicalAvailableForDB()); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "createCollection": { "name": "collection2", @@ -137,16 +111,7 @@ public void happyPathWithExplain() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); @@ -197,8 +162,8 @@ public void happyPathWithExplain() { } """; - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findCollections": { "options": { @@ -206,16 +171,7 @@ public void happyPathWithExplain() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(2)) .body( @@ -226,10 +182,7 @@ public void happyPathWithExplain() { @Test @Order(4) public void happyPathWithMixedCase() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "createCollection": { @@ -237,27 +190,16 @@ public void happyPathWithMixedCase() { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); // then find - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findCollections": { } } """) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(greaterThanOrEqualTo(1))) .body("status.collections", hasItem("TableName")); @@ -268,20 +210,15 @@ public void happyPathWithMixedCase() { public void emptyNamespace() { // create namespace first String namespace = "nam" + RandomStringUtils.randomNumeric(16); - String json = - """ + givenHeadersAndJson( + """ { "createNamespace": { "name": "%s" } } """ - .formatted(namespace); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(namespace)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -290,18 +227,13 @@ public void emptyNamespace() { .body("status.ok", is(1)); // then find - json = - """ + givenHeadersAndJson( + """ { "findCollections": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(KeyspaceResource.BASE_PATH, namespace) .then() @@ -310,20 +242,15 @@ public void emptyNamespace() { .body("status.collections", hasSize(0)); // cleanup - json = - """ + givenHeadersAndJson( + """ { "dropNamespace": { "name": "%s" } } """ - .formatted(namespace); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(namespace)) .when() .post(GeneralResource.BASE_PATH) .then() @@ -335,19 +262,13 @@ public void emptyNamespace() { @Test @Order(6) public void notExistingNamespace() { - // then find - String json = - """ + givenHeadersAndJson( + """ { "findCollections": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(KeyspaceResource.BASE_PATH, "should_not_be_there") .then() @@ -365,8 +286,8 @@ public void happyPathIndexingWithExplain() { // To create Collection with Lexical, it must be available for the database Assumptions.assumeTrue(isLexicalAvailableForDB()); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "createCollection": { "name": "collection4", @@ -386,22 +307,14 @@ public void happyPathIndexingWithExplain() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); String expected1 = """ - {"name":"TableName","options":{ + {"name":"TableName", + "options":{ "lexical": { "enabled": true, "analyzer": "standard" @@ -467,8 +380,9 @@ public void happyPathIndexingWithExplain() { } } """; - json = - """ + + givenHeadersPostJsonThenOkNoErrors( + """ { "findCollections": { "options": { @@ -476,16 +390,7 @@ public void happyPathIndexingWithExplain() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.collections", hasSize(4)) .body( diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindEmbeddingProvidersIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindEmbeddingProvidersIntegrationTest.java index 1e84816451..9be2ec7e88 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindEmbeddingProvidersIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindEmbeddingProvidersIntegrationTest.java @@ -1,14 +1,18 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; +import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; +import java.util.stream.Stream; import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; @QuarkusIntegrationTest @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) @@ -20,19 +24,14 @@ class FindEmbeddingProviders { @Test public final void happyPath() { - String json = - """ + // without option specified, only return supported models + givenHeadersAndJson( + """ { "findEmbeddingProviders": { } } - """; - - given() - .port(getTestPort()) - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -40,7 +39,109 @@ public final void happyPath() { .body("$", responseIsStatusOnly()) .body("status.embeddingProviders", notNullValue()) .body("status.embeddingProviders.nvidia.url", notNullValue()) - .body("status.embeddingProviders.nvidia.models[0].vectorDimension", equalTo(1024)); + .body("status.embeddingProviders.nvidia.models[0].vectorDimension", equalTo(1024)) + .body("status.embeddingProviders.nvidia.models[0].name", equalTo("NV-Embed-QA")) + .body( + "status.embeddingProviders.nvidia.models[0].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.SUPPORTED.name())); + } + + private static Stream returnedAllStatus() { + return Stream.of( + // emtpy string + Arguments.of("\"\""), + // null + Arguments.of("null")); + } + + @ParameterizedTest() + @MethodSource("returnedAllStatus") + public final void returnModelsWithAllStatus(String filterModelStatus) { + givenHeadersAndJson( + """ + { + "findEmbeddingProviders": { + "options": { + "filterModelStatus": %s + } + } + } + """ + .formatted(filterModelStatus)) + .when() + .post(GeneralResource.BASE_PATH) + .then() + .statusCode(200) + .body("$", responseIsStatusOnly()) + .body("status.embeddingProviders", notNullValue()) + .body("status.embeddingProviders.nvidia.models", hasSize(3)) + .body("status.embeddingProviders.nvidia.models[0].name", equalTo("NV-Embed-QA")) + .body( + "status.embeddingProviders.nvidia.models[0].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.SUPPORTED.name())) + .body( + "status.embeddingProviders.nvidia.models[1].name", + equalTo("a-EOL-nvidia-embedding-model")) + .body( + "status.embeddingProviders.nvidia.models[1].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.END_OF_LIFE.name())) + .body( + "status.embeddingProviders.nvidia.models[2].name", + equalTo("a-deprecated-nvidia-embedding-model")) + .body( + "status.embeddingProviders.nvidia.models[2].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.DEPRECATED.name())); + } + + @Test + public final void returnModelsWithSpecifiedStatus() { + givenHeadersAndJson( + """ + { + "findEmbeddingProviders": { + "options": { + "filterModelStatus": "deprecated" + } + } + } + """) + .when() + .post(GeneralResource.BASE_PATH) + .then() + .statusCode(200) + .body("$", responseIsStatusOnly()) + .body("status.embeddingProviders", notNullValue()) + .body("status.embeddingProviders.nvidia.models", hasSize(1)) + .body( + "status.embeddingProviders.nvidia.models[0].name", + equalTo("a-deprecated-nvidia-embedding-model")) + .body( + "status.embeddingProviders.nvidia.models[0].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.DEPRECATED.name())); + } + + @Test + public final void failedWithRandomStatus() { + givenHeadersAndJson( + """ + { + "findEmbeddingProviders": { + "options": { + "filterModelStatus": "random" + } + } + } + """) + .when() + .post(GeneralResource.BASE_PATH) + .then() + .statusCode(200) + .body("$", responseIsError()) + .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) + .body( + "errors[0].message", + containsString( + "field 'command.options.filterModelStatus' value \"random\" not valid")); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java index c3b97a5596..ba49d4b924 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindIntegrationTest.java @@ -1,13 +1,11 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.*; @@ -117,8 +115,10 @@ public void setUp() { @Test public void wrongKeyspace() { - String json = - """ + givenHeadersPostJsonThenOk( + "something_else", + collectionName, + """ { "find": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -127,16 +127,9 @@ public void wrongKeyspace() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, "something_else", collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) + .body("errors", hasSize(1)) .body("errors[0].message", is("The provided keyspace does not exist: something_else")) .body("errors[0].errorCode", is("KEYSPACE_DOES_NOT_EXIST")) .body("errors[0].exceptionClass", is("JsonApiException")); @@ -274,10 +267,7 @@ public void byId() { // https://github.com/stargate/jsonapi/issues/572 -- is passing empty Object for "sort" ok? @Test public void byIdEmptySort() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "find": { @@ -288,10 +278,6 @@ public void byIdEmptySort() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindSuccess()) .body( "data.documents[0]", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindKeyspacesIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindKeyspacesIntegrationTest.java index aff95c821a..1ae0e031c3 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindKeyspacesIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindKeyspacesIntegrationTest.java @@ -1,12 +1,10 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsDDLSuccess; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.config.constants.ErrorObjectV2Constants; import io.stargate.sgv2.jsonapi.exception.ErrorFamily; import io.stargate.sgv2.jsonapi.exception.RequestException; @@ -29,18 +27,13 @@ class FindKeyspaces { @Test public final void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "findKeyspaces": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -57,18 +50,13 @@ class DeprecatedFindNamespaces { @Test public final void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "findNamespaces": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndDeleteIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndDeleteIntegrationTest.java index 77b1b7153f..8be8c2780a 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndDeleteIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndDeleteIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.assertj.core.api.Assertions.assertThat; @@ -8,7 +7,6 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReferenceArray; @@ -23,7 +21,7 @@ public class FindOneAndDeleteIntegrationTest extends AbstractCollectionIntegrati class FindOneAndDelete { @Test public void byId() { - String document = + final String document = """ { "_id": "doc3", @@ -33,82 +31,56 @@ public void byId() { """; insertDoc(document); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndDelete": { "filter" : {"_id" : "doc3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(document)) .body("status.deletedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void byIdNoData() { - String document = + insertDoc( """ { "_id": "doc3", "username": "user3", "active_user" : true } - """; - insertDoc(document); - - String json = - """ + """); + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndDelete": { "filter" : {"_id" : "doc5"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("status.deletedCount", is(0)); } @Test public void withSortDesc() { - String document = + final String document = """ { "_id": "doc3", @@ -118,7 +90,7 @@ public void withSortDesc() { """; insertDoc(document); - String document1 = + final String document1 = """ { "_id": "doc2", @@ -128,51 +100,35 @@ public void withSortDesc() { """; insertDoc(document1); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndDelete": { "filter" : {"active_user" : true}, "sort" : {"username" : -1} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(document)) .body("status.deletedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void withSort() { - String document = + final String document = """ { "_id": "doc3", @@ -182,7 +138,7 @@ public void withSort() { """; insertDoc(document); - String document1 = + final String document1 = """ { "_id": "doc2", @@ -192,51 +148,35 @@ public void withSort() { """; insertDoc(document1); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndDelete": { "filter" : {"active_user" : true}, "sort" : {"username" : 1} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(document1)) .body("status.deletedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc2"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void withSortProjection() { - String document = + final String document = """ { "_id": "doc3", @@ -246,7 +186,7 @@ public void withSortProjection() { """; insertDoc(document); - String document1 = + final String document1 = """ { "_id": "doc2", @@ -256,15 +196,8 @@ public void withSortProjection() { """; insertDoc(document1); - String expected = - """ - { - "username": "user2" - } - """; - - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndDelete": { "filter" : {"active_user" : true}, @@ -272,36 +205,27 @@ public void withSortProjection() { "projection" : { "_id":0, "username":1 } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) - .body("data.document", jsonEquals(expected)) + .body( + "data.document", + jsonEquals( + """ + { + "username": "user2" + } + """)) .body("status.deletedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc2"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.document", is(nullValue())); } @@ -329,14 +253,6 @@ class ConcurrentDelete { @Test public void findOneAndDelete() throws Exception { insertDocuments(); - String json = - """ - { - "findOneAndDelete": { - "filter" : {"name" : "Logic Layers"} - } - } - """; int threads = 5; AtomicReferenceArray assertionErrors = new AtomicReferenceArray<>(threads); @@ -348,14 +264,14 @@ public void findOneAndDelete() throws Exception { new Thread( () -> { try { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "findOneAndDelete": { + "filter" : {"name" : "Logic Layers"} + } + } + """) .body("$", responseIsFindAndSuccess()) .body("status.deletedCount", anyOf(is(0), is(1))); } catch (AssertionError e) { @@ -380,15 +296,6 @@ public void findOneAndDelete() throws Exception { @Test public void findOneAndDeleteProjection() throws Exception { insertDocuments(); - String json = - """ - { - "findOneAndDelete": { - "filter" : {"name" : "Coded Cleats"}, - "projection" : {"name" : 1} - } - } - """; int threads = 5; AtomicReferenceArray assertionErrors = new AtomicReferenceArray<>(threads); CountDownLatch latch = new CountDownLatch(threads); @@ -399,14 +306,15 @@ public void findOneAndDeleteProjection() throws Exception { new Thread( () -> { try { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "findOneAndDelete": { + "filter" : {"name" : "Coded Cleats"}, + "projection" : {"name" : 1} + } + } + """) .body("$", responseIsFindAndSuccess()) .body("status.deletedCount", anyOf(is(0), is(1))); } catch (AssertionError e) { @@ -430,8 +338,8 @@ public void findOneAndDeleteProjection() throws Exception { } public void insertDocuments() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "insertMany": { "documents": [ @@ -458,14 +366,7 @@ public void insertDocuments() { ] } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() + """) .body("$", responseIsWriteSuccess()) .body("status.insertedIds[0]", not(emptyString())) .statusCode(200); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndReplaceIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndReplaceIntegrationTest.java index 4d50d03d53..ff838e66cd 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndReplaceIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndReplaceIntegrationTest.java @@ -1,9 +1,9 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; @@ -11,7 +11,6 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.ClassOrderer; @@ -29,7 +28,7 @@ public class FindOneAndReplaceIntegrationTest extends AbstractCollectionIntegrat class FindOneAndReplace { @Test public void byId() { - String document = + final String document = """ { "_id": "doc3", @@ -39,61 +38,45 @@ public void byId() { """; insertDoc(document); - String expected = - """ - { - "_id": "doc3", - "username": "user3", - "status" : false - } - """; - - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {"_id" : "doc3"}, "replacement" : { "username": "user3", "status" : false } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(document)) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "doc3", + "username": "user3", + "status" : false + } + """)); } @Test public void byIdWithId() { - String document = + final String document = """ { "_id": "doc3", @@ -103,61 +86,45 @@ public void byIdWithId() { """; insertDoc(document); - String expected = - """ - { - "_id": "doc3", - "username": "user3", - "status" : false - } - """; - - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {"_id" : "doc3"}, "replacement" : {"_id" : "doc3", "username": "user3", "status" : false } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(document)) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "doc3", + "username": "user3", + "status" : false + } + """)); } @Test public void byIdWithIdNoChange() { - String document = + final String document = """ { "_id": "doc3", @@ -167,52 +134,36 @@ public void byIdWithIdNoChange() { """; insertDoc(document); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {"_id" : "doc3"}, "replacement" : {"_id" : "doc3", "username": "user3", "active_user" : true } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(document)) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(0)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents[0]", jsonEquals(document)); } @Test public void withSort() { - String document = + final String document = """ { "_id": "doc3", @@ -222,7 +173,7 @@ public void withSort() { """; insertDoc(document); - String document1 = + final String document1 = """ { "_id": "doc2", @@ -231,7 +182,7 @@ public void withSort() { } """; insertDoc(document1); - String expected = + final String expected = """ { "_id": "doc2", @@ -240,8 +191,8 @@ public void withSort() { } """; - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {"active_user" : true}, @@ -250,44 +201,28 @@ public void withSort() { "options" : {"returnDocument" : "after"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(expected)) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc2"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents[0]", jsonEquals(expected)); } @Test public void withUpsert() { - String expected = + final String expected = """ { "_id": "doc2", @@ -296,8 +231,8 @@ public void withUpsert() { } """; - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {"_id" : "doc2"}, @@ -305,15 +240,7 @@ public void withUpsert() { "options" : {"returnDocument" : "after", "upsert" : true} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(expected)) .body("status.matchedCount", is(0)) @@ -321,22 +248,14 @@ public void withUpsert() { .body("status.upsertedId", is("doc2")); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc2"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents[0]", jsonEquals(expected)); } @@ -344,8 +263,8 @@ public void withUpsert() { @Test public void withUpsertNewId() { final String newId = "new-id-1234"; - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {}, @@ -360,15 +279,7 @@ public void withUpsertNewId() { } } """ - .formatted(newId); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + .formatted(newId)) .body("$", responseIsFindAndSuccess()) .body("status.matchedCount", is(0)) .body("status.modifiedCount", is(0)) @@ -377,30 +288,22 @@ public void withUpsertNewId() { .body("status.upsertedId", is(newId)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"username" : "aaronm"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents[0]._id", is(newId)); } @Test public void withUpsertNoId() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {"username" : "username2"}, @@ -408,15 +311,7 @@ public void withUpsertNoId() { "options" : {"returnDocument" : "after", "upsert" : true} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document._id", is(notNullValue())) .body("data.document._id", any(String.class)) @@ -426,30 +321,23 @@ public void withUpsertNoId() { .body("status.upsertedId", any(String.class)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"username" : "username2"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) + .body("data.documents", hasSize(1)) .body("data.documents[0]._id", is(notNullValue())) .body("data.documents[0]._id", any(String.class)); } @Test public void byIdWithDifferentId() { - String document = + final String document = """ { "_id": "doc3", @@ -458,24 +346,15 @@ public void byIdWithDifferentId() { } """; insertDoc(document); - - String json = - """ - { - "findOneAndReplace": { - "filter" : {"_id" : "doc3"}, - "replacement" : {"_id" : "doc4", "username": "user3", "status" : false } - } - } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk( + """ + { + "findOneAndReplace": { + "filter" : {"_id" : "doc3"}, + "replacement" : {"_id" : "doc4", "username": "user3", "status" : false } + } + } + """) .body("$", responseIsError()) .body("errors[0].errorCode", is("DOCUMENT_REPLACE_DIFFERENT_DOCID")) .body( @@ -486,7 +365,7 @@ public void byIdWithDifferentId() { @Test public void byIdWithEmptyDocument() { - String document = + final String document = """ { "_id": "doc3", @@ -496,54 +375,38 @@ public void byIdWithEmptyDocument() { """; insertDoc(document); - String expected = - """ - { - "_id": "doc3" - } - """; - - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOneAndReplace": { "filter" : {"_id" : "doc3"}, "replacement" : {} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindAndSuccess()) .body("data.document", jsonEquals(document)) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "doc3" + } + """)); } } @@ -561,17 +424,7 @@ public void byIdProjectionAfter() { } """); - String expectedAfterProjection = - """ - { - "_id": "docProjAfter", - "status" : false - } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findOneAndReplace": { @@ -582,28 +435,21 @@ public void byIdProjectionAfter() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindAndSuccess()) - .body("data.document", jsonEquals(expectedAfterProjection)) - .body("status.matchedCount", is(1)) - .body("status.modifiedCount", is(1)); - - // assert state after update - String expectedAfterReplace = - """ + .body( + "data.document", + jsonEquals( + """ { "_id": "docProjAfter", - "username": "userP", "status" : false } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + """)) + .body("status.matchedCount", is(1)) + .body("status.modifiedCount", is(1)); + + // assert state after update + givenHeadersPostJsonThenOkNoErrors( """ { "find": { @@ -611,12 +457,17 @@ public void byIdProjectionAfter() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expectedAfterReplace)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "docProjAfter", + "username": "userP", + "status" : false + } + """)); } @Test @@ -630,17 +481,7 @@ public void byIdProjectionBefore() { } """); - String expectedWithProjectionBefore = - """ - { - "_id": "docProjBefore", - "active_user" : true - } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findOneAndReplace": { @@ -651,28 +492,21 @@ public void byIdProjectionBefore() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindAndSuccess()) - .body("data.document", jsonEquals(expectedWithProjectionBefore)) + .body( + "data.document", + jsonEquals( + """ + { + "_id": "docProjBefore", + "active_user" : true + } + """)) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); // assert state after update - String expectedAfterReplace = - """ - { - "_id": "docProjBefore", - "username": "userP", - "status" : false - } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "find": { @@ -680,12 +514,17 @@ public void byIdProjectionBefore() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expectedAfterReplace)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "docProjBefore", + "username": "userP", + "status" : false + } + """)); } // Reproduction to verify https://github.com/stargate/data-api/issues/1000 @@ -701,10 +540,7 @@ public void projectionBeforeWithoutId() { """); String upsertedId = - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findOneAndReplace": { @@ -715,10 +551,6 @@ public void projectionBeforeWithoutId() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindAndSuccess()) .body("status.matchedCount", is(0)) .body("status.modifiedCount", is(0)) @@ -731,10 +563,7 @@ public void projectionBeforeWithoutId() { // assert state after update String expectedAfterReplace = "{\"_id\":\"%s\"}".formatted(upsertedId); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "find": { @@ -743,10 +572,6 @@ public void projectionBeforeWithoutId() { } """ .formatted(upsertedId)) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindSuccess()) .body("data.documents[0]", jsonEquals(expectedAfterReplace)); } @@ -762,14 +587,13 @@ public void cleanUpData() { class FindOneAndReplaceFailing { @Test public void tryReplaceWithTooLongNumber() { - String document = + insertDoc( """ { "_id": "tooLongNumber1", "value" : 123 } - """; - insertDoc(document); + """); // Max number length: 100; use 110 String tooLongNumStr = "1234567890".repeat(11); @@ -786,14 +610,8 @@ public void tryReplaceWithTooLongNumber() { } """ .formatted(tooLongNumStr); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(json) + .body("errors", hasSize(1)) .body("$", responseIsError()) .body("errors[0].errorCode", is("SHRED_DOC_LIMIT_VIOLATION")) .body( diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateIntegrationTest.java index 0d3876c110..aafaac0575 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.assertj.core.api.Assertions.assertThat; @@ -16,7 +15,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.ClassOrderer; @@ -279,14 +277,7 @@ public void byColumnUpsert() { } } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk(json) .body("$", responseIsFindSuccess()) .body("data.documents[0]", is(notNullValue())) .body("data.documents[0]._id", any(String.class)); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateNoIndexIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateNoIndexIntegrationTest.java index a8e8423c7f..0eefba0eb9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateNoIndexIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneAndUpdateNoIndexIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.hamcrest.Matchers.containsString; @@ -12,7 +11,6 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.config.DocumentLimitsConfig; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.ClassOrderer; @@ -34,8 +32,8 @@ public class FindOneAndUpdateNoIndexIntegrationTest extends AbstractKeyspaceInte class CreateCollection { @Test public void createBaseCollection() { - String json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name": "%s", @@ -51,15 +49,7 @@ public void createBaseCollection() { } } """ - .formatted(collectionName); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + .formatted(collectionName)) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); } @@ -70,10 +60,7 @@ public void createBaseCollection() { class FindAndUpdateWithSet { @Test public void byIdAfterUpdate() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "insertOne": { @@ -93,10 +80,7 @@ public void byIdAfterUpdate() { .statusCode(200) .body("$", responseIsWriteSuccess()); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "findOneAndUpdate": { @@ -147,10 +131,7 @@ public void byIdBeforeUpdate() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "insertOne": { @@ -165,10 +146,7 @@ public void byIdBeforeUpdate() { .statusCode(200) .body("$", responseIsWriteSuccess()); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "findOneAndUpdate": { @@ -205,10 +183,8 @@ class ArraySizeLimit { public void allowNonIndexedBigArray() { insertEmptyDoc("array_size_big_noindex_doc"); final String arrayJson = bigArray(DocumentLimitsConfig.DEFAULT_MAX_ARRAY_LENGTH + 10); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + + givenHeadersAndJson( """ { "findOneAndUpdate": { @@ -236,10 +212,8 @@ public void allowNonIndexedBigArray() { public void failOnIndexedTooBigArray() { insertEmptyDoc("array_size_too_big_doc"); final String arrayJson = bigArray(DocumentLimitsConfig.DEFAULT_MAX_ARRAY_LENGTH + 10); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + + givenHeadersAndJson( """ { "findOneAndUpdate": { @@ -275,10 +249,7 @@ public void allowNonIndexedBigObject() { insertEmptyDoc("object_size_big_noindex_doc"); final String objectJson = bigObject(DocumentLimitsConfig.DEFAULT_MAX_OBJECT_PROPERTIES + 10); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "findOneAndUpdate": { @@ -307,10 +278,7 @@ public void failOnIndexedTooBigObject() { insertEmptyDoc("object_size_too_big_doc"); final String objectJson = bigObject(DocumentLimitsConfig.DEFAULT_MAX_OBJECT_PROPERTIES + 10); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "findOneAndUpdate": { @@ -338,10 +306,7 @@ public void failOnIndexedTooBigObject() { } private void insertEmptyDoc(String docId) { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "insertOne": { @@ -361,7 +326,7 @@ private void insertEmptyDoc(String docId) { .body("status.insertedIds[0]", is(docId)); } - private final String bigArray(int elementCount) { + private String bigArray(int elementCount) { final ArrayNode array = MAPPER.createArrayNode(); for (int i = 0; i < elementCount; i++) { array.add(i); @@ -369,7 +334,7 @@ private final String bigArray(int elementCount) { return array.toString(); } - private final String bigObject(int propertyCount) { + private String bigObject(int propertyCount) { final ObjectNode ob = MAPPER.createObjectNode(); for (int i = 0; i < propertyCount; i++) { ob.put("prop" + i, i); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java index 17b3ded3d9..66daf17d88 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneIntegrationTest.java @@ -770,7 +770,7 @@ public void failWithInvalidEscape() { .body( "errors[0].message", containsString( - "Invalid filter expression: filter clause path ('pricing&price&aud') is not a valid path.")); + "Invalid filter expression: filter clause path ('pricing&price&aud') is not a valid path: ")); } } @@ -831,24 +831,6 @@ public void failForInvalidObjectIdAsId() { containsString( "Bad JSON Extension value: '$objectId' value has to be 24-digit hexadecimal ObjectId, instead got (\"bogus\")")); } - - // [data-api#1902] - $lexical not allowed in filter - @Test - public void failForFilteringOnLexical() { - for (String filter : - new String[] { - "{\"$lexical\": \"quick brown fox\"}", "{\"$lexical\": {\"eq\": \"quick brown fox\"}}" - }) { - givenHeadersPostJsonThenOk("{ \"findOne\": { \"filter\" : %s}}".formatted(filter)) - .body("$", responseIsError()) - .body("errors", hasSize(1)) - .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) - .body( - "errors[0].message", - containsString( - "Cannot filter on lexical content field '$lexical': only 'sort' clause supported")); - } - } } @Nested diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneWithProjectionIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneWithProjectionIntegrationTest.java index 4c22a941b7..51a5d5f5a4 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneWithProjectionIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindOneWithProjectionIntegrationTest.java @@ -1,12 +1,10 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsFindSuccess; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Nested; @@ -68,14 +66,7 @@ public void byIdNestedExclusion() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors(json) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -107,14 +98,7 @@ public void byIdIncludeDates() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors(json) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -152,14 +136,7 @@ public void byIdExcludeDates() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors(json) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -200,10 +177,7 @@ class ProjectionWithJSONExtensions { @Test public void byIdDefaultProjection() { insertDoc(EXT_DOC1); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findOne": { @@ -211,10 +185,6 @@ public void byIdDefaultProjection() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindSuccess()) .body("data.document", jsonEquals(EXT_DOC1)); } @@ -222,10 +192,7 @@ public void byIdDefaultProjection() { @Test public void byIdIncludeExtValues() { insertDoc(EXT_DOC1); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findOne": { @@ -234,10 +201,6 @@ public void byIdIncludeExtValues() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -255,10 +218,7 @@ public void byIdIncludeExtValues() { @Test public void byIdExcludeExtValues() { insertDoc(EXT_DOC1); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOkNoErrors( """ { "findOne": { @@ -267,10 +227,6 @@ public void byIdExcludeExtValues() { } } """) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -307,14 +263,7 @@ public void byIdRootSliceHead() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors(json) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -344,14 +293,7 @@ public void byIdRootSliceTail() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors(json) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -379,8 +321,9 @@ public void byIdRootSliceHeadOverrun() { insertDoc(DOC1_JSON); insertDoc(DOC2_JSON); insertDoc(DOC3_JSON); - String json = - """ + + givenHeadersPostJsonThenOkNoErrors( + """ { "findOne": { "filter" : {"_id" : "doc2"}, @@ -390,16 +333,7 @@ public void byIdRootSliceHeadOverrun() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -418,8 +352,9 @@ public void byIdRootSliceTail() { insertDoc(DOC1_JSON); insertDoc(DOC2_JSON); insertDoc(DOC3_JSON); - String json = - """ + + givenHeadersPostJsonThenOkNoErrors( + """ { "findOne": { "filter" : {"_id" : "doc2"}, @@ -431,16 +366,7 @@ public void byIdRootSliceTail() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body( "data.document", @@ -460,8 +386,8 @@ public void byIdNestedArraySliceHead() { insertDoc(DOC2_JSON); insertDoc(DOC3_JSON); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "findOne": { "filter" : {"_id" : "doc2"}, @@ -473,16 +399,7 @@ public void byIdNestedArraySliceHead() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body( "data.document", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindRerankingProvidersIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindRerankingProvidersIntegrationTest.java index 226cec05a7..a55bded50a 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindRerankingProvidersIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindRerankingProvidersIntegrationTest.java @@ -1,16 +1,18 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; -import io.stargate.sgv2.jsonapi.service.provider.ModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; +import java.util.stream.Stream; import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; @QuarkusIntegrationTest @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) @@ -23,19 +25,13 @@ class FindRerankingProviders { @Test public final void defaultSupportModels() { // without option specified, only return supported models - String json = - """ + givenHeadersAndJson( + """ { "findRerankingProviders": { } } - """; - - given() - .port(getTestPort()) - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -47,80 +43,108 @@ public final void defaultSupportModels() { "status.rerankingProviders.nvidia.models[0].name", equalTo("nvidia/llama-3.2-nv-rerankqa-1b-v2")) .body( - "status.rerankingProviders.nvidia.models[0].modelSupport.status", - equalTo(ModelSupport.SupportStatus.SUPPORTED.name())); + "status.rerankingProviders.nvidia.models[0].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.SUPPORTED.name())); } - @Test - public final void filterByModelStatus() { - String json = - """ - { - "findRerankingProviders": { - "options": { - "includeModelStatus": [ - "DEPRECATED", - "END_OF_LIFE" - ] - } - } - } - """; + private static Stream returnedAllStatus() { + return Stream.of( + // emtpy string + Arguments.of("\"\""), + // null + Arguments.of("null")); + } - given() - .port(getTestPort()) - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + @ParameterizedTest() + @MethodSource("returnedAllStatus") + public final void returnModelsWithAllStatus(String filterModelStatus) { + givenHeadersAndJson( + """ + { + "findRerankingProviders": { + "options": { + "filterModelStatus": %s + } + } + } + """ + .formatted(filterModelStatus)) .when() .post(GeneralResource.BASE_PATH) .then() .statusCode(200) .body("$", responseIsStatusOnly()) .body("status.rerankingProviders", notNullValue()) - .body("status.rerankingProviders.nvidia.models", hasSize(2)) + .body("status.rerankingProviders.nvidia.models", hasSize(3)) .body( "status.rerankingProviders.nvidia.models[0].name", equalTo("nvidia/a-random-EOL-model")) .body( - "status.rerankingProviders.nvidia.models[0].modelSupport.status", - equalTo(ModelSupport.SupportStatus.END_OF_LIFE.name())) + "status.rerankingProviders.nvidia.models[0].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.END_OF_LIFE.name())) .body( "status.rerankingProviders.nvidia.models[1].name", equalTo("nvidia/a-random-deprecated-model")) .body( - "status.rerankingProviders.nvidia.models[1].modelSupport.status", - equalTo(ModelSupport.SupportStatus.DEPRECATED.name())); + "status.rerankingProviders.nvidia.models[1].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.DEPRECATED.name())) + .body( + "status.rerankingProviders.nvidia.models[2].name", + equalTo("nvidia/llama-3.2-nv-rerankqa-1b-v2")) + .body( + "status.rerankingProviders.nvidia.models[2].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.SUPPORTED.name())); + } + + @Test + public final void returnModelsWithSpecifiedStatus() { + givenHeadersAndJson( + """ + { + "findRerankingProviders": { + "options": { + "filterModelStatus": "deprecated" + } + } + } + """) + .when() + .post(GeneralResource.BASE_PATH) + .then() + .statusCode(200) + .body("$", responseIsStatusOnly()) + .body("status.rerankingProviders", notNullValue()) + .body("status.rerankingProviders.nvidia.models", hasSize(1)) + .body( + "status.rerankingProviders.nvidia.models[0].name", + equalTo("nvidia/a-random-deprecated-model")) + .body( + "status.rerankingProviders.nvidia.models[0].apiModelSupport.status", + equalTo(ApiModelSupport.SupportStatus.DEPRECATED.name())); } @Test public final void failedWithRandomStatus() { - String json = - """ + givenHeadersAndJson( + """ { "findRerankingProviders": { "options": { - "includeModelStatus": [ - "random" - ] + "filterModelStatus": "random" } } } - """; - - given() - .port(getTestPort()) - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() - .statusCode(400) + .statusCode(200) .body("$", responseIsError()) - .body("errors[0].errorCode", is("INVALID_REQUEST_STRUCTURE_MISMATCH")) + .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) .body( - "errors[0].message", containsString("not one of the values accepted for Enum class")); + "errors[0].message", + containsString( + "field 'command.options.filterModelStatus' value \"random\" not valid")); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResourceIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResourceIntegrationTest.java index e61645b108..88bfc76ddb 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResourceIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/GeneralResourceIntegrationTest.java @@ -27,7 +27,8 @@ class ClientErrors { @Test public void tokenMissing() { - given() + + given() // No headers added on purpose .contentType(ContentType.JSON) .body("{}") .when() @@ -43,10 +44,7 @@ public void tokenMissing() { @Test public void malformedBody() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{wrong}") + givenHeadersAndJson("{wrong}") .when() .post(GeneralResource.BASE_PATH) .then() @@ -58,18 +56,13 @@ public void malformedBody() { @Test public void unknownCommand() { - String json = - """ + givenHeadersAndJson( + """ { "unknownCommand": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(GeneralResource.BASE_PATH) .then() @@ -84,9 +77,7 @@ public void unknownCommand() { @Test public void emptyBody() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) + givenHeaders() .when() .post(GeneralResource.BASE_PATH) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InAndNinIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InAndNinIntegrationTest.java index a1fca91e8e..b4de265d32 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InAndNinIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InAndNinIntegrationTest.java @@ -1,13 +1,11 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import java.util.stream.Stream; import net.javacrumbs.jsonunit.ConfigurableJsonMatcher; @@ -22,15 +20,7 @@ class InAndNinIntegrationTest extends AbstractCollectionIntegrationTestBase { private void insert(String json) { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) - .body("$", responseIsWriteSuccess()); + givenHeadersPostJsonThenOkNoErrors(json).body("$", responseIsWriteSuccess()); } private ConfigurableJsonMatcher[] getJsonEquals(int... docs) { @@ -84,89 +74,89 @@ private ConfigurableJsonMatcher[] getJsonEquals(int... docs) { public void setUp() { insert( """ - { - "insertOne": { - "document": { - "_id": "doc1", - "username": "user1", - "active_user" : true, - "date" : {"$date": 1672531200000}, - "age" : 20, - "null_column": null - } - } - } - """); + { + "insertOne": { + "document": { + "_id": "doc1", + "username": "user1", + "active_user" : true, + "date" : {"$date": 1672531200000}, + "age" : 20, + "null_column": null + } + } + } + """); insert( """ - { - "insertOne": { - "document": { - "_id": "doc2", - "username": "user2", - "subdoc" : { - "id" : "abc" - }, - "array" : [ - "value1" - ] - } - } - } - """); + { + "insertOne": { + "document": { + "_id": "doc2", + "username": "user2", + "subdoc" : { + "id" : "abc" + }, + "array" : [ + "value1" + ] + } + } + } + """); insert( """ - { - "insertOne": { - "document": { - "_id": "doc3", - "username": "user3", - "tags" : ["tag1", "tag2", "tag1234567890123456789012345", null, 1, true], - "nestedArray" : [["tag1", "tag2"], ["tag1234567890123456789012345", null]] - } - } - } - """); + { + "insertOne": { + "document": { + "_id": "doc3", + "username": "user3", + "tags" : ["tag1", "tag2", "tag1234567890123456789012345", null, 1, true], + "nestedArray" : [["tag1", "tag2"], ["tag1234567890123456789012345", null]] + } + } + } + """); insert( """ - { - "insertOne": { - "document": { - "_id": "doc4", - "username" : "user4", - "indexedObject" : { "0": "value_0", "1": "value_1" } - } - } - } - """); + { + "insertOne": { + "document": { + "_id": "doc4", + "username" : "user4", + "indexedObject" : { "0": "value_0", "1": "value_1" } + } + } + } + """); insert( """ - { - "insertOne": { - "document": { - "_id": "doc5", - "username": "user5", - "sub_doc" : { "a": 5, "b": { "c": "v1", "d": false } } - } - } - } - """); + { + "insertOne": { + "document": { + "_id": "doc5", + "username": "user5", + "sub_doc" : { "a": 5, "b": { "c": "v1", "d": false } } + } + } + } + """); insert( """ - { - "insertOne": { - "document": { - "_id": {"$date": 6}, - "username": "user6" - } - } - } - """); + { + "insertOne": { + "document": { + "_id": {"$date": 6}, + "username": "user6" + } + } + } + """); } @Nested @@ -176,37 +166,29 @@ class In { @Test public void inCondition() { - String json = - """ + // findOne resolves any one of the resolved documents. So the order of the documents in the + // $in clause is not guaranteed. + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : {"$in": ["doc1", "doc4"]}} } } - """; - - // findOne resolves any one of the resolved documents. So the order of the documents in the - // $in clause is not guaranteed. - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - String expected2 = - """ - {"_id":"doc4", "username":"user4", "indexedObject":{"0":"value_0","1":"value_1"}} - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(2)) - .body("data.documents", containsInAnyOrder(jsonEquals(expected1), jsonEquals(expected2))); + .body( + "data.documents", + containsInAnyOrder( + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """), + jsonEquals( + """ + {"_id":"doc4", "username":"user4", "indexedObject":{"0":"value_0","1":"value_1"}} + """))); } private static Stream IN_FOR_ID_WITH_LIMIT() { @@ -221,137 +203,102 @@ private static Stream IN_FOR_ID_WITH_LIMIT() { @ParameterizedTest @MethodSource("IN_FOR_ID_WITH_LIMIT") public void inForIdWithLimit(String filter, int limit, int expected) { - String json = + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : %s, + "options": {"limit": %s} + } + } """ - { - "find": { - "filter" : %s, - "options": {"limit": %s} - } - } - """ - .formatted(filter, limit); - givenHeadersPostJsonThenOkNoErrors(json) + .formatted(filter, limit)) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(expected)); } @Test public void inConditionWithSubDoc() { - String json = - """ - { - "find": { - "filter" : {"sub_doc" : {"$in" : [{ "a": 5, "b": { "c": "v1", "d": false }}]} } - } - } - """; - - String expected5 = - """ + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : {"sub_doc" : {"$in" : [{ "a": 5, "b": { "c": "v1", "d": false }}]} } + } + } + """) + .body("$", responseIsFindSuccess()) + .body("data.documents", hasSize(1)) + .body( + "data.documents[0]", + jsonEquals( + """ { "_id": "doc5", "username": "user5", "sub_doc" : { "a": 5, "b": { "c": "v1", "d": false } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) - .body("$", responseIsFindSuccess()) - .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected5)); + """)); } @Test public void inConditionWithArray() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"array" : {"$in" : [["value1"]] } } } } - """; - - String expected = - """ - {"_id":"doc2", "username":"user2", "subdoc":{"id":"abc"},"array":["value1"]} - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc2", "username":"user2", "subdoc":{"id":"abc"},"array":["value1"]} + """)); } @Test public void inConditionWithOtherCondition() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : {"$in": ["doc1", "doc4"]}, "username" : "user1" } } } - """; - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected1)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """)); } @Test public void idInConditionEmptyArray() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : {"$in": []}} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void nonIDInConditionEmptyArray() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : { @@ -359,23 +306,15 @@ public void nonIDInConditionEmptyArray() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void nonIDInConditionEmptyArrayAnd() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : { @@ -392,23 +331,15 @@ public void nonIDInConditionEmptyArrayAnd() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void nonIDInConditionEmptyArrayOr() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : { @@ -425,66 +356,41 @@ public void nonIDInConditionEmptyArrayOr() { } } } - """; - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected1)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """)); } @Test public void inOperatorEmptyArrayWithAdditionalFilters() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"username": "user1", "_id" : {"$in": []}} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void inConditionNonArrayArray() { - String json = - """ + givenHeadersPostJsonThenOk( + """ { "find": { "filter" : {"_id" : {"$in": true}} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) .body("errors", hasSize(1)) .body("errors[0].exceptionClass", is("JsonApiException")) @@ -494,8 +400,8 @@ public void inConditionNonArrayArray() { @Test public void inConditionNonIdField() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : { @@ -503,28 +409,21 @@ public void inConditionNonIdField() { } } } - """; - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected1)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """)); } @Test public void inConditionNonIdFieldMulti() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : { @@ -532,16 +431,7 @@ public void inConditionNonIdFieldMulti() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(2)) .body("data.documents", containsInAnyOrder(getJsonEquals(1, 4))); @@ -549,93 +439,72 @@ public void inConditionNonIdFieldMulti() { @Test public void inConditionNonIdFieldIdField() { - String json = - """ - { - "find": { - "filter" : { - "username" : {"$in" : ["user1", "user10"]}, - "_id" : {"$in" : ["doc1", "???"]} - } - } - } - """; - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : { + "username" : {"$in" : ["user1", "user10"]}, + "_id" : {"$in" : ["doc1", "???"]} + } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected1)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """)); } @Test public void inConditionNonIdFieldIdFieldSort() { - String json = - """ - { - "find": { - "filter" : { - "username" : {"$in" : ["user1", "user10"]}, - "_id" : {"$in" : ["doc1", "???"]} - }, - "sort": { "username": -1 } - } - } - """; - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : { + "username" : {"$in" : ["user1", "user10"]}, + "_id" : {"$in" : ["doc1", "???"]} + }, + "sort": { "username": -1 } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected1)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """)); } @Test public void inConditionWithDuplicateValues() { - String json = - """ - { - "find": { - "filter" : { - "username" : {"$in" : ["user1", "user1"]}, - "_id" : {"$in" : ["doc1", "???"]} - } - } - } - """; - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : { + "username" : {"$in" : ["user1", "user1"]}, + "_id" : {"$in" : ["doc1", "???"]} + } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected1)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """)); } } @@ -646,52 +515,34 @@ class Nin { @Test public void nonIdSimpleNinCondition() { - String json = - """ - { - "find": { - "filter" : {"username" : {"$nin": ["user2", "user3","user4","user5","user6"]}} - } - } - """; - - String expected1 = - """ - {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : {"username" : {"$nin": ["user2", "user3","user4","user5","user6"]}} + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) - .body("data.documents[0]", jsonEquals(expected1)); + .body( + "data.documents[0]", + jsonEquals( + """ + {"_id":"doc1", "username":"user1", "active_user":true, "date" : {"$date": 1672531200000}, "age" : 20, "null_column": null} + """)); } @Test public void ninConditionWithSubDoc() { - String json = - """ - { - "find": { - "filter" : {"sub_doc" : {"$nin": [{ "a": 5, "b": { "c": "v1", "d": false } }]}} - } - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : {"sub_doc" : {"$nin": [{ "a": 5, "b": { "c": "v1", "d": false } }]}} + } + } + """) .body("$", responseIsFindSuccess()) // except doc 5 .body("data.documents", hasSize(5)) @@ -701,23 +552,14 @@ public void ninConditionWithSubDoc() { @Test public void ninConditionWithArray() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"array" : {"$nin" : [["value1"]] } } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) // except doc 2 .body("data.documents", hasSize(5)) @@ -727,50 +569,31 @@ public void ninConditionWithArray() { @Test public void nonIdNinEmptyArray() { - String json = - """ + // should find everything + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"username" : {"$nin": []}} } } - """; - - // should find everything - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(6)) .body("data.documents", containsInAnyOrder(getJsonEquals(1, 2, 3, 4, 5, 6))); - ; } @Test public void idNinEmptyArray() { - String json = - """ + // should find everything + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : {"$nin": []}} } } - """; - - // should find everything - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(6)) .body("data.documents", containsInAnyOrder(getJsonEquals(1, 2, 3, 4, 5, 6))); @@ -784,32 +607,24 @@ class Combination { @Test public void nonIdInEmptyAndNonIdNinEmptyAnd() { - String json = - """ + // should find nothing + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"username" : {"$in": []}, "age": {"$nin" : []}} } } - """; - - // should find nothing - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } @Test public void nonIdInEmptyOrNonIdNinEmptyOr() { - String json = - """ + // should find everything + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" :{ @@ -821,17 +636,7 @@ public void nonIdInEmptyOrNonIdNinEmptyOr() { } } } - """; - - // should find everything - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(6)) .body("data.documents", containsInAnyOrder(getJsonEquals(1, 2, 3, 4, 5, 6))); @@ -839,24 +644,15 @@ public void nonIdInEmptyOrNonIdNinEmptyOr() { @Test public void nonIdInEmptyAndIdNinEmptyAnd() { - String json = - """ + // should find nothing + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"username" : {"$in": []}, "_id": {"$nin" : []}} } } - """; - - // should find nothing - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(0)); } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/IndexingConfigIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/IndexingConfigIntegrationTest.java index e108c8f5ff..4ff7db0f2b 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/IndexingConfigIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/IndexingConfigIntegrationTest.java @@ -1,13 +1,11 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsError; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsFindSuccess; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.*; @@ -170,22 +168,16 @@ class IndexingConfig { @Test public void filterFieldInDenyOne() { // explicitly deny "address.city", implicitly allow "_id", "name", "address.street" - String filterData = - """ + givenHeadersPostJsonThenOk( + keyspaceName, + denyOneIndexingCollection, + """ { "find": { "filter": {"address.city": "monkey town"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) - .then() - .statusCode(200) + """) .body("$", responseIsError()) .body("errors[0].message", endsWith("filter path 'address.city' is not indexed")) .body("errors[0].errorCode", is("UNINDEXED_FILTER_PATH")) @@ -195,22 +187,16 @@ public void filterFieldInDenyOne() { @Test public void filterVectorFieldInDenyAll() { // explicitly deny "address.city", implicitly allow "_id", "name", "address.street" - String filterData = - """ - { - "find": { - "filter": {"$vector": {"$exists": true}} - } - } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, denyAllIndexingCollection) - .then() - .statusCode(200) + givenHeadersPostJsonThenOk( + keyspaceName, + denyAllIndexingCollection, + """ + { + "find": { + "filter": {"$vector": {"$exists": true}} + } + } + """) .body("$", responseIsError()) .body("errors[0].message", endsWith("filter path '$vector' is not indexed")) .body("errors[0].errorCode", is("UNINDEXED_FILTER_PATH")) @@ -220,18 +206,14 @@ public void filterVectorFieldInDenyAll() { @Test public void filterFieldNotInDenyOne() { // explicitly deny "address.city", implicitly allow "_id", "name", "address.street" - String filterData1 = - """ + givenHeadersAndJson( + """ { "find": { "filter": {"name": "aaron"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData1) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) .then() @@ -239,8 +221,8 @@ public void filterFieldNotInDenyOne() { .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)); // deny "address.city", only this as a string, not "address" as an object - String filterData2 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -252,19 +234,15 @@ public void filterFieldNotInDenyOne() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData2) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) .then() .statusCode(200) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)); - String filterData3 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -277,11 +255,7 @@ public void filterFieldNotInDenyOne() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData3) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) .then() @@ -296,18 +270,14 @@ public void filterFieldNotInDenyOne() { public void filterFieldInDenyMany() { // explicitly deny "name", "address", implicitly allow "_id" // deny "address", "address.city" should also be included - String filterData = - """ + givenHeadersAndJson( + """ { "find": { "filter": {"address.city": "monkey town"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyManyIndexingCollection) .then() @@ -320,19 +290,14 @@ public void filterFieldInDenyMany() { @Test public void filterFieldInDenyAll() { - // deny all use "*" - String filterData = - """ + givenHeadersAndJson( + """ { "find": { "filter": {"address.city": "monkey town"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyAllIndexingCollection) .then() @@ -345,19 +310,14 @@ public void filterFieldInDenyAll() { @Test public void filterIdInDenyAllWithEqAndIn() { - // deny all use "*" - String filterId1 = - """ + givenHeadersAndJson( + """ { "find": { "filter": {"_id": "1"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterId1) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyAllIndexingCollection) .then() @@ -365,8 +325,8 @@ public void filterIdInDenyAllWithEqAndIn() { .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)); - String filterId2 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -379,11 +339,7 @@ public void filterIdInDenyAllWithEqAndIn() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterId2) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyAllIndexingCollection) .then() @@ -394,9 +350,8 @@ public void filterIdInDenyAllWithEqAndIn() { @Test public void filterIdInDenyAllWithoutEqAndIn() { - // deny all use "*" - String filterId3 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -409,11 +364,7 @@ public void filterIdInDenyAllWithoutEqAndIn() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterId3) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyAllIndexingCollection) .then() @@ -429,18 +380,14 @@ public void filterIdInDenyAllWithoutEqAndIn() { @Test public void filterFieldInAllowOne() { // explicitly allow "name", implicitly deny "_id" "address" - String filterData = - """ + givenHeadersAndJson( + """ { "find": { "filter": {"name": "aaron"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowOneIndexingCollection) .then() @@ -452,8 +399,8 @@ public void filterFieldInAllowOne() { @Test public void filterFieldNotInAllowOne() { // explicitly allow "name", implicitly deny "_id" "address.city" "address.street" "address" - String filterData1 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -466,11 +413,7 @@ public void filterFieldNotInAllowOne() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData1) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowOneIndexingCollection) .then() @@ -479,8 +422,8 @@ public void filterFieldNotInAllowOne() { .body("errors[0].message", endsWith("filter path 'address' is not indexed")) .body("errors[0].errorCode", is("UNINDEXED_FILTER_PATH")) .body("errors[0].exceptionClass", is("JsonApiException")); - String filterData2 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -493,11 +436,7 @@ public void filterFieldNotInAllowOne() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData2) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowOneIndexingCollection) .then() @@ -508,18 +447,14 @@ public void filterFieldNotInAllowOne() { .body( "errors[0].message", is("_id is not indexed: you can only use $eq or $in as the operator")); - String filterData3 = - """ + givenHeadersAndJson( + """ { "find": { "filter": {"_id": "1"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData3) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowOneIndexingCollection) .then() @@ -531,8 +466,8 @@ public void filterFieldNotInAllowOne() { @Test public void filterFieldInAllowMany() { // explicitly allow "name" "address.city", implicitly deny "_id" "address.street" - String filterData = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -561,11 +496,7 @@ public void filterFieldInAllowMany() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowManyIndexingCollection) .then() @@ -578,8 +509,8 @@ public void filterFieldInAllowMany() { public void filterFieldNotInAllowMany() { // explicitly allow "name" "address.city", implicitly deny "_id" "address.street" "address" // _id is allowed using in - String filterData1 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -606,11 +537,7 @@ public void filterFieldNotInAllowMany() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData1) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowManyIndexingCollection) .then() @@ -620,8 +547,8 @@ public void filterFieldNotInAllowMany() { .body("errors[0].errorCode", is("UNINDEXED_FILTER_PATH")) .body("errors[0].exceptionClass", is("JsonApiException")); // allow "address.city", only this as a string, not "address" as an object - String filterData2 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -634,12 +561,7 @@ public void filterFieldNotInAllowMany() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData2) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowManyIndexingCollection) .then() @@ -655,8 +577,8 @@ public void incrementalPathInArray() { // explicitly deny "address.city", implicitly allow "_id", "name", "address.street" "address" // String and array in array - no incremental path, the path is "address" - should be allowed // but no data return - String filterData1 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -666,11 +588,7 @@ public void incrementalPathInArray() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData1) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) .then() @@ -680,8 +598,8 @@ public void incrementalPathInArray() { .body("errors[0].errorCode", is("UNINDEXED_FILTER_PATH")) .body("errors[0].exceptionClass", is("JsonApiException")); // explicitly deny "address.city", implicitly allow "_id", "name", "address.street" "address" - String filterData2 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -691,11 +609,7 @@ public void incrementalPathInArray() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData2) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) .then() @@ -704,8 +618,8 @@ public void incrementalPathInArray() { .body("data.documents", hasSize(1)); // explicitly deny "address.city", implicitly allow "_id", "name", "address.street" "address" // Object (Hashmap) in array - incremental path is "address.city" - String filterData3 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -719,11 +633,7 @@ public void incrementalPathInArray() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData3) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) .then() @@ -733,8 +643,8 @@ public void incrementalPathInArray() { .body("errors[0].errorCode", is("UNINDEXED_FILTER_PATH")) .body("errors[0].exceptionClass", is("JsonApiException")); // explicitly deny "name", "address" "contact.email" - String filterData4 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -751,11 +661,7 @@ public void incrementalPathInArray() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData4) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyManyIndexingCollection) .then() @@ -768,10 +674,8 @@ public void incrementalPathInArray() { @Test public void incrementalPathInMap() { - // explicitly deny "address.city", implicitly allow "_id", "name", "address.street" "address" - // map in map - String filterData1 = - """ + givenHeadersAndJson( + """ { "find": { "filter": { @@ -785,11 +689,7 @@ public void incrementalPathInMap() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(filterData1) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyOneIndexingCollection) .then() @@ -802,9 +702,8 @@ public void incrementalPathInMap() { @Test public void sortFieldInAllowMany() { - // explicitly deny "name", "address", implicitly allow "_id", "$vector" - String sortData = - """ + givenHeadersAndJson( + """ { "find": { "sort": { @@ -812,11 +711,7 @@ public void sortFieldInAllowMany() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(sortData) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, denyManyIndexingCollection) .then() @@ -829,8 +724,8 @@ public void sortFieldInAllowMany() { public void sortFieldNotInAllowMany() { // explicitly allow "name" "address.city", implicitly deny "_id" "address.street"; // (and implicitly allow "$vector" as well) - String sortData = - """ + givenHeadersAndJson( + """ { "find": { "sort": { @@ -838,11 +733,7 @@ public void sortFieldNotInAllowMany() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(sortData) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, allowManyIndexingCollection) .then() @@ -856,10 +747,7 @@ public void sortFieldNotInAllowMany() { @Test public void fieldNameWithDot() { // allow "pricing.price&.usd", so one document is returned - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { @@ -876,34 +764,28 @@ public void fieldNameWithDot() { .body("data.documents", hasSize(1)); // allow "pricing.price&&jpy", but the path is not escaped, so no document is returned - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ - { - "find": { - "filter": { - "pricing.price&jpy": 1 - } - } - } - """) + { + "find": { + "filter": { + "pricing.price&jpy": 1 + } + } + } + """) .post(CollectionResource.BASE_PATH, keyspaceName, allowManyIndexingCollection) .then() .statusCode(200) .body("$", responseIsError()) .body( "errors[0].message", - containsString("filter clause path ('pricing.price&jpy') is not a valid path.")) + containsString("filter clause path ('pricing.price&jpy') is not a valid path: ")) .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) .body("errors[0].exceptionClass", is("JsonApiException")); // allow "metadata.app&.kubernetes&.io/name", so one document is returned - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { @@ -921,10 +803,7 @@ public void fieldNameWithDot() { // did not allow "pricing.price&.aud", even though the path is escaped, no document is // returned - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertInCollectionIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertInCollectionIntegrationTest.java index 667b27148a..a8332acacb 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertInCollectionIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertInCollectionIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.assertj.core.api.Assertions.assertThat; @@ -13,7 +12,6 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.restassured.response.Response; import io.stargate.sgv2.jsonapi.config.DocumentLimitsConfig; import io.stargate.sgv2.jsonapi.config.OperationsConfig; @@ -502,10 +500,7 @@ public void insertDocWithAutoObjectIdKey() { } """; Response response = - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, COLLECTION_WITH_AUTO_OBJECTID) .then() @@ -525,10 +520,7 @@ public void insertDocWithAutoObjectIdKey() { assertThat(objectId).isNotNull(); // And with that, we should be able to find the document - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( "{\"find\": { \"filter\" : {\"_id\": {\"$objectId\":\"%s\"}}}}" .formatted(objectId.toString())) .when() @@ -1576,10 +1568,7 @@ public void orderedFailBadKeyspace() { } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson(json) .when() .post(CollectionResource.BASE_PATH, "something_else", collectionName) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertLexicalInCollectionIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertLexicalInCollectionIntegrationTest.java index c2f45cc0a0..c1b1c8ebef 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertLexicalInCollectionIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/InsertLexicalInCollectionIntegrationTest.java @@ -12,6 +12,7 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.fixtures.TestTextUtil; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.ClassOrderer; @@ -27,6 +28,8 @@ @TestClassOrder(ClassOrderer.OrderAnnotation.class) public class InsertLexicalInCollectionIntegrationTest extends AbstractCollectionIntegrationTestBase { + private static final int MAX_LEXICAL_LENGTH = 8192; + protected InsertLexicalInCollectionIntegrationTest() { super("col_insert_lexical_"); } @@ -163,6 +166,30 @@ public void insertDocWithLexicalOk() { """)); } + @Test + public void insertDocWithLongestLexicalOk() { + final String docId = "lexical-long-ok"; + final String text = TestTextUtil.generateTextDoc(MAX_LEXICAL_LENGTH, " "); + String doc = + """ + { + "_id": "%s", + "$lexical": "%s" + } + """ + .formatted(docId, text); + + givenHeadersPostJsonThenOkNoErrors("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + .body("$", responseIsWriteSuccess()) + .body("status.insertedIds[0]", is(docId)); + + givenHeadersPostJsonThenOkNoErrors( + "{ \"find\": { \"filter\" : { \"_id\" : \"%s\"}}}}".formatted(docId)) + .body("$", responseIsFindSuccess()) + // NOTE: "$lexical" is not included in the response by default, ensure + .body("data.documents", hasSize(1)); + } + @Test public void failInsertDocWithLexicalIfNotEnabled() { final String COLLECTION_WITHOUT_LEXICAL = @@ -200,6 +227,28 @@ public void failInsertDocWithLexicalIfNotEnabled() { // And finally, drop the Collection after use deleteCollection(COLLECTION_WITHOUT_LEXICAL); } + + @Test + public void failInsertDocWithTooLongLexical() { + final String docId = "lexical-too-long"; + // Limit not based on the length of the string, but on total length of unique + // tokens. So need to guesstimate "too big" size + final String text = TestTextUtil.generateTextDoc((int) (MAX_LEXICAL_LENGTH * 1.5), " "); + String doc = + """ + { + "_id": "%s", + "$lexical": "%s" + } + """ + .formatted(docId, text); + + givenHeadersPostJsonThenOk("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + .body("$", responseIsWritePartialSuccess()) + .body("errors", hasSize(1)) + .body("errors[0].errorCode", is(ErrorCodeV1.LEXICAL_CONTENT_TOO_BIG.name())) + .body("errors[0].message", containsString("Lexical content is too big")); + } } @DisabledIfSystemProperty(named = TEST_PROP_LEXICAL_DISABLED, matches = "true") diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResourceIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResourceIntegrationTest.java index 61ef5ab546..fa35e85ef2 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResourceIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/KeyspaceResourceIntegrationTest.java @@ -20,7 +20,7 @@ class ClientErrors { @Test public void tokenMissing() { - given() + given() // Headers omitted on purpose .contentType(ContentType.JSON) .body("{}") .when() @@ -36,10 +36,7 @@ public void tokenMissing() { @Test public void malformedBody() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{wrong}") + givenHeadersAndJson("{wrong}") .when() .post(KeyspaceResource.BASE_PATH, keyspaceName) .then() @@ -51,18 +48,13 @@ public void malformedBody() { @Test public void unknownCommand() { - String json = - """ + givenHeadersAndJson( + """ { "unknownCommand": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(KeyspaceResource.BASE_PATH, keyspaceName) .then() @@ -77,9 +69,7 @@ public void unknownCommand() { @Test public void emptyBody() { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) + givenHeaders() .when() .post(KeyspaceResource.BASE_PATH, keyspaceName) .then() diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/LwtRetryIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/LwtRetryIntegrationTest.java index c6b734f344..e18caccb5d 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/LwtRetryIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/LwtRetryIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsFindSuccess; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; import static org.hamcrest.Matchers.anyOf; @@ -9,7 +8,6 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import java.util.concurrent.CountDownLatch; import org.junit.jupiter.api.AfterEach; @@ -21,38 +19,13 @@ public class LwtRetryIntegrationTest extends AbstractCollectionIntegrationTestBa @RepeatedTest(10) public void mixedOperations() throws Exception { - String document = + insertDoc( """ { "_id": "doc1", "count": 0 } - """; - insertDoc(document); - - String delete = - """ - { - "deleteOne": { - "filter": { - "_id": "doc1" - } - } - } - """; - String update = - """ - { - "updateOne": { - "filter": { - "_id": "doc1" - }, - "update" : { - "$inc": {"count": 1} - } - } - } - """; + """); CountDownLatch latch = new CountDownLatch(2); @@ -60,14 +33,19 @@ public void mixedOperations() throws Exception { // but for update we might see delete before read, before update or after new Thread( () -> { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(update) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "updateOne": { + "filter": { + "_id": "doc1" + }, + "update" : { + "$inc": {"count": 1} + } + } + } + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", anyOf(is(0), is(1))) .body("status.modifiedCount", anyOf(is(0), is(1))); @@ -78,14 +56,16 @@ public void mixedOperations() throws Exception { new Thread( () -> { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(delete) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "deleteOne": { + "filter": { + "_id": "doc1" + } + } + } + """) .body("$", responseIsStatusOnly()) .body("status.deletedCount", is(1)); @@ -96,21 +76,13 @@ public void mixedOperations() throws Exception { latch.await(); // ensure there's nothing left - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", is(empty())); } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/PaginationIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/PaginationIntegrationTest.java index 08d9b92966..21ada69ddc 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/PaginationIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/PaginationIntegrationTest.java @@ -1,13 +1,11 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsFindSuccess; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsWriteSuccess; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import org.junit.jupiter.api.*; @@ -43,71 +41,45 @@ public void setUp() { } private void insert(String json) { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) - .body("$", responseIsWriteSuccess()); + givenHeadersPostJsonThenOkNoErrors(json).body("$", responseIsWriteSuccess()); } @Test @Order(2) public void threePagesCheck() { - String json = - """ + String nextPageState = + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { } } - """; - - String nextPageState = - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(defaultPageSize)) .extract() .path("data.nextPageState"); - String json1 = - """ - { - "find": { - "options":{ - "pageState" : "%s" - } - } - } - """ - .formatted(nextPageState); - nextPageState = - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json1) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "options":{ + "pageState" : "%s" + } + } + } + """ + .formatted(nextPageState)) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(defaultPageSize)) .extract() .path("data.nextPageState"); // should be fine with the empty sort clause - String json2 = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "sort": {}, @@ -117,15 +89,7 @@ public void threePagesCheck() { } } """ - .formatted(nextPageState); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json2) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + .formatted(nextPageState)) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(documentAmount - 2 * defaultPageSize)) .body("data.nextPageState", nullValue()); @@ -134,8 +98,8 @@ public void threePagesCheck() { @Test @Order(3) public void pageLimitCheck() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "options": { @@ -144,16 +108,7 @@ public void pageLimitCheck() { } } """ - .formatted(documentLimit); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + .formatted(documentLimit)) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(documentLimit)) .body("data.nextPageState", nullValue()); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/RangeReadIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/RangeReadIntegrationTest.java index d2688d8606..6c5c2fc219 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/RangeReadIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/RangeReadIntegrationTest.java @@ -1,12 +1,8 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; -import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsFindSuccess; -import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.responseIsStatusOnly; +import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.*; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -14,17 +10,15 @@ import com.fasterxml.jackson.databind.node.JsonNodeFactory; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import java.util.ArrayList; import java.util.List; -import org.junit.jupiter.api.ClassOrderer; -import org.junit.jupiter.api.MethodOrderer; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Order; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestClassOrder; -import org.junit.jupiter.api.TestMethodOrder; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; @QuarkusIntegrationTest @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) @@ -49,27 +43,20 @@ public void setUp() { public void gt() throws Exception { int[] ids = {24, 25}; List testDatas = getDocuments(ids); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); + for (int i = 0; i < testDatas.size(); i++) { + arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"userId" : {"$gt" : 23}}, "sort" : {"userId" : 1} } } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); - for (int i = 0; i < testDatas.size(); i++) - arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(2)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -80,27 +67,20 @@ public void gt() throws Exception { public void gte() throws Exception { int[] ids = {23, 24, 25}; List testDatas = getDocuments(ids); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); + for (int i = 0; i < testDatas.size(); i++) { + arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"userId" : {"$gte" : 23}}, "sort" : {"userId" : 1} } } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); - for (int i = 0; i < testDatas.size(); i++) - arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(3)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -111,27 +91,20 @@ public void gte() throws Exception { public void lt() throws Exception { int[] ids = {1, 2}; List testDatas = getDocuments(ids); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); + for (int i = 0; i < testDatas.size(); i++) { + arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"userId" : {"$lt" : 3}}, "sort" : {"userId" : 1} } } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); - for (int i = 0; i < testDatas.size(); i++) - arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(2)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -142,27 +115,20 @@ public void lt() throws Exception { public void lte() throws Exception { int[] ids = {1, 2, 3}; List testDatas = getDocuments(ids); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); + for (int i = 0; i < testDatas.size(); i++) { + arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"userId" : {"$lte" : 3}}, "sort" : {"userId" : 1} } } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); - for (int i = 0; i < testDatas.size(); i++) - arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(3)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -173,27 +139,20 @@ public void lte() throws Exception { public void rangeWithDate() throws Exception { int[] ids = {24, 25}; List testDatas = getDocuments(ids); - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); + for (int i = 0; i < testDatas.size(); i++) { + arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"dateValue" : {"$gt" : {"$date" : 1672531223000}}}, "sort" : {"userId" : 1} } } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); - for (int i = 0; i < testDatas.size(); i++) - arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(2)) .body("data.documents", jsonEquals(arrayNode.toString())); @@ -202,51 +161,36 @@ public void rangeWithDate() throws Exception { @Test @Order(8) public void rangeWithText() throws Exception { - String json = - """ + JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); + final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); + for (int i = 0; i < testDatas.size(); i++) { + arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); + } + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"username" : {"$gt" : "user23"}}, "sort" : {"userId" : 1} } } - """; - JsonNodeFactory nodefactory = objectMapper.getNodeFactory(); - final ArrayNode arrayNode = nodefactory.arrayNode(testDatas.size()); - for (int i = 0; i < testDatas.size(); i++) - arrayNode.add(objectMapper.valueToTree(testDatas.get(i))); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", notNullValue()); } @Test @Order(8) - public void rangeWithBoolean() throws Exception { - String json = - """ + public void rangeWithBoolean() { + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"activeUser" : {"$gt" : false}}, "sort" : {"userId" : 1} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents", notNullValue()); } @@ -256,24 +200,16 @@ public void rangeWithBoolean() throws Exception { public void gtWithFindOne() throws Exception { int[] ids = {24}; List testDatas = getDocuments(ids); - String json = - """ + final String expected = objectMapper.writeValueAsString(testDatas.get(0)); + givenHeadersPostJsonThenOkNoErrors( + """ { "findOne": { "filter" : {"userId" : {"$gt" : 23}}, "sort" : {"userId" : 1} } } - """; - final String expected = objectMapper.writeValueAsString(testDatas.get(0)); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.document", jsonEquals(expected)); } @@ -283,24 +219,16 @@ public void gtWithFindOne() throws Exception { public void gtWithIDRange() throws Exception { int[] ids = {24, 25}; List testDatas = getDocuments(ids); - String json = - """ + final String expected = objectMapper.writeValueAsString(testDatas.get(0)); + givenHeadersPostJsonThenOkNoErrors( + """ { "findOne": { "filter" : {"_id" : {"$gt" : 23}}, "sort" : {"userId" : 1} } } - """; - final String expected = objectMapper.writeValueAsString(testDatas.get(0)); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.document", is(notNullValue())) .body("data.document", jsonEquals(expected)); @@ -308,46 +236,30 @@ public void gtWithIDRange() throws Exception { @Test @Order(11) - public void gtWithDeleteOne() throws Exception { - String json = - """ + public void gtWithDeleteOne() { + givenHeadersPostJsonThenOkNoErrors( + """ { "deleteOne": { "filter" : {"userId" : {"$gt" : 23}}, "sort" : {"userId" : 1} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.deletedCount", is(1)); } @Test @Order(12) - public void gtWithDeleteMany() throws Exception { - String json = - """ + public void gtWithDeleteMany() { + givenHeadersPostJsonThenOkNoErrors( + """ { "deleteMany": { - "filter" : {"userId" : {"$gte" : 23}} } + "filter" : {"userId" : {"$gte" : 23}} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.deletedCount", is(2)); } @@ -384,7 +296,7 @@ private List getDocuments(int[] docIds) { private void insert(List testDatas) { testDatas.forEach( testData -> { - String json = null; + String json; try { json = objectMapper.writeValueAsString(testData); } catch (JsonProcessingException e) { @@ -397,4 +309,90 @@ private void insert(List testDatas) { record TestData(int _id, String username, int userId, boolean activeUser, DateValue dateValue) {} record DateValue(long $date) {} + + @Nested + @TestMethodOrder(MethodOrderer.OrderAnnotation.class) + @Order(2) + class DocumentIdRange { + + // DocumentId with different types, Date/String/Boolean/BigDecimal + static Stream documentIds() { + return Stream.of( + Arguments.of(Map.of("$date", 1672531200000L)), + Arguments.of("doc1"), + Arguments.of(true), + Arguments.of(123)); + } + + @Test + @Order(0) + public void cleanUpCollection() { + givenHeadersPostJsonThenOkNoErrors( + """ + { + "deleteMany": { + "filter": {} + } + } + """) + .body("$", responseIsStatusOnly()) + .body("status.deletedCount", is(-1)) + .body("status.moreData", is(nullValue())); + } + + @ParameterizedTest() + @MethodSource("documentIds") + @Order(1) + public void inserts(Object id) throws Exception { + givenHeadersPostJsonThenOkNoErrors( + """ + { + "insertOne": { + "document": { + "_id": %s + } + } + } + """ + .formatted(objectMapper.writeValueAsString(id))) + .body("$", responseIsWriteSuccess()) + .body("status.insertedIds[0]", is(id)); + } + + @ParameterizedTest() + @MethodSource("documentIds") + @Order(2) + // Take $lte as example, we can use equal to test the filter value against inserted value. + public void rangeTest(Object id) throws Exception { + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : {"_id" : {"$lte" : %s}} + } + } + """ + .formatted(objectMapper.writeValueAsString(id))) + .body("$", responseIsFindSuccess()) + .body("data.documents", hasSize(1)) + .body("data.documents[0]._id", is(id)); + } + + @Order(3) + @Test + public void InvalidRangeFilter() { + String filter = + """ + {"_id" : {"$lte" : null}} + """; + givenHeadersPostJsonThenOk("{ \"findOne\": { \"filter\" : %s}}".formatted(filter)) + .body("$", responseIsError()) + .body("errors", hasSize(1)) + .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) + .body( + "errors[0].message", + containsString( + "$lte operator must have `DATE` or `NUMBER` or `TEXT` or `BOOLEAN` value")); + } + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/UpdateManyIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/UpdateManyIntegrationTest.java index 8dc6e4bfd6..5133c046e9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/UpdateManyIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/UpdateManyIntegrationTest.java @@ -1,6 +1,5 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.hamcrest.Matchers.everyItem; @@ -12,7 +11,6 @@ import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReferenceArray; @@ -34,38 +32,30 @@ class UpdateMany { private void insert(int countOfDocument) { for (int i = 1; i <= countOfDocument; i++) { - String json = - """ + insertDoc( + """ { "_id": "doc%s", "username": "user%s", "active_user" : true } - """; - insertDoc(json.formatted(i, i)); + """ + .formatted(i, i)); } } @Test public void byId() { insert(2); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"_id" : "doc1"}, "update" : {"$set" : {"active_user": false}} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)) @@ -73,66 +63,52 @@ public void byId() { .body("status.nextPageState", nullValue()); // assert state after update, first changed document - String expected = - """ + givenHeadersPostJsonThenOkNoErrors( + """ + { + "find": { + "filter" : {"_id" : "doc1"} + } + } + """) + .body("$", responseIsFindSuccess()) + .body( + "data.documents[0]", + jsonEquals( + """ { "_id":"doc1", "username":"user1", "active_user":false } - """; - json = - """ + """)); + + // then not changed document + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { - "filter" : {"_id" : "doc1"} + "filter" : {"_id" : "doc2"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); - - // then not changed document - expected = - """ + .body( + "data.documents[0]", + jsonEquals( + """ { "_id":"doc2", "username":"user2", "active_user":true } - """; - json = - """ - { - "find": { - "filter" : {"_id" : "doc2"} - } - } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) - .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + """)); } @Test public void emptyOptionsAllowed() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"_id" : "doc1"}, @@ -140,16 +116,7 @@ public void emptyOptionsAllowed() { "options": {} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(0)) .body("status.modifiedCount", is(0)) @@ -160,23 +127,15 @@ public void emptyOptionsAllowed() { @Test public void byColumn() { insert(5); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"active_user": true}, "update" : {"$set" : {"active_user": false}} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(5)) .body("status.modifiedCount", is(5)) @@ -184,21 +143,13 @@ public void byColumn() { .body("status.nextPageState", nullValue()); // assert all updated - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents.active_user", everyItem(is(false))); } @@ -206,45 +157,29 @@ public void byColumn() { @Test public void limit() { insert(20); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"active_user": true}, "update" : {"$set" : {"active_user": false}} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(20)) .body("status.modifiedCount", is(20)) .body("status.moreData", is(true)) .body("status.nextPageState", not(isEmptyOrNullString())); - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"active_user": false} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents.active_user", everyItem(is(false))); } @@ -252,45 +187,29 @@ public void limit() { @Test public void limitMoreDataFlag() { insert(25); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"active_user" : true}, "update" : {"$set" : {"active_user": false}} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(20)) .body("status.modifiedCount", is(20)) .body("status.moreData", is(true)) .body("status.nextPageState", not(isEmptyOrNullString())); - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"active_user": true} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents.active_user", everyItem(is(true))); } @@ -298,24 +217,16 @@ public void limitMoreDataFlag() { @Test public void updatePagination() { insert(25); - String json = - """ + String nextPageState = + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"active_user" : true}, "update" : {"$set" : {"new_data": "new_data_value"}} } } - """; - String nextPageState = - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(20)) .body("status.modifiedCount", is(20)) @@ -325,8 +236,8 @@ public void updatePagination() { .body() .path("status.nextPageState"); - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"active_user" : true}, @@ -335,36 +246,20 @@ public void updatePagination() { } } """ - .formatted(nextPageState); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + .formatted(nextPageState)) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(5)) .body("status.modifiedCount", is(5)) .body("status.moreData", nullValue()) .body("status.nextPageState", nullValue()); - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"active_user": true} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents.new_data", everyItem(is("new_data_value"))); } @@ -372,8 +267,8 @@ public void updatePagination() { @Test public void upsert() { insert(5); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"_id": "doc6"}, @@ -381,15 +276,7 @@ public void upsert() { "options" : {"upsert" : true} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.upsertedId", is("doc6")) .body("status.matchedCount", is(0)) @@ -398,38 +285,31 @@ public void upsert() { .body("status.nextPageState", nullValue()); // assert upsert - String expected = - """ - { - "_id":"doc6", - "active_user":false - } - """; - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc6"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id":"doc6", + "active_user":false + } + """)); } @Test public void upsertWithSetOnInsert() { insert(2); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"_id": "no-such-doc"}, @@ -440,15 +320,7 @@ public void upsertWithSetOnInsert() { "options" : {"upsert" : true} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.upsertedId", is("docX")) .body("status.matchedCount", is(0)) @@ -457,92 +329,69 @@ public void upsertWithSetOnInsert() { .body("status.nextPageState", nullValue()); // assert upsert - String expected = - """ - { - "_id": "docX", - "active_user": true - } - """; - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "docX"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "docX", + "active_user": true + } + """)); } @Test public void byIdNoChange() { insert(2); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"_id" : "doc1"}, "update" : {"$set" : {"active_user": true}} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(0)) .body("status.moreData", nullValue()) .body("status.nextPageState", nullValue()); - String expected = - """ - { - "_id":"doc1", - "username":"user1", - "active_user":true - } - """; - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc1"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id":"doc1", + "username":"user1", + "active_user":true + } + """)); } @Test public void upsertManyByColumnUpsert() { - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"location" : "my_city"}, @@ -550,39 +399,21 @@ public void upsertManyByColumnUpsert() { "options" : {"upsert" : true} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.upsertedId", is(notNullValue())) .body("status.matchedCount", is(0)) .body("status.modifiedCount", is(0)); // assert state after update - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"location" : "my_city"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents[0]", is(notNullValue())); } @@ -590,8 +421,8 @@ public void upsertManyByColumnUpsert() { @Test public void upsertAddFilterColumn() { insert(5); - String json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "updateMany": { "filter" : {"_id": "doc6", "answer" : 42}, @@ -599,15 +430,7 @@ public void upsertAddFilterColumn() { "options" : {"upsert" : true} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsStatusOnly()) .body("status.upsertedId", is("doc6")) .body("status.matchedCount", is(0)) @@ -616,32 +439,25 @@ public void upsertAddFilterColumn() { .body("status.nextPageState", nullValue()); // assert state after update - String expected = - """ - { - "_id":"doc6", - "answer": 42, - "active_user": false - } - """; - json = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { "filter" : {"_id" : "doc6"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id":"doc6", + "answer": 42, + "active_user": false + } + """)); } } @@ -667,17 +483,6 @@ public void concurrentUpdates() throws Exception { int threads = 3; CountDownLatch latch = new CountDownLatch(threads); - // find all documents - String updateJson = - """ - { - "updateMany": { - "update" : { - "$inc" : {"count": 1} - } - } - } - """; // start all threads AtomicReferenceArray exceptions = new AtomicReferenceArray<>(threads); for (int i = 0; i < threads; i++) { @@ -685,14 +490,16 @@ public void concurrentUpdates() throws Exception { new Thread( () -> { try { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(updateJson) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + givenHeadersPostJsonThenOkNoErrors( + """ + { + "updateMany": { + "update" : { + "$inc" : {"count": 1} + } + } + } + """) .body("$", responseIsStatusOnly()) .body("status.matchedCount", is(5)) .body("status.modifiedCount", is(5)); @@ -721,21 +528,13 @@ public void concurrentUpdates() throws Exception { } // assert state after all updates - String findJson = - """ + givenHeadersPostJsonThenOkNoErrors( + """ { "find": { } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(findJson) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsFindSuccess()) .body("data.documents.count", everyItem(is(3))); } @@ -747,22 +546,14 @@ class ClientErrors { @Test public void invalidCommand() { - String updateJson = - """ + givenHeadersPostJsonThenOk( + """ { "updateMany": { "filter" : {"something" : "matching"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(updateJson) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) .body("errors[0].exceptionClass", is("JsonApiException")) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java index 72685a34de..a725eb8fa6 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorSearchIntegrationTest.java @@ -1,13 +1,11 @@ package io.stargate.sgv2.jsonapi.api.v1; -import static io.restassured.RestAssured.given; import static io.stargate.sgv2.jsonapi.api.v1.ResponseAssertions.*; import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals; import static org.hamcrest.Matchers.*; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.restassured.response.Response; import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonApiMetricsConfig; import io.stargate.sgv2.jsonapi.config.DocumentLimitsConfig; @@ -32,8 +30,8 @@ public class VectorSearchIntegrationTest extends AbstractKeyspaceIntegrationTest class CreateCollection { @Test public void happyPathVectorSearch() { - String json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name" : "my_collection", @@ -45,23 +43,15 @@ public void happyPathVectorSearch() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); } @Test public void happyPathVectorSearchDefaultFunction() { - String json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name" : "my_collection_default_function", @@ -72,15 +62,7 @@ public void happyPathVectorSearchDefaultFunction() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); } @@ -94,10 +76,7 @@ public void happyPathBigVectorCollection() { public void failForTooBigVector() { final int maxDimension = DocumentLimitsConfig.DEFAULT_MAX_VECTOR_EMBEDDING_LENGTH; final int tooHighDimension = maxDimension + 10; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersPostJsonThenOk( """ { "createCollection": { @@ -112,13 +91,10 @@ public void failForTooBigVector() { } """ .formatted(tooHighDimension)) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) .body("$", responseIsError()) - .body("errors[0].exceptionClass", is("JsonApiException")) + .body("errors", hasSize(1)) .body("errors[0].errorCode", is("VECTOR_SEARCH_TOO_BIG_VALUE")) + .body("errors[0].exceptionClass", is("JsonApiException")) .body( "errors[0].message", startsWith( @@ -131,8 +107,8 @@ public void failForTooBigVector() { @Test public void failForInvalidVectorMetric() { - String json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name" : "invalidVectorMetric", @@ -144,16 +120,9 @@ public void failForInvalidVectorMetric() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsError()) + .body("errors", hasSize(1)) .body("errors[0].exceptionClass", is("JsonApiException")) .body("errors[0].errorCode", is("COMMAND_FIELD_INVALID")) .body( @@ -168,24 +137,19 @@ public void failForInvalidVectorMetric() { class InsertOneCollection { @Test public void insertVectorSearch() { - String json = - """ - { - "insertOne": { - "document": { - "_id": "1", - "name": "Coded Cleats", - "description": "ChatGPT integrated sneakers that talk to you", - "$vector": [0.25, 0.25, 0.25, 0.25, 0.25] - } - } - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson( + """ + { + "insertOne": { + "document": { + "_id": "1", + "name": "Coded Cleats", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.25, 0.25, 0.25, 0.25, 0.25] + } + } + } + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -193,35 +157,31 @@ public void insertVectorSearch() { .body("$", responseIsWriteSuccess()) .body("status.insertedIds[0]", is("1")); - json = - """ - { - "find": { - "filter" : {"_id" : "1"}, - "projection": { "*": 1 } + givenHeadersAndJson( + """ + { + "find": { + "filter" : {"_id" : "1"}, + "projection": { "*": 1 } + } } - } - """; - String expected = - """ - { - "_id": "1", - "name": "Coded Cleats", - "description": "ChatGPT integrated sneakers that talk to you", - "$vector": [0.25, 0.25, 0.25, 0.25, 0.25] - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() .statusCode(200) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "1", + "name": "Coded Cleats", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.25, 0.25, 0.25, 0.25, 0.25] + } + """)); } // Test to verify vector embedding size can exceed general Array length limit @@ -232,10 +192,7 @@ public void insertBigVectorThenSearch() { insertBigVectorDoc("bigVector1", "Victor", "Big Vectors Rule ok?", vectorStr); // Then verify it was inserted correctly - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { @@ -267,10 +224,7 @@ public void insertBigVectorThenSearch() { """ .formatted(vectorSearchStr); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(findRequest) + givenHeadersAndJson(findRequest) .when() .post(CollectionResource.BASE_PATH, keyspaceName, bigVectorCollectionName) .then() @@ -284,23 +238,18 @@ public void insertBigVectorThenSearch() { @Test public void insertVectorCollectionWithoutVectorData() { - String json = - """ - { - "insertOne": { - "document": { - "_id": "10", - "name": "Coded Cleats", - "description": "ChatGPT integrated sneakers that talk to you" - } - } - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson( + """ + { + "insertOne": { + "document": { + "_id": "10", + "name": "Coded Cleats", + "description": "ChatGPT integrated sneakers that talk to you" + } + } + } + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -308,55 +257,46 @@ public void insertVectorCollectionWithoutVectorData() { .body("$", responseIsWriteSuccess()) .body("status.insertedIds[0]", is("10")); - json = - """ - { - "find": { - "filter" : {"_id" : "10"} - } - } - """; - String expected = - """ - { - "_id": "10", - "name": "Coded Cleats", - "description": "ChatGPT integrated sneakers that talk to you" - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson( + """ + { + "find": { + "filter" : {"_id" : "10"} + } + } + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() .statusCode(200) .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + .body( + "data.documents[0]", + jsonEquals( + """ + { + "_id": "10", + "name": "Coded Cleats", + "description": "ChatGPT integrated sneakers that talk to you" + } + """)); } @Test public void insertEmptyVectorData() { - String json = - """ - { - "insertOne": { - "document": { - "_id": "Invalid", - "name": "Coded Cleats", - "description": "ChatGPT integrated sneakers that talk to you", - "$vector": [] - } - } - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson( + """ + { + "insertOne": { + "document": { + "_id": "Invalid", + "name": "Coded Cleats", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [] + } + } + } + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -371,24 +311,19 @@ public void insertEmptyVectorData() { @Test public void insertInvalidVectorData() { - String json = - """ - { - "insertOne": { - "document": { - "_id": "Invalid", - "name": "Coded Cleats", - "description": "ChatGPT integrated sneakers that talk to you", - "$vector": [0.11, "abc", true, null] - } - } + givenHeadersAndJson( + """ + { + "insertOne": { + "document": { + "_id": "Invalid", + "name": "Coded Cleats", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.11, "abc", true, null] } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + } + } + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -417,10 +352,7 @@ public void insertSimpleBinaryVector() { .formatted(id, base64Vector); // insert the document - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -430,10 +362,7 @@ public void insertSimpleBinaryVector() { // get the document and verify the vector value Response response = - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( "{\"find\": { \"filter\" : {\"_id\" : \"%s\"}, \"projection\" : {\"$vector\" : 1}}}" .formatted(id)) .when() @@ -482,10 +411,7 @@ public void insertLargeDimensionBinaryVector() { .formatted(id, base64Vector); // insert the document - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, "large_binary_vector_collection") .then() @@ -495,10 +421,7 @@ public void insertLargeDimensionBinaryVector() { // get the document and verify the vector value Response response = - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( "{\"find\": { \"filter\" : {\"_id\" : \"%s\"}, \"projection\" : {\"$vector\" : 1}}}" .formatted(id)) .when() @@ -536,10 +459,7 @@ public void failToInsertBinaryVectorWithInvalidBinaryString() { """ .formatted(invalidBinaryString); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -562,10 +482,7 @@ public void failToInsertBinaryVectorWithInvalidBinaryValue() { "$vector": {"$binary": 1234} } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -588,10 +505,7 @@ public void failToInsertBinaryVectorWithInvalidVectorObject() { "$vector": {"binary": "PoAAAD6AAAA+gAAAPoAAAD6AAAA="} } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -614,10 +528,7 @@ public void failToInsertBinaryVectorWithInvalidDecodedValue() { "$vector": {"$binary": "1234"} } """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -644,10 +555,7 @@ public void failToInsertBinaryVectorWithUnmatchedVectorDimension() { """ .formatted(base64Vector); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) + givenHeadersAndJson("{ \"insertOne\": { \"document\": %s }}".formatted(doc)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -667,8 +575,8 @@ public void failToInsertBinaryVectorWithUnmatchedVectorDimension() { class InsertManyCollection { @Test public void insertVectorSearch() { - String json = - """ + givenHeadersAndJson( + """ { "insertMany": { "documents": [ @@ -690,12 +598,7 @@ public void insertVectorSearch() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -704,51 +607,42 @@ public void insertVectorSearch() { .body("status.insertedIds[0]", is("2")) .body("status.insertedIds[1]", is("3")); - json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "2"}, "projection": { "*": 1 } } } - """; - String expected = - """ + """) + .when() + .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) + .then() + .statusCode(200) + .body("$", responseIsFindSuccess()) + .body( + "data.documents[0]", + jsonEquals( + """ { "_id": "2", "name": "Logic Layers", "description": "An AI quilt to help you sleep forever", "$vector": [0.25, 0.25, 0.25, 0.25, 0.25] } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) - .then() - .statusCode(200) - .body("$", responseIsFindSuccess()) - .body("data.documents[0]", jsonEquals(expected)); + """)); } } public void insertVectorDocuments() { - String json = - """ + givenHeadersAndJson( + """ { "deleteMany": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -757,8 +651,8 @@ public void insertVectorDocuments() { .extract() .path("status.moreData"); - json = - """ + givenHeadersAndJson( + """ { "insertMany": { "documents": [ @@ -783,11 +677,7 @@ public void insertVectorDocuments() { ] } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -811,23 +701,18 @@ public void setUp() { @Test @Order(2) public void happyPath() { - String json = - """ - { - "find": { - "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, - "projection" : {"_id" : 1, "$vector" : 1}, - "options" : { - "limit" : 5 + givenHeadersAndJson( + """ + { + "find": { + "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, + "projection" : {"_id" : 1, "$vector" : 1}, + "options" : { + "limit" : 5 + } } } - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -846,8 +731,8 @@ public void happyPath() { public void happyPathBinaryVector() { String vectorString = generateBase64EncodedBinaryVector(new float[] {0.15f, 0.1f, 0.1f, 0.35f, 0.55f}); - String json = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : {"$binary" : "%s" } }, @@ -858,12 +743,7 @@ public void happyPathBinaryVector() { } } """ - .formatted(vectorString); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(vectorString)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -880,8 +760,8 @@ public void happyPathBinaryVector() { @Test @Order(2) public void happyPathWithIncludeSortVectorOption() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -892,12 +772,7 @@ public void happyPathWithIncludeSortVectorOption() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -915,23 +790,18 @@ public void happyPathWithIncludeSortVectorOption() { @Test @Order(3) public void happyPathWithIncludeAll() { - String json = - """ - { - "find": { - "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, - "projection" : {"*" : 1}, - "options" : { - "limit" : 5 - } - } - } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson( + """ + { + "find": { + "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, + "projection" : {"*" : 1}, + "options" : { + "limit" : 5 + } + } + } + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -948,8 +818,8 @@ public void happyPathWithIncludeAll() { @Test @Order(4) public void happyPathWithExcludeAll() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -959,12 +829,7 @@ public void happyPathWithExcludeAll() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -978,8 +843,8 @@ public void happyPathWithExcludeAll() { @Test @Order(5) public void happyPathWithFilter() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "1"}, @@ -990,12 +855,7 @@ public void happyPathWithFilter() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1009,8 +869,8 @@ public void happyPathWithFilter() { @Test @Order(6) public void happyPathWithInFilter() { - String json = - """ + givenHeadersAndJson( + """ { "insertOne": { "document": { @@ -1021,20 +881,16 @@ public void happyPathWithInFilter() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() .statusCode(200) .body("$", responseIsWriteSuccess()) .body("status.insertedIds[0]", is("xx")); - json = - """ + + givenHeadersAndJson( + """ { "find": { "filter" : { @@ -1048,12 +904,7 @@ public void happyPathWithInFilter() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1066,8 +917,8 @@ public void happyPathWithInFilter() { @Test @Order(7) public void happyPathWithEmptyVector() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "1"}, @@ -1077,12 +928,7 @@ public void happyPathWithEmptyVector() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1097,8 +943,8 @@ public void happyPathWithEmptyVector() { @Test @Order(8) public void happyPathWithInvalidData() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "1"}, @@ -1108,12 +954,7 @@ public void happyPathWithInvalidData() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1128,8 +969,8 @@ public void happyPathWithInvalidData() { @Test @Order(9) public void limitError() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -1139,12 +980,7 @@ public void limitError() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1160,8 +996,8 @@ public void limitError() { @Test @Order(10) public void skipError() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -1171,12 +1007,7 @@ public void skipError() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1202,19 +1033,14 @@ public void setUp() { @Test @Order(2) public void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1226,20 +1052,15 @@ public void happyPath() { @Test @Order(3) public void happyPathWithIdFilter() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"}, "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1251,20 +1072,15 @@ public void happyPathWithIdFilter() { @Test @Order(4) public void failWithEmptyVector() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"}, "sort" : {"$vector" : []} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1279,20 +1095,15 @@ public void failWithEmptyVector() { @Test @Order(5) public void failWithZerosVector() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"}, "sort" : {"$vector" : [0.0,0.0,0.0,0.0,0.0]} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1310,20 +1121,15 @@ public void failWithZerosVector() { @Test @Order(6) public void failWithInvalidVectorElements() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"}, "sort" : {"$vector" : [0.11, "abc", true]} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1339,19 +1145,14 @@ public void failWithInvalidVectorElements() { @Test @Order(7) public void failWithVectorFilter() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"$vector" : [ 1, 1, 1, 1, 1 ]} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1363,7 +1164,7 @@ public void failWithVectorFilter() { .body( "errors[0].message", containsString( - "Cannot filter on '$vector' field using operator '$eq': only '$exists' is supported")); + "Cannot filter on '$vector' field using operator $eq: only $exists is supported")); } } @@ -1380,8 +1181,8 @@ public void setUp() { @Test @Order(2) public void setOperation() { - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"_id": "2"}, @@ -1390,12 +1191,7 @@ public void setOperation() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1410,8 +1206,8 @@ public void setOperation() { @Test @Order(3) public void unsetOperation() { - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"name": "Coded Cleats"}, @@ -1419,12 +1215,7 @@ public void unsetOperation() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1439,8 +1230,8 @@ public void unsetOperation() { @Test @Order(4) public void setOnInsertOperation() { - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"_id": "11"}, @@ -1449,12 +1240,7 @@ public void setOnInsertOperation() { "options" : {"returnDocument" : "after", "upsert": true} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1470,8 +1256,8 @@ public void setOnInsertOperation() { @Test @Order(5) public void errorOperationForVector() { - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"_id": "3"}, @@ -1479,12 +1265,7 @@ public void errorOperationForVector() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1502,10 +1283,7 @@ public void setBigVectorOperation() { insertBigVectorDoc("bigVectorForSet", "Bob", "Desc for Bob.", null); // and verify we have null for it - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { @@ -1525,8 +1303,8 @@ public void setBigVectorOperation() { // then set the vector final String vectorStr = buildVectorElements(7, BIG_VECTOR_SIZE); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"_id": "bigVectorForSet"}, @@ -1536,12 +1314,7 @@ public void setBigVectorOperation() { } } """ - .formatted(vectorStr); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(vectorStr)) .when() .post(CollectionResource.BASE_PATH, keyspaceName, bigVectorCollectionName) .then() @@ -1554,10 +1327,7 @@ public void setBigVectorOperation() { .body("data.document.$vector", hasSize(BIG_VECTOR_SIZE)); // and verify it was set to value with expected size - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { @@ -1586,8 +1356,8 @@ class VectorSearchExtendedCommands { @Order(1) public void findOneAndUpdate() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -1595,12 +1365,7 @@ public void findOneAndUpdate() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1616,19 +1381,15 @@ public void findOneAndUpdate() { @Order(2) public void updateOne() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "updateOne": { "update" : {"$set" : {"new_col": "new_val"}}, "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1637,18 +1398,14 @@ public void updateOne() { .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)) .body("status.moreData", is(nullValue())); - json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "3"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1662,8 +1419,8 @@ public void updateOne() { @Order(3) public void findOneAndReplace() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndReplace": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -1672,12 +1429,7 @@ public void findOneAndReplace() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1694,8 +1446,8 @@ public void findOneAndReplace() { @Order(4) public void findOneAndReplaceWithoutVector() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndReplace": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, @@ -1703,12 +1455,7 @@ public void findOneAndReplaceWithoutVector() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1728,10 +1475,7 @@ public void findOneAndReplaceWithBigVector() { insertBigVectorDoc("bigVectorForFindReplace", "Alice", "Desc for Alice.", null); // and verify we have null for it - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { @@ -1750,9 +1494,8 @@ public void findOneAndReplaceWithBigVector() { .body("data.documents[0].$vector", is(nullValue())); // then set the vector - final String vectorStr = buildVectorElements(2, BIG_VECTOR_SIZE); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndReplace": { "filter" : {"_id" : "bigVectorForFindReplace"}, @@ -1762,12 +1505,7 @@ public void findOneAndReplaceWithBigVector() { } } """ - .formatted(vectorStr); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + .formatted(buildVectorElements(2, BIG_VECTOR_SIZE))) .when() .post(CollectionResource.BASE_PATH, keyspaceName, bigVectorCollectionName) .then() @@ -1780,10 +1518,7 @@ public void findOneAndReplaceWithBigVector() { .body("data.document.$vector", hasSize(BIG_VECTOR_SIZE)); // and verify it was set to value with expected size - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "find": { @@ -1807,20 +1542,15 @@ public void findOneAndReplaceWithBigVector() { @Order(6) public void findOneAndDelete() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndDelete": { "projection": { "*": 1 }, "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1836,20 +1566,15 @@ public void findOneAndDelete() { @Order(7) public void deleteOne() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "deleteOne": { "filter" : {"$vector" : {"$exists" : true}}, "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1858,19 +1583,14 @@ public void deleteOne() { .body("status.deletedCount", is(1)); // ensure find does not find the document - json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "3"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1884,9 +1604,8 @@ public void deleteOne() { public void insertVectorWithUnmatchedSize() { createVectorCollection(keyspaceName, vectorSizeTestCollectionName, 5); // Insert data with $vector array size less than vector index defined size. - final String vectorStrCount3 = buildVectorElements(0, 3); - String jsonVectorStrCount3 = - """ + givenHeadersAndJson( + """ { "insertOne": { "document": { @@ -1898,12 +1617,7 @@ public void insertVectorWithUnmatchedSize() { } } """ - .formatted(vectorStrCount3); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(jsonVectorStrCount3) + .formatted(buildVectorElements(0, 3))) .when() .post(CollectionResource.BASE_PATH, keyspaceName, vectorSizeTestCollectionName) .then() @@ -1917,9 +1631,8 @@ public void insertVectorWithUnmatchedSize() { "Length of vector parameter different from declared '$vector' dimension: root cause =")); // Insert data with $vector array size greater than vector index defined size. - final String vectorStrCount7 = buildVectorElements(0, 7); - String jsonVectorStrCount7 = - """ + givenHeadersAndJson( + """ { "insertOne": { "document": { @@ -1931,12 +1644,7 @@ public void insertVectorWithUnmatchedSize() { } } """ - .formatted(vectorStrCount7); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(jsonVectorStrCount7) + .formatted(buildVectorElements(0, 7))) .when() .post(CollectionResource.BASE_PATH, keyspaceName, vectorSizeTestCollectionName) .then() @@ -1954,9 +1662,8 @@ public void insertVectorWithUnmatchedSize() { @Order(9) public void findVectorWithUnmatchedSize() { // Sort clause with $vector array size greater than vector index defined size. - final String vectorStrCount3 = buildVectorElements(0, 3); - String jsonVectorStrCount3 = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : [ %s ]}, @@ -1965,13 +1672,8 @@ public void findVectorWithUnmatchedSize() { } } } - """ - .formatted(vectorStrCount3); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(jsonVectorStrCount3) + """ + .formatted(buildVectorElements(0, 3))) .when() .post(CollectionResource.BASE_PATH, keyspaceName, vectorSizeTestCollectionName) .then() @@ -1985,9 +1687,8 @@ public void findVectorWithUnmatchedSize() { "Length of vector parameter different from declared '$vector' dimension: root cause =")); // Insert data with $vector array size greater than vector index defined size. - final String vectorStrCount7 = buildVectorElements(0, 7); - String jsonVectorStrCount7 = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : [ %s ]}, @@ -1997,12 +1698,7 @@ public void findVectorWithUnmatchedSize() { } } """ - .formatted(vectorStrCount7); - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(jsonVectorStrCount7) + .formatted(buildVectorElements(0, 7))) .when() .post(CollectionResource.BASE_PATH, keyspaceName, vectorSizeTestCollectionName) .then() @@ -2026,20 +1722,15 @@ class VectorSearchSimilarityProjection { @Order(1) public void findOneSimilarityOption() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, "options" : {"includeSimilarity" : true} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -2053,20 +1744,15 @@ public void findOneSimilarityOption() { @Order(2) public void findSimilarityOption() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vector" : [0.15, 0.1, 0.1, 0.35, 0.55]}, "options" : {"includeSimilarity" : true} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -2083,10 +1769,7 @@ public void findSimilarityOption() { } private void createVectorCollection(String namespaceName, String collectionName, int vectorSize) { - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "createCollection": { @@ -2124,10 +1807,7 @@ private void insertBigVectorDoc(String id, String name, String description, Stri // First insert a document with a big vector - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "insertOne": { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorizeSearchIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorizeSearchIntegrationTest.java index f9df954f4b..985a70ca8b 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorizeSearchIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/VectorizeSearchIntegrationTest.java @@ -8,7 +8,6 @@ import com.datastax.oss.driver.api.core.cql.SimpleStatement; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; -import io.restassured.http.ContentType; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import java.util.Arrays; import java.util.List; @@ -34,8 +33,8 @@ public class VectorizeSearchIntegrationTest extends AbstractKeyspaceIntegrationT class CreateCollection { @Test public void happyPathVectorSearch() { - String json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name": "my_collection_vectorize", @@ -57,20 +56,12 @@ public void happyPathVectorSearch() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); - json = - """ + givenHeadersPostJsonThenOk( + """ { "createCollection": { "name": "my_collection_vectorize_deny", @@ -95,15 +86,7 @@ public void happyPathVectorSearch() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) - .when() - .post(KeyspaceResource.BASE_PATH, keyspaceName) - .then() - .statusCode(200) + """) .body("$", responseIsDDLSuccess()) .body("status.ok", is(1)); } @@ -143,10 +126,7 @@ public void insertVectorSearch() { initialInputByteSum = findEmbeddingSumFromMetrics(vectorizeInputBytesMetrics); } - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson(json) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -174,10 +154,7 @@ public void insertVectorSearch() { afterCallInputByteSum, initialInputByteSum) .isEqualTo(44L); - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson(json) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionNameDenyAll) .then() @@ -185,20 +162,15 @@ public void insertVectorSearch() { .body("$", responseIsWriteSuccess()) .body("status.insertedIds[0]", is("1")); - json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "1"}, "projection": { "$vector": 1 } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -211,8 +183,8 @@ public void insertVectorSearch() { @Test public void insertVectorArrayData() { - String json = - """ + givenHeadersAndJson( + """ { "insertOne": { "document": { @@ -223,12 +195,7 @@ public void insertVectorArrayData() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -245,8 +212,8 @@ public void insertVectorArrayData() { @Test public void insertInvalidVectorizeData() { - String json = - """ + givenHeadersAndJson( + """ { "insertOne": { "document": { @@ -257,12 +224,7 @@ public void insertInvalidVectorizeData() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -280,8 +242,8 @@ public void insertInvalidVectorizeData() { class InsertManyCollection { @Test public void insertVectorSearch() { - String json = - """ + givenHeadersAndJson( + """ { "insertMany": { "documents": [ @@ -303,12 +265,7 @@ public void insertVectorSearch() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -352,7 +309,8 @@ public void insertVectorSearch() { assertThat(vectorizeInputBytesMetrics) .satisfies( lines -> { - assertThat(lines.size()).isEqualTo(3); + // aaron, this used to check the number of lines, that is linked to the number of + // percentiles and is very very fragle to include in a test lines.forEach( line -> { assertThat(line).contains("embedding_provider=\"CustomITEmbeddingProvider\""); @@ -377,20 +335,15 @@ public void insertVectorSearch() { }); }); - json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "2"}, "projection": { "$vector": 1 } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -402,18 +355,13 @@ public void insertVectorSearch() { } public void insertVectorDocuments() { - String json = - """ + givenHeadersAndJson( + """ { "deleteMany": { } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -422,8 +370,8 @@ public void insertVectorDocuments() { .extract() .path("status.moreData"); - json = - """ + givenHeadersAndJson( + """ { "insertMany": { "documents": [ @@ -448,11 +396,7 @@ public void insertVectorDocuments() { ] } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -476,8 +420,8 @@ public void setUp() { @Test @Order(2) public void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"}, @@ -487,12 +431,7 @@ public void happyPath() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -515,8 +454,8 @@ public void happyPath() { @Test @Order(3) public void happyPathWithFilter() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "1"}, @@ -527,12 +466,7 @@ public void happyPathWithFilter() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -545,8 +479,8 @@ public void happyPathWithFilter() { @Test @Order(5) public void happyPathWithInvalidData() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "filter" : {"_id" : "1"}, @@ -556,12 +490,7 @@ public void happyPathWithInvalidData() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -578,20 +507,15 @@ public void happyPathWithInvalidData() { @Test @Order(6) public void vectorizeSortDenyAll() { - String json = - """ + givenHeadersAndJson( + """ { "find": { "projection": { "$vector": 1, "$vectorize" : 1 }, "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionNameDenyAll) .then() @@ -618,20 +542,15 @@ public void setUp() { @Test @Order(2) public void happyPath() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"}, "options" : { "includeSortVector" : true } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -644,20 +563,15 @@ public void happyPath() { @Test @Order(3) public void happyPathWithIdFilter() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"}, "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -669,20 +583,15 @@ public void happyPathWithIdFilter() { @Test @Order(4) public void failWithEmptyVector() { - String json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"}, "sort" : {"$vectorize" : []} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -709,8 +618,8 @@ public void setUp() { @Test @Order(2) public void setOperation() { - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"_id": "2"}, @@ -719,12 +628,7 @@ public void setOperation() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -740,8 +644,8 @@ public void setOperation() { @Test @Order(3) public void unsetOperation() { - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"name": "Coded Cleats"}, @@ -749,12 +653,7 @@ public void unsetOperation() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -769,8 +668,8 @@ public void unsetOperation() { @Test @Order(4) public void setOnInsertOperation() { - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "filter" : {"_id": "11"}, @@ -779,12 +678,7 @@ public void setOnInsertOperation() { "options" : {"returnDocument" : "after", "upsert": true} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -806,8 +700,8 @@ class VectorSearchExtendedCommands { @Order(1) public void findOneAndUpdate_sortClause() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "sort" : {"$vectorize" : "A deep learning display that controls your mood"}, @@ -815,12 +709,7 @@ public void findOneAndUpdate_sortClause() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -836,8 +725,8 @@ public void findOneAndUpdate_sortClause() { @Order(2) public void findOneAndUpdate_updateClause() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndUpdate": { "sort" : {"$vectorize" : "A deep learning display that controls your mood"}, @@ -845,12 +734,7 @@ public void findOneAndUpdate_updateClause() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -861,8 +745,8 @@ public void findOneAndUpdate_updateClause() { .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); - json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "3"}, @@ -871,11 +755,7 @@ public void findOneAndUpdate_updateClause() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -891,19 +771,15 @@ public void findOneAndUpdate_updateClause() { @Order(3) public void updateOne_sortClause() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "updateOne": { "update" : {"$set" : {"new_col": "new_val"}}, "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -912,18 +788,15 @@ public void updateOne_sortClause() { .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)) .body("status.moreData", is(nullValue())); - json = - """ + + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -937,19 +810,15 @@ public void updateOne_sortClause() { @Order(4) public void updateOne_updateClause() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "updateOne": { "update" : {"$set" : {"new_col": "new_val", "$vectorize":"ChatGPT upgraded"}}, "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"} } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -958,8 +827,8 @@ public void updateOne_updateClause() { .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)) .body("status.moreData", is(nullValue())); - json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"}, @@ -968,11 +837,7 @@ public void updateOne_updateClause() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -988,8 +853,8 @@ public void updateOne_updateClause() { @Order(5) public void findOneAndReplace() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndReplace": { "projection": { "$vector": 1 }, @@ -998,12 +863,7 @@ public void findOneAndReplace() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1021,8 +881,8 @@ public void findOneAndReplace() { @Order(6) public void findOneAndReplaceWithoutVector() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndReplace": { "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"}, @@ -1030,12 +890,7 @@ public void findOneAndReplaceWithoutVector() { "options" : {"returnDocument" : "after"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1052,20 +907,15 @@ public void findOneAndReplaceWithoutVector() { @Order(7) public void findOneAndDelete() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "findOneAndDelete": { "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"}, "projection": { "*": 1 } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1081,20 +931,15 @@ public void findOneAndDelete() { @Order(8) public void deleteOne() { insertVectorDocuments(); - String json = - """ + givenHeadersAndJson( + """ { "deleteOne": { "filter" : {"$vector" : {"$exists" : true}}, "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1103,19 +948,14 @@ public void deleteOne() { .body("status.deletedCount", is(1)); // ensure find does not find the document - json = - """ + givenHeadersAndJson( + """ { "findOne": { "filter" : {"_id" : "1"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, collectionName) .then() @@ -1127,8 +967,8 @@ public void deleteOne() { @Test @Order(9) public void createDropDifferentVectorDimension() { - String json = - """ + givenHeadersAndJson( + """ { "createCollection": { "name": "cacheTestTable", @@ -1150,11 +990,7 @@ public void createDropDifferentVectorDimension() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(KeyspaceResource.BASE_PATH, keyspaceName) .then() @@ -1163,8 +999,8 @@ public void createDropDifferentVectorDimension() { .body("status.ok", is(1)); // insertOne to trigger the schema cache - json = - """ + givenHeadersAndJson( + """ { "insertOne": { "document": { @@ -1175,12 +1011,7 @@ public void createDropDifferentVectorDimension() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, "cacheTestTable") .then() @@ -1189,10 +1020,7 @@ public void createDropDifferentVectorDimension() { .body("status.insertedIds[0]", is("1")); // DeleteCollection, should evict the corresponding schema cache - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body( + givenHeadersAndJson( """ { "deleteCollection": { @@ -1209,8 +1037,8 @@ public void createDropDifferentVectorDimension() { .body("status.ok", is(1)); // Create a new collection with same name, but dimension as 6 - json = - """ + givenHeadersAndJson( + """ { "createCollection": { "name": "cacheTestTable", @@ -1232,11 +1060,7 @@ public void createDropDifferentVectorDimension() { } } } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(KeyspaceResource.BASE_PATH, keyspaceName) .then() @@ -1245,8 +1069,8 @@ public void createDropDifferentVectorDimension() { .body("status.ok", is(1)); // insertOne, should use the new collectionSetting, since the outdated one has been evicted - json = - """ + givenHeadersAndJson( + """ { "insertOne": { "document": { @@ -1257,12 +1081,7 @@ public void createDropDifferentVectorDimension() { } } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, "cacheTestTable") .then() @@ -1271,20 +1090,15 @@ public void createDropDifferentVectorDimension() { .body("status.insertedIds[0]", is("1")); // find, verify the dimension - json = - """ + givenHeadersAndJson( + """ { "find": { "projection": { "$vector": 1, "$vectorize" : 1 }, "sort" : {"$vectorize" : "ChatGPT integrated sneakers that talk to you"} } } - """; - - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + """) .when() .post(CollectionResource.BASE_PATH, keyspaceName, "cacheTestTable") .then() @@ -1301,7 +1115,13 @@ public void createDropDifferentVectorDimension() { @Nested @Order(8) @TestMethodOrder(MethodOrderer.OrderAnnotation.class) - class DeprecatedModel { + class UnknownExistingModel { + + // As best practice, when we deprecate or EOL a model, + // we should mark them in the configuration, + // instead of removing the whole entry as bad practice! + // The bad practice should only happen in dev before, add this validation to capture, and + // confirm it does at least not return 500. @Test @Order(1) public void findOneAndUpdate_sortClause() { @@ -1327,26 +1147,18 @@ public void findOneAndUpdate_sortClause() { executeCqlStatement( SimpleStatement.newInstance( tableWithBadModel.formatted(keyspaceName, collection, collection))); - String json = - """ - { "findOne": {} } - """; - given() - .headers(getHeaders()) - .contentType(ContentType.JSON) - .body(json) + givenHeadersAndJson("{ \"findOne\": {} } ") .when() .post(CollectionResource.BASE_PATH, keyspaceName, collection) .then() .statusCode(200) .body("$", responseIsError()) .body("errors", hasSize(1)) - .body("errors[0].errorCode", is("VECTORIZE_MODEL_DEPRECATED")) + .body("errors[0].errorCode", is("VECTORIZE_SERVICE_TYPE_UNAVAILABLE")) .body("errors[0].exceptionClass", is("JsonApiException")) .body( "errors[0].message", - containsString( - "Model random is deprecated, supported models for provider 'nvidia' are")); + containsString("unknown model 'random' for service provider 'nvidia'")); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AbstractTableIntegrationTestBase.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AbstractTableIntegrationTestBase.java index 6a59cfc67d..dfe05fc1d3 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AbstractTableIntegrationTestBase.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AbstractTableIntegrationTestBase.java @@ -6,7 +6,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import io.stargate.sgv2.jsonapi.api.v1.AbstractKeyspaceIntegrationTestBase; -/** Abstract class for all table int tests that needs a collection to execute tests in. */ +/** Abstract class for all table int tests that needs a table to execute tests in. */ public class AbstractTableIntegrationTestBase extends AbstractKeyspaceIntegrationTestBase { private static final ObjectMapper MAPPER = new ObjectMapper(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AlterTableIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AlterTableIntegrationTest.java index 69511e2ed3..b60c0ea14d 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AlterTableIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/AlterTableIntegrationTest.java @@ -269,7 +269,7 @@ public void addingToInvalidColumn() { .hasSingleApiError( SchemaException.Code.CANNOT_VECTORIZE_UNKNOWN_COLUMNS, SchemaException.class, - "The command attempted to drop the unknown columns: invalid_column."); + "The command attempted to vectorize the unknown columns: invalid_column."); } @Test @@ -285,6 +285,33 @@ public void addingToNonVectorTypeColumn() { SchemaException.class, "The command attempted to vectorize the non-vector columns: age."); } + + private static Stream deprecatedEmbeddingModelSource() { + return Stream.of( + Arguments.of( + "DEPRECATED", + "a-deprecated-nvidia-embedding-model", + SchemaException.Code.DEPRECATED_AI_MODEL), + Arguments.of( + "END_OF_LIFE", + "a-EOL-nvidia-embedding-model", + SchemaException.Code.END_OF_LIFE_AI_MODEL)); + } + + @ParameterizedTest + @MethodSource("deprecatedEmbeddingModelSource") + public void deprecatedEmbeddingModel( + String status, String modelName, SchemaException.Code errorCode) { + assertTableCommand(keyspaceName, testTableName) + .templated() + .alterTable( + "addVectorize", + Map.of("vector_type_1", Map.of("provider", "nvidia", "modelName", modelName))) + .hasSingleApiError( + errorCode, + SchemaException.class, + "The model is: %s. It is at %s status".formatted(modelName, status)); + } } @Nested diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java index 659453e01a..1ab08237ea 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIndexIntegrationTest.java @@ -22,6 +22,8 @@ @TestClassOrder(ClassOrderer.OrderAnnotation.class) class CreateTableIndexIntegrationTest extends AbstractTableIntegrationTestBase { String testTableName = "tableForCreateIndexTest"; + String lexicalTableName = "tableForCreateTextIndexTest"; + String vectorTableName = "tableForCreateVectorIndexTest"; private void verifyCreatedIndex(String indexName) { assertTableCommand(keyspaceName, testTableName) @@ -31,8 +33,25 @@ private void verifyCreatedIndex(String indexName) { .hasIndex(indexName); } + private void verifyCreatedTextIndex(String indexName) { + assertTableCommand(keyspaceName, lexicalTableName) + .templated() + .listIndexes(false) + .wasSuccessful() + .hasIndex(indexName); + } + + private void verifyCreatedVectorIndex(String indexName) { + assertTableCommand(keyspaceName, vectorTableName) + .templated() + .listIndexes(false) + .wasSuccessful() + .hasIndex(indexName); + } + @BeforeAll - public final void createSimpleTable() { + public final void createTestTables() { + // Create test tables for indexing: first one for "regular" indexes assertNamespaceCommand(keyspaceName) .templated() .createTable( @@ -60,7 +79,17 @@ public final void createSimpleTable() { Map.of("type", "map", "keyType", "int", "valueType", "text")), Map.entry( "map_type_float_value", - Map.of("type", "map", "keyType", "text", "valueType", "float")), + Map.of("type", "map", "keyType", "text", "valueType", "float"))), + "id") + .wasSuccessful(); + + // and then the second table for vector indexes + assertNamespaceCommand(keyspaceName) + .templated() + .createTable( + vectorTableName, + Map.ofEntries( + Map.entry("id", Map.of("type", "text")), Map.entry("vector_type_1", Map.of("type", "vector", "dimension", 1024)), Map.entry("vector_type_2", Map.of("type", "vector", "dimension", 1536)), Map.entry("vector_type_3", Map.of("type", "vector", "dimension", 1024)), @@ -70,11 +99,25 @@ public final void createSimpleTable() { Map.entry("vector_type_7", Map.of("type", "vector", "dimension", 1024))), "id") .wasSuccessful(); + + // and then the third table for text (aka "lexical") indexes + assertNamespaceCommand(keyspaceName) + .templated() + .createTable( + lexicalTableName, + Map.ofEntries( + Map.entry("id", Map.of("type", "text")), + Map.entry("text_field_1", Map.of("type", "text")), + Map.entry("text_field_2", Map.of("type", "text")), + Map.entry("text_field_3", Map.of("type", "text")), + Map.entry("text_field_x", Map.of("type", "text"))), + "id") + .wasSuccessful(); } @Nested @Order(1) - class CreateIndexSuccess { + class CreateRegularIndexSuccess { @Test public void createIndexBasic() { @@ -460,7 +503,7 @@ public void createIndexWithCorrectIndexType() { class CreateVectorIndexSuccess { @Test public void createVectorIndex() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -472,12 +515,12 @@ public void createVectorIndex() { """) .wasSuccessful(); - verifyCreatedIndex("vector_type_1_idx"); + verifyCreatedVectorIndex("vector_type_1_idx"); } @Test public void createVectorIndexWithSourceModel() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -492,12 +535,12 @@ public void createVectorIndexWithSourceModel() { """) .wasSuccessful(); - verifyCreatedIndex("vector_type_2_idx"); + verifyCreatedVectorIndex("vector_type_2_idx"); } @Test public void createVectorIndexWithMetric() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -512,12 +555,12 @@ public void createVectorIndexWithMetric() { """) .wasSuccessful(); - verifyCreatedIndex("vector_type_3_idx"); + verifyCreatedVectorIndex("vector_type_3_idx"); } @Test public void createVectorIndexWithMetricAndSourceModel() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -533,12 +576,12 @@ public void createVectorIndexWithMetricAndSourceModel() { """) .wasSuccessful(); - verifyCreatedIndex("vector_type_4_idx"); + verifyCreatedVectorIndex("vector_type_4_idx"); } @Test public void createVectorIndexWithCorrectIndexType() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -551,13 +594,88 @@ public void createVectorIndexWithCorrectIndexType() { """) .wasSuccessful(); - verifyCreatedIndex("vector_type_6_idx"); + verifyCreatedVectorIndex("vector_type_6_idx"); } } @Nested @Order(3) - class CreateIndexFailure { + class CreateTextIndexSuccess { + // First, a test for the default text index creation (no options specified) + @Test + public void createTextIndexWithDefaults() { + assertTableCommand(keyspaceName, lexicalTableName) + .postCreateTextIndex( + """ + { + "name": "text_field_1_idx", + "definition": { + "column": "text_field_1" + } + } + """) + .wasSuccessful(); + + verifyCreatedTextIndex("text_field_1_idx"); + } + + // Then a test with "named" analyzer like "standard" or "english" + @Test + public void createTextIndexWithNamed() { + assertTableCommand(keyspaceName, lexicalTableName) + .postCreateTextIndex( + """ + { + "name": "text_field_2_idx", + "definition": { + "column": "text_field_2", + "options": { + "analyzer": "english" + } + }, + "options": { + "ifNotExists": true + } + } + """) + .wasSuccessful(); + + verifyCreatedTextIndex("text_field_2_idx"); + } + + // Then a test with explicit settings + @Test + public void createTextIndexWithFullDefinition() { + assertTableCommand(keyspaceName, lexicalTableName) + .postCreateTextIndex( + """ + { + "name": "text_field_3_idx", + "definition": { + "column": "text_field_3", + "options": { + "analyzer": { + "tokenizer" : {"name" : "standard"}, + "filters": [ + { "name": "lowercase" }, + { "name": "stop" }, + { "name": "porterstem" }, + { "name": "asciifolding" } + ] + } + } + } + } + """) + .wasSuccessful(); + + verifyCreatedTextIndex("text_field_3_idx"); + } + } + + @Nested + @Order(4) + class CreateRegularIndexFailure { @Test public void createIndexWithEmptyName() { @@ -813,11 +931,11 @@ public void analyzeOptionsForEntriesIndexOnMap(String columnValue) { } @Nested - @Order(4) + @Order(5) class CreateVectorIndexFailure { @Test public void createIndexWithEmptyName() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateIndex( """ { @@ -838,7 +956,7 @@ public void createIndexWithEmptyName() { @Test public void createIndexWithBlankName() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -859,7 +977,7 @@ public void createIndexWithBlankName() { @Test public void createIndexWithNameTooLong() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateIndex( """ { @@ -879,7 +997,7 @@ public void createIndexWithNameTooLong() { @Test public void createIndexWithSpecialCharacterInName() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateIndex( """ { @@ -899,7 +1017,7 @@ public void createIndexWithSpecialCharacterInName() { @Test public void tryCreateIndexMissingColumn() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -918,7 +1036,7 @@ public void tryCreateIndexMissingColumn() { @Test public void invalidSourceModel() { - DataApiCommandSenders.assertTableCommand(keyspaceName, testTableName) + DataApiCommandSenders.assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -939,7 +1057,7 @@ public void invalidSourceModel() { @Test public void createVectorIndexWithUnsupportedIndexType() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -959,7 +1077,7 @@ public void createVectorIndexWithUnsupportedIndexType() { @Test public void createVectorIndexWithUnknownIndexType() { - assertTableCommand(keyspaceName, testTableName) + assertTableCommand(keyspaceName, vectorTableName) .postCreateVectorIndex( """ { @@ -977,4 +1095,31 @@ public void createVectorIndexWithUnknownIndexType() { "The command used the unknown index type: unknown."); } } + + @Nested + @Order(6) + class CreateTextIndexFailure { + // Definition of the text index must be JSON String or Object; fail if not + @Test + public void failForDefNotStringOrObject() { + assertTableCommand(keyspaceName, lexicalTableName) + .postCreateTextIndex( + """ + { + "name": "text_field_x_idx", + "definition": { + "column": "text_field_x", + "options": { + "analyzer": [1, 2, 3] + } + } + } + """) + .hasSingleApiError( + SchemaException.Code.UNSUPPORTED_JSON_TYPE_FOR_TEXT_INDEX, + SchemaException.class, + "command attempted to create a text index using an unsupported JSON value", + "command used the unsupported JSON value type: Array"); + } + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java index f669e44a3b..511c12bf46 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/CreateTableIntegrationTest.java @@ -901,6 +901,66 @@ private static Stream allTableData() { ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.name(), "The provided options are invalid: Model name 'mistral-embed-invalid' for provider 'mistral' is not supported"))); + // vector type with deprecated model + testCases.add( + Arguments.of( + new CreateTableTestData( + """ + { + "name": "deprecatedEmbedModel", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "content": { + "type": "vector", + "dimension": 1024, + "service": { + "provider": "nvidia", + "modelName": "a-deprecated-nvidia-embedding-model" + } + } + }, + "primaryKey": "id" + } + } + """, + "deprecatedEmbedModel", + true, + SchemaException.Code.DEPRECATED_AI_MODEL.name(), + "The model is: a-deprecated-nvidia-embedding-model. It is at DEPRECATED status."))); + + // vector type with end_of_life model + testCases.add( + Arguments.of( + new CreateTableTestData( + """ + { + "name": "deprecatedEmbedModel", + "definition": { + "columns": { + "id": { + "type": "text" + }, + "content": { + "type": "vector", + "dimension": 1024, + "service": { + "provider": "nvidia", + "modelName": "a-EOL-nvidia-embedding-model" + } + } + }, + "primaryKey": "id" + } + } + """, + "deprecatedEmbedModel", + true, + SchemaException.Code.END_OF_LIFE_AI_MODEL.name(), + "The model is: a-EOL-nvidia-embedding-model. It is at END_OF_LIFE status."))); + // vector type with dimension mismatch testCases.add( Arguments.of( diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java index 7e850c771a..416a3af54e 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java @@ -2,6 +2,8 @@ import static io.stargate.sgv2.jsonapi.api.v1.util.DataApiCommandSenders.assertNamespaceCommand; import static io.stargate.sgv2.jsonapi.api.v1.util.DataApiCommandSenders.assertTableCommand; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; @@ -1827,8 +1829,12 @@ void insertDifferentVectorizeProviders() { } """) .hasSingleApiError( - ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.name(), - "Provider: openai; HTTP Status: 401; Error Message: Incorrect API key provided: test_emb"); + ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR, + anyOf( + containsString( + "Provider: openai; HTTP Status: 401; Error Message: Incorrect API key provided: test_emb"), + containsString( + "Provider: jinaAI; HTTP Status: 401; Error Message: \"Unauthorized\""))); } } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/ListIndexesIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/ListIndexesIntegrationTest.java index fa6d7f8fb6..551ba73b7d 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/ListIndexesIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/ListIndexesIntegrationTest.java @@ -11,10 +11,13 @@ import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.api.querybuilder.SchemaBuilder; import com.datastax.oss.driver.api.querybuilder.schema.CreateTable; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import io.quarkus.test.common.WithTestResource; import io.quarkus.test.junit.QuarkusIntegrationTest; import io.stargate.sgv2.jsonapi.testresource.DseTestResource; import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil; +import java.util.Map; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.ClassOrderer; import org.junit.jupiter.api.MethodOrderer; @@ -28,22 +31,23 @@ @WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) @TestClassOrder(ClassOrderer.OrderAnnotation.class) public class ListIndexesIntegrationTest extends AbstractTableIntegrationTestBase { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private static final String TABLE = "person"; private static final String createIndex = """ - { - "name": "name_idx", - "definition": { - "column": "name", - "options": { - "ascii": true, - "caseSensitive": false, - "normalize": true + { + "name": "name_idx", + "definition": { + "column": "name", + "options": { + "ascii": true, + "caseSensitive": false, + "normalize": true + } + }, + "indexType": "regular" } - }, - "indexType": "regular" - } - """; + """; String createWithoutOptionsOnText = """ @@ -83,53 +87,53 @@ public class ListIndexesIntegrationTest extends AbstractTableIntegrationTestBase String createWithoutOptionsOnIntExpected = """ - { - "name": "age_idx", - "definition": { - "column": "age", - "options": { - } - }, - "indexType": "regular" + { + "name": "age_idx", + "definition": { + "column": "age", + "options": { } - """; + }, + "indexType": "regular" + } + """; String createVectorIndex = """ - { - "name": "content_idx", - "definition": { - "column": "content", - "options": { - "metric": "cosine", - "sourceModel": "openai-v3-small" + { + "name": "content_idx", + "definition": { + "column": "content", + "options": { + "metric": "cosine", + "sourceModel": "openai-v3-small" + } + }, + "indexType": "vector" } - }, - "indexType": "vector" - } - """; + """; @BeforeAll public final void createDefaultTablesAndIndexes() { String tableData = """ - { - "name": "%s", - "definition": { - "columns": { - "id": "text", - "age": "int", - "name": "text", - "city": "text", - "content": { - "type": "vector", - "dimension": 1024 + { + "name": "%s", + "definition": { + "columns": { + "id": "text", + "age": "int", + "name": "text", + "city": "text", + "content": { + "type": "vector", + "dimension": 1024 + } + }, + "primaryKey": "id" + } } - }, - "primaryKey": "id" - } - } - """; + """; assertNamespaceCommand(keyspaceName) .postCreateTable(tableData.formatted(TABLE)) .wasSuccessful(); @@ -273,108 +277,108 @@ public final void createPreExistedCqlTable() { } @Test - @Order(2) + @Order(3) public void listIndexesWithDefinition() { // full index on frozen map is unsupported, so the index will have UNKNOWN column in the // definition var expected_idx_set = """ - { - "name": "idx_set", - "definition": { - "column": { - "setColumn": "$values" - }, - "options": { - } - }, - "indexType": "regular" - } - """; + { + "name": "idx_set", + "definition": { + "column": { + "setColumn": "$values" + }, + "options": { + } + }, + "indexType": "regular" + } + """; var expected_idx_map_values = """ - { - "name": "idx_map_values", - "definition": { - "column": { - "mapColumn": "$values" - }, - "options": { - } - }, - "indexType": "regular" - } - """; + { + "name": "idx_map_values", + "definition": { + "column": { + "mapColumn": "$values" + }, + "options": { + } + }, + "indexType": "regular" + } + """; var expected_idx_map_keys = """ - { - "name": "idx_map_keys", - "definition": { - "column": { - "mapColumn": "$keys" - }, - "options": { - } - }, - "indexType": "regular" - } - """ + { + "name": "idx_map_keys", + "definition": { + "column": { + "mapColumn": "$keys" + }, + "options": { + } + }, + "indexType": "regular" + } + """ .formatted(keyspaceName, PRE_EXISTED_CQL_TABLE); var expected_idx_map_entries = """ - { - "name": "idx_map_entries", - "definition": { - "column": "mapColumn", - "options": { + { + "name": "idx_map_entries", + "definition": { + "column": "mapColumn", + "options": { + } + }, + "indexType": "regular" } - }, - "indexType": "regular" - } - """; + """; var expected_full_index_frozen_map = """ - { - "name": "idx_full_frozen_map", - "definition": { - "column": "UNKNOWN", - "apiSupport": { - "createIndex": false, - "filter": false, - "cqlDefinition": "CREATE CUSTOM INDEX idx_full_frozen_map ON \\"%s\\".\\"%s\\" (full(\\"frozenMapColumn\\"))\\nUSING 'StorageAttachedIndex'" + { + "name": "idx_full_frozen_map", + "definition": { + "column": "UNKNOWN", + "apiSupport": { + "createIndex": false, + "filter": false, + "cqlDefinition": "CREATE CUSTOM INDEX idx_full_frozen_map ON \\"%s\\".\\"%s\\" (full(\\"frozenMapColumn\\"))\\nUSING 'StorageAttachedIndex'" + } + }, + "indexType": "UNKNOWN" } - }, - "indexType": "UNKNOWN" - } - """ + """ .formatted(keyspaceName, PRE_EXISTED_CQL_TABLE); var expected_idx_list = """ - { - "name": "idx_list", - "definition": { - "column": { - "listColumn": "$values" - }, - "options": { + { + "name": "idx_list", + "definition": { + "column": { + "listColumn": "$values" + }, + "options": { + } + }, + "indexType": "regular" } - }, - "indexType": "regular" - } - """ + """ .formatted(keyspaceName, PRE_EXISTED_CQL_TABLE); var expected_idx_quoted = """ - { - "name": "idx_textQuoted", - "definition": { - "column": "TextQuoted", - "options": { - } - }, - "indexType": "regular" - } - """; + { + "name": "idx_textQuoted", + "definition": { + "column": "TextQuoted", + "options": { + } + }, + "indexType": "regular" + } + """; assertTableCommand(keyspaceName, PRE_EXISTED_CQL_TABLE) .templated() .listIndexes(true) @@ -392,4 +396,124 @@ public void listIndexesWithDefinition() { jsonEquals(expected_idx_quoted))); } } + + @Nested + @TestMethodOrder(MethodOrderer.OrderAnnotation.class) + @Order(2) + public class ListTextIndexes { + private static final String lexicalTableName = "text_index_table_for_list_indexes"; + + private static final String TEXT_INDEX_1 = + """ + { + "name": "text_field_1_idx", + "definition": { + "column": "text_field_1" + } + } + """; + + private static final String TEXT_INDEX_2 = + """ + { + "name": "text_field_2_idx", + "definition": { + "column": "text_field_2", + "options": { + "analyzer": "english" + } + } + } + """; + private static final String TEXT_INDEX_3 = + """ + { + "name": "text_field_3_idx", + "definition": { + "column": "text_field_3", + "options": { + "analyzer": { + "tokenizer" : {"name" : "standard"}, + "filters": [ + { "name": "lowercase" } + ] + } + } + } + } + """; + + @BeforeAll + public static void createTestTableAndIndexes() { + assertNamespaceCommand(keyspaceName) + .templated() + .createTable( + lexicalTableName, + Map.ofEntries( + Map.entry("id", Map.of("type", "text")), + Map.entry("text_field_1", Map.of("type", "text")), + Map.entry("text_field_2", Map.of("type", "text")), + Map.entry("text_field_3", Map.of("type", "text"))), + "id") + .wasSuccessful(); + + // 3 tables: one with default text index; one with named analyzer; and last with custom + // settings + assertTableCommand(keyspaceName, lexicalTableName) + .postCreateTextIndex(TEXT_INDEX_1) + .wasSuccessful(); + assertTableCommand(keyspaceName, lexicalTableName) + .postCreateTextIndex(TEXT_INDEX_2) + .wasSuccessful(); + assertTableCommand(keyspaceName, lexicalTableName) + .postCreateTextIndex(TEXT_INDEX_3) + .wasSuccessful(); + } + + @Test + @Order(1) + public void listIndexNamesOnly() { + + assertTableCommand(keyspaceName, lexicalTableName) + .templated() + .listIndexes(false) + .wasSuccessful() + .hasIndexes("text_field_1_idx", "text_field_2_idx", "text_field_3_idx"); + } + + @Test + @Order(2) + public void listIndexesWithDefinitions() { + assertTableCommand(keyspaceName, lexicalTableName) + .templated() + .listIndexes(true) + .wasSuccessful() + // Validate that status.indexes has all indexes for the table + .body("status.indexes", hasSize(3)) + // Validate index without options + .body( + "status.indexes", + containsInAnyOrder( // Validate that the indexes are in any order + jsonEquals( + addNamedAnalyzer(readIndexDescAddType(TEXT_INDEX_1), "standard").toString()), + jsonEquals(readIndexDescAddType(TEXT_INDEX_2)), + jsonEquals(readIndexDescAddType(TEXT_INDEX_3)))); + } + + private static ObjectNode readIndexDescAddType(String json) { + ObjectNode ob; + try { + ob = OBJECT_MAPPER.readValue(json, ObjectNode.class); + } catch (Exception e) { + throw new RuntimeException(e); + } + ob.put("indexType", "text"); + return ob; + } + + private ObjectNode addNamedAnalyzer(ObjectNode ob, String analyzerName) { + ob.with("definition").with("options").put("analyzer", analyzerName); + return ob; + } + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/LogicalFilterIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/LogicalFilterIntegrationTest.java new file mode 100644 index 0000000000..cd6aa8f577 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/LogicalFilterIntegrationTest.java @@ -0,0 +1,166 @@ +package io.stargate.sgv2.jsonapi.api.v1.tables; + +import static io.stargate.sgv2.jsonapi.api.v1.util.DataApiCommandSenders.assertNamespaceCommand; +import static io.stargate.sgv2.jsonapi.api.v1.util.DataApiCommandSenders.assertTableCommand; + +import io.quarkus.test.common.WithTestResource; +import io.quarkus.test.junit.QuarkusIntegrationTest; +import io.stargate.sgv2.jsonapi.exception.WarningException; +import io.stargate.sgv2.jsonapi.testresource.DseTestResource; +import java.util.Map; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.ClassOrderer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestClassOrder; + +@QuarkusIntegrationTest +@WithTestResource(value = DseTestResource.class, restrictToAnnotatedClass = false) +@TestClassOrder(ClassOrderer.OrderAnnotation.class) +public class LogicalFilterIntegrationTest extends AbstractTableIntegrationTestBase { + + static final String TABLE_WITH_COLUMN_TYPES_INDEXED = "logical_table_indexed"; + + static final Map ALL_COLUMNS = + Map.ofEntries( + Map.entry("id", Map.of("type", "text")), + Map.entry("age", Map.of("type", "int")), + Map.entry("name", Map.of("type", "text")), + Map.entry("is_active", Map.of("type", "boolean")), + Map.entry("total_views", Map.of("type", "bigint"))); + + static final Map> COLUMNS_CAN_BE_SAI_INDEXED = + Map.ofEntries( + // Map.entry("id", Map.of("type", "text")), + Map.entry("age", Map.of("type", "int")), + Map.entry("name", Map.of("type", "text")), + Map.entry("is_active", Map.of("type", "boolean")), + Map.entry("total_views", Map.of("type", "bigint"))); + + static final String SAMPLE_ROW_JSON_1 = + """ + { + "id": "1", + "age": 25, + "name": "John Doe", + "is_active": true, + "total_views": 1000 + } + """; + static final String SAMPLE_ROW_JSON_2 = + """ + { + "id": "2", + "age": 30, + "name": "Jane Smith", + "is_active": false, + "total_views": 2000 + } + """; + static final String SAMPLE_ROW_JSON_3 = + """ + { + "id": "3", + "age": 35, + "name": "Alice Johnson", + "is_active": true, + "total_views": 3000 + } + """; + + @BeforeAll + public final void createDefaultTables() { + // create table + assertNamespaceCommand(keyspaceName) + .templated() + .createTable(TABLE_WITH_COLUMN_TYPES_INDEXED, ALL_COLUMNS, "id") + .wasSuccessful(); + // create index on indexable columns + for (String columnName : COLUMNS_CAN_BE_SAI_INDEXED.keySet()) { + assertTableCommand(keyspaceName, TABLE_WITH_COLUMN_TYPES_INDEXED) + .templated() + .createIndex(TABLE_WITH_COLUMN_TYPES_INDEXED + "_" + columnName, columnName) + .wasSuccessful(); + } + // insert 3 rows + assertTableCommand(keyspaceName, TABLE_WITH_COLUMN_TYPES_INDEXED) + .templated() + .insertMany(SAMPLE_ROW_JSON_1, SAMPLE_ROW_JSON_2, SAMPLE_ROW_JSON_3) + .wasSuccessful() + .hasInsertedIdCount(3); + } + + @Test + public void simpleOr() { + var filter = + """ + { + "filter": { + "$or": [ + {"age": {"$eq" : 25}}, + {"name": "Alice Johnson"} + ] + } + } + """; + assertTableCommand(keyspaceName, TABLE_WITH_COLUMN_TYPES_INDEXED) + .postFind(filter) + .hasDocuments(2) + .hasDocumentUnknowingPosition(SAMPLE_ROW_JSON_1) + .hasDocumentUnknowingPosition(SAMPLE_ROW_JSON_3) + .hasNoWarnings() + .hasNoErrors(); + } + + @Test + public void oneLevelNestedAndOr() { + var filter = + """ + { + "filter": { + "is_active": {"$eq" : true}, + "$or": [ + {"age": {"$lt" : 26}}, + {"total_views": {"$gt" : 2999}} + ] + } + } + """; + assertTableCommand(keyspaceName, TABLE_WITH_COLUMN_TYPES_INDEXED) + .postFind(filter) + .hasDocuments(2) + .hasDocumentUnknowingPosition(SAMPLE_ROW_JSON_1) + .hasDocumentUnknowingPosition(SAMPLE_ROW_JSON_3) + .hasNoWarnings() + .hasNoErrors(); + } + + @Test + public void twoLevelNestedAndOr() { + var filter = + """ + { + "filter": { + "is_active": {"$eq" : true}, + "$or": [ + { + "$and": [ + {"age": {"$lt" : 26} }, + {"name": {"$ne" : "John Doe"}} + ] + }, + {"total_views": {"$gt" : 2999}} + ] + } + } + """; + assertTableCommand(keyspaceName, TABLE_WITH_COLUMN_TYPES_INDEXED) + .postFind(filter) + .hasDocuments(1) + .hasDocumentUnknowingPosition(SAMPLE_ROW_JSON_3) + .hasWarning( + 0, + WarningException.Code.NOT_EQUALS_UNSUPPORTED_BY_INDEXING, + "The filter uses $ne (not equals) on columns that, while indexed, are still inefficient to filter on using not equals.") + .hasNoErrors(); + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiResponseValidator.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiResponseValidator.java index 89b89d479b..a836210fe9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiResponseValidator.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiResponseValidator.java @@ -30,7 +30,15 @@ public DataApiResponseValidator(CommandName commandName, ValidatableResponse res this.responseIsError = switch (commandName) { - case DROP_TABLE, DROP_INDEX, CREATE_INDEX, CREATE_TABLE, ALTER_TABLE, FIND_ONE, FIND -> + case DROP_TABLE, + DROP_INDEX, + CREATE_INDEX, + CREATE_TEXT_INDEX, + CREATE_VECTOR_INDEX, + CREATE_TABLE, + ALTER_TABLE, + FIND_ONE, + FIND -> responseIsErrorWithOptionalStatus(); default -> responseIsError(); }; @@ -43,8 +51,9 @@ public DataApiResponseValidator(CommandName commandName, ValidatableResponse res CREATE_TABLE, DROP_TABLE, CREATE_INDEX, - DROP_INDEX, + CREATE_TEXT_INDEX, CREATE_VECTOR_INDEX, + DROP_INDEX, LIST_TABLES, LIST_INDEXES -> responseIsDDLSuccess(); @@ -105,7 +114,13 @@ public DataApiResponseValidator wasSuccessful() { case DELETE_ONE, DELETE_MANY -> { return hasNoErrors(); } - case ALTER_TABLE, CREATE_TABLE, DROP_TABLE, CREATE_INDEX, DROP_INDEX, CREATE_VECTOR_INDEX -> { + case ALTER_TABLE, + CREATE_TABLE, + DROP_TABLE, + CREATE_INDEX, + CREATE_TEXT_INDEX, + CREATE_VECTOR_INDEX, + DROP_INDEX -> { return hasNoErrors().hasStatusOK(); } case LIST_TABLES, LIST_INDEXES -> { @@ -353,6 +368,10 @@ public DataApiResponseValidator hasDocumentInPosition(int position, String docum return body("data.documents[%s]".formatted(position), jsonEquals(documentJSON)); } + public DataApiResponseValidator hasDocumentUnknowingPosition(String documentJSON) { + return body("data.documents", hasItem(jsonEquals(documentJSON))); + } + public DataApiResponseValidator mayFoundSingleDocumentIdByFindOne( FilterException.Code expectedFilterException, String sampleId) { if (expectedFilterException != null) { diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiTableCommandSender.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiTableCommandSender.java index 07d8608ce5..5fc28321d8 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiTableCommandSender.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/util/DataApiTableCommandSender.java @@ -68,6 +68,10 @@ public DataApiResponseValidator postListIndexes(String jsonClause) { return postCommand(CommandName.LIST_INDEXES, jsonClause); } + public DataApiResponseValidator postCreateTextIndex(String jsonClause) { + return postCommand(CommandName.CREATE_TEXT_INDEX, jsonClause); + } + public DataApiResponseValidator postCreateVectorIndex(String jsonClause) { return postCommand(CommandName.CREATE_VECTOR_INDEX, jsonClause); } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/TestTextUtil.java b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/TestTextUtil.java new file mode 100644 index 0000000000..48d5b4cc29 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/TestTextUtil.java @@ -0,0 +1,23 @@ +package io.stargate.sgv2.jsonapi.fixtures; + +import com.github.javafaker.Faker; + +/** Helper methods for generating textual data for use in the tests. */ +public class TestTextUtil { + /** + * Helper method for generating a Document with exactly specified length (in characters), composed + * of Latin words in "Lorem Ipsum" style. + * + * @param targetLength Exact length of the generated string + * @param separator Separator between sentences. + */ + public static String generateTextDoc(int targetLength, String separator) { + Faker faker = new Faker(); + StringBuilder sb = new StringBuilder(targetLength + 100); + + while (sb.length() < targetLength) { + sb.append(faker.lorem().sentence()).append(separator); + } + return sb.substring(0, targetLength); + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/tables/TableLogicalRelationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/tables/TableLogicalRelationTest.java new file mode 100644 index 0000000000..a6313e787f --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/tables/TableLogicalRelationTest.java @@ -0,0 +1,199 @@ +package io.stargate.sgv2.jsonapi.service.operation.tables; + +import io.stargate.sgv2.jsonapi.fixtures.testdata.TestData; +import io.stargate.sgv2.jsonapi.fixtures.testdata.TestDataNames; +import io.stargate.sgv2.jsonapi.service.operation.query.DBLogicalExpression; +import io.stargate.sgv2.jsonapi.service.operation.query.TableFilter; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for {@link TableWhereCQLClause} to build correct logicalRelation statement */ +public class TableLogicalRelationTest { + + private static final TestData TEST_DATA = new TestData(); + + private TestDataNames names() { + return TEST_DATA.names; + } + + /** The filter is empty so only an implicitAnd, so the where clause should be empty as well. */ + @Test + public void emptyFilter() { + var fixture = TEST_DATA.tableWhereCQLClause().tableWithAllDataTypes("AND()"); + var expBuilder = fixture.expressionBuilder; + var implicitAnd = expBuilder.rootImplicitAnd; + + fixture + .expressionBuilder() + .replaceRootDBLogicalExpression(implicitAnd) + .applyAndGetOnGoingWhereClause() + .assertNoWhereClause() + .assertNoPositionalValues(); + } + + /** + * Implicit AND with two eq filters, so two filters should be AND together in the where clause. + */ + @Test + public void simpleAnd() { + var fixture = TEST_DATA.tableWhereCQLClause().tableWithAllDataTypes("AND(eq,eq)"); + var expBuilder = fixture.expressionBuilder; + var implicitAnd = expBuilder.rootImplicitAnd; + var dbLogicalExpression = + addFilters( + implicitAnd, + expBuilder.eq(names().CQL_INT_COLUMN), + expBuilder.eq(names().CQL_TEXT_COLUMN)); + fixture + .expressionBuilder() + .replaceRootDBLogicalExpression(dbLogicalExpression) + .applyAndGetOnGoingWhereClause() + .assertWhereCQL( + "WHERE (%s=? AND %s=?)".formatted(names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)) + .assertWherePositionalValues(List.of(names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)); + } + + /** Implicit AND with an explicit OR includes two eq filters. */ + @Test + public void simpleOR() { + var fixture = TEST_DATA.tableWhereCQLClause().tableWithAllDataTypes("AND(OR(eq,eq))"); + var expBuilder = fixture.expressionBuilder; + var implicitAnd = expBuilder.rootImplicitAnd; + var dbLogicalExpression = + implicitAnd.addSubExpressionReturnCurrent( + or(expBuilder.eq(names().CQL_INT_COLUMN), expBuilder.eq(names().CQL_TEXT_COLUMN))); + fixture + .expressionBuilder() + .replaceRootDBLogicalExpression(dbLogicalExpression) + .applyAndGetOnGoingWhereClause() + .assertWhereCQL( + "WHERE ((%s=? OR %s=?))".formatted(names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)) + .assertWherePositionalValues(List.of(names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)); + } + + /** Implicit AND with a root-level filter and an explicit OR includes two other eq filters. */ + @Test + public void tableFilterWithLogicalExpressionOR() { + var fixture = TEST_DATA.tableWhereCQLClause().tableWithAllDataTypes("AND(eq, OR(eq,eq))"); + var expBuilder = fixture.expressionBuilder; + var implicitAnd = expBuilder.rootImplicitAnd; + var dbLogicalExpression = + addFilters(implicitAnd, expBuilder.eq(names().CQL_DATE_COLUMN)) + .addSubExpressionReturnCurrent( + or(expBuilder.eq(names().CQL_INT_COLUMN), expBuilder.eq(names().CQL_TEXT_COLUMN))); + fixture + .expressionBuilder() + .replaceRootDBLogicalExpression(dbLogicalExpression) + .applyAndGetOnGoingWhereClause() + .assertWhereCQL( + "WHERE (%s=? AND (%s=? OR %s=?))" + .formatted( + names().CQL_DATE_COLUMN, names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)) + .assertWherePositionalValues( + List.of(names().CQL_DATE_COLUMN, names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)); + } + + /** + * Implicit AND with a root-level filter and an explicit empty OR. However, the OR is empty, so + * just empty parentheses are added. + */ + @Test + public void tableFilterWithPartialEmptyOR() { + var fixture = TEST_DATA.tableWhereCQLClause().tableWithAllDataTypes("AND(eq, OR())"); + var expBuilder = fixture.expressionBuilder; + var implicitAnd = expBuilder.rootImplicitAnd; + var dbLogicalExpression = + addFilters(implicitAnd, expBuilder.eq(names().CQL_DATE_COLUMN)) + .addSubExpressionReturnCurrent(or()); + fixture + .expressionBuilder() + .replaceRootDBLogicalExpression(dbLogicalExpression) + .applyAndGetOnGoingWhereClause() + .assertWhereCQL("WHERE (%s=? AND ())".formatted(names().CQL_DATE_COLUMN)) + .assertWherePositionalValues(List.of(names().CQL_DATE_COLUMN)); + } + + /** Implicit AND with two explicit ORs, each with two eq filters. */ + @Test + public void twoLogicalExpressionOR() { + var fixture = + TEST_DATA.tableWhereCQLClause().tableWithAllDataTypes("AND(OR(eq,eq), OR(eq,eq))"); + var expBuilder = fixture.expressionBuilder; + var implicitAnd = expBuilder.rootImplicitAnd; + var dbLogicalExpression = + implicitAnd + .addSubExpressionReturnCurrent( + or(expBuilder.eq(names().CQL_INT_COLUMN), expBuilder.eq(names().CQL_TEXT_COLUMN))) + .addSubExpressionReturnCurrent( + or( + expBuilder.eq(names().CQL_DATE_COLUMN), + expBuilder.eq(names().CQL_BOOLEAN_COLUMN))); + fixture + .expressionBuilder() + .replaceRootDBLogicalExpression(dbLogicalExpression) + .applyAndGetOnGoingWhereClause() + .assertWhereCQL( + "WHERE ((%s=? OR %s=?) AND (%s=? OR %s=?))" + .formatted( + names().CQL_INT_COLUMN, + names().CQL_TEXT_COLUMN, + names().CQL_DATE_COLUMN, + names().CQL_BOOLEAN_COLUMN)) + .assertWherePositionalValues( + List.of( + names().CQL_INT_COLUMN, + names().CQL_TEXT_COLUMN, + names().CQL_DATE_COLUMN, + names().CQL_BOOLEAN_COLUMN)); + } + + /** Implicit AND with two explicit ORs, The second OR is nested inside the first OR. */ + @Test + public void nested_AND_OR() { + var fixture = TEST_DATA.tableWhereCQLClause().tableWithAllDataTypes("AND(OR(OR(eq,eq)))"); + var expBuilder = fixture.expressionBuilder; + var implicitAnd = expBuilder.rootImplicitAnd; + var dbLogicalExpression = + implicitAnd.addSubExpressionReturnCurrent( + or().addSubExpressionReturnCurrent( + or( + expBuilder.eq(names().CQL_INT_COLUMN), + expBuilder.eq(names().CQL_TEXT_COLUMN)))); + fixture + .expressionBuilder() + .replaceRootDBLogicalExpression(dbLogicalExpression) + .applyAndGetOnGoingWhereClause() + .assertWhereCQL( + "WHERE (((%s=? OR %s=?)))".formatted(names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)) + .assertWherePositionalValues(List.of(names().CQL_INT_COLUMN, names().CQL_TEXT_COLUMN)); + } + + // ================================================================================================================== + // Methods below are created to help construct DBLogicalExpression in unit tests + public DBLogicalExpression and() { + return new DBLogicalExpression(DBLogicalExpression.DBLogicalOperator.AND); + } + + public DBLogicalExpression and(TableFilter... filters) { + var and = and(); + Arrays.stream(filters).forEach(and::addFilter); + return and; + } + + public DBLogicalExpression or() { + return new DBLogicalExpression(DBLogicalExpression.DBLogicalOperator.OR); + } + + public DBLogicalExpression or(TableFilter... filters) { + var or = or(); + Arrays.stream(filters).forEach(or::addFilter); + return or; + } + + public DBLogicalExpression addFilters( + DBLogicalExpression dbLogicalExpression, TableFilter... filters) { + Arrays.stream(filters).forEach(dbLogicalExpression::addFilter); + return dbLogicalExpression; + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/LogicalExpressionTestData.java b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/LogicalExpressionTestData.java index d757d95579..18847a38ec 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/LogicalExpressionTestData.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/LogicalExpressionTestData.java @@ -8,7 +8,12 @@ import com.datastax.oss.driver.api.core.type.DataTypes; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import io.stargate.sgv2.jsonapi.exception.checked.MissingJSONCodecException; +import io.stargate.sgv2.jsonapi.exception.checked.ToCQLCodecException; +import io.stargate.sgv2.jsonapi.exception.checked.UnknownColumnException; import io.stargate.sgv2.jsonapi.service.operation.filters.table.*; +import io.stargate.sgv2.jsonapi.service.operation.filters.table.codecs.JSONCodec; +import io.stargate.sgv2.jsonapi.service.operation.filters.table.codecs.JSONCodecRegistries; import io.stargate.sgv2.jsonapi.service.operation.query.DBLogicalExpression; import io.stargate.sgv2.jsonapi.service.operation.query.TableFilter; import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil; @@ -16,53 +21,62 @@ import java.math.BigInteger; import java.util.List; import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class LogicalExpressionTestData extends TestDataSuplier { + private static final Logger log = LoggerFactory.getLogger(LogicalExpressionTestData.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); public LogicalExpressionTestData(TestData testData) { super(testData); } - public DBLogicalExpression andExpression(TableMetadata tableMetadata) { + public DBLogicalExpression implicitAndExpression(TableMetadata tableMetadata) { return new DBLogicalExpression(DBLogicalExpression.DBLogicalOperator.AND); } public static class ExpressionBuilder { - public final DBLogicalExpression expression; + public DBLogicalExpression rootImplicitAnd; private final TableMetadata tableMetadata; private final FixtureT fixture; public ExpressionBuilder( - FixtureT fixture, DBLogicalExpression expression, TableMetadata tableMetadata) { + FixtureT fixture, DBLogicalExpression rootImplicitAnd, TableMetadata tableMetadata) { this.fixture = fixture; - this.expression = expression; + this.rootImplicitAnd = rootImplicitAnd; this.tableMetadata = tableMetadata; } + public FixtureT replaceRootDBLogicalExpression(DBLogicalExpression dbLogicalExpression) { + this.rootImplicitAnd = dbLogicalExpression; + return fixture; + } + public FixtureT eqOn(CqlIdentifier column) { - expression.addFilter(eq(tableMetadata.getColumn(column).orElseThrow())); + rootImplicitAnd.addFilter(eq(tableMetadata.getColumn(column).orElseThrow())); return fixture; } public FixtureT notEqOn(CqlIdentifier column) { - expression.addFilter(notEq(tableMetadata.getColumn(column).orElseThrow())); + rootImplicitAnd.addFilter(notEq(tableMetadata.getColumn(column).orElseThrow())); return fixture; } public FixtureT gtOn(CqlIdentifier column) { - expression.addFilter(gt(tableMetadata.getColumn(column).orElseThrow())); + rootImplicitAnd.addFilter(gt(tableMetadata.getColumn(column).orElseThrow())); return fixture; } public FixtureT inOn(CqlIdentifier column) { - expression.addFilter(in(tableMetadata.getColumn(column).orElseThrow())); + rootImplicitAnd.addFilter(in(tableMetadata.getColumn(column).orElseThrow())); return fixture; } public FixtureT notInOn(CqlIdentifier column) { - expression.addFilter(nin(tableMetadata.getColumn(column).orElseThrow())); + rootImplicitAnd.addFilter(nin(tableMetadata.getColumn(column).orElseThrow())); return fixture; } @@ -71,23 +85,12 @@ public FixtureT eqAllPrimaryKeys() { return eqAllClusteringKeys(); } - public FixtureT inOnOnePartitionKey( - InTableFilter.Operator inFilterOperator, ColumnMetadata firstPartitionKey) { - if (inFilterOperator == InTableFilter.Operator.IN) { - expression.addFilter(in(firstPartitionKey)); - } - if (inFilterOperator == InTableFilter.Operator.NIN) { - expression.addFilter(nin(firstPartitionKey)); - } - return fixture; - } - public FixtureT eqAllPartitionKeys() { tableMetadata .getPartitionKey() .forEach( columnMetadata -> { - expression.addFilter(eq(columnMetadata)); + rootImplicitAnd.addFilter(eq(columnMetadata)); }); return fixture; } @@ -103,7 +106,7 @@ public FixtureT eqSkipOnePartitionKeys(int skipIndex) { if (index == skipIndex) { continue; } - expression.addFilter(eq(columnMetadata)); + rootImplicitAnd.addFilter(eq(columnMetadata)); } return fixture; } @@ -114,11 +117,22 @@ public FixtureT eqAllClusteringKeys() { .keySet() .forEach( columnMetadata -> { - expression.addFilter(eq(columnMetadata)); + rootImplicitAnd.addFilter(eq(columnMetadata)); }); return fixture; } + public FixtureT inOnOnePartitionKey( + InTableFilter.Operator inFilterOperator, ColumnMetadata firstPartitionKey) { + if (inFilterOperator == InTableFilter.Operator.IN) { + rootImplicitAnd.addFilter(in(firstPartitionKey)); + } + if (inFilterOperator == InTableFilter.Operator.NIN) { + rootImplicitAnd.addFilter(nin(firstPartitionKey)); + } + return fixture; + } + public FixtureT eqSkipOneClusteringKeys(int skipIndex) { if (skipIndex >= tableMetadata.getClusteringColumns().size()) { @@ -131,7 +145,7 @@ public FixtureT eqSkipOneClusteringKeys(int skipIndex) { if (index == skipIndex) { continue; } - expression.addFilter(eq(columnMetadata)); + rootImplicitAnd.addFilter(eq(columnMetadata)); } return fixture; } @@ -140,7 +154,7 @@ public FixtureT eqOnlyOneClusteringKey(int index) { ColumnMetadata columnMetadata = tableMetadata.getClusteringColumns().keySet().stream().toList().get(index); - expression.addFilter(eq(columnMetadata)); + rootImplicitAnd.addFilter(eq(columnMetadata)); return fixture; } @@ -157,7 +171,7 @@ public FixtureT eqFirstNonPKOrIndexed() { .filter(columnMetadata -> !allIndexTargets.contains(columnMetadata.getName())) .findFirst() .ifPresentOrElse( - columnMetadata -> expression.addFilter(eq(columnMetadata)), + columnMetadata -> rootImplicitAnd.addFilter(eq(columnMetadata)), () -> { throw new IllegalArgumentException( "Table don't have a column that is NOT on the SAI table to generate test data"); @@ -173,6 +187,10 @@ public static TableFilter eq(ColumnMetadata columnMetadata) { value(columnMetadata.getType())); } + public TableFilter eq(CqlIdentifier columnCqlIdentifier) { + return eq(tableMetadata.getColumn(columnCqlIdentifier).orElseThrow()); + } + public static TableFilter notEq(ColumnMetadata columnMetadata) { return filter( columnMetadata.getName(), @@ -295,7 +313,7 @@ public static Object value(DataType type) { return "P1H30M"; // Handle duration as a string } if (type.equals(DataTypes.INT)) { - return 25; + return 25L; } if (type.equals(DataTypes.BIGINT)) { return 1000000L; @@ -408,5 +426,26 @@ public static JsonNode jsonNodeValue(DataType dataType) { throw new IllegalArgumentException( "Did not understand type %s to convert into JsonNode".formatted(dataType)); } + + // Get CqlValue of type that Driver expects + public Object CqlValue(CqlIdentifier column) { + var cqlDataType = tableMetadata.getColumn(column).orElseThrow().getType(); + var javaValue = value(cqlDataType); + JSONCodec codec = null; + try { + codec = + JSONCodecRegistries.DEFAULT_REGISTRY.codecToCQL( + tableMetadata, + column, + ExpressionBuilder.value(tableMetadata.getColumn(column).orElseThrow().getType())); + } catch (UnknownColumnException | ToCQLCodecException | MissingJSONCodecException e) { + throw new IllegalArgumentException(e); + } + try { + return codec.toCQL(javaValue); + } catch (ToCQLCodecException e) { + throw new RuntimeException(e); + } + } } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/TableWhereCQLClauseTestData.java b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/TableWhereCQLClauseTestData.java new file mode 100644 index 0000000000..efb15d8946 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/fixtures/testdata/TableWhereCQLClauseTestData.java @@ -0,0 +1,128 @@ +package io.stargate.sgv2.jsonapi.fixtures.testdata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata; +import com.datastax.oss.driver.api.querybuilder.select.Select; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.stargate.sgv2.jsonapi.exception.WithWarnings; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; +import io.stargate.sgv2.jsonapi.service.cqldriver.override.ExtendedSelect; +import io.stargate.sgv2.jsonapi.service.operation.query.DBLogicalExpression; +import io.stargate.sgv2.jsonapi.service.operation.tables.TableWhereCQLClause; +import io.stargate.sgv2.jsonapi.util.recordable.Recordable; +import java.util.ArrayList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TableWhereCQLClauseTestData extends TestDataSuplier { + + private static final Logger LOGGER = LoggerFactory.getLogger(TableWhereCQLClauseTestData.class); + + public TableWhereCQLClauseTestData(TestData testData) { + super(testData); + } + + public TableWhereCQLClauseTestData.TableWhereCQLClauseFixture tableWithAllDataTypes( + String message) { + var tableMetaData = testData.tableMetadata().tableAllDatatypesIndexed(); + return new TableWhereCQLClauseTestData.TableWhereCQLClauseFixture( + message, tableMetaData, testData.logicalExpression().implicitAndExpression(tableMetaData)); + } + + public static class TableWhereCQLClauseFixture implements Recordable { + + private final String message; + private final TableMetadata tableMetadata; + private final TableSchemaObject tableSchemaObject; + public final LogicalExpressionTestData.ExpressionBuilder + expressionBuilder; + + private TableWhereCQLClause