when_all.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. /**
  2. * Tencent is pleased to support the open source community by making Tars available.
  3. *
  4. * Copyright (C) 2016THL A29 Limited, a Tencent company. All rights reserved.
  5. *
  6. * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  7. * in compliance with the License. You may obtain a copy of the License at
  8. *
  9. * https://opensource.org/licenses/BSD-3-Clause
  10. *
  11. * Unless required by applicable law or agreed to in writing, software distributed
  12. * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  13. * CONDITIONS OF ANY KIND, either express or implied. See the License for the
  14. * specific language governing permissions and limitations under the License.
  15. */
  16. #ifndef _TARS_WHEN_ALL_H_
  17. #define _TARS_WHEN_ALL_H_
  18. #include <vector>
  19. #include <atomic>
  20. #include "promise/template_helper.h"
  21. #include "promise/promise.h"
  22. namespace tars {
  23. namespace wa {
  24. template <typename... Futures>
  25. struct MakeFuturesStorageImpl {
  26. using Type = std::tuple<Futures...>;
  27. };
  28. template <typename... Futures>
  29. using MakeFuturesStorage = typename MakeFuturesStorageImpl<
  30. typename std::decay<Futures>::type...>::Type;
  31. template <typename... Futures>
  32. struct FutureList {
  33. using StorgeType = MakeFuturesStorage<Futures...>;
  34. using FutureType = Future<StorgeType>;
  35. using PromiseType = Promise<StorgeType>;
  36. };
  37. template <typename T>
  38. class ParallelCallbackBase {
  39. public:
  40. virtual ~ParallelCallbackBase() {}
  41. protected:
  42. using PromiseAll = Promise<T>;
  43. ParallelCallbackBase(const ParallelCallbackBase&) = delete;
  44. ParallelCallbackBase(ParallelCallbackBase&&) = delete;
  45. ParallelCallbackBase& operator = (const ParallelCallbackBase&) = delete;
  46. ParallelCallbackBase(const Promise<T>& p, const int totalCount)
  47. : m_promise(p)
  48. , m_futures()
  49. , m_waitCount(totalCount)
  50. {
  51. }
  52. PromiseAll m_promise;
  53. T m_futures;
  54. std::atomic<int> m_waitCount;
  55. };
  56. template <typename... Futures>
  57. class ParallelCallback final
  58. : public std::enable_shared_from_this<ParallelCallback<Futures...> >
  59. , public ParallelCallbackBase<
  60. typename FutureList<Futures...>::StorgeType> {
  61. public:
  62. using StorgeType = typename FutureList<Futures...>::StorgeType;
  63. using PromiseAll = typename ParallelCallbackBase<StorgeType>::PromiseAll;
  64. ParallelCallback(const PromiseAll& p)
  65. : ParallelCallbackBase<StorgeType>(p, sizeof...(Futures))
  66. {
  67. }
  68. ~ParallelCallback() override {}
  69. void registerCallback()
  70. {
  71. // do nothing here.
  72. }
  73. template <size_t N, typename T>
  74. void registerCallback(T&& f)
  75. {
  76. f.then(Bind(&ParallelCallback::template onFuture<N>,
  77. this->shared_from_this()));
  78. }
  79. template <size_t N, typename T, typename... Types>
  80. void registerCallback(T&& f, Types&&... fs)
  81. {
  82. registerCallback<N>(std::forward<T>(f));
  83. registerCallback<N+1>(std::forward<Types>(fs)...);
  84. }
  85. template <typename T, typename... Types>
  86. void registerCallback(T&& f, Types&&... fs)
  87. {
  88. registerCallback<0>(std::forward<T>(f), std::forward<Types>(fs)...);
  89. }
  90. template <size_t N>
  91. void onFuture(const typename std::tuple_element<N, StorgeType>::type& f)
  92. {
  93. std::get<N>(this->m_futures) = f;
  94. int waitCount = this->m_waitCount.fetch_sub(1);
  95. if (waitCount > 1)
  96. return;
  97. try {
  98. this->m_promise.setValue(this->m_futures);
  99. } catch (...) {
  100. this->m_promise.setException(currentException());
  101. }
  102. }
  103. };
  104. template <typename Futures>
  105. class ParallelCallback<std::vector<Futures> > final
  106. : public std::enable_shared_from_this<ParallelCallback<std::vector<Futures> > >
  107. , public ParallelCallbackBase<std::vector<Futures> > {
  108. public:
  109. using PromiseAll = typename ParallelCallbackBase<std::vector<Futures> >::PromiseAll;
  110. ParallelCallback(const PromiseAll& p, const size_t count)
  111. : ParallelCallbackBase<std::vector<Futures> >(p, count)
  112. {
  113. this->m_futures.resize(count);
  114. }
  115. ~ParallelCallback() override {}
  116. void onFuture(const size_t n, const Futures& f)
  117. {
  118. this->m_futures[n] = f;
  119. int waitCount = this->m_waitCount.fetch_sub(1);
  120. if (waitCount > 1)
  121. return;
  122. try {
  123. this->m_promise.setValue(this->m_futures);
  124. } catch (...) {
  125. this->m_promise.setException(currentException());
  126. }
  127. }
  128. };
  129. } // end namespace wa(when all)
  130. template <typename... Futures>
  131. typename wa::FutureList<Futures...>::FutureType WhenAll(Futures... f)
  132. {
  133. static_assert((sizeof...(Futures) > 1), "TarsWhenAll need at least two features.");
  134. using PromiseType = typename wa::FutureList<Futures...>::PromiseType;
  135. using WhenAllCallback = wa::ParallelCallback<Futures...>;
  136. PromiseType p;
  137. std::shared_ptr<WhenAllCallback> c = std::make_shared<WhenAllCallback>(p);
  138. c->registerCallback(f...);
  139. return p.getFuture();
  140. }
  141. template <typename T>
  142. Future<std::vector<T> > WhenAll(std::vector<T>& f)
  143. {
  144. static_assert(IsFutureType<T>::value, "T is not a Future type.");
  145. using PromiseType = Promise<std::vector<T> >;
  146. using WhenAllCallback = wa::ParallelCallback<std::vector<T> >;
  147. PromiseType p;
  148. if (f.empty())
  149. p.setValue(std::vector<T>());
  150. else {
  151. std::shared_ptr<WhenAllCallback> c = std::make_shared<WhenAllCallback>(p, f.size());
  152. for (size_t i=0; i<f.size(); ++i)
  153. f[i].then(Bind(&WhenAllCallback::onFuture, c, i));
  154. }
  155. return p.getFuture();
  156. }
  157. } // end namespace tars
  158. #endif