diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java index 9ad0427695..dc22ef4ad4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java @@ -5,7 +5,7 @@ import io.stargate.sgv2.jsonapi.api.model.command.CollectionOnlyCommand; import io.stargate.sgv2.jsonapi.api.model.command.CommandName; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; -import io.stargate.sgv2.jsonapi.config.constants.RerankingConstants; +import io.stargate.sgv2.jsonapi.config.constants.ServiceDescConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.schema.collections.DocumentPath; import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules; @@ -276,20 +276,20 @@ public record RerankServiceDesc( description = "Registered reranking service provider", type = SchemaType.STRING, implementation = String.class) - @JsonProperty(RerankingConstants.RerankingService.PROVIDER) + @JsonProperty(ServiceDescConstants.PROVIDER) String provider, @Schema( description = "Registered reranking service model", type = SchemaType.STRING, implementation = String.class) - @JsonProperty(RerankingConstants.RerankingService.MODEL_NAME) + @JsonProperty(ServiceDescConstants.MODEL_NAME) String modelName, @Valid @Nullable @Schema( description = "Authentication config for chosen reranking service", type = SchemaType.OBJECT) - @JsonProperty(RerankingConstants.RerankingService.AUTHENTICATION) + @JsonProperty(ServiceDescConstants.AUTHENTICATION) @JsonInclude(JsonInclude.Include.NON_NULL) Map authentication, @Nullable @@ -297,7 +297,7 @@ public record RerankServiceDesc( description = "Optional parameters that match the messageTemplate provided for the reranking provider", type = SchemaType.OBJECT) - @JsonProperty(RerankingConstants.RerankingService.PARAMETERS) + @JsonProperty(ServiceDescConstants.PARAMETERS) @JsonInclude(JsonInclude.Include.NON_NULL) Map parameters) { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java index e4e5888f9d..fd1b631642 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java @@ -4,7 +4,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.stargate.sgv2.jsonapi.config.constants.VectorConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.embedding.operation.HuggingFaceDedicatedEmbeddingProvider; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.validation.Valid; import jakarta.validation.constraints.*; import java.util.*; @@ -48,24 +49,30 @@ public VectorizeConfig( String modelName, Map authentication, Map parameters) { + if (provider == null) { throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( "'provider' in required property for 'vector.service' Object value"); } + this.provider = provider; + // HuggingfaceDedicated does not need user to specify model explicitly // If user specifies modelName other than endpoint-defined-model, will error out // By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder - if (ProviderConstants.HUGGINGFACE_DEDICATED.equals(provider)) { + if (ModelProvider.HUGGINGFACE_DEDICATED.apiName().equals(provider)) { if (modelName == null) { - modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL; - } else if (!modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) { + modelName = + HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL; + } else if (!modelName.equals( + HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL)) { throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( "'modelName' is not needed for embedding provider %s explicitly, only '%s' is accepted", - ProviderConstants.HUGGINGFACE_DEDICATED, - ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL); + ModelProvider.HUGGINGFACE_DEDICATED, + HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL); } } + this.modelName = modelName; if (authentication != null && !authentication.isEmpty()) { Map updatedAuth = new HashMap<>(); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java index 5706f61fce..171cb61745 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java @@ -6,9 +6,16 @@ * EmbeddingCredentials is a record that holds the embedding provider credentials for the embedding * service passed as header. * + *

Includes the tenantID, so we can fully identify the usage when creating the {@link + * io.stargate.sgv2.jsonapi.service.provider.ModelUsage} + * + * @param tenantId - Tenant Id that called the API. * @param apiKey - API token for the embedding service * @param accessId - Access Id used for AWS Bedrock embedding service * @param secretId - Secret Id used for AWS Bedrock embedding service */ public record EmbeddingCredentials( - Optional apiKey, Optional accessId, Optional secretId) {} + String tenantId, + Optional apiKey, + Optional accessId, + Optional secretId) {} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java index a05460c296..e479b4c94c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java @@ -79,10 +79,14 @@ public EmbeddingCredentials create( && collectionSupportsNoneAuth) { var authToken = requestContext.getHttpHeaders().getHeader(this.authTokenHeaderName); return new EmbeddingCredentials( - Optional.ofNullable(authToken), Optional.empty(), Optional.empty()); + requestContext.getTenantId().orElse(""), + Optional.ofNullable(authToken), + Optional.empty(), + Optional.empty()); } return new EmbeddingCredentials( + requestContext.getTenantId().orElse(""), Optional.ofNullable(embeddingApi), Optional.ofNullable(accessId), Optional.ofNullable(secretId)); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java index 0a7ab0d09d..c1f0f1738b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java @@ -2,6 +2,7 @@ import com.fasterxml.uuid.Generators; import com.fasterxml.uuid.NoArgGenerator; +import com.google.common.annotations.VisibleForTesting; import io.stargate.sgv2.jsonapi.api.request.tenant.DataApiTenantResolver; import io.stargate.sgv2.jsonapi.api.request.token.DataApiTokenResolver; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; @@ -35,20 +36,25 @@ public class RequestContext { private final String userAgent; - /** - * Constructor that will be useful in the offline library mode, where only the tenant will be set - * and accessed. - * - * @param tenantId Tenant Id - */ - public RequestContext(Optional tenantId) { + /** FOR TESTING ONLY - so we can bypass pulling things the headers, still messy, getting better */ + @VisibleForTesting + public RequestContext( + Optional tenantId, + Optional cassandraToken, + RerankingCredentials rerankingCredentials, + String userAgent) { this.tenantId = tenantId; - cassandraToken = Optional.empty(); - embeddingCredentialsSupplier = null; - rerankingCredentials = null; - httpHeaders = null; + this.cassandraToken = cassandraToken; + embeddingCredentialsSupplier = + new EmbeddingCredentialsSupplier( + HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME, + HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME, + HttpConstants.EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME, + HttpConstants.EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME); + this.rerankingCredentials = rerankingCredentials; + this.userAgent = userAgent; + this.httpHeaders = new HttpHeaderAccess(io.vertx.core.MultiMap.caseInsensitiveMultiMap()); requestId = generateRequestId(); - userAgent = null; } @Inject @@ -77,11 +83,14 @@ public RequestContext( HeaderBasedRerankingKeyResolver.resolveRerankingKey(routingContext); rerankingCredentials = rerankingApiKeyFromHeader - .map(apiKey -> new RerankingCredentials(Optional.of(apiKey))) + .map(apiKey -> new RerankingCredentials(this.tenantId.orElse(""), Optional.of(apiKey))) .orElse( this.cassandraToken - .map(cassandraToken -> new RerankingCredentials(Optional.of(cassandraToken))) - .orElse(new RerankingCredentials(Optional.empty()))); + .map( + cassandraToken -> + new RerankingCredentials( + this.tenantId.orElse(""), Optional.of(cassandraToken))) + .orElse(new RerankingCredentials(this.tenantId.orElse(""), Optional.empty()))); } private static String generateRequestId() { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java index 874075cb32..89a0d8d1c0 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java @@ -7,5 +7,8 @@ * resolved from the request header 'reranking-api-key', if it is not present, then we will use the * cassandra token as the reranking api key. Note, both cassandra token and reranking-api-key could * be absent in Data API request, although it is invalid for authentication. + * + *

Includes the tenantId, so we can fully identify the usage when creating the {@link + * io.stargate.sgv2.jsonapi.service.provider.ModelUsage} */ -public record RerankingCredentials(Optional apiKey) {} +public record RerankingCredentials(String tenantId, Optional apiKey) {} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java index 8e1cfb48a4..4240db973b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java @@ -255,7 +255,7 @@ public Uni> postCommand( if (vectorColDef != null && vectorColDef.vectorizeDefinition() != null) { embeddingProvider = - embeddingProviderFactory.getConfiguration( + embeddingProviderFactory.create( requestContext.getTenantId(), requestContext.getCassandraToken(), vectorColDef.vectorizeDefinition().provider(), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java index e34f36bc2f..a166409ab3 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java @@ -6,6 +6,4 @@ interface CollectionRerankingOptions { String ENABLED = "enabled"; String SERVICE = ServiceDescConstants.SERVICE; } - - interface RerankingService extends ServiceDescConstants {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java index a2442489df..5e3536bfd3 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java @@ -1,7 +1,7 @@ package io.stargate.sgv2.jsonapi.config.constants; /** Common service description constants shared between vector and reranking */ -interface ServiceDescConstants { +public interface ServiceDescConstants { String SERVICE = "service"; String PROVIDER = "provider"; String MODEL_NAME = "modelName"; diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java index c893394568..f26afda2fa 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java @@ -52,7 +52,8 @@ public DataVectorizer( SchemaObject schemaObject) { this.embeddingProvider = embeddingProvider; this.nodeFactory = nodeFactory; - this.embeddingCredentials = embeddingCredentials; + this.embeddingCredentials = + Objects.requireNonNull(embeddingCredentials, "embeddingCredentials must not be null"); this.schemaObject = schemaObject; } @@ -175,7 +176,7 @@ public Uni vectorize(String vectorizeContent) { List.of(vectorizeContent), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) - .map(EmbeddingProvider.Response::embeddings); + .map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings); return vectors .onItem() .transform( @@ -303,7 +304,7 @@ private Uni> vectorizeTexts( return embeddingProvider .vectorize(1, textsToVectorize, embeddingCredentials, requestType) - .map(EmbeddingProvider.Response::embeddings) + .map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings) .onItem() .transform( vectorData -> { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java index d9226c6f82..d69243fc71 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizerService.java @@ -72,6 +72,7 @@ public Uni vectorize( public DataVectorizer constructDataVectorizer( CommandContext commandContext) { + EmbeddingProvider embeddingProvider = Optional.ofNullable(commandContext.embeddingProvider()) .map( @@ -83,6 +84,7 @@ public DataVectorizer constructDataVectorizer( provider, commandContext.commandName())) .orElse(null); + return new DataVectorizer( embeddingProvider, objectMapper.getNodeFactory(), @@ -91,7 +93,7 @@ public DataVectorizer constructDataVectorizer( .getEmbeddingCredentialsSupplier() .create( commandContext.requestContext(), - embeddingProvider == null ? null : embeddingProvider.getProviderConfig()), + embeddingProvider == null ? null : embeddingProvider.providerConfig()), commandContext.schemaObject()); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/DefaultEmbeddingProviderConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/DefaultEmbeddingProviderConfig.java index 6d25b26844..7032e02133 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/DefaultEmbeddingProviderConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/DefaultEmbeddingProviderConfig.java @@ -10,6 +10,8 @@ */ @ConfigMapping(prefix = "stargate.jsonapi.embedding") public interface DefaultEmbeddingProviderConfig { + + // TODO: WHAT DOES THIS ACTUALLY RETURN ? @Nullable Map providers(); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java deleted file mode 100644 index 06f638d5f9..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderConfigStore.java +++ /dev/null @@ -1,87 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding.configuration; - -import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import java.util.Map; -import java.util.Optional; - -public interface EmbeddingProviderConfigStore { - - record ServiceConfig( - String serviceName, - String serviceProvider, - String baseUrl, - // `implementationClass` is the custom class that implements the EmbeddingProvider interface - Optional> implementationClass, - RequestProperties requestConfiguration, - Map> modelUrlOverrides) { - - public static ServiceConfig provider( - String serviceName, - String serviceProvider, - String baseUrl, - RequestProperties requestConfiguration, - Map> modelUrlOverrides) { - return new ServiceConfig( - serviceName, serviceProvider, baseUrl, null, requestConfiguration, modelUrlOverrides); - } - - public static ServiceConfig custom(Optional> implementationClass) { - return new ServiceConfig( - ProviderConstants.CUSTOM, - ProviderConstants.CUSTOM, - null, - implementationClass, - null, - Map.of()); - } - - public String getBaseUrl(String modelName) { - if (modelUrlOverrides != null && modelUrlOverrides.get(modelName) == null) { - // modelUrlOverride is a work-around for self-hosted nvidia models with different url. - // This is bad, initial design should have url in model level instead of provider level. - // As best practice, when we deprecate or EOL a model: - // we must mark the status in the configuration, - // instead of removing the whole configuration entry. - throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( - "unknown model '%s' for service provider '%s'", modelName, serviceProvider); - } - return modelUrlOverrides != null ? modelUrlOverrides.get(modelName).orElse(baseUrl) : baseUrl; - } - } - - record RequestProperties( - int atMostRetries, - int initialBackOffMillis, - int readTimeoutMillis, - int maxBackOffMillis, - double jitter, - Optional requestTypeQuery, - Optional requestTypeIndex, - // `maxBatchSize` is the maximum number of documents to be sent in a single request to be - // embedding provider - int maxBatchSize) { - public static RequestProperties of( - int atMostRetries, - int initialBackOffMillis, - int readTimeoutMillis, - int maxBackOffMillis, - double jitter, - Optional requestTypeQuery, - Optional requestTypeIndex, - int maxBatchSize) { - return new RequestProperties( - atMostRetries, - initialBackOffMillis, - readTimeoutMillis, - maxBackOffMillis, - jitter, - requestTypeQuery, - requestTypeIndex, - maxBatchSize); - } - } - - void saveConfiguration(Optional tenant, ServiceConfig serviceConfig); - - ServiceConfig getConfiguration(Optional tenant, String serviceName); -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java index 1a262f688f..cde7ad3819 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfig.java @@ -13,7 +13,10 @@ import java.util.Optional; import org.eclipse.microprofile.config.spi.Converter; +// TODO: SOME DOCUMENTATION FOR WHAT THIS IS MEANT TO DO!!! public interface EmbeddingProvidersConfig { + + // TODO: WHAT IS THE KEY FOR THIS MAP ????? Map providers(); @Nullable diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java index 601d5efeea..b76bd9536f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigImpl.java @@ -5,6 +5,7 @@ import java.util.*; import java.util.stream.Collectors; +// TODO: SOME DOCUMENTATION FOR WHAT THIS IS MEANT TO DO!!! public record EmbeddingProvidersConfigImpl( Map providers, CustomConfig custom) implements EmbeddingProvidersConfig { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java index b0c078c586..744b2c424a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProvidersConfigProducer.java @@ -15,6 +15,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +// TODO: SOME DOCUMENTATION FOR WHAT THIS IS MEANT TO DO!!! public class EmbeddingProvidersConfigProducer { private static final Logger LOG = LoggerFactory.getLogger(EmbeddingProvidersConfigProducer.class); @@ -33,14 +34,17 @@ EmbeddingProvidersConfig produce( OperationsConfig operationsConfig, EmbeddingProvidersConfig.CustomConfig customConfig, @GrpcClient("embedding") EmbeddingServiceGrpc.EmbeddingServiceBlockingStub embeddingService) { + EmbeddingProvidersConfig defaultConfig = new EmbeddingProvidersConfigImpl(defaultEmbeddingProviderConfig.providers(), customConfig); + // defaultEmbeddingProviderConfig is what we mapped from embedding-providers-config.yaml // and will be used if embedding-gateway is not enabled if (!operationsConfig.enableEmbeddingGateway()) { LOG.info("embedding gateway disabled, use default embedding config"); return defaultConfig; } + LOG.info("embedding gateway enabled, fetch supported providers from embedding gateway"); final EmbeddingGateway.GetSupportedProvidersRequest getSupportedProvidersRequest = EmbeddingGateway.GetSupportedProvidersRequest.newBuilder().build(); @@ -61,6 +65,7 @@ EmbeddingProvidersConfig produce( private EmbeddingProvidersConfig grpcResponseToConfig( EmbeddingGateway.GetSupportedProvidersResponse getSupportedProvidersResponse, EmbeddingProvidersConfig.CustomConfig customConfig) { + Map providerMap = new HashMap<>(); // traverse EmbeddingProvidersConfig in Grpc GetSupportedProvidersResponse @@ -113,8 +118,10 @@ private EmbeddingProvidersConfig grpcResponseToConfig( // 3. construct modelConfig list for the provider List providerModelList = new ArrayList<>(); + final List grpcProviderConfigModelsList = grpcProviderConfig.getModelsList(); + for (EmbeddingGateway.GetSupportedProvidersResponse.ProviderConfig.ModelConfig grpcModelConfig : grpcProviderConfigModelsList) { // construct parameterConfig List for the model diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java deleted file mode 100644 index 0b6cc19946..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedEmbeddingProviderConfigStore.java +++ /dev/null @@ -1,56 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding.configuration; - -import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.inject.Inject; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; - -@ApplicationScoped -public class PropertyBasedEmbeddingProviderConfigStore implements EmbeddingProviderConfigStore { - - @Inject private EmbeddingProvidersConfig config; - - @Override - public void saveConfiguration(Optional tenant, ServiceConfig serviceConfig) { - throw ErrorCodeV1.SERVER_INTERNAL_ERROR.toApiException( - "PropertyBasedEmbeddingProviderConfigStore.saveConfiguration() not implemented"); - } - - @Override - public EmbeddingProviderConfigStore.ServiceConfig getConfiguration( - Optional tenant, String serviceName) { - // already checked if the service exists and enabled in CreateCollectionCommandResolver - if (serviceName.equals(ProviderConstants.CUSTOM)) { - return ServiceConfig.custom(config.custom().clazz()); - } - if (config.providers().get(serviceName) == null - || !config.providers().get(serviceName).enabled()) { - throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException(serviceName); - } - - final var properties = config.providers().get(serviceName).properties(); - Map> modelwiseServiceUrlOverrides = - Objects.requireNonNull(config.providers().get(serviceName).models()).stream() - .collect( - HashMap::new, - (map, modelConfig) -> map.put(modelConfig.name(), modelConfig.serviceUrlOverride()), - HashMap::putAll); - return ServiceConfig.provider( - serviceName, - serviceName, - config.providers().get(serviceName).url().orElse(null), - RequestProperties.of( - properties.atMostRetries(), - properties.initialBackOffMillis(), - properties.readTimeoutMillis(), - properties.maxBackOffMillis(), - properties.jitter(), - properties.taskTypeRead(), - properties.taskTypeStore(), - properties.maxBatchSize()), - modelwiseServiceUrlOverrides); - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedServiceConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedServiceConfigStore.java new file mode 100644 index 0000000000..51e99a688f --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/PropertyBasedServiceConfigStore.java @@ -0,0 +1,77 @@ +package io.stargate.sgv2.jsonapi.service.embedding.configuration; + +import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * aaron - 17 june 2025 - as far as I can tell there is a single implementation of this interface, + * not sure why this is an interface, why this class exists, and why it is ApplicationScoped. + */ +@ApplicationScoped +public class PropertyBasedServiceConfigStore implements ServiceConfigStore { + + @Inject private EmbeddingProvidersConfig providersConfig; + + @Override + public ServiceConfigStore.ServiceConfig getConfiguration(ModelProvider modelProvider) { + + // already checked if the service exists and enabled in CreateCollectionCommandResolver + if (modelProvider == ModelProvider.CUSTOM) { + Objects.requireNonNull( + providersConfig.custom(), "ModelProvider is CUSTOM configuration has null custom config"); + Objects.requireNonNull( + providersConfig.custom().clazz(), "ModelProvider is CUSTOM configuration has null class"); + + /** + * See {@link DseTestResource} for where the implementationClass is set, and {@link + * PropertyBasedServiceConfigStore} for where it is read + */ + return ServiceConfig.forCustomProvider( + providersConfig + .custom() + .clazz() + .orElseThrow( + () -> + new IllegalStateException( + "ModelProvider is CUSTOM but no class is provided in configuration"))); + } + + var providerConfig = providersConfig.providers().get(modelProvider.apiName()); + if (providerConfig == null || !providerConfig.enabled()) { + throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException(modelProvider.apiName()); + } + + Objects.requireNonNull( + providerConfig.models(), + "ModelProvider configuration has null models, provider: " + modelProvider.apiName()); + + // aaron 16 June 2025 - not sure what this is doing, left in place for now + Map> modelwiseServiceUrlOverrides = + providerConfig.models().stream() + .collect( + HashMap::new, + (map, modelConfig) -> map.put(modelConfig.name(), modelConfig.serviceUrlOverride()), + HashMap::putAll); + + var requestProperties = providerConfig.properties(); + return ServiceConfig.forKnownProvider( + modelProvider, + providerConfig.url().orElse(null), + new ServiceRequestProperties( + requestProperties.atMostRetries(), + requestProperties.initialBackOffMillis(), + requestProperties.readTimeoutMillis(), + requestProperties.maxBackOffMillis(), + requestProperties.jitter(), + requestProperties.taskTypeRead(), + requestProperties.taskTypeStore(), + requestProperties.maxBatchSize()), + modelwiseServiceUrlOverrides); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ProviderConstants.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ProviderConstants.java deleted file mode 100644 index e47bc6738a..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ProviderConstants.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding.configuration; - -public final class ProviderConstants { - public static final String OPENAI = "openai"; - public static final String AZURE_OPENAI = "azureOpenAI"; - public static final String HUGGINGFACE = "huggingface"; - public static final String HUGGINGFACE_DEDICATED = "huggingfaceDedicated"; - public static final String HUGGINGFACE_DEDICATED_DEFINED_MODEL = "endpoint-defined-model"; - public static final String VERTEXAI = "vertexai"; - public static final String COHERE = "cohere"; - public static final String NVIDIA = "nvidia"; - public static final String UPSTAGE_AI = "upstageAI"; - public static final String VOYAGE_AI = "voyageAI"; - public static final String JINA_AI = "jinaAI"; - public static final String CUSTOM = "custom"; - public static final String MISTRAL = "mistral"; - public static final String BEDROCK = "bedrock"; - - // Private constructor to prevent instantiation - private ProviderConstants() {} -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ServiceConfigStore.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ServiceConfigStore.java new file mode 100644 index 0000000000..cdc2e721d2 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/ServiceConfigStore.java @@ -0,0 +1,93 @@ +package io.stargate.sgv2.jsonapi.service.embedding.configuration; + +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * aaron - 16 june 2025 - This used to be called the EmbeddingProviderConfigStore. + * + *

I think this is config that merges together the provider, and model config. The main thing it + * does is the 1) know where to get the name of the class for the custom config provider and 2) + * provide getBaseUrl() which coalesces the baseUrl with any model-specific overrides. Both of these + * things can be improved and this thing removed. + */ +public interface ServiceConfigStore { + + record ServiceConfig( + ModelProvider modelProvider, + // aaron 16 june 2025 - DANGER here, there is a method that changes the baseUrl use + // getBaseUrl() + // will refactor later + String baseUrl, + // `implementationClass` is the custom class that implements the EmbeddingProvider interface + Optional> implementationClass, + ServiceRequestProperties requestProperties, + Map> modelUrlOverrides) { + + public static ServiceConfig forKnownProvider( + ModelProvider modelProvider, + String baseUrl, + ServiceRequestProperties requestConfiguration, + Map> modelUrlOverrides) { + + return new ServiceConfig( + modelProvider, baseUrl, Optional.empty(), requestConfiguration, modelUrlOverrides); + } + + /** + * See {@link DseTestResource} for where the implementationClass is set, and {@link + * PropertyBasedServiceConfigStore} for where it is read + */ + public static ServiceConfig forCustomProvider(Class implementationClass) { + Objects.requireNonNull(implementationClass, "implementationClass must not be null"); + + // null for modelUrlOverrides important to say there is none available, see getBaseUrl() + return new ServiceConfig( + ModelProvider.CUSTOM, null, Optional.of(implementationClass), null, null); + } + + public String getBaseUrl(String modelName) { + + // aaron 16 june 2025 - leaving below for how this used to work, I think before I did some + // refactoring + // this method was not called all the time. No it is, so if there is no model + // override just return the baseUrl. + + // if (modelUrlOverrides != null && modelUrlOverrides.get(modelName) == null) { + // // modelUrlOverride is a work-around for self-hosted nvidia models with different + // url. + // // This is bad, initial design should have url in model level instead of provider + // level. + // // As best practice, when we deprecate or EOL a model: + // // we must mark the status in the configuration, + // // instead of removing the whole configuration entry. + // throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + // "unknown model '%s' for service provider '%s'", modelName, modelProvider); + // } + // return modelUrlOverrides != null ? modelUrlOverrides.get(modelName).orElse(baseUrl) : + // baseUrl; + + if (modelUrlOverrides == null) { + return baseUrl; + } + var override = modelUrlOverrides.get(modelName); + return override == null ? baseUrl : override.orElse(baseUrl); + } + } + + record ServiceRequestProperties( + int atMostRetries, + int initialBackOffMillis, + int readTimeoutMillis, + int maxBackOffMillis, + double jitter, + Optional requestTypeQuery, + Optional requestTypeIndex, + // `maxBatchSize` is the maximum number of documents to be sent in a single request to be + // embedding provider + int maxBatchSize) {} + + ServiceConfig getConfiguration(ModelProvider modelProvider); +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java index 86ccfc4b21..6d53be8a95 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/gateway/EmbeddingGatewayClient.java @@ -8,9 +8,10 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; -import java.util.Collections; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -33,57 +34,48 @@ public class EmbeddingGatewayClient extends EmbeddingProvider { /** Map to the value of `Token` in the header */ private static final String DATA_API_TOKEN = "DATA_API_TOKEN"; - private EmbeddingProviderConfigStore.RequestProperties requestProperties; - - private String provider; - - private int dimension; + private ServiceConfigStore.ServiceRequestProperties requestProperties; private Optional tenant; private Optional authToken; - private String modelName; - private String baseUrl; - private EmbeddingService embeddingService; - private Map vectorizeServiceParameter; + private EmbeddingService grpcGatewayClient; Map authentication; private String commandName; - /** - * @param requestProperties - * @param provider - Embedding provider `openai`, `cohere`, etc - * @param dimension - Dimension of the embedding to be returned - * @param tenant - Tenant id {aka database id} - * @param authToken - Auth token for the tenant - * @param baseUrl - base url of the embedding client - * @param modelName - Model name for the embedding provider - * @param embeddingService - Embedding service client - * @param vectorizeServiceParameter - Additional parameters for the vectorize service - */ + /** */ public EmbeddingGatewayClient( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String provider, + ModelProvider modelProvider, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, + Map vectorizeServiceParameters, Optional tenant, Optional authToken, - String baseUrl, - String modelName, - EmbeddingService embeddingService, - Map vectorizeServiceParameter, + EmbeddingService grpcGatewayClient, Map authentication, String commandName) { - this.requestProperties = requestProperties; - this.provider = provider; - this.dimension = dimension; + super( + modelProvider, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); + this.tenant = tenant; this.authToken = authToken; - this.modelName = modelName; - this.baseUrl = baseUrl; - this.embeddingService = embeddingService; - this.vectorizeServiceParameter = vectorizeServiceParameter; + this.grpcGatewayClient = grpcGatewayClient; this.authentication = authentication; this.commandName = commandName; } + @Override + protected String errorMessageJsonPtr() { + // not used , this is passing through the grpc error + return ""; + } + /** * Vectorize the given list of texts * @@ -93,48 +85,52 @@ public EmbeddingGatewayClient( * @return */ @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - Map - grpcVectorizeServiceParameter = new HashMap<>(); - if (vectorizeServiceParameter != null) { - vectorizeServiceParameter.forEach( + + var gatewayRequestParams = + new HashMap< + String, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue>(); + + if (vectorizeServiceParameters != null) { + vectorizeServiceParameters.forEach( (key, value) -> { if (value instanceof String) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setStrValue((String) value) .build()); else if (value instanceof Integer) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setIntValue((Integer) value) .build()); else if (value instanceof Float) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setFloatValue((Float) value) .build()); else if (value instanceof Boolean) - grpcVectorizeServiceParameter.put( + gatewayRequestParams.put( key, EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.ParameterValue.newBuilder() .setBoolValue((Boolean) value) .build()); }); } - EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest embeddingRequest = + + var gatewayEmbedding = EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.newBuilder() - .setModelName(modelName) + .setModelName(modelName()) .setDimensions(dimension) .setCommandName(commandName) - .putAllParameters(grpcVectorizeServiceParameter) + .putAllParameters(gatewayRequestParams) .setInputType( embeddingRequestType == EmbeddingRequestType.INDEX ? EmbeddingGateway.ProviderEmbedRequest.EmbeddingRequest.InputType.INDEX @@ -142,58 +138,60 @@ else if (value instanceof Boolean) .addAllInputs(texts) .build(); - final EmbeddingGateway.ProviderEmbedRequest.ProviderContext.Builder builder = + var contextBuilder = EmbeddingGateway.ProviderEmbedRequest.ProviderContext.newBuilder() - .setProviderName(provider) - .setTenantId(tenant.orElse(DEFAULT_TENANT_ID)); - // Add the value of `Token` in the header - builder.putAuthTokens(DATA_API_TOKEN, authToken.orElse("")); - // Add the value of `x-embedding-api-key` in the header - if (embeddingCredentials.apiKey().isPresent()) { - builder.putAuthTokens(EMBEDDING_API_KEY, embeddingCredentials.apiKey().get()); - } - // Add the value of `x-embedding-access-id` in the header - if (embeddingCredentials.accessId().isPresent()) { - builder.putAuthTokens(EMBEDDING_ACCESS_ID, embeddingCredentials.accessId().get()); - } - // Add the value of `x-embedding-secret-id` in the header - if (embeddingCredentials.secretId().isPresent()) { - builder.putAuthTokens(EMBEDDING_SECRET_ID, embeddingCredentials.secretId().get()); - } + .setProviderName(modelProvider().apiName()) + .setTenantId(tenant.orElse(DEFAULT_TENANT_ID)) + .putAuthTokens(DATA_API_TOKEN, authToken.orElse("")); + + embeddingCredentials + .apiKey() + .ifPresent(v -> contextBuilder.putAuthTokens(EMBEDDING_API_KEY, v)); + embeddingCredentials + .accessId() + .ifPresent(v -> contextBuilder.putAuthTokens(EMBEDDING_ACCESS_ID, v)); + embeddingCredentials + .secretId() + .ifPresent(v -> contextBuilder.putAuthTokens(EMBEDDING_SECRET_ID, v)); + // Add the `authentication` (sync service key) in the createCollection command if (authentication != null) { - builder.putAllAuthTokens(authentication); + contextBuilder.putAllAuthTokens(authentication); } - EmbeddingGateway.ProviderEmbedRequest.ProviderContext providerContext = builder.build(); - EmbeddingGateway.ProviderEmbedRequest providerEmbedRequest = + var gatewayRequest = EmbeddingGateway.ProviderEmbedRequest.newBuilder() - .setEmbeddingRequest(embeddingRequest) - .setProviderContext(providerContext) + .setEmbeddingRequest(gatewayEmbedding) + .setProviderContext(contextBuilder.build()) .build(); + + // aaron 17 June 2025 - unsure why this error handled was not in the uni pipleine below + // kept it as is when refactoring Uni embeddingResponse; try { - embeddingResponse = embeddingService.embed(providerEmbedRequest); + embeddingResponse = grpcGatewayClient.embed(gatewayRequest); } catch (StatusRuntimeException e) { if (e.getStatus().getCode().equals(Status.Code.DEADLINE_EXCEEDED)) { throw ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException(e, e.getMessage()); } throw e; } + return embeddingResponse .onItem() .transform( - resp -> { - if (resp.hasError()) { + gatewayResponse -> { + // TODO : move to V2 error + if (gatewayResponse.hasError()) { throw new JsonApiException( - ErrorCodeV1.valueOf(resp.getError().getErrorCode()), - resp.getError().getErrorMessage()); - } - if (resp.getEmbeddingsList() == null) { - return Response.of(batchId, Collections.emptyList()); + ErrorCodeV1.valueOf(gatewayResponse.getError().getErrorCode()), + gatewayResponse.getError().getErrorMessage()); } + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // but grpc will make sure resp.getEmbeddingsList() is never null + final List vectors = - resp.getEmbeddingsList().stream() + gatewayResponse.getEmbeddingsList().stream() .map( data -> { float[] embedding = new float[data.getEmbeddingCount()]; @@ -203,7 +201,8 @@ else if (value instanceof Boolean) return embedding; }) .toList(); - return Response.of(batchId, vectors); + return new BatchedEmbeddingResponse( + batchId, vectors, createModelUsage(gatewayResponse.getModelUsage())); }); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java index 7903fe6de7..558f9e993e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java @@ -9,169 +9,210 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectReader; import com.fasterxml.jackson.databind.ObjectWriter; +import com.google.common.io.CountingOutputStream; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import jakarta.ws.rs.core.Response; import java.io.IOException; +import java.io.OutputStream; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; -import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; /** Provider implementation for AWS Bedrock. To start we support only Titan embedding models. */ public class AwsBedrockEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.BEDROCK; - private static final ObjectWriter ow = new ObjectMapper().writer(); - private static final ObjectReader or = new ObjectMapper().reader(); + private static final ObjectWriter OBJECT_WRITER = new ObjectMapper().writer(); + private static final ObjectReader OBJECT_READER = new ObjectMapper().reader(); public AwsBedrockEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + Map vectorizeServiceParameters) { super( - requestProperties, - baseUrl, - model, - acceptsTitanAIDimensions(model.name()) ? dimension : 0, - vectorizeServiceParameters, - providerConfig); + ModelProvider.BEDROCK, + providerConfig, + modelConfig, + serviceConfig, + acceptsTitanAIDimensions(modelConfig.name()) ? dimension : 0, + vectorizeServiceParameters); } @Override - public Uni vectorize( + protected String errorMessageJsonPtr() { + // not used in this provider, has custom error handling + return ""; + } + + @Override + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model + checkEOLModelUsage(); + + // the config should mean we only do a batch on 1, sanity checking + if (texts.size() != 1) { + throw new IllegalArgumentException( + "AWS Bedrock embedding provider only supports a single text input per request, but received: " + + texts.size()); + } + + // TODO: move to V2 errors if (embeddingCredentials.accessId().isEmpty() && embeddingCredentials.secretId().isEmpty()) { throw ErrorCodeV1.EMBEDDING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( "Both '%s' and '%s' are missing in the header for provider '%s'", EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME, EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME, - providerId); + modelProvider().apiName()); } else if (embeddingCredentials.accessId().isEmpty()) { throw ErrorCodeV1.EMBEDDING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( "'%s' is missing in the header for provider '%s'", - EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME, providerId); + EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME, modelProvider().apiName()); } else if (embeddingCredentials.secretId().isEmpty()) { throw ErrorCodeV1.EMBEDDING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( "'%s' is missing in the header for provider '%s'", - EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME, providerId); + EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME, modelProvider().apiName()); } - AwsBasicCredentials awsCreds = + var awsCreds = AwsBasicCredentials.create( embeddingCredentials.accessId().get(), embeddingCredentials.secretId().get()); - BedrockRuntimeAsyncClient client = + try (var bedrockClient = BedrockRuntimeAsyncClient.builder() .credentialsProvider(StaticCredentialsProvider.create(awsCreds)) .region(Region.of(vectorizeServiceParameters.get("region").toString())) - .build(); - final CompletableFuture invokeModelResponseCompletableFuture = - client.invokeModel( - request -> { - final byte[] inputData; - try { - inputData = ow.writeValueAsBytes(new EmbeddingRequest(texts.get(0), dimension)); - request.body(SdkBytes.fromByteArray(inputData)).modelId(model.name()); - } catch (JsonProcessingException e) { - throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); - } - }); - - final CompletableFuture responseCompletableFuture = - invokeModelResponseCompletableFuture.thenApply( - res -> { - try { - EmbeddingResponse response = - or.readValue(res.body().asInputStream(), EmbeddingResponse.class); - List vectors = List.of(response.embedding); - return Response.of(batchId, vectors); - } catch (IOException e) { - throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); - } - }); - - return Uni.createFrom() - .completionStage(responseCompletableFuture) - .onFailure(BedrockRuntimeException.class) - .transform( - error -> { - BedrockRuntimeException bedrockRuntimeException = (BedrockRuntimeException) error; - // Status code == 408 and 504 for timeout - if (bedrockRuntimeException.statusCode() - == jakarta.ws.rs.core.Response.Status.REQUEST_TIMEOUT.getStatusCode() - || bedrockRuntimeException.statusCode() - == jakarta.ws.rs.core.Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { - return ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // Status code == 429 - if (bedrockRuntimeException.statusCode() - == jakarta.ws.rs.core.Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { - return ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // Status code in 4XX other than 429 - if (bedrockRuntimeException.statusCode() > 400 - && bedrockRuntimeException.statusCode() < 500) { - return ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // Status code in 5XX - if (bedrockRuntimeException.statusCode() >= 500) { - return ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - } - - // All other errors, Should never happen as all errors are covered above - return ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerId, - bedrockRuntimeException.statusCode(), - bedrockRuntimeException.getMessage()); - }); + .build()) { + + long callStartNano = System.nanoTime(); + + // NOTE: need to use the AWS client for the request, not a Rest Easy, so we cannot use + // all the features from the superclasses such as error mapping and building the model usage + var bytesUsageTracker = new ByteUsageTracker(); + var bedrockFuture = + bedrockClient + .invokeModel( + requestBuilder -> { + try { + var inputData = + OBJECT_WRITER.writeValueAsBytes( + new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension)); + bytesUsageTracker.requestBytes = inputData.length; + requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName()); + } catch (JsonProcessingException e) { + throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException(); + } + }) + .thenApply( + rawResponse -> { + try { + // aws docs say do not need to close the stream + var inputStream = rawResponse.body().asInputStream(); + var bedrockResponse = + OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + try (var countingOut = + new CountingOutputStream(OutputStream.nullOutputStream())) { + inputStream.transferTo(countingOut); + long responseSize = countingOut.getCount(); + bytesUsageTracker.responseBytes = + responseSize > Integer.MAX_VALUE + ? Integer.MAX_VALUE + : (int) responseSize; + } + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + bedrockResponse.inputTextTokenCount(), + bedrockResponse.inputTextTokenCount(), + bytesUsageTracker.requestBytes, + bytesUsageTracker.responseBytes, + callDurationNano); + + return new BatchedEmbeddingResponse( + batchId, List.of(bedrockResponse.embedding), modelUsage); + + } catch (IOException e) { + throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException(); + } + }); + + return Uni.createFrom() + .completionStage(bedrockFuture) + .onFailure(BedrockRuntimeException.class) + .transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable)); + } } - private record EmbeddingRequest( - String inputText, @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} + private Throwable mapBedrockException(BedrockRuntimeException bedrockException) { - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(float[] embedding, int inputTextTokenCount) {} + if (bedrockException.statusCode() == Response.Status.REQUEST_TIMEOUT.getStatusCode() + || bedrockException.statusCode() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + if (bedrockException.statusCode() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } + + if (bedrockException.statusCode() > 400 && bedrockException.statusCode() < 500) { + return ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } + + if (bedrockException.statusCode() >= 500) { + return ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); + } + + // All other errors, Should never happen as all errors are covered above + return ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), bedrockException.statusCode(), bedrockException.getMessage()); } + + private static class ByteUsageTracker { + int requestBytes = 0; + int responseBytes = 0; + } + + /** + * Request structure of the AWS Bedrock REST service. + * + *

.. + */ + public record AwsBedrockEmbeddingRequest( + String inputText, @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} + + /** + * Response structure of the AWS Bedrock REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record AwsBedrockEmbeddingResponse(float[] embedding, int inputTextTokenCount) {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java index f6de08edd1..033581bd48 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AzureOpenAIEmbeddingProvider.java @@ -2,23 +2,22 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -32,125 +31,141 @@ * details of REST API being called. */ public class AzureOpenAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.AZURE_OPENAI; - private final OpenAIEmbeddingProviderClient openAIEmbeddingProviderClient; + + private final AzureOpenAIEmbeddingProviderClient azureClient; public AzureOpenAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + Map vectorizeServiceParameters) { // One special case: legacy "ada-002" model does not accept "dimension" parameter super( - requestProperties, - baseUrl, - model, - acceptsOpenAIDimensions(model.name()) ? dimension : 0, - vectorizeServiceParameters, - providerConfig); - - String actualUrl = replaceParameters(baseUrl, vectorizeServiceParameters); - openAIEmbeddingProviderClient = + ModelProvider.AZURE_OPENAI, + providerConfig, + modelConfig, + serviceConfig, + acceptsOpenAIDimensions(modelConfig.name()) ? dimension : 0, + vectorizeServiceParameters); + + String actualUrl = + replaceParameters(serviceConfig.getBaseUrl(modelName()), vectorizeServiceParameters); + azureClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(actualUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) - .build(OpenAIEmbeddingProviderClient.class); + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) + .build(AzureOpenAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface OpenAIEmbeddingProviderClient { - @POST - // no path specified, as it is already included in the baseUri - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - // API keys as "api-key", MS Entra as "Authorization: Bearer [token] - @HeaderParam("api-key") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "error": {
-     *     "code": "401",
-     *     "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
-     *   }
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "message" node from the "error" node - JsonNode messageNode = rootNode.at("/error/message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - private record EmbeddingRequest( - String[] input, - String model, - @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * The example response body is: + * + *
+   * {
+   *   "error": {
+   *     "code": "401",
+   *     "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
+   *   }
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/error/message"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model + checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - String[] textArray = new String[texts.size()]; - EmbeddingRequest request = - new EmbeddingRequest(texts.toArray(textArray), model.name(), dimension); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + var azureRequest = + new AzureOpenAIEmbeddingRequest( + texts.toArray(new String[texts.size()]), modelName(), dimension); - // NOTE: NO "Bearer " prefix with API key for Azure OpenAI - Uni response = - applyRetry( - openAIEmbeddingProviderClient.embed(embeddingCredentials.apiKey().get(), request)); + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + // NOTE: NO "Bearer " prefix with API key for Azure + var accessToken = embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(azureClient.embed(accessToken, azureRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var azureResponse = + decodeResponse(jakartaResponse, AzureOpenAIEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (azureResponse.data() == null) { + throwEmptyData(jakartaResponse); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(azureResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(azureResponse.data()) + .map(AzureOpenAIEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + azureResponse.usage().prompt_tokens(), + azureResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Azure Open AI Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface AzureOpenAIEmbeddingProviderClient { + // no path specified, as it is already included in the baseUri + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + // API keys as "api-key", MS Entra as "Authorization: Bearer [token] + @HeaderParam("api-key") String accessToken, AzureOpenAIEmbeddingRequest request); + } + + /** + * Request structure of the Azure Open AI REST service. + * + *

.. + */ + public record AzureOpenAIEmbeddingRequest( + String[] input, + String model, + @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} + + /** + * Response structure of the Azure Open AI REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record AzureOpenAIEmbeddingResponse( + String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java index 62ea3ddb88..ca6fb86d3a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/CohereEmbeddingProvider.java @@ -1,24 +1,25 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -31,136 +32,157 @@ * of chosen Cohere model. */ public class CohereEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.COHERE; - private final CohereEmbeddingProviderClient cohereEmbeddingProviderClient; + + private final CohereEmbeddingProviderClient cohereClient; public CohereEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, vectorizeServiceParameters, providerConfig); - - cohereEmbeddingProviderClient = + Map vectorizeServiceParameters) { + super( + ModelProvider.COHERE, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); + + cohereClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(CohereEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface CohereEmbeddingProviderClient { - @POST - @Path("/embed") - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message": "invalid api token"
-     * }
-     *
-     * 429 response body:
-     * {
-     *   "data": "string"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Check if the root node contains a "message" field - JsonNode messageNode = rootNode.path("message"); - if (!messageNode.isMissingNode()) { - return messageNode.toString(); - } - // Check if the root node contains a "data" field - JsonNode dataNode = rootNode.path("data"); - if (!dataNode.isMissingNode()) { - return dataNode.toString(); - } - // Return the whole response body if no message or data field is found - return rootNode.toString(); - } + @Override + protected String errorMessageJsonPtr() { + // overriding the function that calls this + return ""; } - private record EmbeddingRequest(String[] texts, String model, String input_type) {} - - // @JsonIgnoreProperties({"id", "texts", "meta", "response_type"}) - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private static class EmbeddingResponse { - - protected EmbeddingResponse() {} - - private List embeddings; + /** + * The example response body is: + * + *
+   * {
+   *   "message": "invalid api token"
+   * }
+   *
+   * 429 response body:
+   * {
+   *   "data": "string"
+   * }
+   */
+  @Override
+  protected String responseErrorMessage(JsonNode rootNode) {
 
-    public List getEmbeddings() {
-      return embeddings;
+    JsonNode messageNode = rootNode.path("message");
+    if (!messageNode.isMissingNode()) {
+      return messageNode.toString();
     }
 
-    public void setEmbeddings(List embeddings) {
-      this.embeddings = embeddings;
+    JsonNode dataNode = rootNode.path("data");
+    if (!dataNode.isMissingNode()) {
+      return dataNode.toString();
     }
-  }
 
-  // Input type to be used for vector search should "search_query"
-  private static final String SEARCH_QUERY = "search_query";
-  private static final String SEARCH_DOCUMENT = "search_document";
+    // Return the whole response body if no message or data field is found
+    return rootNode.toString();
+  }
 
   @Override
-  public Uni vectorize(
+  public Uni vectorize(
       int batchId,
       List texts,
       EmbeddingCredentials embeddingCredentials,
       EmbeddingRequestType embeddingRequestType) {
-    // Check if using an EOF model
+
     checkEOLModelUsage();
-    checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey());
+    checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey());
 
-    String[] textArray = new String[texts.size()];
-    String input_type =
-        embeddingRequestType == EmbeddingRequestType.INDEX ? SEARCH_DOCUMENT : SEARCH_QUERY;
-    EmbeddingRequest request =
-        new EmbeddingRequest(texts.toArray(textArray), model.name(), input_type);
+    // Input type to be used for vector search should "search_query"
+    var input_type =
+        embeddingRequestType == EmbeddingRequestType.INDEX ? "search_document" : "search_query";
+    var cohereRequest =
+        new CohereEmbeddingRequest(
+            texts.toArray(new String[texts.size()]), modelName(), input_type);
 
-    Uni response =
-        applyRetry(
-            cohereEmbeddingProviderClient.embed(
-                HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(),
-                request));
+    // TODO: V2 error
+    // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty.
+    var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get();
 
-    return response
+    long callStartNano = System.nanoTime();
+
+    return retryHTTPCall(cohereClient.embed(accessToken, cohereRequest))
         .onItem()
         .transform(
-            resp -> {
-              if (resp.getEmbeddings() == null) {
-                return Response.of(batchId, Collections.emptyList());
+            jakartaResponse -> {
+              var cohereResponse = decodeResponse(jakartaResponse, CohereEmbeddingResponse.class);
+              long callDurationNano = System.nanoTime() - callStartNano;
+
+              // aaron - 10 June 2025 - previous code would silently swallow no data returned
+              // and return an empty result. If we made a request we should get a response.
+              if (cohereResponse.embeddings() == null) {
+                throwEmptyData(jakartaResponse);
               }
-              return Response.of(batchId, resp.getEmbeddings());
+
+              var modelUsage =
+                  createModelUsage(
+                      embeddingCredentials.tenantId(),
+                      ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
+                      cohereResponse.meta().billed_units().input_tokens(),
+                      cohereResponse.meta().billed_units().input_tokens(),
+                      jakartaResponse,
+                      callDurationNano);
+              return new BatchedEmbeddingResponse(
+                  batchId, cohereResponse.embeddings().values(), modelUsage);
             });
   }
 
-  @Override
-  public int maxBatchSize() {
-    return requestProperties.maxBatchSize();
+  /**
+   * REST client interface for the Cohere Embedding Service.
+   *
+   * 

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface CohereEmbeddingProviderClient { + @POST + @Path("/embed") + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, CohereEmbeddingRequest request); + } + + /** + * Request structure of the Cohere REST service. + * + *

.. + */ + public record CohereEmbeddingRequest(String[] texts, String model, String input_type) {} + + /** + * Response structure of the Cohere REST service. + * + *

aaron - 9 June 2025, change from class to record, check git if this breaks. + * https://docs.cohere.com/reference/embed#response + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record CohereEmbeddingResponse( + String id, List texts, Embeddings embeddings, Meta meta) { + @JsonIgnoreProperties(ignoreUnknown = true) + public record Embeddings(@JsonProperty("float") List values) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + public record Meta(ApiVersion api_version, BilledUnits billed_units, List warnings) { + @JsonIgnoreProperties(ignoreUnknown = true) + public record ApiVersion(String version, boolean is_experimental) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + public record BilledUnits(int input_tokens) {} + } } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java index e64b20d1cc..7781a96961 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProvider.java @@ -2,110 +2,74 @@ import static io.stargate.sgv2.jsonapi.config.constants.HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME; import static io.stargate.sgv2.jsonapi.exception.ErrorCodeV1.EMBEDDING_PROVIDER_API_KEY_MISSING; +import static jakarta.ws.rs.core.Response.Status.Family.CLIENT_ERROR; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.exception.SchemaException; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.*; import io.stargate.sgv2.jsonapi.util.recordable.Recordable; +import jakarta.ws.rs.core.Response; import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Interface that accepts a list of texts that needs to be vectorized and returns embeddings based - * of chosen model. - */ -public abstract class EmbeddingProvider { - protected static final Logger logger = LoggerFactory.getLogger(EmbeddingProvider.class); - protected final EmbeddingProviderConfigStore.RequestProperties requestProperties; - protected final String baseUrl; - protected final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model; +/** A provider for Embedding models , using {@link ModelType#EMBEDDING} */ +public abstract class EmbeddingProvider extends ProviderBase { + + protected static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingProvider.class); + + // IMPORTANT: all of these config objects have some form of a request properties config, + // use the one from the serviceConfig, as it should be the most specific for this + // schema object. We should be able to remove ServiceConfig later - aaron 16 jue 2025 + // use {@link #requestProperties()} to access the request properties + protected final EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig; + protected final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig; + protected final ServiceConfigStore.ServiceConfig serviceConfig; + protected final int dimension; protected final Map vectorizeServiceParameters; - protected final EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig; - /** Default constructor */ - protected EmbeddingProvider() { - this(null, null, null, 0, null, null); - } + protected final Duration initialBackOffDuration; + protected final Duration maxBackOffDuration; - /** Constructs an EmbeddingProvider with the specified configuration. */ protected EmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + ModelProvider modelProvider, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - this.requestProperties = requestProperties; - this.baseUrl = baseUrl; - this.model = model; + Map vectorizeServiceParameters) { + super(modelProvider, ModelType.EMBEDDING); + + this.providerConfig = providerConfig; + this.modelConfig = modelConfig; + this.serviceConfig = serviceConfig; this.dimension = dimension; this.vectorizeServiceParameters = vectorizeServiceParameters; - this.providerConfig = providerConfig; + + this.initialBackOffDuration = Duration.ofMillis(requestProperties().initialBackOffMillis()); + this.maxBackOffDuration = Duration.ofMillis(requestProperties().maxBackOffMillis()); } - public EmbeddingProvidersConfig.EmbeddingProviderConfig getProviderConfig() { - return providerConfig; + @Override + public String modelName() { + return modelConfig.name(); } - /** - * Applies a retry mechanism with backoff and jitter to the Uni returned by the embed() method, - * which makes an HTTP request to a third-party service. - * - * @param The type of the item emitted by the Uni. - * @param uni The Uni to which the retry mechanism should be applied. - * @return A Uni that will retry on the specified failures with the configured backoff and jitter. - */ - protected Uni applyRetry(Uni uni) { - return uni.onFailure( - throwable -> - (throwable.getCause() != null - && throwable.getCause() instanceof JsonApiException jae - && jae.getErrorCode() == ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT) - || throwable instanceof TimeoutException) - .retry() - .withBackOff( - Duration.ofMillis(requestProperties.initialBackOffMillis()), - Duration.ofMillis(requestProperties.maxBackOffMillis())) - .withJitter(requestProperties.jitter()) - .atMost(requestProperties.atMostRetries()); + @Override + public ApiModelSupport modelSupport() { + return modelConfig.apiModelSupport(); } - /** - * Checks if the vectorization will use an END_OF_LIFE model and throws an exception if it is. - * - *

As part of embedding model deprecation ability, any read and write with vectorization in an - * END_OF_LIFE model will throw an exception. - * - *

Note, SUPPORTED and DEPRECATED models are still allowed to be used in read and write. - * - *

This method should be called before any vectorization operation. - */ - protected void checkEOLModelUsage() { - // Validate if the model is END_OF_LIFE - if (model.apiModelSupport().status() == ApiModelSupport.SupportStatus.END_OF_LIFE) { - throw SchemaException.Code.END_OF_LIFE_AI_MODEL.get( - Map.of( - "model", - model.name(), - "modelStatus", - model.apiModelSupport().status().name(), - "message", - model - .apiModelSupport() - .message() - .orElse("The model is no longer supported (reached its end-of-life)."))); - } + public EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig() { + return providerConfig; } /** @@ -116,7 +80,7 @@ protected void checkEOLModelUsage() { * @param embeddingRequestType Type of request (INDEX or SEARCH) * @return VectorResponse */ - public abstract Uni vectorize( + public abstract Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, @@ -127,7 +91,17 @@ public abstract Uni vectorize( * * @return */ - public abstract int maxBatchSize(); + public int maxBatchSize() { + return requestProperties().maxBatchSize(); + } + + /** + * Use this to get the properties for the request, including the URL , see comment at the top of + * class + */ + protected ServiceConfigStore.ServiceRequestProperties requestProperties() { + return serviceConfig.requestProperties(); + } /** * Helper method that has logic wrt whether OpenAI (azure or regular) accepts {@code "dimensions"} @@ -165,9 +139,9 @@ protected static boolean acceptsTitanAIDimensions(String modelName) { } /** - * Helper method to replace parameters in a messageTemplate string with values from a map: - * placeholders are of form {@code {parameterName}} and matching value to look for in the map is - * String {@code "parameterName"}. + * Replace parameters in a messageTemplate string with values from a map: placeholders are of form + * {@code {parameterName}} and matching value to look for in the map is String {@code + * "parameterName"}. * * @param template Template with placeholders to replace * @param parameters Parameters to replace in the messageTemplate @@ -198,30 +172,108 @@ protected String replaceParameters(String template, Map paramete return baseUrl.toString(); } - /** Helper method to check if the API key is present in the header */ - protected void checkEmbeddingApiKeyHeader(String providerId, Optional apiKey) { + /** Check if the API key is present in the header */ + protected void checkEmbeddingApiKeyHeader(Optional apiKey) { + if (apiKey.isEmpty()) { throw EMBEDDING_PROVIDER_API_KEY_MISSING.toApiException( "header value `%s` is missing for embedding provider: %s", - EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME, providerId); + EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME, modelProvider().apiName()); } } + @Override + protected Duration initialBackOffDuration() { + return initialBackOffDuration; + } + + @Override + protected Duration maxBackOffDuration() { + return maxBackOffDuration; + } + + @Override + protected double jitter() { + return requestProperties().jitter(); + } + + @Override + protected int atMostRetries() { + return requestProperties().atMostRetries(); + } + + @Override + protected boolean decideRetry(Throwable throwable) { + + var retry = + (throwable.getCause() instanceof JsonApiException jae + && jae.getErrorCode() == ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT); + + return retry || super.decideRetry(throwable); + } + + /** Maps an HTTP response to a V1 JsonApiException */ + @Override + protected RuntimeException mapHTTPError(Response jakartaResponse, String errorMessage) { + + if (jakartaResponse.getStatus() == Response.Status.REQUEST_TIMEOUT.getStatusCode() + || jakartaResponse.getStatus() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // Status code == 429 + if (jakartaResponse.getStatus() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { + return ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // Status code in 4XX other than 429 + if (jakartaResponse.getStatusInfo().getFamily() == CLIENT_ERROR) { + return ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // Status code in 5XX + if (jakartaResponse.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) { + return ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // All other errors, Should never happen as all errors are covered above + return ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + /** Call this from the subclass when the response from the provider is empty */ + protected void throwEmptyData(Response jakartaResponse) { + throw ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), + jakartaResponse.getStatus(), + "ModelProvider returned empty data for model %s".formatted(modelName())); + } + /** * Record to hold the batchId and embedding vectors * * @param batchId - Sequence number for the batch to order the vectors. * @param embeddings - Embedding vectors for the given text inputs. */ - public record Response(int batchId, List embeddings) implements Recordable { - - public static Response of(int batchId, List embeddings) { - return new Response(batchId, embeddings); - } + public record BatchedEmbeddingResponse( + int batchId, List embeddings, ModelUsage modelUsage) implements Recordable { @Override public DataRecorder recordTo(DataRecorder dataRecorder) { - return dataRecorder.append("batchId", batchId).append("embeddings", embeddings); + return dataRecorder + .append("batchId", batchId) + .append("embeddings", embeddings) + .append("modelUsage", modelUsage); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java index c2b577c236..51ffe68b17 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java @@ -4,53 +4,58 @@ import io.stargate.embedding.gateway.EmbeddingService; import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.gateway.EmbeddingGatewayClient; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; import java.util.Map; import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @ApplicationScoped public class EmbeddingProviderFactory { - @Inject Instance embeddingProviderConfigStore; + private static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingProviderFactory.class); + + // aaron 16 june 2025 - unclear which is in Instance<> left as is for now + @Inject Instance embeddingProviderConfigStore; @Inject EmbeddingProvidersConfig embeddingProvidersConfig; - @Inject OperationsConfig config; + @Inject OperationsConfig operationsConfig; @GrpcClient("embedding") - EmbeddingService embeddingService; + EmbeddingService grpcGatewayClient; + @FunctionalInterface interface ProviderConstructor { EmbeddingProvider create( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameter, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig); + Map vectorizeServiceParameter); } - private static final Map providersMap = - // alphabetic order + // Immutable map, not concurrency concerns. + private static final Map EMBEDDING_PROVIDER_CTORS = Map.ofEntries( - Map.entry(ProviderConstants.AZURE_OPENAI, AzureOpenAIEmbeddingProvider::new), - Map.entry(ProviderConstants.COHERE, CohereEmbeddingProvider::new), - Map.entry(ProviderConstants.HUGGINGFACE, HuggingFaceEmbeddingProvider::new), + Map.entry(ModelProvider.AZURE_OPENAI, AzureOpenAIEmbeddingProvider::new), + Map.entry(ModelProvider.BEDROCK, AwsBedrockEmbeddingProvider::new), + Map.entry(ModelProvider.COHERE, CohereEmbeddingProvider::new), + Map.entry(ModelProvider.HUGGINGFACE, HuggingFaceEmbeddingProvider::new), Map.entry( - ProviderConstants.HUGGINGFACE_DEDICATED, HuggingFaceDedicatedEmbeddingProvider::new), - Map.entry(ProviderConstants.JINA_AI, JinaAIEmbeddingProvider::new), - Map.entry(ProviderConstants.MISTRAL, MistralEmbeddingProvider::new), - Map.entry(ProviderConstants.NVIDIA, NvidiaEmbeddingProvider::new), - Map.entry(ProviderConstants.OPENAI, OpenAIEmbeddingProvider::new), - Map.entry(ProviderConstants.UPSTAGE_AI, UpstageAIEmbeddingProvider::new), - Map.entry(ProviderConstants.VERTEXAI, VertexAIEmbeddingProvider::new), - Map.entry(ProviderConstants.VOYAGE_AI, VoyageAIEmbeddingProvider::new), - Map.entry(ProviderConstants.BEDROCK, AwsBedrockEmbeddingProvider::new)); - - public EmbeddingProvider getConfiguration( + ModelProvider.HUGGINGFACE_DEDICATED, HuggingFaceDedicatedEmbeddingProvider::new), + Map.entry(ModelProvider.JINA_AI, JinaAIEmbeddingProvider::new), + Map.entry(ModelProvider.MISTRAL, MistralEmbeddingProvider::new), + Map.entry(ModelProvider.NVIDIA, NvidiaEmbeddingProvider::new), + Map.entry(ModelProvider.OPENAI, OpenAIEmbeddingProvider::new), + Map.entry(ModelProvider.UPSTAGE_AI, UpstageAIEmbeddingProvider::new), + Map.entry(ModelProvider.VERTEXAI, VertexAIEmbeddingProvider::new), + Map.entry(ModelProvider.VOYAGE_AI, VoyageAIEmbeddingProvider::new)); + + public EmbeddingProvider create( Optional tenant, Optional authToken, String serviceName, @@ -59,13 +64,32 @@ public EmbeddingProvider getConfiguration( Map vectorizeServiceParameters, Map authentication, String commandName) { - if (vectorizeServiceParameters == null) { - vectorizeServiceParameters = Map.of(); + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "create() - tenant: {}, serviceName: {}, modelName: {}, dimension: {}, vectorizeServiceParameters: {}, commandName: {}", + tenant, + serviceName, + modelName, + dimension, + vectorizeServiceParameters, + commandName); } - return addService( + + // aaron 7 June 2025, the code previously threw this error when the name from the config was not + // found in the code, but this is a serious error that should not happen, it should be more like + // a IllegalState. + var modelProvider = + ModelProvider.fromApiName(serviceName) + .orElseThrow( + () -> + ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + "unknown service provider '%s'", serviceName)); + + return create( tenant, authToken, - serviceName, + modelProvider, modelName, dimension, vectorizeServiceParameters, @@ -73,41 +97,46 @@ public EmbeddingProvider getConfiguration( commandName); } - private synchronized EmbeddingProvider addService( + public EmbeddingProvider create( Optional tenant, Optional authToken, - String serviceName, + ModelProvider modelProvider, String modelName, int dimension, Map vectorizeServiceParameters, Map authentication, String commandName) { - final EmbeddingProviderConfigStore.ServiceConfig configuration = - embeddingProviderConfigStore.get().getConfiguration(tenant, serviceName); + if (vectorizeServiceParameters == null) { + vectorizeServiceParameters = Map.of(); + } - if (config.enableEmbeddingGateway()) { - return new EmbeddingGatewayClient( - configuration.requestConfiguration(), - configuration.serviceProvider(), - dimension, + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "create() - tenant: {}, modelProvider: {}, modelName: {}, dimension: {}, vectorizeServiceParameters: {}, commandName: {}", tenant, - authToken, - configuration.getBaseUrl(modelName), + modelProvider, modelName, - embeddingService, + dimension, vectorizeServiceParameters, - authentication, commandName); } - // CUSTOM is for test only - if (configuration.serviceProvider().equals(ProviderConstants.CUSTOM)) { - Optional> clazz = configuration.implementationClass(); - if (!clazz.isPresent()) { + // WARNING: aaron 15 june 2025, Refactored this, it was very messy + // leaving full types here because the names are very, very confusing + + ServiceConfigStore.ServiceConfig serviceConfig = + embeddingProviderConfigStore.get().getConfiguration(modelProvider); + + if (serviceConfig.modelProvider().equals(ModelProvider.CUSTOM)) { + // CUSTOM is for test only, but we cannot really check that here + // checking this and existing because the rest of the function is validating models etc exist. + Optional> clazz = serviceConfig.implementationClass(); + if (clazz.isEmpty()) { throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( "custom class undefined"); } + try { return (EmbeddingProvider) clazz.get().getConstructor(int.class).newInstance(dimension); } catch (Exception e) { @@ -117,34 +146,45 @@ private synchronized EmbeddingProvider addService( } } - ProviderConstructor ctor = providersMap.get(configuration.serviceProvider()); - if (ctor == null) { - throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( - "unknown service provider '%s'", configuration.serviceProvider()); - } - - // Get the provider, then get the model. - var providerConfig = embeddingProvidersConfig.providers().get(configuration.serviceProvider()); + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig = + embeddingProvidersConfig.providers().get(serviceConfig.modelProvider().apiName()); if (providerConfig == null) { throw ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( - "unknown service provider '%s'", configuration.serviceProvider()); + "unknown service provider '%s'", serviceConfig.modelProvider()); } - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model = - embeddingProvidersConfig.providers().get(configuration.serviceProvider()).models().stream() + + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig = + providerConfig.models().stream() .filter(m -> m.name().equals(modelName)) .findFirst() .orElseThrow( () -> ErrorCodeV1.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( "unknown model '%s' for service provider '%s'", - modelName, configuration.serviceProvider())); + modelName, serviceConfig.modelProvider())); + + if (operationsConfig.enableEmbeddingGateway()) { + return new EmbeddingGatewayClient( + modelProvider, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters, + tenant, + authToken, + grpcGatewayClient, + authentication, + commandName); + } + + var ctor = EMBEDDING_PROVIDER_CTORS.get(modelProvider); + if (ctor == null) { + throw new IllegalStateException( + "ModelProvider does not have a constructor: " + modelProvider); + } return ctor.create( - configuration.requestConfiguration(), - configuration.getBaseUrl(modelName), - model, - dimension, - vectorizeServiceParameters, - providerConfig); + providerConfig, modelConfig, serviceConfig, dimension, vectorizeServiceParameters); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java index 09bf39c00a..1e4c1c053c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceDedicatedEmbeddingProvider.java @@ -1,21 +1,21 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -24,121 +24,142 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class HuggingFaceDedicatedEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.HUGGINGFACE_DEDICATED; - private final HuggingFaceDedicatedEmbeddingProviderClient - huggingFaceDedicatedEmbeddingProviderClient; + + public static final String HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL = + "endpoint-defined-model"; + + private final HuggingFaceDedicatedEmbeddingProviderClient huggingFaceClient; public HuggingFaceDedicatedEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, vectorizeServiceParameters, providerConfig); + Map vectorizeServiceParameters) { + super( + ModelProvider.HUGGINGFACE_DEDICATED, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); // replace placeholders: endPointName, regionName, cloudName - String dedicatedApiUrl = replaceParameters(baseUrl, vectorizeServiceParameters); - huggingFaceDedicatedEmbeddingProviderClient = + String dedicatedApiUrl = + replaceParameters(serviceConfig.getBaseUrl(modelName()), vectorizeServiceParameters); + huggingFaceClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(dedicatedApiUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(HuggingFaceDedicatedEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface HuggingFaceDedicatedEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message": "Batch size error",
-     *   "type": "validation"
-     * }
-     *
-     * {
-     *   "message": "Model is overloaded",
-     *   "type": "overloaded"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "message" node - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - // huggingfaceDedicated, Test Embeddings Inference, openAI compatible route - // https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/openai_embed - private record EmbeddingRequest(String[] input) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * The example response body is: + * + *
+   * {
+   *   "message": "Batch size error",
+   *   "type": "validation"
+   * }
+   *
+   * {
+   *   "message": "Model is overloaded",
+   *   "type": "overloaded"
+   * }
+   */
+  @Override
+  protected String errorMessageJsonPtr() {
+    return "/message";
   }
 
   @Override
-  public Uni vectorize(
+  public Uni vectorize(
       int batchId,
       List texts,
       EmbeddingCredentials embeddingCredentials,
       EmbeddingRequestType embeddingRequestType) {
-    // Check if using an EOF model
+
     checkEOLModelUsage();
-    checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey());
+    checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey());
 
-    String[] textArray = new String[texts.size()];
-    EmbeddingRequest request = new EmbeddingRequest(texts.toArray(textArray));
+    var huggingFaceRequest =
+        new HuggingFaceDedicatedEmbeddingRequest(texts.toArray(new String[texts.size()]));
 
-    Uni response =
-        applyRetry(
-            huggingFaceDedicatedEmbeddingProviderClient.embed(
-                HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(),
-                request));
+    // TODO: V2 error
+    // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty.
+    var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get();
 
-    return response
+    long callStartNano = System.nanoTime();
+    return retryHTTPCall(huggingFaceClient.embed(accessToken, huggingFaceRequest))
         .onItem()
         .transform(
-            resp -> {
-              if (resp.data() == null) {
-                return Response.of(batchId, Collections.emptyList());
+            jakartaResponse -> {
+              var huggingFaceResponse =
+                  decodeResponse(jakartaResponse, HuggingFaceDedicatedEmbeddingResponse.class);
+              long callDurationNano = System.nanoTime() - callStartNano;
+
+              // aaron - 10 June 2025 - previous code would silently swallow no data returned
+              // and return an empty result. If we made a request we should get a response.
+              if (huggingFaceResponse.data() == null) {
+                throwEmptyData(jakartaResponse);
               }
-              Arrays.sort(resp.data(), (a, b) -> a.index() - b.index());
+
+              Arrays.sort(huggingFaceResponse.data(), (a, b) -> a.index() - b.index());
               List vectors =
-                  Arrays.stream(resp.data()).map(data -> data.embedding()).toList();
-              return Response.of(batchId, vectors);
+                  Arrays.stream(huggingFaceResponse.data())
+                      .map(HuggingFaceDedicatedEmbeddingResponse.Data::embedding)
+                      .toList();
+
+              var modelUsage =
+                  createModelUsage(
+                      embeddingCredentials.tenantId(),
+                      ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
+                      huggingFaceResponse.usage().prompt_tokens(),
+                      huggingFaceResponse.usage().total_tokens(),
+                      jakartaResponse,
+                      callDurationNano);
+              return new BatchedEmbeddingResponse(batchId, vectors, modelUsage);
             });
   }
 
-  @Override
-  public int maxBatchSize() {
-    return requestProperties.maxBatchSize();
+  /**
+   * REST client interface for the HuggingFace Dedicated Embedding Service.
+   *
+   * 

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface HuggingFaceDedicatedEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, + HuggingFaceDedicatedEmbeddingRequest request); + } + + /** + * Request structure of the HuggingFace Dedicated REST service. + * + *

huggingfaceDedicated, Test Embeddings Inference, openAI compatible route + * https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/openai_embed + */ + public record HuggingFaceDedicatedEmbeddingRequest(String[] input) {} + + /** + * Response structure of the HuggingFace Dedicated REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record HuggingFaceDedicatedEmbeddingResponse( + String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java index 4ce4c07e41..9b234b7ecd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java @@ -1,22 +1,21 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; -import jakarta.ws.rs.core.HttpHeaders; -import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.core.*; import java.net.URI; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -25,98 +24,132 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class HuggingFaceEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.HUGGINGFACE; - private final HuggingFaceEmbeddingProviderClient huggingFaceEmbeddingProviderClient; + + private final HuggingFaceEmbeddingProviderClient huggingFaceClient; public HuggingFaceEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, vectorizeServiceParameters, providerConfig); - - String actualUrl = replaceParameters(baseUrl, Map.of("modelId", model.name())); + Map vectorizeServiceParameters) { + super( + ModelProvider.HUGGINGFACE, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); - huggingFaceEmbeddingProviderClient = + huggingFaceClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(actualUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(HuggingFaceEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface HuggingFaceEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni> embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extracts the error message from the response body. The example response body is: - * - *

-     * {
-     *   "error": "Authorization header is correct, but the token seems invalid"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body, or null if the message is not - * found. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "error" node - JsonNode errorNode = rootNode.path("error"); - // Return the text of the "message" node, or the whole response body if it is missing - return errorNode.isMissingNode() ? rootNode.toString() : errorNode.toString(); - } - } - - private record EmbeddingRequest(List inputs, Options options) { - public record Options(boolean waitForModel) {} + /** + * The example response body is: + * + *
+   * {
+   *   "error": "Authorization header is correct, but the token seems invalid"
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/error"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model + checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - EmbeddingRequest request = new EmbeddingRequest(texts, new EmbeddingRequest.Options(true)); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + var huggingFaceRequest = + new HuggingFaceEmbeddingRequest(texts, new HuggingFaceEmbeddingRequest.Options(true)); - return applyRetry( - huggingFaceEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)) + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); + + long callStartNano = System.nanoTime(); + return retryHTTPCall(huggingFaceClient.embed(accessToken, modelName(), huggingFaceRequest)) .onItem() .transform( - resp -> { - if (resp == null) { - return Response.of(batchId, Collections.emptyList()); - } - return Response.of(batchId, resp); + jakartaResponse -> { + + // NOTE: Boxing happening here, as the response is a JSON array of arrays of floats. + // should return zero legnth list if entity is null or empty. + // TODO: how to deserialise without boxing ? + List vectorsBoxed = jakartaResponse.readEntity(new GenericType<>() {}); + long callDurationNano = System.nanoTime() - callStartNano; + + List vectorsUnboxed = + vectorsBoxed.stream() + .map( + vector -> { + if (vector == null) { + return new float[0]; // Handle null vectors + } + float[] unboxed = new float[vector.length]; + for (int i = 0; i < vector.length; i++) { + unboxed[i] = vector[i]; + } + return unboxed; + }) + .toList(); + + // The hugging face API we are calling does not return usage information, there may be + // a + // newer version of the API that does, but for now we will not return usage + // information. + // https://huggingface.co/blog/getting-started-with-embeddings + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectorsUnboxed, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the HuggingFace Embedding Service. + * + *

.. NOTE: the response is just a JSON array of arrays of floats, e.g.: + * + *

+   *   [[-0.123, 0.456, ...], [-0.789, 0.012, ...], ...]
+   * 
+ */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface HuggingFaceEmbeddingProviderClient { + @POST + @Path("/{modelId}") + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, + @PathParam("modelId") String modelId, + HuggingFaceEmbeddingRequest request); + } + + /** + * Request structure of the HuggingFace REST service. + * + *

.. + */ + public record HuggingFaceEmbeddingRequest(List inputs, Options options) { + public record Options(boolean waitForModel) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java index 15d062d1a5..815e34b976 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/JinaAIEmbeddingProvider.java @@ -2,21 +2,21 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -30,130 +30,141 @@ * called. */ public class JinaAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.JINA_AI; - private final JinaAIEmbeddingProviderClient jinaAIEmbeddingProviderClient; + + private final JinaAIEmbeddingProviderClient jinaClient; public JinaAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + Map vectorizeServiceParameters) { super( - requestProperties, - baseUrl, - model, - acceptsJinaAIDimensions(model.name()) ? dimension : 0, - vectorizeServiceParameters, - providerConfig); - - jinaAIEmbeddingProviderClient = + ModelProvider.JINA_AI, + providerConfig, + modelConfig, + serviceConfig, + acceptsJinaAIDimensions(modelConfig.name()) ? dimension : 0, + vectorizeServiceParameters); + + jinaClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(JinaAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface JinaAIEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *    "detail": "ValidationError(model='TextDoc', errors=[{'loc': ('text',), 'msg': 'Single text cannot exceed 8192 tokens. 10454 tokens given.', 'type': 'value_error'}])"
-     * }
-     * 
- * - *
-     *     {"detail":"Failed to authenticate with the provided api key."}
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "detail" node - JsonNode detailNode = rootNode.path("detail"); - return detailNode.isMissingNode() ? rootNode.toString() : detailNode.toString(); - } - } - - // By default, Jina Text Encoding Format is float - private record EmbeddingRequest( - List input, - String model, - @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions, - @JsonInclude(value = JsonInclude.Include.NON_NULL) String task, - @JsonInclude(value = JsonInclude.Include.NON_NULL) Boolean late_chunking) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * Extract the error message from the response body. The example response body is: + * + *
+   * {
+   *    "detail": "ValidationError(model='TextDoc', errors=[{'loc': ('text',), 'msg': 'Single text cannot exceed 8192 tokens. 10454 tokens given.', 'type': 'value_error'}])"
+   * }
+   * 
+ * + *
+   *     {"detail":"Failed to authenticate with the provided api key."}
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/detail"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model + checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); - EmbeddingRequest request = - new EmbeddingRequest( + var jinaRequest = + new JinaEmbeddingRequest( texts, - model.name(), + modelName(), dimension, (String) vectorizeServiceParameters.get("task"), (Boolean) vectorizeServiceParameters.get("late_chunking")); - Uni response = - applyRetry( - jinaAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(jinaClient.embed(accessToken, jinaRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var jinaResponse = decodeResponse(jakartaResponse, JinaEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (jinaResponse.data() == null) { + throwEmptyData(jakartaResponse); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(jinaResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(EmbeddingResponse.Data::embedding).toList(); - return Response.of(batchId, vectors); + Arrays.stream(jinaResponse.data()) + .map(JinaEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + jinaResponse.usage().prompt_tokens(), + jinaResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Jina Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface JinaAIEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, JinaEmbeddingRequest request); + } + + /** + * Request structure of the Voyage REST service. + * + *

By default, Jina Text Encoding Format is float + */ + public record JinaEmbeddingRequest( + List input, + String model, + @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions, + @JsonInclude(value = JsonInclude.Include.NON_NULL) String task, + @JsonInclude(value = JsonInclude.Include.NON_NULL) Boolean late_chunking) {} + + /** + * Response structure of the Jina REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record JinaEmbeddingResponse(String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + public record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + public record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java index 5851d5ba85..20afc92e13 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MeteredEmbeddingProvider.java @@ -11,10 +11,15 @@ import io.stargate.sgv2.jsonapi.api.request.RequestContext; import io.stargate.sgv2.jsonapi.api.v1.metrics.JsonApiMetricsConfig; import io.stargate.sgv2.jsonapi.metrics.MetricsConstants; +import io.stargate.sgv2.jsonapi.service.provider.ModelUsage; +import io.stargate.sgv2.jsonapi.util.recordable.PrettyPrintable; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Provides a metered version of an {@link EmbeddingProvider}, adding metrics collection to the @@ -23,6 +28,10 @@ * input texts. */ public class MeteredEmbeddingProvider extends EmbeddingProvider { + private static final Logger LOGGER = LoggerFactory.getLogger(MeteredEmbeddingProvider.class); + + private static final String UNKNOWN_TENANT_ID = "unknown"; + private final MeterRegistry meterRegistry; private final JsonApiMetricsConfig jsonApiMetricsConfig; private final RequestContext requestContext; @@ -35,6 +44,16 @@ public MeteredEmbeddingProvider( RequestContext requestContext, EmbeddingProvider embeddingProvider, String commandName) { + // aaron 9 June 2025 - we need to remove this "metered" design pattern, for now just pass the + // config through + super( + embeddingProvider.modelProvider(), + embeddingProvider.providerConfig, + embeddingProvider.modelConfig, + embeddingProvider.serviceConfig, + embeddingProvider.dimension, + embeddingProvider.vectorizeServiceParameters); + this.meterRegistry = meterRegistry; this.jsonApiMetricsConfig = jsonApiMetricsConfig; this.requestContext = requestContext; @@ -42,6 +61,12 @@ public MeteredEmbeddingProvider( this.commandName = commandName; } + @Override + protected String errorMessageJsonPtr() { + // not used we are just passing through + return ""; + } + /** * Vectorizes a list of texts, adding metrics collection for the duration of the vectorization * call and the size of the input texts. @@ -52,11 +77,16 @@ public MeteredEmbeddingProvider( * @return a {@link Uni} that will provide the list of vectorized texts, as arrays of floats. */ @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { + + Objects.requireNonNull(texts, "texts must not be null"); + Objects.requireNonNull(embeddingCredentials, "embeddingCredentials must not be null"); + Objects.requireNonNull(embeddingRequestType, "embeddingRequestType type must not be null"); + // String bytes metrics for vectorize DistributionSummary ds = DistributionSummary.builder(jsonApiMetricsConfig.vectorizeInputBytesMetrics()) @@ -94,11 +124,24 @@ public Uni vectorize( Collections.sort( vectorizedBatches, (a, b) -> Integer.compare(a.batchId(), b.batchId())); List result = new ArrayList<>(); - for (Response vectorizedBatch : vectorizedBatches) { + + ModelUsage aggregatedModelUsage = null; + for (BatchedEmbeddingResponse vectorizedBatch : vectorizedBatches) { + + aggregatedModelUsage = + aggregatedModelUsage == null + ? vectorizedBatch.modelUsage() + : aggregatedModelUsage.merge(vectorizedBatch.modelUsage()); // create the final ordered result result.addAll(vectorizedBatch.embeddings()); } - return Response.of(1, result); + var embeddingResponse = new BatchedEmbeddingResponse(1, result, aggregatedModelUsage); + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "Vectorize call completed, aggregatedModelUsage: {}", + PrettyPrintable.print(aggregatedModelUsage)); + } + return embeddingResponse; }) .invoke( () -> diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java index b0962c7f8f..1100ed3c27 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java @@ -1,21 +1,21 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -29,122 +29,137 @@ * REST API being called. */ public class MistralEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.MISTRAL; - private final MistralEmbeddingProviderClient mistralEmbeddingProviderClient; + + private final MistralEmbeddingProviderClient mistralClient; public MistralEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, vectorizeServiceParameters, providerConfig); + Map vectorizeServiceParameters) { + super( + ModelProvider.MISTRAL, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); - mistralEmbeddingProviderClient = + mistralClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(MistralEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface MistralEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extracts the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message":"Unauthorized",
-     *   "request_id":"1383ed1b472cb85fdfaa9624515d2d0e"
-     * }
-     *
-     * {
-     *   "object":"error",
-     *   "message":"Input is too long. Max length is 8192 got 10970",
-     *   "type":"invalid_request_error",
-     *   "param":null,
-     *   "code":null
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body, or null if the message is not - * found. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.info( - String.format( - "Error response from embedding provider '%s': %s", providerId, rootNode.toString())); - // Extract the "message" node from the root node - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - private record EmbeddingRequest(List input, String model, String encoding_format) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse( - String id, String object, Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(String object, int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage( - int prompt_tokens, int total_tokens, int completion_tokens, int request_count) {} + /** + * The example response body is: + * + *
+   * {
+   *   "message":"Unauthorized",
+   *   "request_id":"1383ed1b472cb85fdfaa9624515d2d0e"
+   * }
+   *
+   * {
+   *   "object":"error",
+   *   "message":"Input is too long. Max length is 8192 got 10970",
+   *   "type":"invalid_request_error",
+   *   "param":null,
+   *   "code":null
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/message"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model + checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); - EmbeddingRequest request = new EmbeddingRequest(texts, model.name(), "float"); + var mistralRequest = new MistralEmbeddingRequest(texts, modelName(), "float"); + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - Uni response = - applyRetry( - mistralEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + long callStartNano = System.nanoTime(); - return response + return retryHTTPCall(mistralClient.embed(accessToken, mistralRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var mistralResponse = decodeResponse(jakartaResponse, MistralEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (mistralResponse.data() == null) { + throwEmptyData(jakartaResponse); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(mistralResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(mistralResponse.data()) + .map(MistralEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + mistralResponse.usage().prompt_tokens(), + mistralResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Mistral Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface MistralEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, MistralEmbeddingRequest request); + } + + /** + * Request structure of the Mistral REST service. + * + *

.. + */ + public record MistralEmbeddingRequest(List input, String model, String encoding_format) {} + + /** + * Response structure of the Mistral REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record MistralEmbeddingResponse( + String id, String object, Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(String object, int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage( + int prompt_tokens, int total_tokens, int completion_tokens, int request_count) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java index a7d362a213..3ab41810cd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/NvidiaEmbeddingProvider.java @@ -1,24 +1,23 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -31,116 +30,128 @@ * of chosen Nvidia model. */ public class NvidiaEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.NVIDIA; - private final NvidiaEmbeddingProviderClient nvidiaEmbeddingProviderClient; + + private final NvidiaEmbeddingProviderClient nvidiaClient; public NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, vectorizeServiceParameters, providerConfig); - - nvidiaEmbeddingProviderClient = + Map vectorizeServiceParameters) { + super( + ModelProvider.NVIDIA, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); + + nvidiaClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(NvidiaEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface NvidiaEmbeddingProviderClient { - @POST - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "object": "error",
-     *   "message": "Input length exceeds the maximum token length of the model",
-     *   "detail": {},
-     *   "type": "invalid_request_error"
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } - } - - private record EmbeddingRequest(String[] input, String model, String input_type) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Data(int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} + /** + * The example response body is: + * + *
+   * {
+   *   "object": "error",
+   *   "message": "Input length exceeds the maximum token length of the model",
+   *   "detail": {},
+   *   "type": "invalid_request_error"
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/message"; } - private static final String PASSAGE = "passage"; - private static final String QUERY = "query"; - @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model - checkEOLModelUsage(); - String[] textArray = new String[texts.size()]; - String input_type = embeddingRequestType == EmbeddingRequestType.INDEX ? PASSAGE : QUERY; - EmbeddingRequest request = - new EmbeddingRequest(texts.toArray(textArray), model.name(), input_type); + checkEOLModelUsage(); + var input_type = embeddingRequestType == EmbeddingRequestType.INDEX ? "passage" : "query"; + var nvidiaRequest = + new NvidiaEmbeddingRequest( + texts.toArray(new String[texts.size()]), modelName(), input_type); - Uni response = - applyRetry( - nvidiaEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().orElse(""), - request)); + // TODO: XXX No token to pass with the nvidia request for now. This will change on main merge + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY; - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(nvidiaClient.embed(accessToken, nvidiaRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var nvidiaResponse = decodeResponse(jakartaResponse, NvidiaEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (nvidiaResponse.data() == null) { + throwEmptyData(jakartaResponse); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + Arrays.sort(nvidiaResponse.data(), (a, b) -> a.index() - b.index()); List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(nvidiaResponse.data()) + .map(NvidiaEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + nvidiaResponse.usage().prompt_tokens(), + nvidiaResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the NVidia Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface NvidiaEmbeddingProviderClient { + @POST + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, NvidiaEmbeddingRequest request); + } + + /** + * Request structure of the Nidia REST service. + * + *

.. + */ + public record NvidiaEmbeddingRequest(String[] input, String model, String input_type) {} + + /** + * Response structure of the Nvidia REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error + private record NvidiaEmbeddingResponse(Data[] data, String model, Usage usage) { + @JsonIgnoreProperties(ignoreUnknown = true) + private record Data(int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + private record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java index 1d271988dd..3aba5d0b65 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAIEmbeddingProvider.java @@ -2,25 +2,24 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -29,135 +28,143 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class OpenAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.OPENAI; - private final OpenAIEmbeddingProviderClient openAIEmbeddingProviderClient; + + private final OpenAIEmbeddingProviderClient openAIClient; public OpenAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { + Map vectorizeServiceParameters) { // One special case: legacy "ada-002" model does not accept "dimension" parameter super( - requestProperties, - baseUrl, - model, - acceptsOpenAIDimensions(model.name()) ? dimension : 0, - vectorizeServiceParameters, - providerConfig); - - openAIEmbeddingProviderClient = + ModelProvider.OPENAI, + providerConfig, + modelConfig, + serviceConfig, + acceptsOpenAIDimensions(modelConfig.name()) ? dimension : 0, + vectorizeServiceParameters); + + openAIClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(OpenAIEmbeddingProviderClient.class); } + /** + * The example response body is: + * + *

+   * {
+   *   "error": {
+   *     "message": "You exceeded your current quota, please check your plan and billing details. For
+   *                 more information on this error, read the docs:
+   *                 https://platform.openai.com/docs/guides/error-codes/api-errors.",
+   *     "type": "insufficient_quota",
+   *     "param": null,
+   *     "code": "insufficient_quota"
+   *   }
+   * }
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/error/message"; + } + + @Override + public Uni vectorize( + int batchId, + List texts, + EmbeddingCredentials embeddingCredentials, + EmbeddingRequestType embeddingRequestType) { + + checkEOLModelUsage(); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + + var openAiRequest = + new OpenAiEmbeddingRequest(texts.toArray(new String[texts.size()]), modelName(), dimension); + var organizationId = (String) vectorizeServiceParameters.get("organizationId"); + var projectId = (String) vectorizeServiceParameters.get("projectId"); + + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); + + final long callStartNano = System.nanoTime(); + return retryHTTPCall(openAIClient.embed(accessToken, organizationId, projectId, openAiRequest)) + .onItem() + .transform( + jakartaResponse -> { + var openAiResponse = decodeResponse(jakartaResponse, OpenAiEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (openAiResponse.data() == null) { + throwEmptyData(jakartaResponse); + } + Arrays.sort(openAiResponse.data(), (a, b) -> a.index() - b.index()); + List vectors = + Arrays.stream(openAiResponse.data()) + .map(OpenAiEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + openAiResponse.usage().prompt_tokens(), + openAiResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); + }); + } + + /** + * REST client interface for the OpenAI Embedding Service. + * + *

.. + */ @RegisterRestClient @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) public interface OpenAIEmbeddingProviderClient { @POST @Path("/embeddings") @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( + Uni embed( @HeaderParam("Authorization") String accessToken, @HeaderParam("OpenAI-Organization") String organizationId, @HeaderParam("OpenAI-Project") String projectId, - EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {
-     *   "error": {
-     *     "message": "You exceeded your current quota, please check your plan and billing details. For
-     *                 more information on this error, read the docs:
-     *                 https://platform.openai.com/docs/guides/error-codes/api-errors.",
-     *     "type": "insufficient_quota",
-     *     "param": null,
-     *     "code": "insufficient_quota"
-     *   }
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "message" node from the "error" node - JsonNode messageNode = rootNode.at("/error/message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.asText(); - } + OpenAiEmbeddingRequest request); } - private record EmbeddingRequest( + /** + * Request structure of the OpenAI REST service. + * + *

.. + */ + public record OpenAiEmbeddingRequest( String[] input, String model, @JsonInclude(value = JsonInclude.Include.NON_DEFAULT) int dimensions) {} - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private record EmbeddingResponse(String object, Data[] data, String model, Usage usage) { + /** + * Response structure of the OpenAI REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private record OpenAiEmbeddingResponse(String object, Data[] data, String model, Usage usage) { @JsonIgnoreProperties(ignoreUnknown = true) private record Data(String object, int index, float[] embedding) {} @JsonIgnoreProperties(ignoreUnknown = true) private record Usage(int prompt_tokens, int total_tokens) {} } - - @Override - public Uni vectorize( - int batchId, - List texts, - EmbeddingCredentials embeddingCredentials, - EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model - checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - String[] textArray = new String[texts.size()]; - EmbeddingRequest request = - new EmbeddingRequest(texts.toArray(textArray), model.name(), dimension); - String organizationId = (String) vectorizeServiceParameters.get("organizationId"); - String projectId = (String) vectorizeServiceParameters.get("projectId"); - - Uni response = - applyRetry( - openAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - organizationId, - projectId, - request)); - - return response - .onItem() - .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); - } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); - List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); - }); - } - - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); - } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java index 5046dd855c..a0e8b7e38f 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/UpstageAIEmbeddingProvider.java @@ -2,24 +2,24 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -28,115 +28,99 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class UpstageAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.UPSTAGE_AI; + private static final String UPSTAGE_MODEL_SUFFIX_QUERY = "-query"; private static final String UPSTAGE_MODEL_SUFFIX_PASSAGE = "-passage"; + private final String modelNamePrefix; - private final UpstageAIEmbeddingProviderClient upstageAIEmbeddingProviderClient; + private final UpstageAIEmbeddingProviderClient upstageClient; public UpstageAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, vectorizeServiceParameters, providerConfig); + Map vectorizeServiceParameters) { + super( + ModelProvider.UPSTAGE_AI, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); - this.modelNamePrefix = model.name(); - upstageAIEmbeddingProviderClient = + this.modelNamePrefix = modelConfig.name(); + upstageClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(UpstageAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface UpstageAIEmbeddingProviderClient { - @POST - // no path specified, as it is already included in the baseUri - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extracts the error message from the response body. The example response body is: - * - *

-     * {
-     *   "message": "Unauthorized"
-     * }
-     *
-     * {
-     *   "error": {
-     *     "message": "This model's maximum context length is 4000 tokens. however you requested 10969 tokens. Please reduce your prompt.",
-     *     "type": "invalid_request_error",
-     *     "param": null,
-     *     "code": null
-     *   }
-     * }
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body, or null if the message is not - * found. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Check if the root node contains a "message" field - JsonNode messageNode = rootNode.path("message"); - if (!messageNode.isMissingNode()) { - return messageNode.asText(); - } - // If the "message" field is not found, check for the nested "error" object - JsonNode errorMessageNode = rootNode.at("/error/message"); - if (!errorMessageNode.isMissingNode()) { - return errorMessageNode.asText(); - } - // Return the whole response body if no message is found - return rootNode.toString(); - } + @Override + protected String errorMessageJsonPtr() { + // overriding the function that calls this + return ""; } - // NOTE: "input" is a single String, not array of Constants! - record EmbeddingRequest(String input, String model) {} + /** + * Extracts the error message from the response body. The example response body is: + * + *
+   * {
+   *   "message": "Unauthorized"
+   * }
+   *
+   * {
+   *   "error": {
+   *     "message": "This model's maximum context length is 4000 tokens. however you requested 10969 tokens. Please reduce your prompt.",
+   *     "type": "invalid_request_error",
+   *     "param": null,
+   *     "code": null
+   *   }
+   * }
+   * 
+ */ + @Override + protected String responseErrorMessage(Response jakartaResponse) { - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - record EmbeddingResponse(Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - record Data(int index, float[] embedding) {} + JsonNode rootNode = jakartaResponse.readEntity(JsonNode.class); - @JsonIgnoreProperties(ignoreUnknown = true) - record Usage(int prompt_tokens, int total_tokens) {} + // Check if the root node contains a "message" field + JsonNode messageNode = rootNode.path("message"); + if (!messageNode.isMissingNode()) { + return messageNode.asText(); + } + + // If the "message" field is not found, check for the nested "error" object + JsonNode errorMessageNode = rootNode.at("/error/message"); + if (!errorMessageNode.isMissingNode()) { + return errorMessageNode.asText(); + } + // Return the whole response body if no message is found + return rootNode.toString(); } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model + checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + // Oddity: Implementation does not support batching, so we only accept "batches" // of 1 String, fail for others if (texts.size() != 1) { + // TODO: This should be IllegalArgumentException + // Temporary fail message: with re-batching will give better information throw ErrorCodeV1.INVALID_VECTORIZE_VALUE_TYPE.toApiException( "UpstageAI only supports vectorization of 1 text at a time, got " + texts.size()); } + // Another oddity: model name used as prefix final String modelName = modelNamePrefix @@ -144,30 +128,81 @@ public Uni vectorize( ? UPSTAGE_MODEL_SUFFIX_QUERY : UPSTAGE_MODEL_SUFFIX_PASSAGE); - EmbeddingRequest request = new EmbeddingRequest(texts.get(0), modelName); + var upstageRequest = new UpstageEmbeddingRequest(texts.getFirst(), modelName); - Uni response = - applyRetry( - upstageAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(upstageClient.embed(accessToken, upstageRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var upstageResponse = decodeResponse(jakartaResponse, UpstageEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (upstageResponse.data() == null) { + throwEmptyData(jakartaResponse); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + // aaron - 11 june 2025 - prev code would sort upstageResponse.data() BUT per above we + // only support a batch size of 1, so no need to sort. + List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(upstageResponse.data()) + .map(UpstageEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + upstageResponse.usage().prompt_tokens(), + upstageResponse.usage().total_tokens(), + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Upstage Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface UpstageAIEmbeddingProviderClient { + @POST + // no path specified, as it is already included in the baseUri + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, UpstageEmbeddingRequest request); + } + + /** + * Request structure of the Upstage REST service. + * + *

NOTE: "input" is a single String, not array of Constants! + */ + public record UpstageEmbeddingRequest(String input, String model) {} + + /** + * Response structure of the Upstage REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record UpstageEmbeddingResponse(Data[] data, String model, Usage usage) { + + @JsonIgnoreProperties(ignoreUnknown = true) + record Data(int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + record Usage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java index 07456d8c1d..edb4f8dddb 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingProvider.java @@ -3,196 +3,178 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; -import java.util.Collections; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class VertexAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.VERTEXAI; - private final VertexAIEmbeddingProviderClient vertexAIEmbeddingProviderClient; + + private final VertexAIEmbeddingProviderClient vertexClient; public VertexAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map serviceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, serviceParameters, providerConfig); - - String actualUrl = replaceParameters(baseUrl, serviceParameters); - vertexAIEmbeddingProviderClient = + Map vectorizeServiceParameters) { + super( + ModelProvider.VERTEXAI, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); + + String actualUrl = + replaceParameters(serviceConfig.getBaseUrl(modelName()), vectorizeServiceParameters); + vertexClient = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(actualUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(VertexAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface VertexAIEmbeddingProviderClient { - @POST - @Path("/{modelId}:predict") - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, - @PathParam("modelId") String modelId, - EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * TODO: Add customized error message extraction logic here.
- * Extract the error message from the response body. The example response body is: - * - *

-     *
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - return rootNode.toString(); - } + @Override + protected String errorMessageJsonPtr() { + // overriding the call that needs this. + return null; } - private record EmbeddingRequest(List instances) { - public record Content(String content) {} + @Override + protected String responseErrorMessage(Response jakartaResponse) { + // aaron 9 june 2025 - this is what it did originally, just get the whole response body + + // Get the whole response body + JsonNode rootNode = jakartaResponse.readEntity(JsonNode.class); + return rootNode.toString(); } - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - private static class EmbeddingResponse { - public EmbeddingResponse() {} + @Override + public Uni vectorize( + int batchId, + List texts, + EmbeddingCredentials embeddingCredentials, + EmbeddingRequestType embeddingRequestType) { - private List predictions; + checkEOLModelUsage(); + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); - @JsonIgnore private Object metadata; + var vertexRequest = + new VertexEmbeddingRequest( + texts.stream().map(VertexEmbeddingRequest.Content::new).toList()); - public List getPredictions() { - return predictions; - } + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - public void setPredictions(List predictions) { - this.predictions = predictions; - } + long callStartNano = System.nanoTime(); + return retryHTTPCall(vertexClient.embed(accessToken, modelName(), vertexRequest)) + .onItem() + .transform( + jakartaResponse -> { + var vertexResponse = decodeResponse(jakartaResponse, VertexEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (vertexResponse.predictions() == null) { + throwEmptyData(jakartaResponse); + } - public Object getMetadata() { - return metadata; - } + // token usage is for each of the embeddings , need to sum it up + int total_tokens = 0; + List vectors = new ArrayList<>(vertexResponse.predictions().size()); + for (var prediction : vertexResponse.predictions()) { + vectors.add(prediction.embeddings().values); + total_tokens += prediction.embeddings().statistics().token_count; + } - public void setMetadata(Object metadata) { - this.metadata = metadata; - } + // Docs say the token_count in the response is the "Number of tokens of the input + // text." + // so seems safe ot use this as the prompt_tokens and total_tokens + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#response_body + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + total_tokens, + total_tokens, + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); + }); + } - @JsonIgnoreProperties(ignoreUnknown = true) - protected static class Prediction { - public Prediction() {} + /** + * REST client interface for the Vertex Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface VertexAIEmbeddingProviderClient { - private Embeddings embeddings; + @POST + @Path("/{modelId}:predict") + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, + @PathParam("modelId") String modelId, + VertexEmbeddingRequest request); + } - public Embeddings getEmbeddings() { - return embeddings; - } + /** + * Request structure of the Vertex REST service. + * + *

.. + */ + private record VertexEmbeddingRequest(List instances) { + public record Content(String content) {} + } - public void setEmbeddings(Embeddings embeddings) { - this.embeddings = embeddings; - } + /** + * Response structure of the Vertex REST service. + * + *

.. aaron - 10 June 2025 - this used to be a class, moved to be a record for consistency + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record VertexEmbeddingResponse( + List predictions, + // aaron 10 june 2025, could not see metadata in API docs, but it was in the old code. + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#response_body + @JsonIgnore Object metadata) { + @JsonIgnoreProperties(ignoreUnknown = true) + public record Prediction(Embeddings embeddings) { @JsonIgnoreProperties(ignoreUnknown = true) - protected static class Embeddings { - public Embeddings() {} - - private float[] values; - - @JsonIgnore private Object statistics; - - public float[] getValues() { - return values; - } + public record Embeddings(float[] values, Statistics statistics) { - public void setValues(float[] values) { - this.values = values; - } - - public Object getStatistics() { - return statistics; - } - - public void setStatistics(Object statistics) { - this.statistics = statistics; - } + @JsonIgnoreProperties(ignoreUnknown = true) + public record Statistics(boolean truncated, int token_count) {} } } } - - @Override - public Uni vectorize( - int batchId, - List texts, - EmbeddingCredentials embeddingCredentials, - EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model - checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); - EmbeddingRequest request = - new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList()); - - Uni serviceResponse = - applyRetry( - vertexAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - model.name(), - request)); - - return serviceResponse - .onItem() - .transform( - response -> { - if (response.getPredictions() == null) { - return Response.of(batchId, Collections.emptyList()); - } - List vectors = - response.getPredictions().stream() - .map(prediction -> prediction.getEmbeddings().getValues()) - .collect(Collectors.toList()); - return Response.of(batchId, vectors); - }); - } - - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); - } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java index c15e291fcb..2890e1a2a4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java @@ -2,24 +2,23 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.EmbeddingProviderErrorMapper; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -28,125 +27,155 @@ import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; public class VoyageAIEmbeddingProvider extends EmbeddingProvider { - private static final String providerId = ProviderConstants.VOYAGE_AI; - private final VoyageAIEmbeddingProviderClient voyageAIEmbeddingProviderClient; + + private final VoyageAIEmbeddingProviderClient voyageClient; + private final String requestTypeQuery, requestTypeIndex; private final Boolean autoTruncate; public VoyageAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, + EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig, + EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig modelConfig, + ServiceConfigStore.ServiceConfig serviceConfig, int dimension, - Map serviceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, serviceParameters, providerConfig); + Map vectorizeServiceParameters) { + super( + ModelProvider.VERTEXAI, + providerConfig, + modelConfig, + serviceConfig, + dimension, + vectorizeServiceParameters); // use configured input_type if available - requestTypeQuery = requestProperties.requestTypeQuery().orElse(null); - requestTypeIndex = requestProperties.requestTypeIndex().orElse(null); - Object v = (serviceParameters == null) ? null : serviceParameters.get("autoTruncate"); + requestTypeQuery = requestProperties().requestTypeQuery().orElse(null); + requestTypeIndex = requestProperties().requestTypeIndex().orElse(null); + + Object v = + (vectorizeServiceParameters == null) + ? null + : vectorizeServiceParameters.get("autoTruncate"); autoTruncate = (v instanceof Boolean) ? (Boolean) v : null; - voyageAIEmbeddingProviderClient = + voyageClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(serviceConfig.getBaseUrl(modelName()))) + .readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(VoyageAIEmbeddingProviderClient.class); } - @RegisterRestClient - @RegisterProvider(EmbeddingProviderResponseValidation.class) - public interface VoyageAIEmbeddingProviderClient { - @POST - // no path specified, as it is already included in the baseUri - @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni embed( - @HeaderParam("Authorization") String accessToken, EmbeddingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return EmbeddingProviderErrorMapper.mapToAPIException(providerId, response, errorMessage); - } - - /** - * Extract the error message from the response body. The example response body is: - * - *

-     * {"detail":"You have not yet added your payment method in the billing page and will have reduced rate limits of 3 RPM and 10K TPM.  Please add your payment method in the billing page (https://dash.voyageai.com/billing/payment-methods) to unlock our standard rate limits (https://docs.voyageai.com/docs/rate-limits).  Even with payment methods entered, the free tokens (50M tokens per model) will still apply."}
-     *
-     * {"detail":"Provided API key is invalid."}
-     * 
- * - * @param response The response body as a String. - * @return The error message extracted from the response body. - */ - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from embedding provider '{}': {}", providerId, rootNode.toString()); - // Extract the "detail" node - JsonNode detailNode = rootNode.path("detail"); - // Return the text of the "detail" node, or the full response body if it is missing - return detailNode.isMissingNode() ? rootNode.toString() : detailNode.toString(); - } - } - - record EmbeddingRequest( - @JsonInclude(JsonInclude.Include.NON_EMPTY) String input_type, - String[] input, - String model, - @JsonInclude(JsonInclude.Include.NON_NULL) Boolean truncation) {} - - @JsonIgnoreProperties(ignoreUnknown = true) // ignore possible extra fields without error - record EmbeddingResponse(Data[] data, String model, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - record Data(int index, float[] embedding) {} - - @JsonIgnoreProperties(ignoreUnknown = true) - record Usage(int total_tokens) {} + /** + * Response body with an error will look like below: + * + *
+   * {"detail":"You have not yet added your payment method in the billing page and will have reduced rate limits of 3 RPM and 10K TPM.  Please add your payment method in the billing page (https://dash.voyageai.com/billing/payment-methods) to unlock our standard rate limits (https://docs.voyageai.com/docs/rate-limits).  Even with payment methods entered, the free tokens (50M tokens per model) will still apply."}
+   *
+   * {"detail":"Provided API key is invalid."}
+   * 
+ */ + @Override + protected String errorMessageJsonPtr() { + return "/detail"; } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, EmbeddingRequestType embeddingRequestType) { - // Check if using an EOF model - checkEOLModelUsage(); - checkEmbeddingApiKeyHeader(providerId, embeddingCredentials.apiKey()); + + checkEmbeddingApiKeyHeader(embeddingCredentials.apiKey()); + + // TODO: remove the requestTypeQuery and requestTypeIndex from config ! + // aaron 8 June 2025 - this looks like the term to sue for query and index is in config, but + // there is + // NOT handling of when this config is not set final String inputType = (embeddingRequestType == EmbeddingRequestType.SEARCH) ? requestTypeQuery : requestTypeIndex; - String[] textArray = new String[texts.size()]; - EmbeddingRequest request = - new EmbeddingRequest(inputType, texts.toArray(textArray), model.name(), autoTruncate); - Uni response = - applyRetry( - voyageAIEmbeddingProviderClient.embed( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(), - request)); + var voyageRequest = + new VoyageEmbeddingRequest( + inputType, texts.toArray(new String[texts.size()]), modelName(), autoTruncate); + + // TODO: V2 error + // aaron 8 June 2025 - old code had NO comment to explain what happens if the API key is empty. + var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get(); - return response + long callStartNano = System.nanoTime(); + return retryHTTPCall(voyageClient.embed(accessToken, voyageRequest)) .onItem() .transform( - resp -> { - if (resp.data() == null) { - return Response.of(batchId, Collections.emptyList()); + jakartaResponse -> { + var voyageResponse = decodeResponse(jakartaResponse, VoyageEmbeddingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // aaron - 10 June 2025 - previous code would silently swallow no data returned + // and return an empty result. If we made a request we should get a response. + if (voyageResponse.data() == null) { + throwEmptyData(jakartaResponse); } - Arrays.sort(resp.data(), (a, b) -> a.index() - b.index()); + + // TODO: WHY SORT ? + Arrays.sort(voyageResponse.data(), (a, b) -> a.index() - b.index()); + List vectors = - Arrays.stream(resp.data()).map(data -> data.embedding()).toList(); - return Response.of(batchId, vectors); + Arrays.stream(voyageResponse.data()) + .map(VoyageEmbeddingResponse.Data::embedding) + .toList(); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + voyageResponse.usage.total_tokens, + jakartaResponse, + callDurationNano); + return new BatchedEmbeddingResponse(batchId, vectors, modelUsage); }); } - @Override - public int maxBatchSize() { - return requestProperties.maxBatchSize(); + /** + * REST client interface for the Voyage Embedding Service. + * + *

.. + */ + @RegisterRestClient + @RegisterProvider(EmbeddingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) + public interface VoyageAIEmbeddingProviderClient { + @POST + // no path specified, as it is already included in the baseUri + @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) + Uni embed( + @HeaderParam("Authorization") String accessToken, VoyageEmbeddingRequest request); + } + + /** + * Request structure of the Voyage REST service. + * + *

.. + */ + public record VoyageEmbeddingRequest( + @JsonInclude(JsonInclude.Include.NON_EMPTY) String input_type, + String[] input, + String model, + @JsonInclude(JsonInclude.Include.NON_NULL) Boolean truncation) {} + + /** + * Response structure of the Voyage REST service. + * + *

.. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record VoyageEmbeddingResponse(Data[] data, String model, Usage usage) { + @JsonIgnoreProperties(ignoreUnknown = true) + record Data(int index, float[] embedding) {} + + @JsonIgnoreProperties(ignoreUnknown = true) + record Usage(int total_tokens) { + // Voyage API does not return prompt_tokens + } } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/EmbeddingProviderErrorMapper.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/EmbeddingProviderErrorMapper.java deleted file mode 100644 index 203123ef9a..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/EmbeddingProviderErrorMapper.java +++ /dev/null @@ -1,54 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding.operation.error; - -import static jakarta.ws.rs.core.Response.Status.Family.CLIENT_ERROR; - -import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import jakarta.ws.rs.core.Response; - -public class EmbeddingProviderErrorMapper { - /** - * Maps an HTTP response to a corresponding API exception. Individual providers can override this - * method to provide custom exception handling. - * - * @param providerName the name of the provider - * @param response the HTTP response - * @param message the error message from provider - * @return a JsonApiException that corresponds to the specific HTTP response status - */ - public static RuntimeException mapToAPIException( - String providerName, Response response, String message) { - // Status code == 408 and 504 for timeout - if (response.getStatus() == Response.Status.REQUEST_TIMEOUT.getStatusCode() - || response.getStatus() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { - return ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // Status code == 429 - if (response.getStatus() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { - return ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // Status code in 4XX other than 429 - if (response.getStatusInfo().getFamily() == CLIENT_ERROR) { - return ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // Status code in 5XX - if (response.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) { - return ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // All other errors, Should never happen as all errors are covered above - return ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/RerankingResponseErrorMessageMapper.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/RerankingResponseErrorMessageMapper.java deleted file mode 100644 index fe0f73b792..0000000000 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/error/RerankingResponseErrorMessageMapper.java +++ /dev/null @@ -1,54 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.embedding.operation.error; - -import static jakarta.ws.rs.core.Response.Status.Family.CLIENT_ERROR; - -import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import jakarta.ws.rs.core.Response; - -public class RerankingResponseErrorMessageMapper { - /** - * Maps an HTTP response to a corresponding API exception. Individual reranking providers can - * override this method to provide custom exception handling. - * - * @param providerName the name of the reranking provider - * @param response the HTTP response - * @param message the error message from reranking provider - * @return a JsonApiException that corresponds to the specific HTTP response status - */ - public static RuntimeException mapToAPIException( - String providerName, Response response, String message) { - // Status code == 408 and 504 for timeout - if (response.getStatus() == Response.Status.REQUEST_TIMEOUT.getStatusCode() - || response.getStatus() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { - return ErrorCodeV1.RERANKING_PROVIDER_TIMEOUT.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // Status code == 429 - if (response.getStatus() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { - return ErrorCodeV1.RERANKING_PROVIDER_RATE_LIMITED.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // Status code in 4XX other than 429 - if (response.getStatusInfo().getFamily() == CLIENT_ERROR) { - return ErrorCodeV1.RERANKING_PROVIDER_CLIENT_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // Status code in 5XX - if (response.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) { - return ErrorCodeV1.RERANKING_PROVIDER_SERVER_ERROR.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } - - // All other errors, Should never happen as all errors are covered above - return ErrorCodeV1.RERANKING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( - "Provider: %s; HTTP Status: %s; Error Message: %s", - providerName, response.getStatus(), message); - } -} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java index f7ea7f24c2..b93da3e5bc 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/test/CustomITEmbeddingProvider.java @@ -4,8 +4,11 @@ import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfigImpl; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import java.util.*; /** @@ -31,25 +34,63 @@ public class CustomITEmbeddingProvider extends EmbeddingProvider { private final int dimension; + private static final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl + .RequestPropertiesImpl + REQUEST_PROPERTIES = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.RequestPropertiesImpl( + 3, 10, 100, 100, 0.5, Optional.empty(), Optional.empty(), Optional.empty(), 10); + + private static final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl PROVIDER_CONFIG = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl( + ModelProvider.CUSTOM.apiName(), + true, + Optional.of("http://testing.com"), + false, + Map.of(), + List.of(), + REQUEST_PROPERTIES, + List.of()); + + private static final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl + MODEL_CONFIG = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl( + "test-model", + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), + Optional.of(5), + List.of(), + Map.of(), + Optional.empty()); + + private static final ServiceConfigStore.ServiceConfig SERVICE_CONFIG = + new ServiceConfigStore.ServiceConfig( + ModelProvider.CUSTOM, + "http://testing.com", + Optional.empty(), + new ServiceConfigStore.ServiceRequestProperties( + REQUEST_PROPERTIES.atMostRetries(), + REQUEST_PROPERTIES.initialBackOffMillis(), + REQUEST_PROPERTIES.readTimeoutMillis(), + REQUEST_PROPERTIES.maxBackOffMillis(), + REQUEST_PROPERTIES.jitter(), + REQUEST_PROPERTIES.taskTypeRead(), + REQUEST_PROPERTIES.taskTypeStore(), + REQUEST_PROPERTIES.maxBatchSize()), + Map.of()); + public CustomITEmbeddingProvider(int dimension) { - // construct the test modelConfig - super( - null, - null, - new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl( - "testModel", - new ApiModelSupport.ApiModelSupportImpl( - ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), - Optional.of(dimension), - List.of(), - Map.of(), - Optional.empty()), - dimension, - Map.of(), - null); + // aaron 9 June 2025 - refactoring , I think none of the super class is used, so passing dummy + // values + super(ModelProvider.CUSTOM, PROVIDER_CONFIG, MODEL_CONFIG, SERVICE_CONFIG, 5, Map.of()); + this.dimension = dimension; } + @Override + protected String errorMessageJsonPtr() { + return ""; + } + static { TEST_DATA_DIMENSION_5.put( "ChatGPT integrated sneakers that talk to you", @@ -83,7 +124,7 @@ public CustomITEmbeddingProvider(int dimension) { } @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, @@ -93,10 +134,23 @@ public Uni vectorize( checkEOLModelUsage(); List response = new ArrayList<>(texts.size()); - if (texts.size() == 0) return Uni.createFrom().item(Response.of(batchId, response)); - if (!embeddingCredentials.apiKey().isPresent() - || !embeddingCredentials.apiKey().get().equals(TEST_API_KEY)) + if (texts.isEmpty()) { + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom().item(new BatchedEmbeddingResponse(batchId, response, modelUsage)); + } + if (embeddingCredentials.apiKey().isEmpty() + || !embeddingCredentials.apiKey().get().equals(TEST_API_KEY)) { return Uni.createFrom().failure(new RuntimeException("Invalid API Key")); + } + for (String text : texts) { if (dimension == 5) { if (TEST_DATA_DIMENSION_5.containsKey(text)) { @@ -113,7 +167,17 @@ public Uni vectorize( } } } - return Uni.createFrom().item(Response.of(batchId, response)); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom().item(new BatchedEmbeddingResponse(batchId, response, modelUsage)); } @Override diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java index f8f939ce1c..e222b26330 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTask.java @@ -77,7 +77,7 @@ protected EmbeddingTask.EmbeddingResultSupplier buildResultSupplier( commandContext .requestContext() .getEmbeddingCredentialsSupplier() - .create(commandContext.requestContext(), embeddingProvider.getProviderConfig()), + .create(commandContext.requestContext(), embeddingProvider.providerConfig()), requestType), embeddingActions, vectorizeTexts); @@ -111,14 +111,14 @@ public static class EmbeddingResultSupplier implements BaseTask.UniSupplier embeddingTask; protected final CommandContext commandContext; - protected final BaseTask.UniSupplier supplier; + protected final BaseTask.UniSupplier supplier; protected final List actions; private final List vectorizeTexts; EmbeddingResultSupplier( EmbeddingTask embeddingTask, CommandContext commandContext, - BaseTask.UniSupplier supplier, + BaseTask.UniSupplier supplier, List actions, List vectorizeTexts) { this.embeddingTask = embeddingTask; @@ -181,7 +181,7 @@ private EmbeddingTaskResult(List rawVectors, List embeddingTask, CommandContext commandContext, - EmbeddingProvider.Response providerResponse, + EmbeddingProvider.BatchedEmbeddingResponse providerResponse, List actions) { commandContext diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTaskBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTaskBuilder.java index 51d7e1f387..b19f5b49d4 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTaskBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/embeddings/EmbeddingTaskBuilder.java @@ -81,7 +81,7 @@ public EmbeddingTask build() { var embeddingProvider = commandContext .embeddingProviderFactory() - .getConfiguration( + .create( commandContext.requestContext().getTenantId(), commandContext.requestContext().getCassandraToken(), vectorizeDefinition.provider(), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java index 2ff333b436..2b5795554c 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/reranking/RerankingTask.java @@ -275,7 +275,8 @@ public Uni get() { RerankingTaskResult.create( requestTracing, rerankingProvider, - new RerankingProvider.RerankingResponse(List.of()), + new RerankingProvider.RerankingResponse( + List.of(), rerankingProvider.createEmptyModelUsage(credentials)), unrankedDocs, limit)); } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java new file mode 100644 index 0000000000..91c8b31d8a --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java @@ -0,0 +1,39 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import io.stargate.embedding.gateway.EmbeddingGateway; +import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider; +import java.util.Optional; + +/** + * If the model usage was for indexing data or searching data + * + *

Keeps in parity with the grpc proto definition in embedding_gateway.proto. + */ +public enum ModelInputType { + + /** + * The input type is not specified, for parity with grpc, or where it does not make sense to such + * as for a reranking model. + */ + INPUT_TYPE_UNSPECIFIED, + INDEX, + SEARCH; + + public static ModelInputType fromEmbeddingRequestType( + EmbeddingProvider.EmbeddingRequestType embeddingRequestType) { + return switch (embeddingRequestType) { + case INDEX -> INDEX; + case SEARCH -> SEARCH; + }; + } + + public static Optional fromEmbeddingGateway( + EmbeddingGateway.ModelUsage.InputType inputType) { + return switch (inputType) { + case INPUT_TYPE_UNSPECIFIED -> Optional.of(INPUT_TYPE_UNSPECIFIED); + case INDEX -> Optional.of(INDEX); + case SEARCH -> Optional.of(SEARCH); + default -> Optional.empty(); + }; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelProvider.java new file mode 100644 index 0000000000..501dc5a583 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelProvider.java @@ -0,0 +1,57 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +/** + * Identifier for a Model Provider. + * + *

The list here needs to sync with the list used in the yaml config files, this is the + * canonnical list of Model Proviers that we know about. aaron 17 june 2025 - This is used to be a + * series of string consts in a class called ProviderConstants + */ +public enum ModelProvider { + AZURE_OPENAI("azureOpenAI"), + BEDROCK("bedrock"), + COHERE("cohere"), + CUSTOM("custom"), + HUGGINGFACE("huggingface"), + HUGGINGFACE_DEDICATED("huggingfaceDedicated"), + HUGGINGFACE_DEDICATED_DEFINED_MODEL("endpoint-defined-model"), + JINA_AI("jinaAI"), + MISTRAL("mistral"), + NVIDIA("nvidia"), + OPENAI("openai"), + UPSTAGE_AI("upstageAI"), + VERTEXAI("vertexai"), + VOYAGE_AI("voyageAI"); + + private static final Map API_NAME_TO_PROVIDER; + + static { + API_NAME_TO_PROVIDER = new HashMap<>(); + for (ModelProvider provider : ModelProvider.values()) { + API_NAME_TO_PROVIDER.put(provider.apiName(), provider); + } + } + + private final String apiName; + + ModelProvider(String apiName) { + this.apiName = apiName; + } + + public String apiName() { + return apiName; + } + + public static Optional fromApiName(String apiName) { + return Optional.ofNullable(API_NAME_TO_PROVIDER.get(apiName)); + } + + @Override + public String toString() { + return apiName; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java new file mode 100644 index 0000000000..87348a37ca --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java @@ -0,0 +1,26 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import io.stargate.embedding.gateway.EmbeddingGateway; +import java.util.Optional; + +/** + * The type of model that was used, such as embedding or reranking. + * + *

Keeps in parity with the grpc proto definition in embedding_gateway.proto + */ +public enum ModelType { + /** The input type is not specified, for parity with grpc */ + MODEL_TYPE_UNSPECIFIED, + EMBEDDING, + RERANKING; + + public static Optional fromEmbeddingGateway( + EmbeddingGateway.ModelUsage.ModelType modelType) { + return switch (modelType) { + case MODEL_TYPE_UNSPECIFIED -> Optional.of(MODEL_TYPE_UNSPECIFIED); + case EMBEDDING -> Optional.of(EMBEDDING); + case RERANKING -> Optional.of(RERANKING); + default -> Optional.empty(); + }; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java new file mode 100644 index 0000000000..8217ab52b4 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java @@ -0,0 +1,218 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import io.stargate.embedding.gateway.EmbeddingGateway; +import io.stargate.sgv2.jsonapi.util.recordable.PrettyPrintable; +import io.stargate.sgv2.jsonapi.util.recordable.Recordable; +import java.util.Objects; + +/** + * Usage of a model, any model, recorded for billing or metrics purposes. + * + *

When doing batching, create one instance and then use {@link #merge(ModelUsage)} to combine. + * Note that the durations are added , use the batchCount to get average duration. + */ +public final class ModelUsage implements Recordable { + + private final ModelProvider modelProvider; + private final ModelType modelType; + private final String modelName; + private final String tenantId; + private final ModelInputType inputType; + private final int promptTokens; + private final int totalTokens; + private final int requestBytes; + private final int responseBytes; + private final long durationNanos; + private final int batchCount; + + public ModelUsage( + ModelProvider modelProvider, + ModelType modelType, + String modelName, + String tenantId, + ModelInputType inputType, + int promptTokens, + int totalTokens, + int requestBytes, + int responseBytes, + long durationNanos) { + this( + modelProvider, + modelType, + modelName, + tenantId, + inputType, + promptTokens, + totalTokens, + requestBytes, + responseBytes, + durationNanos, + 1); + } + + private ModelUsage( + ModelProvider modelProvider, + ModelType modelType, + String modelName, + String tenantId, + ModelInputType inputType, + int promptTokens, + int totalTokens, + int requestBytes, + int responseBytes, + long durationNanos, + int batchCount) { + this.modelProvider = Objects.requireNonNull(modelProvider, "modelProvider must not be null"); + this.modelType = Objects.requireNonNull(modelType, "modelType must not be null"); + this.modelName = Objects.requireNonNull(modelName, "modelName must not be null"); + this.tenantId = Objects.requireNonNull(tenantId, "tenantId must not be null"); + this.inputType = Objects.requireNonNull(inputType, "inputType must not be null"); + if (promptTokens < 0) { + throw new IllegalArgumentException("promptTokens must not be negative"); + } + this.promptTokens = promptTokens; + if (totalTokens < 0) { + throw new IllegalArgumentException("totalTokens must not be negative"); + } + this.totalTokens = totalTokens; + if (requestBytes < 0) { + throw new IllegalArgumentException("requestBytes must not be negative"); + } + this.requestBytes = requestBytes; + if (responseBytes < 0) { + throw new IllegalArgumentException("responseBytes must not be negative"); + } + this.responseBytes = responseBytes; + if (durationNanos < 0) { + throw new IllegalArgumentException("durationNanos must not be negative"); + } + this.durationNanos = durationNanos; + if (batchCount < 1) { + throw new IllegalArgumentException("batchCount must be at least 1"); + } + this.batchCount = batchCount; + } + + /** Create a ModelUsage from an EmbeddingGateway.ModelUsage. grpc object */ + public static ModelUsage fromEmbeddingGateway(EmbeddingGateway.ModelUsage grpcModelUsage) { + + return new ModelUsage( + ModelProvider.fromApiName(grpcModelUsage.getModelProvider()) + .orElseThrow( + () -> + new IllegalArgumentException( + "ModelUsage() - Unknown grpcModelUsage.getModelProvider(): '%s'" + .formatted(grpcModelUsage.getModelProvider()))), + ModelType.fromEmbeddingGateway(grpcModelUsage.getModelType()) + .orElseThrow( + () -> + new IllegalArgumentException( + "ModelUsage() - Unknown grpcModelUsage.getModelType(): '%s'" + .formatted(grpcModelUsage.getModelType()))), + grpcModelUsage.getModelName(), + grpcModelUsage.getTenantId(), + ModelInputType.fromEmbeddingGateway(grpcModelUsage.getInputType()) + .orElseThrow( + () -> + new IllegalArgumentException( + "Unknown Embedding Gateway modelInputType: " + + grpcModelUsage.getInputType())), + grpcModelUsage.getPromptTokens(), + grpcModelUsage.getTotalTokens(), + grpcModelUsage.getRequestBytes(), + grpcModelUsage.getResponseBytes(), + grpcModelUsage.getCallDurationNanos()); + } + + /** + * Creates a new model usage that merges this and the other usage, to combine after batching. + * + * @return A new ModelUsage instance that combines the properties of this and the other usage. + */ + public ModelUsage merge(ModelUsage other) { + + Objects.requireNonNull(other, "other must not be null"); + if (!this.modelProvider.equals(other.modelProvider) + || !this.modelType.equals(other.modelType) + || !this.modelName.equals(other.modelName) + || !this.tenantId.equals(other.tenantId) + || !this.inputType.equals(other.inputType)) { + throw new IllegalArgumentException( + "Cannot merge ModelUsage with different properties, this: %s, other: %s" + .formatted(PrettyPrintable.print(this), PrettyPrintable.print(other))); + } + + return new ModelUsage( + this.modelProvider, + this.modelType, + this.modelName, + this.tenantId, + this.inputType, + this.promptTokens + other.promptTokens, + this.totalTokens + other.totalTokens, + this.requestBytes + other.requestBytes, + this.responseBytes + other.responseBytes, + this.durationNanos + other.durationNanos, + this.batchCount + other.batchCount); + } + + public ModelProvider modelProvider() { + return modelProvider; + } + + public ModelType modelType() { + return modelType; + } + + public String modelName() { + return modelName; + } + + public String tenantId() { + return tenantId; + } + + public ModelInputType inputType() { + return inputType; + } + + public int promptTokens() { + return promptTokens; + } + + public int totalTokens() { + return totalTokens; + } + + public int requestBytes() { + return requestBytes; + } + + public int responseBytes() { + return responseBytes; + } + + public long durationNanos() { + return durationNanos; + } + + public int batchCount() { + return batchCount; + } + + @Override + public DataRecorder recordTo(DataRecorder dataRecorder) { + return dataRecorder + .append("modelProvider", modelProvider) + .append("modelType", modelType) + .append("modelName", modelName) + .append("tenantId", tenantId) + .append("inputType", inputType) + .append("promptTokens", promptTokens) + .append("totalTokens", totalTokens) + .append("requestBytes", requestBytes) + .append("responseBytes", responseBytes) + .append("durationNanos", durationNanos) + .append("batchCount", batchCount); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java new file mode 100644 index 0000000000..1487fad15b --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java @@ -0,0 +1,332 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import com.fasterxml.jackson.databind.JsonNode; +import io.smallrye.mutiny.Uni; +import io.stargate.embedding.gateway.EmbeddingGateway; +import io.stargate.sgv2.jsonapi.exception.SchemaException; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.TimeoutException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base for model providers of any model type, such as embedding and reranking. + * + *

Notes for implementors: + * + *

    + *
  • Define the rest easy client to return a {@link Response} and call {@link + * #retryHTTPCall(Uni)} to manage retries and backoff. + *
  • Once you have the response called {@link #decodeResponse(Response, Class)} to decode the + * response to a specific class. + *
+ * + *

. The Embedding and Rerank code *does not* share model configs, but they can & should do, so + * we cannot pass the model into the base until we refactor the code. That is why there are + * properties that could be removed or refactored if we had more common config. + */ +public abstract class ProviderBase { + protected static final Logger LOGGER = LoggerFactory.getLogger(ProviderBase.class); + + private final ModelProvider modelProvider; + private final ModelType modelType; + + protected ProviderBase(ModelProvider modelProvider, ModelType modelType) { + this.modelProvider = modelProvider; + this.modelType = modelType; + } + + public ModelProvider modelProvider() { + return modelProvider; + } + + public abstract String modelName(); + + public abstract ApiModelSupport modelSupport(); + + /** + * Called to map the HTTP response to an API exception, subclasses should override this method to + * provide a specific mapping for the model provider. + * + *

This method is called after the error message is extracted from the response. + * + * @param response The raw HTTP response from the model provider + * @param errorMessage The error message extracted from the response + * @return The mapped exception to later throw, should not return null. + */ + protected abstract RuntimeException mapHTTPError(Response response, String errorMessage); + + /** + * Last in the chain to extract the error message from the response JSON, see {@link + * #responseErrorMessage(Response)} + * + *

+ * + * @return JSON Pointer path that will be used with {@link JsonNode#at(String)} to extract the + * error message node. + */ + protected abstract String errorMessageJsonPtr(); + + protected abstract Duration initialBackOffDuration(); + + protected abstract Duration maxBackOffDuration(); + + protected abstract double jitter(); + + protected abstract int atMostRetries(); + + /** + * Retries the HTTP call with backoff and jitter, and translates the response to a API exception. + */ + protected Uni retryHTTPCall(Uni uni) { + + return uni + // Catch *any* web exception from jakarta rest client + .onFailure(WebApplicationException.class) + // and recover with the jakarta response, so we can translate to API exception + .recoverWithItem(ex -> ((WebApplicationException) ex).getResponse()) + .onItem() + // handle the response, throws if there is an error + .transform(this::handleHTTPResponse) + // decide if we want to retry + .onFailure(this::decideRetry) + .retry() + .withBackOff(initialBackOffDuration(), maxBackOffDuration()) + .withJitter(jitter()) + .atMost(atMostRetries()); + } + + /** + * Called to determine if the operation should be retried based on the throwable. + * + *

Subclasses should normally override, and then call the base if they do not want to retry. + * + * @param throwable Exception, either the API Exception mapped from the jakarta response. Or any + * other error if the rest client throws non WebApplicationException + * @return true if the operation should be retried, false otherwise. + */ + protected boolean decideRetry(Throwable throwable) { + return throwable instanceof TimeoutException; + } + + /** + * Called to process the HTTP response from the model provider, called for both successful and + * error responses. This function determines if the response is an error. + * + *

Implementatioms shoudl throw any exceptions created from the response + * + * @param jakartaResponse Raw HTTP response from the model provider, which may be an error + * response. + * @return The original response if it is successful, or throws an exception if it is an error. + */ + protected Response handleHTTPResponse(Response jakartaResponse) { + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "handleHTTPResponse() - got response, modelProvider: {}, modelName: {}, response.status: {}, response.headers: {}", + modelProvider(), + modelName(), + jakartaResponse.getStatus(), + jakartaResponse.getHeaders()); + } + + if (jakartaResponse.getStatus() >= 400) { + var runtimeException = mapHTTPError(jakartaResponse); + if (runtimeException != null) { + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "handleHTTPResponse() - http response mapped to error, runtimeException: {}", + runtimeException.toString()); + } + throw runtimeException; + } + + throw new IllegalStateException( + String.format( + "Unhandled error from model provider, modelProvider: %s, modelName: %s, status: %d, responseBody: %s", + modelProvider(), + modelName(), + jakartaResponse.getStatus(), + jakartaResponse.readEntity(String.class))); + } + return jakartaResponse; + } + + /** + * Called to map the HTTP response to an API exception, sublcasses should override {@link + * #mapHTTPError(Response, String)} which is called after the error message is extracted from the + * response. + * + *

Should only be called when there response status is >= 400, i.e. an error response. + * + * @param jakartaResponse The raw HTTP response from the model provider + * @return The mapped exception to later throw, should not return null. + */ + protected RuntimeException mapHTTPError(Response jakartaResponse) { + + var errorMessage = responseErrorMessage(jakartaResponse); + // this is the main "error" log when the response is an error + LOGGER.error( + "Error response from model provider, modelProvider: {}, modelName:{}, http.status: {}, error: {}", + modelProvider, + modelName(), + jakartaResponse.getStatus(), + errorMessage); + + var mappedException = mapHTTPError(jakartaResponse, errorMessage); + if (mappedException != null) { + return mappedException; + } + + return new IllegalStateException( + String.format( + "Unhandled error from model provider, modelProvider: %s, modelName: %s, status: %d, responseBody: %s", + modelProvider(), + modelName(), + jakartaResponse.getStatus(), + jakartaResponse.readEntity(String.class))); + } + + /** + * First in the chain to extract the error message from the response, the easiest thing for + * subclasses is to override {@link #errorMessageJsonPtr()} to provide the JSON Pointer to get a + * single error message from the response JSON. + * + *

This method decodes the JSON response and calles {@link #responseErrorMessage(JsonNode)} + */ + protected String responseErrorMessage(Response jakartaResponse) { + + MediaType contentType = jakartaResponse.getMediaType(); + String raw = jakartaResponse.readEntity(String.class); + + if (contentType == null || !MediaType.APPLICATION_JSON_TYPE.isCompatible(contentType)) { + LOGGER.error( + "Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}", + modelProvider(), + modelName(), + raw); + return raw; + } + + JsonNode rootNode = null; + try { + rootNode = jakartaResponse.readEntity(JsonNode.class); + } catch (Exception e) { + // If we cannot read the response as JsonNode, log the error and return the raw response + LOGGER.error( + "Error parsing error JSON from reranking provider, modelProvider: {}, modelName: {}", + modelProvider, + modelName(), + e); + } + + return (rootNode == null) ? raw : responseErrorMessage(rootNode); + } + + /** + * Secong in the chain to extract the error message from the response JSON, this is called with + * the decoded JSON node. The easiest thing for subclasses is to override {@link + * #errorMessageJsonPtr()} + */ + protected String responseErrorMessage(JsonNode rootNode) { + + var messageNode = rootNode.at(errorMessageJsonPtr()); + return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); + } + + /** + * Utility method to decode the response (JSON entity) to a specific class, and log if there is an + * error. + * + *

Because decoding happens after all the retry, it will not be mapped into an API exception + * through same proccess as making the HTTP calls. + */ + protected T decodeResponse(Response jakartaResponse, Class responseClass) { + try { + return jakartaResponse.readEntity(responseClass); + } catch (Throwable e) { + LOGGER.error( + "decodeResponse() - error decoding response modelProvider: {}, modelName: {}, responseClass: {}", + modelProvider(), + modelName(), + responseClass.getName(), + e); + // rethrow so it can be handled elsewhere, we just want to log the error + throw e; + } + } + + /** + * Checks if the vectorization will use an END_OF_LIFE model and throws an exception if it is. + * + *

As part of embedding model deprecation ability, any read and write with vectorization in an + * END_OF_LIFE model will throw an exception. + * + *

Note, SUPPORTED and DEPRECATED models are still allowed to be used in read and write. + * + *

This method should be called before any vectorization operation. + */ + protected void checkEOLModelUsage() { + + if (modelSupport().status() == ApiModelSupport.SupportStatus.END_OF_LIFE) { + throw SchemaException.Code.END_OF_LIFE_AI_MODEL.get( + Map.of( + "model", + modelName(), + "modelStatus", + modelSupport().status().name(), + "message", + modelSupport() + .message() + .orElse("The model is no longer supported (reached its end-of-life)."))); + } + } + + protected ModelUsage createModelUsage( + String tenantId, + ModelInputType modelInputType, + int promptTokens, + int totalTokens, + int requestBytes, + int responseBytes, + long durationNanos) { + + return new ModelUsage( + modelProvider, + modelType, + modelName(), + tenantId, + modelInputType, + promptTokens, + totalTokens, + requestBytes, + responseBytes, + durationNanos); + } + + protected ModelUsage createModelUsage( + String tenantId, + ModelInputType modelInputType, + int promptTokens, + int totalTokens, + Response jakartaResponse, + long durationNanos) { + + return createModelUsage( + tenantId, + modelInputType, + promptTokens, + totalTokens, + ProviderHttpInterceptor.getSentBytes(jakartaResponse), + ProviderHttpInterceptor.getReceivedBytes(jakartaResponse), + durationNanos); + } + + protected ModelUsage createModelUsage(EmbeddingGateway.ModelUsage gatewayModelUsage) { + return ModelUsage.fromEmbeddingGateway(gatewayModelUsage); + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java new file mode 100644 index 0000000000..2a0fa8f0d1 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java @@ -0,0 +1,111 @@ +package io.stargate.sgv2.jsonapi.service.provider; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.CountingOutputStream; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientResponseContext; +import jakarta.ws.rs.client.ClientResponseFilter; +import jakarta.ws.rs.core.Response; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.OutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class is to track the usage at the http request level to the embedding or reranking provider + * model service. + * + *

E.G. When a providerClient registered the interceptor + * as @RegisterProvider(ProviderHttpInterceptor.class), the interceptor will intercept the http + * request and response, then add the sent-bytes and received-bytes to the response headers in the + * response context. + * + *

Note, if provider already returned content-length in the response header, then the interceptor + * will reuse it and won't calculate the response size. + */ +public class ProviderHttpInterceptor implements ClientResponseFilter { + + private static final Logger LOGGER = LoggerFactory.getLogger(ProviderHttpInterceptor.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** Header name to track the sent_bytes to the provider (extra detailed to avoid collisions) */ + private static final String SENT_BYTES_HEADER = "data-api-model-usage-sent-bytes"; + + /** + * Header name to track the received_bytes from the provider (extra detailed to avoid collisions) + */ + private static final String RECEIVED_BYTES_HEADER = "data-api-model-usage-received-bytes"; + + @Override + public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) + throws IOException { + + long receivedBytes = 0; + long sentBytes = 0; + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "ProviderHttpInterceptor.filter() - requestContext.getUri(): {}, requestContext.getHeaders(): {}", + requestContext.getUri(), + requestContext.getStringHeaders()); + + LOGGER.trace( + "ProviderHttpInterceptor.filter() - responseContext.getStatus(): {}, responseContext.getHeaders(): {}", + responseContext.getStatus(), + responseContext.getHeaders()); + } + + // Parse the request entity stream to measure its size. + if (requestContext.hasEntity()) { + try (var cus = new CountingOutputStream(OutputStream.nullOutputStream())) { + OBJECT_MAPPER.writeValue(cus, requestContext.getEntity()); + sentBytes = cus.getCount(); + + } catch (Exception e) { + if (LOGGER.isWarnEnabled()) { + LOGGER.warn("Failed to measure request body size.", e); + } + } + } + + // Use the content-length if present, otherwise parse the response entity stream to measure its + // size. + if (responseContext.hasEntity()) { + receivedBytes = responseContext.getLength(); + + // if provider does not return content-length in the response header. + if (receivedBytes <= 0) { + // IMPORTANT - need to reset the entity stream so it can be read again, we have not + // decoded this into objects yet. + byte[] body = responseContext.getEntityStream().readAllBytes(); + receivedBytes = body.length; + responseContext.setEntityStream(new ByteArrayInputStream(body)); + } + } + + responseContext.getHeaders().add(SENT_BYTES_HEADER, String.valueOf(sentBytes)); + responseContext.getHeaders().add(RECEIVED_BYTES_HEADER, String.valueOf(receivedBytes)); + } + + public static int getSentBytes(Response jakartaResponse) { + return getHeaderInt(jakartaResponse, SENT_BYTES_HEADER); + } + + public static int getReceivedBytes(Response jakartaResponse) { + return getHeaderInt(jakartaResponse, RECEIVED_BYTES_HEADER); + } + + private static int getHeaderInt(Response jakartaResponse, String headerName) { + + var headerString = jakartaResponse.getHeaderString(headerName); + if (headerString != null && !headerString.isBlank()) { + try { + return Integer.parseInt(headerString); + } catch (NumberFormatException e) { + LOGGER.warn("Failed to parse headerName:{}, headerString:{}", headerName, headerString, e); + } + } + return 0; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java index 2beb42afe1..e9857049c5 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingEGWClient.java @@ -8,6 +8,7 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import java.util.*; @@ -18,82 +19,73 @@ public class RerankingEGWClient extends RerankingProvider { private static final String DEFAULT_TENANT_ID = "default"; - /** - * This string acts as key of authTokens map, for passing Data API token to EGW in grpc request. - */ + /** Key of authTokens map, for passing Data API token to EGW in grpc request. */ private static final String DATA_API_TOKEN = "DATA_API_TOKEN"; - /** - * This string acts as key of authTokens map, for passing Reranking API key to EGW in grpc - * request. - */ + /** Key in the authTokens map, for passing Reranking API key to EGW in grpc request. */ private static final String RERANKING_API_KEY = "RERANKING_API_KEY"; - private final String provider; private final Optional tenant; private final Optional authToken; - private final String modelName; - private final RerankingService rerankingGrpcService; + private final RerankingService grpcGatewayService; Map authentication; private final String commandName; public RerankingEGWClient( - String baseUrl, - RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties, - String provider, + ModelProvider modelProvider, + RerankingProvidersConfig.RerankingProviderConfig.ModelConfig modelConfig, Optional tenant, Optional authToken, - String modelName, - RerankingService rerankingGrpcService, + RerankingService grpcGatewayService, Map authentication, String commandName) { - super(baseUrl, modelName, requestProperties); - this.provider = provider; + super(modelProvider, modelConfig); + this.tenant = tenant; this.authToken = authToken; - this.modelName = modelName; - this.rerankingGrpcService = rerankingGrpcService; + this.grpcGatewayService = grpcGatewayService; this.authentication = authentication; this.commandName = commandName; } @Override - public Uni rerank( + protected String errorMessageJsonPtr() { + // not used here, we are just passing through. + return ""; + } + + @Override + public Uni rerank( int batchId, String query, List passages, RerankingCredentials rerankingCredentials) { - // Build the reranking provider request in grpc request - final EmbeddingGateway.ProviderRerankingRequest.RerankingRequest rerankingRequest = + var gatewayReranking = EmbeddingGateway.ProviderRerankingRequest.RerankingRequest.newBuilder() - .setModelName(modelName) + .setModelName(modelName()) .setQuery(query) .addAllPassages(passages) + // TODO: Why is the command name passed here ? Can it be removed ? .setCommandName(commandName) .build(); - // Build the reranking provider context in grpc request var contextBuilder = EmbeddingGateway.ProviderRerankingRequest.ProviderContext.newBuilder() - .setProviderName(provider) + .setProviderName(modelProvider().apiName()) .setTenantId(tenant.orElse(DEFAULT_TENANT_ID)) .putAuthTokens(DATA_API_TOKEN, authToken.orElse("")); + rerankingCredentials + .apiKey() + .ifPresent(v -> contextBuilder.putAuthTokens(RERANKING_API_KEY, v)); - if (rerankingCredentials.apiKey().isPresent()) { - contextBuilder.putAuthTokens(RERANKING_API_KEY, rerankingCredentials.apiKey().get()); - } - final EmbeddingGateway.ProviderRerankingRequest.ProviderContext providerContext = - contextBuilder.build(); - - // Built the Grpc request - final EmbeddingGateway.ProviderRerankingRequest grpcRerankingRequest = + var gatewayRequest = EmbeddingGateway.ProviderRerankingRequest.newBuilder() - .setRerankingRequest(rerankingRequest) - .setProviderContext(providerContext) + .setRerankingRequest(gatewayReranking) + .setProviderContext(contextBuilder.build()) .build(); - Uni grpcRerankingResponse; + // TODO: XXX Why is this error handling here not part of the uni pipeline? + Uni gatewayRerankingUni; try { - grpcRerankingResponse = rerankingGrpcService.rerank(grpcRerankingRequest); + gatewayRerankingUni = grpcGatewayService.rerank(gatewayRequest); } catch (StatusRuntimeException e) { if (e.getStatus().getCode().equals(Status.Code.DEADLINE_EXCEEDED)) { throw ErrorCodeV1.RERANKING_PROVIDER_TIMEOUT.toApiException(e, e.getMessage()); @@ -101,21 +93,23 @@ public Uni rerank( throw e; } - return grpcRerankingResponse + return gatewayRerankingUni .onItem() .transform( - resp -> { - if (resp.hasError()) { + gatewayResponse -> { + if (gatewayResponse.hasError()) { + // TODO : move to V2 error throw new JsonApiException( - ErrorCodeV1.valueOf(resp.getError().getErrorCode()), - resp.getError().getErrorMessage()); + ErrorCodeV1.valueOf(gatewayResponse.getError().getErrorCode()), + gatewayResponse.getError().getErrorMessage()); } - return RerankingBatchResponse.of( + + return new BatchedRerankingResponse( batchId, - resp.getRanksList().stream() + gatewayResponse.getRanksList().stream() .map(rank -> new Rank(rank.getIndex(), rank.getScore())) .collect(Collectors.toList()), - new Usage(resp.getUsage().getPromptTokens(), resp.getUsage().getTotalTokens())); + createModelUsage(gatewayResponse.getModelUsage())); }); } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java index a22ef6173d..9137815802 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/NvidiaRerankingProvider.java @@ -1,21 +1,21 @@ package io.stargate.sgv2.jsonapi.service.reranking.operation; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.config.constants.HttpConstants; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; -import io.stargate.sgv2.jsonapi.service.embedding.operation.error.RerankingResponseErrorMessageMapper; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProviderResponseValidation; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.POST; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.*; import java.util.concurrent.TimeUnit; @@ -59,109 +59,125 @@ */ public class NvidiaRerankingProvider extends RerankingProvider { - private static final String providerId = ProviderConstants.NVIDIA; - private final NvidiaRerankingClient nvidiaRerankingClient; - - // Nvidia Reranking Service supports truncate or error when the passage is too long. - // Data API use NONE as default, means the reranking request will error out if there is a query - // and - // passage pair that exceeds allowed token size 8192 - // https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html#token-limits-truncation + private final NvidiaRerankingClient nvidiaClient; + + /** + * Nvidia Reranking Service supports truncation or error behavior when the passage is too long. + * + *

The Data API uses {@code NONE} as the default, which means the reranking request will error + * out if there is a query and passage pair that exceeds the allowed token size of 8192. + * + *

See: + * https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/using-reranking.html#token-limits-truncation + */ private static final String TRUNCATE_PASSAGE = "NONE"; public NvidiaRerankingProvider( - String baseUrl, - String modelName, - RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties) { - super(baseUrl, modelName, requestProperties); - nvidiaRerankingClient = + RerankingProvidersConfig.RerankingProviderConfig.ModelConfig modelConfig) { + super(ModelProvider.NVIDIA, modelConfig); + + nvidiaClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(URI.create(baseUrl)) - .readTimeout(requestProperties.readTimeoutMillis(), TimeUnit.MILLISECONDS) + .baseUri(URI.create(modelConfig.url())) + .readTimeout(modelConfig.properties().readTimeoutMillis(), TimeUnit.MILLISECONDS) .build(NvidiaRerankingClient.class); } + @Override + protected String errorMessageJsonPtr() { + return "/message"; + } + + @Override + protected Uni rerank( + int batchId, String query, List passages, RerankingCredentials rerankingCredentials) { + + // TODO: Move error to v2 + var accessToken = + rerankingCredentials + .apiKey() + .map(apiKey -> HttpConstants.BEARER_PREFIX_FOR_API_KEY + apiKey) + .orElseThrow( + () -> + ErrorCodeV1.RERANKING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( + "In order to rerank, please provide the reranking API key.")); + + var nvidiaRequest = + new NvidiaRerankingRequest( + modelName(), + new NvidiaRerankingRequest.TextWrapper(query), + passages.stream().map(NvidiaRerankingRequest.TextWrapper::new).toList(), + TRUNCATE_PASSAGE); + + final long callStartNano = System.nanoTime(); + return retryHTTPCall(nvidiaClient.rerank(accessToken, nvidiaRequest)) + .onItem() + .transform( + jakartaResponse -> { + var nvidiaResponse = decodeResponse(jakartaResponse, NvidiaRerankingResponse.class); + long callDurationNano = System.nanoTime() - callStartNano; + + // converting from the specific Nvidia response to the generic RerankingBatchResponse + var ranks = + nvidiaResponse.rankings().stream() + .map(rank -> new Rank(rank.index(), rank.logit())) + .toList(); + + var modelUsage = + createModelUsage( + rerankingCredentials.tenantId(), + ModelInputType.INPUT_TYPE_UNSPECIFIED, + nvidiaResponse.usage().prompt_tokens, + nvidiaResponse.usage().total_tokens, + jakartaResponse, + callDurationNano); + return new BatchedRerankingResponse(batchId, ranks, modelUsage); + }); + } + + /** + * REST client interface for the Nvidia Reranking Service. + * + *

.. + */ @RegisterRestClient @RegisterProvider(RerankingProviderResponseValidation.class) + @RegisterProvider(ProviderHttpInterceptor.class) public interface NvidiaRerankingClient { @POST @ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON) - Uni rerank( - @HeaderParam("Authorization") String accessToken, RerankingRequest request); - - @ClientExceptionMapper - static RuntimeException mapException(jakarta.ws.rs.core.Response response) { - String errorMessage = getErrorMessage(response); - return RerankingResponseErrorMessageMapper.mapToAPIException( - providerId, response, errorMessage); - } - - private static String getErrorMessage(jakarta.ws.rs.core.Response response) { - // Get the whole response body - JsonNode rootNode = response.readEntity(JsonNode.class); - // Log the response body - logger.error( - "Error response from reranking provider '{}': {}", providerId, rootNode.toString()); - JsonNode messageNode = rootNode.path("message"); - // Return the text of the "message" node, or the whole response body if it is missing - return messageNode.isMissingNode() ? rootNode.toString() : messageNode.toString(); - } + Uni rerank( + @HeaderParam("Authorization") String accessToken, NvidiaRerankingRequest request); } - /** reranking request to the Nvidia Reranking Service */ - private record RerankingRequest( + /** + * Request structure of the NVIDIA REST service. + * + *

.. + */ + public record NvidiaRerankingRequest( String model, TextWrapper query, List passages, String truncate) { + /** * query and passage string needs to be wrapped in with text key for request to the Nvidia * Reranking Service. E.G. { "text": "which way should i go?" } */ - private record TextWrapper(String text) {} + record TextWrapper(String text) {} } - /** reranking response from the Nvidia reranking Service */ + /** + * Response structure of hte NVIDIA REST service. + * + *

.. + */ @JsonIgnoreProperties(ignoreUnknown = true) - private record RerankingResponse(List rankings, Usage usage) { - @JsonIgnoreProperties(ignoreUnknown = true) - private record Ranking(int index, float logit) {} + record NvidiaRerankingResponse(List rankings, NvidiaUsage usage) { @JsonIgnoreProperties(ignoreUnknown = true) - private record Usage(int prompt_tokens, int total_tokens) {} - } - - @Override - public Uni rerank( - int batchId, String query, List passages, RerankingCredentials rerankingCredentials) { - - RerankingRequest request = - new RerankingRequest( - modelName, - new RerankingRequest.TextWrapper(query), - passages.stream().map(RerankingRequest.TextWrapper::new).toList(), - TRUNCATE_PASSAGE); + record NvidiaRanking(int index, float logit) {} - if (rerankingCredentials.apiKey().isEmpty()) { - throw ErrorCodeV1.RERANKING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED.toApiException( - "In order to rerank, please provide the reranking API key."); - } - - Uni response = - applyRetry( - nvidiaRerankingClient.rerank( - HttpConstants.BEARER_PREFIX_FOR_API_KEY + rerankingCredentials.apiKey().get(), - request)); - - return response - .onItem() - .transform( - resp -> { - List ranks = - resp.rankings().stream() - .map(rank -> new Rank(rank.index(), rank.logit())) - .toList(); - Usage usage = new Usage(resp.usage().prompt_tokens(), resp.usage().total_tokens()); - return RerankingBatchResponse.of(batchId, ranks, usage); - }); + @JsonIgnoreProperties(ignoreUnknown = true) + record NvidiaUsage(int prompt_tokens, int total_tokens) {} } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java index db870d2bca..65dcdf6208 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProvider.java @@ -1,61 +1,84 @@ package io.stargate.sgv2.jsonapi.service.reranking.operation; +import static jakarta.ws.rs.core.Response.Status.Family.CLIENT_ERROR; + import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.provider.*; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; +import jakarta.ws.rs.core.Response; import java.time.Duration; import java.util.ArrayList; import java.util.Comparator; import java.util.List; -import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class RerankingProvider { - protected static final Logger logger = LoggerFactory.getLogger(RerankingProvider.class); - protected final String baseUrl; - protected final String modelName; - protected final RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties; +/** A provider for Embedding models , using {@link ModelType#RERANKING} */ +public abstract class RerankingProvider extends ProviderBase { + + protected static final Logger LOGGER = LoggerFactory.getLogger(RerankingProvider.class); + + protected final RerankingProvidersConfig.RerankingProviderConfig.ModelConfig modelConfig; + + protected final Duration initialBackOffDuration; + + protected final Duration maxBackOffDuration; protected RerankingProvider( - String baseUrl, - String modelName, - RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties) { - this.baseUrl = baseUrl; - this.modelName = modelName; - this.requestProperties = requestProperties; + ModelProvider modelProvider, + RerankingProvidersConfig.RerankingProviderConfig.ModelConfig modelConfig) { + super(modelProvider, ModelType.RERANKING); + + this.modelConfig = modelConfig; + + this.initialBackOffDuration = + Duration.ofMillis(modelConfig.properties().initialBackOffMillis()); + this.maxBackOffDuration = Duration.ofMillis(modelConfig.properties().maxBackOffMillis()); } + @Override public String modelName() { - return modelName; + return modelConfig.name(); + } + + @Override + public ApiModelSupport modelSupport() { + return modelConfig.apiModelSupport(); + } + + public ModelUsage createEmptyModelUsage(RerankingCredentials rerankingCredentials) { + return createModelUsage( + rerankingCredentials.tenantId(), ModelInputType.INPUT_TYPE_UNSPECIFIED, 0, 0, 0, 0, 0); } /** - * Gather the results from all batch reranking calls, adjust the indices, so they refer to the - * original passages list, and return a final RerankingResponse as the original order of the - * passages with the reranking score. + * Reranks the texts, batching as needed, and returns a final RerankingResponse as the original + * order of the passages with the reranking score. * - *

E.G. if the original passages list is ["a", "b", "c", "d", "e"] and the micro batch is 2, - * then API will do 3 batch reranking calls: ["a", "b"], ["c", "d"], ["e"]. 3 response will be - * returned: + *

E.G. if the original passages list is ["a", "b", "c", "d", "e"] and the micro + * batch is 2, then API will do 3 batch reranking calls: ["a", "b"], ["c", "d"], ["e"] + * 3 response will be returned: * *

    - *
  • batch 0: [{index:1, score:x1}, {index:0, score:x2}] - *
  • batch 1: [{index:0, score:x3}, {index:1, score:x4}] - *
  • batch 2: [{index:0, score:x5}] + *
  • batch 0: [{index:1, score:x1}, {index:0, score:x2}] + *
  • batch 1: [{index:0, score:x3}, {index:1, score:x4}] + *
  • batch 2: [{index:0, score:x5}] *
* - * Then this method will adjust the indices and return the final response: [{index:0, score:x1}, - * {index:1, score:x2}, {index:2, score:x3}, {index:3, score:x4}, {index:4, score:x5}] + * Then this method will adjust the indices and return the final response: + * [{index:0, score:x1}, + * {index:1, score:x2}, {index:2, score:x3}, {index:3, score:x4}, {index:4, score:x5}] */ public Uni rerank( String query, List passages, RerankingCredentials rerankingCredentials) { + + // TODO: what to do if passages is empty? List> passageBatches = createPassageBatches(passages); - List> batchRerankings = new ArrayList<>(); + List> batchRerankings = new ArrayList<>(); + for (int batchId = 0; batchId < passageBatches.size(); batchId++) { batchRerankings.add( rerank(batchId, query, passageBatches.get(batchId), rerankingCredentials)); @@ -64,67 +87,138 @@ public Uni rerank( return Uni.join().all(batchRerankings).andFailFast().map(this::aggregateRanks); } + /** + * Subclasses must implement to do the reranking, after the batching is done. + * + *

... + */ + protected abstract Uni rerank( + int batchId, String query, List passages, RerankingCredentials rerankingCredentials); + + @Override + protected Duration initialBackOffDuration() { + return initialBackOffDuration; + } + + @Override + protected Duration maxBackOffDuration() { + return maxBackOffDuration; + } + + @Override + protected double jitter() { + return modelConfig.properties().jitter(); + } + + @Override + protected int atMostRetries() { + return modelConfig.properties().atMostRetries(); + } + + @Override + protected boolean decideRetry(Throwable throwable) { + + var retry = + (throwable.getCause() instanceof JsonApiException jae + && jae.getErrorCode() == ErrorCodeV1.RERANKING_PROVIDER_TIMEOUT); + + return retry || super.decideRetry(throwable); + } + + @Override + protected RuntimeException mapHTTPError(Response jakartaResponse, String errorMessage) { + + // TODO: move to V2 errors + + if (jakartaResponse.getStatus() == Response.Status.REQUEST_TIMEOUT.getStatusCode() + || jakartaResponse.getStatus() == Response.Status.GATEWAY_TIMEOUT.getStatusCode()) { + + return ErrorCodeV1.RERANKING_PROVIDER_TIMEOUT.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + if (jakartaResponse.getStatus() == Response.Status.TOO_MANY_REQUESTS.getStatusCode()) { + + return ErrorCodeV1.RERANKING_PROVIDER_RATE_LIMITED.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + if (jakartaResponse.getStatusInfo().getFamily() == CLIENT_ERROR) { + + return ErrorCodeV1.RERANKING_PROVIDER_CLIENT_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + if (jakartaResponse.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) { + + return ErrorCodeV1.RERANKING_PROVIDER_SERVER_ERROR.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + + // All other errors, Should never happen as all errors are covered above + return ErrorCodeV1.RERANKING_PROVIDER_UNEXPECTED_RESPONSE.toApiException( + "Provider: %s; HTTP Status: %s; Error Message: %s", + modelProvider().apiName(), jakartaResponse.getStatus(), errorMessage); + } + /** Create batches of passages to be reranked. */ private List> createPassageBatches(List passages) { + List> batches = new ArrayList<>(); - for (int i = 0; i < passages.size(); i += requestProperties.maxBatchSize()) { + for (int i = 0; i < passages.size(); i += modelConfig.properties().maxBatchSize()) { batches.add( - passages.subList(i, Math.min(i + requestProperties.maxBatchSize(), passages.size()))); + passages.subList( + i, Math.min(i + modelConfig.properties().maxBatchSize(), passages.size()))); } return batches; } /** Aggregate the ranks from all batched reranking calls. */ - private RerankingResponse aggregateRanks(List batchResponses) { + private RerankingResponse aggregateRanks(List batchResponses) { + List finalRanks = new ArrayList<>(); - for (RerankingBatchResponse batchResponse : batchResponses) { - int batchStartIndex = batchResponse.batchId() * requestProperties.maxBatchSize(); + ModelUsage aggregatedModelUsage = null; + + for (BatchedRerankingResponse batchResponse : batchResponses) { + int batchStartIndex = batchResponse.batchId() * modelConfig.properties().maxBatchSize(); + + aggregatedModelUsage = + aggregatedModelUsage == null + ? batchResponse.modelUsage() + : aggregatedModelUsage.merge(batchResponse.modelUsage()); for (Rank rank : batchResponse.ranks()) { finalRanks.add(new Rank(batchStartIndex + rank.index(), rank.score())); } } // This is the original order of all the passages. finalRanks.sort(Comparator.comparingInt(Rank::index)); - return new RerankingResponse(finalRanks); - } - - public record RerankingResponse(List ranks) {} - - /** Micro batch rerank method, which will rerank a batch of passages. */ - public abstract Uni rerank( - int batchId, String query, List passages, RerankingCredentials rerankingCredentials); - - /** The response of a batch rerank call. */ - public record RerankingBatchResponse(int batchId, List ranks, Usage usage) { - public static RerankingBatchResponse of(int batchId, List rankings, Usage usage) { - return new RerankingBatchResponse(batchId, rankings, usage); - } + return new RerankingResponse(finalRanks, aggregatedModelUsage); } - public record Rank(int index, float score) {} + /** + * Unbatched reranking response, returned from the public {@link #rerank(String, List, + * RerankingCredentials)} + * + *

... + */ + public record RerankingResponse(List ranks, ModelUsage modelUsage) {} - public record Usage(int prompt_tokens, int total_tokens) {} + /** + * Unbatched reranking response, returned from the protected {@link #rerank(int, String, List, + * RerankingCredentials)} + * + *

... + */ + public record BatchedRerankingResponse(int batchId, List ranks, ModelUsage modelUsage) {} /** - * Applies a retry mechanism with backoff and jitter to the Uni returned by the rerank() method, - * which makes an HTTP request to a third-party service. + * Individual rank and the index of the input passage. * - * @param The type of the item emitted by the Uni. - * @param uni The Uni to which the retry mechanism should be applied. - * @return A Uni that will retry on the specified failures with the configured backoff and jitter. + *

... */ - protected Uni applyRetry(Uni uni) { - return uni.onFailure( - throwable -> - (throwable.getCause() != null - && throwable.getCause() instanceof JsonApiException jae - && jae.getErrorCode() == ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT) - || throwable instanceof TimeoutException) - .retry() - .withBackOff( - Duration.ofMillis(requestProperties.initialBackOffMillis()), - Duration.ofMillis(requestProperties.maxBackOffMillis())) - .withJitter(requestProperties.jitter()) - .atMost(requestProperties.atMostRetries()); - } + public record Rank(int index, float score) {} } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java index 308521dc56..c0d1917e1e 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderFactory.java @@ -4,59 +4,86 @@ import io.stargate.embedding.gateway.RerankingService; import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; import io.stargate.sgv2.jsonapi.service.reranking.gateway.RerankingEGWClient; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import java.util.Map; import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @ApplicationScoped public class RerankingProviderFactory { + private static final Logger LOGGER = LoggerFactory.getLogger(RerankingProviderFactory.class); + @Inject RerankingProvidersConfig rerankingConfig; @Inject OperationsConfig operationsConfig; @GrpcClient("embedding") - RerankingService rerankingGrpcService; + RerankingService grpcGatewayService; + @FunctionalInterface interface ProviderConstructor { RerankingProvider create( - String baseUrl, - String modelName, - RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties); + RerankingProvidersConfig.RerankingProviderConfig.ModelConfig modelConfig); } - private static final Map RERANKING_PROVIDER_CONSTRUCTOR_MAP = - Map.ofEntries(Map.entry("nvidia", NvidiaRerankingProvider::new)); + private static final Map RERANKING_PROVIDER_CTORS = + Map.ofEntries(Map.entry(ModelProvider.NVIDIA, NvidiaRerankingProvider::new)); - public RerankingProvider getConfiguration( + public RerankingProvider create( Optional tenant, Optional authToken, String serviceName, String modelName, Map authentication, String commandName) { - return addService(tenant, authToken, serviceName, modelName, authentication, commandName); + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "create() - tenant: {}, serviceName: {}, modelName: {}, commandName: {}", + tenant, + serviceName, + modelName, + commandName); + } + + var modelProvider = + ModelProvider.fromApiName(serviceName) + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Unknown reranking service provider '%s'", serviceName))); + return create(tenant, authToken, modelProvider, modelName, authentication, commandName); } - private synchronized RerankingProvider addService( + private synchronized RerankingProvider create( Optional tenant, Optional authToken, - String serviceName, + ModelProvider modelProvider, String modelName, Map authentication, String commandName) { - final RerankingProvidersConfig.RerankingProviderConfig configuration = - rerankingConfig.providers().get(serviceName); - RerankingProviderFactory.ProviderConstructor ctor = - RERANKING_PROVIDER_CONSTRUCTOR_MAP.get(serviceName); - if (ctor == null) { + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace( + "create() - tenant: {}, modelProvider: {}, modelName: {}, commandName: {}", + tenant, + modelProvider, + modelName, + commandName); + } + + var providerConfig = rerankingConfig.providers().get(modelProvider.apiName()); + if (providerConfig == null) { throw ErrorCodeV1.RERANKING_SERVICE_TYPE_UNAVAILABLE.toApiException( - "unknown service provider '%s'", serviceName); + "unknown reranking service provider '%s'", modelProvider.apiName()); } + var modelConfig = - configuration.models().stream() + providerConfig.models().stream() .filter(model -> model.name().equals(modelName)) .findFirst() .orElseThrow( @@ -67,18 +94,21 @@ private synchronized RerankingProvider addService( if (operationsConfig.enableEmbeddingGateway()) { // return the reranking Grpc client to embedding gateway service return new RerankingEGWClient( - modelConfig.url(), - modelConfig.properties(), - serviceName, + modelProvider, + modelConfig, tenant, authToken, - modelName, - rerankingGrpcService, + grpcGatewayService, authentication, commandName); } - return ctor.create(modelConfig.url(), modelConfig.name(), modelConfig.properties()); + RerankingProviderFactory.ProviderConstructor ctor = RERANKING_PROVIDER_CTORS.get(modelProvider); + if (ctor == null) { + throw ErrorCodeV1.RERANKING_SERVICE_TYPE_UNAVAILABLE.toApiException( + "unknown service provider '%s'", modelProvider.apiName()); + } + return ctor.create(modelConfig); } public RerankingProvidersConfig getRerankingConfig() { diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java index 77fd333f96..b8257f7df9 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/FindAndRerankOperationBuilder.java @@ -194,7 +194,7 @@ private void checkSupported() { RerankingProvider rerankingProvider = commandContext .rerankingProviderFactory() - .getConfiguration( + .create( commandContext.requestContext().getTenantId(), commandContext.requestContext().getCassandraToken(), providerConfig.provider(), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java index 6ea3ebdcd8..93a5710340 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/resolver/VectorizeConfigValidator.java @@ -6,8 +6,8 @@ import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.exception.SchemaException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import java.util.ArrayList; @@ -53,7 +53,7 @@ public VectorizeConfigValidator( */ public Integer validateService(VectorizeConfig userConfig, Integer userVectorDimension) { // Only for internal tests - if (userConfig.provider().equals(ProviderConstants.CUSTOM)) { + if (userConfig.provider().equals(ModelProvider.CUSTOM.apiName())) { return userVectorDimension; } // Check if the service provider exists and is enabled @@ -329,10 +329,10 @@ private Integer validateModelAndDimension( // Find the model configuration by matching the model name // 1. huggingfaceDedicated does not require model, but requires dimension - if (userConfig.provider().equals(ProviderConstants.HUGGINGFACE_DEDICATED)) { + if (userConfig.provider().equals(ModelProvider.HUGGINGFACE_DEDICATED.apiName())) { if (userVectorDimension == null) { throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException( - "'dimension' is needed for provider %s", ProviderConstants.HUGGINGFACE_DEDICATED); + "'dimension' is needed for provider %s", ModelProvider.HUGGINGFACE_DEDICATED.apiName()); } } diff --git a/src/main/proto/embedding_gateway.proto b/src/main/proto/embedding_gateway.proto index f2172f9110..1305cc4134 100644 --- a/src/main/proto/embedding_gateway.proto +++ b/src/main/proto/embedding_gateway.proto @@ -5,39 +5,39 @@ option java_package = "io.stargate.embedding.gateway"; package stargate; // The request message that is sent to embedding gateway gRPC API message ProviderEmbedRequest { - ProviderContext provider_context = 1; - EmbeddingRequest embedding_request = 2; + ProviderContext provider_context = 1; + EmbeddingRequest embedding_request = 2; // The provider context message for the embedding gateway gRPC API message ProviderContext { - string provider_name = 1; - string tenant_id = 2; - map auth_tokens = 3; + string provider_name = 1; + string tenant_id = 2; + map auth_tokens = 3; } // The request message for the embedding gateway gRPC API message EmbeddingRequest { // The model name for the embedding request - string model_name = 1; + string model_name = 1; // The dimensions of the embedding, some providers supports multiple dimensions - optional int32 dimensions = 2; + optional int32 dimensions = 2; // The parameter value, used when provided needs user specified parameters - map parameters = 3; + map parameters = 3; // The input type for the embedding request - InputType input_type = 4; + InputType input_type = 4; // The input data that needs to be vectorized - repeated string inputs = 5; + repeated string inputs = 5; // The command contains vectorize - string command_name = 6; + string command_name = 6; // The parameter value, used when provided needs user specified parameters message ParameterValue { oneof ParameterValueOneOf { - string str_value = 1; - int32 int_value = 2; - float float_value = 3; - bool bool_value = 4; + string str_value = 1; + int32 int_value = 2; + float float_value = 3; + bool bool_value = 4; } } @@ -53,28 +53,16 @@ message ProviderEmbedRequest { // The response message for the embedding gateway gRPC API if successful message EmbeddingResponse { - Usage usage = 1; - repeated FloatEmbedding embeddings = 2; + ModelUsage modelUsage = 1; + repeated FloatEmbedding embeddings = 2; ErrorResponse error = 3; // The embedding response message message FloatEmbedding { // The index of the embedding corresponding to the input - int32 index = 1; + int32 index = 1; // The embedding values - repeated float embedding = 2; - } - - // The usage statistics for the embedding gateway gRPC API on successful response - message Usage { - string provider_name = 1; - string model_name = 2; - string tenant_id = 3; - int32 prompt_tokens = 4; - int32 total_tokens = 5; - int32 input_bytes = 6; - int32 output_bytes = 7; - int32 call_duration_us = 8; + repeated float embedding = 2; } // The error response message for the embedding gateway gRPC API @@ -90,20 +78,20 @@ message GetSupportedProvidersRequest {} // The response message for the get supported providers gRPC API if successful message GetSupportedProvidersResponse { - map supportedProviders = 1; - ErrorResponse error = 2; + map supportedProviders = 1; + ErrorResponse error = 2; // ProviderConfig message represents configuration for an embedding provider. message ProviderConfig { - string displayName = 1; - bool enabled = 2; + string displayName = 1; + bool enabled = 2; optional string url = 3; // No AuthenticationType Enum, since enum can not be key of map in grpc message map supported_authentications = 4; - repeated ParameterConfig parameters = 5; - RequestProperties properties = 6; - repeated ModelConfig models = 7; - bool authTokenPassThroughForNoneAuth = 8; + repeated ParameterConfig parameters = 5; + RequestProperties properties = 6; + repeated ModelConfig models = 7; + bool authTokenPassThroughForNoneAuth = 8; message AuthenticationConfig{ @@ -207,20 +195,20 @@ service EmbeddingService { // The reranking request message that is sent to embedding gateway gRPC API message ProviderRerankingRequest { - ProviderContext provider_context = 1; - RerankingRequest Reranking_request = 2; + ProviderContext provider_context = 1; + RerankingRequest Reranking_request = 2; message ProviderContext { - string provider_name = 1; - string tenant_id = 2; - map auth_tokens = 3; + string provider_name = 1; + string tenant_id = 2; + map auth_tokens = 3; } message RerankingRequest { // The model name for the reranking request - string model_name = 1; + string model_name = 1; // The query text for the reranking request - string query = 2; + string query = 2; // The passages texts for the reranking request - repeated string passages = 3; + repeated string passages = 3; // The command contains reranking string command_name = 4; } @@ -229,22 +217,16 @@ message ProviderRerankingRequest { // The reranking response message for the embedding gateway gRPC API if successful message RerankingResponse { - Usage usage = 1; - repeated Rank ranks = 2; - ErrorResponse error = 3; + ModelUsage modelUsage = 1; + repeated Rank ranks = 2; + ErrorResponse error = 3; // Reranking result for each passage message Rank { // The rank index of the passage - int32 index = 1; + int32 index = 1; // The rank score value of the passage - float score = 2; - } - - // The usage statistics of reranking for the embedding gateway gRPC API on successful response - message Usage { - int32 prompt_tokens = 1; - int32 total_tokens = 2; + float score = 2; } message ErrorResponse { @@ -258,7 +240,7 @@ message GetSupportedRerankingProvidersRequest {} // The response message for the get supported reranking providers gRPC API if successful message GetSupportedRerankingProvidersResponse { - map supportedProviders = 1; + map supportedProviders = 1; ErrorResponse error = 2; // ProviderConfig message represents configuration for an reranking provider. @@ -316,3 +298,35 @@ service RerankingService { rpc Rerank (ProviderRerankingRequest) returns (RerankingResponse) {} rpc GetSupportedRerankingProviders (GetSupportedRerankingProvidersRequest) returns (GetSupportedRerankingProvidersResponse){} } + +// Common structure for all model usage tracking, is included in response messages +message ModelUsage { + string model_provider = 1; + ModelType model_type = 2; + string model_name = 3; + string tenant_id = 4; + InputType input_type = 5; + // tokens sent in the request + int32 prompt_tokens = 6; + // total tokens the request will be billed for + int32 total_tokens = 7; + // number of bytes in the outgoing http request sent to the provider + int32 request_bytes = 8; + // number of bytes in the response received from the provider + int32 response_bytes = 9; + int64 call_duration_nanos = 10; + + // If the model usage was for indexing data or searching data + enum InputType { + INPUT_TYPE_UNSPECIFIED = 0; + INDEX = 1; + SEARCH = 2; + } + + enum ModelType { + MODEL_TYPE_UNSPECIFIED = 0; + EMBEDDING = 1; + RERANKING = 2; + } +} + diff --git a/src/main/resources/embedding-providers-config.yaml b/src/main/resources/embedding-providers-config.yaml index 0549879a93..06da3fea6d 100644 --- a/src/main/resources/embedding-providers-config.yaml +++ b/src/main/resources/embedding-providers-config.yaml @@ -272,7 +272,7 @@ stargate: required: false default-value: true help: "If set to false, text that exceeds the token limit causes the request to fail. The default value is true." - # OUT OF SCOPE FOR INITIAL PREVIEW + # COHERE was OUT OF SCOPE FOR INITIAL PREVIEW, TODO: decide if we want to enable, drop if not. cohere: # see https://docs.cohere.com/reference/embed display-name: Cohere diff --git a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java index d0536c24e9..8eb4a9edfb 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/TestConstants.java @@ -1,13 +1,16 @@ package io.stargate.sgv2.jsonapi; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import io.micrometer.core.instrument.MeterRegistry; import io.stargate.sgv2.jsonapi.api.model.command.CommandConfig; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; +import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentialsSupplier; import io.stargate.sgv2.jsonapi.api.request.RequestContext; +import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants; import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter; import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache; @@ -110,9 +113,17 @@ public CommandContext collectionContext( JsonProcessingMetricsReporter metricsReporter, EmbeddingProvider embeddingProvider) { + var embeddingCredentials = mock(EmbeddingCredentials.class); + when(embeddingCredentials.tenantId()).thenReturn("test-tenant"); + when(embeddingCredentials.apiKey()).thenReturn(Optional.of("test-apiKey")); + when(embeddingCredentials.accessId()).thenReturn(Optional.of("test-accessId")); + when(embeddingCredentials.secretId()).thenReturn(Optional.of("test-secretId")); + + var embeddingCredentialsSupplier = mock(EmbeddingCredentialsSupplier.class); + when(embeddingCredentialsSupplier.create(any(), any())).thenReturn(embeddingCredentials); + var requestContext = mock(RequestContext.class); - when(requestContext.getEmbeddingCredentialsSupplier()) - .thenReturn(mock(EmbeddingCredentialsSupplier.class)); + when(requestContext.getEmbeddingCredentialsSupplier()).thenReturn(embeddingCredentialsSupplier); when(requestContext.getTenantId()).thenReturn(Optional.of("test-tenant")); return CommandContext.builderSupplier() @@ -135,6 +146,14 @@ public CommandContext keyspaceContext() { TEST_COMMAND_NAME, KEYSPACE_SCHEMA_OBJECT, mock(JsonProcessingMetricsReporter.class)); } + public RequestContext requestContext() { + return new RequestContext( + Optional.of("test-tenant"), + Optional.empty(), + new RerankingCredentials("test-tenant", Optional.empty()), + "test-user-agent"); + } + public CommandContext keyspaceContext( String commandName, KeyspaceSchemaObject schema, @@ -150,7 +169,7 @@ public CommandContext keyspaceContext( .withMeterRegistry(mock(MeterRegistry.class)) .getBuilder(schema) .withCommandName(commandName) - .withRequestContext(new RequestContext(Optional.of("test-tenant"))) + .withRequestContext(requestContext()) .build(); } @@ -164,7 +183,7 @@ public CommandContext databaseContext() { .withMeterRegistry(mock(MeterRegistry.class)) .getBuilder(DATABASE_SCHEMA_OBJECT) .withCommandName(TEST_COMMAND_NAME) - .withRequestContext(new RequestContext(Optional.of("test-tenant"))) + .withRequestContext(requestContext()) .build(); } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalIntegrationTest.java index 8581dda456..f4afcb2537 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/FindCollectionWithLexicalIntegrationTest.java @@ -56,16 +56,16 @@ void createCollectionWithLexical() { // Create a Collection with default Lexical settings createComplexCollection( """ - { - "name": "%s", - "options" : { - "lexical": { - "enabled": true, - "analyzer": "standard" - } - } + { + "name": "%s", + "options" : { + "lexical": { + "enabled": true, + "analyzer": "standard" } - """ + } + } + """ .formatted(COLLECTION_WITH_LEXICAL)); // And then insert 5 documents insertDoc(COLLECTION_WITH_LEXICAL, DOC1_JSON); @@ -80,15 +80,15 @@ void createCollectionWithoutLexical() { // Create a Collection with lexical feature disabled createComplexCollection( """ - { - "name": "%s", - "options" : { - "lexical": { - "enabled": false - } - } + { + "name": "%s", + "options" : { + "lexical": { + "enabled": false } - """ + } + } + """ .formatted(COLLECTION_WITHOUT_LEXICAL)); } } @@ -103,12 +103,12 @@ void findManyWithLexicalSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "sort" : {"$lexical": "banana" } - } - } - """) + { + "find": { + "sort" : {"$lexical": "banana" } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(2)) .body("data.documents[0]._id", is("lexical-1")) @@ -121,16 +121,16 @@ void findManyWithOnlyLexicalFilter() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "filter" : { - "$lexical": { - "$match": "biking" - } - } - } - } - """) + { + "find": { + "filter" : { + "$lexical": { + "$match": "biking" + } + } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) .body("data.documents[0]._id", is("lexical-3")); @@ -143,17 +143,17 @@ void findManyWithLexicalAndOtherFilter() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "filter" : { - "$and": [ - { "$lexical": { "$match": "banana" } }, - { "tag": "bottom" } - ] - } - } - } - """) + { + "find": { + "filter" : { + "$and": [ + { "$lexical": { "$match": "banana" } }, + { "tag": "bottom" } + ] + } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.documents", hasSize(1)) .body("data.documents[0]._id", is("lexical-4")); @@ -170,13 +170,13 @@ void findOneWithLexicalSortBiking() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOne": { - "projection": {"$lexical": 1 }, - "sort" : {"$lexical": "biking" } - } - } - """) + { + "findOne": { + "projection": {"$lexical": 1 }, + "sort" : {"$lexical": "biking" } + } + } + """) .body("$", responseIsFindSuccess()) // Needs to get "lexical-3" with "biking fun" .body("data.document", jsonEquals(DOC3_JSON)); @@ -188,13 +188,13 @@ void findOneWithLexicalSortMonkeyBananas() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOne": { - "projection": {"$lexical": 1 }, - "sort" : {"$lexical": "monkey banana" } - } - } - """) + { + "findOne": { + "projection": {"$lexical": 1 }, + "sort" : {"$lexical": "monkey banana" } + } + } + """) .body("$", responseIsFindSuccess()) // Needs to get "lexical-1" with "monkey banana" .body("data.document", jsonEquals(DOC1_JSON)); @@ -206,13 +206,13 @@ void findOneWithOnlyLexicalFilter() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOne": { - "projection": {"$lexical": 1 }, - "filter" : {"$lexical": {"$match": "bread butter" } } - } - } - """) + { + "findOne": { + "projection": {"$lexical": 1 }, + "filter" : {"$lexical": {"$match": "bread butter" } } + } + } + """) .body("$", responseIsFindSuccess()) // Needs to get "lexical-4" .body("data.document", jsonEquals(DOC4_JSON)); @@ -229,12 +229,12 @@ void failSortIfLexicalDisabledForCollection() { keyspaceName, COLLECTION_WITHOUT_LEXICAL, """ - { - "find": { - "sort" : {"$lexical": "banana" } - } - } - """) + { + "find": { + "sort" : {"$lexical": "banana" } + } + } + """) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("LEXICAL_NOT_ENABLED_FOR_COLLECTION")) .body( @@ -248,12 +248,12 @@ void failFilterIfLexicalDisabledForCollection() { keyspaceName, COLLECTION_WITHOUT_LEXICAL, """ - { - "find": { - "filter" : {"$lexical": {"$match": "banana" } } - } - } - """) + { + "find": { + "filter" : {"$lexical": {"$match": "banana" } } + } + } + """) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("LEXICAL_NOT_ENABLED_FOR_COLLECTION")) .body( @@ -267,12 +267,12 @@ void failForBadLexicalSortValueType() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "sort" : {"$lexical": -1 } - } - } - """) + { + "find": { + "sort" : {"$lexical": -1 } + } + } + """) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_SORT_CLAUSE")) .body( @@ -286,12 +286,12 @@ void failForBadLexicalFilterValueType() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "filter" : {"$lexical": {"$match": [ 1, 2, 3 ] } } - } - } - """) + { + "find": { + "filter" : {"$lexical": {"$match": [ 1, 2, 3 ] } } + } + } + """) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) .body( @@ -306,15 +306,15 @@ void failForLexicalSortWithOtherExpressions() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "sort" : { - "a": 1, - "$lexical": "bananas" - } - } - } - """) + { + "find": { + "sort" : { + "a": 1, + "$lexical": "bananas" + } + } + } + """) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_SORT_CLAUSE")) .body( @@ -329,12 +329,12 @@ void failForLexicalFilterWithNot() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "filter" : {"$not": {"$lexical": {"$match": "banana" } }}} - } - } - """) + { + "find": { + "filter" : {"$not": {"$lexical": {"$match": "banana" } }}} + } + } + """) .body("errors", hasSize(1)) .body("errors[0].errorCode", is("INVALID_FILTER_EXPRESSION")) .body( @@ -376,15 +376,15 @@ void findOneAndUpdateWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOneAndUpdate": { - "sort": { "$lexical": "banana" }, - "update" : {"$set" : {"value": "value1-updated"}}, - "projection": {"$lexical": 1 }, - "options": {"returnDocument": "after"} - } - } - """) + { + "findOneAndUpdate": { + "sort": { "$lexical": "banana" }, + "update" : {"$set" : {"value": "value1-updated"}}, + "projection": {"$lexical": 1 }, + "options": {"returnDocument": "after"} + } + } + """) .body("data.document", jsonEquals(expectedAfterChange)) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); @@ -393,13 +393,13 @@ void findOneAndUpdateWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOne": { - "filter" : {"_id" : "lexical-1"}, - "projection": {"*": 1 } - } - } - """) + { + "findOne": { + "filter" : {"_id" : "lexical-1"}, + "projection": {"*": 1 } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.document", jsonEquals(expectedAfterChange)); } @@ -416,13 +416,13 @@ void updateOneWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "updateOne": { - "sort": { "$lexical": "banana" }, - "update" : {"$set" : {"value": "value1-updated-2"}} - } - } - """) + { + "updateOne": { + "sort": { "$lexical": "banana" }, + "update" : {"$set" : {"value": "value1-updated-2"}} + } + } + """) .body("status.matchedCount", is(1)) .body("status.modifiedCount", is(1)); // Plus query to check that the document was updated @@ -430,13 +430,13 @@ void updateOneWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOne": { - "filter" : {"_id" : "lexical-1"}, - "projection": {"*": 1 } - } - } - """) + { + "findOne": { + "filter" : {"_id" : "lexical-1"}, + "projection": {"*": 1 } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.document", jsonEquals(expectedAfterChange)); } @@ -453,15 +453,15 @@ void findOneAndReplaceWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOneAndReplace": { - "sort": { "$lexical": "banana" }, - "replacement" : %s, - "projection": {"$lexical": 1 }, - "options": {"returnDocument": "after"} - } - } - """ + { + "findOneAndReplace": { + "sort": { "$lexical": "banana" }, + "replacement" : %s, + "projection": {"$lexical": 1 }, + "options": {"returnDocument": "after"} + } + } + """ .formatted(expectedAfterChange)) .body("data.document", jsonEquals(expectedAfterChange)) .body("status.matchedCount", is(1)) @@ -471,13 +471,13 @@ void findOneAndReplaceWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOne": { - "filter" : {"_id" : "lexical-1"}, - "projection": {"*": 1 } - } - } - """) + { + "findOne": { + "filter" : {"_id" : "lexical-1"}, + "projection": {"*": 1 } + } + } + """) .body("$", responseIsFindSuccess()) .body("data.document", jsonEquals(expectedAfterChange)); } @@ -493,13 +493,13 @@ void findOneAndDeleteWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "findOneAndDelete": { - "sort": { "$lexical": "monkey" }, - "projection": {"$lexical": 1 } - } - } - """) + { + "findOneAndDelete": { + "sort": { "$lexical": "monkey" }, + "projection": {"$lexical": 1 } + } + } + """) .body("status.deletedCount", is(1)) .body("data.document", jsonEquals(DOC2_JSON)); @@ -508,12 +508,12 @@ void findOneAndDeleteWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "projection": {"_id": 1, "value": 0, "tag": 0 } - } - } - """) + { + "find": { + "projection": {"_id": 1, "value": 0, "tag": 0 } + } + } + """) .body("$", responseIsFindSuccess()) .body( "data.documents", @@ -536,13 +536,13 @@ void deleteOneWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "deleteOne": { - "filter": { }, - "sort": { "$lexical": "biking" } - } - } - """) + { + "deleteOne": { + "filter": { }, + "sort": { "$lexical": "biking" } + } + } + """) .body("$", responseIsStatusOnly()) .body("status.deletedCount", is(1)); @@ -551,12 +551,12 @@ void deleteOneWithSort() { keyspaceName, COLLECTION_WITH_LEXICAL, """ - { - "find": { - "projection": {"_id": 1, "value": 0 } - } - } - """) + { + "find": { + "projection": {"_id": 1, "value": 0 } + } + } + """) .body("$", responseIsFindSuccess()) .body( "data.documents", diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java index 416a3af54e..498da2bdc6 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java @@ -1735,7 +1735,7 @@ void insertDifferentVectorizeDimensions() { """) .hasSingleApiError( ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.name(), - "The Embedding Provider returned a HTTP client error: Provider: openai; HTTP Status: 401; Error Message: Incorrect API key provided: test_emb"); + "The Embedding Provider returned a HTTP client error: Provider: openai; HTTP Status: 401; Error Message: \"Incorrect API key provided: test_emb"); } @Order(2) @@ -1783,7 +1783,7 @@ void insertDifferentVectorizeModels() { """) .hasSingleApiError( ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR.name(), - "The Embedding Provider returned a HTTP client error: Provider: openai; HTTP Status: 401; Error Message: Incorrect API key provided: test_emb"); + "The Embedding Provider returned a HTTP client error: Provider: openai; HTTP Status: 401; Error Message: \"Incorrect API key provided: test_emb"); } @Order(3) @@ -1832,7 +1832,7 @@ void insertDifferentVectorizeProviders() { ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR, anyOf( containsString( - "Provider: openai; HTTP Status: 401; Error Message: Incorrect API key provided: test_emb"), + "Provider: openai; HTTP Status: 401; Error Message: \"Incorrect API key provided: test_emb"), containsString( "Provider: jinaAI; HTTP Status: 401; Error Message: \"Unauthorized\""))); } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java index 1b8e5d446e..957fe8ef1d 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/DataVectorizerTest.java @@ -20,6 +20,7 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorizeDefinition; import io.stargate.sgv2.jsonapi.service.embedding.DataVectorizer; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; import io.stargate.sgv2.jsonapi.service.schema.EmbeddingSourceModel; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionLexicalConfig; @@ -43,7 +44,7 @@ public class DataVectorizerTest { TestEmbeddingProvider.TEST_EMBEDDING_PROVIDER; private final EmbeddingProvider testService = testEmbeddingProvider; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.empty(), Optional.empty(), Optional.empty()); + new EmbeddingCredentials("test-tenant", Optional.empty(), Optional.empty(), Optional.empty()); private CollectionSchemaObject collectionSettings = null; @@ -197,7 +198,7 @@ public void testWithUnmatchedVectorsNumber() { TestEmbeddingProvider testProvider = new TestEmbeddingProvider() { @Override - public Uni vectorize( + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, @@ -206,7 +207,18 @@ public Uni vectorize( texts.forEach(t -> customResponse.add(new float[] {0.5f, 0.5f, 0.5f})); // add additional vector customResponse.add(new float[] {0.5f, 0.5f, 0.5f}); - return Uni.createFrom().item(Response.of(batchId, customResponse)); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom() + .item(new BatchedEmbeddingResponse(batchId, customResponse, modelUsage)); } }; List documents = new ArrayList<>(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java index 00eaddb770..e790891061 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingClientTestResource.java @@ -9,11 +9,27 @@ import jakarta.ws.rs.core.MediaType; import java.util.Map; +// TODO: WRITE SOME DAM DOCUMENTATION ! @QuarkusTest public class EmbeddingClientTestResource implements QuarkusTestResourceLifecycleManager { private WireMockServer wireMockServer; + // NOTE: These are the host and path to use with this lifecycle manager. + // previously the start() methods returned below to override quarkus config properties, + // return Map.of( + // "stargate.jsonapi.embedding.providers.nvidia.url", + // wireMockServer.baseUrl() + "/v1/embeddings", + // "stargate.jsonapi.embedding.providers.openai.url", + // wireMockServer.baseUrl() + "/v1/"); + + public static final String HOST = "http://localhost:8080"; + public static final String NVIDIA_PATH = "/v1/embeddings"; + public static final String OPENAI_PATH = "/v1"; + + public static final String NVIDIA_URL = HOST + NVIDIA_PATH; + public static final String OPENAI_URL = HOST + OPENAI_PATH; + @Override public Map start() { wireMockServer = new WireMockServer(); @@ -179,11 +195,7 @@ public Map start() { .withStatus(401) .withStatusMessage("Unauthorized"))); - return Map.of( - "stargate.jsonapi.embedding.providers.nvidia.url", - wireMockServer.baseUrl() + "/v1/embeddings", - "stargate.jsonapi.embedding.providers.openai.url", - wireMockServer.baseUrl() + "/v1/"); + return Map.of(); } @Override diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java index a355a2dce7..5d31cc84be 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java @@ -14,8 +14,14 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfigImpl; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.gateway.EmbeddingGatewayClient; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ModelType; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import java.util.Arrays; import java.util.List; @@ -30,11 +36,55 @@ public class EmbeddingGatewayClientTest { public static final String TESTING_COMMAND_NAME = "test_command"; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.empty(), Optional.empty(), Optional.empty()); + new EmbeddingCredentials("test-tenant", Optional.empty(), Optional.empty(), Optional.empty()); + + private static final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig MODEL_CONFIG = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl( + "testModel", + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), + Optional.empty(), + List.of(), + Map.of(), + Optional.empty()); + + private static final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl + .RequestPropertiesImpl + REQUEST_PROPERTIES = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.RequestPropertiesImpl( + 3, 10, 100, 100, 0.5, Optional.empty(), Optional.empty(), Optional.empty(), 10); + + private static final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl PROVIDER_CONFIG = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl( + ModelProvider.CUSTOM.apiName(), + true, + Optional.of("http://testing.com"), + false, + Map.of(), + List.of(), + REQUEST_PROPERTIES, + List.of()); + + private final ServiceConfigStore.ServiceConfig SERVICE_CONFIG = + new ServiceConfigStore.ServiceConfig( + ModelProvider.CUSTOM, + "http://testing.com", + Optional.empty(), + new ServiceConfigStore.ServiceRequestProperties( + REQUEST_PROPERTIES.atMostRetries(), + REQUEST_PROPERTIES.initialBackOffMillis(), + REQUEST_PROPERTIES.readTimeoutMillis(), + REQUEST_PROPERTIES.maxBackOffMillis(), + REQUEST_PROPERTIES.jitter(), + REQUEST_PROPERTIES.taskTypeRead(), + REQUEST_PROPERTIES.taskTypeStore(), + REQUEST_PROPERTIES.maxBatchSize()), + null); // aaron -passing null here to bypass url overrides // for [data-api#1088] (NPE for VoyageAI provider) @Test void verifyDirectConstructionWithNullServiceParameters() { + List providerCtors = Arrays.asList( AzureOpenAIEmbeddingProvider::new, @@ -47,56 +97,62 @@ void verifyDirectConstructionWithNullServiceParameters() { UpstageAIEmbeddingProvider::new, VertexAIEmbeddingProvider::new, VoyageAIEmbeddingProvider::new); + for (EmbeddingProviderFactory.ProviderConstructor ctor : providerCtors) { - EmbeddingProviderConfigStore.RequestProperties requestProperties = - EmbeddingProviderConfigStore.RequestProperties.of( - 3, 5, 5000, 5, 0.5, Optional.empty(), Optional.empty(), 2048); - assertThat( - ctor.create( - requestProperties, - "baseUrl", - TestEmbeddingProvider.TEST_MODEL_CONFIG, - 5, - null, - null)) - .isNotNull(); + + assertThat(ctor.create(PROVIDER_CONFIG, MODEL_CONFIG, SERVICE_CONFIG, 5, null)).isNotNull(); } } @Test void handleValidResponse() { + + var floatEmbeddingBuilder = + EmbeddingGateway.EmbeddingResponse.FloatEmbedding.newBuilder() + .addEmbedding(0.5f) + .addEmbedding(0.5f) + .addEmbedding(0.5f) + .addEmbedding(0.5f) + .addEmbedding(0.5f); + + var modelUsageBuilder = + EmbeddingGateway.ModelUsage.newBuilder() + .setModelProvider(ModelProvider.OPENAI.apiName()) + .setModelType(EmbeddingGateway.ModelUsage.ModelType.EMBEDDING) + .setModelName("test-model") + .setTenantId("test-tenant") + .setInputType(EmbeddingGateway.ModelUsage.InputType.INDEX) + .setPromptTokens(5) + .setTotalTokens(5) + .setRequestBytes(100) + .setResponseBytes(100) + .setCallDurationNanos(20000); + + var embeddingResonseBuilder = + EmbeddingGateway.EmbeddingResponse.newBuilder() + .addEmbeddings(floatEmbeddingBuilder.build()) + .addEmbeddings(floatEmbeddingBuilder.build()) + .setModelUsage(modelUsageBuilder.build()); + EmbeddingService embeddingService = mock(EmbeddingService.class); - final EmbeddingGateway.EmbeddingResponse.Builder builder = - EmbeddingGateway.EmbeddingResponse.newBuilder(); - EmbeddingGateway.EmbeddingResponse.FloatEmbedding.Builder floatEmbeddingBuilder = - EmbeddingGateway.EmbeddingResponse.FloatEmbedding.newBuilder(); - floatEmbeddingBuilder - .addEmbedding(0.5f) - .addEmbedding(0.5f) - .addEmbedding(0.5f) - .addEmbedding(0.5f) - .addEmbedding(0.5f); - builder - .addEmbeddings(floatEmbeddingBuilder.build()) - .addEmbeddings(floatEmbeddingBuilder.build()); + when(embeddingService.embed(any())) + .thenReturn(Uni.createFrom().item(embeddingResonseBuilder.build())); - when(embeddingService.embed(any())).thenReturn(Uni.createFrom().item(builder.build())); EmbeddingGatewayClient embeddingGatewayClient = new EmbeddingGatewayClient( - EmbeddingProviderConfigStore.RequestProperties.of( - 5, 5, 5, 5, 0.5, Optional.empty(), Optional.empty(), 2048), - "openai", + ModelProvider.OPENAI, + PROVIDER_CONFIG, + MODEL_CONFIG, + SERVICE_CONFIG, 1536, + Map.of(), Optional.of("default"), Optional.of("default"), - "https://api.openai.com/v1/", - "text-embedding-3-small", embeddingService, Map.of(), - Map.of(), TESTING_COMMAND_NAME); - final EmbeddingProvider.Response response = + final EmbeddingProvider.BatchedEmbeddingResponse response = embeddingGatewayClient .vectorize( 1, @@ -114,36 +170,51 @@ void handleValidResponse() { assertThat(response.embeddings().size()).isEqualTo(2); assertThat(response.embeddings().get(0).length).isEqualTo(5); assertThat(response.embeddings().get(1).length).isEqualTo(5); + + assertThat(response.modelUsage()).isNotNull(); + assertThat(response.modelUsage().modelProvider()).isEqualTo(ModelProvider.OPENAI); + assertThat(response.modelUsage().modelType()).isEqualTo(ModelType.EMBEDDING); + assertThat(response.modelUsage().modelName()).isEqualTo("test-model"); + assertThat(response.modelUsage().tenantId()).isEqualTo("test-tenant"); + assertThat(response.modelUsage().inputType()).isEqualTo(ModelInputType.INDEX); + + assertThat(response.modelUsage().promptTokens()).isEqualTo(5); + assertThat(response.modelUsage().totalTokens()).isEqualTo(5); + assertThat(response.modelUsage().requestBytes()).isEqualTo(100); + assertThat(response.modelUsage().responseBytes()).isEqualTo(100); + assertThat(response.modelUsage().durationNanos()).isEqualTo(20000); + assertThat(response.modelUsage().batchCount()).isEqualTo(1); } @Test void handleError() { EmbeddingService embeddingService = mock(EmbeddingService.class); - final EmbeddingGateway.EmbeddingResponse.Builder builder = + EmbeddingGateway.EmbeddingResponse.Builder builder = EmbeddingGateway.EmbeddingResponse.newBuilder(); EmbeddingGateway.EmbeddingResponse.ErrorResponse.Builder errorResponseBuilder = EmbeddingGateway.EmbeddingResponse.ErrorResponse.newBuilder(); - final JsonApiException apiException = + JsonApiException apiException = ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED.toApiException( "Error Code : %s response description : %s", 429, "Too Many Requests"); + errorResponseBuilder .setErrorCode(apiException.getErrorCode().name()) .setErrorMessage(apiException.getMessage()); builder.setError(errorResponseBuilder.build()); when(embeddingService.embed(any())).thenReturn(Uni.createFrom().item(builder.build())); + EmbeddingGatewayClient embeddingGatewayClient = new EmbeddingGatewayClient( - EmbeddingProviderConfigStore.RequestProperties.of( - 5, 5, 5, 5, 0.5, Optional.empty(), Optional.empty(), 2048), - "openai", + ModelProvider.OPENAI, + PROVIDER_CONFIG, + MODEL_CONFIG, + SERVICE_CONFIG, 1536, + Map.of(), Optional.of("default"), Optional.of("default"), - "https://api.openai.com/v1/", - "text-embedding-3-small", embeddingService, Map.of(), - Map.of(), TESTING_COMMAND_NAME); Throwable result = diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java index 0079e083b3..bec9fa6d44 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderErrorMessageTest.java @@ -8,60 +8,103 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfigImpl; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.inject.Inject; import jakarta.ws.rs.core.MediaType; +import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Optional; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +/** + * NOTE: this test relies on the {@link EmbeddingClientTestResource} to mock the server responses + */ @QuarkusTest @WithTestResource(EmbeddingClientTestResource.class) public class EmbeddingProviderErrorMessageTest { + + @Inject EmbeddingProvidersConfig embeddingProvidersConfig; + private static final int DEFAULT_DIMENSIONS = 0; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.of("test"), Optional.empty(), Optional.empty()); + new EmbeddingCredentials( + "test-tenant", Optional.of("test"), Optional.empty(), Optional.empty()); - private final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig testModel = + private final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig MODEL_CONFIG = new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl( - "test-model", + "testModel", new ApiModelSupport.ApiModelSupportImpl( ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), - Optional.of(123), + Optional.empty(), List.of(), Map.of(), Optional.empty()); - @Inject EmbeddingProvidersConfig config; + private final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.RequestPropertiesImpl + REQUEST_PROPERTIES = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.RequestPropertiesImpl( + 3, 10, 10000, 10000, 0.5, Optional.empty(), Optional.empty(), Optional.empty(), 10); + + private final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl PROVIDER_CONFIG = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl( + ModelProvider.NVIDIA.apiName(), + true, + Optional.of( + EmbeddingClientTestResource + .NVIDIA_URL), // path important see EmbeddingProviderErrorMessageTest + false, + Map.of(), + List.of(), + REQUEST_PROPERTIES, + List.of()); + + private final ServiceConfigStore.ServiceConfig SERVICE_CONFIG = + new ServiceConfigStore.ServiceConfig( + ModelProvider.NVIDIA, + EmbeddingClientTestResource + .NVIDIA_URL, // path important see EmbeddingProviderErrorMessageTest + Optional.empty(), + new ServiceConfigStore.ServiceRequestProperties( + REQUEST_PROPERTIES.atMostRetries(), + REQUEST_PROPERTIES.initialBackOffMillis(), + REQUEST_PROPERTIES.readTimeoutMillis(), + REQUEST_PROPERTIES.maxBackOffMillis(), + REQUEST_PROPERTIES.jitter(), + REQUEST_PROPERTIES.taskTypeRead(), + REQUEST_PROPERTIES.taskTypeStore(), + REQUEST_PROPERTIES.maxBatchSize()), + Map.of()); + + private NvidiaEmbeddingProvider createProvider() { + return new NvidiaEmbeddingProvider( + PROVIDER_CONFIG, MODEL_CONFIG, SERVICE_CONFIG, DEFAULT_DIMENSIONS, null); + } + + private Throwable vectorizeWithError(String text) { + + return createProvider() + .vectorize( + 1, List.of(text), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) + .subscribe() + .withSubscriber(UniAssertSubscriber.create()) + .awaitFailure() + .getFailure(); + } @Nested class NvidiaEmbeddingProviderTest { @Test public void test429() throws Exception { - Throwable exception = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("429"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = vectorizeWithError("429"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue("errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_RATE_LIMITED) @@ -72,24 +115,9 @@ public void test429() throws Exception { @Test public void test4xx() throws Exception { - Throwable exception = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("400"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = vectorizeWithError("400"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue("errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR) @@ -100,24 +128,9 @@ public void test4xx() throws Exception { @Test public void test5xx() throws Exception { - Throwable exception = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("503"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = vectorizeWithError("503"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue("errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_SERVER_ERROR) @@ -128,24 +141,9 @@ public void test5xx() throws Exception { @Test public void testRetryError() throws Exception { - Throwable exception = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("408"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = vectorizeWithError("408"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue("errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_TIMEOUT) @@ -156,15 +154,9 @@ public void testRetryError() throws Exception { @Test public void testCorrectHeaderAndBody() { - final EmbeddingProvider.Response result = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) + + final EmbeddingProvider.BatchedEmbeddingResponse result = + createProvider() .vectorize( 1, List.of(MediaType.APPLICATION_JSON), @@ -172,8 +164,9 @@ public void testCorrectHeaderAndBody() { EmbeddingProvider.EmbeddingRequestType.INDEX) .subscribe() .withSubscriber(UniAssertSubscriber.create()) - .awaitItem() + .awaitItem(Duration.ofDays(1)) .getItem(); + assertThat(result).isNotNull(); assertThat(result.batchId()).isEqualTo(1); assertThat(result.embeddings()).isNotNull(); @@ -181,24 +174,9 @@ public void testCorrectHeaderAndBody() { @Test public void testIncorrectContentTypeXML() { - Throwable exception = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("application/xml"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = vectorizeWithError("application/xml"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue( @@ -210,24 +188,9 @@ public void testIncorrectContentTypeXML() { @Test public void testIncorrectContentTypePlainText() { - Throwable exception = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("text/plain;charset=UTF-8"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = vectorizeWithError("text/plain;charset=UTF-8"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue( @@ -239,24 +202,9 @@ public void testIncorrectContentTypePlainText() { @Test public void testNoJsonResponse() { - Throwable exception = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("no json body"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = vectorizeWithError("no json body"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue( @@ -268,27 +216,16 @@ public void testNoJsonResponse() { @Test public void testEmptyJsonResponse() { - final EmbeddingProvider.Response result = - new NvidiaEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("nvidia").url().get(), - testModel, - DEFAULT_DIMENSIONS, - null, - null) - .vectorize( - 1, - List.of("empty json body"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitItem() - .getItem(); - assertThat(result).isNotNull(); - assertThat(result.batchId()).isEqualTo(1); - assertThat(result.embeddings()).isNotNull(); + + var exception = vectorizeWithError("empty json body"); + + assertThat(exception) + .isInstanceOf(JsonApiException.class) + .hasFieldOrPropertyWithValue( + "errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE) + .hasFieldOrPropertyWithValue( + "message", + "The Embedding Provider returned an unexpected response: Provider: nvidia; HTTP Status: 200; Error Message: ModelProvider returned empty data for model testModel"); } } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java index 9c885e4714..cf7dcdcf22 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/OpenAiEmbeddingClientTest.java @@ -8,10 +8,11 @@ import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfigImpl; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import jakarta.inject.Inject; import jakarta.ws.rs.core.MediaType; import java.util.List; @@ -20,16 +21,20 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +/** + * NOTE: this test relies on the {@link EmbeddingClientTestResource} to mock the server responses + */ @QuarkusTest @WithTestResource(EmbeddingClientTestResource.class) public class OpenAiEmbeddingClientTest { - @Inject EmbeddingProvidersConfig config; + @Inject EmbeddingProvidersConfig embeddingProvidersConfig; private final EmbeddingCredentials embeddingCredentials = - new EmbeddingCredentials(Optional.of("test"), Optional.empty(), Optional.empty()); + new EmbeddingCredentials( + "test-tenant", Optional.of("test"), Optional.empty(), Optional.empty()); - private final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig testModel = + private final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig MODEL_CONFIG = new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl( "test-model", new ApiModelSupport.ApiModelSupportImpl( @@ -39,30 +44,79 @@ public class OpenAiEmbeddingClientTest { Map.of(), Optional.empty()); + private final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.RequestPropertiesImpl + REQUEST_PROPERTIES = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.RequestPropertiesImpl( + 3, 10, 100, 100, 0.5, Optional.empty(), Optional.empty(), Optional.empty(), 10); + + private final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl PROVIDER_CONFIG = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl( + ModelProvider.OPENAI.apiName(), + true, + Optional.of(EmbeddingClientTestResource.OPENAI_URL), + false, + Map.of(), + List.of(), + REQUEST_PROPERTIES, + List.of()); + + private final ServiceConfigStore.ServiceConfig SERVICE_CONFIG = + new ServiceConfigStore.ServiceConfig( + ModelProvider.OPENAI, + EmbeddingClientTestResource + .OPENAI_URL, // path important see EmbeddingProviderErrorMessageTest + Optional.empty(), + new ServiceConfigStore.ServiceRequestProperties( + REQUEST_PROPERTIES.atMostRetries(), + REQUEST_PROPERTIES.initialBackOffMillis(), + REQUEST_PROPERTIES.readTimeoutMillis(), + REQUEST_PROPERTIES.maxBackOffMillis(), + REQUEST_PROPERTIES.jitter(), + REQUEST_PROPERTIES.taskTypeRead(), + REQUEST_PROPERTIES.taskTypeStore(), + REQUEST_PROPERTIES.maxBatchSize()), + Map.of()); + + private OpenAIEmbeddingProvider createProvider(Map vectorizeServiceParameters) { + return new OpenAIEmbeddingProvider( + PROVIDER_CONFIG, MODEL_CONFIG, SERVICE_CONFIG, 3, vectorizeServiceParameters); + } + + private EmbeddingProvider.BatchedEmbeddingResponse runVectorize( + EmbeddingProvider embeddingProvider, List texts) { + + return embeddingProvider + .vectorize(1, texts, embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) + .subscribe() + .withSubscriber(UniAssertSubscriber.create()) + .awaitItem() + .getItem(); + } + + private Throwable vectorizeWithError(EmbeddingProvider embeddingProvider, String text) { + + return embeddingProvider + .vectorize( + 1, List.of(text), embeddingCredentials, EmbeddingProvider.EmbeddingRequestType.INDEX) + .subscribe() + .withSubscriber(UniAssertSubscriber.create()) + .awaitFailure() + .getFailure(); + } + @Nested class OpenAiEmbeddingTest { + @Test public void happyPath() throws Exception { - final EmbeddingProvider.Response response = - new OpenAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("openai").url().get(), - testModel, - 3, - Map.of("organizationId", "org-id", "projectId", "project-id"), - null) - .vectorize( - 1, - List.of("some data"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitItem() - .getItem(); + + var response = + runVectorize( + createProvider(Map.of("organizationId", "org-id", "projectId", "project-id")), + List.of("some data")); + assertThat(response) - .isInstanceOf(EmbeddingProvider.Response.class) + .isInstanceOf(EmbeddingProvider.BatchedEmbeddingResponse.class) .satisfies( r -> { assertThat(r.embeddings()).isNotNull(); @@ -73,26 +127,11 @@ public void happyPath() throws Exception { @Test public void onlyToken() throws Exception { - final EmbeddingProvider.Response response = - new OpenAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("openai").url().get(), - testModel, - 3, - Map.of(), - null) - .vectorize( - 1, - List.of(MediaType.APPLICATION_JSON), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitItem() - .getItem(); + + var response = runVectorize(createProvider(Map.of()), List.of(MediaType.APPLICATION_JSON)); + assertThat(response) - .isInstanceOf(EmbeddingProvider.Response.class) + .isInstanceOf(EmbeddingProvider.BatchedEmbeddingResponse.class) .satisfies( r -> { assertThat(r.embeddings()).isNotNull(); @@ -103,24 +142,12 @@ public void onlyToken() throws Exception { @Test public void invalidOrg() throws Exception { - Throwable exception = - new OpenAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("openai").url().get(), - testModel, - 3, - Map.of("organizationId", "invalid org", "projectId", "project-id"), - null) - .vectorize( - 1, - List.of("some data"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = + vectorizeWithError( + createProvider(Map.of("organizationId", "invalid org", "projectId", "project-id")), + "some data"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue("errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR) @@ -131,24 +158,12 @@ public void invalidOrg() throws Exception { @Test public void invalidProject() throws Exception { - Throwable exception = - new OpenAIEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties.of( - 2, 100, 3000, 100, 0.5, Optional.empty(), Optional.empty(), 10), - config.providers().get("openai").url().get(), - testModel, - 3, - Map.of("organizationId", "org-id", "projectId", "invalid proj"), - null) - .vectorize( - 1, - List.of("some data"), - embeddingCredentials, - EmbeddingProvider.EmbeddingRequestType.INDEX) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitFailure() - .getFailure(); + + var exception = + vectorizeWithError( + createProvider(Map.of("organizationId", "org-id", "projectId", "invalid proj")), + "some data"); + assertThat(exception) .isInstanceOf(JsonApiException.class) .hasFieldOrPropertyWithValue("errorCode", ErrorCodeV1.EMBEDDING_PROVIDER_CLIENT_ERROR) diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java index b0c521f310..8e10edb672 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/TestEmbeddingProvider.java @@ -1,7 +1,5 @@ package io.stargate.sgv2.jsonapi.service.embedding.operation; -import static org.mockito.Mockito.mock; - import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.TestConstants; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; @@ -10,10 +8,12 @@ import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorConfig; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorizeDefinition; -import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfig; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProvidersConfigImpl; +import io.stargate.sgv2.jsonapi.service.embedding.configuration.ServiceConfigStore; import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelInputType; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; import io.stargate.sgv2.jsonapi.service.schema.EmbeddingSourceModel; import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction; import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionLexicalConfig; @@ -27,19 +27,9 @@ public class TestEmbeddingProvider extends EmbeddingProvider { - public TestEmbeddingProvider( - EmbeddingProviderConfigStore.RequestProperties requestProperties, - String baseUrl, - EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig model, - int dimension, - Map vectorizeServiceParameters, - EmbeddingProvidersConfig.EmbeddingProviderConfig providerConfig) { - super(requestProperties, baseUrl, model, dimension, vectorizeServiceParameters, providerConfig); - } - - public TestEmbeddingProvider() {} + private final TestConstants TEST_CONSTANTS = new TestConstants(); - public static final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig + private static final EmbeddingProvidersConfig.EmbeddingProviderConfig.ModelConfig TEST_MODEL_CONFIG = new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.ModelConfigImpl( "testModel", @@ -50,22 +40,50 @@ public TestEmbeddingProvider() {} Map.of(), Optional.empty()); - public static final TestEmbeddingProvider TEST_EMBEDDING_PROVIDER = - new TestEmbeddingProvider( - null, - null, - TEST_MODEL_CONFIG, - 3, + private static final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl + .RequestPropertiesImpl + REQUEST_PROPERTIES = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl.RequestPropertiesImpl( + 3, 10, 100, 100, 0.5, Optional.empty(), Optional.empty(), Optional.empty(), 10); + + private static final EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl PROVIDER_CONFIG = + new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl( + ModelProvider.CUSTOM.apiName(), + true, + Optional.of("http://testing.com"), + false, Map.of(), - mock(EmbeddingProvidersConfig.EmbeddingProviderConfig.class)); + List.of(), + REQUEST_PROPERTIES, + List.of()); + + private static final ServiceConfigStore.ServiceConfig SERVICE_CONFIG = + new ServiceConfigStore.ServiceConfig( + ModelProvider.CUSTOM, + "http://testing.com", + Optional.empty(), + new ServiceConfigStore.ServiceRequestProperties( + REQUEST_PROPERTIES.atMostRetries(), + REQUEST_PROPERTIES.initialBackOffMillis(), + REQUEST_PROPERTIES.readTimeoutMillis(), + REQUEST_PROPERTIES.maxBackOffMillis(), + REQUEST_PROPERTIES.jitter(), + REQUEST_PROPERTIES.taskTypeRead(), + REQUEST_PROPERTIES.taskTypeStore(), + REQUEST_PROPERTIES.maxBatchSize()), + Map.of()); - private TestConstants testConstants = new TestConstants(); + public static final TestEmbeddingProvider TEST_EMBEDDING_PROVIDER = new TestEmbeddingProvider(); + + public TestEmbeddingProvider() { + super(ModelProvider.CUSTOM, PROVIDER_CONFIG, TEST_MODEL_CONFIG, SERVICE_CONFIG, 3, Map.of()); + } public CommandContext commandContextWithVectorize() { - return testConstants.collectionContext( + return TEST_CONSTANTS.collectionContext( "testCommand", new CollectionSchemaObject( - testConstants.SCHEMA_OBJECT_NAME, + TEST_CONSTANTS.SCHEMA_OBJECT_NAME, null, IdConfig.defaultIdConfig(), VectorConfig.fromColumnDefinitions( @@ -84,7 +102,13 @@ public CommandContext commandContextWithVectorize() { } @Override - public Uni vectorize( + protected String errorMessageJsonPtr() { + // not used in tests + return ""; + } + + @Override + public Uni vectorize( int batchId, List texts, EmbeddingCredentials embeddingCredentials, @@ -95,7 +119,17 @@ public Uni vectorize( if (t.equals("return 1s")) response.add(new float[] {1.0f, 1.0f, 1.0f}); else response.add(new float[] {0.25f, 0.25f, 0.25f}); }); - return Uni.createFrom().item(Response.of(batchId, response)); + + var modelUsage = + createModelUsage( + embeddingCredentials.tenantId(), + ModelInputType.fromEmbeddingRequestType(embeddingRequestType), + 0, + 0, + 0, + 0, + 0); + return Uni.createFrom().item(new BatchedEmbeddingResponse(batchId, response, modelUsage)); } @Override diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/tables/WriteableTableRowBuilderTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/tables/WriteableTableRowBuilderTest.java index 26a5e5fb18..899524ee8d 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/tables/WriteableTableRowBuilderTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/tables/WriteableTableRowBuilderTest.java @@ -7,9 +7,9 @@ import static org.mockito.Mockito.mock; import io.micrometer.core.instrument.MeterRegistry; +import io.stargate.sgv2.jsonapi.TestConstants; import io.stargate.sgv2.jsonapi.api.model.command.CommandConfig; import io.stargate.sgv2.jsonapi.api.model.command.CommandContext; -import io.stargate.sgv2.jsonapi.api.request.RequestContext; import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures; import io.stargate.sgv2.jsonapi.exception.DocumentException; import io.stargate.sgv2.jsonapi.fixtures.*; @@ -32,7 +32,6 @@ import io.stargate.sgv2.jsonapi.util.recordable.PrettyPrintable; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -44,6 +43,8 @@ public class WriteableTableRowBuilderTest { private static final Logger LOGGER = LoggerFactory.getLogger(WriteableTableRowBuilderTest.class); + private static final TestConstants TEST_CONSTANTS = new TestConstants(); + private static void logFixture(String testName, JsonContainerFixture fixture) { // 24-Jan-2025, tatu: This produces thousands of lines noise in logs, so let's // change to TRACE level (from INFO) @@ -68,7 +69,7 @@ private static WriteableTableRow buildRow(JsonContainerFixture fixture) { .getBuilder(fixture.cqlFixture().tableSchemaObject()) .withEmbeddingProvider(mock(EmbeddingProvider.class)) .withCommandName("testCommand") - .withRequestContext(new RequestContext(Optional.of("test-tenant"))) + .withRequestContext(TEST_CONSTANTS.requestContext()) .withApiFeatures(ApiFeatures.empty()) .build(); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java deleted file mode 100644 index fa4187bae7..0000000000 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/NvidiaRerankingClientTest.java +++ /dev/null @@ -1,65 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.reranking; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.quarkus.test.junit.QuarkusTest; -import io.quarkus.test.junit.TestProfile; -import io.smallrye.mutiny.Uni; -import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; -import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; -import io.stargate.sgv2.jsonapi.service.reranking.operation.NvidiaRerankingProvider; -import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; -import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; -import java.util.List; -import java.util.Optional; -import java.util.stream.IntStream; -import org.junit.jupiter.api.Test; - -/** - * Tests for the RerankEGWClient class. Mocking the embedding gateway service to test the grpc - * rerank API. - */ -@QuarkusTest -@TestProfile(NoGlobalResourcesTestProfile.Impl.class) -public class NvidiaRerankingClientTest { - - private static final RerankingCredentials RERANK_CREDENTIALS = - new RerankingCredentials(Optional.of("mocked data api token")); - - @Test - void handleValidResponse() { - NvidiaRerankingProvider nvidiaRerankingProvider = mock(NvidiaRerankingProvider.class); - when(nvidiaRerankingProvider.rerank(anyInt(), any(), any(), any())) - .thenAnswer( - invocation -> { - List ranks = - IntStream.range(0, 2) - .mapToObj(i -> new RerankingProvider.Rank(i, i == 0 ? 0.1f : 1f)) - .toList(); - return Uni.createFrom() - .item( - new RerankingProvider.RerankingBatchResponse( - 1, ranks, new RerankingProvider.Usage(0, 0))); - }); - - final RerankingProvider.RerankingBatchResponse response = - nvidiaRerankingProvider - .rerank(1, "apple", List.of("orange", "apple"), RERANK_CREDENTIALS) - .subscribe() - .withSubscriber(UniAssertSubscriber.create()) - .awaitItem() - .getItem(); - - assertThat(response).isNotNull(); - assertThat(response.batchId()).isEqualTo(1); - assertThat(response.ranks()).isNotEmpty(); - assertThat(response.ranks().size()).isEqualTo(2); - assertThat(response.ranks().get(0).index()).isEqualTo(0); - assertThat(response.ranks().get(0).score()).isEqualTo(0.1f); - assertThat(response.ranks().get(1).index()).isEqualTo(1); - assertThat(response.ranks().get(1).score()).isEqualTo(1f); - } -} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java deleted file mode 100644 index cca8bd2eb1..0000000000 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/TestRerankingProvider.java +++ /dev/null @@ -1,42 +0,0 @@ -package io.stargate.sgv2.jsonapi.service.reranking; - -import io.smallrye.mutiny.Uni; -import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; -import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; -import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfigImpl; -import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; -import java.util.ArrayList; -import java.util.List; - -/** Mock a test reranking provider that returns ranks based on query and passages */ -public class TestRerankingProvider extends RerankingProvider { - - protected TestRerankingProvider( - String baseUrl, - String modelName, - RerankingProvidersConfig.RerankingProviderConfig.ModelConfig.RequestProperties - requestProperties) { - super(baseUrl, modelName, requestProperties); - } - - protected TestRerankingProvider(int maxBatchSize) { - super( - "mockUrl", - "mockModel", - new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl - .RequestPropertiesImpl(3, 100, 5000, 500, 0.5, maxBatchSize)); - } - - @Override - public Uni rerank( - int batchId, String query, List passages, RerankingCredentials rerankCredentials) { - List ranks = new ArrayList<>(passages.size()); - for (int i = 0; i < passages.size(); i++) { - String passage = passages.get(i); - float score = passage.equals(query) ? 1.0f : (float) Math.random(); // Example scoring logic - ranks.add(new Rank(i, score)); - } - ranks.sort((o1, o2) -> Float.compare(o2.score(), o1.score())); // Descending order - return Uni.createFrom().item(RerankingBatchResponse.of(batchId, ranks, new Usage(0, 0))); - } -} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingGatewayClientTest.java similarity index 63% rename from src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java rename to src/test/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingGatewayClientTest.java index f25d9da763..e9427385b0 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingGatewayClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/gateway/RerankingGatewayClientTest.java @@ -1,4 +1,4 @@ -package io.stargate.sgv2.jsonapi.service.reranking; +package io.stargate.sgv2.jsonapi.service.reranking.gateway; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -14,7 +14,11 @@ import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.exception.JsonApiException; -import io.stargate.sgv2.jsonapi.service.reranking.gateway.RerankingEGWClient; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.provider.ModelType; +import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; +import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfigImpl; import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import java.util.List; @@ -34,7 +38,26 @@ public class RerankingGatewayClientTest { public static final String TESTING_COMMAND_NAME = "test_command"; private static final RerankingCredentials RERANK_CREDENTIALS = - new RerankingCredentials(Optional.of("mocked reranking api key")); + new RerankingCredentials("test-tenant", Optional.of("mocked reranking api key")); + + private static final RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl + .RequestPropertiesImpl + REQUEST_PROPERTIES = + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl + .RequestPropertiesImpl(3, 10, 100, 100, 0.5, 10); + + private static final RerankingProvidersConfig.RerankingProviderConfig.ModelConfig MODEL_CONFIG = + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl( + "testModel", + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), + false, + "http://testing.com", + REQUEST_PROPERTIES); + + private static final RerankingProvidersConfigImpl.RerankingProviderConfigImpl PROVIDER_CONFIG = + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl( + false, "test", true, Map.of(), List.of()); @Test void handleValidResponse() { @@ -55,22 +78,31 @@ void handleValidResponse() { .build()) .toList(); builder.addAllRanks(ranks); + // mock model usage + builder.setModelUsage( + EmbeddingGateway.ModelUsage.newBuilder() + .setModelType(EmbeddingGateway.ModelUsage.ModelType.RERANKING) + .setModelProvider(ModelProvider.NVIDIA.apiName()) + .setModelName("llama-3.2-nv-rerankqa-1b-v2") + .setPromptTokens(10) + .setTotalTokens(20) + .setRequestBytes(100) + .setResponseBytes(200) + .build()); when(rerankService.rerank(any())).thenReturn(Uni.createFrom().item(builder.build())); // Create a RerankEGWClient instance RerankingEGWClient rerankEGWClient = new RerankingEGWClient( - "https://xxx", - null, - "xxx", + ModelProvider.NVIDIA, + MODEL_CONFIG, Optional.of("default"), Optional.of("default"), - "xxx", rerankService, Map.of(), TESTING_COMMAND_NAME); - final RerankingProvider.RerankingBatchResponse response = + final RerankingProvider.BatchedRerankingResponse response = rerankEGWClient .rerank(1, "apple", List.of("orange", "apple"), RERANK_CREDENTIALS) .subscribe() @@ -86,10 +118,20 @@ void handleValidResponse() { assertThat(response.ranks().get(0).score()).isEqualTo(1f); assertThat(response.ranks().get(1).index()).isEqualTo(0); assertThat(response.ranks().get(1).score()).isEqualTo(0.1f); + + assertThat(response.modelUsage()).isNotNull(); + assertThat(response.modelUsage().modelType()).isEqualTo(ModelType.RERANKING); + assertThat(response.modelUsage().modelProvider()).isEqualTo(ModelProvider.NVIDIA); + assertThat(response.modelUsage().modelName()).isEqualTo("llama-3.2-nv-rerankqa-1b-v2"); + assertThat(response.modelUsage().promptTokens()).isEqualTo(10); + assertThat(response.modelUsage().totalTokens()).isEqualTo(20); + assertThat(response.modelUsage().requestBytes()).isEqualTo(100); + assertThat(response.modelUsage().responseBytes()).isEqualTo(200); } @Test void handleError() { + RerankingService rerankService = mock(RerankingService.class); final EmbeddingGateway.RerankingResponse.Builder builder = EmbeddingGateway.RerankingResponse.newBuilder(); @@ -106,12 +148,10 @@ void handleError() { // Create a RerankEGWClient instance RerankingEGWClient rerankEGWClient = new RerankingEGWClient( - "https://xxx", - null, - "xxx", + ModelProvider.NVIDIA, + MODEL_CONFIG, Optional.of("default"), Optional.of("default"), - "xxx", rerankService, Map.of(), TESTING_COMMAND_NAME); diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingProviderTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderTest.java similarity index 93% rename from src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingProviderTest.java rename to src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderTest.java index e1e4a65573..776cb4dc87 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/RerankingProviderTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/RerankingProviderTest.java @@ -1,4 +1,4 @@ -package io.stargate.sgv2.jsonapi.service.reranking; +package io.stargate.sgv2.jsonapi.service.reranking.operation; import static org.assertj.core.api.Assertions.assertThat; @@ -6,7 +6,6 @@ import io.quarkus.test.junit.TestProfile; import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; -import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProvider; import io.stargate.sgv2.jsonapi.testresource.NoGlobalResourcesTestProfile; import java.util.List; import java.util.Optional; @@ -18,7 +17,7 @@ public class RerankingProviderTest { private static final RerankingCredentials RERANK_CREDENTIALS = - new RerankingCredentials(Optional.of("mocked reranking api key")); + new RerankingCredentials("test-tenant", Optional.of("mocked reranking api key")); @Test @SuppressWarnings("unchecked") diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/TestRerankingProvider.java b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/TestRerankingProvider.java new file mode 100644 index 0000000000..42fa5335e7 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/reranking/operation/TestRerankingProvider.java @@ -0,0 +1,74 @@ +package io.stargate.sgv2.jsonapi.service.reranking.operation; + +import io.smallrye.mutiny.Uni; +import io.stargate.sgv2.jsonapi.api.request.RerankingCredentials; +import io.stargate.sgv2.jsonapi.service.provider.ApiModelSupport; +import io.stargate.sgv2.jsonapi.service.provider.ModelProvider; +import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfig; +import io.stargate.sgv2.jsonapi.service.reranking.configuration.RerankingProvidersConfigImpl; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** Mock a test reranking provider that returns ranks based on query and passages */ +public class TestRerankingProvider extends RerankingProvider { + + private static final RerankingCredentials RERANK_CREDENTIALS = + new RerankingCredentials("test-tenant", Optional.of("mocked reranking api key")); + + private static final RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl + .RequestPropertiesImpl + REQUEST_PROPERTIES = + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl + .RequestPropertiesImpl(3, 10, 100, 100, 0.5, 10); + + private static final RerankingProvidersConfig.RerankingProviderConfig.ModelConfig MODEL_CONFIG = + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl( + "testModel", + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), + false, + "http://testing.com", + REQUEST_PROPERTIES); + + private static final RerankingProvidersConfigImpl.RerankingProviderConfigImpl PROVIDER_CONFIG = + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl( + false, "test", true, Map.of(), List.of()); + + protected TestRerankingProvider(int maxBatchSize) { + super( + ModelProvider.CUSTOM, + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl( + "testModel", + new ApiModelSupport.ApiModelSupportImpl( + ApiModelSupport.SupportStatus.SUPPORTED, Optional.empty()), + false, + "http://testing.com", + new RerankingProvidersConfigImpl.RerankingProviderConfigImpl.ModelConfigImpl + .RequestPropertiesImpl(3, 100, 5000, 500, 0.5, maxBatchSize))); + } + + @Override + protected String errorMessageJsonPtr() { + // not used in tests + return ""; + } + + @Override + public Uni rerank( + int batchId, String query, List passages, RerankingCredentials rerankCredentials) { + + List ranks = new ArrayList<>(passages.size()); + for (int i = 0; i < passages.size(); i++) { + String passage = passages.get(i); + float score = passage.equals(query) ? 1.0f : (float) Math.random(); // Example scoring logic + ranks.add(new Rank(i, score)); + } + + ranks.sort((o1, o2) -> Float.compare(o2.score(), o1.score())); // Descending order + return Uni.createFrom() + .item( + new BatchedRerankingResponse(batchId, ranks, createEmptyModelUsage(rerankCredentials))); + } +} diff --git a/src/test/resources/application.yaml b/src/test/resources/application.yaml index 9fca7902e8..a1c5c8785a 100644 --- a/src/test/resources/application.yaml +++ b/src/test/resources/application.yaml @@ -13,6 +13,8 @@ quarkus: enabled: false test-port: 9080 log: + console: + format: "%-5p [%t] %d{yyyy-MM-dd HH:mm:ss,SSS} %F:%L - %m%n" category: # production log level for this categoy is DEBUG, way too noisy for tests 'io.stargate':