rpc_context.inl 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /*
  2. Copyright (c) 2020 Sogou, Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. #include <mutex>
  14. #include <condition_variable>
  15. #include <workflow/WFTask.h>
  16. #include "rpc_message.h"
  17. #include "rpc_module.h"
  18. namespace srpc
  19. {
  20. template<class T>
  21. struct ThriftReceiver
  22. {
  23. std::mutex mutex;
  24. std::condition_variable cond;
  25. RPCSyncContext ctx;
  26. T output;
  27. bool is_done = false;
  28. };
  29. template<class RPCREQ, class RPCRESP>
  30. class RPCContextImpl : public RPCContext
  31. {
  32. public:
  33. long long get_seqid() const override
  34. {
  35. return task_->get_task_seq();
  36. }
  37. std::string get_remote_ip() const override
  38. {
  39. char ip_str[INET6_ADDRSTRLEN + 1] = { 0 };
  40. struct sockaddr_storage addr;
  41. socklen_t addrlen = sizeof (addr);
  42. if (this->get_peer_addr((struct sockaddr *)&addr, &addrlen) == 0)
  43. {
  44. if (addr.ss_family == AF_INET)
  45. {
  46. struct sockaddr_in *sin = (struct sockaddr_in *)(&addr);
  47. inet_ntop(AF_INET, &sin->sin_addr, ip_str, INET_ADDRSTRLEN);
  48. }
  49. else if (addr.ss_family == AF_INET6)
  50. {
  51. struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)(&addr);
  52. inet_ntop(AF_INET6, &sin6->sin6_addr, ip_str, INET6_ADDRSTRLEN);
  53. }
  54. }
  55. return std::string(ip_str);
  56. }
  57. int get_peer_addr(struct sockaddr *addr, socklen_t *addrlen) const override
  58. {
  59. return task_->get_peer_addr(addr, addrlen);
  60. }
  61. const std::string& get_service_name() const override
  62. {
  63. return task_->get_req()->get_service_name();
  64. }
  65. const std::string& get_method_name() const override
  66. {
  67. return task_->get_req()->get_method_name();
  68. }
  69. void set_data_type(RPCDataType type) override
  70. {
  71. task_->get_resp()->set_data_type(type);
  72. }
  73. void set_compress_type(RPCCompressType type) override
  74. {
  75. task_->get_resp()->set_compress_type(type);
  76. }
  77. void set_send_timeout(int timeout) override
  78. {
  79. task_->set_send_timeout(timeout);
  80. }
  81. void set_keep_alive(int timeout) override
  82. {
  83. task_->set_keep_alive(timeout);
  84. }
  85. SeriesWork *get_series() const override
  86. {
  87. return series_of(task_);
  88. }
  89. void *get_user_data() const override
  90. {
  91. return task_->user_data;
  92. }
  93. bool get_attachment(const char **attachment, size_t *len) const override
  94. {
  95. if (this->is_server_task())
  96. return task_->get_req()->get_attachment_nocopy(attachment, len);
  97. else
  98. return task_->get_resp()->get_attachment_nocopy(attachment, len);
  99. }
  100. bool get_http_header(const std::string& name, std::string& value) const override
  101. {
  102. if (this->is_server_task())
  103. return task_->get_req()->get_http_header(name, value);
  104. else
  105. return task_->get_resp()->get_http_header(name, value);
  106. }
  107. public:
  108. // for client-done
  109. bool success() const override
  110. {
  111. return task_->get_resp()->get_status_code() == RPCStatusOK;
  112. }
  113. int get_status_code() const override
  114. {
  115. return task_->get_resp()->get_status_code();
  116. }
  117. const char *get_errmsg() const override
  118. {
  119. return task_->get_resp()->get_errmsg();
  120. }
  121. int get_error() const override
  122. {
  123. return task_->get_resp()->get_error();
  124. }
  125. int get_timeout_reason() const override
  126. {
  127. return task_->get_timeout_reason();
  128. }
  129. public:
  130. // for server-process
  131. void set_attachment_nocopy(const char *attachment, size_t len) override
  132. {
  133. task_->get_resp()->set_attachment_nocopy(attachment, len);
  134. }
  135. void set_reply_callback(std::function<void (RPCContext *ctx)> cb) override
  136. {
  137. if (this->is_server_task())
  138. {
  139. if (cb)
  140. {
  141. task_->set_callback([this, cb](SubTask *task) {
  142. cb(this);
  143. });
  144. }
  145. else
  146. task_->set_callback(nullptr);
  147. }
  148. }
  149. bool set_http_code(int code) override
  150. {
  151. if (this->is_server_task())
  152. return task_->get_resp()->set_http_code(code);
  153. return false;
  154. }
  155. bool set_http_header(const std::string& name, const std::string& value) override
  156. {
  157. if (this->is_server_task())
  158. return task_->get_resp()->set_http_header(name, value);
  159. return false;
  160. }
  161. bool add_http_header(const std::string& name, const std::string& value) override
  162. {
  163. if (this->is_server_task())
  164. return task_->get_resp()->add_http_header(name, value);
  165. return false;
  166. }
  167. bool log(const RPCLogVector& fields) override
  168. {
  169. if (this->is_server_task() && module_data_)
  170. {
  171. std::string key;
  172. std::string value;
  173. RPCCommon::log_format(key, value, fields);
  174. module_data_->emplace(std::move(key), std::move(value));
  175. return true;
  176. }
  177. return false;
  178. }
  179. bool add_baggage(const std::string& key, const std::string& value) override
  180. {
  181. if (this->is_server_task() && module_data_)
  182. {
  183. (*module_data_)[key] = value;
  184. return true;
  185. }
  186. return false;
  187. }
  188. bool get_baggage(const std::string& key, std::string& value) override
  189. {
  190. if (module_data_)
  191. {
  192. const auto it = module_data_->find(key);
  193. if (it != module_data_->cend())
  194. {
  195. value = it->second;
  196. return true;
  197. }
  198. }
  199. return false;
  200. }
  201. void set_json_add_whitespace(bool on) override
  202. {
  203. if (this->is_server_task())
  204. task_->get_resp()->set_json_add_whitespace(on);
  205. }
  206. void set_json_always_print_enums_as_ints(bool on) override
  207. {
  208. if (this->is_server_task())
  209. task_->get_resp()->set_json_enums_as_ints(on);
  210. }
  211. void set_json_preserve_proto_field_names(bool on) override
  212. {
  213. if (this->is_server_task())
  214. task_->get_resp()->set_json_preserve_names(on);
  215. }
  216. void set_json_always_print_primitive_fields(bool on) override
  217. {
  218. if (this->is_server_task())
  219. task_->get_resp()->set_json_print_primitive(on);
  220. }
  221. //void noreply() override;
  222. //WFConnection *get_connection() override;
  223. public:
  224. RPCContextImpl(WFNetworkTask<RPCREQ, RPCRESP> *task,
  225. RPCModuleData *module_data) :
  226. task_(task),
  227. module_data_(module_data)
  228. {
  229. }
  230. protected:
  231. bool is_server_task() const
  232. {
  233. return task_->get_state() == WFT_STATE_TOREPLY
  234. || task_->get_state() == WFT_STATE_NOREPLY;
  235. }
  236. protected:
  237. WFNetworkTask<RPCREQ, RPCRESP> *task_;
  238. private:
  239. RPCModuleData *module_data_;
  240. };
  241. } // namespace srpc