rpc_client.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. /*
  2. Copyright (c) 2020 Sogou, Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. #ifndef __RPC_CLIENT_H__
  14. #define __RPC_CLIENT_H__
  15. #include "rpc_types.h"
  16. #include "rpc_context.h"
  17. #include "rpc_options.h"
  18. #include "rpc_global.h"
  19. #include "rpc_trace_module.h"
  20. #include "rpc_metrics_module.h"
  21. namespace srpc
  22. {
  23. template<class RPCTYPE>
  24. class RPCClient
  25. {
  26. public:
  27. using TASK = RPCClientTask<typename RPCTYPE::REQ, typename RPCTYPE::RESP>;
  28. protected:
  29. using COMPLEXTASK = WFComplexClientTask<typename RPCTYPE::REQ,
  30. typename RPCTYPE::RESP>;
  31. public:
  32. RPCClient(const std::string& service_name);
  33. virtual ~RPCClient() { };
  34. const RPCTaskParams *get_task_params() const;
  35. const std::string& get_service_name() const;
  36. void task_init(COMPLEXTASK *task) const;
  37. void set_keep_alive(int timeout);
  38. void set_watch_timeout(int timeout);
  39. void add_filter(RPCFilter *filter);
  40. protected:
  41. template<class OUTPUT>
  42. TASK *create_rpc_client_task(const std::string& method_name,
  43. std::function<void (OUTPUT *, RPCContext *)>&& done)
  44. {
  45. std::list<RPCModule *> module;
  46. for (int i = 0; i < SRPC_MODULE_MAX; i++)
  47. {
  48. if (this->modules[i])
  49. module.push_back(this->modules[i]);
  50. }
  51. auto *task = new TASK(this->service_name,
  52. method_name,
  53. &this->params.task_params,
  54. std::move(module),
  55. [done](int status_code, RPCWorker& worker) -> int {
  56. return ClientRPCDoneImpl(status_code, worker, done);
  57. });
  58. this->task_init(task);
  59. return task;
  60. }
  61. void init(const RPCClientParams *params);
  62. std::string service_name;
  63. private:
  64. void __task_init(COMPLEXTASK *task) const;
  65. protected:
  66. RPCClientParams params;
  67. ParsedURI uri;
  68. private:
  69. struct sockaddr_storage ss;
  70. socklen_t ss_len;
  71. bool has_addr_info;
  72. std::mutex mutex;
  73. RPCModule *modules[SRPC_MODULE_MAX] = { 0 };
  74. };
  75. ////////
  76. // inl
  77. template<class RPCTYPE>
  78. inline RPCClient<RPCTYPE>::RPCClient(const std::string& service_name):
  79. params(RPC_CLIENT_PARAMS_DEFAULT),
  80. has_addr_info(false)
  81. {
  82. SRPCGlobal::get_instance();
  83. this->service_name = service_name;
  84. }
  85. template<class RPCTYPE>
  86. inline const RPCTaskParams *RPCClient<RPCTYPE>::get_task_params() const
  87. {
  88. return &this->params.task_params;
  89. }
  90. template<class RPCTYPE>
  91. inline const std::string& RPCClient<RPCTYPE>::get_service_name() const
  92. {
  93. return this->service_name;
  94. }
  95. template<class RPCTYPE>
  96. inline void RPCClient<RPCTYPE>::set_keep_alive(int timeout)
  97. {
  98. this->params.task_params.keep_alive_timeout = timeout;
  99. }
  100. template<class RPCTYPE>
  101. inline void RPCClient<RPCTYPE>::set_watch_timeout(int timeout)
  102. {
  103. this->params.task_params.watch_timeout = timeout;
  104. }
  105. template<class RPCTYPE>
  106. void RPCClient<RPCTYPE>::add_filter(RPCFilter *filter)
  107. {
  108. using CLIENT_TASK = RPCClientTask<typename RPCTYPE::REQ,
  109. typename RPCTYPE::RESP>;
  110. using SERVER_TASK = RPCServerTask<typename RPCTYPE::REQ,
  111. typename RPCTYPE::RESP>;
  112. int type = filter->get_module_type();
  113. this->mutex.lock();
  114. if (type < SRPC_MODULE_MAX && type >= 0)
  115. {
  116. RPCModule *module = this->modules[type];
  117. if (!module)
  118. {
  119. switch (type)
  120. {
  121. case RPCModuleTypeTrace:
  122. module = new RPCTraceModule<SERVER_TASK, CLIENT_TASK>();
  123. break;
  124. case RPCModuleTypeMetrics:
  125. module = new RPCMetricsModule<SERVER_TASK, CLIENT_TASK>();
  126. break;
  127. default:
  128. break;
  129. }
  130. this->modules[type] = module;
  131. }
  132. if (module)
  133. module->add_filter(filter);
  134. }
  135. this->mutex.unlock();
  136. return;
  137. }
  138. template<class RPCTYPE>
  139. inline void RPCClient<RPCTYPE>::init(const RPCClientParams *params)
  140. {
  141. this->params = *params;
  142. if (this->params.task_params.data_type == RPCDataUndefined)
  143. this->params.task_params.data_type = RPCTYPE::default_data_type;
  144. this->has_addr_info = SRPCGlobal::get_instance()->task_init(this->params,
  145. this->uri,
  146. &this->ss,
  147. &this->ss_len);
  148. if (this->params.is_ssl)
  149. {
  150. if (this->params.transport_type == TT_TCP)
  151. this->params.transport_type = TT_TCP_SSL;
  152. else if (this->params.transport_type == TT_SCTP)
  153. this->params.transport_type = TT_SCTP_SSL;
  154. }
  155. else if (this->params.transport_type == TT_TCP_SSL ||
  156. this->params.transport_type == TT_SCTP_SSL)
  157. {
  158. this->params.is_ssl = true;
  159. }
  160. }
  161. template<class RPCTYPE>
  162. inline void RPCClient<RPCTYPE>::__task_init(COMPLEXTASK *task) const
  163. {
  164. if (this->has_addr_info)
  165. {
  166. task->init(this->params.transport_type,
  167. (const struct sockaddr *)&this->ss, this->ss_len, "");
  168. }
  169. else
  170. {
  171. task->init(this->uri);
  172. task->set_transport_type(this->params.transport_type);
  173. }
  174. }
  175. template<class RPCTYPE>
  176. inline void RPCClient<RPCTYPE>::task_init(COMPLEXTASK *task) const
  177. {
  178. __task_init(task);
  179. }
  180. static inline void __set_host_by_uri(const ParsedURI *uri, bool is_ssl,
  181. std::string& header_host)
  182. {
  183. if (uri->host && uri->host[0])
  184. header_host = uri->host;
  185. if (uri->port && uri->port[0])
  186. {
  187. int port = atoi(uri->port);
  188. if (is_ssl)
  189. {
  190. if (port != 443)
  191. {
  192. header_host += ":";
  193. header_host += uri->port;
  194. }
  195. }
  196. else
  197. {
  198. if (port != 80)
  199. {
  200. header_host += ":";
  201. header_host += uri->port;
  202. }
  203. }
  204. }
  205. }
  206. template<>
  207. inline void RPCClient<RPCTYPESRPCHttp>::task_init(COMPLEXTASK *task) const
  208. {
  209. __task_init(task);
  210. std::string header_host;
  211. if (this->has_addr_info)
  212. header_host += this->params.host + ":" + std::to_string(this->params.port);
  213. else
  214. __set_host_by_uri(task->get_current_uri(), this->params.is_ssl, header_host);
  215. task->get_req()->set_header_pair("Host", header_host.c_str());
  216. }
  217. template<>
  218. inline void RPCClient<RPCTYPEThriftHttp>::task_init(COMPLEXTASK *task) const
  219. {
  220. __task_init(task);
  221. std::string header_host;
  222. if (this->has_addr_info)
  223. header_host += this->params.host + ":" + std::to_string(this->params.port);
  224. else
  225. __set_host_by_uri(task->get_current_uri(), this->params.is_ssl, header_host);
  226. task->get_req()->set_header_pair("Host", header_host.c_str());
  227. }
  228. } // namespace srpc
  229. #endif