@@ -1054,10 +1054,9 @@ struct SDGenerationParams {
10541054 std::vector<int > high_noise_skip_layers = {7 , 8 , 9 };
10551055 sd_sample_params_t high_noise_sample_params;
10561056
1057- std::string easycache_option;
1057+ std::string cache_mode;
1058+ std::string cache_option;
10581059 sd_easycache_params_t easycache_params;
1059-
1060- std::string ucache_option;
10611060 sd_ucache_params_t ucache_params;
10621061
10631062 float moe_boundary = 0 .875f ;
@@ -1378,68 +1377,24 @@ struct SDGenerationParams {
13781377 return 1 ;
13791378 };
13801379
1381- auto on_easycache_arg = [&](int argc, const char ** argv, int index) {
1382- const std::string default_values = " 0.2,0.15,0.95" ;
1383- auto looks_like_value = [](const std::string& token) {
1384- if (token.empty ()) {
1385- return false ;
1386- }
1387- if (token[0 ] != ' -' ) {
1388- return true ;
1389- }
1390- if (token.size () == 1 ) {
1391- return false ;
1392- }
1393- unsigned char next = static_cast <unsigned char >(token[1 ]);
1394- return std::isdigit (next) || token[1 ] == ' .' ;
1395- };
1396-
1397- std::string option_value;
1398- int consumed = 0 ;
1399- if (index + 1 < argc) {
1400- std::string next_arg = argv[index + 1 ];
1401- if (looks_like_value (next_arg)) {
1402- option_value = argv_to_utf8 (index + 1 , argv);
1403- consumed = 1 ;
1404- }
1380+ auto on_cache_mode_arg = [&](int argc, const char ** argv, int index) {
1381+ if (++index >= argc) {
1382+ return -1 ;
14051383 }
1406- if (option_value.empty ()) {
1407- option_value = default_values;
1384+ cache_mode = argv_to_utf8 (index, argv);
1385+ if (cache_mode != " easycache" && cache_mode != " ucache" ) {
1386+ fprintf (stderr, " error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n " , cache_mode.c_str ());
1387+ return -1 ;
14081388 }
1409- easycache_option = option_value;
1410- return consumed;
1389+ return 1 ;
14111390 };
14121391
1413- auto on_ucache_arg = [&](int argc, const char ** argv, int index) {
1414- const std::string default_values = " 1.0,0.15,0.95" ;
1415- auto looks_like_value = [](const std::string& token) {
1416- if (token.empty ()) {
1417- return false ;
1418- }
1419- if (token[0 ] != ' -' ) {
1420- return true ;
1421- }
1422- if (token.size () == 1 ) {
1423- return false ;
1424- }
1425- unsigned char next = static_cast <unsigned char >(token[1 ]);
1426- return std::isdigit (next) || token[1 ] == ' .' ;
1427- };
1428-
1429- std::string option_value;
1430- int consumed = 0 ;
1431- if (index + 1 < argc) {
1432- std::string next_arg = argv[index + 1 ];
1433- if (looks_like_value (next_arg)) {
1434- option_value = argv_to_utf8 (index + 1 , argv);
1435- consumed = 1 ;
1436- }
1437- }
1438- if (option_value.empty ()) {
1439- option_value = default_values;
1392+ auto on_cache_option_arg = [&](int argc, const char ** argv, int index) {
1393+ if (++index >= argc) {
1394+ return -1 ;
14401395 }
1441- ucache_option = option_value ;
1442- return consumed ;
1396+ cache_option = argv_to_utf8 (index, argv) ;
1397+ return 1 ;
14431398 };
14441399
14451400 options.manual_options = {
@@ -1474,13 +1429,13 @@ struct SDGenerationParams {
14741429 " reference image for Flux Kontext models (can be used multiple times)" ,
14751430 on_ref_image_arg},
14761431 {" " ,
1477- " --easycache " ,
1478- " enable EasyCache for DiT models with optional \" threshold,start_percent,end_percent \" (default: 0.2,0.15,0.95 )" ,
1479- on_easycache_arg },
1432+ " --cache-mode " ,
1433+ " caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL )" ,
1434+ on_cache_mode_arg },
14801435 {" " ,
1481- " --ucache " ,
1482- " enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \" threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)" ,
1483- on_ucache_arg },
1436+ " --cache-option " ,
1437+ " cache parameters \" threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache )" ,
1438+ on_cache_option_arg },
14841439
14851440 };
14861441
@@ -1593,62 +1548,21 @@ struct SDGenerationParams {
15931548 return false ;
15941549 }
15951550
1596- if (!easycache_option.empty ()) {
1597- float values[3 ] = {0 .0f , 0 .0f , 0 .0f };
1598- std::stringstream ss (easycache_option);
1599- std::string token;
1600- int idx = 0 ;
1601- while (std::getline (ss, token, ' ,' )) {
1602- auto trim = [](std::string& s) {
1603- const char * whitespace = " \t\r\n " ;
1604- auto start = s.find_first_not_of (whitespace);
1605- if (start == std::string::npos) {
1606- s.clear ();
1607- return ;
1608- }
1609- auto end = s.find_last_not_of (whitespace);
1610- s = s.substr (start, end - start + 1 );
1611- };
1612- trim (token);
1613- if (token.empty ()) {
1614- fprintf (stderr, " error: invalid easycache option '%s'\n " , easycache_option.c_str ());
1615- return false ;
1616- }
1617- if (idx >= 3 ) {
1618- fprintf (stderr, " error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1619- return false ;
1620- }
1621- try {
1622- values[idx] = std::stof (token);
1623- } catch (const std::exception&) {
1624- fprintf (stderr, " error: invalid easycache value '%s'\n " , token.c_str ());
1625- return false ;
1551+ easycache_params.enabled = false ;
1552+ ucache_params.enabled = false ;
1553+
1554+ if (!cache_mode.empty ()) {
1555+ std::string option_str = cache_option;
1556+ if (option_str.empty ()) {
1557+ if (cache_mode == " easycache" ) {
1558+ option_str = " 0.2,0.15,0.95" ;
1559+ } else {
1560+ option_str = " 1.0,0.15,0.95" ;
16261561 }
1627- idx++;
1628- }
1629- if (idx != 3 ) {
1630- fprintf (stderr, " error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1631- return false ;
16321562 }
1633- if (values[0 ] < 0 .0f ) {
1634- fprintf (stderr, " error: easycache threshold must be non-negative\n " );
1635- return false ;
1636- }
1637- if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1638- fprintf (stderr, " error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
1639- return false ;
1640- }
1641- easycache_params.enabled = true ;
1642- easycache_params.reuse_threshold = values[0 ];
1643- easycache_params.start_percent = values[1 ];
1644- easycache_params.end_percent = values[2 ];
1645- } else {
1646- easycache_params.enabled = false ;
1647- }
16481563
1649- if (!ucache_option.empty ()) {
16501564 float values[3 ] = {0 .0f , 0 .0f , 0 .0f };
1651- std::stringstream ss (ucache_option );
1565+ std::stringstream ss (option_str );
16521566 std::string token;
16531567 int idx = 0 ;
16541568 while (std::getline (ss, token, ' ,' )) {
@@ -1664,39 +1578,45 @@ struct SDGenerationParams {
16641578 };
16651579 trim (token);
16661580 if (token.empty ()) {
1667- fprintf (stderr, " error: invalid ucache option '%s'\n " , ucache_option .c_str ());
1581+ fprintf (stderr, " error: invalid cache option '%s'\n " , option_str .c_str ());
16681582 return false ;
16691583 }
16701584 if (idx >= 3 ) {
1671- fprintf (stderr, " error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1585+ fprintf (stderr, " error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n " );
16721586 return false ;
16731587 }
16741588 try {
16751589 values[idx] = std::stof (token);
16761590 } catch (const std::exception&) {
1677- fprintf (stderr, " error: invalid ucache value '%s'\n " , token.c_str ());
1591+ fprintf (stderr, " error: invalid cache option value '%s'\n " , token.c_str ());
16781592 return false ;
16791593 }
16801594 idx++;
16811595 }
16821596 if (idx != 3 ) {
1683- fprintf (stderr, " error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1597+ fprintf (stderr, " error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n " );
16841598 return false ;
16851599 }
16861600 if (values[0 ] < 0 .0f ) {
1687- fprintf (stderr, " error: ucache threshold must be non-negative\n " );
1601+ fprintf (stderr, " error: cache threshold must be non-negative\n " );
16881602 return false ;
16891603 }
16901604 if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1691- fprintf (stderr, " error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
1605+ fprintf (stderr, " error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
16921606 return false ;
16931607 }
1694- ucache_params.enabled = true ;
1695- ucache_params.reuse_threshold = values[0 ];
1696- ucache_params.start_percent = values[1 ];
1697- ucache_params.end_percent = values[2 ];
1698- } else {
1699- ucache_params.enabled = false ;
1608+
1609+ if (cache_mode == " easycache" ) {
1610+ easycache_params.enabled = true ;
1611+ easycache_params.reuse_threshold = values[0 ];
1612+ easycache_params.start_percent = values[1 ];
1613+ easycache_params.end_percent = values[2 ];
1614+ } else {
1615+ ucache_params.enabled = true ;
1616+ ucache_params.reuse_threshold = values[0 ];
1617+ ucache_params.start_percent = values[1 ];
1618+ ucache_params.end_percent = values[2 ];
1619+ }
17001620 }
17011621
17021622 sample_params.guidance .slg .layers = skip_layers.data ();
@@ -1791,12 +1711,18 @@ struct SDGenerationParams {
17911711 << " sample_params: " << sample_params_str << " ,\n "
17921712 << " high_noise_skip_layers: " << vec_to_string (high_noise_skip_layers) << " ,\n "
17931713 << " high_noise_sample_params: " << high_noise_sample_params_str << " ,\n "
1794- << " easycache_option: \" " << easycache_option << " \" ,\n "
1714+ << " cache_mode: \" " << cache_mode << " \" ,\n "
1715+ << " cache_option: \" " << cache_option << " \" ,\n "
17951716 << " easycache: "
17961717 << (easycache_params.enabled ? " enabled" : " disabled" )
17971718 << " (threshold=" << easycache_params.reuse_threshold
17981719 << " , start=" << easycache_params.start_percent
17991720 << " , end=" << easycache_params.end_percent << " ),\n "
1721+ << " ucache: "
1722+ << (ucache_params.enabled ? " enabled" : " disabled" )
1723+ << " (threshold=" << ucache_params.reuse_threshold
1724+ << " , start=" << ucache_params.start_percent
1725+ << " , end=" << ucache_params.end_percent << " ),\n "
18001726 << " moe_boundary: " << moe_boundary << " ,\n "
18011727 << " video_frames: " << video_frames << " ,\n "
18021728 << " fps: " << fps << " ,\n "
0 commit comments