123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- /*
- Copyright (c) 2020 Sogou, Inc.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- #ifndef __RPC_SERVER_H__
- #define __RPC_SERVER_H__
- #include <map>
- #include <string>
- #include <errno.h>
- #include <workflow/WFServer.h>
- #include <workflow/WFHttpServer.h>
- #include "rpc_types.h"
- #include "rpc_service.h"
- #include "rpc_options.h"
- #include "rpc_trace_module.h"
- #include "rpc_metrics_module.h"
- namespace srpc
- {
- template<class RPCTYPE>
- class RPCServer : public WFServer<typename RPCTYPE::REQ,
- typename RPCTYPE::RESP>
- {
- public:
- using REQTYPE = typename RPCTYPE::REQ;
- using RESPTYPE = typename RPCTYPE::RESP;
- using TASK = RPCServerTask<REQTYPE, RESPTYPE>;
- using SERIES = typename TASK::RPCSeries;
- protected:
- using NETWORKTASK = WFNetworkTask<REQTYPE, RESPTYPE>;
- public:
- RPCServer();
- RPCServer(const struct RPCServerParams *params);
- int add_service(RPCService *service);
- const RPCService* find_service(const std::string& name) const;
- void add_filter(RPCFilter *filter);
- protected:
- RPCServer(const struct RPCServerParams *params,
- std::function<void (NETWORKTASK *)>&& process);
- CommSession *new_session(long long seq, CommConnection *conn) override;
- void server_process(NETWORKTASK *task) const;
- private:
- std::mutex mutex;
- std::map<std::string, RPCService *> service_map;
- RPCModule *modules[SRPC_MODULE_MAX] = { NULL };
- };
- ////////
- // inl
- template<class RPCTYPE>
- inline RPCServer<RPCTYPE>::RPCServer():
- WFServer<REQTYPE, RESPTYPE>(&RPC_SERVER_PARAMS_DEFAULT,
- std::bind(&RPCServer::server_process,
- this, std::placeholders::_1))
- {}
- template<class RPCTYPE>
- inline RPCServer<RPCTYPE>::RPCServer(const struct RPCServerParams *params):
- WFServer<REQTYPE, RESPTYPE>(params,
- std::bind(&RPCServer::server_process,
- this, std::placeholders::_1))
- {}
- template<class RPCTYPE>
- inline RPCServer<RPCTYPE>::RPCServer(const struct RPCServerParams *params,
- std::function<void (NETWORKTASK *)>&& process):
- WFServer<REQTYPE, RESPTYPE>(¶ms, std::move(process))
- {}
- template<class RPCTYPE>
- inline int RPCServer<RPCTYPE>::add_service(RPCService* service)
- {
- const auto it = this->service_map.emplace(service->get_name(), service);
- if (!it.second)
- {
- errno = EEXIST;
- return -1;
- }
- return 0;
- }
- template<>
- inline int RPCServer<RPCTYPESRPC>::add_service(RPCService* service)
- {
- const std::string &name = service->get_name();
- const auto it = this->service_map.emplace(name, service);
- if (!it.second)
- {
- errno = EEXIST;
- return -1;
- }
- auto pos = name.find_last_of('.');
- if (pos != std::string::npos)
- this->service_map.emplace(name.substr(pos + 1), service);
- return 0;
- }
- template<>
- inline int RPCServer<RPCTYPESRPCHttp>::add_service(RPCService* service)
- {
- const std::string &name = service->get_name();
- const auto it = this->service_map.emplace(name, service);
- if (!it.second)
- {
- errno = EEXIST;
- return -1;
- }
- auto pos = name.find_last_of('.');
- if (pos != std::string::npos)
- this->service_map.emplace(name.substr(pos + 1), service);
- return 0;
- }
- template<class RPCTYPE>
- void RPCServer<RPCTYPE>::add_filter(RPCFilter *filter)
- {
- using CLIENT_TASK = RPCClientTask<typename RPCTYPE::REQ,
- typename RPCTYPE::RESP>;
- using SERVER_TASK = RPCServerTask<typename RPCTYPE::REQ,
- typename RPCTYPE::RESP>;
- int type = filter->get_module_type();
- this->mutex.lock();
- if (type < SRPC_MODULE_MAX && type >= 0)
- {
- RPCModule *module = this->modules[type];
- if (!module)
- {
- switch (type)
- {
- case RPCModuleTypeTrace:
- module = new RPCTraceModule<SERVER_TASK, CLIENT_TASK>();
- break;
- case RPCModuleTypeMetrics:
- module = new RPCMetricsModule<SERVER_TASK, CLIENT_TASK>();
- break;
- default:
- break;
- }
- this->modules[type] = module;
- }
- if (module)
- module->add_filter(filter);
- }
- this->mutex.unlock();
- return;
- }
- template<class RPCTYPE>
- inline const RPCService *
- RPCServer<RPCTYPE>::find_service(const std::string& name) const
- {
- const auto it = this->service_map.find(name);
- if (it != this->service_map.cend())
- return it->second;
- return NULL;
- }
- template<class RPCTYPE>
- inline CommSession *RPCServer<RPCTYPE>::new_session(long long seq,
- CommConnection *conn)
- {
- /* TODO: Change to a factory function. */
- std::list<RPCModule *> module;
- for (int i = 0; i < SRPC_MODULE_MAX; i++)
- {
- if (this->modules[i])
- module.push_back(this->modules[i]);
- }
- auto *task = new TASK(this, this->process, std::move(module));
- task->set_keep_alive(this->params.keep_alive_timeout);
- task->get_req()->set_size_limit(this->params.request_size_limit);
- return task;
- }
- template<class RPCTYPE>
- void RPCServer<RPCTYPE>::server_process(NETWORKTASK *task) const
- {
- auto *req = task->get_req();
- auto *resp = task->get_resp();
- int status_code;
- if (!req->deserialize_meta())
- status_code = RPCStatusMetaError;
- else
- {
- auto *server_task = static_cast<TASK *>(task);
- RPCModuleData *task_data = server_task->mutable_module_data();
- req->get_meta_module_data(*task_data);
- RPCTYPE::server_reply_init(req, resp);
- auto *service = this->find_service(req->get_service_name());
- if (!service)
- status_code = RPCStatusServiceNotFound;
- else
- {
- auto *rpc = service->find_method(req->get_method_name());
- if (!rpc)
- status_code = RPCStatusMethodNotFound;
- else
- {
- for (auto *module : this->modules)
- {
- if (module)
- module->server_task_begin(server_task, *task_data);
- }
- status_code = req->decompress();
- if (status_code == RPCStatusOK)
- status_code = (*rpc)(server_task->worker);
- }
- }
- SERIES *series = static_cast<SERIES *>(series_of(task));
- series->set_module_data(task_data);
- }
- resp->set_status_code(status_code);
- }
- template<>
- inline const RPCService *
- RPCServer<RPCTYPEThrift>::find_service(const std::string& name) const
- {
- if (this->service_map.empty())
- return NULL;
- return this->service_map.cbegin()->second;
- }
- template<>
- inline const RPCService *
- RPCServer<RPCTYPEThriftHttp>::find_service(const std::string& name) const
- {
- if (this->service_map.empty())
- return NULL;
- return this->service_map.cbegin()->second;
- }
- } // namespace srpc
- #endif
|