diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index 653db43bc5..8b15a585c6 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -96,12 +96,8 @@ def pad_nd( return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) try: _pad = _np_pad - if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in { - torch.int16, - torch.int64, - torch.bool, - torch.uint8, - }: + if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}: + # Try PyTorch pad for these modes; fallback to NumPy on error. _pad = _pt_pad return _pad(img, pad_width=to_pad, mode=mode, **kwargs) except (ValueError, TypeError, RuntimeError) as err: diff --git a/tests/transforms/croppad/test_pad_nd_dtypes.py b/tests/transforms/croppad/test_pad_nd_dtypes.py new file mode 100644 index 0000000000..b619745cc3 --- /dev/null +++ b/tests/transforms/croppad/test_pad_nd_dtypes.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +import torch + +import monai.transforms.croppad.functional as F +from monai.transforms.croppad.functional import pad_nd + + +def test_pad_uses_pt_for_bool(): + img = torch.ones((1, 4, 4), dtype=torch.bool) + to_pad = [(0, 0), (1, 1), (2, 2)] + with patch.object(F, "_pt_pad", wraps=F._pt_pad) as mock_pt, patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np: + out = pad_nd(img, to_pad, mode="constant", value=0) + + assert mock_pt.called + assert not mock_np.called + assert out.dtype == img.dtype + + +def test_pad_falls_back_to_np_if_pt_raises(): + img = torch.ones((1, 4, 4), dtype=torch.bool) + to_pad = [(0, 0), (1, 1), (2, 2)] + with ( + patch.object(F, "_pt_pad", new=Mock(side_effect=NotImplementedError("no"))) as mock_pt, + patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np, + ): + out = pad_nd(img, to_pad, mode="constant", value=0) + + assert mock_pt.called + assert mock_np.called + assert out.dtype == img.dtype + + +@pytest.mark.parametrize( + "dtype", [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32] +) +def test_pad_dtype_no_error_and_dtype_preserved(dtype): + img = torch.ones((1, 4, 4), dtype=dtype) + to_pad = [(0, 0), (1, 1), (2, 2)] + out = pad_nd(img, to_pad, mode="constant", value=0) + + assert out.shape == (1, 6, 8) + assert out.dtype == img.dtype