From 5c53563b41732559330dca5c0fd4e0e231556ca9 Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Tue, 11 Feb 2025 11:57:05 -0500 Subject: [PATCH 01/11] First draft version for vectorize usage response --- .../api/model/command/CommandResult.java | 32 +++++++++++ .../api/model/command/VectorizeUsageBean.java | 55 +++++++++++++++++++ .../jsonapi/api/v1/CollectionResource.java | 22 +++++++- .../service/embedding/DataVectorizer.java | 38 +++++++++++-- .../embedding/DataVectorizerService.java | 9 ++- .../operation/EmbeddingProvider.java | 10 +++- .../operation/MeteredEmbeddingProvider.java | 16 +++++- .../operation/NetworkUsageInterceptor.java | 55 +++++++++++++++++++ .../operation/OpenAIEmbeddingProvider.java | 32 ++++++++--- .../operation/VectorizeUsageInfo.java | 41 ++++++++++++++ .../operation/DataVectorizerTest.java | 54 +++++++++++++++--- 11 files changed, 336 insertions(+), 28 deletions(-) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java index 5b60582ff2..37faf293a6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java @@ -138,6 +138,38 @@ public RestResponse toRestResponse() { return RestResponse.ok(this); } + /** + * Create the {@link RestResponse} Maps CommandResult to RestResponse. Except for few selective + * errors, all errors are mapped to http status 200. In case of 401, 500, 502 and 504 response is + * sent with appropriate status code. + * + * @return + */ + public RestResponse toRestResponse(String vectorizeHeader) { + if (null != this.errors()) { + final Optional first = + this.errors().stream() + .filter(error -> error.httpStatus() != Response.Status.OK) + .findFirst(); + + if (first.isPresent()) { + if (vectorizeHeader != null) { + return RestResponse.ResponseBuilder.create(first.get().httpStatus(), this) + .header("vectorize-usage", vectorizeHeader) + .build(); + } else { + return RestResponse.ResponseBuilder.create(first.get().httpStatus(), this).build(); + } + } + } + if (vectorizeHeader != null) { + return RestResponse.ResponseBuilder.create(RestResponse.Status.OK, this) + .header("vectorize-usage", vectorizeHeader) + .build(); + } + return RestResponse.ok(this); + } + /** * returned a new CommandResult with warning message added in status map * diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java new file mode 100644 index 0000000000..a739b8a06e --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java @@ -0,0 +1,55 @@ +package io.stargate.sgv2.jsonapi.api.model.command; + +import jakarta.enterprise.context.RequestScoped; + +/** + * Bean class which is used to return as response json in header + */ +@RequestScoped +public class VectorizeUsageBean { + private int requestSize; + private int responseSize; + private int totalTokens; + private String provider; + private String model; + + public int getRequestSize() { + return requestSize; + } + + public void setRequestSize(int requestSize) { + this.requestSize = requestSize; + } + + public int getResponseSize() { + return responseSize; + } + + public void setResponseSize(int responseSize) { + this.responseSize = responseSize; + } + + public int getTotalTokens() { + return totalTokens; + } + + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } + + public String getProvider() { + return provider; + } + + public void setProvider(String provider) { + this.provider = provider; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java index 9c1aabb6fb..9d3952025e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java @@ -2,6 +2,9 @@ import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.model.command.*; import io.stargate.sgv2.jsonapi.api.model.command.impl.AlterTableCommand; @@ -36,6 +39,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory; +import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.service.processor.MeteredCommandProcessor; import jakarta.inject.Inject; import jakarta.validation.Valid; @@ -83,6 +87,10 @@ public class CollectionResource { @Inject private JsonProcessingMetricsReporter jsonProcessingMetricsReporter; + @Inject VectorizeUsageBean vectorizeUsageBean; + + @Inject ObjectMapper objectMapper; + @Inject public CollectionResource(MeteredCommandProcessor meteredCommandProcessor) { this.meteredCommandProcessor = meteredCommandProcessor; @@ -260,6 +268,18 @@ public Uni> postCommand( dataApiRequestInfo, commandContext, command); } }) - .map(commandResult -> commandResult.toRestResponse()); + .map( + commandResult -> { + String vectorize = null; + try { + + if (!Strings.isNullOrEmpty(vectorizeUsageBean.getModel())) { + vectorize = objectMapper.writeValueAsString(vectorizeUsageBean); + } + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + return commandResult.toRestResponse(vectorize); + }); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index 47f011feeb..183d7d1f09 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -19,6 +19,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDef; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiVectorType; @@ -34,6 +35,7 @@ public class DataVectorizer { private final JsonNodeFactory nodeFactory; private final EmbeddingCredentials embeddingCredentials; private final SchemaObject schemaObject; + private final VectorizeUsageBean vectorizeUsageBean; /** * Constructor @@ -48,11 +50,13 @@ public DataVectorizer( EmbeddingProvider embeddingProvider, JsonNodeFactory nodeFactory, EmbeddingCredentials embeddingCredentials, - SchemaObject schemaObject) { + SchemaObject schemaObject, + VectorizeUsageBean vectorizeUsageBean) { this.embeddingProvider = embeddingProvider; this.nodeFactory = nodeFactory; this.embeddingCredentials = embeddingCredentials; this.schemaObject = schemaObject; + this.vectorizeUsageBean = vectorizeUsageBean; } /** @@ -107,7 +111,16 @@ public Uni vectorize(List documents) { vectorizeTexts, embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) - .map(res -> res.embeddings()); + .map( + res -> { + vectorizeUsageBean.setRequestSize(res.vectorizeUsageInfo().getRequestSize()); + vectorizeUsageBean.setResponseSize( + res.vectorizeUsageInfo().getResponseSize()); + vectorizeUsageBean.setTotalTokens(res.vectorizeUsageInfo().getTotalTokens()); + vectorizeUsageBean.setProvider(res.vectorizeUsageInfo().getProvider()); + vectorizeUsageBean.setModel(res.vectorizeUsageInfo().getModel()); + return res.embeddings(); + }); return vectors .onItem() .transform( @@ -174,7 +187,15 @@ public Uni vectorize(String vectorizeContent) { List.of(vectorizeContent), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) - .map(EmbeddingProvider.Response::embeddings); + .map( + res -> { + vectorizeUsageBean.setRequestSize(res.vectorizeUsageInfo().getRequestSize()); + vectorizeUsageBean.setResponseSize(res.vectorizeUsageInfo().getResponseSize()); + vectorizeUsageBean.setTotalTokens(res.vectorizeUsageInfo().getTotalTokens()); + vectorizeUsageBean.setProvider(res.vectorizeUsageInfo().getProvider()); + vectorizeUsageBean.setModel(res.vectorizeUsageInfo().getModel()); + return res.embeddings(); + }); return vectors .onItem() .transform( @@ -221,7 +242,16 @@ public Uni vectorize(SortClause sortClause) { List.of(text), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.SEARCH) - .map(res -> res.embeddings()); + .map( + res -> { + vectorizeUsageBean.setRequestSize(res.vectorizeUsageInfo().getRequestSize()); + vectorizeUsageBean.setResponseSize( + res.vectorizeUsageInfo().getResponseSize()); + vectorizeUsageBean.setTotalTokens(res.vectorizeUsageInfo().getTotalTokens()); + vectorizeUsageBean.setProvider(res.vectorizeUsageInfo().getProvider()); + vectorizeUsageBean.setModel(res.vectorizeUsageInfo().getModel()); + return res.embeddings(); + }); return vectors .onItem() .transform( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java index 073ed3dac9..0dcbcd2e8c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java @@ -22,6 +22,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; import io.stargate.sgv2.jsonapi.service.embedding.operation.MeteredEmbeddingProvider; +import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDef; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiSupportDef; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName; @@ -37,15 +38,18 @@ public class DataVectorizerService { private final ObjectMapper objectMapper; private final MeterRegistry meterRegistry; private final JsonApiMetricsConfig jsonApiMetricsConfig; + private final VectorizeUsageBean vectorizeUsageBean; @Inject public DataVectorizerService( ObjectMapper objectMapper, MeterRegistry meterRegistry, - JsonApiMetricsConfig jsonApiMetricsConfig) { + JsonApiMetricsConfig jsonApiMetricsConfig, + VectorizeUsageBean vectorizeUsageBean) { this.objectMapper = objectMapper; this.meterRegistry = meterRegistry; this.jsonApiMetricsConfig = jsonApiMetricsConfig; + this.vectorizeUsageBean = vectorizeUsageBean; } /** @@ -95,7 +99,8 @@ public DataVectorizer constructDataVectorizer( embeddingProvider, objectMapper.getNodeFactory(), dataApiRequestInfo.getEmbeddingCredentials(), - commandContext.schemaObject()); + commandContext.schemaObject(), + vectorizeUsageBean); } private Uni vectorizeSortClause( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java index aa21ae1af4..3ee91bcdb0 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java @@ -175,9 +175,15 @@ protected void checkEmbeddingApiKeyHeader(String providerId, Optional ap * @param batchId - Sequence number for the batch to order the vectors. * @param embeddings - Embedding vectors for the given text inputs. */ - public record Response(int batchId, List embeddings) { + public record Response( + int batchId, List embeddings, VectorizeUsageInfo vectorizeUsageInfo) { public static Response of(int batchId, List embeddings) { - return new Response(batchId, embeddings); + return new Response(batchId, embeddings, new VectorizeUsageInfo(0, 0, 0, "", "")); + } + + public static Response of( + int batchId, List embeddings, VectorizeUsageInfo vectorizeUsageInfo) { + return new Response(batchId, embeddings, vectorizeUsageInfo); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java index 1cc2e3dac3..69debda5e9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java @@ -91,11 +91,25 @@ public Uni vectorize( Collections.sort( vectorizedBatches, (a, b) -> Integer.compare(a.batchId(), b.batchId())); List result = new ArrayList<>(); + int sentBytes = 0; + int receivedBytes = 0; + int totalToken = 0; + String provider = ""; + String modelName = ""; for (Response vectorizedBatch : vectorizedBatches) { // create the final ordered result result.addAll(vectorizedBatch.embeddings()); + sentBytes += vectorizedBatch.vectorizeUsageInfo().getRequestSize(); + receivedBytes += vectorizedBatch.vectorizeUsageInfo().getResponseSize(); + totalToken += vectorizedBatch.vectorizeUsageInfo().getTotalTokens(); + provider = vectorizedBatch.vectorizeUsageInfo().getProvider(); + modelName = vectorizedBatch.vectorizeUsageInfo().getModel(); } - return Response.of(1, result); + return Response.of( + 1, + result, + new VectorizeUsageInfo( + sentBytes, receivedBytes, totalToken, provider, modelName)); }) .invoke( () -> diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java new file mode 100644 index 0000000000..1b8dab7000 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java @@ -0,0 +1,55 @@ +package io.stargate.sgv2.jsonapi.service.embedding.operation; + +import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientRequestFilter; +import jakarta.ws.rs.client.ClientResponseContext; +import jakarta.ws.rs.client.ClientResponseFilter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.logging.Logger; + +public class NetworkUsageInterceptor implements ClientRequestFilter, ClientResponseFilter { + + private static final Logger LOGGER = Logger.getLogger(NetworkUsageInterceptor.class.getName()); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // Jackson object mapper + + @Override + public void filter(ClientRequestContext requestContext) throws IOException { + + // **3. Calculate Body size (if present)** + if (requestContext.hasEntity()) { + try { + byte[] requestBody = OBJECT_MAPPER.writeValueAsBytes(requestContext.getEntity()); + requestContext.setProperty("sentBytes", requestBody.length); + } catch (Exception e) { + LOGGER.warning("Failed to measure request body size: " + e.getMessage()); + } + } + } + + @Override + public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) + throws IOException { + int receivedBytes = 0; + int sentBytes = (int) requestContext.getProperty("sentBytes"); + if (responseContext.hasEntity()) { + // Read the response entity stream to measure its size + InputStream inputStream = responseContext.getEntityStream(); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + byteArrayOutputStream.write(buffer, 0, bytesRead); + receivedBytes += bytesRead; + } + responseContext.setEntityStream( + new ByteArrayInputStream(byteArrayOutputStream.toByteArray())); + } + LOGGER.info("Received Bytes: " + receivedBytes); + responseContext.getHeaders().add("sent-bytes", String.valueOf(sentBytes)); + responseContext.getHeaders().add("received-bytes", String.valueOf(receivedBytes)); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java index 8c5077f2ff..1b5d328c7a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java @@ -10,6 +10,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; +import jakarta.enterprise.context.control.ActivateRequestContext; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; @@ -23,6 +24,7 @@ import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +@ActivateRequestContext public class OpenAIEmbeddingProvider extends EmbeddingProvider { private static final String providerId = ProviderConstants.OPENAI; private final OpenAIEmbeddingProviderClient openAIEmbeddingProviderClient; @@ -50,11 +52,12 @@ public OpenAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface OpenAIEmbeddingProviderClient { @POST @Path("/embeddings") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @HeaderParam("OpenAI-Organization") String organizationId, @HeaderParam("OpenAI-Project") String projectId, @@ -122,25 +125,36 @@ public Uni vectorize( String organizationId = (String) vectorizeServiceParameters.get("organizationId"); String projectId = (String) vectorizeServiceParameters.get("projectId"); - Uni response = + // ✅ Create an instance of NetworkUsageInfo and pass it to request properties + Uni response = applyRetry( openAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), organizationId, projectId, - request)); + request)); // Pass the object dynamically return response .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + res -> { + EmbeddingResponse embeddingResponse = res.readEntity(EmbeddingResponse.class); + int sentBytes = Integer.parseInt(res.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(res.getHeaderString("received-bytes")); + VectorizeUsageInfo vectorizeUsageInfo = + new VectorizeUsageInfo( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens, + "openai", + modelName); + if (embeddingResponse.data() == null) { + return Response.of(batchId, Collections.emptyList(), vectorizeUsageInfo); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return Response.of(batchId, vectors, vectorizeUsageInfo); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java new file mode 100644 index 0000000000..21fa38feac --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java @@ -0,0 +1,41 @@ +package io.stargate.sgv2.jsonapi.service.embedding.operation; + +/** + * Used to track the metric at a request level to the embedding service + */ +public class VectorizeUsageInfo { + private int requestSize; + private int responseSize; + private int totalTokens; + private String provider; + private String model; + + public VectorizeUsageInfo( + int requestSize, int responseSize, int totalTokens, String provider, String model) { + this.requestSize = requestSize; + this.responseSize = responseSize; + this.totalTokens = totalTokens; + this.provider = provider; + this.model = model; + } + + public int getRequestSize() { + return requestSize; + } + + public int getResponseSize() { + return responseSize; + } + + public int getTotalTokens() { + return totalTokens; + } + + public String getProvider() { + return provider; + } + + public String getModel() { + return model; + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index dd92b1affb..18bdeab79b 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -10,6 +10,7 @@ import io.quarkus.test.junit.TestProfile; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; +import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; @@ -35,6 +36,8 @@ @TestProfile(PropertyBasedOverrideProfile.class) public class DataVectorizerTest { @Inject ObjectMapper objectMapper; + @Inject + VectorizeUsageBean vectorizeUsageBean; private final EmbeddingProvider testService = new TestEmbeddingProvider(); private final CollectionSchemaObject collectionSettings = TestEmbeddingProvider.commandContextWithVectorize.schemaObject(); @@ -52,7 +55,11 @@ public void testTextValues() { } DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); try { dataVectorizer.vectorize(documents).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -79,7 +86,11 @@ public void testEmptyValues() { DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); try { dataVectorizer.vectorize(documents).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -109,7 +120,11 @@ public void testNonTextValues() { DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); try { Throwable failure = dataVectorizer @@ -138,7 +153,11 @@ public void testNullValues() { DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); try { dataVectorizer.vectorize(documents).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -161,7 +180,11 @@ public void testWithBothVectorFieldValues() { documents.add(document); DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); try { Throwable failure = dataVectorizer @@ -207,7 +230,8 @@ public Uni vectorize( testProvider, objectMapper.getNodeFactory(), embeddingCredentials, - collectionSettings); + collectionSettings, + vectorizeUsageBean); Throwable failure = dataVectorizer @@ -249,7 +273,11 @@ public void testWithUnmatchedVectorSize() { } DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); Throwable failure = dataVectorizer @@ -278,7 +306,11 @@ public void sortClauseValues() { SortClause sortClause = new SortClause(sortExpressions); DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); try { dataVectorizer.vectorize(sortClause).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -301,7 +333,11 @@ public void vectorize() { arrayNode.add(objectMapper.getNodeFactory().numberNode(0.11f)); DataVectorizer dataVectorizer = new DataVectorizer( - testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); + testService, + objectMapper.getNodeFactory(), + embeddingCredentials, + collectionSettings, + vectorizeUsageBean); try { final float[] testData = dataVectorizer.vectorize("test data").subscribe().asCompletionStage().get(); From 5efe8da070b7c80c789c91f611d4b2d6d7455d31 Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Tue, 11 Feb 2025 12:45:22 -0500 Subject: [PATCH 02/11] Formatted the classes --- .../sgv2/jsonapi/api/model/command/VectorizeUsageBean.java | 4 +--- .../io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java | 2 +- .../sgv2/jsonapi/service/embedding/DataVectorizer.java | 2 +- .../sgv2/jsonapi/service/embedding/DataVectorizerService.java | 2 +- .../service/embedding/operation/NetworkUsageInterceptor.java | 3 +-- .../service/embedding/operation/VectorizeUsageInfo.java | 4 +--- .../service/embedding/operation/DataVectorizerTest.java | 3 +-- 7 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java index a739b8a06e..f0e699bea2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java @@ -2,9 +2,7 @@ import jakarta.enterprise.context.RequestScoped; -/** - * Bean class which is used to return as response json in header - */ +/** Bean class which is used to return as response json in header */ @RequestScoped public class VectorizeUsageBean { private int requestSize; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java index 9d3952025e..6a412750eb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java @@ -7,6 +7,7 @@ import com.google.common.base.Strings; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.model.command.*; +import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.impl.AlterTableCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CountDocumentsCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateIndexCommand; @@ -39,7 +40,6 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory; -import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.service.processor.MeteredCommandProcessor; import jakarta.inject.Inject; import jakarta.validation.Valid; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index 183d7d1f09..69e7a4a9e1 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import io.smallrye.mutiny.Uni; +import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; @@ -19,7 +20,6 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; -import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDef; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiVectorType; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java index 0dcbcd2e8c..6798f8aab2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java @@ -9,6 +9,7 @@ import io.micrometer.core.instrument.MeterRegistry; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.model.command.*; +import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateOperator; @@ -22,7 +23,6 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; import io.stargate.sgv2.jsonapi.service.embedding.operation.MeteredEmbeddingProvider; -import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDef; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiSupportDef; import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java index 1b8dab7000..bc4436db04 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java @@ -18,8 +18,7 @@ public class NetworkUsageInterceptor implements ClientRequestFilter, ClientRespo @Override public void filter(ClientRequestContext requestContext) throws IOException { - - // **3. Calculate Body size (if present)** + // **1. Calculate Body size (if present)** if (requestContext.hasEntity()) { try { byte[] requestBody = OBJECT_MAPPER.writeValueAsBytes(requestContext.getEntity()); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java index 21fa38feac..8f2aa48b8c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java @@ -1,8 +1,6 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; -/** - * Used to track the metric at a request level to the embedding service - */ +/** Used to track the metric at a request level to the embedding service */ public class VectorizeUsageInfo { private int requestSize; private int responseSize; diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index 18bdeab79b..7f1973c2f5 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -36,8 +36,7 @@ @TestProfile(PropertyBasedOverrideProfile.class) public class DataVectorizerTest { @Inject ObjectMapper objectMapper; - @Inject - VectorizeUsageBean vectorizeUsageBean; + @Inject VectorizeUsageBean vectorizeUsageBean; private final EmbeddingProvider testService = new TestEmbeddingProvider(); private final CollectionSchemaObject collectionSettings = TestEmbeddingProvider.commandContextWithVectorize.schemaObject(); From 7113bffcf45bf3cd71211cd829ff62b288092816 Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 10:23:16 -0500 Subject: [PATCH 03/11] Refactored the `NetworkUsageInterceptor` as per suggestions --- .../operation/NetworkUsageInterceptor.java | 48 ++++++++++--------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java index bc4436db04..ed17b13c9e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java @@ -1,8 +1,8 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.CountingOutputStream; import jakarta.ws.rs.client.ClientRequestContext; -import jakarta.ws.rs.client.ClientRequestFilter; import jakarta.ws.rs.client.ClientResponseContext; import jakarta.ws.rs.client.ClientResponseFilter; import java.io.ByteArrayInputStream; @@ -11,42 +11,44 @@ import java.io.InputStream; import java.util.logging.Logger; -public class NetworkUsageInterceptor implements ClientRequestFilter, ClientResponseFilter { +public class NetworkUsageInterceptor implements ClientResponseFilter { private static final Logger LOGGER = Logger.getLogger(NetworkUsageInterceptor.class.getName()); private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // Jackson object mapper @Override - public void filter(ClientRequestContext requestContext) throws IOException { - // **1. Calculate Body size (if present)** + public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) + throws IOException { + int receivedBytes = 0; + int sentBytes = 0; if (requestContext.hasEntity()) { try { - byte[] requestBody = OBJECT_MAPPER.writeValueAsBytes(requestContext.getEntity()); - requestContext.setProperty("sentBytes", requestBody.length); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + CountingOutputStream cus = new CountingOutputStream(byteArrayOutputStream); + OBJECT_MAPPER.writeValue(cus, requestContext.getEntity()); + cus.close(); + sentBytes = (int) cus.getCount(); } catch (Exception e) { LOGGER.warning("Failed to measure request body size: " + e.getMessage()); } } - } - - @Override - public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) - throws IOException { - int receivedBytes = 0; - int sentBytes = (int) requestContext.getProperty("sentBytes"); if (responseContext.hasEntity()) { - // Read the response entity stream to measure its size - InputStream inputStream = responseContext.getEntityStream(); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - byteArrayOutputStream.write(buffer, 0, bytesRead); - receivedBytes += bytesRead; + receivedBytes = responseContext.getLength(); + if (receivedBytes <= 0) { + // Read the response entity stream to measure its size + InputStream inputStream = responseContext.getEntityStream(); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[1024]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + byteArrayOutputStream.write(buffer, 0, bytesRead); + receivedBytes += bytesRead; + } + responseContext.setEntityStream( + new ByteArrayInputStream(byteArrayOutputStream.toByteArray())); } - responseContext.setEntityStream( - new ByteArrayInputStream(byteArrayOutputStream.toByteArray())); } + LOGGER.info("Received Bytes: " + receivedBytes); responseContext.getHeaders().add("sent-bytes", String.valueOf(sentBytes)); responseContext.getHeaders().add("received-bytes", String.valueOf(receivedBytes)); From 96c40f2a06735c820cf3b7a94c2d895e8ec32b83 Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 10:31:11 -0500 Subject: [PATCH 04/11] Refactored the `NetworkUsageInterceptor` as per suggestions --- .../embedding/operation/NetworkUsageInterceptor.java | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java index ed17b13c9e..26eff247f1 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java @@ -1,11 +1,11 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.ByteStreams; import com.google.common.io.CountingOutputStream; import jakarta.ws.rs.client.ClientRequestContext; import jakarta.ws.rs.client.ClientResponseContext; import jakarta.ws.rs.client.ClientResponseFilter; -import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; @@ -38,18 +38,10 @@ public void filter(ClientRequestContext requestContext, ClientResponseContext re // Read the response entity stream to measure its size InputStream inputStream = responseContext.getEntityStream(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - byte[] buffer = new byte[1024]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - byteArrayOutputStream.write(buffer, 0, bytesRead); - receivedBytes += bytesRead; - } - responseContext.setEntityStream( - new ByteArrayInputStream(byteArrayOutputStream.toByteArray())); + receivedBytes = (int) ByteStreams.copy(inputStream, byteArrayOutputStream); } } - LOGGER.info("Received Bytes: " + receivedBytes); responseContext.getHeaders().add("sent-bytes", String.valueOf(sentBytes)); responseContext.getHeaders().add("received-bytes", String.valueOf(receivedBytes)); } From b0d6b3452464dd48f0fcab44d7da75470912adec Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 11:28:46 -0500 Subject: [PATCH 05/11] Fixed the byte size calculation and also fixed the Logger to use slf4j --- .../operation/NetworkUsageInterceptor.java | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java index 26eff247f1..3206e05d8c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java @@ -6,14 +6,15 @@ import jakarta.ws.rs.client.ClientRequestContext; import jakarta.ws.rs.client.ClientResponseContext; import jakarta.ws.rs.client.ClientResponseFilter; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; -import java.util.logging.Logger; +import java.io.OutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class NetworkUsageInterceptor implements ClientResponseFilter { - private static final Logger LOGGER = Logger.getLogger(NetworkUsageInterceptor.class.getName()); + private static final Logger LOGGER = LoggerFactory.getLogger(NetworkUsageInterceptor.class); private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // Jackson object mapper @Override @@ -23,13 +24,12 @@ public void filter(ClientRequestContext requestContext, ClientResponseContext re int sentBytes = 0; if (requestContext.hasEntity()) { try { - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - CountingOutputStream cus = new CountingOutputStream(byteArrayOutputStream); + CountingOutputStream cus = new CountingOutputStream(OutputStream.nullOutputStream()); OBJECT_MAPPER.writeValue(cus, requestContext.getEntity()); cus.close(); sentBytes = (int) cus.getCount(); } catch (Exception e) { - LOGGER.warning("Failed to measure request body size: " + e.getMessage()); + LOGGER.warn("Failed to measure request body size: " + e.getMessage()); } } if (responseContext.hasEntity()) { @@ -37,8 +37,7 @@ public void filter(ClientRequestContext requestContext, ClientResponseContext re if (receivedBytes <= 0) { // Read the response entity stream to measure its size InputStream inputStream = responseContext.getEntityStream(); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - receivedBytes = (int) ByteStreams.copy(inputStream, byteArrayOutputStream); + receivedBytes = (int) ByteStreams.copy(inputStream, OutputStream.nullOutputStream()); } } From 1ad808e2cfd11a7c61c5115a6ca60327966c78b8 Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 13:22:40 -0500 Subject: [PATCH 06/11] Added the bytes usage and token usage to places where available. --- .../service/embedding/DataVectorizer.java | 32 ++++----- .../gateway/EmbeddingGatewayClient.java | 6 +- .../AwsBedrockEmbeddingProvider.java | 16 ++++- .../AzureOpenAIEmbeddingProvider.java | 29 ++++++-- .../operation/CohereEmbeddingProvider.java | 49 +++++++++++-- .../operation/EmbeddingProvider.java | 12 +--- ...HuggingFaceDedicatedEmbeddingProvider.java | 29 ++++++-- .../HuggingFaceEmbeddingProvider.java | 32 +++++++-- .../operation/JinaAIEmbeddingProvider.java | 30 +++++--- .../operation/MeteredEmbeddingProvider.java | 18 +---- .../operation/MistralEmbeddingProvider.java | 32 ++++++--- .../operation/NvidiaEmbeddingProvider.java | 32 ++++++--- .../operation/OpenAIEmbeddingProvider.java | 22 +++--- .../operation/UpstageAIEmbeddingProvider.java | 32 ++++++--- .../embedding/operation/VectorizeUsage.java | 69 +++++++++++++++++++ .../operation/VectorizeUsageInfo.java | 39 ----------- .../operation/VertexAIEmbeddingProvider.java | 28 +++++--- .../operation/VoyageAIEmbeddingProvider.java | 29 ++++++-- .../test/CustomITEmbeddingProvider.java | 6 +- .../operation/DataVectorizerTest.java | 3 +- .../operation/TestEmbeddingProvider.java | 2 +- 21 files changed, 370 insertions(+), 177 deletions(-) create mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsage.java delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index 69e7a4a9e1..3e843b9246 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -113,12 +113,11 @@ public Uni vectorize(List documents) { EmbeddingProvider.EmbeddingRequestType.INDEX) .map( res -> { - vectorizeUsageBean.setRequestSize(res.vectorizeUsageInfo().getRequestSize()); - vectorizeUsageBean.setResponseSize( - res.vectorizeUsageInfo().getResponseSize()); - vectorizeUsageBean.setTotalTokens(res.vectorizeUsageInfo().getTotalTokens()); - vectorizeUsageBean.setProvider(res.vectorizeUsageInfo().getProvider()); - vectorizeUsageBean.setModel(res.vectorizeUsageInfo().getModel()); + vectorizeUsageBean.setRequestSize(res.vectorizeUsage().getRequestBytes()); + vectorizeUsageBean.setResponseSize(res.vectorizeUsage().getResponseBytes()); + vectorizeUsageBean.setTotalTokens(res.vectorizeUsage().getTotalTokens()); + vectorizeUsageBean.setProvider(res.vectorizeUsage().getProvider()); + vectorizeUsageBean.setModel(res.vectorizeUsage().getModel()); return res.embeddings(); }); return vectors @@ -189,11 +188,11 @@ public Uni vectorize(String vectorizeContent) { EmbeddingProvider.EmbeddingRequestType.INDEX) .map( res -> { - vectorizeUsageBean.setRequestSize(res.vectorizeUsageInfo().getRequestSize()); - vectorizeUsageBean.setResponseSize(res.vectorizeUsageInfo().getResponseSize()); - vectorizeUsageBean.setTotalTokens(res.vectorizeUsageInfo().getTotalTokens()); - vectorizeUsageBean.setProvider(res.vectorizeUsageInfo().getProvider()); - vectorizeUsageBean.setModel(res.vectorizeUsageInfo().getModel()); + vectorizeUsageBean.setRequestSize(res.vectorizeUsage().getRequestBytes()); + vectorizeUsageBean.setResponseSize(res.vectorizeUsage().getResponseBytes()); + vectorizeUsageBean.setTotalTokens(res.vectorizeUsage().getTotalTokens()); + vectorizeUsageBean.setProvider(res.vectorizeUsage().getProvider()); + vectorizeUsageBean.setModel(res.vectorizeUsage().getModel()); return res.embeddings(); }); return vectors @@ -244,12 +243,11 @@ public Uni vectorize(SortClause sortClause) { EmbeddingProvider.EmbeddingRequestType.SEARCH) .map( res -> { - vectorizeUsageBean.setRequestSize(res.vectorizeUsageInfo().getRequestSize()); - vectorizeUsageBean.setResponseSize( - res.vectorizeUsageInfo().getResponseSize()); - vectorizeUsageBean.setTotalTokens(res.vectorizeUsageInfo().getTotalTokens()); - vectorizeUsageBean.setProvider(res.vectorizeUsageInfo().getProvider()); - vectorizeUsageBean.setModel(res.vectorizeUsageInfo().getModel()); + vectorizeUsageBean.setRequestSize(res.vectorizeUsage().getRequestBytes()); + vectorizeUsageBean.setResponseSize(res.vectorizeUsage().getResponseBytes()); + vectorizeUsageBean.setTotalTokens(res.vectorizeUsage().getTotalTokens()); + vectorizeUsageBean.setProvider(res.vectorizeUsage().getProvider()); + vectorizeUsageBean.setModel(res.vectorizeUsage().getModel()); return res.embeddings(); }); return vectors diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java index 86ccfc4b21..9ec67b03a6 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java @@ -10,6 +10,7 @@ import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import io.stargate.sgv2.jsonapi.service.embedding.operation.VectorizeUsage; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -190,7 +191,8 @@ else if (value instanceof Boolean) resp.getError().getErrorMessage()); } if (resp.getEmbeddingsList() == null) { - return Response.of(batchId, Collections.emptyList()); + return new Response( + batchId, Collections.emptyList(), new VectorizeUsage(provider, modelName)); } final List vectors = resp.getEmbeddingsList().stream() @@ -203,7 +205,7 @@ else if (value instanceof Boolean) return embedding; }) .toList(); - return Response.of(batchId, vectors); + return new Response(batchId, vectors, new VectorizeUsage(provider, modelName)); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java index e32819a680..585b03d67b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java @@ -9,12 +9,15 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectReader; import com.fasterxml.jackson.databind.ObjectWriter; +import com.google.common.io.ByteStreams; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -78,12 +81,14 @@ public Uni vectorize( .credentialsProvider(StaticCredentialsProvider.create(awsCreds)) .region(Region.of(vectorizeServiceParameters.get("region").toString())) .build(); + final VectorizeUsage vectorizeUsage = new VectorizeUsage(ProviderConstants.BEDROCK, modelName); final CompletableFuture invokeModelResponseCompletableFuture = client.invokeModel( request -> { final byte[] inputData; try { inputData = ow.writeValueAsBytes(new EmbeddingRequest(texts.get(0), dimension)); + vectorizeUsage.setRequestBytes(inputData.length); request.body(SdkBytes.fromByteArray(inputData)).modelId(modelName); } catch (JsonProcessingException e) { throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); @@ -94,10 +99,15 @@ public Uni vectorize( invokeModelResponseCompletableFuture.thenApply( res -> { try { - EmbeddingResponse response = - or.readValue(res.body().asInputStream(), EmbeddingResponse.class); + InputStream inputStream = res.body().asInputStream(); + int receivedBytes = + (int) ByteStreams.copy(inputStream, OutputStream.nullOutputStream()); + EmbeddingResponse response = or.readValue(inputStream, EmbeddingResponse.class); + vectorizeUsage.setResponseBytes(receivedBytes); List vectors = List.of(response.embedding); - return Response.of(batchId, vectors); + // Note: This is input tokens + vectorizeUsage.setTotalTokens(response.inputTextTokenCount()); + return new Response(batchId, vectors, vectorizeUsage); } catch (IOException e) { throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java index dcc655e1f8..549b926d6b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java @@ -12,6 +12,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -55,11 +56,12 @@ public AzureOpenAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface OpenAIEmbeddingProviderClient { @POST // no path specified, as it is already included in the baseUri @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( // API keys as "api-key", MS Entra as "Authorization: Bearer [token] @HeaderParam("api-key") String accessToken, EmbeddingRequest request); @@ -120,7 +122,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension); // NOTE: NO "Bearer " prefix with API key for Azure OpenAI - Uni response = + Uni response = applyRetry( openAIEmbeddingProviderClient.embed(embeddingCredentials.apiKey().get(), request)); @@ -128,13 +130,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.AZURE_OPENAI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.AZURE_OPENAI, + modelName); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java index c03d0a7138..fa17bee637 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java @@ -13,6 +13,7 @@ import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Collections; import java.util.List; @@ -47,11 +48,12 @@ public CohereEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface CohereEmbeddingProviderClient { @POST @Path("/embed") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -108,6 +110,8 @@ protected EmbeddingResponse() {} private List embeddings; + private BilledUnits billed_units; + public List getEmbeddings() { return embeddings; } @@ -115,6 +119,28 @@ public List getEmbeddings() { public void setEmbeddings(List embeddings) { this.embeddings = embeddings; } + + public BilledUnits getBilled_units() { + return billed_units; + } + + public void setBilled_units(BilledUnits billed_units) { + this.billed_units = billed_units; + } + + private static class BilledUnits { + public int input_tokens; + + public BilledUnits() {} + + public int getInput_tokens() { + return input_tokens; + } + + public void setInput_tokens(int input_tokens) { + this.input_tokens = input_tokens; + } + } } // Input type to be used for vector search should "search_query" @@ -135,7 +161,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, input_type); - Uni response = + Uni response = applyRetry( cohereEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -144,10 +170,23 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.getEmbeddings() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.getEmbeddings() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.COHERE, modelName)); } - return Response.of(batchId, resp.getEmbeddings()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.getBilled_units().getInput_tokens(), + ProviderConstants.COHERE, + modelName); + return new Response(batchId, embeddingResponse.getEmbeddings(), vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java index 3ee91bcdb0..aedbba6ed8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java @@ -175,17 +175,7 @@ protected void checkEmbeddingApiKeyHeader(String providerId, Optional ap * @param batchId - Sequence number for the batch to order the vectors. * @param embeddings - Embedding vectors for the given text inputs. */ - public record Response( - int batchId, List embeddings, VectorizeUsageInfo vectorizeUsageInfo) { - public static Response of(int batchId, List embeddings) { - return new Response(batchId, embeddings, new VectorizeUsageInfo(0, 0, 0, "", "")); - } - - public static Response of( - int batchId, List embeddings, VectorizeUsageInfo vectorizeUsageInfo) { - return new Response(batchId, embeddings, vectorizeUsageInfo); - } - } + public record Response(int batchId, List embeddings, VectorizeUsage vectorizeUsage) {} public enum EmbeddingRequestType { /** This is used when vectorizing data in write operation for indexing */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java index f8c2d01fd6..b591461ddf 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java @@ -11,6 +11,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -42,10 +43,11 @@ public HuggingFaceDedicatedEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface HuggingFaceDedicatedEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -107,7 +109,7 @@ public Uni vectorize( String[] textArray = new String[texts.size()]; EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray)); - Uni response = + Uni response = applyRetry( huggingFaceDedicatedEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -116,13 +118,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.HUGGINGFACE_DEDICATED, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.COHERE, + modelName); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java index 7351267226..1299c43d7b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java @@ -1,6 +1,9 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; @@ -13,6 +16,7 @@ import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Collections; import java.util.List; @@ -25,6 +29,7 @@ public class HuggingFaceEmbeddingProvider extends EmbeddingProvider { private static final String providerId = ProviderConstants.HUGGINGFACE; private final HuggingFaceEmbeddingProviderClient huggingFaceEmbeddingProviderClient; + private static final ObjectMapper objectMapper = new ObjectMapper(); public HuggingFaceEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -43,11 +48,12 @@ public HuggingFaceEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface HuggingFaceEmbeddingProviderClient { @POST @Path("/{modelId}") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni> embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @PathParam("modelId") String modelId, EmbeddingRequest request); @@ -76,8 +82,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Extract the "error" node JsonNode errorNode = rootNode.path("error"); // Return the text of the "message" node, or the whole response body if it is missing @@ -104,10 +109,25 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp == null) { - return Response.of(batchId, Collections.emptyList()); + String json = resp.readEntity(String.class); // Read raw JSON + if (json == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.HUGGINGFACE, modelName)); } - return Response.of(batchId, resp); + List embeddings = null; + try { + embeddings = objectMapper.readValue(json, new TypeReference>() {}); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, receivedBytes, 0, ProviderConstants.HUGGINGFACE, modelName); + return new Response(batchId, embeddings, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java index 741ebbe495..f366e3d726 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java @@ -12,6 +12,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -50,10 +51,11 @@ public JinaAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface JinaAIEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -83,8 +85,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}}': {}", providerId, rootNode.toString()); // Extract the "detail" node JsonNode detailNode = rootNode.path("detail"); return detailNode.isMissingNode() ? rootNode.toString() : detailNode.toString(); @@ -121,7 +122,7 @@ public Uni vectorize( (String) vectorizeServiceParameters.get("task"), (Boolean) vectorizeServiceParameters.get("late_chunking")); - Uni response = + Uni response = applyRetry( jinaAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -130,13 +131,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.JINA_AI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.JINA_AI, + modelName); List vectors = Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList(); - return Response.of(batchId, vectors); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java index 69debda5e9..aeb1213248 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java @@ -91,25 +91,13 @@ public Uni vectorize( Collections.sort( vectorizedBatches, (a, b) -> Integer.compare(a.batchId(), b.batchId())); List result = new ArrayList<>(); - int sentBytes = 0; - int receivedBytes = 0; - int totalToken = 0; - String provider = ""; - String modelName = ""; + VectorizeUsage vectorizeUsage = new VectorizeUsage(); for (Response vectorizedBatch : vectorizedBatches) { // create the final ordered result result.addAll(vectorizedBatch.embeddings()); - sentBytes += vectorizedBatch.vectorizeUsageInfo().getRequestSize(); - receivedBytes += vectorizedBatch.vectorizeUsageInfo().getResponseSize(); - totalToken += vectorizedBatch.vectorizeUsageInfo().getTotalTokens(); - provider = vectorizedBatch.vectorizeUsageInfo().getProvider(); - modelName = vectorizedBatch.vectorizeUsageInfo().getModel(); + vectorizeUsage.merge(vectorizedBatch.vectorizeUsage()); } - return Response.of( - 1, - result, - new VectorizeUsageInfo( - sentBytes, receivedBytes, totalToken, provider, modelName)); + return new Response(1, result, vectorizeUsage); }) .invoke( () -> diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java index fdab0408c1..c447c9d6e2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java @@ -11,6 +11,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -44,10 +45,11 @@ public MistralEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface MistralEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -83,8 +85,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Extract the "message" node from the root node JsonNode messageNode = rootNode.path("message"); // Return the text of the "message" node, or the whole response body if it is missing @@ -111,7 +112,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts, modelName, "float"); - Uni response = + Uni response = applyRetry( mistralEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -120,13 +121,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.MISTRAL, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.MISTRAL, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java index 4ab62e3f19..d275dda739 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java @@ -12,6 +12,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -47,10 +48,11 @@ public NvidiaEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface NvidiaEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -79,8 +81,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); JsonNode messageNode = rootNode.path("message"); // Return the text of the "message" node, or the whole response body if it is missing return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); @@ -114,20 +115,33 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, input_type); - Uni response = + Uni response = applyRetry(nvidiaEmbeddingProviderClient.embed("Bearer ", request)); return response .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.NVIDIA, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.NVIDIA, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java index 1b5d328c7a..7a22ac1f00 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java @@ -93,8 +93,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Extract the "message" node from the "error" node JsonNode messageNode = rootNode.at("/error/message"); // Return the text of the "message" node, or the whole response body if it is missing @@ -139,22 +138,25 @@ public Uni vectorize( .transform( res -> { EmbeddingResponse embeddingResponse = res.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.OPENAI, modelName)); + } int sentBytes = Integer.parseInt(res.getHeaderString("sent-bytes")); int receivedBytes = Integer.parseInt(res.getHeaderString("received-bytes")); - VectorizeUsageInfo vectorizeUsageInfo = - new VectorizeUsageInfo( + VectorizeUsage vectorizeUsage = + new VectorizeUsage( sentBytes, receivedBytes, - embeddingResponse.usage().total_tokens, - "openai", + embeddingResponse.usage().total_tokens(), + ProviderConstants.OPENAI, modelName); - if (embeddingResponse.data() == null) { - return Response.of(batchId, Collections.emptyList(), vectorizeUsageInfo); - } Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); List vectors = Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors, vectorizeUsageInfo); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java index 8ab72b1a6b..15d954696c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java @@ -13,6 +13,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -48,11 +49,12 @@ public UpstageAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface UpstageAIEmbeddingProviderClient { @POST // no path specified, as it is already included in the baseUri @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -88,8 +90,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Check if the root node contains a "message" field JsonNode messageNode = rootNode.path("message"); if (!messageNode.isMissingNode()) { @@ -139,7 +140,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.get(0), modelName); - Uni response = + Uni response = applyRetry( upstageAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -148,13 +149,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.UPSTAGE_AI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.UPSTAGE_AI, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsage.java new file mode 100644 index 0000000000..f804dee67d --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsage.java @@ -0,0 +1,69 @@ +package io.stargate.sgv2.jsonapi.service.embedding.operation; + +/** Used to track the metric at a request level to the embedding service */ +public class VectorizeUsage { + private int requestBytes = 0; + private int responseBytes = 0; + private int totalTokens = 0; + private String provider = ""; + private String model = ""; + + public VectorizeUsage( + int requestBytes, int responseBytes, int totalTokens, String provider, String model) { + this.requestBytes = requestBytes; + this.responseBytes = responseBytes; + this.totalTokens = totalTokens; + this.provider = provider; + this.model = model; + } + + public VectorizeUsage() { + super(); + } + + public VectorizeUsage(String provider, String model) { + super(); + this.provider = provider; + this.model = model; + } + + public void merge(VectorizeUsage vectorizeUsage) { + this.requestBytes += vectorizeUsage.getRequestBytes(); + this.responseBytes += vectorizeUsage.getResponseBytes(); + this.totalTokens += vectorizeUsage.getTotalTokens(); + this.provider = vectorizeUsage.getProvider(); + this.model = vectorizeUsage.getModel(); + } + + public int getRequestBytes() { + return requestBytes; + } + + public int getResponseBytes() { + return responseBytes; + } + + public int getTotalTokens() { + return totalTokens; + } + + public String getProvider() { + return provider; + } + + public String getModel() { + return model; + } + + public void setRequestBytes(int requestBytes) { + this.requestBytes = requestBytes; + } + + public void setResponseBytes(int responseBytes) { + this.responseBytes = responseBytes; + } + + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java deleted file mode 100644 index 8f2aa48b8c..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsageInfo.java +++ /dev/null @@ -1,39 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding.operation; - -/** Used to track the metric at a request level to the embedding service */ -public class VectorizeUsageInfo { - private int requestSize; - private int responseSize; - private int totalTokens; - private String provider; - private String model; - - public VectorizeUsageInfo( - int requestSize, int responseSize, int totalTokens, String provider, String model) { - this.requestSize = requestSize; - this.responseSize = responseSize; - this.totalTokens = totalTokens; - this.provider = provider; - this.model = model; - } - - public int getRequestSize() { - return requestSize; - } - - public int getResponseSize() { - return responseSize; - } - - public int getTotalTokens() { - return totalTokens; - } - - public String getProvider() { - return provider; - } - - public String getModel() { - return model; - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java index 76efccfd9f..f17dec6d53 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java @@ -14,6 +14,7 @@ import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Collections; import java.util.List; @@ -46,11 +47,12 @@ public VertexAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface VertexAIEmbeddingProviderClient { @POST @Path("/{modelId}:predict") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @PathParam("modelId") String modelId, EmbeddingRequest request); @@ -77,8 +79,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); return rootNode.toString(); } } @@ -159,7 +160,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList()); - Uni serviceResponse = + Uni serviceResponse = applyRetry( vertexAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), modelName, request)); @@ -167,15 +168,24 @@ public Uni vectorize( return serviceResponse .onItem() .transform( - response -> { - if (response.getPredictions() == null) { - return Response.of(batchId, Collections.emptyList()); + resp -> { + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.getPredictions() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.VERTEXAI, modelName)); } List vectors = - response.getPredictions().stream() + embeddingResponse.getPredictions().stream() .map(prediction -> prediction.getEmbeddings().getValues()) .collect(Collectors.toList()); - return Response.of(batchId, vectors); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, receivedBytes, 0, ProviderConstants.VERTEXAI, modelName); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java index c7b1289593..62278f9f96 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java @@ -13,6 +13,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -52,11 +53,12 @@ public VoyageAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface VoyageAIEmbeddingProviderClient { @POST // no path specified, as it is already included in the baseUri @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -118,7 +120,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(inputType, texts.toArray(textArray), modelName, autoTruncate); - Uni response = + Uni response = applyRetry( voyageAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -127,13 +129,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.VOYAGE_AI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.VOYAGE_AI, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java index 1dba9e52ff..17c99639b8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java @@ -4,6 +4,7 @@ import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import io.stargate.sgv2.jsonapi.service.embedding.operation.VectorizeUsage; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -74,7 +75,8 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { List response = new ArrayList<>(texts.size()); - if (texts.size() == 0) return Uni.createFrom().item(Response.of(batchId, response)); + if (texts.size() == 0) + return Uni.createFrom().item(new Response(batchId, response, new VectorizeUsage())); if (!embeddingCredentials.apiKey().isPresent() || !embeddingCredentials.apiKey().get().equals(TEST_API_KEY)) return Uni.createFrom().failure(new RuntimeException("Invalid API Key")); @@ -94,7 +96,7 @@ public Uni vectorize( } } } - return Uni.createFrom().item(Response.of(batchId, response)); + return Uni.createFrom().item(new Response(batchId, response, new VectorizeUsage())); } @Override diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index 7f1973c2f5..ca078799f9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -217,7 +217,8 @@ public Uni vectorize( texts.forEach(t -> customResponse.add(new float[] {0.5f, 0.5f, 0.5f})); // add additional vector customResponse.add(new float[] {0.5f, 0.5f, 0.5f}); - return Uni.createFrom().item(Response.of(batchId, customResponse)); + return Uni.createFrom() + .item(new Response(batchId, customResponse, new VectorizeUsage())); } }; List documents = new ArrayList<>(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java index d2861e9c2b..586bfbfe73 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java @@ -48,7 +48,7 @@ public Uni vectorize( if (t.equals("return 1s")) response.add(new float[] {1.0f, 1.0f, 1.0f}); else response.add(new float[] {0.25f, 0.25f, 0.25f}); }); - return Uni.createFrom().item(Response.of(batchId, response)); + return Uni.createFrom().item(new Response(batchId, response, new VectorizeUsage())); } @Override From 16c8083f3adbcfaf154d60cec5a4c1af2acdf02e Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 13:33:23 -0500 Subject: [PATCH 07/11] Changes to remove VectorizeUsageBean --- .../api/model/command/CommandResult.java | 32 ----------- .../api/model/command/VectorizeUsageBean.java | 53 ------------------- .../jsonapi/api/v1/CollectionResource.java | 22 +------- .../service/embedding/DataVectorizer.java | 36 ++----------- .../embedding/DataVectorizerService.java | 9 +--- .../operation/DataVectorizerTest.java | 53 ++++--------------- 6 files changed, 16 insertions(+), 189 deletions(-) delete mode 100644 src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java index 37faf293a6..5b60582ff2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/CommandResult.java @@ -138,38 +138,6 @@ public RestResponse toRestResponse() { return RestResponse.ok(this); } - /** - * Create the {@link RestResponse} Maps CommandResult to RestResponse. Except for few selective - * errors, all errors are mapped to http status 200. In case of 401, 500, 502 and 504 response is - * sent with appropriate status code. - * - * @return - */ - public RestResponse toRestResponse(String vectorizeHeader) { - if (null != this.errors()) { - final Optional first = - this.errors().stream() - .filter(error -> error.httpStatus() != Response.Status.OK) - .findFirst(); - - if (first.isPresent()) { - if (vectorizeHeader != null) { - return RestResponse.ResponseBuilder.create(first.get().httpStatus(), this) - .header("vectorize-usage", vectorizeHeader) - .build(); - } else { - return RestResponse.ResponseBuilder.create(first.get().httpStatus(), this).build(); - } - } - } - if (vectorizeHeader != null) { - return RestResponse.ResponseBuilder.create(RestResponse.Status.OK, this) - .header("vectorize-usage", vectorizeHeader) - .build(); - } - return RestResponse.ok(this); - } - /** * returned a new CommandResult with warning message added in status map * diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java deleted file mode 100644 index f0e699bea2..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/VectorizeUsageBean.java +++ /dev/null @@ -1,53 +0,0 @@ -package io.stargate.sgv2.jsonapi.api.model.command; - -import jakarta.enterprise.context.RequestScoped; - -/** Bean class which is used to return as response json in header */ -@RequestScoped -public class VectorizeUsageBean { - private int requestSize; - private int responseSize; - private int totalTokens; - private String provider; - private String model; - - public int getRequestSize() { - return requestSize; - } - - public void setRequestSize(int requestSize) { - this.requestSize = requestSize; - } - - public int getResponseSize() { - return responseSize; - } - - public void setResponseSize(int responseSize) { - this.responseSize = responseSize; - } - - public int getTotalTokens() { - return totalTokens; - } - - public void setTotalTokens(int totalTokens) { - this.totalTokens = totalTokens; - } - - public String getProvider() { - return provider; - } - - public void setProvider(String provider) { - this.provider = provider; - } - - public String getModel() { - return model; - } - - public void setModel(String model) { - this.model = model; - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java index 6a412750eb..9c1aabb6fb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java @@ -2,12 +2,8 @@ import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Strings; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.impl.AlterTableCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CountDocumentsCommand; import io.stargate.sgv2.jsonapi.api.model.command.impl.CreateIndexCommand; @@ -87,10 +83,6 @@ public class CollectionResource { @Inject private JsonProcessingMetricsReporter jsonProcessingMetricsReporter; - @Inject VectorizeUsageBean vectorizeUsageBean; - - @Inject ObjectMapper objectMapper; - @Inject public CollectionResource(MeteredCommandProcessor meteredCommandProcessor) { this.meteredCommandProcessor = meteredCommandProcessor; @@ -268,18 +260,6 @@ public Uni> postCommand( dataApiRequestInfo, commandContext, command); } }) - .map( - commandResult -> { - String vectorize = null; - try { - - if (!Strings.isNullOrEmpty(vectorizeUsageBean.getModel())) { - vectorize = objectMapper.writeValueAsString(vectorizeUsageBean); - } - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - return commandResult.toRestResponse(vectorize); - }); + .map(commandResult -> commandResult.toRestResponse()); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index 3e843b9246..31395afda8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -9,7 +9,6 @@ import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import io.smallrye.mutiny.Uni; -import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; @@ -35,7 +34,6 @@ public class DataVectorizer { private final JsonNodeFactory nodeFactory; private final EmbeddingCredentials embeddingCredentials; private final SchemaObject schemaObject; - private final VectorizeUsageBean vectorizeUsageBean; /** * Constructor @@ -50,13 +48,11 @@ public DataVectorizer( EmbeddingProvider embeddingProvider, JsonNodeFactory nodeFactory, EmbeddingCredentials embeddingCredentials, - SchemaObject schemaObject, - VectorizeUsageBean vectorizeUsageBean) { + SchemaObject schemaObject) { this.embeddingProvider = embeddingProvider; this.nodeFactory = nodeFactory; this.embeddingCredentials = embeddingCredentials; this.schemaObject = schemaObject; - this.vectorizeUsageBean = vectorizeUsageBean; } /** @@ -111,15 +107,7 @@ public Uni vectorize(List documents) { vectorizeTexts, embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) - .map( - res -> { - vectorizeUsageBean.setRequestSize(res.vectorizeUsage().getRequestBytes()); - vectorizeUsageBean.setResponseSize(res.vectorizeUsage().getResponseBytes()); - vectorizeUsageBean.setTotalTokens(res.vectorizeUsage().getTotalTokens()); - vectorizeUsageBean.setProvider(res.vectorizeUsage().getProvider()); - vectorizeUsageBean.setModel(res.vectorizeUsage().getModel()); - return res.embeddings(); - }); + .map(res -> res.embeddings()); return vectors .onItem() .transform( @@ -186,15 +174,7 @@ public Uni vectorize(String vectorizeContent) { List.of(vectorizeContent), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) - .map( - res -> { - vectorizeUsageBean.setRequestSize(res.vectorizeUsage().getRequestBytes()); - vectorizeUsageBean.setResponseSize(res.vectorizeUsage().getResponseBytes()); - vectorizeUsageBean.setTotalTokens(res.vectorizeUsage().getTotalTokens()); - vectorizeUsageBean.setProvider(res.vectorizeUsage().getProvider()); - vectorizeUsageBean.setModel(res.vectorizeUsage().getModel()); - return res.embeddings(); - }); + .map(res -> res.embeddings()); return vectors .onItem() .transform( @@ -241,15 +221,7 @@ public Uni vectorize(SortClause sortClause) { List.of(text), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.SEARCH) - .map( - res -> { - vectorizeUsageBean.setRequestSize(res.vectorizeUsage().getRequestBytes()); - vectorizeUsageBean.setResponseSize(res.vectorizeUsage().getResponseBytes()); - vectorizeUsageBean.setTotalTokens(res.vectorizeUsage().getTotalTokens()); - vectorizeUsageBean.setProvider(res.vectorizeUsage().getProvider()); - vectorizeUsageBean.setModel(res.vectorizeUsage().getModel()); - return res.embeddings(); - }); + .map(res -> res.embeddings()); return vectors .onItem() .transform( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java index 6798f8aab2..073ed3dac9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java @@ -9,7 +9,6 @@ import io.micrometer.core.instrument.MeterRegistry; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.model.command.*; -import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateOperator; @@ -38,18 +37,15 @@ public class DataVectorizerService { private final ObjectMapper objectMapper; private final MeterRegistry meterRegistry; private final JsonApiMetricsConfig jsonApiMetricsConfig; - private final VectorizeUsageBean vectorizeUsageBean; @Inject public DataVectorizerService( ObjectMapper objectMapper, MeterRegistry meterRegistry, - JsonApiMetricsConfig jsonApiMetricsConfig, - VectorizeUsageBean vectorizeUsageBean) { + JsonApiMetricsConfig jsonApiMetricsConfig) { this.objectMapper = objectMapper; this.meterRegistry = meterRegistry; this.jsonApiMetricsConfig = jsonApiMetricsConfig; - this.vectorizeUsageBean = vectorizeUsageBean; } /** @@ -99,8 +95,7 @@ public DataVectorizer constructDataVectorizer( embeddingProvider, objectMapper.getNodeFactory(), dataApiRequestInfo.getEmbeddingCredentials(), - commandContext.schemaObject(), - vectorizeUsageBean); + commandContext.schemaObject()); } private Uni vectorizeSortClause( diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index ca078799f9..173ba5f42c 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -10,7 +10,6 @@ import io.quarkus.test.junit.TestProfile; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; -import io.stargate.sgv2.jsonapi.api.model.command.VectorizeUsageBean; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause; import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; @@ -36,7 +35,6 @@ @TestProfile(PropertyBasedOverrideProfile.class) public class DataVectorizerTest { @Inject ObjectMapper objectMapper; - @Inject VectorizeUsageBean vectorizeUsageBean; private final EmbeddingProvider testService = new TestEmbeddingProvider(); private final CollectionSchemaObject collectionSettings = TestEmbeddingProvider.commandContextWithVectorize.schemaObject(); @@ -54,11 +52,7 @@ public void testTextValues() { } DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); try { dataVectorizer.vectorize(documents).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -85,11 +79,7 @@ public void testEmptyValues() { DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); try { dataVectorizer.vectorize(documents).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -119,11 +109,7 @@ public void testNonTextValues() { DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); try { Throwable failure = dataVectorizer @@ -152,11 +138,7 @@ public void testNullValues() { DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); try { dataVectorizer.vectorize(documents).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -179,11 +161,7 @@ public void testWithBothVectorFieldValues() { documents.add(document); DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); try { Throwable failure = dataVectorizer @@ -230,8 +208,7 @@ public Uni vectorize( testProvider, objectMapper.getNodeFactory(), embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + collectionSettings); Throwable failure = dataVectorizer @@ -273,11 +250,7 @@ public void testWithUnmatchedVectorSize() { } DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); Throwable failure = dataVectorizer @@ -306,11 +279,7 @@ public void sortClauseValues() { SortClause sortClause = new SortClause(sortExpressions); DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); try { dataVectorizer.vectorize(sortClause).subscribe().asCompletionStage().get(); } catch (Exception e) { @@ -333,11 +302,7 @@ public void vectorize() { arrayNode.add(objectMapper.getNodeFactory().numberNode(0.11f)); DataVectorizer dataVectorizer = new DataVectorizer( - testService, - objectMapper.getNodeFactory(), - embeddingCredentials, - collectionSettings, - vectorizeUsageBean); + testService, objectMapper.getNodeFactory(), embeddingCredentials, collectionSettings); try { final float[] testData = dataVectorizer.vectorize("test data").subscribe().asCompletionStage().get(); From 741440a93a2152506db5efafa5665f588271f49b Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 14:15:02 -0500 Subject: [PATCH 08/11] Refactored vectorize usage from the EGW client response. --- .../embedding/gateway/EmbeddingGatewayClient.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java index 9ec67b03a6..bc2296c66c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java @@ -190,10 +190,14 @@ else if (value instanceof Boolean) ErrorCodeV1.valueOf(resp.getError().getErrorCode()), resp.getError().getErrorMessage()); } + VectorizeUsage vectorizeUsage = new VectorizeUsage(provider, modelName); if (resp.getEmbeddingsList() == null) { - return new Response( - batchId, Collections.emptyList(), new VectorizeUsage(provider, modelName)); + return new Response(batchId, Collections.emptyList(), vectorizeUsage); } + EmbeddingGateway.EmbeddingResponse.Usage usage = resp.getUsage(); + vectorizeUsage.setRequestBytes(usage.getInputBytes()); + vectorizeUsage.setResponseBytes(usage.getOutputBytes()); + vectorizeUsage.setTotalTokens(usage.getTotalTokens()); final List vectors = resp.getEmbeddingsList().stream() .map( @@ -205,7 +209,7 @@ else if (value instanceof Boolean) return embedding; }) .toList(); - return new Response(batchId, vectors, new VectorizeUsage(provider, modelName)); + return new Response(batchId, vectors, vectorizeUsage); }); } From 42c4b6e255dbc193d7d29ffed1a6602a31828bd2 Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 14:19:53 -0500 Subject: [PATCH 09/11] Fixed compile error --- .../service/embedding/operation/JinaAIEmbeddingProvider.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java index f366e3d726..d2e5b218a2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java @@ -149,7 +149,9 @@ public Uni vectorize( ProviderConstants.JINA_AI, modelName); List vectors = - Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList(); + Arrays.stream(embeddingResponse.data()) + .map(EmbeddingResponse.Data::embedding) + .toList(); return new Response(batchId, vectors, vectorizeUsage); }); } From b4eef8daeab914628568b2f1620f5ccbc45e7152 Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 14:59:19 -0500 Subject: [PATCH 10/11] Reverted the unnecessary change --- .../stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index 31395afda8..47f011feeb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -174,7 +174,7 @@ public Uni vectorize(String vectorizeContent) { List.of(vectorizeContent), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) - .map(res -> res.embeddings()); + .map(EmbeddingProvider.Response::embeddings); return vectors .onItem() .transform( From 88a64f75d25bd8ea148cfe25c4743db1570bd75e Mon Sep 17 00:00:00 2001 From: "mahesh.rajamani" Date: Wed, 12 Feb 2025 15:06:12 -0500 Subject: [PATCH 11/11] Fixed typo --- .../operation/HuggingFaceDedicatedEmbeddingProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java index b591461ddf..370fbad7ba 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java @@ -135,7 +135,7 @@ public Uni vectorize( sentBytes, receivedBytes, embeddingResponse.usage().total_tokens(), - ProviderConstants.COHERE, + ProviderConstants.HUGGINGFACE_DEDICATED, modelName); return new Response(batchId, vectors, vectorizeUsage); });