@@ -1477,7 +1477,7 @@ struct SDGenerationParams {
14771477 on_cache_mode_arg},
14781478 {" " ,
14791479 " --cache-option" ,
1480- " cache params - legacy: \" threshold,start,end[ ,decay,relative] \" , cache-dit: \" Fn ,Bn,threshold,warmup\" (default: 8,0,0.08,8) " ,
1480+ " named cache params: easycache/ucache: threshold= ,start= ,end= ,decay= ,relative= | cache-dit: Fn= ,Bn= ,threshold= ,warmup= " ,
14811481 on_cache_option_arg},
14821482 {" " ,
14831483 " --scm-mask" ,
@@ -1606,88 +1606,125 @@ struct SDGenerationParams {
16061606 cache_params.mode = SD_CACHE_DISABLED;
16071607
16081608 if (!cache_mode.empty ()) {
1609- std::string option_str = cache_option;
1610- if (option_str.empty ()) {
1611- if (cache_mode == " easycache" ) {
1612- option_str = " 0.2,0.15,0.95" ;
1613- } else if (cache_mode == " ucache" ) {
1614- option_str = " 1.0,0.15,0.95" ;
1615- } else if (cache_mode == " dbcache" || cache_mode == " taylorseer" || cache_mode == " cache-dit" ) {
1616- option_str = " 8,0,0.08,8" ;
1617- }
1618- }
1619-
1620- float values[5 ] = {0 .0f , 0 .0f , 0 .0f , 1 .0f , 1 .0f };
1621- std::stringstream ss (option_str);
1622- std::string token;
1623- int idx = 0 ;
16241609 auto trim = [](std::string& s) {
16251610 const char * whitespace = " \t\r\n " ;
1626- auto start = s.find_first_not_of (whitespace);
1627- if (start == std::string::npos) {
1628- s.clear ();
1629- return ;
1630- }
1611+ auto start = s.find_first_not_of (whitespace);
1612+ if (start == std::string::npos) { s.clear (); return ; }
16311613 auto end = s.find_last_not_of (whitespace);
1632- s = s.substr (start, end - start + 1 );
1614+ s = s.substr (start, end - start + 1 );
16331615 };
1634- while (std::getline (ss, token, ' ,' )) {
1635- trim (token);
1636- if (token.empty ()) {
1637- fprintf (stderr, " error: invalid cache option '%s'\n " , option_str.c_str ());
1638- return false ;
1639- }
1640- if (idx >= 5 ) {
1641- fprintf (stderr, " error: cache option expects 3-5 comma-separated values (threshold,start,end[,decay,relative])\n " );
1642- return false ;
1616+
1617+ auto parse_named_params = [&](const std::string& opt_str) -> bool {
1618+ std::stringstream ss (opt_str);
1619+ std::string token;
1620+ while (std::getline (ss, token, ' ,' )) {
1621+ trim (token);
1622+ if (token.empty ()) continue ;
1623+
1624+ size_t eq_pos = token.find (' =' );
1625+ if (eq_pos == std::string::npos) {
1626+ fprintf (stderr, " error: invalid named parameter '%s', expected key=value\n " , token.c_str ());
1627+ return false ;
1628+ }
1629+
1630+ std::string key = token.substr (0 , eq_pos);
1631+ std::string val = token.substr (eq_pos + 1 );
1632+ trim (key);
1633+ trim (val);
1634+
1635+ if (key.empty () || val.empty ()) {
1636+ fprintf (stderr, " error: invalid named parameter '%s'\n " , token.c_str ());
1637+ return false ;
1638+ }
1639+
1640+ try {
1641+ if (key == " threshold" ) {
1642+ if (cache_mode == " easycache" || cache_mode == " ucache" ) {
1643+ cache_params.reuse_threshold = std::stof (val);
1644+ } else {
1645+ cache_params.residual_diff_threshold = std::stof (val);
1646+ }
1647+ } else if (key == " start" ) {
1648+ cache_params.start_percent = std::stof (val);
1649+ } else if (key == " end" ) {
1650+ cache_params.end_percent = std::stof (val);
1651+ } else if (key == " decay" ) {
1652+ cache_params.error_decay_rate = std::stof (val);
1653+ } else if (key == " relative" ) {
1654+ cache_params.use_relative_threshold = (std::stof (val) != 0 .0f );
1655+ } else if (key == " Fn" || key == " fn" ) {
1656+ cache_params.Fn_compute_blocks = std::stoi (val);
1657+ } else if (key == " Bn" || key == " bn" ) {
1658+ cache_params.Bn_compute_blocks = std::stoi (val);
1659+ } else if (key == " warmup" ) {
1660+ cache_params.max_warmup_steps = std::stoi (val);
1661+ } else {
1662+ fprintf (stderr, " error: unknown cache parameter '%s'\n " , key.c_str ());
1663+ return false ;
1664+ }
1665+ } catch (const std::exception&) {
1666+ fprintf (stderr, " error: invalid value '%s' for parameter '%s'\n " , val.c_str (), key.c_str ());
1667+ return false ;
1668+ }
16431669 }
1644- try {
1645- values[idx] = std::stof (token);
1646- } catch (const std::exception&) {
1647- fprintf (stderr, " error: invalid cache option value '%s'\n " , token.c_str ());
1670+ return true ;
1671+ };
1672+
1673+ if (cache_mode == " easycache" ) {
1674+ cache_params.mode = SD_CACHE_EASYCACHE;
1675+ cache_params.reuse_threshold = 0 .2f ;
1676+ cache_params.start_percent = 0 .15f ;
1677+ cache_params.end_percent = 0 .95f ;
1678+ cache_params.error_decay_rate = 1 .0f ;
1679+ cache_params.use_relative_threshold = true ;
1680+ } else if (cache_mode == " ucache" ) {
1681+ cache_params.mode = SD_CACHE_UCACHE;
1682+ cache_params.reuse_threshold = 1 .0f ;
1683+ cache_params.start_percent = 0 .15f ;
1684+ cache_params.end_percent = 0 .95f ;
1685+ cache_params.error_decay_rate = 1 .0f ;
1686+ cache_params.use_relative_threshold = true ;
1687+ } else if (cache_mode == " dbcache" ) {
1688+ cache_params.mode = SD_CACHE_DBCACHE;
1689+ cache_params.Fn_compute_blocks = 8 ;
1690+ cache_params.Bn_compute_blocks = 0 ;
1691+ cache_params.residual_diff_threshold = 0 .08f ;
1692+ cache_params.max_warmup_steps = 8 ;
1693+ } else if (cache_mode == " taylorseer" ) {
1694+ cache_params.mode = SD_CACHE_TAYLORSEER;
1695+ cache_params.Fn_compute_blocks = 8 ;
1696+ cache_params.Bn_compute_blocks = 0 ;
1697+ cache_params.residual_diff_threshold = 0 .08f ;
1698+ cache_params.max_warmup_steps = 8 ;
1699+ } else if (cache_mode == " cache-dit" ) {
1700+ cache_params.mode = SD_CACHE_CACHE_DIT;
1701+ cache_params.Fn_compute_blocks = 8 ;
1702+ cache_params.Bn_compute_blocks = 0 ;
1703+ cache_params.residual_diff_threshold = 0 .08f ;
1704+ cache_params.max_warmup_steps = 8 ;
1705+ } else {
1706+ fprintf (stderr, " error: unknown cache mode '%s'\n " , cache_mode.c_str ());
1707+ return false ;
1708+ }
1709+
1710+ if (!cache_option.empty ()) {
1711+ if (!parse_named_params (cache_option)) {
16481712 return false ;
16491713 }
1650- idx++;
16511714 }
1715+
16521716 if (cache_mode == " easycache" || cache_mode == " ucache" ) {
1653- if (idx < 3 ) {
1654- fprintf (stderr, " error: cache option expects at least 3 comma-separated values (threshold,start,end)\n " );
1655- return false ;
1656- }
1657- if (values[0 ] < 0 .0f ) {
1717+ if (cache_params.reuse_threshold < 0 .0f ) {
16581718 fprintf (stderr, " error: cache threshold must be non-negative\n " );
16591719 return false ;
16601720 }
1661- if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1721+ if (cache_params.start_percent < 0 .0f || cache_params.start_percent >= 1 .0f ||
1722+ cache_params.end_percent <= 0 .0f || cache_params.end_percent > 1 .0f ||
1723+ cache_params.start_percent >= cache_params.end_percent ) {
16621724 fprintf (stderr, " error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
16631725 return false ;
16641726 }
16651727 }
1666-
1667- if (cache_mode == " easycache" || cache_mode == " ucache" ) {
1668- cache_params.reuse_threshold = values[0 ];
1669- cache_params.start_percent = values[1 ];
1670- cache_params.end_percent = values[2 ];
1671- cache_params.error_decay_rate = values[3 ];
1672- cache_params.use_relative_threshold = (values[4 ] != 0 .0f );
1673- if (cache_mode == " easycache" ) {
1674- cache_params.mode = SD_CACHE_EASYCACHE;
1675- } else {
1676- cache_params.mode = SD_CACHE_UCACHE;
1677- }
1678- } else {
1679- cache_params.Fn_compute_blocks = (idx >= 1 ) ? static_cast <int >(values[0 ]) : 8 ;
1680- cache_params.Bn_compute_blocks = (idx >= 2 ) ? static_cast <int >(values[1 ]) : 0 ;
1681- cache_params.residual_diff_threshold = (idx >= 3 ) ? values[2 ] : 0 .08f ;
1682- cache_params.max_warmup_steps = (idx >= 4 ) ? static_cast <int >(values[3 ]) : 8 ;
1683- if (cache_mode == " dbcache" ) {
1684- cache_params.mode = SD_CACHE_DBCACHE;
1685- } else if (cache_mode == " taylorseer" ) {
1686- cache_params.mode = SD_CACHE_TAYLORSEER;
1687- } else {
1688- cache_params.mode = SD_CACHE_CACHE_DIT;
1689- }
1690- }
16911728 }
16921729
16931730 if (cache_params.mode == SD_CACHE_DBCACHE ||
0 commit comments