/** * Tencent is pleased to support the open source community by making Tars available. * * Copyright (C) 2016THL A29 Limited, a Tencent company. All rights reserved. * * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * https://opensource.org/licenses/BSD-3-Clause * * Unless required by applicable law or agreed to in writing, software distributed * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR * CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ #ifndef _TARS_WHEN_ALL_H_ #define _TARS_WHEN_ALL_H_ #include #include #include "promise/template_helper.h" #include "promise/promise.h" namespace tars { namespace wa { template struct MakeFuturesStorageImpl { using Type = std::tuple; }; template using MakeFuturesStorage = typename MakeFuturesStorageImpl< typename std::decay::type...>::Type; template struct FutureList { using StorgeType = MakeFuturesStorage; using FutureType = Future; using PromiseType = Promise; }; template class ParallelCallbackBase { public: virtual ~ParallelCallbackBase() {} protected: using PromiseAll = Promise; ParallelCallbackBase(const ParallelCallbackBase&) = delete; ParallelCallbackBase(ParallelCallbackBase&&) = delete; ParallelCallbackBase& operator = (const ParallelCallbackBase&) = delete; ParallelCallbackBase(const Promise& p, const int totalCount) : m_promise(p) , m_futures() , m_waitCount(totalCount) { } PromiseAll m_promise; T m_futures; std::atomic m_waitCount; }; template class ParallelCallback final : public std::enable_shared_from_this > , public ParallelCallbackBase< typename FutureList::StorgeType> { public: using StorgeType = typename FutureList::StorgeType; using PromiseAll = typename ParallelCallbackBase::PromiseAll; ParallelCallback(const PromiseAll& p) : ParallelCallbackBase(p, sizeof...(Futures)) { } ~ParallelCallback() override {} void registerCallback() { // do nothing here. } template void registerCallback(T&& f) { f.then(Bind(&ParallelCallback::template onFuture, this->shared_from_this())); } template void registerCallback(T&& f, Types&&... fs) { registerCallback(std::forward(f)); registerCallback(std::forward(fs)...); } template void registerCallback(T&& f, Types&&... fs) { registerCallback<0>(std::forward(f), std::forward(fs)...); } template void onFuture(const typename std::tuple_element::type& f) { std::get(this->m_futures) = f; int waitCount = this->m_waitCount.fetch_sub(1); if (waitCount > 1) return; try { this->m_promise.setValue(this->m_futures); } catch (...) { this->m_promise.setException(currentException()); } } }; template class ParallelCallback > final : public std::enable_shared_from_this > > , public ParallelCallbackBase > { public: using PromiseAll = typename ParallelCallbackBase >::PromiseAll; ParallelCallback(const PromiseAll& p, const size_t count) : ParallelCallbackBase >(p, count) { this->m_futures.resize(count); } ~ParallelCallback() override {} void onFuture(const size_t n, const Futures& f) { this->m_futures[n] = f; int waitCount = this->m_waitCount.fetch_sub(1); if (waitCount > 1) return; try { this->m_promise.setValue(this->m_futures); } catch (...) { this->m_promise.setException(currentException()); } } }; } // end namespace wa(when all) template typename wa::FutureList::FutureType WhenAll(Futures... f) { static_assert((sizeof...(Futures) > 1), "TarsWhenAll need at least two features."); using PromiseType = typename wa::FutureList::PromiseType; using WhenAllCallback = wa::ParallelCallback; PromiseType p; std::shared_ptr c = std::make_shared(p); c->registerCallback(f...); return p.getFuture(); } template Future > WhenAll(std::vector& f) { static_assert(IsFutureType::value, "T is not a Future type."); using PromiseType = Promise >; using WhenAllCallback = wa::ParallelCallback >; PromiseType p; if (f.empty()) p.setValue(std::vector()); else { std::shared_ptr c = std::make_shared(p, f.size()); for (size_t i=0; i