diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index b2fb299c7..defc51517 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -401,7 +401,6 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(scorer.LoraAffinityScorerType, scorer.LoraAffinityScorerFactory) // Latency predictor plugins plugins.Register(slo_aware_router.SLOAwareRouterPluginType, slo_aware_router.SLOAwareRouterFactory) - plugins.Register(slo_aware_router.SLOAwareProfileHandlerType, slo_aware_router.SLOAwareProfileHandlerFactory) // register filter for test purpose only (used in conformance tests) plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory) // register response received plugin for test purpose only (used in conformance tests) diff --git a/config/charts/inferencepool/templates/epp-config.yaml b/config/charts/inferencepool/templates/epp-config.yaml index f34d5cf21..6f947a929 100644 --- a/config/charts/inferencepool/templates/epp-config.yaml +++ b/config/charts/inferencepool/templates/epp-config.yaml @@ -15,7 +15,7 @@ data: - type: predicted-latency-scorer parameters: {{- with .Values.inferenceExtension.latencyPredictor.sloAwareRouting | default dict }} - samplingMean: {{ .samplingMean | default 100.0 }} + samplingMean: {{ .samplingMean | default 1000.0 }} maxSampledTokens: {{ .maxSampledTokens | default 20 }} sloBufferFactor: {{ .sloBufferFactor | default 1.0 }} negHeadroomTTFTWeight: {{ .negHeadroomTTFTWeight | default 0.8 }} @@ -32,23 +32,14 @@ data: affinityGateTauGlobal: {{ .affinityGateTauGlobal | default 0.99 }} selectionMode: {{ .selectionMode | default "linear" | quote }} {{- end }} - - type: predicted-latency-profile-handler {{- end }} schedulingProfiles: {{- if .Values.inferenceExtension.latencyPredictor.enabled }} - - name: predicted-latency-prefix - plugins: - - pluginRef: prefix-cache-scorer - - name: predicted-latency-no-routing - plugins: - - pluginRef: prefix-cache-scorer - - pluginRef: predicted-latency-scorer - weight: 0 - - pluginRef: queue-scorer - - pluginRef: kv-cache-utilization-scorer - - name: predicted-latency-routing + - name: default plugins: - pluginRef: predicted-latency-scorer + featureGates: + - prepareDataPlugins {{- else }} - name: default plugins: diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index fd03e4bbb..f0f44d51e 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -228,6 +228,8 @@ func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMReque matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)] pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(matchLen, total)) } + // Store the state in plugin state for later use. + p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state) return nil } @@ -241,6 +243,8 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques } cycleState.Write(plugins.StateKey(p.TypedName().String()), state) + + // store the state in plugin state for later use in PreRequest. This may go away once we default to prepare request data plugin hook. p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state) log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes) // calculate the scores of pods diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go index 4ef23ca94..a89581c1b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -45,9 +45,3 @@ func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (fl // 3. Return the successfully parsed value return parsedFloat, nil } - -// hasHeader checks if a header key exists in the request headers map. -func hasHeader(request schedulingtypes.LLMRequest, headerName string) bool { - _, ok := request.Headers[headerName] - return ok -} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go index f31932538..ddb59c629 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -41,7 +41,7 @@ type podPredictionResult struct { } // generatePredictions creates prediction results for all candidate pods -func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) { +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) { logger := log.FromContext(ctx) predictions := make([]podPredictionResult, 0, len(candidatePods)) @@ -55,7 +55,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) // Get prefix cache score for the pod - prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + prefixCacheScore := sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) @@ -108,19 +108,7 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul // updateRequestContextWithPredictions updates the request context with prediction data func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *sloRequestContext, predictions []podPredictionResult) { - for _, pred := range predictions { - if pred.Error == nil { - podKey := pred.Pod.GetPod().String() - if sloCtx.predictedTTFTForScheduling == nil { - sloCtx.predictedTTFTForScheduling = make(map[string]float64) - } - if sloCtx.predictedTPOTForScheduling == nil { - sloCtx.predictedTPOTForScheduling = make(map[string]float64) - } - sloCtx.predictedTTFTForScheduling[podKey] = pred.TTFT - sloCtx.predictedTPOTForScheduling[podKey] = pred.TPOT - } - } + sloCtx.predictionsForScheduling = predictions } func (s *SLOAwareRouter) validatePrediction( diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/preparedata_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/preparedata_hooks.go new file mode 100644 index 000000000..a49e42187 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/preparedata_hooks.go @@ -0,0 +1,64 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// PrepareRequestData prepares the SLO context for the request, including parsing SLO headers and gathering prefix cache scores abds generating predictions. +func (s *SLOAwareRouter) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + logger := log.FromContext(ctx) + sloCtx := s.getOrMakeSLORequestContext(request) + + s.parseSLOHeaders(ctx, request, sloCtx) + var prefixCacheScore float64 + for _, pod := range pods { + + if prefixCacheInfoRaw, ok := pod.Get(approximateprefix.PrefixCacheMatchInfoKey); ok { + prefixCacheInfo := prefixCacheInfoRaw.(*approximateprefix.PrefixCacheMatchInfo) + prefixCacheScore = float64(prefixCacheInfo.MatchLength()) / float64(prefixCacheInfo.TotalLength()) + if !math.IsNaN(prefixCacheScore) { + logger.V(logutil.DEBUG).Info("Found prefix cache score in pod attribute", "pod", pod.GetPod().String(), "score", prefixCacheScore) + } else { + prefixCacheScore = 0.0 + logger.V(logutil.DEBUG).Info("Prefix cache score is NaN, defaulting to 0", "pod", pod.GetPod().String()) + } + } else { + logger.V(logutil.DEBUG).Info("No prefix cache score found in pod attribute, defaulting to 0", "pod", pod.GetPod().String()) + prefixCacheScore = 0.0 + } + sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + } + + return nil +} + +func (p *SLOAwareRouter) Produces() map[string]any { + return map[string]any{} +} + +func (p *SLOAwareRouter) Consumes() map[string]any { + return map[string]any{approximateprefix.PrefixCacheMatchInfoKey: approximateprefix.PrefixCacheMatchInfo{}} +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index d91aac2cf..042998c33 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -64,12 +64,8 @@ type sloRequestContext struct { // TPOTSLO is the target time per output token SLO for the request. avgTPOTSLO float64 - // predictorBasedScheduling indicates whether to use predictor based scheduling. - predictorBasedScheduling bool // predictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. - predictedTTFTForScheduling map[string]float64 - // predictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. - predictedTPOTForScheduling map[string]float64 + predictionsForScheduling []podPredictionResult // boolean set if request has valid pod based on predictions hasValidPod bool @@ -77,11 +73,10 @@ type sloRequestContext struct { func newSLORequestContext(request *schedulingtypes.LLMRequest) *sloRequestContext { return &sloRequestContext{ - schedulingRequest: *request, - lastSeenMetrics: make(map[string]*backendmetrics.MetricsState), - prefixCacheScoresForPods: make(map[string]float64), - predictedTTFTForScheduling: make(map[string]float64), - predictedTPOTForScheduling: make(map[string]float64), + schedulingRequest: *request, + lastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + prefixCacheScoresForPods: make(map[string]float64), + predictionsForScheduling: make([]podPredictionResult, 0), } } @@ -245,8 +240,6 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli } } - logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.predictorBasedScheduling) - podName := types.NamespacedName{ Name: targetPod.NamespacedName.Name, Namespace: targetPod.NamespacedName.Namespace, diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go index ac7344e30..60d3ba06b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -69,7 +69,7 @@ func createTestRouter() *SLOAwareRouter { // Test cases func TestNewSLORequestContext(t *testing.T) { - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) ctx := newSLORequestContext(request) @@ -77,15 +77,14 @@ func TestNewSLORequestContext(t *testing.T) { assert.Equal(t, *request, ctx.schedulingRequest) assert.NotNil(t, ctx.lastSeenMetrics) assert.NotNil(t, ctx.prefixCacheScoresForPods) - assert.NotNil(t, ctx.predictedTTFTForScheduling) - assert.NotNil(t, ctx.predictedTPOTForScheduling) + assert.NotNil(t, ctx.predictionsForScheduling) assert.Empty(t, ctx.lastSeenMetrics) assert.Empty(t, ctx.prefixCacheScoresForPods) } func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { router := createTestRouter() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) sloCtx := newSLORequestContext(request) // Set context @@ -100,7 +99,7 @@ func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { func TestSLOAwareRouter_GetSLOContext_NotFound(t *testing.T) { router := createTestRouter() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) // Try to get context that doesn't exist ctx, err := router.getSLOContextForRequest(request) @@ -112,7 +111,7 @@ func TestSLOAwareRouter_GetSLOContext_NotFound(t *testing.T) { func TestSLOAwareRouter_DeleteSLOContext(t *testing.T) { router := createTestRouter() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) sloCtx := newSLORequestContext(request) // Set and then delete context @@ -128,7 +127,7 @@ func TestSLOAwareRouter_DeleteSLOContext(t *testing.T) { func TestSLOAwareRouter_PreRequest_NoSchedulingResult(t *testing.T) { router := createTestRouter() ctx := context.Background() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) // Call PreRequest with nil scheduling result router.PreRequest(ctx, request, nil) @@ -141,7 +140,7 @@ func TestSLOAwareRouter_PreRequest_NoSchedulingResult(t *testing.T) { func TestSLOAwareRouter_PreRequest_EmptySchedulingResult(t *testing.T) { router := createTestRouter() ctx := context.Background() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) schedulingResult := &schedulingtypes.SchedulingResult{ ProfileResults: map[string]*schedulingtypes.ProfileRunResult{}, @@ -162,7 +161,7 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create and set initial SLO context @@ -195,7 +194,7 @@ func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create and set initial SLO context @@ -219,8 +218,8 @@ func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request1 := createTestLLMRequest("test-id-1", 100, 50, true) - request2 := createTestLLMRequest("test-id-2", 100, 50, true) + request1 := createTestLLMRequest("test-id-1", 100, 50) + request2 := createTestLLMRequest("test-id-2", 100, 50) schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create and set initial SLO contexts @@ -250,7 +249,7 @@ func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) @@ -270,7 +269,7 @@ func TestSLOAwareRouter_ResponseReceived_NoPod(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) @@ -290,7 +289,7 @@ func TestSLOAwareRouter_ResponseReceived_NoContext(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} // Don't set SLO context @@ -306,7 +305,7 @@ func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) @@ -326,7 +325,7 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} schedulingResult := createTestSchedulingResult(pod.GetPod()) @@ -377,7 +376,7 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} schedulingResult := createTestSchedulingResult(pod.GetPod()) @@ -425,7 +424,7 @@ func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) @@ -450,7 +449,7 @@ func TestSLOAwareRouter_ResponseStreaming_NoContext(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} // Don't set SLO context - should handle gracefully @@ -467,7 +466,7 @@ func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} // Create queue and add request @@ -501,7 +500,7 @@ func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) @@ -521,7 +520,7 @@ func TestSLOAwareRouter_ResponseComplete_NoPod(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) @@ -542,7 +541,7 @@ func TestSLOAwareRouter_ResponseComplete_NoContext(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} // Don't set SLO context - should handle gracefully @@ -559,7 +558,7 @@ func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} // Create queue @@ -592,7 +591,7 @@ func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test-id", 0, 0, true) // No SLOs + request := createTestLLMRequest("test-id", 0, 0) // No SLOs response := &requestcontrol.Response{} // Create queue @@ -647,14 +646,13 @@ func TestSLOAwareRouter_CheckPredictor_Success(t *testing.T) { } func TestSLORequestContext_Fields(t *testing.T) { - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) ctx := newSLORequestContext(request) // Test all field initialization assert.NotNil(t, ctx.lastSeenMetrics) assert.NotNil(t, ctx.prefixCacheScoresForPods) - assert.NotNil(t, ctx.predictedTTFTForScheduling) - assert.NotNil(t, ctx.predictedTPOTForScheduling) + assert.NotNil(t, ctx.predictionsForScheduling) assert.Empty(t, ctx.tpotObservations) assert.Empty(t, ctx.predictedTPOTObservations) assert.Zero(t, ctx.generatedTokenCount) @@ -666,7 +664,7 @@ func TestSLORequestContext_Fields(t *testing.T) { } func TestSLORequestContext_UpdateMetrics(t *testing.T) { - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) ctx := newSLORequestContext(request) // Add some metrics @@ -682,23 +680,22 @@ func TestSLORequestContext_UpdateMetrics(t *testing.T) { } func TestSLORequestContext_PredictionData(t *testing.T) { - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) ctx := newSLORequestContext(request) + ctx.predictionsForScheduling = make([]podPredictionResult, 0) + // Set prediction data - ctx.predictedTTFTForScheduling["pod1"] = 80.0 - ctx.predictedTPOTForScheduling["pod1"] = 30.0 - ctx.predictedTTFTForScheduling["pod2"] = 90.0 - ctx.predictedTPOTForScheduling["pod2"] = 35.0 + ctx.predictionsForScheduling = append(ctx.predictionsForScheduling, podPredictionResult{Pod: createTestPod("pod1", 0, 0, 0), TTFT: 80.0, TPOT: 25.0}) + ctx.predictionsForScheduling = append(ctx.predictionsForScheduling, podPredictionResult{Pod: createTestPod("pod1", 0, 0, 0), TPOT: 30.0, TTFT: 85.0}) - assert.Len(t, ctx.predictedTTFTForScheduling, 2) - assert.Len(t, ctx.predictedTPOTForScheduling, 2) - assert.Equal(t, 80.0, ctx.predictedTTFTForScheduling["pod1"]) - assert.Equal(t, 30.0, ctx.predictedTPOTForScheduling["pod1"]) + assert.Len(t, ctx.predictionsForScheduling, 2) + assert.Equal(t, 80.0, ctx.predictionsForScheduling[0].TTFT) + assert.Equal(t, 30.0, ctx.predictionsForScheduling[1].TPOT) } func TestSLORequestContext_PrefixCacheScores(t *testing.T) { - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) ctx := newSLORequestContext(request) // Set prefix cache scores @@ -724,7 +721,7 @@ func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { defer wg.Done() requestID := uuid.New().String() - request := createTestLLMRequest(requestID, 100, 50, true) + request := createTestLLMRequest(requestID, 100, 50) sloCtx := newSLORequestContext(request) // Set context @@ -751,9 +748,9 @@ func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request1 := createTestLLMRequest("test-id-1", 100, 50, true) - request2 := createTestLLMRequest("test-id-2", 100, 50, true) - request3 := createTestLLMRequest("test-id-3", 100, 50, true) + request1 := createTestLLMRequest("test-id-1", 100, 50) + request2 := createTestLLMRequest("test-id-2", 100, 50) + request3 := createTestLLMRequest("test-id-3", 100, 50) schedulingResult := createTestSchedulingResult(pod.GetPod()) @@ -782,7 +779,7 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) response := &requestcontrol.Response{} schedulingResult := createTestSchedulingResult(pod.GetPod()) @@ -834,8 +831,8 @@ func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { pod1 := createTestPod("test-pod-1", 1, 1, 1) pod2 := createTestPod("test-pod-2", 1, 1, 1) - request1 := createTestLLMRequest("test-id-1", 100, 50, true) - request2 := createTestLLMRequest("test-id-2", 100, 50, true) + request1 := createTestLLMRequest("test-id-1", 100, 50) + request2 := createTestLLMRequest("test-id-2", 100, 50) schedulingResult1 := createTestSchedulingResult(pod1.GetPod()) schedulingResult2 := createTestSchedulingResult(pod2.GetPod()) @@ -899,7 +896,7 @@ func TestSLORequestContext_SLOValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - request := createTestLLMRequest("test-id", tt.ttftSLO, tt.tpotSLO, true) + request := createTestLLMRequest("test-id", tt.ttftSLO, tt.tpotSLO) ctx := newSLORequestContext(request) ctx.ttftSLO = tt.ttftSLO ctx.avgTPOTSLO = tt.tpotSLO @@ -921,7 +918,7 @@ func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { requestID := uuid.New().String() - request := createTestLLMRequest(requestID, 100, 50, true) + request := createTestLLMRequest(requestID, 100, 50) sloCtx := newSLORequestContext(request) sloCtx.avgTPOTSLO = 50 router.setSLOContextForRequest(request, sloCtx) @@ -931,7 +928,7 @@ func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { func BenchmarkSLOAwareRouter_ContextOperations(b *testing.B) { router := createTestRouter() - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) sloCtx := newSLORequestContext(request) b.ResetTimer() @@ -943,7 +940,7 @@ func BenchmarkSLOAwareRouter_ContextOperations(b *testing.B) { } func BenchmarkSLORequestContext_Creation(b *testing.B) { - request := createTestLLMRequest("test", 100, 50, true) + request := createTestLLMRequest("test", 100, 50) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index 25bb1e8ed..74146adcc 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -230,7 +230,7 @@ func (s *SLOAwareRouter) epsilonGreedyAffinityGate( // when latency predictions are unavailable func (s *SLOAwareRouter) scoreWithoutPredictions( ctx context.Context, - state *schedulingtypes.CycleState, + sloCtx *sloRequestContext, pods []schedulingtypes.Pod, r *rand.Rand, ) map[schedulingtypes.Pod]float64 { @@ -249,7 +249,7 @@ func (s *SLOAwareRouter) scoreWithoutPredictions( // Build prediction results with only prefix cache scores podResults := make([]podPredictionResult, 0, len(pods)) for _, pod := range pods { - prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + prefixScore := sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] podResults = append(podResults, podPredictionResult{ Pod: pod, PrefixCacheScore: prefixScore, @@ -277,39 +277,21 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle sloCtx := s.getOrMakeSLORequestContext(request) - s.parseSLOHeaders(ctx, request, sloCtx) - - for _, pod := range pods { - prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore - } - - // Check if SLOs are provided - if !sloCtx.predictorBasedScheduling { - logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") + predictions, err := s.generatePredictions(ctx, request, sloCtx, pods) + if err != nil || len(predictions) == 0 { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring") s.setSLOContextForRequest(request, sloCtx) - return nil + return s.scoreWithoutPredictions(ctx, sloCtx, pods, rand.New(rand.NewSource(time.Now().UnixNano()))) } + s.updateRequestContextWithPredictions(sloCtx, predictions) // Initialize scores map with all pods having score 0 scores := make(map[schedulingtypes.Pod]float64, len(pods)) for _, pod := range pods { scores[pod] = 0 } - - source := rand.NewSource(time.Now().UnixNano()) - r := rand.New(source) - predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods) - if err != nil { - logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring") - // Fall back to composite-only scoring using prefix cache scores - s.setSLOContextForRequest(request, sloCtx) - return s.scoreWithoutPredictions(ctx, state, pods, r) - } - s.updateRequestContextWithPredictions(sloCtx, predictions) - allPreds := append([]podPredictionResult(nil), predictions...) - allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) + allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, rand.New(rand.NewSource(time.Now().UnixNano())), "overall", AffinityGateTauGlobal) // Check if all pods are invalid and all have running requests allPodsInvalid := (sloCtx.ttftSLO > 0 && sloCtx.avgTPOTSLO > 0) @@ -339,7 +321,7 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle "positivePods", len(posHeadroomPods), "negativePods", len(negHeadroomPods)) - selectedPod := s.selectPodBasedOnStrategy(ctx, r, allPreds, posHeadroomPods, negHeadroomPods) + selectedPod := s.selectPodBasedOnStrategy(ctx, rand.New(rand.NewSource(time.Now().UnixNano())), allPreds, posHeadroomPods, negHeadroomPods) // Set score = 1 for selected pod, 0 for all others if selectedPod != nil { @@ -357,6 +339,7 @@ func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLM if err != nil { sloCtx = newSLORequestContext(request) } + return sloCtx } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go index bb97ba346..7b6331959 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go @@ -41,7 +41,6 @@ func (s *SLOAwareRouter) parseSLOHeaders(ctx context.Context, request *schedulin if err != nil { logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", tpotSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") } - sloCtx.predictorBasedScheduling = !hasHeader(*request, "x-prediction-based-scheduling-off") } func (s *SLOAwareRouter) classifyPodsByHeadroom(allPreds []podPredictionResult) (posHeadroomPods, negHeadroomPods []podPredictionResult) { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go index 8d8f68393..a2c727aed 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -119,7 +119,7 @@ func createTestPod(name string, kvCacheUsage float64, runningRequestsSize, waiti } } -func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64, predictionBased bool) *schedulingtypes.LLMRequest { +func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64) *schedulingtypes.LLMRequest { headers := make(map[string]string) headers[requtil.RequestIdHeaderKey] = reqID if ttftSLO > 0 { @@ -128,9 +128,6 @@ func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64, predictionBase if tpotSLO > 0 { headers["x-avg-tpot-slo"] = fmt.Sprintf("%f", tpotSLO) } - if !predictionBased { - headers["x-prediction-based-scheduling-off"] = "true" - } return &schedulingtypes.LLMRequest{ Headers: headers, @@ -152,22 +149,11 @@ func TestSLOAwareRouter_Score(t *testing.T) { expectedScores map[string]float64 // Map of pod name to expected score expectNil bool }{ - { - name: "Prediction-based scheduling disabled", - predictor: &mockPredictor{}, - strategy: headroomStrategyLeast, - request: createTestLLMRequest("test", 1.0, 0.05, false), // predictionBased = false - pods: []schedulingtypes.Pod{ - createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting - createTestPod("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting - }, - expectNil: true, - }, { name: "No predictor configured", predictor: nil, strategy: headroomStrategyLeast, - request: createTestLLMRequest("test", 1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), }, @@ -183,7 +169,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, }, strategy: headroomStrategyLeast, - request: createTestLLMRequest("test", 1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), // 50% KV cache createTestPod("pod2", 0.6, 3, 2), // 60% KV cache @@ -203,7 +189,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, }, strategy: headroomStrategyLeast, - request: createTestLLMRequest("test", 1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load createTestPod("pod2", 0.9, 6, 4), // 90% KV cache, very high load @@ -220,7 +206,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, }, strategy: headroomStrategyLeast, - request: createTestLLMRequest("test", 1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05), pods: []schedulingtypes.Pod{ createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom createTestPod("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom @@ -234,7 +220,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { err: errors.New("prediction failed"), }, strategy: headroomStrategyLeast, - request: createTestLLMRequest("test", 1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), createTestPod("pod2", 0.6, 3, 2), @@ -248,7 +234,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { name: "Empty pod list", predictor: &mockPredictor{}, strategy: headroomStrategyLeast, - request: createTestLLMRequest("test", 1.0, 0.05, true), + request: createTestLLMRequest("test", 1.0, 0.05), pods: []schedulingtypes.Pod{}, // Should return empty scores map expectedScores: map[string]float64{}, @@ -348,7 +334,7 @@ func TestSLOAwareRouter_Strategies(t *testing.T) { cfg.HeadroomSelectionStrategy = string(tt.strategy) router := NewSLOAwareRouter(cfg, predictor) - request := createTestLLMRequest("test", 1.0, 0.05, true) + request := createTestLLMRequest("test", 1.0, 0.05) pods := []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), createTestPod("pod2", 0.6, 3, 2), diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/slo_aware_profile_handler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/slo_aware_profile_handler.go deleted file mode 100644 index f377ad55c..000000000 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/slo_aware_profile_handler.go +++ /dev/null @@ -1,149 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package slo_aware_router - -import ( - "context" - "encoding/json" - "errors" - "fmt" - - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" -) - -const ( - SLOAwareProfileHandlerType = "predicted-latency-profile-handler" - NoLatencyRoutingProfileName = "predicted-latency-no-routing" - PrefixProfileName = "predicted-latency-prefix" - LatencyRoutingProfileName = "predicted-latency-routing" - - // Boolean header string for whether to use predictor based scheduling - PreictionBasedSchedulingHeaderKey = "x-prediction-based-scheduling-off" -) - -// compile-time type assertion -var _ framework.ProfileHandler = &SLOAwareProfileHandler{} - -// SLOAwareProfileHandlerFactory defines the factory function for SLOAwareProfileHandler. -func SLOAwareProfileHandlerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { - return NewSLOAwareProfileHandler().WithName(name), nil -} - -// NewSLOAwareProfileHandler initializes a new SLOAwareProfileHandler and returns its pointer. -func NewSLOAwareProfileHandler() *SLOAwareProfileHandler { - return &SLOAwareProfileHandler{ - typedName: plugins.TypedName{Type: SLOAwareProfileHandlerType, Name: SLOAwareProfileHandlerType}, - } -} - -// SLOAwareProfileHandler handles two profiles: the default profile and the SLO profile. -// When the request has PredictorBasedScheduling=true, it uses the SLO profile result to select -// the destination pod. Otherwise, it uses the default profile result. -type SLOAwareProfileHandler struct { - typedName plugins.TypedName -} - -// TypedName returns the type and name tuple of this plugin instance. -func (h *SLOAwareProfileHandler) TypedName() plugins.TypedName { - return h.typedName -} - -// WithName sets the name of the profile handler. -func (h *SLOAwareProfileHandler) WithName(name string) *SLOAwareProfileHandler { - h.typedName.Name = name - return h -} - -// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the -// previously executed cycles along with their results. -func (h *SLOAwareProfileHandler) Pick(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, - profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { - - predictorBasedScheduling := !isHeaderPresent(*request, PreictionBasedSchedulingHeaderKey) - - _, prefixExecuted := profileResults[PrefixProfileName] - // if prefix profile was not executed yet, first let the scheduler run it - if !prefixExecuted { - return map[string]*framework.SchedulerProfile{ - PrefixProfileName: profiles[PrefixProfileName], - } - } - - if predictorBasedScheduling { - _, routingExecuted := profileResults[LatencyRoutingProfileName] - // routing profile has not been executed yet - if !routingExecuted { - return map[string]*framework.SchedulerProfile{ - LatencyRoutingProfileName: profiles[LatencyRoutingProfileName], - } - } - } else { - _, defaultExecuted := profileResults[NoLatencyRoutingProfileName] - // predictorBasedScheduling is off, and NoLatencyRoutingProfileName profile has not been executed yet - if !defaultExecuted { - return map[string]*framework.SchedulerProfile{ - NoLatencyRoutingProfileName: profiles[NoLatencyRoutingProfileName], - } - } - } - - // all previous profiles have been executed, nothing more to run - return map[string]*framework.SchedulerProfile{} -} - -// ProcessResults handles the outcome of the profile runs after all profiles ran. -// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the -// key of the primary profile that should be used to get the request selected destination. -// When a profile run fails, its result in the profileResults map is nil. -func (h *SLOAwareProfileHandler) ProcessResults(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { - - predictorBasedScheduling := !isHeaderPresent(*request, PreictionBasedSchedulingHeaderKey) - - if predictorBasedScheduling { // TODO grab header directly from request.Headers instead of request field - if len(profileResults) < 2 { - return nil, errors.New("SLOAwareProfileHandler requires at least two profiles to operate when predictorBasedScheduling is true") - } - if profileResults[LatencyRoutingProfileName] == nil { // there was an error while running the SLO profile - return nil, fmt.Errorf("failed to run scheduler profile '%s'", LatencyRoutingProfileName) - } - return &types.SchedulingResult{ - ProfileResults: profileResults, - PrimaryProfileName: LatencyRoutingProfileName, - }, nil - } - if len(profileResults) < 1 { - return nil, errors.New("SLOAwareProfileHandler requires at least one profiles to operate when predictorBasedScheduling is false") - } - - if profileResults[NoLatencyRoutingProfileName] == nil { // there was an error while running the default profile - return nil, fmt.Errorf("failed to run scheduler profile '%s'", NoLatencyRoutingProfileName) - } - - return &types.SchedulingResult{ - ProfileResults: profileResults, - PrimaryProfileName: NoLatencyRoutingProfileName, - }, nil -} - -// isHeaderPresent checks if a header key exists in the request headers map. -func isHeaderPresent(request types.LLMRequest, headerName string) bool { - // 1. Get header value from the map - _, ok := request.Headers[headerName] - return ok -}