diff --git a/docs/user_guide/examples/tutorial_Argofloats.ipynb b/docs/user_guide/examples/tutorial_Argofloats.ipynb index 3ca0841ab..94cb9808c 100644 --- a/docs/user_guide/examples/tutorial_Argofloats.ipynb +++ b/docs/user_guide/examples/tutorial_Argofloats.ipynb @@ -26,7 +26,7 @@ "source": [ "import numpy as np\n", "\n", - "# Define the new Kernels that mimic Argo vertical movement\n", + "# Define the new Kernel that mimics Argo vertical movement\n", "driftdepth = 1000 # maximum depth in m\n", "maxdepth = 2000 # maximum depth in m\n", "vertical_speed = 0.10 # sink and rise speed in m/s\n", @@ -34,62 +34,54 @@ "drifttime = 9 * 86400 # time of deep drift in seconds\n", "\n", "\n", - "def ArgoPhase1(particles, fieldset):\n", - " def SinkingPhase(p):\n", - " \"\"\"Phase 0: Sinking with vertical_speed until depth is driftdepth\"\"\"\n", - " p.dz += vertical_speed * particles.dt\n", - " p.cycle_phase = np.where(p.z + p.dz >= driftdepth, 1, p.cycle_phase)\n", - " p.dz = np.where(p.z + p.dz >= driftdepth, driftdepth - p.z, p.dz)\n", + "def ArgoVerticalMovement(particles, fieldset):\n", + " # Split particles based on their current cycle_phase\n", + " ptcls0 = particles[particles.cycle_phase == 0]\n", + " ptcls1 = particles[particles.cycle_phase == 1]\n", + " ptcls2 = particles[particles.cycle_phase == 2]\n", + " ptcls3 = particles[particles.cycle_phase == 3]\n", + " ptcls4 = particles[particles.cycle_phase == 4]\n", + "\n", + " # Phase 0: Sinking with vertical_speed until depth is driftdepth\n", + " ptcls0.dz += vertical_speed * ptcls0.dt\n", + " ptcls0.cycle_phase = np.where(\n", + " ptcls0.z + ptcls0.dz >= driftdepth, 1, ptcls0.cycle_phase\n", + " )\n", + " ptcls0.dz = np.where(\n", + " ptcls0.z + ptcls0.dz >= driftdepth, driftdepth - ptcls0.z, ptcls0.dz\n", + " )\n", + "\n", + " # Phase 1: Drifting at depth for drifttime seconds\n", + " ptcls1.drift_age += ptcls1.dt\n", + " ptcls1.cycle_phase = np.where(ptcls1.drift_age >= drifttime, 2, ptcls1.cycle_phase)\n", + " ptcls1.drift_age = np.where(ptcls1.drift_age >= drifttime, 0, ptcls1.drift_age)\n", + "\n", + " # Phase 2: Sinking further to maxdepth\n", + " ptcls2.dz += vertical_speed * ptcls2.dt\n", + " ptcls2.cycle_phase = np.where(\n", + " ptcls2.z + ptcls2.dz >= maxdepth, 3, ptcls2.cycle_phase\n", + " )\n", + " ptcls2.dz = np.where(\n", + " ptcls2.z + ptcls2.dz >= maxdepth, maxdepth - ptcls2.z, ptcls2.dz\n", + " )\n", + "\n", + " # Phase 3: Rising with vertical_speed until at surface\n", + " ptcls3.dz -= vertical_speed * ptcls3.dt\n", + " ptcls3.temp = fieldset.thetao[ptcls3.time, ptcls3.z, ptcls3.lat, ptcls3.lon]\n", + " ptcls3.cycle_phase = np.where(\n", + " ptcls3.z + ptcls3.dz <= fieldset.mindepth, 4, ptcls3.cycle_phase\n", + " )\n", + " ptcls3.dz = np.where(\n", + " ptcls3.z + ptcls3.dz <= fieldset.mindepth,\n", + " fieldset.mindepth - ptcls3.z,\n", + " ptcls3.dz,\n", + " )\n", + "\n", + " # Phase 4: Transmitting at surface until cycletime is reached\n", + " ptcls4.cycle_phase = np.where(ptcls4.cycle_age >= cycletime, 0, ptcls4.cycle_phase)\n", + " ptcls4.cycle_age = np.where(ptcls4.cycle_age >= cycletime, 0, ptcls4.cycle_age)\n", + " ptcls4.temp = np.nan # no temperature measurement when at surface\n", "\n", - " SinkingPhase(particles[particles.cycle_phase == 0])\n", - "\n", - "\n", - "def ArgoPhase2(particles, fieldset):\n", - " def DriftingPhase(p):\n", - " \"\"\"Phase 1: Drifting at depth for drifttime seconds\"\"\"\n", - " p.drift_age += particles.dt\n", - " p.cycle_phase = np.where(p.drift_age >= drifttime, 2, p.cycle_phase)\n", - " p.drift_age = np.where(p.drift_age >= drifttime, 0, p.drift_age)\n", - "\n", - " DriftingPhase(particles[particles.cycle_phase == 1])\n", - "\n", - "\n", - "def ArgoPhase3(particles, fieldset):\n", - " def SecondSinkingPhase(p):\n", - " \"\"\"Phase 2: Sinking further to maxdepth\"\"\"\n", - " p.dz += vertical_speed * particles.dt\n", - " p.cycle_phase = np.where(p.z + p.dz >= maxdepth, 3, p.cycle_phase)\n", - " p.dz = np.where(p.z + p.dz >= maxdepth, maxdepth - p.z, p.dz)\n", - "\n", - " SecondSinkingPhase(particles[particles.cycle_phase == 2])\n", - "\n", - "\n", - "def ArgoPhase4(particles, fieldset):\n", - " def RisingPhase(p):\n", - " \"\"\"Phase 3: Rising with vertical_speed until at surface\"\"\"\n", - " p.dz -= vertical_speed * particles.dt\n", - " p.temp = fieldset.thetao[p.time, p.z, p.lat, p.lon]\n", - " p.cycle_phase = np.where(p.z + p.dz <= fieldset.mindepth, 4, p.cycle_phase)\n", - " p.dz = np.where(\n", - " p.z + p.dz <= fieldset.mindepth,\n", - " fieldset.mindepth - p.z,\n", - " p.dz,\n", - " )\n", - "\n", - " RisingPhase(particles[particles.cycle_phase == 3])\n", - "\n", - "\n", - "def ArgoPhase5(particles, fieldset):\n", - " def TransmittingPhase(p):\n", - " \"\"\"Phase 4: Transmitting at surface until cycletime is reached\"\"\"\n", - " p.cycle_phase = np.where(p.cycle_age >= cycletime, 0, p.cycle_phase)\n", - " p.cycle_age = np.where(p.cycle_age >= cycletime, 0, p.cycle_age)\n", - " p.temp = np.nan # no temperature measurement when at surface\n", - "\n", - " TransmittingPhase(particles[particles.cycle_phase == 4])\n", - "\n", - "\n", - "def ArgoPhase6(particles, fieldset):\n", " particles.cycle_age += particles.dt # update cycle_age" ] }, @@ -136,9 +128,7 @@ "ArgoParticle = parcels.Particle.add_variable(\n", " [\n", " parcels.Variable(\"cycle_phase\", dtype=np.int32, initial=0.0),\n", - " parcels.Variable(\n", - " \"cycle_age\", dtype=np.float32, initial=0.0\n", - " ), # TODO update to \"timedelta64[s]\"\n", + " parcels.Variable(\"cycle_age\", dtype=np.float32, initial=0.0),\n", " parcels.Variable(\"drift_age\", dtype=np.float32, initial=0.0),\n", " parcels.Variable(\"temp\", dtype=np.float32, initial=np.nan),\n", " ]\n", @@ -155,12 +145,7 @@ "\n", "# combine Argo vertical movement kernel with built-in Advection kernel\n", "kernels = [\n", - " ArgoPhase1,\n", - " ArgoPhase2,\n", - " ArgoPhase3,\n", - " ArgoPhase4,\n", - " ArgoPhase5,\n", - " ArgoPhase6,\n", + " ArgoVerticalMovement,\n", " parcels.kernels.AdvectionRK4,\n", "]\n", "\n", diff --git a/docs/user_guide/examples/tutorial_interaction.ipynb b/docs/user_guide/examples/tutorial_interaction.ipynb index 47f10e9d5..09d8079e7 100644 --- a/docs/user_guide/examples/tutorial_interaction.ipynb +++ b/docs/user_guide/examples/tutorial_interaction.ipynb @@ -293,18 +293,9 @@ " larger_idx = np.where(mass_j > mass_i, pair_j, pair_i)\n", " smaller_idx = np.where(mass_j > mass_i, pair_i, pair_j)\n", "\n", - " # perform transfer and mark deletions\n", - " # TODO note that we use temporary arrays for indexing because of KernelParticle bug (GH #2143)\n", - " masses = particles.mass\n", - " states = particles.state\n", - "\n", " # transfer mass from smaller to larger and mark smaller for deletion\n", - " masses[larger_idx] += particles.mass[smaller_idx]\n", - " states[smaller_idx] = parcels.StatusCode.Delete\n", - "\n", - " # TODO use particle variables directly after KernelParticle bug (GH #2143) is fixed\n", - " particles.mass = masses\n", - " particles.state = states" + " particles.mass[larger_idx] += particles.mass[smaller_idx]\n", + " particles.state[smaller_idx] = parcels.StatusCode.Delete" ] }, { diff --git a/src/parcels/__init__.py b/src/parcels/__init__.py index 177eceb7b..1d1a9af75 100644 --- a/src/parcels/__init__.py +++ b/src/parcels/__init__.py @@ -17,7 +17,6 @@ Variable, Particle, ParticleClass, - KernelParticle, # ? remove? ) from parcels._core.field import Field, VectorField from parcels._core.basegrid import BaseGrid @@ -87,8 +86,6 @@ "logger", "download_example_dataset", "list_example_datasets", - # (marked for potential removal) - "KernelParticle", ] _stdlib_warnings.warn( diff --git a/src/parcels/_core/field.py b/src/parcels/_core/field.py index bbeff928b..988636f13 100644 --- a/src/parcels/_core/field.py +++ b/src/parcels/_core/field.py @@ -14,7 +14,7 @@ _unitconverters_map, ) from parcels._core.index_search import GRID_SEARCH_ERROR, LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, _search_time_index -from parcels._core.particle import KernelParticle +from parcels._core.particlesetview import ParticleSetView from parcels._core.statuscodes import ( AllParcelsErrorCodes, StatusCode, @@ -35,9 +35,9 @@ def _deal_with_errors(error, key, vector_type: VectorType): - if isinstance(key, KernelParticle): + if isinstance(key, ParticleSetView): key.state = AllParcelsErrorCodes[type(error)] - elif isinstance(key[-1], KernelParticle): + elif isinstance(key[-1], ParticleSetView): key[-1].state = AllParcelsErrorCodes[type(error)] else: raise RuntimeError(f"{error}. Error could not be handled because particles was not part of the Field Sampling.") @@ -229,7 +229,7 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True): def __getitem__(self, key): self._check_velocitysampling() try: - if isinstance(key, KernelParticle): + if isinstance(key, ParticleSetView): return self.eval(key.time, key.z, key.lat, key.lon, key) else: return self.eval(*key) @@ -330,7 +330,7 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True): def __getitem__(self, key): try: - if isinstance(key, KernelParticle): + if isinstance(key, ParticleSetView): return self.eval(key.time, key.z, key.lat, key.lon, key) else: return self.eval(*key) diff --git a/src/parcels/_core/particle.py b/src/parcels/_core/particle.py index 2587d3750..a9a187f30 100644 --- a/src/parcels/_core/particle.py +++ b/src/parcels/_core/particle.py @@ -11,7 +11,7 @@ from parcels._core.utils.time import TimeInterval from parcels._reprs import _format_list_items_multiline -__all__ = ["KernelParticle", "Particle", "ParticleClass", "Variable"] +__all__ = ["Particle", "ParticleClass", "Variable"] _TO_WRITE_OPTIONS = [True, False, "once"] @@ -116,30 +116,6 @@ def add_variable(self, variable: Variable | list[Variable]): return ParticleClass(variables=self.variables + variable) -class KernelParticle: - """Simple class to be used in a kernel that links a particle (on the kernel level) to a particle dataset.""" - - def __init__(self, data, index): - self._data = data - self._index = index - - def __getattr__(self, name): - return self._data[name][self._index] - - def __setattr__(self, name, value): - if name in ["_data", "_index"]: - object.__setattr__(self, name, value) - else: - self._data[name][self._index] = value - - def __getitem__(self, index): - self._index = index - return self - - def __len__(self): - return len(self._index) - - def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_vars: list[Variable]): existing_names = {var.name for var in existing_vars} for var in new_vars: diff --git a/src/parcels/_core/particleset.py b/src/parcels/_core/particleset.py index 29c91d4ad..8dea8efaa 100644 --- a/src/parcels/_core/particleset.py +++ b/src/parcels/_core/particleset.py @@ -11,7 +11,8 @@ from parcels._core.converters import _convert_to_flat_array from parcels._core.kernel import Kernel -from parcels._core.particle import KernelParticle, Particle, create_particle_data +from parcels._core.particle import Particle, create_particle_data +from parcels._core.particlesetview import ParticleSetView from parcels._core.statuscodes import StatusCode from parcels._core.utils.time import ( TimeInterval, @@ -166,7 +167,7 @@ def __getattr__(self, name): def __getitem__(self, index): """Get a single particle by index.""" - return KernelParticle(self._data, index=index) + return ParticleSetView(self._data, index=index) def __setattr__(self, name, value): if name in ["_data"]: diff --git a/src/parcels/_core/particlesetview.py b/src/parcels/_core/particlesetview.py new file mode 100644 index 000000000..c0ce88c04 --- /dev/null +++ b/src/parcels/_core/particlesetview.py @@ -0,0 +1,294 @@ +import numpy as np + + +class ParticleSetView: + """Class to be used in a kernel that links a View of the ParticleSet (on the kernel level) to a ParticleSet.""" + + def __init__(self, data, index): + self._data = data + self._index = index + + def __getattr__(self, name): + # Return a proxy that behaves like the underlying numpy array but + # writes back into the parent arrays when sliced/modified. This + # enables constructs like `particles.dlon[mask] += vals` to update + # the parent arrays rather than temporary copies. + if name in self._data: + # If this ParticleSetView represents a single particle (integer + # index), return the underlying scalar directly to preserve + # user-facing semantics (e.g., `pset[0].time` should be a number). + if isinstance(self._index, (int, np.integer)): + return self._data[name][self._index] + if isinstance(self._index, np.ndarray) and self._index.ndim == 0: + return self._data[name][int(self._index)] + return ParticleSetViewArray(self._data, self._index, name) + return self._data[name][self._index] + + def __setattr__(self, name, value): + if name in ["_data", "_index"]: + object.__setattr__(self, name, value) + else: + self._data[name][self._index] = value + + def __getitem__(self, index): + # normalize single-element tuple indexing (e.g., (inds,)) + if isinstance(index, tuple) and len(index) == 1: + index = index[0] + + base = self._index + new_index = np.zeros_like(base, dtype=bool) + + # Boolean mask (could be local-length or global-length) + if isinstance(index, (np.ndarray, list)) and np.asarray(index).dtype == bool: + arr = np.asarray(index) + if arr.size == base.size: + # global mask + new_index = arr + elif arr.size == int(np.sum(base)): + new_index[base] = arr + else: + raise ValueError( + f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}" + ) + return ParticleSetView(self._data, new_index) + + # Integer array/list, slice or single integer relative to the local view + # (boolean masks were handled above). Normalize and map to global + # particle indices for both boolean-base and integer-base `self._index`. + if isinstance(index, (np.ndarray, list, slice, int)): + # convert list/ndarray to ndarray, keep slice/int as-is + idx = np.asarray(index) if isinstance(index, (np.ndarray, list)) else index + if base.dtype == bool: + particle_idxs = np.flatnonzero(base) + sel = particle_idxs[idx] + else: + base_arr = np.asarray(base) + sel = base_arr[idx] + new_index[sel] = True + return ParticleSetView(self._data, new_index) + + # Fallback: try to assign directly (preserves previous behaviour for other index types) + try: + new_index[base] = index + return ParticleSetView(self._data, new_index) + except Exception as e: + raise TypeError(f"Unsupported index type for ParticleSetView.__getitem__: {type(index)!r}") from e + + def __len__(self): + return len(self._index) + + +def _unwrap(other): + """Return ndarray for ParticleSetViewArray or the value unchanged.""" + return other.__array__() if isinstance(other, ParticleSetViewArray) else other + + +def _asarray(other): + """Return numpy array for ParticleSetViewArray, otherwise return argument.""" + return np.asarray(other.__array__()) if isinstance(other, ParticleSetViewArray) else other + + +class ParticleSetViewArray: + """Array-like proxy for a ParticleSetView that writes through to the parent arrays when mutated.""" + + def __init__(self, data, index, name): + self._data = data + self._index = index + self._name = name + + def __array__(self, dtype=None): + arr = self._data[self._name][self._index] + return arr.astype(dtype) if dtype is not None else arr + + def __repr__(self): + return repr(self.__array__()) + + def __len__(self): + return len(self.__array__()) + + def _to_global_index(self, subindex=None): + """Return a global index (boolean mask or integer indices) that + addresses the parent arrays. If `subindex` is provided it selects + within the current local view and maps back to the global index. + """ + base = self._index + if subindex is None: + return base + + # If subindex is a boolean array, support both local-length masks + # (length == base.sum()) and global-length masks (length == base.size). + if isinstance(subindex, (np.ndarray, list)) and np.asarray(subindex).dtype == bool: + arr = np.asarray(subindex) + if arr.size == base.size: + # already a global mask + return arr + if arr.size == int(np.sum(base)): + global_mask = np.zeros_like(base, dtype=bool) + global_mask[base] = arr + return global_mask + raise ValueError( + f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}" + ) + + # Handle tuple indexing where the first axis indexes particles + # and later axes index into the per-particle array shape (e.g. ei[:, igrid]) + if isinstance(subindex, tuple): + first, *rest = subindex + # map the first index (local selection) to global particle indices + if base.dtype == bool: + particle_idxs = np.flatnonzero(base) + first_arr = np.asarray(first) if isinstance(first, (np.ndarray, list)) else first + sel = particle_idxs[first_arr] + else: + base_arr = np.asarray(base) + sel = base_arr[first] + + # if rest contains a single int (e.g., column), return tuple index + if len(rest) == 1: + return (sel, rest[0]) + # return full tuple (sel, ...) for higher-dim cases + return tuple([sel] + rest) + + # If base is a boolean mask over the parent array and subindex is + # an integer or slice relative to the local view, map it to integer + # indices in the parent array. + if base.dtype == bool: + if isinstance(subindex, (slice, int)): + rel = np.flatnonzero(base)[subindex] + return rel + # If subindex is an integer/array selection (relative to the + # local view) map those to global integer indices. + arr = np.asarray(subindex) + if arr.dtype != bool: + particle_idxs = np.flatnonzero(base) + sel = particle_idxs[arr] + return sel + # Otherwise treat subindex as a boolean mask relative to the + # local view and expand to a global boolean mask. + global_mask = np.zeros_like(base, dtype=bool) + global_mask[base] = arr + return global_mask + + # If base is an array of integer indices + base_arr = np.asarray(base) + try: + return base_arr[subindex] + except Exception: + return base_arr[np.asarray(subindex, dtype=bool)] + + def __getitem__(self, subindex): + # Handle tuple indexing (e.g. [:, igrid]) by applying the tuple + # to the local selection first. This covers the common case + # `particles.ei[:, igrid]` where `ei` is a 2D parent array and the + # second index selects the grid index. + if isinstance(subindex, tuple): + local = self._data[self._name][self._index] + return local[subindex] + + new_index = self._to_global_index(subindex) + return ParticleSetViewArray(self._data, new_index, self._name) + + def __setitem__(self, subindex, value): + tgt = self._to_global_index(subindex) + self._data[self._name][tgt] = value + + # in-place ops must write back into the parent array + def __iadd__(self, other): + vals = self._data[self._name][self._index] + _unwrap(other) + self._data[self._name][self._index] = vals + return self + + def __isub__(self, other): + vals = self._data[self._name][self._index] - _unwrap(other) + self._data[self._name][self._index] = vals + return self + + def __imul__(self, other): + vals = self._data[self._name][self._index] * _unwrap(other) + self._data[self._name][self._index] = vals + return self + + # Provide simple numpy-like evaluation for binary ops by delegating to ndarray + def __add__(self, other): + return self.__array__() + _unwrap(other) + + def __sub__(self, other): + return self.__array__() - _unwrap(other) + + def __mul__(self, other): + return self.__array__() * _unwrap(other) + + def __truediv__(self, other): + return self.__array__() / _unwrap(other) + + def __floordiv__(self, other): + return self.__array__() // _unwrap(other) + + def __pow__(self, other): + return self.__array__() ** _unwrap(other) + + def __neg__(self): + return -self.__array__() + + def __pos__(self): + return +self.__array__() + + def __abs__(self): + return abs(self.__array__()) + + # Right-hand operations to handle cases like `scalar - ParticleSetViewArray` + def __radd__(self, other): + return _unwrap(other) + self.__array__() + + def __rsub__(self, other): + return _unwrap(other) - self.__array__() + + def __rmul__(self, other): + return _unwrap(other) * self.__array__() + + def __rtruediv__(self, other): + return _unwrap(other) / self.__array__() + + def __rfloordiv__(self, other): + return _unwrap(other) // self.__array__() + + def __rpow__(self, other): + return _unwrap(other) ** self.__array__() + + # Comparison operators should return plain numpy boolean arrays so that + # expressions like `mask = particles.gridID == gid` produce an ndarray + # usable for indexing (rather than another ParticleSetViewArray). + def __eq__(self, other): + left = np.asarray(self.__array__()) + right = _asarray(other) + return left == right + + def __ne__(self, other): + left = np.asarray(self.__array__()) + right = _asarray(other) + return left != right + + def __lt__(self, other): + left = np.asarray(self.__array__()) + right = _asarray(other) + return left < right + + def __le__(self, other): + left = np.asarray(self.__array__()) + right = _asarray(other) + return left <= right + + def __gt__(self, other): + left = np.asarray(self.__array__()) + right = _asarray(other) + return left > right + + def __ge__(self, other): + left = np.asarray(self.__array__()) + right = _asarray(other) + return left >= right + + # Allow attribute access like .dtype etc. by forwarding to the ndarray + def __getattr__(self, item): + arr = self.__array__() + return getattr(arr, item) diff --git a/tests/test_particleset_execute.py b/tests/test_particleset_execute.py index 60be3da19..c998677bb 100644 --- a/tests/test_particleset_execute.py +++ b/tests/test_particleset_execute.py @@ -433,13 +433,7 @@ def PythonFail(particles, fieldset): # pragma: no cover [ ("Lat1", [0, 1]), ("Lat2", [2, 0]), - pytest.param( - "Lat1and2", - [2, 1], - marks=pytest.mark.xfail( - reason="Will be fixed alongside GH #2143 . Failing due to https://github.com/OceanParcels/Parcels/pull/2199#issuecomment-3285278876." - ), - ), + ("Lat1and2", [2, 1]), ("Lat1then2", [2, 1]), ], ) diff --git a/tests/test_particlesetview.py b/tests/test_particlesetview.py new file mode 100644 index 000000000..847878f54 --- /dev/null +++ b/tests/test_particlesetview.py @@ -0,0 +1,182 @@ +import numpy as np +import pytest + +from parcels import Field, FieldSet, Particle, ParticleSet, Variable, VectorField, XGrid +from parcels._core.statuscodes import StatusCode +from parcels._datasets.structured.generic import datasets as datasets_structured +from parcels.interpolators import XLinear + + +@pytest.fixture +def fieldset() -> FieldSet: + ds = datasets_structured["ds_2d_left"] + grid = XGrid.from_dataset(ds, mesh="flat") + U = Field("U", ds["U_A_grid"], grid, interp_method=XLinear) + V = Field("V", ds["V_A_grid"], grid, interp_method=XLinear) + UV = VectorField("UV", U, V) + return FieldSet([U, V, UV]) + + +def test_execution_changing_particle_mask(fieldset): + """Test that particle masks can change during kernel execution.""" + npart = 10 + initial_lons = np.linspace(0, 1, npart) + pset = ParticleSet(fieldset, lon=initial_lons.copy(), lat=np.zeros(npart)) + + def IncrementLowLon(particles, fieldset): # pragma: no cover + # Increment lon for particles with lon < 0.5 + # The mask changes as particles cross the threshold + particles[particles.lon < 0.5].dlon += 0.1 + + pset.execute(IncrementLowLon, runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s")) + + # Particles that started below 0.5 should have moved more + # Particles that started above 0.5 should not have moved + particles_started_low = initial_lons < 0.5 + particles_started_high = initial_lons >= 0.5 + + # Low particles should have increased lon + assert np.all(pset.lon[particles_started_low] > initial_lons[particles_started_low]) + # High particles should not have moved + assert np.allclose(pset.lon[particles_started_high], initial_lons[particles_started_high], atol=1e-6) + + +def test_particle_mask_conditional_state_changes(fieldset): + """Test setting particle state based on a condition using particle masks.""" + npart = 10 + initial_lons = np.linspace(0, 1, npart) + pset = ParticleSet(fieldset, lon=initial_lons.copy(), lat=np.zeros(npart)) + + def StopFastParticles(particles, fieldset): # pragma: no cover + # Stop particles that have moved beyond lon=0.5 + particles[particles.lon > 0.5].state = StatusCode.StopExecution + + def AdvanceLon(particles, fieldset): # pragma: no cover + particles.dlon += 0.2 + + pset.execute([AdvanceLon, StopFastParticles], runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s")) + + # All particles should have stopped when they crossed lon > 0.5 + # Verify all final positions are > 0.5 (since they stop after crossing) + assert np.all(pset.lon > 0.5) + # Particles that started closer to 0.5 should have stopped sooner (lower final lon) + # while particles that started farther should have moved more before stopping + assert pset.lon[0] < pset.lon[-1] # First particle stopped earliest, last stopped latest + + +def test_particle_mask_conditional_updates(fieldset): + """Test applying different updates to different particle subsets using masks.""" + npart = 20 + MyParticle = Particle.add_variable(Variable("temp", initial=10.0)) + pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=MyParticle) + + def ConditionalHeating(particles, fieldset): # pragma: no cover + # Warm particles on the left, cool particles on the right + particles[particles.lon < 0.5].temp += 1.0 + particles[particles.lon >= 0.5].temp -= 0.5 + + pset.execute(ConditionalHeating, runtime=np.timedelta64(4, "s"), dt=np.timedelta64(1, "s")) + + # After 5 timesteps (0, 1, 2, 3, 4): left particles should be at 15.0, right at 7.5 + left_particles = pset.lon < 0.5 + right_particles = pset.lon >= 0.5 + assert np.allclose(pset.temp[left_particles], 15.0, atol=1e-6) + assert np.allclose(pset.temp[right_particles], 7.5, atol=1e-6) + + +def test_particle_mask_progressive_changes(fieldset): + """Test masks that change dynamically as particle properties change during execution.""" + npart = 10 + # Start all particles at lon=0, they will progressively move right + pset = ParticleSet(fieldset, lon=np.zeros(npart), lat=np.linspace(0, 1, npart)) + + def MoveAndStopAtBoundary(particles, fieldset): # pragma: no cover + # Move all particles right + particles.dlon += 0.15 + # Stop particles that cross lon=0.5 + particles[particles.lon + particles.dlon > 0.5].state = StatusCode.StopExecution + + pset.execute(MoveAndStopAtBoundary, runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s")) + + # All particles should have stopped at or before lon=0.5 + # After first step: all reach 0.15 + # After second step: all reach 0.30 + # After third step: all reach 0.45 + # After fourth step: all would reach 0.60, so they stop + assert np.all(pset.lon <= 0.6) + assert np.all(pset.lon >= 0.45) # At least 3 steps completed + + +def test_particle_mask_multiple_sequential_operations(fieldset): + """Test applying multiple different mask operations in sequence within one kernel.""" + npart = 30 + MyParticle = Particle.add_variable([Variable("group", initial=0), Variable("counter", initial=0)]) + + # Divide particles into three groups by initial position + lons = np.linspace(0, 1, npart) + pset = ParticleSet(fieldset, lon=lons, lat=np.zeros(npart), pclass=MyParticle) + + def MultiMaskOperations(particles, fieldset): # pragma: no cover + # Classify particles into groups based on lon + particles[particles.lon < 0.33].group = 1 + particles[(particles.lon >= 0.33) & (particles.lon < 0.67)].group = 2 + particles[particles.lon >= 0.67].group = 3 + + # Apply different operations to each group + particles[particles.group == 1].counter += 1 + particles[particles.group == 2].counter += 2 + particles[particles.group == 3].counter += 3 + + pset.execute(MultiMaskOperations, runtime=np.timedelta64(5, "s"), dt=np.timedelta64(1, "s")) + + # Verify groups were assigned correctly and counters incremented appropriately + group1 = pset.lon < 0.33 + group2 = (pset.lon >= 0.33) & (pset.lon < 0.67) + group3 = pset.lon >= 0.67 + + assert np.allclose(pset.counter[group1], 6, atol=1e-6) # 6 timesteps * 1 + assert np.allclose(pset.counter[group2], 12, atol=1e-6) # 6 timesteps * 2 + assert np.allclose(pset.counter[group3], 18, atol=1e-6) # 6 timesteps * 3 + + +def test_particle_mask_empty_mask_handling(fieldset): + """Test that kernels handle empty masks (no particles matching condition) correctly.""" + npart = 10 + MyParticle = Particle.add_variable(Variable("modified", initial=0)) + # All particles start at lon > 0 + pset = ParticleSet(fieldset, lon=np.linspace(0.1, 1.0, npart), lat=np.zeros(npart), pclass=MyParticle) + + def ModifyNegativeLon(particles, fieldset): # pragma: no cover + # This mask should be empty (no particles have lon < 0) + particles[particles.lon < 0].modified = 1 + # This should affect all particles + particles.dlon += 0.01 + + # Should execute without errors even though the first mask is always empty + pset.execute(ModifyNegativeLon, runtime=np.timedelta64(3, "s"), dt=np.timedelta64(1, "s")) + + # No particles should have been modified + assert np.all(pset.modified == 0) + # But all should have moved + assert np.all(pset.lon > 0.1) + + +def test_particle_mask_with_delete_state(fieldset): + """Test using particle masks to delete particles based on conditions.""" + npart = 20 + pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart)) + initial_size = pset.size + + def DeleteEdgeParticles(particles, fieldset): # pragma: no cover + # Delete particles at the edges + particles[(particles.lon < 0.2) | (particles.lon > 0.8)].state = StatusCode.Delete + + def MoveLon(particles, fieldset): # pragma: no cover + particles.dlon += 0.01 + + pset.execute([DeleteEdgeParticles, MoveLon], runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s")) + + # Should have deleted edge particles + assert pset.size < initial_size + # Remaining particles should be in the middle range + assert np.all((pset.lon >= 0.2) & (pset.lon <= 0.8))