Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.
 
 
 

234 righe
8.1 KiB

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Threading;
  5. using Cysharp.Threading.Tasks.Internal;
  6. namespace Cysharp.Threading.Tasks.Linq
  7. {
  8. public static partial class UniTaskAsyncEnumerable
  9. {
  10. public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second)
  11. {
  12. Error.ThrowArgumentNullException(first, nameof(first));
  13. Error.ThrowArgumentNullException(second, nameof(second));
  14. return new Merge<T>(new [] { first, second });
  15. }
  16. public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third)
  17. {
  18. Error.ThrowArgumentNullException(first, nameof(first));
  19. Error.ThrowArgumentNullException(second, nameof(second));
  20. Error.ThrowArgumentNullException(third, nameof(third));
  21. return new Merge<T>(new[] { first, second, third });
  22. }
  23. public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources)
  24. {
  25. return sources is IUniTaskAsyncEnumerable<T>[] array
  26. ? new Merge<T>(array)
  27. : new Merge<T>(sources.ToArray());
  28. }
  29. public static IUniTaskAsyncEnumerable<T> Merge<T>(params IUniTaskAsyncEnumerable<T>[] sources)
  30. {
  31. return new Merge<T>(sources);
  32. }
  33. }
  34. internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T>
  35. {
  36. readonly IUniTaskAsyncEnumerable<T>[] sources;
  37. public Merge(IUniTaskAsyncEnumerable<T>[] sources)
  38. {
  39. if (sources.Length <= 0)
  40. {
  41. Error.ThrowArgumentException("No source async enumerable to merge");
  42. }
  43. this.sources = sources;
  44. }
  45. public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
  46. => new _Merge(sources, cancellationToken);
  47. enum MergeSourceState
  48. {
  49. Pending,
  50. Running,
  51. Completed,
  52. }
  53. sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator<T>
  54. {
  55. static readonly Action<object> GetResultAtAction = GetResultAt;
  56. readonly int length;
  57. readonly IUniTaskAsyncEnumerator<T>[] enumerators;
  58. readonly MergeSourceState[] states;
  59. readonly Queue<(T, Exception, bool)> queuedResult = new Queue<(T, Exception, bool)>();
  60. readonly CancellationToken cancellationToken;
  61. int moveNextCompleted;
  62. public T Current { get; private set; }
  63. public _Merge(IUniTaskAsyncEnumerable<T>[] sources, CancellationToken cancellationToken)
  64. {
  65. this.cancellationToken = cancellationToken;
  66. length = sources.Length;
  67. states = ArrayPool<MergeSourceState>.Shared.Rent(length);
  68. enumerators = ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Rent(length);
  69. for (var i = 0; i < length; i++)
  70. {
  71. enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken);
  72. states[i] = (int)MergeSourceState.Pending;;
  73. }
  74. }
  75. public UniTask<bool> MoveNextAsync()
  76. {
  77. cancellationToken.ThrowIfCancellationRequested();
  78. completionSource.Reset();
  79. Interlocked.Exchange(ref moveNextCompleted, 0);
  80. if (HasQueuedResult() && Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0)
  81. {
  82. (T, Exception, bool) value;
  83. lock (states)
  84. {
  85. value = queuedResult.Dequeue();
  86. }
  87. var resultValue = value.Item1;
  88. var exception = value.Item2;
  89. var hasNext = value.Item3;
  90. if (exception != null)
  91. {
  92. completionSource.TrySetException(exception);
  93. }
  94. else
  95. {
  96. Current = resultValue;
  97. completionSource.TrySetResult(hasNext);
  98. }
  99. return new UniTask<bool>(this, completionSource.Version);
  100. }
  101. for (var i = 0; i < length; i++)
  102. {
  103. lock (states)
  104. {
  105. if (states[i] == MergeSourceState.Pending)
  106. {
  107. states[i] = MergeSourceState.Running;
  108. }
  109. else
  110. {
  111. continue;
  112. }
  113. }
  114. var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
  115. if (awaiter.IsCompleted)
  116. {
  117. GetResultAt(i, awaiter);
  118. }
  119. else
  120. {
  121. awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter));
  122. }
  123. }
  124. return new UniTask<bool>(this, completionSource.Version);
  125. }
  126. public async UniTask DisposeAsync()
  127. {
  128. for (var i = 0; i < length; i++)
  129. {
  130. await enumerators[i].DisposeAsync();
  131. }
  132. ArrayPool<MergeSourceState>.Shared.Return(states, true);
  133. ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators, true);
  134. }
  135. static void GetResultAt(object state)
  136. {
  137. using (var tuple = (StateTuple<_Merge, int, UniTask<bool>.Awaiter>)state)
  138. {
  139. tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3);
  140. }
  141. }
  142. void GetResultAt(int index, UniTask<bool>.Awaiter awaiter)
  143. {
  144. bool hasNext;
  145. bool completedAll;
  146. try
  147. {
  148. hasNext = awaiter.GetResult();
  149. }
  150. catch (Exception ex)
  151. {
  152. if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0)
  153. {
  154. completionSource.TrySetException(ex);
  155. }
  156. else
  157. {
  158. lock (states)
  159. {
  160. queuedResult.Enqueue((default, ex, default));
  161. }
  162. }
  163. return;
  164. }
  165. lock (states)
  166. {
  167. states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed;
  168. completedAll = !hasNext && IsCompletedAll();
  169. }
  170. if (hasNext || completedAll)
  171. {
  172. if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0)
  173. {
  174. Current = enumerators[index].Current;
  175. completionSource.TrySetResult(!completedAll);
  176. }
  177. else
  178. {
  179. lock (states)
  180. {
  181. queuedResult.Enqueue((enumerators[index].Current, null, !completedAll));
  182. }
  183. }
  184. }
  185. }
  186. bool HasQueuedResult()
  187. {
  188. lock (states)
  189. {
  190. return queuedResult.Count > 0;
  191. }
  192. }
  193. bool IsCompletedAll()
  194. {
  195. lock (states)
  196. {
  197. for (var i = 0; i < length; i++)
  198. {
  199. if (states[i] != MergeSourceState.Completed)
  200. {
  201. return false;
  202. }
  203. }
  204. }
  205. return true;
  206. }
  207. }
  208. }
  209. }