diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderResponseValidation.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderResponseValidation.java index d26e12425d..034a7a3810 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderResponseValidation.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderResponseValidation.java @@ -7,6 +7,7 @@ import jakarta.ws.rs.client.ClientResponseContext; import jakarta.ws.rs.client.ClientResponseFilter; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.io.IOException; import java.nio.charset.StandardCharsets; import org.slf4j.Logger; @@ -37,11 +38,19 @@ public class EmbeddingProviderResponseValidation implements ClientResponseFilter @Override public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) throws JsonApiException { + // If the status is 0, it means something went wrong (maybe a timeout). Directly return and pass // the error to the client if (responseContext.getStatus() == 0) { return; } + + // only validate for successful responses, errors may return non-JSON content, + // e.g. a HTTP 401 may just have "Unauthorized" in the response body + if (responseContext.getStatusInfo().getFamily() != Response.Status.Family.SUCCESSFUL) { + return; + } + // Throw error if there is no response body if (!responseContext.hasEntity()) { throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java index 558f9e993e..368a4aa3c4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java @@ -93,73 +93,72 @@ public Uni vectorize( AwsBasicCredentials.create( embeddingCredentials.accessId().get(), embeddingCredentials.secretId().get()); - try (var bedrockClient = + // NOTE: cannot put this client in a resource block for auto close because it will close + // te connection pool before we pull the async result. + var bedrockClient = BedrockRuntimeAsyncClient.builder() .credentialsProvider(StaticCredentialsProvider.create(awsCreds)) .region(Region.of(vectorizeServiceParameters.get("region").toString())) - .build()) { - - long callStartNano = System.nanoTime(); - - // NOTE: need to use the AWS client for the request, not a Rest Easy, so we cannot use - // all the features from the superclasses such as error mapping and building the model usage - var bytesUsageTracker = new ByteUsageTracker(); - var bedrockFuture = - bedrockClient - .invokeModel( - requestBuilder -> { - try { - var inputData = - OBJECT_WRITER.writeValueAsBytes( - new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension)); - bytesUsageTracker.requestBytes = inputData.length; - requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName()); - } catch (JsonProcessingException e) { - throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); + .build(); + + long callStartNano = System.nanoTime(); + + // NOTE: need to use the AWS client for the request, not a Rest Easy, so we cannot use + // all the features from the superclasses such as error mapping and building the model usage + var bytesUsageTracker = new ByteUsageTracker(); + var bedrockFuture = + bedrockClient + .invokeModel( + requestBuilder -> { + try { + var inputData = + OBJECT_WRITER.writeValueAsBytes( + new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension)); + bytesUsageTracker.requestBytes = inputData.length; + requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName()); + } catch (JsonProcessingException e) { + throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); + } + }) + .thenApply( + rawResponse -> { + try { + // aws docs say do not need to close the stream + var inputStream = rawResponse.body().asInputStream(); + var bedrockResponse = + OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + try (var countingOut = + new CountingOutputStream(OutputStream.nullOutputStream())) { + inputStream.transferTo(countingOut); + long responseSize = countingOut.getCount(); + bytesUsageTracker.responseBytes = + responseSize > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) responseSize; } - }) - .thenApply( - rawResponse -> { - try { - // aws docs say do not need to close the stream - var inputStream = rawResponse.body().asInputStream(); - var bedrockResponse = - OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class); - long callDurationNano = System.nanoTime() - callStartNano; - - try (var countingOut = - new CountingOutputStream(OutputStream.nullOutputStream())) { - inputStream.transferTo(countingOut); - long responseSize = countingOut.getCount(); - bytesUsageTracker.responseBytes = - responseSize > Integer.MAX_VALUE - ? Integer.MAX_VALUE - : (int) responseSize; - } - - var modelUsage = - createModelUsage( - embeddingCredentials.tenantId(), - ModelInputType.fromEmbeddingRequestType(embeddingRequestType), - bedrockResponse.inputTextTokenCount(), - bedrockResponse.inputTextTokenCount(), - bytesUsageTracker.requestBytes, - bytesUsageTracker.responseBytes, - callDurationNano); - - return new BatchedEmbeddingResponse( - batchId, List.of(bedrockResponse.embedding), modelUsage); - - } catch (IOException e) { - throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); - } - }); - return Uni.createFrom() - .completionStage(bedrockFuture) - .onFailure(BedrockRuntimeException.class) - .transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable)); - } + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + bedrockResponse.inputTextTokenCount(), + bedrockResponse.inputTextTokenCount(), + bytesUsageTracker.requestBytes, + bytesUsageTracker.responseBytes, + callDurationNano); + + return new BatchedEmbeddingResponse( + batchId, List.of(bedrockResponse.embedding), modelUsage); + + } catch (IOException e) { + throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); + } + }); + + return Uni.createFrom() + .completionStage(bedrockFuture) + .onFailure(BedrockRuntimeException.class) + .transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable)); } private Throwable mapBedrockException(BedrockRuntimeException bedrockException) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java index 9b234b7ecd..9ccbc3116a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java @@ -12,8 +12,6 @@ import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.PathParam; import jakarta.ws.rs.core.*; import java.net.URI; import java.util.List; @@ -41,9 +39,13 @@ public HuggingFaceEmbeddingProvider( dimension, vectorizeServiceParameters); + var baseUrl = serviceConfig.getBaseUrl(modelName()); + // replace was added in https://github.com/stargate/data-api/pull/2108/files + var actualUrl = replaceParameters(baseUrl, Map.of("modelId", modelName())); + huggingFaceClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .baseUri(URI.create(actualUrl)) .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(HuggingFaceEmbeddingProviderClient.class); } @@ -79,7 +81,7 @@ public Uni vectorize( var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); long callStartNano = System.nanoTime(); - return retryHTTPCall(huggingFaceClient.embed(accessToken, modelName(), huggingFaceRequest)) + return retryHTTPCall(huggingFaceClient.embed(accessToken, huggingFaceRequest)) .onItem() .transform( jakartaResponse -> { @@ -136,12 +138,9 @@ public Uni vectorize( @RegisterProvider(ProviderHttpInterceptor.class) public interface HuggingFaceEmbeddingProviderClient { @POST - @Path("/{modelId}") @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) Uni embed( - @HeaderParam("Authorization") String accessToken, - @PathParam("modelId") String modelId, - HuggingFaceEmbeddingRequest request); + @HeaderParam("Authorization") String accessToken, HuggingFaceEmbeddingRequest request); } /** diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java index 1100ed3c27..e588ec39dd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java @@ -76,6 +76,18 @@ protected String errorMessageJsonPtr() { return "/message"; } + /** + * Mistral for 401 Unauthorized returns a response with no content type and just the text + * "Unauthorized". + */ + @Override + protected String responseErrorMessage(Response jakartaResponse) { + if (jakartaResponse.getStatus() == Response.Status.UNAUTHORIZED.getStatusCode()) { + return Response.Status.UNAUTHORIZED.getReasonPhrase(); + } + return super.responseErrorMessage(jakartaResponse); + } + @Override public Uni vectorize( int batchId, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java index 2890e1a2a4..3771d957ce 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java @@ -40,7 +40,7 @@ public VoyageAIEmbeddingProvider( int dimension, Map vectorizeServiceParameters) { super( - ModelProvider.VERTEXAI, + ModelProvider.VOYAGE_AI, providerConfig, modelConfig, serviceConfig, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java index 91c8b31d8a..1206595c72 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java @@ -36,4 +36,12 @@ public static Optional fromEmbeddingGateway( default -> Optional.empty(); }; } + + public EmbeddingGateway.ModelUsage.InputType toEmbeddingGateway() { + return switch (this) { + case INPUT_TYPE_UNSPECIFIED -> EmbeddingGateway.ModelUsage.InputType.INPUT_TYPE_UNSPECIFIED; + case INDEX -> EmbeddingGateway.ModelUsage.InputType.INDEX; + case SEARCH -> EmbeddingGateway.ModelUsage.InputType.SEARCH; + }; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java index 87348a37ca..637cc15965 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java @@ -23,4 +23,12 @@ public static Optional fromEmbeddingGateway( default -> Optional.empty(); }; } + + public EmbeddingGateway.ModelUsage.ModelType toEmbeddingGateway() { + return switch (this) { + case MODEL_TYPE_UNSPECIFIED -> EmbeddingGateway.ModelUsage.ModelType.MODEL_TYPE_UNSPECIFIED; + case EMBEDDING -> EmbeddingGateway.ModelUsage.ModelType.EMBEDDING; + case RERANKING -> EmbeddingGateway.ModelUsage.ModelType.RERANKING; + }; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java index 8217ab52b4..7877b00674 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java @@ -124,6 +124,21 @@ public static ModelUsage fromEmbeddingGateway(EmbeddingGateway.ModelUsage grpcMo grpcModelUsage.getCallDurationNanos()); } + public EmbeddingGateway.ModelUsage toEmbeddingGateway() { + return EmbeddingGateway.ModelUsage.newBuilder() + .setModelProvider(modelProvider.apiName()) + .setModelType(modelType.toEmbeddingGateway()) + .setModelName(modelName) + .setTenantId(tenantId) + .setInputType(inputType.toEmbeddingGateway()) + .setPromptTokens(promptTokens) + .setTotalTokens(totalTokens) + .setRequestBytes(requestBytes) + .setResponseBytes(responseBytes) + .setCallDurationNanos(durationNanos) + .build(); + } + /** * Creates a new model usage that merges this and the other usage, to combine after batching. * diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java index 1487fad15b..1e209b3aeb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java @@ -10,6 +10,7 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.TimeoutException; +import java.util.function.Predicate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +33,14 @@ public abstract class ProviderBase { protected static final Logger LOGGER = LoggerFactory.getLogger(ProviderBase.class); + // There is not a MediaType for text/json in Jakarta + private static final MediaType MEDIATYPE_TEXT_JSON = new MediaType("text", "json"); + + protected static final Predicate IS_JSON_MEDIA_TYPE = + mediaType -> + MediaType.APPLICATION_JSON_TYPE.isCompatible(mediaType) + || MEDIATYPE_TEXT_JSON.isCompatible(mediaType); + private final ModelProvider modelProvider; private final ModelType modelType; @@ -203,12 +212,15 @@ protected String responseErrorMessage(Response jakartaResponse) { MediaType contentType = jakartaResponse.getMediaType(); String raw = jakartaResponse.readEntity(String.class); - if (contentType == null || !MediaType.APPLICATION_JSON_TYPE.isCompatible(contentType)) { - LOGGER.error( - "Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}", - modelProvider(), - modelName(), - raw); + if (contentType == null || !IS_JSON_MEDIA_TYPE.test(contentType)) { + // we have an error, only need a debug + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}", + modelProvider(), + modelName(), + raw); + } return raw; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java index 2a0fa8f0d1..ecaace768a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java @@ -44,14 +44,32 @@ public void filter(ClientRequestContext requestContext, ClientResponseContext re long receivedBytes = 0; long sentBytes = 0; + // we may still get called even if the request failed, and we do not get a valid HTTP response, + // for sanity check that we have the things we need to for processing. + boolean isValid = + responseContext != null + && responseContext.getStatus() > 0 + && responseContext.getHeaders() != null; + + if (!isValid) { + if (LOGGER.isWarnEnabled()) { + LOGGER.warn( + "filter() - Invalid responseContext, skipping sent/received bytes tracking. responseContext is null: {}, getStatus: {}, getHeaders: {}", + responseContext == null, + responseContext != null ? responseContext.getStatus() : "response null", + responseContext != null ? responseContext.getHeaders() : "response null"); + } + return; + } + if (LOGGER.isTraceEnabled()) { LOGGER.trace( - "ProviderHttpInterceptor.filter() - requestContext.getUri(): {}, requestContext.getHeaders(): {}", + "filter() - requestContext.getUri(): {}, requestContext.getHeaders(): {}", requestContext.getUri(), requestContext.getStringHeaders()); LOGGER.trace( - "ProviderHttpInterceptor.filter() - responseContext.getStatus(): {}, responseContext.getHeaders(): {}", + "filter() - responseContext.getStatus(): {}, responseContext.getHeaders(): {}", responseContext.getStatus(), responseContext.getHeaders()); } @@ -98,6 +116,16 @@ public static int getReceivedBytes(Response jakartaResponse) { private static int getHeaderInt(Response jakartaResponse, String headerName) { + if (jakartaResponse == null || jakartaResponse.getHeaders() == null) { + // log at trace, because this should be detected in filter() method + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "getHeaderInt() - jakartaResponse or headers is null, returning 0 for headerName: {}", + headerName); + } + return 0; + } + var headerString = jakartaResponse.getHeaderString(headerName); if (headerString != null && !headerString.isBlank()) { try { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java index 9137815802..75cdabae6a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java @@ -89,7 +89,7 @@ protected String errorMessageJsonPtr() { } @Override - protected Uni rerank( + public Uni rerank( int batchId, String query, List passages, RerankingCredentials rerankingCredentials) { // TODO: Move error to v2 diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java index 65dcdf6208..5629b55238 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java @@ -90,9 +90,10 @@ public Uni rerank( /** * Subclasses must implement to do the reranking, after the batching is done. * - *

... + *

... NOTE: This is public because the embedding Gateway currently needs to call it, + * use the {@link #rerank(String, List, RerankingCredentials)} method instead. */ - protected abstract Uni rerank( + public abstract Uni rerank( int batchId, String query, List passages, RerankingCredentials rerankingCredentials); @Override diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java index e790891061..25465f62fb 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java @@ -122,14 +122,15 @@ public Map start() { .withHeader(HttpHeaders.CONTENT_TYPE, "application/xml") .withBody("list"))); + // The EmbeddingProviderResponseValidation only validates 2XX status responses, wireMockServer.stubFor( post(urlEqualTo("/v1/embeddings")) .withRequestBody(matchingJsonPath("$.input", containing("text/plain;charset=UTF-8"))) .willReturn( aResponse() .withHeader(HttpHeaders.CONTENT_TYPE, "text/plain;charset=UTF-8") - .withBody("Not Found") - .withStatus(500))); + .withBody("vectors as plain text") + .withStatus(200))); wireMockServer.stubFor( post(urlEqualTo("/v1/embeddings")) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java index bec9fa6d44..b422c90df2 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java @@ -197,7 +197,7 @@ public void testIncorrectContentTypePlainText() { "errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE) .hasFieldOrPropertyWithValue( "message", - "The Embedding Provider returned an unexpected response: Expected response Content-Type ('application/json' or 'text/json') from the embedding provider but found 'text/plain;charset=UTF-8'; HTTP Status: 500; The response body is: 'Not Found'."); + "The Embedding Provider returned an unexpected response: Expected response Content-Type ('application/json' or 'text/json') from the embedding provider but found 'text/plain;charset=UTF-8'; HTTP Status: 200; The response body is: 'vectors as plain text'."); } @Test