diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs new file mode 100644 index 00000000..7abc1a9f --- /dev/null +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -0,0 +1,505 @@ +using Medallion.Threading.Internal; + +namespace Medallion.Threading; + +internal sealed class CompositeDistributedSynchronizationHandle : IDistributedSynchronizationHandle +{ + private readonly IDistributedSynchronizationHandle[] _handles; + private readonly CancellationTokenSource? _linkedLostCts; + private bool _disposed; + + public CompositeDistributedSynchronizationHandle(IReadOnlyList handles) + { + ValidateHandles(handles); + this._handles = handles.ToArray(); + this._linkedLostCts = this.CreateLinkedCancellationTokenSource(); + } + + public CancellationToken HandleLostToken => this._linkedLostCts?.Token ?? CancellationToken.None; + + public void Dispose() + { + if (this._disposed) + { + return; + } + + this._disposed = true; + var errors = this.DisposeHandles(h => h.Dispose()); + this._linkedLostCts?.Dispose(); + ThrowAggregateExceptionIfNeeded(errors, "disposing"); + } + + public async ValueTask DisposeAsync() + { + if (this._disposed) + { + return; + } + + this._disposed = true; + var errors = await this.DisposeHandlesAsync(h => h.DisposeAsync()).ConfigureAwait(false); + this._linkedLostCts?.Dispose(); + ThrowAggregateExceptionIfNeeded(errors, "asynchronously disposing"); + } + + public static async ValueTask TryAcquireAllAsync( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names, + TimeSpan timeout, + CancellationToken cancellationToken) + { + ValidateAcquireParameters(provider, acquireFunc, names); + + var timeoutTracker = new TimeoutTracker(new TimeoutValue(timeout)); + var handles = new List(names.Count); + IDistributedSynchronizationHandle? result = null; + + try + { + foreach (var name in names) + { + var handle = await acquireFunc(provider, name, timeoutTracker.Remaining, cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + break; + } + + handles.Add(handle); + + if (timeoutTracker.IsExpired) + { + break; + } + } + + if (handles.Count == names.Count) + { + result = new CompositeDistributedSynchronizationHandle(handles); + } + } + finally + { + if (result is null) + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } + } + + return result; + } + + public static async ValueTask AcquireAllAsync( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names, + TimeSpan? timeout, + CancellationToken cancellationToken) + { + var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; + var handle = await TryAcquireAllAsync( + provider, + WrapAcquireFunc(acquireFunc), + names, + effectiveTimeout, + cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + throw new TimeoutException($"Timed out after {effectiveTimeout} while acquiring all locks."); + } + + return handle; + } + + public static IDistributedSynchronizationHandle? TryAcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + TimeSpan timeout, + CancellationToken cancellationToken) => + SyncViaAsync.Run( + state => TryAcquireAllAsync( + state.provider, + WrapSyncAcquireFunc(state.acquireFunc), + state.names, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, timeout, cancellationToken) + ); + + public static IDistributedSynchronizationHandle AcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + TimeSpan? timeout, + CancellationToken cancellationToken) => + SyncViaAsync.Run( + state => AcquireAllAsync( + state.provider, + WrapSyncAcquireFuncForRequired(state.acquireFunc), + state.names, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, timeout, cancellationToken) + ); + + public static async ValueTask TryAcquireAllAsync( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan timeout, + CancellationToken cancellationToken) + { + ValidateAcquireParameters(provider, acquireFunc, names); + + var timeoutTracker = new TimeoutTracker(new TimeoutValue(timeout)); + var handles = new List(names.Count); + IDistributedSynchronizationHandle? result = null; + + try + { + foreach (var name in names) + { + var handle = await acquireFunc(provider, name, maxCount, timeoutTracker.Remaining, cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + break; + } + + handles.Add(handle); + + if (timeoutTracker.IsExpired) + { + break; + } + } + + if (handles.Count == names.Count) + { + result = new CompositeDistributedSynchronizationHandle(handles); + } + } + finally + { + if (result is null) + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } + } + + return result; + } + + + public static async ValueTask AcquireAllAsync( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan? timeout, + CancellationToken cancellationToken) + { + var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; + var handle = await TryAcquireAllAsync( + provider, + WrapAcquireFunc(acquireFunc), + names, + maxCount, + effectiveTimeout, + cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + throw new TimeoutException($"Timed out after {effectiveTimeout} while acquiring all locks."); + } + + return handle; + } + + public static IDistributedSynchronizationHandle? TryAcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan timeout, + CancellationToken cancellationToken) => + SyncViaAsync.Run( + state => TryAcquireAllAsync( + state.provider, + WrapSyncAcquireFunc(state.acquireFunc), + state.names, + state.maxCount, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, maxCount, timeout, cancellationToken) + ); + + public static IDistributedSynchronizationHandle AcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan? timeout, + CancellationToken cancellationToken) => + SyncViaAsync.Run( + state => AcquireAllAsync( + state.provider, + WrapSyncAcquireFuncForRequired(state.acquireFunc), + state.names, + state.maxCount, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, maxCount, timeout, cancellationToken) + ); + + private static void ValidateHandles(IReadOnlyList handles) + { + if (handles is null) + { + throw new ArgumentNullException(nameof(handles)); + } + + if (handles.Count == 0) + { + throw new ArgumentException("At least one handle is required", nameof(handles)); + } + + for (var i = 0; i < handles.Count; ++i) + { + if (handles[i] is null) + { + throw new ArgumentException( + $"Handles must not contain null elements; found null at index {i}", + nameof(handles) + ); + } + } + } + + private CancellationTokenSource? CreateLinkedCancellationTokenSource() + { + var cancellableTokens = this._handles + .Select(h => h.HandleLostToken) + .Where(t => t.CanBeCanceled) + .ToArray(); + + return cancellableTokens.Length > 0 + ? CancellationTokenSource.CreateLinkedTokenSource(cancellableTokens) + : null; + } + + private List? DisposeHandles(Action disposeAction) + { + List? errors = null; + + foreach (var handle in this._handles) + { + try + { + disposeAction(handle); + } + catch (Exception ex) + { + (errors ??= []).Add(ex); + } + } + + return errors; + } + + private async ValueTask?> DisposeHandlesAsync( + Func disposeAction) + { + List? errors = null; + + foreach (var handle in this._handles) + { + try + { + await disposeAction(handle).ConfigureAwait(false); + } + catch (Exception ex) + { + (errors ??= []).Add(ex); + } + } + + return errors; + } + + private static void ThrowAggregateExceptionIfNeeded(List? errors, string operation) + { + if (errors is not null && errors.Count > 0) + { + throw new AggregateException( + $"One or more errors occurred while {operation} a composite distributed handle.", errors); + } + } + + private static void ValidateAcquireParameters( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names) + { + if (provider is null) + { + throw new ArgumentNullException(nameof(provider)); + } + + if (acquireFunc is null) + { + throw new ArgumentNullException(nameof(acquireFunc)); + } + + if (names is null) + { + throw new ArgumentNullException(nameof(names)); + } + + if (names.Count == 0) + { + throw new ArgumentException("At least one lock name is required.", nameof(names)); + } + + for (var i = 0; i < names.Count; ++i) + { + if (names[i] is null) + { + throw new ArgumentException( + $"Names must not contain null elements; found null at index {i}", + nameof(names) + ); + } + } + } + + private static void ValidateAcquireParameters( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names) + { + if (provider is null) + { + throw new ArgumentNullException(nameof(provider)); + } + + if (acquireFunc is null) + { + throw new ArgumentNullException(nameof(acquireFunc)); + } + + if (names is null) + { + throw new ArgumentNullException(nameof(names)); + } + + if (names.Count == 0) + { + throw new ArgumentException("At least one lock name is required.", nameof(names)); + } + + for (var i = 0; i < names.Count; ++i) + { + if (names[i] is null) + { + throw new ArgumentException( + $"Names must not contain null elements; found null at index {i}", + nameof(names) + ); + } + } + } + + private static async ValueTask DisposeHandlesAsync(List handles) + { + foreach (var handle in handles) + { + try + { + await handle.DisposeAsync().ConfigureAwait(false); + } + catch + { + // Suppress exceptions during cleanup + } + } + } + + private static Func> + WrapAcquireFunc( + Func> + acquireFunc) => + async (p, n, t, c) => await acquireFunc(p, n, t, c).ConfigureAwait(false); + + private static Func> + WrapSyncAcquireFunc( + Func acquireFunc) => + (p, n, t, c) => new ValueTask(acquireFunc(p, n, t, c)); + + private static Func> + WrapSyncAcquireFuncForRequired( + Func acquireFunc) => + (p, n, t, c) => + { + var handle = acquireFunc(p, n, t, c); + return handle is not null + ? new ValueTask(handle) + : throw new TimeoutException($"Failed to acquire lock for '{n}'"); + }; + + + private static Func> + WrapAcquireFunc( + Func> + acquireFunc) => + async (p, n, mc, t, c) => await acquireFunc(p, n, mc, t, c).ConfigureAwait(false); + + private static Func> + WrapSyncAcquireFunc( + Func + acquireFunc) => + (p, n, mc, t, c) => new ValueTask(acquireFunc(p, n, mc, t, c)); + + private static Func> + WrapSyncAcquireFuncForRequired( + Func + acquireFunc) => + (p, n, mc, t, c) => + { + var handle = acquireFunc(p, n, mc, t, c); + return handle is not null + ? new ValueTask(handle) + : throw new TimeoutException($"Failed to acquire lock for '{n}'"); + }; + + private readonly struct TimeoutTracker(TimeoutValue timeout) + { + private readonly System.Diagnostics.Stopwatch? _stopwatch = timeout.IsInfinite + ? null + : System.Diagnostics.Stopwatch.StartNew(); + + public TimeSpan Remaining => + this._stopwatch is { Elapsed: var elapsed } + ? elapsed >= timeout.TimeSpan + ? TimeSpan.Zero + : timeout.TimeSpan - elapsed + : Timeout.InfiniteTimeSpan; + + public bool IsExpired => this._stopwatch is not null && this._stopwatch.Elapsed >= timeout.TimeSpan; + } +} \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs index ea4ad6f5..42dd78f1 100644 --- a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,50 @@ public static IDistributedSynchronizationHandle AcquireLock(this IDistributedLoc /// public static ValueTask AcquireLockAsync(this IDistributedLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateLock(name).AcquireAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllLocks(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllLocks(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllLocksAsync(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllLocksAsync(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireLockAsync(n, t, c), + names, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs index 5ac66fac..5e2dadd2 100644 --- a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedReaderWriterLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -62,4 +64,90 @@ public static IDistributedSynchronizationHandle AcquireWriteLock(this IDistribut /// public static ValueTask AcquireWriteLockAsync(this IDistributedReaderWriterLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateReaderWriterLock(name).AcquireWriteLockAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllReadLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireReadLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllReadLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireReadLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllReadLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireReadLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllReadLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireReadLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllWriteLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireWriteLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllWriteLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireWriteLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllWriteLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireWriteLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllWriteLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireWriteLockAsync(n, t, c), + names, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs index b808b004..dc961498 100644 --- a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedSemaphoreProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,50 @@ public static IDistributedSynchronizationHandle AcquireSemaphore(this IDistribut /// public static ValueTask AcquireSemaphoreAsync(this IDistributedSemaphoreProvider provider, string name, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateSemaphore(name, maxCount).AcquireAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllSemaphores(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, mc, t, c) => p.TryAcquireSemaphore(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllSemaphores(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, mc, t, c) => p.AcquireSemaphore(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllSemaphoresAsync(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, mc, t, c) => p.TryAcquireSemaphoreAsync(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllSemaphoresAsync(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, mc, t, c) => p.AcquireSemaphoreAsync(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs index 4b8c51f9..1f647384 100644 --- a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedUpgradeableReaderWriterLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,13 @@ public static IDistributedLockUpgradeableHandle AcquireUpgradeableReadLock(this /// public static ValueTask AcquireUpgradeableReadLockAsync(this IDistributedUpgradeableReaderWriterLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateUpgradeableReaderWriterLock(name).AcquireUpgradeableReadLockAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + // Composite methods are not supported for IDistributedUpgradeableReaderWriterLock + // because a composite acquire operation must be able to roll back and upgrade does not support that. + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/PublicAPI.Unshipped.txt b/src/DistributedLock.Core/PublicAPI.Unshipped.txt index 5f282702..4b44c194 100644 --- a/src/DistributedLock.Core/PublicAPI.Unshipped.txt +++ b/src/DistributedLock.Core/PublicAPI.Unshipped.txt @@ -1 +1,17 @@ - \ No newline at end of file +#nullable enable +static Medallion.Threading.DistributedLockProviderExtensions.AcquireAllLocks(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedLockProviderExtensions.AcquireAllLocksAsync(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedLockProviderExtensions.TryAcquireAllLocks(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedLockProviderExtensions.TryAcquireAllLocksAsync(this Medallion.Threading.IDistributedLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllReadLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllReadLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllWriteLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.AcquireAllWriteLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllReadLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllReadLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllWriteLocks(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedReaderWriterLockProviderExtensions.TryAcquireAllWriteLocksAsync(this Medallion.Threading.IDistributedReaderWriterLockProvider! provider, System.Collections.Generic.IReadOnlyList! names, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedSemaphoreProviderExtensions.AcquireAllSemaphores(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle! +static Medallion.Threading.DistributedSemaphoreProviderExtensions.AcquireAllSemaphoresAsync(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan? timeout = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +static Medallion.Threading.DistributedSemaphoreProviderExtensions.TryAcquireAllSemaphores(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> Medallion.Threading.IDistributedSynchronizationHandle? +static Medallion.Threading.DistributedSemaphoreProviderExtensions.TryAcquireAllSemaphoresAsync(this Medallion.Threading.IDistributedSemaphoreProvider! provider, System.Collections.Generic.IReadOnlyList! names, int maxCount, System.TimeSpan timeout = default(System.TimeSpan), System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask diff --git a/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs b/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs index 074e82d5..9ca1a407 100644 --- a/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs +++ b/src/DistributedLock.Tests/Tests/Core/DistributedLockProviderExtensionsTest.cs @@ -49,7 +49,7 @@ public void TestCallThrough([Values] bool isTry, [Values] bool isAsync) } void Test( - Expression> providerFunction, + Expression> providerFunction, Expression> lockFunction) { providerFunction.Compile()(mockProvider.Object); @@ -58,4 +58,286 @@ void Test( mockLock.Verify(lockFunction, Times.Once()); } } + + [Test, Combinatorial] + public void TestCompositeArgumentValidation([Values] bool isAsync) + { + var mockProvider = new Mock(); + var validNames = new List { "lock1", "lock2" }; + + if (isAsync) + { + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(null!, validNames).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(mockProvider.Object, null!).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(mockProvider.Object, new List()).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.TryAcquireAllLocksAsync(mockProvider.Object, new List { "lock1", null! }).AsTask()); + + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(null!, validNames).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(mockProvider.Object, null!).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(mockProvider.Object, new List()).AsTask()); + Assert.ThrowsAsync(() => DistributedLockProviderExtensions.AcquireAllLocksAsync(mockProvider.Object, new List { "lock1", null! }).AsTask()); + } + else + { + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(null!, validNames)); + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(mockProvider.Object, null!)); + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(mockProvider.Object, new List())); + Assert.Throws(() => DistributedLockProviderExtensions.TryAcquireAllLocks(mockProvider.Object, new List { "lock1", null! })); + + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(null!, validNames)); + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(mockProvider.Object, null!)); + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(mockProvider.Object, new List())); + Assert.Throws(() => DistributedLockProviderExtensions.AcquireAllLocks(mockProvider.Object, new List { "lock1", null! })); + } + } + + [Test, Combinatorial] + public async Task TestCompositePartialAcquisitionFailure([Values] bool isAsync) + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + var mockHandleA = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + if (isAsync) + { + mockLockA.Setup(l => l.TryAcquireAsync(TimeSpan.Zero, default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((IDistributedSynchronizationHandle?)null); + } + else + { + mockLockA.Setup(l => l.TryAcquire(TimeSpan.Zero, default)) + .Returns(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns((IDistributedSynchronizationHandle?)null); + } + + var names = new List { "A", "B" }; + IDistributedSynchronizationHandle? result; + + if (isAsync) + { + result = await mockProvider.Object.TryAcquireAllLocksAsync(names, TimeSpan.Zero, default); + } + else + { + result = mockProvider.Object.TryAcquireAllLocks(names, TimeSpan.Zero, default); + } + + Assert.That(result, Is.Null); + + mockHandleA.Verify(h => h.DisposeAsync(), Times.Once); + } + + [Test, Combinatorial] + public async Task TestCompositeSuccessfulAcquisition([Values] bool isTry, [Values] bool isAsync) + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + var mockHandleA = new Mock(); + var mockHandleB = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + if (isAsync) + { + if (isTry) + { + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleB.Object); + } + else + { + mockLockA.Setup(l => l.AcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.AcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleB.Object); + } + } + else + { + if (isTry) + { + mockLockA.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns(mockHandleB.Object); + } + else + { + mockLockA.Setup(l => l.Acquire(It.IsAny(), default)) + .Returns(mockHandleA.Object); + mockLockB.Setup(l => l.Acquire(It.IsAny(), default)) + .Returns(mockHandleB.Object); + } + } + + var names = new List { "A", "B" }; + IDistributedSynchronizationHandle? result; + + if (isAsync) + { + if (isTry) + { + result = await mockProvider.Object.TryAcquireAllLocksAsync(names, TimeSpan.FromSeconds(10), default); + } + else + { + result = await mockProvider.Object.AcquireAllLocksAsync(names, TimeSpan.FromSeconds(10), default); + } + } + else + { + if (isTry) + { + result = mockProvider.Object.TryAcquireAllLocks(names, TimeSpan.FromSeconds(10), default); + } + else + { + result = mockProvider.Object.AcquireAllLocks(names, TimeSpan.FromSeconds(10), default); + } + } + + Assert.That(result, Is.Not.Null); + + mockProvider.Verify(p => p.CreateLock("A"), Times.Once); + mockProvider.Verify(p => p.CreateLock("B"), Times.Once); + + if (isAsync) + { + await result!.DisposeAsync(); + mockHandleA.Verify(h => h.DisposeAsync(), Times.Once); + mockHandleB.Verify(h => h.DisposeAsync(), Times.Once); + } + else + { + result!.Dispose(); + mockHandleA.Verify(h => h.Dispose(), Times.Once); + mockHandleB.Verify(h => h.Dispose(), Times.Once); + } + } + + [Test, Combinatorial] + public void TestCompositeAcquireThrowsOnTimeout([Values] bool isAsync) + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + + if (isAsync) + { + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((IDistributedSynchronizationHandle?)null); + } + else + { + mockLockA.Setup(l => l.TryAcquire(It.IsAny(), default)) + .Returns((IDistributedSynchronizationHandle?)null); + } + + var names = new List { "A" }; + + if (isAsync) + { + Assert.ThrowsAsync(() => mockProvider.Object.AcquireAllLocksAsync(names, TimeSpan.Zero, default).AsTask()); + } + else + { + Assert.Throws(() => mockProvider.Object.AcquireAllLocks(names, TimeSpan.Zero, default)); + } + } + + [Test] + public async Task TestCompositeRemainingTimeDistribution() + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + var mockHandleA = new Mock(); + var mockHandleB = new Mock(); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + var capturedTimeouts = new List(); + + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((TimeSpan timeout, CancellationToken _) => + { + capturedTimeouts.Add(timeout); + Task.Delay(TimeSpan.FromMilliseconds(100)).Wait(); + return mockHandleA.Object; + }); + + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync((TimeSpan timeout, CancellationToken _) => + { + capturedTimeouts.Add(timeout); + return mockHandleB.Object; + }); + + var names = new List { "A", "B" }; + var totalTimeout = TimeSpan.FromSeconds(5); + + var result = await mockProvider.Object.TryAcquireAllLocksAsync(names, totalTimeout, default); + + Assert.That(result, Is.Not.Null); + Assert.That(capturedTimeouts.Count, Is.EqualTo(2)); + + Assert.That(capturedTimeouts[0].TotalMilliseconds, Is.EqualTo(totalTimeout.TotalMilliseconds).Within(1.0)); + + Assert.That(capturedTimeouts[1], Is.LessThan(totalTimeout)); + Assert.That(capturedTimeouts[1], Is.GreaterThanOrEqualTo(TimeSpan.Zero)); + + await result!.DisposeAsync(); + } + + [Test] + public async Task TestCompositeHandleLostToken() + { + var mockProvider = new Mock(); + var mockLockA = new Mock(); + var mockLockB = new Mock(); + + var ctsA = new CancellationTokenSource(); + var ctsB = new CancellationTokenSource(); + + var mockHandleA = new Mock(); + var mockHandleB = new Mock(); + + mockHandleA.Setup(h => h.HandleLostToken).Returns(ctsA.Token); + mockHandleB.Setup(h => h.HandleLostToken).Returns(ctsB.Token); + + mockProvider.Setup(p => p.CreateLock("A")).Returns(mockLockA.Object); + mockProvider.Setup(p => p.CreateLock("B")).Returns(mockLockB.Object); + + mockLockA.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleA.Object); + mockLockB.Setup(l => l.TryAcquireAsync(It.IsAny(), default)) + .ReturnsAsync(mockHandleB.Object); + + var names = new List { "A", "B" }; + var result = await mockProvider.Object.TryAcquireAllLocksAsync(names, TimeSpan.FromSeconds(10), default); + + Assert.That(result, Is.Not.Null); + Assert.That(result!.HandleLostToken.CanBeCanceled, Is.True); + + ctsA.Cancel(); + + Assert.That(result.HandleLostToken.IsCancellationRequested, Is.True); + + await result.DisposeAsync(); + ctsA.Dispose(); + ctsB.Dispose(); + } } diff --git a/src/DistributedLockCodeGen/GenerateProviders.cs b/src/DistributedLockCodeGen/GenerateProviders.cs index ab657b14..b55e0aa8 100644 --- a/src/DistributedLockCodeGen/GenerateProviders.cs +++ b/src/DistributedLockCodeGen/GenerateProviders.cs @@ -10,64 +10,115 @@ namespace DistributedLockCodeGen; [Category("CI")] public class GenerateProviders { - public static readonly IReadOnlyList Interfaces = new[] - { + public static readonly IReadOnlyList Interfaces = + [ "IDistributedLock", "IDistributedReaderWriterLock", "IDistributedUpgradeableReaderWriterLock", "IDistributedSemaphore" + ]; + + private static readonly IReadOnlyDictionary ExcludedInterfacesForCompositeMethods = new Dictionary + { + ["IDistributedUpgradeableReaderWriterLock"] = "a composite acquire operation must be able to roll back and upgrade does not support that." }; [TestCaseSource(nameof(Interfaces))] public void GenerateProviderInterfaceAndExtensions(string interfaceName) { - var interfaceFile = Directory.GetFiles(CodeGenHelpers.SolutionDirectory, interfaceName + ".cs", SearchOption.AllDirectories) + var interfaceFile = Directory + .GetFiles(CodeGenHelpers.SolutionDirectory, interfaceName + ".cs", SearchOption.AllDirectories) .Single(); var providerInterfaceName = interfaceName + "Provider"; var createMethodName = $"Create{interfaceName.Replace("IDistributed", string.Empty)}"; - var providerInterfaceCode = $@"// AUTO-GENERATED -namespace Medallion.Threading; - -/// -/// Acts as a factory for instances of a certain type. This interface may be -/// easier to use than in dependency injection scenarios. -/// -public interface {providerInterfaceName}{(interfaceName == "IDistributedUpgradeableReaderWriterLock" ? ": IDistributedReaderWriterLockProvider" : string.Empty)} -{{ - /// - /// Constructs an instance with the given . - /// - {interfaceName} {createMethodName}(string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}); -}}"; + var providerInterfaceCode = $$""" + // AUTO-GENERATED + namespace Medallion.Threading; + + /// + /// Acts as a factory for instances of a certain type. This interface may be + /// easier to use than in dependency injection scenarios. + /// + public interface {{providerInterfaceName}}{{(interfaceName == "IDistributedUpgradeableReaderWriterLock" ? ": IDistributedReaderWriterLockProvider" : string.Empty)}} + { + /// + /// Constructs an instance with the given . + /// + {{interfaceName}} {{createMethodName}}(string name{{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}}); + } + """; var interfaceMethods = Regex.Matches( File.ReadAllText(interfaceFile), @"(?\S+) (?\S+)\((?((?\S*) (?\w+)[^,)]*(\, )?)*)\);", RegexOptions.ExplicitCapture ); - var extensionMethodBodies = interfaceMethods.Cast() + + var extensionSingleMethodBodies = interfaceMethods .Select(m => -$@" /// - /// Equivalent to calling and then - /// ().Select(c => c.Value))})"" />. - /// - public static {m.Groups["returnType"].Value} {GetExtensionMethodName(m.Groups["name"].Value)}(this {providerInterfaceName} provider, string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => - (provider ?? throw new ArgumentNullException(nameof(provider))).{createMethodName}(name{(interfaceName.Contains("Semaphore") ? ", maxCount" : string.Empty)}).{m.Groups["name"].Value}({string.Join(", ", m.Groups["parameterName"].Captures.Cast().Select(c => c.Value))});" + $""" + /// + /// Equivalent to calling and then + /// c.Value))})" />. + /// + public static {m.Groups["returnType"].Value} {GetExtensionMethodName(m.Groups["name"].Value)}(this {providerInterfaceName} provider, string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => + (provider ?? throw new ArgumentNullException(nameof(provider))).{createMethodName}(name{(interfaceName.Contains("Semaphore") ? ", maxCount" : string.Empty)}).{m.Groups["name"].Value}({string.Join(", ", m.Groups["parameterName"].Captures.Select(c => c.Value))}); + """ ); + var extensionCompositeMethodBodies = ExcludedInterfacesForCompositeMethods.TryGetValue(interfaceName, out var exclusionReason) + ? + [ + $""" + // Composite methods are not supported for {interfaceName} + // because {exclusionReason} + """ + ] + : interfaceMethods + .Select(m => + { + var (extensionMethodName, innerCallName) = GetAllExtensionMethodName(m.Groups["name"].Value); + var isSemaphore = interfaceName.Contains("Semaphore"); + + return $""" + /// + /// Equivalent to calling for each name in and then + /// c.Value))})" /> on each created instance, combining the results into a composite handle. + /// + public static {m.Groups["returnType"].Value} {extensionMethodName}(this {providerInterfaceName} provider, IReadOnlyList names{(isSemaphore ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => + CompositeDistributedSynchronizationHandle.{innerCallName}( + provider, + static (p, n{(isSemaphore ? ", mc" : string.Empty)}, t, c) => p.{GetExtensionMethodName(m.Groups["name"].Value)}(n{(isSemaphore ? ", mc" : string.Empty)}, t, c), + names,{(isSemaphore ? " maxCount," : string.Empty)} timeout, cancellationToken); + """; + } + ); + var providerExtensionsName = providerInterfaceName.TrimStart('I') + "Extensions"; - var providerExtensionsCode = $@"// AUTO-GENERATED + var providerExtensionsCode = $$""" + // AUTO-GENERATED + + namespace Medallion.Threading; + + /// + /// Productivity helper methods for + /// + public static class {{providerExtensionsName}} + { + # region Single Lock Methods + + {{string.Join(Environment.NewLine + Environment.NewLine, extensionSingleMethodBodies)}} + + # endregion -namespace Medallion.Threading; + # region Composite Lock Methods -/// -/// Productivity helper methods for -/// -public static class {providerExtensionsName} -{{ -{string.Join(Environment.NewLine + Environment.NewLine, extensionMethodBodies)} -}}"; + {{string.Join(Environment.NewLine + Environment.NewLine, extensionCompositeMethodBodies)}} + + # endregion + } + """; var changes = new[] { @@ -76,7 +127,8 @@ public static class {providerExtensionsName} } .Select(t => (file: Path.Combine(Path.GetDirectoryName(interfaceFile)!, t.name + ".cs"), t.code)) .Select(t => (t.file, t.code, originalCode: File.Exists(t.file) ? File.ReadAllText(t.file) : string.Empty)) - .Where(t => CodeGenHelpers.NormalizeCodeWhitespace(t.code) != CodeGenHelpers.NormalizeCodeWhitespace(t.originalCode)) + .Where(t => CodeGenHelpers.NormalizeCodeWhitespace(t.code) != + CodeGenHelpers.NormalizeCodeWhitespace(t.originalCode)) .ToList(); changes.ForEach(t => File.WriteAllText(t.file, t.code)); Assert.That(changes.Select(t => t.file), Is.Empty); @@ -85,8 +137,45 @@ string GetExtensionMethodName(string interfaceMethodName) => Regex.IsMatch(interfaceMethodName, "^(Try)?Acquire(Async)?$") // make it more specific to differentiate when one concrete provider implements multiple provider interfaces ? interfaceMethodName.Replace("Async", string.Empty) - + interfaceName.Replace("IDistributed", string.Empty) - + (interfaceMethodName.EndsWith("Async") ? "Async" : string.Empty) + + interfaceName.Replace("IDistributed", string.Empty) + + (interfaceMethodName.EndsWith("Async") ? "Async" : string.Empty) : interfaceMethodName; + + (string extensionMethodName, string innerCallName) GetAllExtensionMethodName(string interfaceMethodName) + { + var isExactAcquire = Regex.IsMatch(interfaceMethodName, "^(Try)?Acquire(Async)?$"); + var isAsync = interfaceMethodName.EndsWith("Async", StringComparison.Ordinal); + var isTryVariant = interfaceMethodName.StartsWith("Try", StringComparison.Ordinal); + + string extensionMethodName; + + if (!isExactAcquire) + { + // e.g. TryAcquireReadLock -> TryAcquireAllReadLocks + // TryAcquireSemaphore -> TryAcquireAllSemaphores + // TryAcquireUpgradeableReadLockAsync -> TryAcquireUpgradeableAllReadLockAsync + extensionMethodName = interfaceMethodName + .Replace("Acquire", "AcquireAll") // Acquire -> AcquireAll + .Replace("Async", string.Empty) // strip Async (add back later) + + "s" // pluralise + + (isAsync ? "Async" : string.Empty); // restore Async if needed + } + else + { + // e.g. TryAcquire -> TryAcquireAllLocks + // AcquireAsync -> AcquireAllLocksAsync + extensionMethodName = interfaceMethodName.Replace("Async", string.Empty) + + "All" + + interfaceName.Replace("IDistributed", string.Empty) + "s" + + (isAsync ? "Async" : string.Empty); + } + + // - “Try…” methods -> TryAcquireAll[Async] + // - plain Acquire… -> AcquireAll[Async] + var innerCallName = (isTryVariant ? "TryAcquireAll" : "AcquireAll") + + (isAsync ? "Async" : string.Empty); + + return (extensionMethodName, innerCallName); + } } -} +} \ No newline at end of file