Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider;
import io.stargate.sgv2.jsonapi.service.embedding.operation.VectorizeUsage;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -189,9 +190,14 @@ else if (value instanceof Boolean)
ErrorCodeV1.valueOf(resp.getError().getErrorCode()),
resp.getError().getErrorMessage());
}
VectorizeUsage vectorizeUsage = new VectorizeUsage(provider, modelName);
if (resp.getEmbeddingsList() == null) {
return Response.of(batchId, Collections.emptyList());
return new Response(batchId, Collections.emptyList(), vectorizeUsage);
}
EmbeddingGateway.EmbeddingResponse.Usage usage = resp.getUsage();
vectorizeUsage.setRequestBytes(usage.getInputBytes());
vectorizeUsage.setResponseBytes(usage.getOutputBytes());
vectorizeUsage.setTotalTokens(usage.getTotalTokens());
final List<float[]> vectors =
resp.getEmbeddingsList().stream()
.map(
Expand All @@ -203,7 +209,7 @@ else if (value instanceof Boolean)
return embedding;
})
.toList();
return Response.of(batchId, vectors);
return new Response(batchId, vectors, vectorizeUsage);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectReader;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.io.ByteStreams;
import io.smallrye.mutiny.Uni;
import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -78,12 +81,14 @@ public Uni<Response> vectorize(
.credentialsProvider(StaticCredentialsProvider.create(awsCreds))
.region(Region.of(vectorizeServiceParameters.get("region").toString()))
.build();
final VectorizeUsage vectorizeUsage = new VectorizeUsage(ProviderConstants.BEDROCK, modelName);
final CompletableFuture<InvokeModelResponse> invokeModelResponseCompletableFuture =
client.invokeModel(
request -> {
final byte[] inputData;
try {
inputData = ow.writeValueAsBytes(new EmbeddingRequest(texts.get(0), dimension));
vectorizeUsage.setRequestBytes(inputData.length);
request.body(SdkBytes.fromByteArray(inputData)).modelId(modelName);
} catch (JsonProcessingException e) {
throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException();
Expand All @@ -94,10 +99,15 @@ public Uni<Response> vectorize(
invokeModelResponseCompletableFuture.thenApply(
res -> {
try {
EmbeddingResponse response =
or.readValue(res.body().asInputStream(), EmbeddingResponse.class);
InputStream inputStream = res.body().asInputStream();
int receivedBytes =
(int) ByteStreams.copy(inputStream, OutputStream.nullOutputStream());
EmbeddingResponse response = or.readValue(inputStream, EmbeddingResponse.class);
vectorizeUsage.setResponseBytes(receivedBytes);
List<float[]> vectors = List.of(response.embedding);
return Response.of(batchId, vectors);
// Note: This is input tokens
vectorizeUsage.setTotalTokens(response.inputTextTokenCount());
return new Response(batchId, vectors, vectorizeUsage);
} catch (IOException e) {
throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -55,11 +56,12 @@ public AzureOpenAIEmbeddingProvider(

@RegisterRestClient
@RegisterProvider(EmbeddingProviderResponseValidation.class)
@RegisterProvider(NetworkUsageInterceptor.class)
public interface OpenAIEmbeddingProviderClient {
@POST
// no path specified, as it is already included in the baseUri
@ClientHeaderParam(name = "Content-Type", value = "application/json")
Uni<EmbeddingResponse> embed(
Uni<jakarta.ws.rs.core.Response> embed(
// API keys as "api-key", MS Entra as "Authorization: Bearer [token]
@HeaderParam("api-key") String accessToken, EmbeddingRequest request);

Expand Down Expand Up @@ -120,21 +122,34 @@ public Uni<Response> vectorize(
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray), modelName, dimension);

// NOTE: NO "Bearer " prefix with API key for Azure OpenAI
Uni<EmbeddingResponse> response =
Uni<jakarta.ws.rs.core.Response> response =
applyRetry(
openAIEmbeddingProviderClient.embed(embeddingCredentials.apiKey().get(), request));

return response
.onItem()
.transform(
resp -> {
if (resp.data() == null) {
return Response.of(batchId, Collections.emptyList());
EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class);
if (embeddingResponse.data() == null) {
return new Response(
batchId,
Collections.emptyList(),
new VectorizeUsage(ProviderConstants.AZURE_OPENAI, modelName));
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes"));
int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes"));
VectorizeUsage vectorizeUsage =
new VectorizeUsage(
sentBytes,
receivedBytes,
embeddingResponse.usage().total_tokens(),
ProviderConstants.AZURE_OPENAI,
modelName);
Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index());
List<float[]> vectors =
Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
return Response.of(batchId, vectors);
Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList();
return new Response(batchId, vectors, vectorizeUsage);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -47,11 +48,12 @@ public CohereEmbeddingProvider(

@RegisterRestClient
@RegisterProvider(EmbeddingProviderResponseValidation.class)
@RegisterProvider(NetworkUsageInterceptor.class)
public interface CohereEmbeddingProviderClient {
@POST
@Path("/embed")
@ClientHeaderParam(name = "Content-Type", value = "application/json")
Uni<EmbeddingResponse> embed(
Uni<jakarta.ws.rs.core.Response> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
Expand Down Expand Up @@ -108,13 +110,37 @@ protected EmbeddingResponse() {}

private List<float[]> embeddings;

private BilledUnits billed_units;

public List<float[]> getEmbeddings() {
return embeddings;
}

public void setEmbeddings(List<float[]> embeddings) {
this.embeddings = embeddings;
}

public BilledUnits getBilled_units() {
return billed_units;
}

public void setBilled_units(BilledUnits billed_units) {
this.billed_units = billed_units;
}

private static class BilledUnits {
public int input_tokens;

public BilledUnits() {}

public int getInput_tokens() {
return input_tokens;
}

public void setInput_tokens(int input_tokens) {
this.input_tokens = input_tokens;
}
}
}

// Input type to be used for vector search should "search_query"
Expand All @@ -135,7 +161,7 @@ public Uni<Response> vectorize(
EmbeddingRequest request =
new EmbeddingRequest(texts.toArray(textArray), modelName, input_type);

Uni<EmbeddingResponse> response =
Uni<jakarta.ws.rs.core.Response> response =
applyRetry(
cohereEmbeddingProviderClient.embed(
"Bearer " + embeddingCredentials.apiKey().get(), request));
Expand All @@ -144,10 +170,23 @@ public Uni<Response> vectorize(
.onItem()
.transform(
resp -> {
if (resp.getEmbeddings() == null) {
return Response.of(batchId, Collections.emptyList());
EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class);
if (embeddingResponse.getEmbeddings() == null) {
return new Response(
batchId,
Collections.emptyList(),
new VectorizeUsage(ProviderConstants.COHERE, modelName));
}
return Response.of(batchId, resp.getEmbeddings());
int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes"));
int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes"));
VectorizeUsage vectorizeUsage =
new VectorizeUsage(
sentBytes,
receivedBytes,
embeddingResponse.getBilled_units().getInput_tokens(),
ProviderConstants.COHERE,
modelName);
return new Response(batchId, embeddingResponse.getEmbeddings(), vectorizeUsage);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,7 @@ protected void checkEmbeddingApiKeyHeader(String providerId, Optional<String> ap
* @param batchId - Sequence number for the batch to order the vectors.
* @param embeddings - Embedding vectors for the given text inputs.
*/
public record Response(int batchId, List<float[]> embeddings) {
public static Response of(int batchId, List<float[]> embeddings) {
return new Response(batchId, embeddings);
}
}
public record Response(int batchId, List<float[]> embeddings, VectorizeUsage vectorizeUsage) {}

public enum EmbeddingRequestType {
/** This is used when vectorizing data in write operation for indexing */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.stargate.sgv2.jsonapi.service.embedding.operation.error.HttpResponseErrorMessageMapper;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.util.*;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -42,10 +43,11 @@ public HuggingFaceDedicatedEmbeddingProvider(

@RegisterRestClient
@RegisterProvider(EmbeddingProviderResponseValidation.class)
@RegisterProvider(NetworkUsageInterceptor.class)
public interface HuggingFaceDedicatedEmbeddingProviderClient {
@POST
@ClientHeaderParam(name = "Content-Type", value = "application/json")
Uni<EmbeddingResponse> embed(
Uni<jakarta.ws.rs.core.Response> embed(
@HeaderParam("Authorization") String accessToken, EmbeddingRequest request);

@ClientExceptionMapper
Expand Down Expand Up @@ -107,7 +109,7 @@ public Uni<Response> vectorize(
String[] textArray = new String[texts.size()];
EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray));

Uni<EmbeddingResponse> response =
Uni<jakarta.ws.rs.core.Response> response =
applyRetry(
huggingFaceDedicatedEmbeddingProviderClient.embed(
"Bearer " + embeddingCredentials.apiKey().get(), request));
Expand All @@ -116,13 +118,26 @@ public Uni<Response> vectorize(
.onItem()
.transform(
resp -> {
if (resp.data() == null) {
return Response.of(batchId, Collections.emptyList());
EmbeddingResponse embeddingResponse = resp.readEntity(EmbeddingResponse.class);
if (embeddingResponse.data() == null) {
return new Response(
batchId,
Collections.emptyList(),
new VectorizeUsage(ProviderConstants.HUGGINGFACE_DEDICATED, modelName));
}
Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
Arrays.sort(embeddingResponse.data(), (a, b) -> a.index() - b.index());
List<float[]> vectors =
Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
return Response.of(batchId, vectors);
Arrays.stream(embeddingResponse.data()).map(data -> data.embedding()).toList();
int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes"));
int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes"));
VectorizeUsage vectorizeUsage =
new VectorizeUsage(
sentBytes,
receivedBytes,
embeddingResponse.usage().total_tokens(),
ProviderConstants.HUGGINGFACE_DEDICATED,
modelName);
return new Response(batchId, vectors, vectorizeUsage);
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package io.stargate.sgv2.jsonapi.service.embedding.operation;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Uni;
Expand All @@ -13,6 +16,7 @@
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.util.Collections;
import java.util.List;
Expand All @@ -25,6 +29,7 @@
public class HuggingFaceEmbeddingProvider extends EmbeddingProvider {
private static final String providerId = ProviderConstants.HUGGINGFACE;
private final HuggingFaceEmbeddingProviderClient huggingFaceEmbeddingProviderClient;
private static final ObjectMapper objectMapper = new ObjectMapper();

public HuggingFaceEmbeddingProvider(
EmbeddingProviderConfigStore.RequestProperties requestProperties,
Expand All @@ -43,11 +48,12 @@ public HuggingFaceEmbeddingProvider(

@RegisterRestClient
@RegisterProvider(EmbeddingProviderResponseValidation.class)
@RegisterProvider(NetworkUsageInterceptor.class)
public interface HuggingFaceEmbeddingProviderClient {
@POST
@Path("/{modelId}")
@ClientHeaderParam(name = "Content-Type", value = "application/json")
Uni<List<float[]>> embed(
Uni<jakarta.ws.rs.core.Response> embed(
@HeaderParam("Authorization") String accessToken,
@PathParam("modelId") String modelId,
EmbeddingRequest request);
Expand Down Expand Up @@ -76,8 +82,7 @@ private static String getErrorMessage(jakarta.ws.rs.core.Response response) {
JsonNode rootNode = response.readEntity(JsonNode.class);
// Log the response body
logger.info(
String.format(
"Error response from embedding provider '%s': %s", providerId, rootNode.toString()));
"Error response from embedding provider '{}': {}", providerId, rootNode.toString());
// Extract the "error" node
JsonNode errorNode = rootNode.path("error");
// Return the text of the "message" node, or the whole response body if it is missing
Expand All @@ -104,10 +109,25 @@ public Uni<Response> vectorize(
.onItem()
.transform(
resp -> {
if (resp == null) {
return Response.of(batchId, Collections.emptyList());
String json = resp.readEntity(String.class); // Read raw JSON
if (json == null) {
return new Response(
batchId,
Collections.emptyList(),
new VectorizeUsage(ProviderConstants.HUGGINGFACE, modelName));
}
return Response.of(batchId, resp);
List<float[]> embeddings = null;
try {
embeddings = objectMapper.readValue(json, new TypeReference<List<float[]>>() {});
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
int sentBytes = Integer.parseInt(resp.getHeaderString("sent-bytes"));
int receivedBytes = Integer.parseInt(resp.getHeaderString("received-bytes"));
VectorizeUsage vectorizeUsage =
new VectorizeUsage(
sentBytes, receivedBytes, 0, ProviderConstants.HUGGINGFACE, modelName);
return new Response(batchId, embeddings, vectorizeUsage);
});
}

Expand Down
Loading
Loading