Skip to content

Commit 148bfdf

Browse files
committed
add decay rate and relative threshold
1 parent f347010 commit 148bfdf

File tree

5 files changed

+151
-42
lines changed

5 files changed

+151
-42
lines changed

examples/cli/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,5 +127,7 @@ Generation Options:
127127
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
128128
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
129129
--cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)
130-
--cache-option cache parameters "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)
130+
--cache-option cache parameters: easycache uses "threshold,start,end" (default: 0.2,0.15,0.95).
131+
ucache uses "threshold,start,end[,decay,relative]" (default: 1.0,0.15,0.95,1.0,1).
132+
decay: error decay rate (0.0-1.0), relative: use relative threshold (0 or 1)
131133
```

examples/cli/main.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,7 @@ struct SDGenerationParams {
14341434
on_cache_mode_arg},
14351435
{"",
14361436
"--cache-option",
1437-
"cache parameters \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)",
1437+
"cache parameters \"threshold,start,end[,warmup,decay,relative]\" (ucache extended: warmup=0, decay=1.0, relative=1)",
14381438
on_cache_option_arg},
14391439

14401440
};
@@ -1561,28 +1561,32 @@ struct SDGenerationParams {
15611561
}
15621562
}
15631563

1564-
float values[3] = {0.0f, 0.0f, 0.0f};
1564+
// Format: threshold,start,end[,decay,relative]
1565+
// - values[0-2]: threshold, start_percent, end_percent (required)
1566+
// - values[3]: error_decay_rate (optional, default: 1.0)
1567+
// - values[4]: use_relative_threshold (optional, 0 or 1, default: 1)
1568+
float values[5] = {0.0f, 0.0f, 0.0f, 1.0f, 1.0f};
15651569
std::stringstream ss(option_str);
15661570
std::string token;
15671571
int idx = 0;
1572+
auto trim = [](std::string& s) {
1573+
const char* whitespace = " \t\r\n";
1574+
auto start = s.find_first_not_of(whitespace);
1575+
if (start == std::string::npos) {
1576+
s.clear();
1577+
return;
1578+
}
1579+
auto end = s.find_last_not_of(whitespace);
1580+
s = s.substr(start, end - start + 1);
1581+
};
15681582
while (std::getline(ss, token, ',')) {
1569-
auto trim = [](std::string& s) {
1570-
const char* whitespace = " \t\r\n";
1571-
auto start = s.find_first_not_of(whitespace);
1572-
if (start == std::string::npos) {
1573-
s.clear();
1574-
return;
1575-
}
1576-
auto end = s.find_last_not_of(whitespace);
1577-
s = s.substr(start, end - start + 1);
1578-
};
15791583
trim(token);
15801584
if (token.empty()) {
15811585
fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str());
15821586
return false;
15831587
}
1584-
if (idx >= 3) {
1585-
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
1588+
if (idx >= 5) {
1589+
fprintf(stderr, "error: cache option expects 3-5 comma-separated values (threshold,start,end[,decay,relative])\n");
15861590
return false;
15871591
}
15881592
try {
@@ -1593,8 +1597,8 @@ struct SDGenerationParams {
15931597
}
15941598
idx++;
15951599
}
1596-
if (idx != 3) {
1597-
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
1600+
if (idx < 3) {
1601+
fprintf(stderr, "error: cache option expects at least 3 comma-separated values (threshold,start,end)\n");
15981602
return false;
15991603
}
16001604
if (values[0] < 0.0f) {
@@ -1612,10 +1616,12 @@ struct SDGenerationParams {
16121616
easycache_params.start_percent = values[1];
16131617
easycache_params.end_percent = values[2];
16141618
} else {
1615-
ucache_params.enabled = true;
1616-
ucache_params.reuse_threshold = values[0];
1617-
ucache_params.start_percent = values[1];
1618-
ucache_params.end_percent = values[2];
1619+
ucache_params.enabled = true;
1620+
ucache_params.reuse_threshold = values[0];
1621+
ucache_params.start_percent = values[1];
1622+
ucache_params.end_percent = values[2];
1623+
ucache_params.error_decay_rate = values[3];
1624+
ucache_params.use_relative_threshold = (values[4] != 0.0f);
16191625
}
16201626
}
16211627

stable-diffusion.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,10 +1548,12 @@ class StableDiffusionGGML {
15481548
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
15491549
} else {
15501550
UCacheConfig ucache_config;
1551-
ucache_config.enabled = true;
1552-
ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold);
1553-
ucache_config.start_percent = ucache_params->start_percent;
1554-
ucache_config.end_percent = ucache_params->end_percent;
1551+
ucache_config.enabled = true;
1552+
ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold);
1553+
ucache_config.start_percent = ucache_params->start_percent;
1554+
ucache_config.end_percent = ucache_params->end_percent;
1555+
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, ucache_params->error_decay_rate));
1556+
ucache_config.use_relative_threshold = ucache_params->use_relative_threshold;
15551557
bool percent_valid = ucache_config.start_percent >= 0.0f &&
15561558
ucache_config.start_percent < 1.0f &&
15571559
ucache_config.end_percent > 0.0f &&
@@ -1565,10 +1567,12 @@ class StableDiffusionGGML {
15651567
ucache_state.init(ucache_config, denoiser.get());
15661568
if (ucache_state.enabled()) {
15671569
ucache_enabled = true;
1568-
LOG_INFO("UCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f",
1570+
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s",
15691571
ucache_config.reuse_threshold,
15701572
ucache_config.start_percent,
1571-
ucache_config.end_percent);
1573+
ucache_config.end_percent,
1574+
ucache_config.error_decay_rate,
1575+
ucache_config.use_relative_threshold ? "true" : "false");
15721576
} else {
15731577
LOG_WARN("UCache requested but could not be initialized for this run");
15741578
}
@@ -2616,11 +2620,13 @@ void sd_easycache_params_init(sd_easycache_params_t* easycache_params) {
26162620
}
26172621

26182622
void sd_ucache_params_init(sd_ucache_params_t* ucache_params) {
2619-
*ucache_params = {};
2620-
ucache_params->enabled = false;
2621-
ucache_params->reuse_threshold = 1.0f;
2622-
ucache_params->start_percent = 0.15f;
2623-
ucache_params->end_percent = 0.95f;
2623+
*ucache_params = {};
2624+
ucache_params->enabled = false;
2625+
ucache_params->reuse_threshold = 1.0f;
2626+
ucache_params->start_percent = 0.15f;
2627+
ucache_params->end_percent = 0.95f;
2628+
ucache_params->error_decay_rate = 1.0f;
2629+
ucache_params->use_relative_threshold = true;
26242630
}
26252631

26262632
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ typedef struct {
247247
float reuse_threshold;
248248
float start_percent;
249249
float end_percent;
250+
float error_decay_rate;
251+
bool use_relative_threshold;
250252
} sd_ucache_params_t;
251253

252254
typedef struct {

ucache.hpp

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
#include "ggml_extend.hpp"
1111

1212
struct UCacheConfig {
13-
bool enabled = false;
14-
float reuse_threshold = 1.0f;
15-
float start_percent = 0.15f;
16-
float end_percent = 0.95f;
13+
bool enabled = false;
14+
float reuse_threshold = 1.0f;
15+
float start_percent = 0.15f;
16+
float end_percent = 0.95f;
17+
float error_decay_rate = 1.0f;
18+
bool use_relative_threshold = true;
19+
bool adaptive_threshold = true;
20+
float early_step_multiplier = 0.5f;
21+
float late_step_multiplier = 1.5f;
1722
};
1823

1924
struct UCacheCacheEntry {
@@ -44,6 +49,45 @@ struct UCacheState {
4449
bool has_last_input_change = false;
4550
int total_steps_skipped = 0;
4651
int current_step_index = -1;
52+
int steps_computed_since_active = 0;
53+
float accumulated_error = 0.0f;
54+
float reference_output_norm = 0.0f;
55+
56+
struct BlockMetrics {
57+
float sum_transformation_rate = 0.0f;
58+
float sum_output_norm = 0.0f;
59+
int sample_count = 0;
60+
float min_change_rate = std::numeric_limits<float>::max();
61+
float max_change_rate = 0.0f;
62+
63+
void reset() {
64+
sum_transformation_rate = 0.0f;
65+
sum_output_norm = 0.0f;
66+
sample_count = 0;
67+
min_change_rate = std::numeric_limits<float>::max();
68+
max_change_rate = 0.0f;
69+
}
70+
71+
void record(float change_rate, float output_norm) {
72+
if (std::isfinite(change_rate) && change_rate > 0.0f) {
73+
sum_transformation_rate += change_rate;
74+
sum_output_norm += output_norm;
75+
sample_count++;
76+
if (change_rate < min_change_rate) min_change_rate = change_rate;
77+
if (change_rate > max_change_rate) max_change_rate = change_rate;
78+
}
79+
}
80+
81+
float avg_transformation_rate() const {
82+
return (sample_count > 0) ? (sum_transformation_rate / sample_count) : 0.0f;
83+
}
84+
85+
float avg_output_norm() const {
86+
return (sample_count > 0) ? (sum_output_norm / sample_count) : 0.0f;
87+
}
88+
};
89+
BlockMetrics block_metrics;
90+
int total_active_steps = 0;
4791

4892
void reset_runtime() {
4993
initial_step = true;
@@ -64,6 +108,11 @@ struct UCacheState {
64108
has_last_input_change = false;
65109
total_steps_skipped = 0;
66110
current_step_index = -1;
111+
steps_computed_since_active = 0;
112+
accumulated_error = 0.0f;
113+
reference_output_norm = 0.0f;
114+
block_metrics.reset();
115+
total_active_steps = 0;
67116
}
68117

69118
void init(const UCacheConfig& cfg, Denoiser* d) {
@@ -114,6 +163,7 @@ struct UCacheState {
114163
return;
115164
}
116165
step_active = true;
166+
total_active_steps++;
117167
}
118168

119169
bool step_is_active() const {
@@ -124,6 +174,31 @@ struct UCacheState {
124174
return enabled() && step_active && skip_current_step;
125175
}
126176

177+
float get_adaptive_threshold(int estimated_total_steps = 0) const {
178+
float base_threshold = config.reuse_threshold;
179+
180+
if (!config.adaptive_threshold) {
181+
return base_threshold;
182+
}
183+
184+
int effective_total = estimated_total_steps;
185+
if (effective_total <= 0) {
186+
effective_total = std::max(20, steps_computed_since_active * 2);
187+
}
188+
189+
float progress = (effective_total > 0) ?
190+
(static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
191+
192+
float multiplier = 1.0f;
193+
if (progress < 0.2f) {
194+
multiplier = config.early_step_multiplier;
195+
} else if (progress > 0.8f) {
196+
multiplier = config.late_step_multiplier;
197+
}
198+
199+
return base_threshold * multiplier;
200+
}
201+
127202
bool has_cache(const SDCondition* cond) const {
128203
auto it = cache_diffs.find(cond);
129204
return it != cache_diffs.end() && !it->second.diff.empty();
@@ -212,15 +287,18 @@ struct UCacheState {
212287
last_input_change > 0.0f && output_prev_norm > 0.0f) {
213288

214289
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
215-
cumulative_change_rate += approx_output_change_rate;
290+
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
291+
292+
float effective_threshold = get_adaptive_threshold();
293+
if (config.use_relative_threshold && reference_output_norm > 0.0f) {
294+
effective_threshold = effective_threshold * reference_output_norm;
295+
}
216296

217-
if (cumulative_change_rate < config.reuse_threshold) {
297+
if (accumulated_error < effective_threshold) {
218298
skip_current_step = true;
219299
total_steps_skipped++;
220300
apply_cache(cond, input, output);
221301
return true;
222-
} else {
223-
cumulative_change_rate = 0.0f;
224302
}
225303
}
226304

@@ -270,16 +348,31 @@ struct UCacheState {
270348
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
271349
has_output_prev_norm = output_prev_norm > 0.0f;
272350

351+
if (reference_output_norm == 0.0f) {
352+
reference_output_norm = output_prev_norm;
353+
}
354+
273355
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
274356
float rate = output_change / last_input_change;
275357
if (std::isfinite(rate)) {
276358
relative_transformation_rate = rate;
277359
has_relative_transformation_rate = true;
360+
block_metrics.record(rate, output_prev_norm);
278361
}
279362
}
280363

281-
cumulative_change_rate = 0.0f;
282-
has_last_input_change = false;
364+
has_last_input_change = false;
365+
}
366+
367+
void log_block_metrics() const {
368+
if (block_metrics.sample_count > 0) {
369+
LOG_INFO("UCacheBlockMetrics: samples=%d, avg_rate=%.4f, min=%.4f, max=%.4f, avg_norm=%.4f",
370+
block_metrics.sample_count,
371+
block_metrics.avg_transformation_rate(),
372+
block_metrics.min_change_rate,
373+
block_metrics.max_change_rate,
374+
block_metrics.avg_output_norm());
375+
}
283376
}
284377
};
285378

0 commit comments

Comments
 (0)