Skip to content

Commit f347010

Browse files
committed
add cache-mode and cache-option
1 parent fb88d86 commit f347010

File tree

2 files changed

+59
-133
lines changed

2 files changed

+59
-133
lines changed

examples/cli/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,6 @@ Generation Options:
126126
--skip-layers layers to skip for SLG steps (default: [7,8,9])
127127
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
128128
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
129-
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
130-
--ucache enable UCache for UNET models with optional "threshold,start_percent,end_percent" (default: 1,0.15,0.95)
129+
--cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)
130+
--cache-option cache parameters "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)
131131
```

examples/cli/main.cpp

Lines changed: 57 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,10 +1054,9 @@ struct SDGenerationParams {
10541054
std::vector<int> high_noise_skip_layers = {7, 8, 9};
10551055
sd_sample_params_t high_noise_sample_params;
10561056

1057-
std::string easycache_option;
1057+
std::string cache_mode;
1058+
std::string cache_option;
10581059
sd_easycache_params_t easycache_params;
1059-
1060-
std::string ucache_option;
10611060
sd_ucache_params_t ucache_params;
10621061

10631062
float moe_boundary = 0.875f;
@@ -1378,68 +1377,24 @@ struct SDGenerationParams {
13781377
return 1;
13791378
};
13801379

1381-
auto on_easycache_arg = [&](int argc, const char** argv, int index) {
1382-
const std::string default_values = "0.2,0.15,0.95";
1383-
auto looks_like_value = [](const std::string& token) {
1384-
if (token.empty()) {
1385-
return false;
1386-
}
1387-
if (token[0] != '-') {
1388-
return true;
1389-
}
1390-
if (token.size() == 1) {
1391-
return false;
1392-
}
1393-
unsigned char next = static_cast<unsigned char>(token[1]);
1394-
return std::isdigit(next) || token[1] == '.';
1395-
};
1396-
1397-
std::string option_value;
1398-
int consumed = 0;
1399-
if (index + 1 < argc) {
1400-
std::string next_arg = argv[index + 1];
1401-
if (looks_like_value(next_arg)) {
1402-
option_value = argv_to_utf8(index + 1, argv);
1403-
consumed = 1;
1404-
}
1380+
auto on_cache_mode_arg = [&](int argc, const char** argv, int index) {
1381+
if (++index >= argc) {
1382+
return -1;
14051383
}
1406-
if (option_value.empty()) {
1407-
option_value = default_values;
1384+
cache_mode = argv_to_utf8(index, argv);
1385+
if (cache_mode != "easycache" && cache_mode != "ucache") {
1386+
fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n", cache_mode.c_str());
1387+
return -1;
14081388
}
1409-
easycache_option = option_value;
1410-
return consumed;
1389+
return 1;
14111390
};
14121391

1413-
auto on_ucache_arg = [&](int argc, const char** argv, int index) {
1414-
const std::string default_values = "1.0,0.15,0.95";
1415-
auto looks_like_value = [](const std::string& token) {
1416-
if (token.empty()) {
1417-
return false;
1418-
}
1419-
if (token[0] != '-') {
1420-
return true;
1421-
}
1422-
if (token.size() == 1) {
1423-
return false;
1424-
}
1425-
unsigned char next = static_cast<unsigned char>(token[1]);
1426-
return std::isdigit(next) || token[1] == '.';
1427-
};
1428-
1429-
std::string option_value;
1430-
int consumed = 0;
1431-
if (index + 1 < argc) {
1432-
std::string next_arg = argv[index + 1];
1433-
if (looks_like_value(next_arg)) {
1434-
option_value = argv_to_utf8(index + 1, argv);
1435-
consumed = 1;
1436-
}
1437-
}
1438-
if (option_value.empty()) {
1439-
option_value = default_values;
1392+
auto on_cache_option_arg = [&](int argc, const char** argv, int index) {
1393+
if (++index >= argc) {
1394+
return -1;
14401395
}
1441-
ucache_option = option_value;
1442-
return consumed;
1396+
cache_option = argv_to_utf8(index, argv);
1397+
return 1;
14431398
};
14441399

14451400
options.manual_options = {
@@ -1474,13 +1429,13 @@ struct SDGenerationParams {
14741429
"reference image for Flux Kontext models (can be used multiple times)",
14751430
on_ref_image_arg},
14761431
{"",
1477-
"--easycache",
1478-
"enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
1479-
on_easycache_arg},
1432+
"--cache-mode",
1433+
"caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)",
1434+
on_cache_mode_arg},
14801435
{"",
1481-
"--ucache",
1482-
"enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \"threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)",
1483-
on_ucache_arg},
1436+
"--cache-option",
1437+
"cache parameters \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)",
1438+
on_cache_option_arg},
14841439

14851440
};
14861441

@@ -1593,62 +1548,21 @@ struct SDGenerationParams {
15931548
return false;
15941549
}
15951550

1596-
if (!easycache_option.empty()) {
1597-
float values[3] = {0.0f, 0.0f, 0.0f};
1598-
std::stringstream ss(easycache_option);
1599-
std::string token;
1600-
int idx = 0;
1601-
while (std::getline(ss, token, ',')) {
1602-
auto trim = [](std::string& s) {
1603-
const char* whitespace = " \t\r\n";
1604-
auto start = s.find_first_not_of(whitespace);
1605-
if (start == std::string::npos) {
1606-
s.clear();
1607-
return;
1608-
}
1609-
auto end = s.find_last_not_of(whitespace);
1610-
s = s.substr(start, end - start + 1);
1611-
};
1612-
trim(token);
1613-
if (token.empty()) {
1614-
fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str());
1615-
return false;
1616-
}
1617-
if (idx >= 3) {
1618-
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1619-
return false;
1620-
}
1621-
try {
1622-
values[idx] = std::stof(token);
1623-
} catch (const std::exception&) {
1624-
fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
1625-
return false;
1551+
easycache_params.enabled = false;
1552+
ucache_params.enabled = false;
1553+
1554+
if (!cache_mode.empty()) {
1555+
std::string option_str = cache_option;
1556+
if (option_str.empty()) {
1557+
if (cache_mode == "easycache") {
1558+
option_str = "0.2,0.15,0.95";
1559+
} else {
1560+
option_str = "1.0,0.15,0.95";
16261561
}
1627-
idx++;
1628-
}
1629-
if (idx != 3) {
1630-
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1631-
return false;
16321562
}
1633-
if (values[0] < 0.0f) {
1634-
fprintf(stderr, "error: easycache threshold must be non-negative\n");
1635-
return false;
1636-
}
1637-
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1638-
fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1639-
return false;
1640-
}
1641-
easycache_params.enabled = true;
1642-
easycache_params.reuse_threshold = values[0];
1643-
easycache_params.start_percent = values[1];
1644-
easycache_params.end_percent = values[2];
1645-
} else {
1646-
easycache_params.enabled = false;
1647-
}
16481563

1649-
if (!ucache_option.empty()) {
16501564
float values[3] = {0.0f, 0.0f, 0.0f};
1651-
std::stringstream ss(ucache_option);
1565+
std::stringstream ss(option_str);
16521566
std::string token;
16531567
int idx = 0;
16541568
while (std::getline(ss, token, ',')) {
@@ -1664,39 +1578,45 @@ struct SDGenerationParams {
16641578
};
16651579
trim(token);
16661580
if (token.empty()) {
1667-
fprintf(stderr, "error: invalid ucache option '%s'\n", ucache_option.c_str());
1581+
fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str());
16681582
return false;
16691583
}
16701584
if (idx >= 3) {
1671-
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1585+
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
16721586
return false;
16731587
}
16741588
try {
16751589
values[idx] = std::stof(token);
16761590
} catch (const std::exception&) {
1677-
fprintf(stderr, "error: invalid ucache value '%s'\n", token.c_str());
1591+
fprintf(stderr, "error: invalid cache option value '%s'\n", token.c_str());
16781592
return false;
16791593
}
16801594
idx++;
16811595
}
16821596
if (idx != 3) {
1683-
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1597+
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
16841598
return false;
16851599
}
16861600
if (values[0] < 0.0f) {
1687-
fprintf(stderr, "error: ucache threshold must be non-negative\n");
1601+
fprintf(stderr, "error: cache threshold must be non-negative\n");
16881602
return false;
16891603
}
16901604
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1691-
fprintf(stderr, "error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1605+
fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
16921606
return false;
16931607
}
1694-
ucache_params.enabled = true;
1695-
ucache_params.reuse_threshold = values[0];
1696-
ucache_params.start_percent = values[1];
1697-
ucache_params.end_percent = values[2];
1698-
} else {
1699-
ucache_params.enabled = false;
1608+
1609+
if (cache_mode == "easycache") {
1610+
easycache_params.enabled = true;
1611+
easycache_params.reuse_threshold = values[0];
1612+
easycache_params.start_percent = values[1];
1613+
easycache_params.end_percent = values[2];
1614+
} else {
1615+
ucache_params.enabled = true;
1616+
ucache_params.reuse_threshold = values[0];
1617+
ucache_params.start_percent = values[1];
1618+
ucache_params.end_percent = values[2];
1619+
}
17001620
}
17011621

17021622
sample_params.guidance.slg.layers = skip_layers.data();
@@ -1791,12 +1711,18 @@ struct SDGenerationParams {
17911711
<< " sample_params: " << sample_params_str << ",\n"
17921712
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
17931713
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
1794-
<< " easycache_option: \"" << easycache_option << "\",\n"
1714+
<< " cache_mode: \"" << cache_mode << "\",\n"
1715+
<< " cache_option: \"" << cache_option << "\",\n"
17951716
<< " easycache: "
17961717
<< (easycache_params.enabled ? "enabled" : "disabled")
17971718
<< " (threshold=" << easycache_params.reuse_threshold
17981719
<< ", start=" << easycache_params.start_percent
17991720
<< ", end=" << easycache_params.end_percent << "),\n"
1721+
<< " ucache: "
1722+
<< (ucache_params.enabled ? "enabled" : "disabled")
1723+
<< " (threshold=" << ucache_params.reuse_threshold
1724+
<< ", start=" << ucache_params.start_percent
1725+
<< ", end=" << ucache_params.end_percent << "),\n"
18001726
<< " moe_boundary: " << moe_boundary << ",\n"
18011727
<< " video_frames: " << video_frames << ",\n"
18021728
<< " fps: " << fps << ",\n"

0 commit comments

Comments
 (0)