Skip to content

Commit a375fd1

Browse files
committed
Bump minor PyTensor dependency
1 parent 4ba1d2f commit a375fd1

File tree

11 files changed

+91
-95
lines changed

11 files changed

+91
-95
lines changed

conda-envs/environment-alternative-backends.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies:
2222
- numpyro>=0.8.0
2323
- pandas>=0.24.0
2424
- pip
25-
- pytensor>=2.36.0,<2.37
25+
- pytensor>=2.36.1,<2.37
2626
- python-graphviz
2727
- networkx
2828
- rich>=13.7.1

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.36.0,<2.37
15+
- pytensor>=2.36.1,<2.37
1616
- python-graphviz
1717
- networkx
1818
- scipy>=1.4.1

conda-envs/environment-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- numpy>=1.25.0
1212
- pandas>=0.24.0
1313
- pip
14-
- pytensor>=2.36.0,<2.37
14+
- pytensor>=2.36.1,<2.37
1515
- python-graphviz
1616
- rich>=13.7.1
1717
- scipy>=1.4.1

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies:
1414
- pandas>=0.24.0
1515
- pip
1616
- polyagamma
17-
- pytensor>=2.36.0,<2.37
17+
- pytensor>=2.36.1,<2.37
1818
- python-graphviz
1919
- networkx
2020
- rich>=13.7.1

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- numpy>=1.25.0
1313
- pandas>=0.24.0
1414
- pip
15-
- pytensor>=2.36.0,<2.37
15+
- pytensor>=2.36.1,<2.37
1616
- python-graphviz
1717
- networkx
1818
- rich>=13.7.1

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies:
1515
- pandas>=0.24.0
1616
- pip
1717
- polyagamma
18-
- pytensor>=2.36.0,<2.37
18+
- pytensor>=2.36.1,<2.37
1919
- python-graphviz
2020
- networkx
2121
- rich>=13.7.1

pymc/distributions/multivariate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,11 +2264,11 @@ class CAR(Continuous):
22642264
def dist(cls, mu, W, alpha, tau, *args, **kwargs):
22652265
# This variable has an expensive validation check, that we want to constant-fold if possible
22662266
# So it's passed as an explicit input
2267-
from pytensor.sparse import as_sparse_or_tensor_variable, structured_sign
2267+
from pytensor.sparse import as_sparse_or_tensor_variable, sign
22682268

22692269
W = as_sparse_or_tensor_variable(W)
22702270
if isinstance(W.type, pytensor.sparse.SparseTensorType):
2271-
abs_diff = structured_sign(W - W.T) * (W - W.T)
2271+
abs_diff = sign(W - W.T) * (W - W.T)
22722272
W_is_valid = pt.isclose(abs_diff.sum(), 0)
22732273
else:
22742274
W_is_valid = pt.allclose(W, W.T)

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pandas>=0.24.0
1616
polyagamma
1717
pre-commit>=2.8.0
1818
pymc-sphinx-theme>=0.16.0
19-
pytensor>=2.36.0,<2.37
19+
pytensor>=2.36.1,<2.37
2020
pytest-cov>=2.5
2121
pytest>=3.0
2222
rich>=13.7.1

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ cachetools>=4.2.1
33
cloudpickle
44
numpy>=1.25.0
55
pandas>=0.24.0
6-
pytensor>=2.36.0,<2.37
6+
pytensor>=2.36.1,<2.37
77
rich>=13.7.1
88
scipy>=1.4.1
99
threadpoolctl>=3.1.0,<4.0.0

tests/distributions/test_custom.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -356,24 +356,25 @@ def test_custom_dist_default_support_point(self, dist_params, size, expected, di
356356
assert_support_point_is_expected(model, expected)
357357

358358
def test_custom_dist_default_support_point_scan(self):
359-
def scan_step(left, right):
360-
x = Uniform.dist(left, right)
361-
x_update = collect_default_updates([x])
362-
return x, x_update
359+
def scan_step(left, right, rng):
360+
x = Uniform.dist(left, right, rng=rng)
361+
x_update = collect_default_updates([x], must_be_shared=False)
362+
return x, x_update[rng]
363363

364364
def dist(size):
365-
with pytest.warns(DeprecationWarning, match="Scan return signature will change"):
366-
xs, updates = scan(
367-
fn=scan_step,
368-
sequences=[
369-
pt.as_tensor_variable(np.array([-4, -3])),
370-
pt.as_tensor_variable(np.array([-2, -1])),
371-
],
372-
name="xs",
373-
# There's a bug in the ordering of outputs when there's a mapped `None` output
374-
# We have to stick with the deprecated API for now
375-
return_updates=True,
376-
)
365+
rng = pytensor.shared(np.random.default_rng())
366+
xs, next_rng = scan(
367+
fn=scan_step,
368+
sequences=[
369+
pt.as_tensor_variable(np.array([-4, -3])),
370+
pt.as_tensor_variable(np.array([-2, -1])),
371+
],
372+
outputs_info=[None, rng],
373+
name="xs",
374+
# There's a bug in the ordering of outputs when there's a mapped `None` output
375+
# We have to stick with the deprecated API for now
376+
return_updates=False,
377+
)
377378
return xs
378379

379380
with Model() as model:
@@ -674,22 +675,21 @@ def test_chained_custom_dist_bug(self):
674675
batch = 2
675676

676677
def scan_dist(seq, n_steps, size):
677-
def step(s):
678-
innov = Normal.dist()
678+
rng = pytensor.shared(np.random.default_rng())
679+
680+
def step(s, rng):
681+
next_rng, innov = Normal.dist(rng=rng).owner.outputs
679682
traffic = s + innov
680-
return traffic, {innov.owner.inputs[0]: innov.owner.outputs[0]}
681-
682-
with pytest.warns(DeprecationWarning, match="Scan return signature will change"):
683-
rv_seq, _ = pytensor.scan(
684-
fn=step,
685-
sequences=[seq],
686-
outputs_info=[None],
687-
n_steps=n_steps,
688-
strict=True,
689-
# There's a bug in the ordering of outputs when there's a mapped `None` output
690-
# We have to stick with the deprecated API for now
691-
return_updates=True,
692-
)
683+
return traffic, next_rng
684+
685+
rv_seq, _next_rng = pytensor.scan(
686+
fn=step,
687+
sequences=[seq],
688+
outputs_info=[None, rng],
689+
n_steps=n_steps,
690+
strict=True,
691+
return_updates=False,
692+
)
693693
return rv_seq
694694

695695
def normal_shifted(mu, size):

0 commit comments

Comments
 (0)