Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ func (r *Runner) registerInTreePlugins() {
plugins.Register(scorer.LoraAffinityScorerType, scorer.LoraAffinityScorerFactory)
// Latency predictor plugins
plugins.Register(slo_aware_router.SLOAwareRouterPluginType, slo_aware_router.SLOAwareRouterFactory)
plugins.Register(slo_aware_router.SLOAwareProfileHandlerType, slo_aware_router.SLOAwareProfileHandlerFactory)
// register filter for test purpose only (used in conformance tests)
plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory)
// register response received plugin for test purpose only (used in conformance tests)
Expand Down
17 changes: 4 additions & 13 deletions config/charts/inferencepool/templates/epp-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ data:
- type: predicted-latency-scorer
parameters:
{{- with .Values.inferenceExtension.latencyPredictor.sloAwareRouting | default dict }}
samplingMean: {{ .samplingMean | default 100.0 }}
samplingMean: {{ .samplingMean | default 1000.0 }}
maxSampledTokens: {{ .maxSampledTokens | default 20 }}
sloBufferFactor: {{ .sloBufferFactor | default 1.0 }}
negHeadroomTTFTWeight: {{ .negHeadroomTTFTWeight | default 0.8 }}
Expand All @@ -32,23 +32,14 @@ data:
affinityGateTauGlobal: {{ .affinityGateTauGlobal | default 0.99 }}
selectionMode: {{ .selectionMode | default "linear" | quote }}
{{- end }}
- type: predicted-latency-profile-handler
{{- end }}
schedulingProfiles:
{{- if .Values.inferenceExtension.latencyPredictor.enabled }}
- name: predicted-latency-prefix
plugins:
- pluginRef: prefix-cache-scorer
- name: predicted-latency-no-routing
plugins:
- pluginRef: prefix-cache-scorer
- pluginRef: predicted-latency-scorer
weight: 0
- pluginRef: queue-scorer
- pluginRef: kv-cache-utilization-scorer
- name: predicted-latency-routing
- name: default
plugins:
- pluginRef: predicted-latency-scorer
featureGates:
- prepareDataPlugins
{{- else }}
- name: default
plugins:
Expand Down
4 changes: 4 additions & 0 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMReque
matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)]
pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(matchLen, total))
}
// Store the state in plugin state for later use.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to change anything in the prefix plugin?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefix scorer does not consume the prefix state in the same way as the predicted latency one?

p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
return nil
}

Expand All @@ -241,6 +243,8 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
}

cycleState.Write(plugins.StateKey(p.TypedName().String()), state)

// store the state in plugin state for later use in PreRequest. This may go away once we default to prepare request data plugin hook.
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes)
// calculate the scores of pods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,3 @@ func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (fl
// 3. Return the successfully parsed value
return parsedFloat, nil
}

// hasHeader checks if a header key exists in the request headers map.
func hasHeader(request schedulingtypes.LLMRequest, headerName string) bool {
_, ok := request.Headers[headerName]
return ok
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type podPredictionResult struct {
}

// generatePredictions creates prediction results for all candidate pods
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) {
func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) {
logger := log.FromContext(ctx)
predictions := make([]podPredictionResult, 0, len(candidatePods))

Expand All @@ -55,7 +55,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul
logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String())

// Get prefix cache score for the pod
prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod)
prefixCacheScore := sloCtx.prefixCacheScoresForPods[pod.GetPod().String()]
sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore

logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore)
Expand Down Expand Up @@ -108,19 +108,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul

// updateRequestContextWithPredictions updates the request context with prediction data
func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *sloRequestContext, predictions []podPredictionResult) {
for _, pred := range predictions {
if pred.Error == nil {
podKey := pred.Pod.GetPod().String()
if sloCtx.predictedTTFTForScheduling == nil {
sloCtx.predictedTTFTForScheduling = make(map[string]float64)
}
if sloCtx.predictedTPOTForScheduling == nil {
sloCtx.predictedTPOTForScheduling = make(map[string]float64)
}
sloCtx.predictedTTFTForScheduling[podKey] = pred.TTFT
sloCtx.predictedTPOTForScheduling[podKey] = pred.TPOT
}
}
sloCtx.predictionsForScheduling = predictions
}

func (s *SLOAwareRouter) validatePrediction(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package slo_aware_router

import (
"context"
"math"

"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

// PrepareRequestData prepares the SLO context for the request, including parsing SLO headers and gathering prefix cache scores abds generating predictions.
func (s *SLOAwareRouter) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
logger := log.FromContext(ctx)
sloCtx := s.getOrMakeSLORequestContext(request)

s.parseSLOHeaders(ctx, request, sloCtx)
var prefixCacheScore float64
for _, pod := range pods {

if prefixCacheInfoRaw, ok := pod.Get(approximateprefix.PrefixCacheMatchInfoKey); ok {
prefixCacheInfo := prefixCacheInfoRaw.(*approximateprefix.PrefixCacheMatchInfo)
prefixCacheScore = float64(prefixCacheInfo.MatchLength()) / float64(prefixCacheInfo.TotalLength())
if !math.IsNaN(prefixCacheScore) {
logger.V(logutil.DEBUG).Info("Found prefix cache score in pod attribute", "pod", pod.GetPod().String(), "score", prefixCacheScore)
} else {
prefixCacheScore = 0.0
logger.V(logutil.DEBUG).Info("Prefix cache score is NaN, defaulting to 0", "pod", pod.GetPod().String())
}
} else {
logger.V(logutil.DEBUG).Info("No prefix cache score found in pod attribute, defaulting to 0", "pod", pod.GetPod().String())
prefixCacheScore = 0.0
}
sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore
}

return nil
}

func (p *SLOAwareRouter) Produces() map[string]any {
return map[string]any{}
}

func (p *SLOAwareRouter) Consumes() map[string]any {
return map[string]any{approximateprefix.PrefixCacheMatchInfoKey: approximateprefix.PrefixCacheMatchInfo{}}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,19 @@ type sloRequestContext struct {
// TPOTSLO is the target time per output token SLO for the request.
avgTPOTSLO float64

// predictorBasedScheduling indicates whether to use predictor based scheduling.
predictorBasedScheduling bool
// predictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling.
predictedTTFTForScheduling map[string]float64
// predictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling.
predictedTPOTForScheduling map[string]float64
predictionsForScheduling []podPredictionResult

// boolean set if request has valid pod based on predictions
hasValidPod bool
}

func newSLORequestContext(request *schedulingtypes.LLMRequest) *sloRequestContext {
return &sloRequestContext{
schedulingRequest: *request,
lastSeenMetrics: make(map[string]*backendmetrics.MetricsState),
prefixCacheScoresForPods: make(map[string]float64),
predictedTTFTForScheduling: make(map[string]float64),
predictedTPOTForScheduling: make(map[string]float64),
schedulingRequest: *request,
lastSeenMetrics: make(map[string]*backendmetrics.MetricsState),
prefixCacheScoresForPods: make(map[string]float64),
predictionsForScheduling: make([]podPredictionResult, 0),
}
}

Expand Down Expand Up @@ -245,8 +240,6 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli
}
}

logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.predictorBasedScheduling)

podName := types.NamespacedName{
Name: targetPod.NamespacedName.Name,
Namespace: targetPod.NamespacedName.Namespace,
Expand Down
Loading