diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 2ae03bc8dc..dca3ec30e2 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -23,7 +23,7 @@ torchvision, _ = optional_import("torchvision") -class PercetualNetworkType(StrEnum): +class PerceptualNetworkType(StrEnum): alex = "alex" vgg = "vgg" squeeze = "squeeze" @@ -49,9 +49,15 @@ class PerceptualLoss(nn.Module): Args: spatial_dims: number of spatial dimensions. - network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``, - ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``} - Specifies the network architecture to use. Defaults to ``"alex"``. + network_type: str | PercetualNetworkType = PercetualNetworkType.alex, + One of: + - "alex" + - "vgg" + - "squeeze" + - "radimagenet_resnet50" + - "medicalnet_resnet10_23datasets" + - "medicalnet_resnet50_23datasets" + - "resnet50" is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. cache_dir: path to cache directory to save the pretrained network weights. @@ -70,7 +76,7 @@ class PerceptualLoss(nn.Module): def __init__( self, spatial_dims: int, - network_type: str = PercetualNetworkType.alex, + network_type: str = PerceptualNetworkType.alex, is_fake_3d: bool = True, fake_3d_ratio: float = 0.5, cache_dir: str | None = None, @@ -93,10 +99,10 @@ def __init__( if channel_wise and "medicalnet_" not in network_type: raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") - if network_type.lower() not in list(PercetualNetworkType): + if network_type.lower() not in list(PerceptualNetworkType): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" - % ", ".join(PercetualNetworkType) + % ", ".join(PerceptualNetworkType) ) if cache_dir: