@@ -356,24 +356,25 @@ def test_custom_dist_default_support_point(self, dist_params, size, expected, di
356356 assert_support_point_is_expected (model , expected )
357357
358358 def test_custom_dist_default_support_point_scan (self ):
359- def scan_step (left , right ):
360- x = Uniform .dist (left , right )
361- x_update = collect_default_updates ([x ])
362- return x , x_update
359+ def scan_step (left , right , rng ):
360+ x = Uniform .dist (left , right , rng = rng )
361+ x_update = collect_default_updates ([x ], must_be_shared = False )
362+ return x , x_update [ rng ]
363363
364364 def dist (size ):
365- with pytest .warns (DeprecationWarning , match = "Scan return signature will change" ):
366- xs , updates = scan (
367- fn = scan_step ,
368- sequences = [
369- pt .as_tensor_variable (np .array ([- 4 , - 3 ])),
370- pt .as_tensor_variable (np .array ([- 2 , - 1 ])),
371- ],
372- name = "xs" ,
373- # There's a bug in the ordering of outputs when there's a mapped `None` output
374- # We have to stick with the deprecated API for now
375- return_updates = True ,
376- )
365+ rng = pytensor .shared (np .random .default_rng ())
366+ xs , next_rng = scan (
367+ fn = scan_step ,
368+ sequences = [
369+ pt .as_tensor_variable (np .array ([- 4 , - 3 ])),
370+ pt .as_tensor_variable (np .array ([- 2 , - 1 ])),
371+ ],
372+ outputs_info = [None , rng ],
373+ name = "xs" ,
374+ # There's a bug in the ordering of outputs when there's a mapped `None` output
375+ # We have to stick with the deprecated API for now
376+ return_updates = False ,
377+ )
377378 return xs
378379
379380 with Model () as model :
@@ -674,22 +675,21 @@ def test_chained_custom_dist_bug(self):
674675 batch = 2
675676
676677 def scan_dist (seq , n_steps , size ):
677- def step (s ):
678- innov = Normal .dist ()
678+ rng = pytensor .shared (np .random .default_rng ())
679+
680+ def step (s , rng ):
681+ next_rng , innov = Normal .dist (rng = rng ).owner .outputs
679682 traffic = s + innov
680- return traffic , {innov .owner .inputs [0 ]: innov .owner .outputs [0 ]}
681-
682- with pytest .warns (DeprecationWarning , match = "Scan return signature will change" ):
683- rv_seq , _ = pytensor .scan (
684- fn = step ,
685- sequences = [seq ],
686- outputs_info = [None ],
687- n_steps = n_steps ,
688- strict = True ,
689- # There's a bug in the ordering of outputs when there's a mapped `None` output
690- # We have to stick with the deprecated API for now
691- return_updates = True ,
692- )
683+ return traffic , next_rng
684+
685+ rv_seq , _next_rng = pytensor .scan (
686+ fn = step ,
687+ sequences = [seq ],
688+ outputs_info = [None , rng ],
689+ n_steps = n_steps ,
690+ strict = True ,
691+ return_updates = False ,
692+ )
693693 return rv_seq
694694
695695 def normal_shifted (mu , size ):
0 commit comments