diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java index 86ccfc4b21..bc2296c66c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java @@ -10,6 +10,7 @@ import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import io.stargate.sgv2.jsonapi.service.embedding.operation.VectorizeUsage; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -189,9 +190,14 @@ else if (value instanceof Boolean) ErrorCodeV1.valueOf(resp.getError().getErrorCode()), resp.getError().getErrorMessage()); } + VectorizeUsage vectorizeUsage = new VectorizeUsage(provider, modelName); if (resp.getEmbeddingsList() == null) { - return Response.of(batchId, Collections.emptyList()); + return new Response(batchId, Collections.emptyList(), vectorizeUsage); } + EmbeddingGateway.EmbeddingResponse.Usage usage = resp.getUsage(); + vectorizeUsage.setRequestBytes(usage.getInputBytes()); + vectorizeUsage.setResponseBytes(usage.getOutputBytes()); + vectorizeUsage.setTotalTokens(usage.getTotalTokens()); final List vectors = resp.getEmbeddingsList().stream() .map( @@ -203,7 +209,7 @@ else if (value instanceof Boolean) return embedding; }) .toList(); - return Response.of(batchId, vectors); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java index e32819a680..585b03d67b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java @@ -9,12 +9,15 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectReader; import com.fasterxml.jackson.databind.ObjectWriter; +import com.google.common.io.ByteStreams; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -78,12 +81,14 @@ public Uni vectorize( .credentialsProvider(StaticCredentialsProvider.create(awsCreds)) .region(Region.of(vectorizeServiceParameters.get("region").toString())) .build(); + final VectorizeUsage vectorizeUsage = new VectorizeUsage(ProviderConstants.BEDROCK, modelName); final CompletableFuture invokeModelResponseCompletableFuture = client.invokeModel( request -> { final byte[] inputData; try { inputData = ow.writeValueAsBytes(new EmbeddingRequest(texts.get(0), dimension)); + vectorizeUsage.setRequestBytes(inputData.length); request.body(SdkBytes.fromByteArray(inputData)).modelId(modelName); } catch (JsonProcessingException e) { throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); @@ -94,10 +99,15 @@ public Uni vectorize( invokeModelResponseCompletableFuture.thenApply( res -> { try { - EmbeddingResponse response = - or.readValue(res.body().asInputStream(), EmbeddingResponse.class); + InputStream inputStream = res.body().asInputStream(); + int receivedBytes = + (int) ByteStreams.copy(inputStream, OutputStream.nullOutputStream()); + EmbeddingResponse response = or.readValue(inputStream, EmbeddingResponse.class); + vectorizeUsage.setResponseBytes(receivedBytes); List vectors = List.of(response.embedding); - return Response.of(batchId, vectors); + // Note: This is input tokens + vectorizeUsage.setTotalTokens(response.inputTextTokenCount()); + return new Response(batchId, vectors, vectorizeUsage); } catch (IOException e) { throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java index dcc655e1f8..549b926d6b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java @@ -12,6 +12,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -55,11 +56,12 @@ public AzureOpenAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface OpenAIEmbeddingProviderClient { @POST // no path specified, as it is already included in the baseUri @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( // API keys as "api-key", MS Entra as "Authorization: Bearer [token] @HeaderParam("api-key") String accessToken, EmbeddingRequest request); @@ -120,7 +122,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension); // NOTE: NO "Bearer " prefix with API key for Azure OpenAI - Uni response = + Uni response = applyRetry( openAIEmbeddingProviderClient.embed(embeddingCredentials.apiKey().get(), request)); @@ -128,13 +130,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.AZURE_OPENAI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.AZURE_OPENAI, + modelName); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java index c03d0a7138..fa17bee637 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java @@ -13,6 +13,7 @@ import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Collections; import java.util.List; @@ -47,11 +48,12 @@ public CohereEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface CohereEmbeddingProviderClient { @POST @Path("/embed") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -108,6 +110,8 @@ protected EmbeddingResponse() {} private List embeddings; + private BilledUnits billed_units; + public List getEmbeddings() { return embeddings; } @@ -115,6 +119,28 @@ public List getEmbeddings() { public void setEmbeddings(List embeddings) { this.embeddings = embeddings; } + + public BilledUnits getBilled_units() { + return billed_units; + } + + public void setBilled_units(BilledUnits billed_units) { + this.billed_units = billed_units; + } + + private static class BilledUnits { + public int input_tokens; + + public BilledUnits() {} + + public int getInput_tokens() { + return input_tokens; + } + + public void setInput_tokens(int input_tokens) { + this.input_tokens = input_tokens; + } + } } // Input type to be used for vector search should "search_query" @@ -135,7 +161,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, input_type); - Uni response = + Uni response = applyRetry( cohereEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -144,10 +170,23 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.getEmbeddings() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.getEmbeddings() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.COHERE, modelName)); } - return Response.of(batchId, resp.getEmbeddings()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.getBilled_units().getInput_tokens(), + ProviderConstants.COHERE, + modelName); + return new Response(batchId, embeddingResponse.getEmbeddings(), vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java index aa21ae1af4..aedbba6ed8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java @@ -175,11 +175,7 @@ protected void checkEmbeddingApiKeyHeader(String providerId, Optional ap * @param batchId - Sequence number for the batch to order the vectors. * @param embeddings - Embedding vectors for the given text inputs. */ - public record Response(int batchId, List embeddings) { - public static Response of(int batchId, List embeddings) { - return new Response(batchId, embeddings); - } - } + public record Response(int batchId, List embeddings, VectorizeUsage vectorizeUsage) {} public enum EmbeddingRequestType { /** This is used when vectorizing data in write operation for indexing */ diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java index f8c2d01fd6..370fbad7ba 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java @@ -11,6 +11,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -42,10 +43,11 @@ public HuggingFaceDedicatedEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface HuggingFaceDedicatedEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -107,7 +109,7 @@ public Uni vectorize( String[] textArray = new String[texts.size()]; EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray)); - Uni response = + Uni response = applyRetry( huggingFaceDedicatedEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -116,13 +118,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.HUGGINGFACE_DEDICATED, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.HUGGINGFACE_DEDICATED, + modelName); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java index 7351267226..1299c43d7b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java @@ -1,6 +1,9 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; @@ -13,6 +16,7 @@ import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Collections; import java.util.List; @@ -25,6 +29,7 @@ public class HuggingFaceEmbeddingProvider extends EmbeddingProvider { private static final String providerId = ProviderConstants.HUGGINGFACE; private final HuggingFaceEmbeddingProviderClient huggingFaceEmbeddingProviderClient; + private static final ObjectMapper objectMapper = new ObjectMapper(); public HuggingFaceEmbeddingProvider( EmbeddingProviderConfigStore.RequestProperties requestProperties, @@ -43,11 +48,12 @@ public HuggingFaceEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface HuggingFaceEmbeddingProviderClient { @POST @Path("/{modelId}") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni> embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @PathParam("modelId") String modelId, EmbeddingRequest request); @@ -76,8 +82,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Extract the "error" node JsonNode errorNode = rootNode.path("error"); // Return the text of the "message" node, or the whole response body if it is missing @@ -104,10 +109,25 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp == null) { - return Response.of(batchId, Collections.emptyList()); + String json = resp.readEntity(String.class); // Read raw JSON + if (json == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.HUGGINGFACE, modelName)); } - return Response.of(batchId, resp); + List embeddings = null; + try { + embeddings = objectMapper.readValue(json, new TypeReference>() {}); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, receivedBytes, 0, ProviderConstants.HUGGINGFACE, modelName); + return new Response(batchId, embeddings, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java index 741ebbe495..d2e5b218a2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java @@ -12,6 +12,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -50,10 +51,11 @@ public JinaAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface JinaAIEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -83,8 +85,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}}': {}", providerId, rootNode.toString()); // Extract the "detail" node JsonNode detailNode = rootNode.path("detail"); return detailNode.isMissingNode() ? rootNode.toString() : detailNode.toString(); @@ -121,7 +122,7 @@ public Uni vectorize( (String) vectorizeServiceParameters.get("task"), (Boolean) vectorizeServiceParameters.get("late_chunking")); - Uni response = + Uni response = applyRetry( jinaAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -130,13 +131,28 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.JINA_AI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.JINA_AI, + modelName); List vectors = - Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()) + .map(EmbeddingResponse.Data::embedding) + .toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java index 1cc2e3dac3..aeb1213248 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java @@ -91,11 +91,13 @@ public Uni vectorize( Collections.sort( vectorizedBatches, (a, b) -> Integer.compare(a.batchId(), b.batchId())); List result = new ArrayList<>(); + VectorizeUsage vectorizeUsage = new VectorizeUsage(); for (Response vectorizedBatch : vectorizedBatches) { // create the final ordered result result.addAll(vectorizedBatch.embeddings()); + vectorizeUsage.merge(vectorizedBatch.vectorizeUsage()); } - return Response.of(1, result); + return new Response(1, result, vectorizeUsage); }) .invoke( () -> diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java index fdab0408c1..c447c9d6e2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java @@ -11,6 +11,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -44,10 +45,11 @@ public MistralEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface MistralEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -83,8 +85,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Extract the "message" node from the root node JsonNode messageNode = rootNode.path("message"); // Return the text of the "message" node, or the whole response body if it is missing @@ -111,7 +112,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts, modelName, "float"); - Uni response = + Uni response = applyRetry( mistralEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -120,13 +121,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.MISTRAL, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.MISTRAL, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java new file mode 100644 index 0000000000..3206e05d8c --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NetworkUsageInterceptor.java @@ -0,0 +1,47 @@ +package io.stargate.sgv2.jsonapi.service.embedding.operation; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.ByteStreams; +import com.google.common.io.CountingOutputStream; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientResponseContext; +import jakarta.ws.rs.client.ClientResponseFilter; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class NetworkUsageInterceptor implements ClientResponseFilter { + + private static final Logger LOGGER = LoggerFactory.getLogger(NetworkUsageInterceptor.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // Jackson object mapper + + @Override + public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) + throws IOException { + int receivedBytes = 0; + int sentBytes = 0; + if (requestContext.hasEntity()) { + try { + CountingOutputStream cus = new CountingOutputStream(OutputStream.nullOutputStream()); + OBJECT_MAPPER.writeValue(cus, requestContext.getEntity()); + cus.close(); + sentBytes = (int) cus.getCount(); + } catch (Exception e) { + LOGGER.warn("Failed to measure request body size: " + e.getMessage()); + } + } + if (responseContext.hasEntity()) { + receivedBytes = responseContext.getLength(); + if (receivedBytes <= 0) { + // Read the response entity stream to measure its size + InputStream inputStream = responseContext.getEntityStream(); + receivedBytes = (int) ByteStreams.copy(inputStream, OutputStream.nullOutputStream()); + } + } + + responseContext.getHeaders().add("sent-bytes", String.valueOf(sentBytes)); + responseContext.getHeaders().add("received-bytes", String.valueOf(receivedBytes)); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java index 4ab62e3f19..d275dda739 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java @@ -12,6 +12,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -47,10 +48,11 @@ public NvidiaEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface NvidiaEmbeddingProviderClient { @POST @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -79,8 +81,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); JsonNode messageNode = rootNode.path("message"); // Return the text of the "message" node, or the whole response body if it is missing return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); @@ -114,20 +115,33 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, input_type); - Uni response = + Uni response = applyRetry(nvidiaEmbeddingProviderClient.embed("Bearer ", request)); return response .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.NVIDIA, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.NVIDIA, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java index 8c5077f2ff..7a22ac1f00 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java @@ -10,6 +10,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; +import jakarta.enterprise.context.control.ActivateRequestContext; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; @@ -23,6 +24,7 @@ import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +@ActivateRequestContext public class OpenAIEmbeddingProvider extends EmbeddingProvider { private static final String providerId = ProviderConstants.OPENAI; private final OpenAIEmbeddingProviderClient openAIEmbeddingProviderClient; @@ -50,11 +52,12 @@ public OpenAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface OpenAIEmbeddingProviderClient { @POST @Path("/embeddings") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @HeaderParam("OpenAI-Organization") String organizationId, @HeaderParam("OpenAI-Project") String projectId, @@ -90,8 +93,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Extract the "message" node from the "error" node JsonNode messageNode = rootNode.at("/error/message"); // Return the text of the "message" node, or the whole response body if it is missing @@ -122,25 +124,39 @@ public Uni vectorize( String organizationId = (String) vectorizeServiceParameters.get("organizationId"); String projectId = (String) vectorizeServiceParameters.get("projectId"); - Uni response = + // ✅ Create an instance of NetworkUsageInfo and pass it to request properties + Uni response = applyRetry( openAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), organizationId, projectId, - request)); + request)); // Pass the object dynamically return response .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + res -> { + EmbeddingResponse embeddingResponse = res.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.OPENAI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(res.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(res.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.OPENAI, + modelName); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java index 8ab72b1a6b..15d954696c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java @@ -13,6 +13,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -48,11 +49,12 @@ public UpstageAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface UpstageAIEmbeddingProviderClient { @POST // no path specified, as it is already included in the baseUri @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -88,8 +90,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); // Check if the root node contains a "message" field JsonNode messageNode = rootNode.path("message"); if (!messageNode.isMissingNode()) { @@ -139,7 +140,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.get(0), modelName); - Uni response = + Uni response = applyRetry( upstageAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -148,13 +149,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.UPSTAGE_AI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.UPSTAGE_AI, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsage.java new file mode 100644 index 0000000000..f804dee67d --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VectorizeUsage.java @@ -0,0 +1,69 @@ +package io.stargate.sgv2.jsonapi.service.embedding.operation; + +/** Used to track the metric at a request level to the embedding service */ +public class VectorizeUsage { + private int requestBytes = 0; + private int responseBytes = 0; + private int totalTokens = 0; + private String provider = ""; + private String model = ""; + + public VectorizeUsage( + int requestBytes, int responseBytes, int totalTokens, String provider, String model) { + this.requestBytes = requestBytes; + this.responseBytes = responseBytes; + this.totalTokens = totalTokens; + this.provider = provider; + this.model = model; + } + + public VectorizeUsage() { + super(); + } + + public VectorizeUsage(String provider, String model) { + super(); + this.provider = provider; + this.model = model; + } + + public void merge(VectorizeUsage vectorizeUsage) { + this.requestBytes += vectorizeUsage.getRequestBytes(); + this.responseBytes += vectorizeUsage.getResponseBytes(); + this.totalTokens += vectorizeUsage.getTotalTokens(); + this.provider = vectorizeUsage.getProvider(); + this.model = vectorizeUsage.getModel(); + } + + public int getRequestBytes() { + return requestBytes; + } + + public int getResponseBytes() { + return responseBytes; + } + + public int getTotalTokens() { + return totalTokens; + } + + public String getProvider() { + return provider; + } + + public String getModel() { + return model; + } + + public void setRequestBytes(int requestBytes) { + this.requestBytes = requestBytes; + } + + public void setResponseBytes(int responseBytes) { + this.responseBytes = responseBytes; + } + + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java index 76efccfd9f..f17dec6d53 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java @@ -14,6 +14,7 @@ import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Collections; import java.util.List; @@ -46,11 +47,12 @@ public VertexAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface VertexAIEmbeddingProviderClient { @POST @Path("/{modelId}:predict") @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @PathParam("modelId") String modelId, EmbeddingRequest request); @@ -77,8 +79,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) { JsonNode rootNode = response.readEntity(JsonNode.class); // Log the response body logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); + "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); return rootNode.toString(); } } @@ -159,7 +160,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList()); - Uni serviceResponse = + Uni serviceResponse = applyRetry( vertexAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), modelName, request)); @@ -167,15 +168,24 @@ public Uni vectorize( return serviceResponse .onItem() .transform( - response -> { - if (response.getPredictions() == null) { - return Response.of(batchId, Collections.emptyList()); + resp -> { + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.getPredictions() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.VERTEXAI, modelName)); } List vectors = - response.getPredictions().stream() + embeddingResponse.getPredictions().stream() .map(prediction -> prediction.getEmbeddings().getValues()) .collect(Collectors.toList()); - return Response.of(batchId, vectors); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, receivedBytes, 0, ProviderConstants.VERTEXAI, modelName); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java index c7b1289593..62278f9f96 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java @@ -13,6 +13,7 @@ import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; import java.util.Collections; @@ -52,11 +53,12 @@ public VoyageAIEmbeddingProvider( @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(NetworkUsageInterceptor.class) public interface VoyageAIEmbeddingProviderClient { @POST // no path specified, as it is already included in the baseUri @ClientHeaderParam(name = "Content-Type", value = "application/json") - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); @ClientExceptionMapper @@ -118,7 +120,7 @@ public Uni vectorize( EmbeddingRequest request = new EmbeddingRequest(inputType, texts.toArray(textArray), modelName, autoTruncate); - Uni response = + Uni response = applyRetry( voyageAIEmbeddingProviderClient.embed( "Bearer " + embeddingCredentials.apiKey().get(), request)); @@ -127,13 +129,26 @@ public Uni vectorize( .onItem() .transform( resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class); + if (embeddingResponse.data() == null) { + return new Response( + batchId, + Collections.emptyList(), + new VectorizeUsage(ProviderConstants.VOYAGE_AI, modelName)); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index()); + int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes")); + int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes")); + VectorizeUsage vectorizeUsage = + new VectorizeUsage( + sentBytes, + receivedBytes, + embeddingResponse.usage().total_tokens(), + ProviderConstants.VOYAGE_AI, + modelName); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList(); + return new Response(batchId, vectors, vectorizeUsage); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java index 1dba9e52ff..17c99639b8 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java @@ -4,6 +4,7 @@ import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import io.stargate.sgv2.jsonapi.service.embedding.operation.VectorizeUsage; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -74,7 +75,8 @@ public Uni vectorize( EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { List response = new ArrayList<>(texts.size()); - if (texts.size() == 0) return Uni.createFrom().item(Response.of(batchId, response)); + if (texts.size() == 0) + return Uni.createFrom().item(new Response(batchId, response, new VectorizeUsage())); if (!embeddingCredentials.apiKey().isPresent() || !embeddingCredentials.apiKey().get().equals(TEST_API_KEY)) return Uni.createFrom().failure(new RuntimeException("Invalid API Key")); @@ -94,7 +96,7 @@ public Uni vectorize( } } } - return Uni.createFrom().item(Response.of(batchId, response)); + return Uni.createFrom().item(new Response(batchId, response, new VectorizeUsage())); } @Override diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index dd92b1affb..173ba5f42c 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -195,7 +195,8 @@ public Uni vectorize( texts.forEach(t -> customResponse.add(new float[] {0.5f, 0.5f, 0.5f})); // add additional vector customResponse.add(new float[] {0.5f, 0.5f, 0.5f}); - return Uni.createFrom().item(Response.of(batchId, customResponse)); + return Uni.createFrom() + .item(new Response(batchId, customResponse, new VectorizeUsage())); } }; List documents = new ArrayList<>(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java index d2861e9c2b..586bfbfe73 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java @@ -48,7 +48,7 @@ public Uni vectorize( if (t.equals("return 1s")) response.add(new float[] {1.0f, 1.0f, 1.0f}); else response.add(new float[] {0.25f, 0.25f, 0.25f}); }); - return Uni.createFrom().item(Response.of(batchId, response)); + return Uni.createFrom().item(new Response(batchId, response, new VectorizeUsage())); } @Override