Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e69667c
init
Yuqi-Du Apr 14, 2025
b50241a
java doc
Yuqi-Du Apr 14, 2025
1c1f1d8
WIP - code changes, compiles,
amorton Jun 10, 2025
e1db443
WIP - basics working on laptop, checking regressions
amorton Jun 11, 2025
120993c
tmp
amorton Jun 13, 2025
88a1d2b
finished merge from main
amorton Jun 15, 2025
a024699
fmt
amorton Jun 15, 2025
19fa063
Use ServiceProviderConfig
amorton Jun 15, 2025
7d8a6cb
EmbeddingGatewayClientTest fixes
amorton Jun 16, 2025
28bca03
EmbeddingProviderErrorMessageTest fixes
amorton Jun 16, 2025
7ce2af7
OpenAiEmbeddingClientTest fixes
amorton Jun 16, 2025
501871e
fmt
amorton Jun 16, 2025
618961e
RerankingProviderTest fixes
amorton Jun 16, 2025
9c72b30
CommandResolverWithVectorizerTest fixes
amorton Jun 16, 2025
a213c80
fix for vectorize IT's that used custom provider
amorton Jun 17, 2025
396e652
fmt
amorton Jun 17, 2025
131ad4e
Merge branch 'main' into yuqi/rerank-metering
amorton Jun 17, 2025
2030562
fixes missed from merge
amorton Jun 17, 2025
cc9b48f
code tidy
amorton Jun 17, 2025
9b215b8
InsertOneTableIntegrationTest fixes
amorton Jun 17, 2025
f65e364
fmt
amorton Jun 17, 2025
1a5c956
cody tidy
amorton Jun 17, 2025
599aa92
changes from review
amorton Jun 17, 2025
070f526
fixes for EGW to use the model usage
amorton Jun 23, 2025
83933a6
updates for the Embedding Gateway
amorton Jun 24, 2025
264652d
Merge branch 'main' into ajm/model-usage-for-egw
amorton Jun 24, 2025
cb6b65c
fix for EmbeddingProviderErrorMessageTest
amorton Jun 24, 2025
a3bf81d
Merge branch 'main' into ajm/model-usage-for-egw
Hazel-Datastax Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jakarta.ws.rs.client.ClientResponseContext;
import jakarta.ws.rs.client.ClientResponseFilter;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import org.slf4j.Logger;
Expand Down Expand Up @@ -37,11 +38,19 @@ public class EmbeddingProviderResponseValidation implements ClientResponseFilter
@Override
public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext)
throws JsonApiException {

// If the status is 0, it means something went wrong (maybe a timeout). Directly return and pass
// the error to the client
if (responseContext.getStatus() == 0) {
return;
}

// only validate for successful responses, errors may return non-JSON content,
// e.g. a HTTP 401 may just have "Unauthorized" in the response body
if (responseContext.getStatusInfo().getFamily() != Response.Status.Family.SUCCESSFUL) {
return;
}

// Throw error if there is no response body
if (!responseContext.hasEntity()) {
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,73 +93,72 @@ public Uni<BatchedEmbeddingResponse> vectorize(
AwsBasicCredentials.create(
embeddingCredentials.accessId().get(), embeddingCredentials.secretId().get());

try (var bedrockClient =
// NOTE: cannot put this client in a resource block for auto close because it will close
// te connection pool before we pull the async result.
var bedrockClient =
BedrockRuntimeAsyncClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(awsCreds))
.region(Region.of(vectorizeServiceParameters.get("region").toString()))
.build()) {

long callStartNano = System.nanoTime();

// NOTE: need to use the AWS client for the request, not a Rest Easy, so we cannot use
// all the features from the superclasses such as error mapping and building the model usage
var bytesUsageTracker = new ByteUsageTracker();
var bedrockFuture =
bedrockClient
.invokeModel(
requestBuilder -> {
try {
var inputData =
OBJECT_WRITER.writeValueAsBytes(
new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension));
bytesUsageTracker.requestBytes = inputData.length;
requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName());
} catch (JsonProcessingException e) {
throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException();
.build();

long callStartNano = System.nanoTime();

// NOTE: need to use the AWS client for the request, not a Rest Easy, so we cannot use
// all the features from the superclasses such as error mapping and building the model usage
var bytesUsageTracker = new ByteUsageTracker();
var bedrockFuture =
bedrockClient
.invokeModel(
requestBuilder -> {
try {
var inputData =
OBJECT_WRITER.writeValueAsBytes(
new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension));
bytesUsageTracker.requestBytes = inputData.length;
requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName());
} catch (JsonProcessingException e) {
throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException();
}
})
.thenApply(
rawResponse -> {
try {
// aws docs say do not need to close the stream
var inputStream = rawResponse.body().asInputStream();
var bedrockResponse =
OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class);
long callDurationNano = System.nanoTime() - callStartNano;

try (var countingOut =
new CountingOutputStream(OutputStream.nullOutputStream())) {
inputStream.transferTo(countingOut);
long responseSize = countingOut.getCount();
bytesUsageTracker.responseBytes =
responseSize > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) responseSize;
}
})
.thenApply(
rawResponse -> {
try {
// aws docs say do not need to close the stream
var inputStream = rawResponse.body().asInputStream();
var bedrockResponse =
OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class);
long callDurationNano = System.nanoTime() - callStartNano;

try (var countingOut =
new CountingOutputStream(OutputStream.nullOutputStream())) {
inputStream.transferTo(countingOut);
long responseSize = countingOut.getCount();
bytesUsageTracker.responseBytes =
responseSize > Integer.MAX_VALUE
? Integer.MAX_VALUE
: (int) responseSize;
}

var modelUsage =
createModelUsage(
embeddingCredentials.tenantId(),
ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
bedrockResponse.inputTextTokenCount(),
bedrockResponse.inputTextTokenCount(),
bytesUsageTracker.requestBytes,
bytesUsageTracker.responseBytes,
callDurationNano);

return new BatchedEmbeddingResponse(
batchId, List.of(bedrockResponse.embedding), modelUsage);

} catch (IOException e) {
throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException();
}
});

return Uni.createFrom()
.completionStage(bedrockFuture)
.onFailure(BedrockRuntimeException.class)
.transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable));
}
var modelUsage =
createModelUsage(
embeddingCredentials.tenantId(),
ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
bedrockResponse.inputTextTokenCount(),
bedrockResponse.inputTextTokenCount(),
bytesUsageTracker.requestBytes,
bytesUsageTracker.responseBytes,
callDurationNano);

return new BatchedEmbeddingResponse(
batchId, List.of(bedrockResponse.embedding), modelUsage);

} catch (IOException e) {
throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException();
}
});

return Uni.createFrom()
.completionStage(bedrockFuture)
.onFailure(BedrockRuntimeException.class)
.transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable));
}

private Throwable mapBedrockException(BedrockRuntimeException bedrockException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.core.*;
import java.net.URI;
import java.util.List;
Expand Down Expand Up @@ -41,9 +39,13 @@ public HuggingFaceEmbeddingProvider(
dimension,
vectorizeServiceParameters);

var baseUrl = serviceConfig.getBaseUrl(modelName());
// replace was added in https://github.com/stargate/data-api/pull/2108/files
var actualUrl = replaceParameters(baseUrl, Map.of("modelId", modelName()));

huggingFaceClient =
QuarkusRestClientBuilder.newBuilder()
.baseUri(URI.create(serviceConfig.getBaseUrl(modelName())))
.baseUri(URI.create(actualUrl))
.readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS)
.build(HuggingFaceEmbeddingProviderClient.class);
}
Expand Down Expand Up @@ -79,7 +81,7 @@ public Uni<BatchedEmbeddingResponse> vectorize(
var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get();

long callStartNano = System.nanoTime();
return retryHTTPCall(huggingFaceClient.embed(accessToken, modelName(), huggingFaceRequest))
return retryHTTPCall(huggingFaceClient.embed(accessToken, huggingFaceRequest))
.onItem()
.transform(
jakartaResponse -> {
Expand Down Expand Up @@ -136,12 +138,9 @@ public Uni<BatchedEmbeddingResponse> vectorize(
@RegisterProvider(ProviderHttpInterceptor.class)
public interface HuggingFaceEmbeddingProviderClient {
@POST
@Path("/{modelId}")
@ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON)
Uni<Response> embed(
@HeaderParam("Authorization") String accessToken,
@PathParam("modelId") String modelId,
HuggingFaceEmbeddingRequest request);
@HeaderParam("Authorization") String accessToken, HuggingFaceEmbeddingRequest request);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ protected String errorMessageJsonPtr() {
return "/message";
}

/**
* Mistral for 401 Unauthorized returns a response with no content type and just the text
* "Unauthorized".
*/
@Override
protected String responseErrorMessage(Response jakartaResponse) {
if (jakartaResponse.getStatus() == Response.Status.UNAUTHORIZED.getStatusCode()) {
return Response.Status.UNAUTHORIZED.getReasonPhrase();
}
return super.responseErrorMessage(jakartaResponse);
}

@Override
public Uni<BatchedEmbeddingResponse> vectorize(
int batchId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public VoyageAIEmbeddingProvider(
int dimension,
Map<String, Object> vectorizeServiceParameters) {
super(
ModelProvider.VERTEXAI,
ModelProvider.VOYAGE_AI,
providerConfig,
modelConfig,
serviceConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,12 @@ public static Optional<ModelInputType> fromEmbeddingGateway(
default -> Optional.empty();
};
}

public EmbeddingGateway.ModelUsage.InputType toEmbeddingGateway() {
return switch (this) {
case INPUT_TYPE_UNSPECIFIED -> EmbeddingGateway.ModelUsage.InputType.INPUT_TYPE_UNSPECIFIED;
case INDEX -> EmbeddingGateway.ModelUsage.InputType.INDEX;
case SEARCH -> EmbeddingGateway.ModelUsage.InputType.SEARCH;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,12 @@ public static Optional<ModelType> fromEmbeddingGateway(
default -> Optional.empty();
};
}

public EmbeddingGateway.ModelUsage.ModelType toEmbeddingGateway() {
return switch (this) {
case MODEL_TYPE_UNSPECIFIED -> EmbeddingGateway.ModelUsage.ModelType.MODEL_TYPE_UNSPECIFIED;
case EMBEDDING -> EmbeddingGateway.ModelUsage.ModelType.EMBEDDING;
case RERANKING -> EmbeddingGateway.ModelUsage.ModelType.RERANKING;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ public static ModelUsage fromEmbeddingGateway(EmbeddingGateway.ModelUsage grpcMo
grpcModelUsage.getCallDurationNanos());
}

public EmbeddingGateway.ModelUsage toEmbeddingGateway() {
return EmbeddingGateway.ModelUsage.newBuilder()
.setModelProvider(modelProvider.apiName())
.setModelType(modelType.toEmbeddingGateway())
.setModelName(modelName)
.setTenantId(tenantId)
.setInputType(inputType.toEmbeddingGateway())
.setPromptTokens(promptTokens)
.setTotalTokens(totalTokens)
.setRequestBytes(requestBytes)
.setResponseBytes(responseBytes)
.setCallDurationNanos(durationNanos)
.build();
}

/**
* Creates a new model usage that merges this and the other usage, to combine after batching.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.TimeoutException;
import java.util.function.Predicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -32,6 +33,14 @@
public abstract class ProviderBase {
protected static final Logger LOGGER = LoggerFactory.getLogger(ProviderBase.class);

// There is not a MediaType for text/json in Jakarta
private static final MediaType MEDIATYPE_TEXT_JSON = new MediaType("text", "json");

protected static final Predicate<MediaType> IS_JSON_MEDIA_TYPE =
mediaType ->
MediaType.APPLICATION_JSON_TYPE.isCompatible(mediaType)
|| MEDIATYPE_TEXT_JSON.isCompatible(mediaType);

private final ModelProvider modelProvider;
private final ModelType modelType;

Expand Down Expand Up @@ -203,12 +212,15 @@ protected String responseErrorMessage(Response jakartaResponse) {
MediaType contentType = jakartaResponse.getMediaType();
String raw = jakartaResponse.readEntity(String.class);

if (contentType == null || !MediaType.APPLICATION_JSON_TYPE.isCompatible(contentType)) {
LOGGER.error(
"Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}",
modelProvider(),
modelName(),
raw);
if (contentType == null || !IS_JSON_MEDIA_TYPE.test(contentType)) {
// we have an error, only need a debug
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}",
modelProvider(),
modelName(),
raw);
}
return raw;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,32 @@ public void filter(ClientRequestContext requestContext, ClientResponseContext re
long receivedBytes = 0;
long sentBytes = 0;

// we may still get called even if the request failed, and we do not get a valid HTTP response,
// for sanity check that we have the things we need to for processing.
boolean isValid =
responseContext != null
&& responseContext.getStatus() > 0
&& responseContext.getHeaders() != null;

if (!isValid) {
if (LOGGER.isWarnEnabled()) {
LOGGER.warn(
"filter() - Invalid responseContext, skipping sent/received bytes tracking. responseContext is null: {}, getStatus: {}, getHeaders: {}",
responseContext == null,
responseContext != null ? responseContext.getStatus() : "response null",
responseContext != null ? responseContext.getHeaders() : "response null");
}
return;
}

if (LOGGER.isTraceEnabled()) {
LOGGER.trace(
"ProviderHttpInterceptor.filter() - requestContext.getUri(): {}, requestContext.getHeaders(): {}",
"filter() - requestContext.getUri(): {}, requestContext.getHeaders(): {}",
requestContext.getUri(),
requestContext.getStringHeaders());

LOGGER.trace(
"ProviderHttpInterceptor.filter() - responseContext.getStatus(): {}, responseContext.getHeaders(): {}",
"filter() - responseContext.getStatus(): {}, responseContext.getHeaders(): {}",
responseContext.getStatus(),
responseContext.getHeaders());
}
Expand Down Expand Up @@ -98,6 +116,16 @@ public static int getReceivedBytes(Response jakartaResponse) {

private static int getHeaderInt(Response jakartaResponse, String headerName) {

if (jakartaResponse == null || jakartaResponse.getHeaders() == null) {
// log at trace, because this should be detected in filter() method
if (LOGGER.isTraceEnabled()) {
LOGGER.trace(
"getHeaderInt() - jakartaResponse or headers is null, returning 0 for headerName: {}",
headerName);
}
return 0;
}

var headerString = jakartaResponse.getHeaderString(headerName);
if (headerString != null && !headerString.isBlank()) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ protected String errorMessageJsonPtr() {
}

@Override
protected Uni<BatchedRerankingResponse> rerank(
public Uni<BatchedRerankingResponse> rerank(
int batchId, String query, List<String> passages, RerankingCredentials rerankingCredentials) {

// TODO: Move error to v2
Expand Down
Loading