Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.
 
 
 

184 rader
5.5 KiB

  1. using Cysharp.Threading.Tasks.Internal;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Runtime.ExceptionServices;
  5. using System.Threading;
  6. namespace Cysharp.Threading.Tasks
  7. {
  8. public partial struct UniTask
  9. {
  10. public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(IEnumerable<UniTask<T>> tasks)
  11. {
  12. return new WhenEachEnumerable<T>(tasks);
  13. }
  14. public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(params UniTask<T>[] tasks)
  15. {
  16. return new WhenEachEnumerable<T>(tasks);
  17. }
  18. }
  19. public readonly struct WhenEachResult<T>
  20. {
  21. public T Result { get; }
  22. public Exception Exception { get; }
  23. //[MemberNotNullWhen(false, nameof(Exception))]
  24. public bool IsCompletedSuccessfully => Exception == null;
  25. //[MemberNotNullWhen(true, nameof(Exception))]
  26. public bool IsFaulted => Exception != null;
  27. public WhenEachResult(T result)
  28. {
  29. this.Result = result;
  30. this.Exception = null;
  31. }
  32. public WhenEachResult(Exception exception)
  33. {
  34. if (exception == null) throw new ArgumentNullException(nameof(exception));
  35. this.Result = default;
  36. this.Exception = exception;
  37. }
  38. public void TryThrow()
  39. {
  40. if (IsFaulted)
  41. {
  42. ExceptionDispatchInfo.Capture(Exception).Throw();
  43. }
  44. }
  45. public T GetResult()
  46. {
  47. if (IsFaulted)
  48. {
  49. ExceptionDispatchInfo.Capture(Exception).Throw();
  50. }
  51. return Result;
  52. }
  53. public override string ToString()
  54. {
  55. if (IsCompletedSuccessfully)
  56. {
  57. return Result?.ToString() ?? "";
  58. }
  59. else
  60. {
  61. return $"Exception{{{Exception.Message}}}";
  62. }
  63. }
  64. }
  65. internal enum WhenEachState : byte
  66. {
  67. NotRunning,
  68. Running,
  69. Completed
  70. }
  71. internal sealed class WhenEachEnumerable<T> : IUniTaskAsyncEnumerable<WhenEachResult<T>>
  72. {
  73. IEnumerable<UniTask<T>> source;
  74. public WhenEachEnumerable(IEnumerable<UniTask<T>> source)
  75. {
  76. this.source = source;
  77. }
  78. public IUniTaskAsyncEnumerator<WhenEachResult<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  79. {
  80. return new Enumerator(source, cancellationToken);
  81. }
  82. sealed class Enumerator : IUniTaskAsyncEnumerator<WhenEachResult<T>>
  83. {
  84. readonly IEnumerable<UniTask<T>> source;
  85. CancellationToken cancellationToken;
  86. Channel<WhenEachResult<T>> channel;
  87. IUniTaskAsyncEnumerator<WhenEachResult<T>> channelEnumerator;
  88. int completeCount;
  89. WhenEachState state;
  90. public Enumerator(IEnumerable<UniTask<T>> source, CancellationToken cancellationToken)
  91. {
  92. this.source = source;
  93. this.cancellationToken = cancellationToken;
  94. }
  95. public WhenEachResult<T> Current => channelEnumerator.Current;
  96. public UniTask<bool> MoveNextAsync()
  97. {
  98. cancellationToken.ThrowIfCancellationRequested();
  99. if (state == WhenEachState.NotRunning)
  100. {
  101. state = WhenEachState.Running;
  102. channel = Channel.CreateSingleConsumerUnbounded<WhenEachResult<T>>();
  103. channelEnumerator = channel.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken);
  104. if (source is UniTask<T>[] array)
  105. {
  106. ConsumeAll(this, array, array.Length);
  107. }
  108. else
  109. {
  110. using (var rentArray = ArrayPoolUtil.Materialize(source))
  111. {
  112. ConsumeAll(this, rentArray.Array, rentArray.Length);
  113. }
  114. }
  115. }
  116. return channelEnumerator.MoveNextAsync();
  117. }
  118. static void ConsumeAll(Enumerator self, UniTask<T>[] array, int length)
  119. {
  120. for (int i = 0; i < length; i++)
  121. {
  122. RunWhenEachTask(self, array[i], length).Forget();
  123. }
  124. }
  125. static async UniTaskVoid RunWhenEachTask(Enumerator self, UniTask<T> task, int length)
  126. {
  127. try
  128. {
  129. var result = await task;
  130. self.channel.Writer.TryWrite(new WhenEachResult<T>(result));
  131. }
  132. catch (Exception ex)
  133. {
  134. self.channel.Writer.TryWrite(new WhenEachResult<T>(ex));
  135. }
  136. if (Interlocked.Increment(ref self.completeCount) == length)
  137. {
  138. self.state = WhenEachState.Completed;
  139. self.channel.Writer.TryComplete();
  140. }
  141. }
  142. public async UniTask DisposeAsync()
  143. {
  144. if (channelEnumerator != null)
  145. {
  146. await channelEnumerator.DisposeAsync();
  147. }
  148. if (state != WhenEachState.Completed)
  149. {
  150. state = WhenEachState.Completed;
  151. channel.Writer.TryComplete(new OperationCanceledException());
  152. }
  153. }
  154. }
  155. }
  156. }