Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io.stargate.sgv2.jsonapi.api.model.command.CollectionOnlyCommand;
import io.stargate.sgv2.jsonapi.api.model.command.CommandName;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.config.constants.RerankingConstants;
import io.stargate.sgv2.jsonapi.config.constants.ServiceDescConstants;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.schema.collections.DocumentPath;
import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules;
Expand Down Expand Up @@ -276,28 +276,28 @@ public record RerankServiceDesc(
description = "Registered reranking service provider",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty(RerankingConstants.RerankingService.PROVIDER)
@JsonProperty(ServiceDescConstants.PROVIDER)
String provider,
@Schema(
description = "Registered reranking service model",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty(RerankingConstants.RerankingService.MODEL_NAME)
@JsonProperty(ServiceDescConstants.MODEL_NAME)
String modelName,
@Valid
@Nullable
@Schema(
description = "Authentication config for chosen reranking service",
type = SchemaType.OBJECT)
@JsonProperty(RerankingConstants.RerankingService.AUTHENTICATION)
@JsonProperty(ServiceDescConstants.AUTHENTICATION)
@JsonInclude(JsonInclude.Include.NON_NULL)
Map<String, String> authentication,
@Nullable
@Schema(
description =
"Optional parameters that match the messageTemplate provided for the reranking provider",
type = SchemaType.OBJECT)
@JsonProperty(RerankingConstants.RerankingService.PARAMETERS)
@JsonProperty(ServiceDescConstants.PARAMETERS)
@JsonInclude(JsonInclude.Include.NON_NULL)
Map<String, Object> parameters) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants;
import io.stargate.sgv2.jsonapi.service.embedding.operation.HuggingFaceDedicatedEmbeddingProvider;
import io.stargate.sgv2.jsonapi.service.provider.ModelProvider;
import jakarta.validation.Valid;
import jakarta.validation.constraints.*;
import java.util.*;
Expand Down Expand Up @@ -48,24 +49,30 @@ public VectorizeConfig(
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {

if (provider == null) {
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"'provider' in required property for 'vector.service' Object value");
}

this.provider = provider;

// HuggingfaceDedicated does not need user to specify model explicitly
// If user specifies modelName other than endpoint-defined-model, will error out
// By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder
if (ProviderConstants.HUGGINGFACE_DEDICATED.equals(provider)) {
if (ModelProvider.HUGGINGFACE_DEDICATED.apiName().equals(provider)) {
if (modelName == null) {
modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL;
} else if (!modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) {
modelName =
HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL;
} else if (!modelName.equals(
HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL)) {
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"'modelName' is not needed for embedding provider %s explicitly, only '%s' is accepted",
ProviderConstants.HUGGINGFACE_DEDICATED,
ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL);
ModelProvider.HUGGINGFACE_DEDICATED,
HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL);
}
}

this.modelName = modelName;
if (authentication != null && !authentication.isEmpty()) {
Map<String, String> updatedAuth = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
* EmbeddingCredentials is a record that holds the embedding provider credentials for the embedding
* service passed as header.
*
* <p>Includes the tenantID, so we can fully identify the usage when creating the {@link
* io.stargate.sgv2.jsonapi.service.provider.ModelUsage}
*
* @param tenantId - Tenant Id that called the API.
* @param apiKey - API token for the embedding service
* @param accessId - Access Id used for AWS Bedrock embedding service
* @param secretId - Secret Id used for AWS Bedrock embedding service
*/
public record EmbeddingCredentials(
Optional<String> apiKey, Optional<String> accessId, Optional<String> secretId) {}
String tenantId,
Optional<String> apiKey,
Optional<String> accessId,
Optional<String> secretId) {}
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ public EmbeddingCredentials create(
&& collectionSupportsNoneAuth) {
var authToken = requestContext.getHttpHeaders().getHeader(this.authTokenHeaderName);
return new EmbeddingCredentials(
Optional.ofNullable(authToken), Optional.empty(), Optional.empty());
requestContext.getTenantId().orElse(""),
Optional.ofNullable(authToken),
Optional.empty(),
Optional.empty());
}

return new EmbeddingCredentials(
requestContext.getTenantId().orElse(""),
Optional.ofNullable(embeddingApi),
Optional.ofNullable(accessId),
Optional.ofNullable(secretId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.uuid.Generators;
import com.fasterxml.uuid.NoArgGenerator;
import com.google.common.annotations.VisibleForTesting;
import io.stargate.sgv2.jsonapi.api.request.tenant.DataApiTenantResolver;
import io.stargate.sgv2.jsonapi.api.request.token.DataApiTokenResolver;
import io.stargate.sgv2.jsonapi.config.constants.HttpConstants;
Expand Down Expand Up @@ -35,20 +36,25 @@ public class RequestContext {

private final String userAgent;

/**
* Constructor that will be useful in the offline library mode, where only the tenant will be set
* and accessed.
*
* @param tenantId Tenant Id
*/
public RequestContext(Optional<String> tenantId) {
/** FOR TESTING ONLY - so we can bypass pulling things the headers, still messy, getting better */
@VisibleForTesting
public RequestContext(
Optional<String> tenantId,
Optional<String> cassandraToken,
RerankingCredentials rerankingCredentials,
String userAgent) {
this.tenantId = tenantId;
cassandraToken = Optional.empty();
embeddingCredentialsSupplier = null;
rerankingCredentials = null;
httpHeaders = null;
this.cassandraToken = cassandraToken;
embeddingCredentialsSupplier =
new EmbeddingCredentialsSupplier(
HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME,
HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME,
HttpConstants.EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME,
HttpConstants.EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME);
this.rerankingCredentials = rerankingCredentials;
this.userAgent = userAgent;
this.httpHeaders = new HttpHeaderAccess(io.vertx.core.MultiMap.caseInsensitiveMultiMap());
requestId = generateRequestId();
userAgent = null;
}

@Inject
Expand Down Expand Up @@ -77,11 +83,14 @@ public RequestContext(
HeaderBasedRerankingKeyResolver.resolveRerankingKey(routingContext);
rerankingCredentials =
rerankingApiKeyFromHeader
.map(apiKey -> new RerankingCredentials(Optional.of(apiKey)))
.map(apiKey -> new RerankingCredentials(this.tenantId.orElse(""), Optional.of(apiKey)))
.orElse(
this.cassandraToken
.map(cassandraToken -> new RerankingCredentials(Optional.of(cassandraToken)))
.orElse(new RerankingCredentials(Optional.empty())));
.map(
cassandraToken ->
new RerankingCredentials(
this.tenantId.orElse(""), Optional.of(cassandraToken)))
.orElse(new RerankingCredentials(this.tenantId.orElse(""), Optional.empty())));
}

private static String generateRequestId() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@
* resolved from the request header 'reranking-api-key', if it is not present, then we will use the
* cassandra token as the reranking api key. Note, both cassandra token and reranking-api-key could
* be absent in Data API request, although it is invalid for authentication.
*
* <p>Includes the tenantId, so we can fully identify the usage when creating the {@link
* io.stargate.sgv2.jsonapi.service.provider.ModelUsage}
*/
public record RerankingCredentials(Optional<String> apiKey) {}
public record RerankingCredentials(String tenantId, Optional<String> apiKey) {}
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public Uni<RestResponse<CommandResult>> postCommand(

if (vectorColDef != null && vectorColDef.vectorizeDefinition() != null) {
embeddingProvider =
embeddingProviderFactory.getConfiguration(
embeddingProviderFactory.create(
requestContext.getTenantId(),
requestContext.getCassandraToken(),
vectorColDef.vectorizeDefinition().provider(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,4 @@ interface CollectionRerankingOptions {
String ENABLED = "enabled";
String SERVICE = ServiceDescConstants.SERVICE;
}

interface RerankingService extends ServiceDescConstants {}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.stargate.sgv2.jsonapi.config.constants;

/** Common service description constants shared between vector and reranking */
interface ServiceDescConstants {
public interface ServiceDescConstants {
String SERVICE = "service";
String PROVIDER = "provider";
String MODEL_NAME = "modelName";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public DataVectorizer(
SchemaObject schemaObject) {
this.embeddingProvider = embeddingProvider;
this.nodeFactory = nodeFactory;
this.embeddingCredentials = embeddingCredentials;
this.embeddingCredentials =
Objects.requireNonNull(embeddingCredentials, "embeddingCredentials must not be null");
this.schemaObject = schemaObject;
}

Expand Down Expand Up @@ -175,7 +176,7 @@ public Uni<float[]> vectorize(String vectorizeContent) {
List.of(vectorizeContent),
embeddingCredentials,
EmbeddingProvider.EmbeddingRequestType.INDEX)
.map(EmbeddingProvider.Response::embeddings);
.map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings);
return vectors
.onItem()
.transform(
Expand Down Expand Up @@ -303,7 +304,7 @@ private Uni<List<float[]>> vectorizeTexts(

return embeddingProvider
.vectorize(1, textsToVectorize, embeddingCredentials, requestType)
.map(EmbeddingProvider.Response::embeddings)
.map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings)
.onItem()
.transform(
vectorData -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public <T extends SchemaObject> Uni<Command> vectorize(

public <T extends SchemaObject> DataVectorizer constructDataVectorizer(
CommandContext<T> commandContext) {

EmbeddingProvider embeddingProvider =
Optional.ofNullable(commandContext.embeddingProvider())
.map(
Expand All @@ -83,6 +84,7 @@ public <T extends SchemaObject> DataVectorizer constructDataVectorizer(
provider,
commandContext.commandName()))
.orElse(null);

return new DataVectorizer(
embeddingProvider,
objectMapper.getNodeFactory(),
Expand All @@ -91,7 +93,7 @@ public <T extends SchemaObject> DataVectorizer constructDataVectorizer(
.getEmbeddingCredentialsSupplier()
.create(
commandContext.requestContext(),
embeddingProvider == null ? null : embeddingProvider.getProviderConfig()),
embeddingProvider == null ? null : embeddingProvider.providerConfig()),
commandContext.schemaObject());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
*/
@ConfigMapping(prefix = "stargate.jsonapi.embedding")
public interface DefaultEmbeddingProviderConfig {

// TODO: WHAT DOES THIS ACTUALLY RETURN ?
@Nullable
Map<String, EmbeddingProvidersConfig.EmbeddingProviderConfig> providers();
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import java.util.Optional;
import org.eclipse.microprofile.config.spi.Converter;

// TODO: SOME DOCUMENTATION FOR WHAT THIS IS MEANT TO DO!!!
public interface EmbeddingProvidersConfig {

// TODO: WHAT IS THE KEY FOR THIS MAP ?????
Map<String, EmbeddingProviderConfig> providers();

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.*;
import java.util.stream.Collectors;

// TODO: SOME DOCUMENTATION FOR WHAT THIS IS MEANT TO DO!!!
public record EmbeddingProvidersConfigImpl(
Map<String, EmbeddingProviderConfig> providers, CustomConfig custom)
implements EmbeddingProvidersConfig {
Expand Down
Loading