diff --git a/src/psfmachine/utils.py b/src/psfmachine/utils.py index 38d2c5a..0901b96 100644 --- a/src/psfmachine/utils.py +++ b/src/psfmachine/utils.py @@ -183,8 +183,10 @@ def _make_A_polar(phi, r, cut_r=6, rmin=1, rmax=18, n_r_knots=12, n_phi_knots=15 return X1 -def spline1d(x, knots, degree=3): +def spline1d(x, knots, degree=3, include_knots=False): """Make a spline in a 1D variable `x`""" + if include_knots: + x = np.hstack([knots.min(), x, knots.max()]) X = sparse.csr_matrix( np.asarray( dmatrix( @@ -193,6 +195,9 @@ def spline1d(x, knots, degree=3): ) ) ) + if include_knots: + X = X[1:-1] + x = x[1:-1] if not X.shape[0] == x.shape[0]: raise ValueError("`patsy` has made the wrong matrix.") X = X[:, np.asarray(X.sum(axis=0) != 0)[0]] @@ -200,14 +205,22 @@ def spline1d(x, knots, degree=3): def _make_A_cartesian(x, y, n_knots=10, radius=3.0, knot_spacing_type="sqrt", degree=3): + # Must be odd + n_odd_knots = n_knots if n_knots % 2 == 1 else n_knots + 1 if knot_spacing_type == "sqrt": - knots = np.linspace(-np.sqrt(radius), np.sqrt(radius), n_knots) - knots = np.sign(knots) * knots ** 2 + x_knots = np.linspace(-np.sqrt(radius), np.sqrt(radius), n_odd_knots) + x_knots = np.sign(x_knots) * x_knots ** 2 + y_knots = np.linspace(-np.sqrt(radius), np.sqrt(radius), n_odd_knots) + y_knots = np.sign(y_knots) * y_knots ** 2 else: - knots = np.linspace(-radius, radius, n_knots) - x_spline = spline1d(x, knots=knots, degree=degree) - y_spline = spline1d(y, knots=knots, degree=degree) + x_knots = np.linspace(-radius, radius, n_odd_knots) + y_knots = np.linspace(-radius, radius, n_odd_knots) + x_spline = spline1d(x, knots=x_knots, degree=degree, include_knots=True) + y_spline = spline1d(y, knots=y_knots, degree=degree, include_knots=True) + + x_spline = x_spline[:, np.asarray(x_spline.sum(axis=0))[0] != 0] + y_spline = y_spline[:, np.asarray(y_spline.sum(axis=0))[0] != 0] X = sparse.hstack( [x_spline.multiply(y_spline[:, idx]) for idx in range(y_spline.shape[1])], format="csr", diff --git a/tests/test_perturbation.py b/tests/test_perturbation.py index 532fcea..db1d95d 100644 --- a/tests/test_perturbation.py +++ b/tests/test_perturbation.py @@ -93,7 +93,7 @@ def test_perturbation_matrix3d(): p3 = PerturbationMatrix3D( time=time, dx=dx, dy=dy, nknots=4, radius=5, resolution=5, poly_order=1 ) - assert p3.cartesian_matrix.shape == (169, 64) + assert p3.cartesian_matrix.shape == (169, 81) assert p3.vectors.shape == (10, 2) assert p3.shape == ( p3.cartesian_matrix.shape[0] * p3.ntime,