Skip to content

Commit 47fdba2

Browse files
committed
Get/Set prompt cache
1 parent 94027ac commit 47fdba2

File tree

6 files changed

+30
-6
lines changed

6 files changed

+30
-6
lines changed

TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ protected override void ValidateOptions(GenerateOptions options)
148148
/// <param name="cancellationToken">The cancellation token.</param>
149149
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
150150
{
151+
var cachedPrompt = GetPromptCache(options);
152+
if (cachedPrompt is not null)
153+
return cachedPrompt;
154+
151155
// Tokenize2
152156
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
153157
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
@@ -179,7 +183,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
179183
var negativePromptPooledEmbeds = negativePromptEmbeddings.TextEmbeds;
180184
negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch();
181185

182-
return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds);
186+
return SetPromptCache(options, new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds));
183187
}
184188

185189

TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ protected override void ValidateOptions(GenerateOptions options)
137137
/// <param name="cancellationToken">The cancellation token.</param>
138138
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
139139
{
140+
var cachedPrompt = GetPromptCache(options);
141+
if (cachedPrompt is not null)
142+
return cachedPrompt;
143+
140144
// Conditional Prompt
141145
var promptEmbeds = await TextEncoder.GetLastHiddenState(new TextGeneration.Common.GenerateOptions
142146
{
@@ -159,7 +163,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
159163
}, cancellationToken);
160164
}
161165

162-
return new PromptResult(promptEmbeds, default, negativePromptEmbeds, default);
166+
return SetPromptCache(options, new PromptResult(promptEmbeds, default, negativePromptEmbeds, default));
163167
}
164168

165169

TensorStack.StableDiffusion/Pipelines/StableCascade/StableCascadeBase.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ protected override void ValidateOptions(GenerateOptions options)
147147
/// <param name="cancellationToken">The cancellation token.</param>
148148
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
149149
{
150+
var cachedPrompt = GetPromptCache(options);
151+
if (cachedPrompt is not null)
152+
return cachedPrompt;
153+
150154
// Tokenizer
151155
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
152156
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
@@ -167,7 +171,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
167171
? negativePromptEmbeddings.TextEmbeds
168172
: negativePromptEmbeddings.TextEmbeds.Reshape([1, .. negativePromptEmbeddings.TextEmbeds.Dimensions]);
169173

170-
return new PromptResult(promptEmbeddings.HiddenStates, textEmbeds, negativePromptEmbeddings.HiddenStates, negativeTextEmbeds);
174+
return SetPromptCache(options, new PromptResult(promptEmbeddings.HiddenStates, textEmbeds, negativePromptEmbeddings.HiddenStates, negativeTextEmbeds));
171175
}
172176

173177

TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionBase.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ protected override void ValidateOptions(GenerateOptions options)
146146
/// <param name="cancellationToken">The cancellation token.</param>
147147
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
148148
{
149+
var cachedPrompt = GetPromptCache(options);
150+
if (cachedPrompt is not null)
151+
return cachedPrompt;
152+
149153
// Tokenizer
150154
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
151155
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
@@ -157,7 +161,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
157161
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
158162
await TextEncoder.UnloadAsync();
159163

160-
return new PromptResult(promptEmbeddings.HiddenStates, promptEmbeddings.TextEmbeds, negativePromptEmbeddings.HiddenStates, negativePromptEmbeddings.TextEmbeds);
164+
return SetPromptCache(options, new PromptResult(promptEmbeddings.HiddenStates, promptEmbeddings.TextEmbeds, negativePromptEmbeddings.HiddenStates, negativePromptEmbeddings.TextEmbeds));
161165
}
162166

163167

TensorStack.StableDiffusion/Pipelines/StableDiffusion3/StableDiffusion3Base.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ protected override void ValidateOptions(GenerateOptions options)
165165
/// <param name="cancellationToken">The cancellation token.</param>
166166
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
167167
{
168+
var cachedPrompt = GetPromptCache(options);
169+
if (cachedPrompt is not null)
170+
return cachedPrompt;
171+
168172
// Tokenizer
169173
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
170174
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
@@ -217,7 +221,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
217221
negativePromptPooledEmbeds = negativePromptPooledEmbeds.Reshape([negativePromptPooledEmbeds.Dimensions[^2], negativePromptPooledEmbeds.Dimensions[^1]]).FirstBatch();
218222
negativePromptPooledEmbeds = negativePromptPooledEmbeds.Concatenate(negativePromptPooledEmbeds2, 1);
219223

220-
return new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds);
224+
return SetPromptCache(options, new PromptResult(promptEmbeds, promptPooledEmbeds, negativePromptEmbeds, negativePromptPooledEmbeds));
221225
}
222226

223227

TensorStack.StableDiffusion/Pipelines/StableDiffusionXL/StableDiffusionXLBase.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ protected override void ValidateOptions(GenerateOptions options)
146146
/// <param name="cancellationToken">The cancellation token.</param>
147147
protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default)
148148
{
149+
var cachedPrompt = GetPromptCache(options);
150+
if (cachedPrompt is not null)
151+
return cachedPrompt;
152+
149153
// Tokenizer
150154
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
151155
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
@@ -176,7 +180,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
176180
var pooledNegativePromptEmbeds = negativePrompt2Embeddings.TextEmbeds;
177181
var negativePromptEmbeddings = negativePrompt1Embeddings.HiddenStates.Concatenate(negativePrompt2Embeddings.HiddenStates, 2);
178182

179-
return new PromptResult(promptEmbeddings, pooledPromptEmbeds, negativePromptEmbeddings, pooledNegativePromptEmbeds);
183+
return SetPromptCache(options, new PromptResult(promptEmbeddings, pooledPromptEmbeds, negativePromptEmbeddings, pooledNegativePromptEmbeds));
180184
}
181185

182186

0 commit comments

Comments
 (0)