Skip to content

Commit b9a35f5

Browse files
UniPCMultistepScheduler for use_flow_sigmas=True & use_karras_sigmas=True
1 parent abba01c commit b9a35f5

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/schedulers/test_scheduler_unipc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,32 @@ def test_beta_sigmas(self):
399399

400400
def test_exponential_sigmas(self):
401401
self.check_over_configs(use_exponential_sigmas=True)
402+
403+
def test_flow_and_karras_sigmas(self):
404+
self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True)
405+
406+
def test_flow_and_karras_sigmas_values(self):
407+
num_train_timesteps = 1000
408+
num_inference_steps = 5
409+
scheduler = UniPCMultistepScheduler(
410+
sigma_min=0.01,
411+
sigma_max=200.0,
412+
use_flow_sigmas=True,
413+
use_karras_sigmas=True,
414+
num_train_timesteps=num_train_timesteps,
415+
)
416+
scheduler.set_timesteps(num_inference_steps=num_inference_steps)
417+
418+
expected_sigmas = [
419+
0.9950248599052429,
420+
0.9787454605102539,
421+
0.8774884343147278,
422+
0.3604971766471863,
423+
0.009900986216962337,
424+
0.0, # 0 appended as default
425+
]
426+
expected_sigmas = torch.tensor(expected_sigmas)
427+
expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64)
428+
expected_timesteps = expected_timesteps[0:-1]
429+
self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas))
430+
self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps))

0 commit comments

Comments
 (0)