Skip to content

Commit 7aa2a62

Browse files
committed
use single unified struct
1 parent 8639fd5 commit 7aa2a62

File tree

3 files changed

+89
-125
lines changed

3 files changed

+89
-125
lines changed

examples/cli/main.cpp

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,7 @@ struct SDGenerationParams {
10631063

10641064
std::string cache_mode;
10651065
std::string cache_option;
1066-
sd_easycache_params_t easycache_params;
1067-
sd_ucache_params_t ucache_params;
1066+
sd_cache_params_t cache_params;
10681067

10691068
float moe_boundary = 0.875f;
10701069
int video_frames = 1;
@@ -1555,8 +1554,7 @@ struct SDGenerationParams {
15551554
return false;
15561555
}
15571556

1558-
easycache_params.enabled = false;
1559-
ucache_params.enabled = false;
1557+
cache_params.mode = SD_CACHE_DISABLED;
15601558

15611559
if (!cache_mode.empty()) {
15621560
std::string option_str = cache_option;
@@ -1617,18 +1615,15 @@ struct SDGenerationParams {
16171615
return false;
16181616
}
16191617

1618+
cache_params.reuse_threshold = values[0];
1619+
cache_params.start_percent = values[1];
1620+
cache_params.end_percent = values[2];
1621+
cache_params.error_decay_rate = values[3];
1622+
cache_params.use_relative_threshold = (values[4] != 0.0f);
16201623
if (cache_mode == "easycache") {
1621-
easycache_params.enabled = true;
1622-
easycache_params.reuse_threshold = values[0];
1623-
easycache_params.start_percent = values[1];
1624-
easycache_params.end_percent = values[2];
1624+
cache_params.mode = SD_CACHE_EASYCACHE;
16251625
} else {
1626-
ucache_params.enabled = true;
1627-
ucache_params.reuse_threshold = values[0];
1628-
ucache_params.start_percent = values[1];
1629-
ucache_params.end_percent = values[2];
1630-
ucache_params.error_decay_rate = values[3];
1631-
ucache_params.use_relative_threshold = (values[4] != 0.0f);
1626+
cache_params.mode = SD_CACHE_UCACHE;
16321627
}
16331628
}
16341629

@@ -1726,16 +1721,12 @@ struct SDGenerationParams {
17261721
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
17271722
<< " cache_mode: \"" << cache_mode << "\",\n"
17281723
<< " cache_option: \"" << cache_option << "\",\n"
1729-
<< " easycache: "
1730-
<< (easycache_params.enabled ? "enabled" : "disabled")
1731-
<< " (threshold=" << easycache_params.reuse_threshold
1732-
<< ", start=" << easycache_params.start_percent
1733-
<< ", end=" << easycache_params.end_percent << "),\n"
1734-
<< " ucache: "
1735-
<< (ucache_params.enabled ? "enabled" : "disabled")
1736-
<< " (threshold=" << ucache_params.reuse_threshold
1737-
<< ", start=" << ucache_params.start_percent
1738-
<< ", end=" << ucache_params.end_percent << "),\n"
1724+
<< " cache: "
1725+
<< (cache_params.mode == SD_CACHE_DISABLED ? "disabled" :
1726+
(cache_params.mode == SD_CACHE_EASYCACHE ? "easycache" : "ucache"))
1727+
<< " (threshold=" << cache_params.reuse_threshold
1728+
<< ", start=" << cache_params.start_percent
1729+
<< ", end=" << cache_params.end_percent << "),\n"
17391730
<< " moe_boundary: " << moe_boundary << ",\n"
17401731
<< " video_frames: " << video_frames << ",\n"
17411732
<< " fps: " << fps << ",\n"
@@ -2315,8 +2306,7 @@ int main(int argc, const char* argv[]) {
23152306
gen_params.pm_style_strength,
23162307
}, // pm_params
23172308
ctx_params.vae_tiling_params,
2318-
gen_params.easycache_params,
2319-
gen_params.ucache_params,
2309+
gen_params.cache_params,
23202310
};
23212311

23222312
results = generate_image(sd_ctx, &img_gen_params);
@@ -2341,8 +2331,7 @@ int main(int argc, const char* argv[]) {
23412331
gen_params.seed,
23422332
gen_params.video_frames,
23432333
gen_params.vace_strength,
2344-
gen_params.easycache_params,
2345-
gen_params.ucache_params,
2334+
gen_params.cache_params,
23462335
};
23472336

23482337
results = generate_video(sd_ctx, &vid_gen_params, &num_results);

stable-diffusion.cpp

Lines changed: 62 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,8 +1477,7 @@ class StableDiffusionGGML {
14771477
ggml_tensor* denoise_mask = nullptr,
14781478
ggml_tensor* vace_context = nullptr,
14791479
float vace_strength = 1.f,
1480-
const sd_easycache_params_t* easycache_params = nullptr,
1481-
const sd_ucache_params_t* ucache_params = nullptr) {
1480+
const sd_cache_params_t* cache_params = nullptr) {
14821481
if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) {
14831482
LOG_WARN("timestep shifting is only supported for SDXL models!");
14841483
shifted_timestep = 0;
@@ -1495,65 +1494,54 @@ class StableDiffusionGGML {
14951494
}
14961495

14971496
EasyCacheState easycache_state;
1497+
UCacheState ucache_state;
14981498
bool easycache_enabled = false;
1499-
if (easycache_params != nullptr && easycache_params->enabled) {
1500-
bool easycache_supported = sd_version_is_dit(version);
1501-
if (!easycache_supported) {
1502-
LOG_WARN("EasyCache requested but not supported for this model type");
1503-
} else {
1504-
EasyCacheConfig easycache_config;
1505-
easycache_config.enabled = true;
1506-
easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold);
1507-
easycache_config.start_percent = easycache_params->start_percent;
1508-
easycache_config.end_percent = easycache_params->end_percent;
1509-
bool percent_valid = easycache_config.start_percent >= 0.0f &&
1510-
easycache_config.start_percent < 1.0f &&
1511-
easycache_config.end_percent > 0.0f &&
1512-
easycache_config.end_percent <= 1.0f &&
1513-
easycache_config.start_percent < easycache_config.end_percent;
1514-
if (!percent_valid) {
1515-
LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)",
1516-
easycache_config.start_percent,
1517-
easycache_config.end_percent);
1499+
bool ucache_enabled = false;
1500+
1501+
if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) {
1502+
bool percent_valid = cache_params->start_percent >= 0.0f &&
1503+
cache_params->start_percent < 1.0f &&
1504+
cache_params->end_percent > 0.0f &&
1505+
cache_params->end_percent <= 1.0f &&
1506+
cache_params->start_percent < cache_params->end_percent;
1507+
1508+
if (!percent_valid) {
1509+
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
1510+
cache_params->start_percent,
1511+
cache_params->end_percent);
1512+
} else if (cache_params->mode == SD_CACHE_EASYCACHE) {
1513+
bool easycache_supported = sd_version_is_dit(version);
1514+
if (!easycache_supported) {
1515+
LOG_WARN("EasyCache requested but not supported for this model type");
15181516
} else {
1517+
EasyCacheConfig easycache_config;
1518+
easycache_config.enabled = true;
1519+
easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
1520+
easycache_config.start_percent = cache_params->start_percent;
1521+
easycache_config.end_percent = cache_params->end_percent;
15191522
easycache_state.init(easycache_config, denoiser.get());
15201523
if (easycache_state.enabled()) {
15211524
easycache_enabled = true;
1522-
LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f",
1525+
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
15231526
easycache_config.reuse_threshold,
15241527
easycache_config.start_percent,
15251528
easycache_config.end_percent);
15261529
} else {
15271530
LOG_WARN("EasyCache requested but could not be initialized for this run");
15281531
}
15291532
}
1530-
}
1531-
}
1532-
1533-
UCacheState ucache_state;
1534-
bool ucache_enabled = false;
1535-
if (ucache_params != nullptr && ucache_params->enabled) {
1536-
bool ucache_supported = sd_version_is_unet(version);
1537-
if (!ucache_supported) {
1538-
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
1539-
} else {
1540-
UCacheConfig ucache_config;
1541-
ucache_config.enabled = true;
1542-
ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold);
1543-
ucache_config.start_percent = ucache_params->start_percent;
1544-
ucache_config.end_percent = ucache_params->end_percent;
1545-
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, ucache_params->error_decay_rate));
1546-
ucache_config.use_relative_threshold = ucache_params->use_relative_threshold;
1547-
bool percent_valid = ucache_config.start_percent >= 0.0f &&
1548-
ucache_config.start_percent < 1.0f &&
1549-
ucache_config.end_percent > 0.0f &&
1550-
ucache_config.end_percent <= 1.0f &&
1551-
ucache_config.start_percent < ucache_config.end_percent;
1552-
if (!percent_valid) {
1553-
LOG_WARN("UCache disabled due to invalid percent range (start=%.3f, end=%.3f)",
1554-
ucache_config.start_percent,
1555-
ucache_config.end_percent);
1533+
} else if (cache_params->mode == SD_CACHE_UCACHE) {
1534+
bool ucache_supported = sd_version_is_unet(version);
1535+
if (!ucache_supported) {
1536+
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
15561537
} else {
1538+
UCacheConfig ucache_config;
1539+
ucache_config.enabled = true;
1540+
ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
1541+
ucache_config.start_percent = cache_params->start_percent;
1542+
ucache_config.end_percent = cache_params->end_percent;
1543+
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
1544+
ucache_config.use_relative_threshold = cache_params->use_relative_threshold;
15571545
ucache_state.init(ucache_config, denoiser.get());
15581546
if (ucache_state.enabled()) {
15591547
ucache_enabled = true;
@@ -2601,22 +2589,14 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
26012589
return LORA_APPLY_MODE_COUNT;
26022590
}
26032591

2604-
void sd_easycache_params_init(sd_easycache_params_t* easycache_params) {
2605-
*easycache_params = {};
2606-
easycache_params->enabled = false;
2607-
easycache_params->reuse_threshold = 0.2f;
2608-
easycache_params->start_percent = 0.15f;
2609-
easycache_params->end_percent = 0.95f;
2610-
}
2611-
2612-
void sd_ucache_params_init(sd_ucache_params_t* ucache_params) {
2613-
*ucache_params = {};
2614-
ucache_params->enabled = false;
2615-
ucache_params->reuse_threshold = 1.0f;
2616-
ucache_params->start_percent = 0.15f;
2617-
ucache_params->end_percent = 0.95f;
2618-
ucache_params->error_decay_rate = 1.0f;
2619-
ucache_params->use_relative_threshold = true;
2592+
void sd_cache_params_init(sd_cache_params_t* cache_params) {
2593+
*cache_params = {};
2594+
cache_params->mode = SD_CACHE_DISABLED;
2595+
cache_params->reuse_threshold = 1.0f;
2596+
cache_params->start_percent = 0.15f;
2597+
cache_params->end_percent = 0.95f;
2598+
cache_params->error_decay_rate = 1.0f;
2599+
cache_params->use_relative_threshold = true;
26202600
}
26212601

26222602
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
@@ -2777,8 +2757,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
27772757
sd_img_gen_params->control_strength = 0.9f;
27782758
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
27792759
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
2780-
sd_easycache_params_init(&sd_img_gen_params->easycache);
2781-
sd_ucache_params_init(&sd_img_gen_params->ucache);
2760+
sd_cache_params_init(&sd_img_gen_params->cache);
27822761
}
27832762

27842763
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
@@ -2822,12 +2801,18 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
28222801
sd_img_gen_params->pm_params.id_images_count,
28232802
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
28242803
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
2804+
const char* cache_mode_str = "disabled";
2805+
if (sd_img_gen_params->cache.mode == SD_CACHE_EASYCACHE) {
2806+
cache_mode_str = "easycache";
2807+
} else if (sd_img_gen_params->cache.mode == SD_CACHE_UCACHE) {
2808+
cache_mode_str = "ucache";
2809+
}
28252810
snprintf(buf + strlen(buf), 4096 - strlen(buf),
2826-
"easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
2827-
sd_img_gen_params->easycache.enabled ? "enabled" : "disabled",
2828-
sd_img_gen_params->easycache.reuse_threshold,
2829-
sd_img_gen_params->easycache.start_percent,
2830-
sd_img_gen_params->easycache.end_percent);
2811+
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
2812+
cache_mode_str,
2813+
sd_img_gen_params->cache.reuse_threshold,
2814+
sd_img_gen_params->cache.start_percent,
2815+
sd_img_gen_params->cache.end_percent);
28312816
free(sample_params_str);
28322817
return buf;
28332818
}
@@ -2844,8 +2829,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
28442829
sd_vid_gen_params->video_frames = 6;
28452830
sd_vid_gen_params->moe_boundary = 0.875f;
28462831
sd_vid_gen_params->vace_strength = 1.f;
2847-
sd_easycache_params_init(&sd_vid_gen_params->easycache);
2848-
sd_ucache_params_init(&sd_vid_gen_params->ucache);
2832+
sd_cache_params_init(&sd_vid_gen_params->cache);
28492833
}
28502834

28512835
struct sd_ctx_t {
@@ -2923,8 +2907,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
29232907
bool increase_ref_index,
29242908
ggml_tensor* concat_latent = nullptr,
29252909
ggml_tensor* denoise_mask = nullptr,
2926-
const sd_easycache_params_t* easycache_params = nullptr,
2927-
const sd_ucache_params_t* ucache_params = nullptr) {
2910+
const sd_cache_params_t* cache_params = nullptr) {
29282911
if (seed < 0) {
29292912
// Generally, when using the provided command line, the seed is always >0.
29302913
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -3213,8 +3196,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
32133196
denoise_mask,
32143197
nullptr,
32153198
1.0f,
3216-
easycache_params,
3217-
ucache_params);
3199+
cache_params);
32183200
int64_t sampling_end = ggml_time_ms();
32193201
if (x_0 != nullptr) {
32203202
// print_ggml_tensor(x_0);
@@ -3548,8 +3530,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
35483530
sd_img_gen_params->increase_ref_index,
35493531
concat_latent,
35503532
denoise_mask,
3551-
&sd_img_gen_params->easycache,
3552-
&sd_img_gen_params->ucache);
3533+
&sd_img_gen_params->cache);
35533534

35543535
size_t t2 = ggml_time_ms();
35553536

@@ -3914,8 +3895,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
39143895
denoise_mask,
39153896
vace_context,
39163897
sd_vid_gen_params->vace_strength,
3917-
&sd_vid_gen_params->easycache,
3918-
&sd_vid_gen_params->ucache);
3898+
&sd_vid_gen_params->cache);
39193899

39203900
int64_t sampling_end = ggml_time_ms();
39213901
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
@@ -3952,8 +3932,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
39523932
denoise_mask,
39533933
vace_context,
39543934
sd_vid_gen_params->vace_strength,
3955-
&sd_vid_gen_params->easycache,
3956-
&sd_vid_gen_params->ucache);
3935+
&sd_vid_gen_params->cache);
39573936

39583937
int64_t sampling_end = ggml_time_ms();
39593938
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);

stable-diffusion.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,21 +236,20 @@ typedef struct {
236236
float style_strength;
237237
} sd_pm_params_t; // photo maker
238238

239-
typedef struct {
240-
bool enabled;
241-
float reuse_threshold;
242-
float start_percent;
243-
float end_percent;
244-
} sd_easycache_params_t;
239+
enum sd_cache_mode_t {
240+
SD_CACHE_DISABLED = 0,
241+
SD_CACHE_EASYCACHE,
242+
SD_CACHE_UCACHE,
243+
};
245244

246245
typedef struct {
247-
bool enabled;
246+
enum sd_cache_mode_t mode;
248247
float reuse_threshold;
249248
float start_percent;
250249
float end_percent;
251250
float error_decay_rate;
252251
bool use_relative_threshold;
253-
} sd_ucache_params_t;
252+
} sd_cache_params_t;
254253

255254
typedef struct {
256255
bool is_high_noise;
@@ -280,8 +279,7 @@ typedef struct {
280279
float control_strength;
281280
sd_pm_params_t pm_params;
282281
sd_tiling_params_t vae_tiling_params;
283-
sd_easycache_params_t easycache;
284-
sd_ucache_params_t ucache;
282+
sd_cache_params_t cache;
285283
} sd_img_gen_params_t;
286284

287285
typedef struct {
@@ -303,8 +301,7 @@ typedef struct {
303301
int64_t seed;
304302
int video_frames;
305303
float vace_strength;
306-
sd_easycache_params_t easycache;
307-
sd_ucache_params_t ucache;
304+
sd_cache_params_t cache;
308305
} sd_vid_gen_params_t;
309306

310307
typedef struct sd_ctx_t sd_ctx_t;
@@ -334,8 +331,7 @@ SD_API enum preview_t str_to_preview(const char* str);
334331
SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode);
335332
SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str);
336333

337-
SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params);
338-
SD_API void sd_ucache_params_init(sd_ucache_params_t* ucache_params);
334+
SD_API void sd_cache_params_init(sd_cache_params_t* cache_params);
339335

340336
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
341337
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

0 commit comments

Comments
 (0)