Kaynağa Gözat

Feat : support setting transport_type in RPCClientParams (#366)

liyingxin 2 ay önce
ebeveyn
işleme
5eedc287c3

+ 0 - 25
src/generator/printer.h

@@ -1305,31 +1305,6 @@ inline %sClient::%sClient(const struct srpc::RPCClientParams *params):
 		this->register_compress_type(srpc::RPCCompressGzip);
 )";
 
-	std::string client_class_constructor_params_format = R"(
-	%sClient(struct srpc::RPCServiceParams& params) :
-			srpc::RPCService("%s")
-	{
-		this->params.service_params = params;
-
-		if (!params.url.empty())
-		{
-			URIParser::parse(params.url, this->params.uri);
-			this->params.has_uri = true;
-		} else {
-			this->params.has_uri = false;
-		}
-
-		if (!params.host.empty() && !params.port.empty())
-		{
-			get_addr_info(params.host.c_str(), params.port.c_str(), &this->params.ai);
-			this->params.has_addr_info = true;
-		} else {
-			this->params.has_addr_info = false;
-		}
-
-		this->register_compress_type(srpc::RPCCompressGzip);
-)";
-
 	std::string client_class_functions = R"(
 	const struct srpc::RPCClientParams *get_params()
 	{

+ 17 - 2
src/rpc_client.h

@@ -180,18 +180,33 @@ inline void RPCClient<RPCTYPE>::init(const RPCClientParams *params)
 																this->uri,
 																&this->ss,
 																&this->ss_len);
+
+	if (this->params.is_ssl)
+	{
+		if (this->params.transport_type == TT_TCP)
+			this->params.transport_type = TT_TCP_SSL;
+		else if (this->params.transport_type == TT_SCTP)
+			this->params.transport_type = TT_SCTP_SSL;
+	}
+	else if (this->params.transport_type == TT_TCP_SSL ||
+			 this->params.transport_type == TT_SCTP_SSL)
+	{
+		this->params.is_ssl = true;
+	}
 }
 
 template<class RPCTYPE>
 inline void RPCClient<RPCTYPE>::__task_init(COMPLEXTASK *task) const
 {
 	if (this->has_addr_info)
-		task->init(this->params.is_ssl ? TT_TCP_SSL : TT_TCP,
+	{
+		task->init(this->params.transport_type,
 				   (const struct sockaddr *)&this->ss, this->ss_len, "");
+	}
 	else
 	{
 		task->init(this->uri);
-		task->set_transport_type(this->params.is_ssl ? TT_TCP_SSL : TT_TCP);
+		task->set_transport_type(this->params.transport_type);
 	}
 }
 

+ 2 - 0
src/rpc_options.h

@@ -39,6 +39,7 @@ struct RPCTaskParams
 struct RPCClientParams
 {
 	RPCTaskParams task_params;
+	enum TransportType transport_type;
 //host + port + is_ssl
 	std::string host;
 	unsigned short port;
@@ -71,6 +72,7 @@ static constexpr struct RPCTaskParams RPC_TASK_PARAMS_DEFAULT =
 static const struct RPCClientParams RPC_CLIENT_PARAMS_DEFAULT =
 {
 /*	.task_params		=	*/	RPC_TASK_PARAMS_DEFAULT,
+/*	.transport_type		=	*/	TT_TCP,
 /*	.host				=	*/	"",
 /*	.port				=	*/	SRPC_DEFAULT_PORT,
 /*	.is_ssl				=	*/	false,

+ 7 - 4
tools/srpc_config.cc

@@ -409,6 +409,13 @@ static std::string ctl_client_load_params_format = R"(
 	RPCClientParams params = RPC_CLIENT_PARAMS_DEFAULT;
 	params.host = config.client_host();
 	params.port = config.client_port();
+	params.transport_type = config.client_transport_type();
+	params.is_ssl = config.client_is_ssl();
+	params.url = config.client_url();
+	params.callee_timeout = config.client_callee_timeout();
+	params.caller = config.client_caller();
+
+	params.task_params.retry_max = config.client_retry_max();
 )";
 
 static std::string ctl_client_main_params_format = R"(
@@ -465,10 +472,6 @@ void ControlGenerator::ControlPrinter::print_server_load_config()
 void ControlGenerator::ControlPrinter::print_client_params()
 {
 	fprintf(this->out_file, "%s", ctl_client_load_params_format.c_str());
-	fprintf(this->out_file, "\tparams.is_ssl = config.client_is_ssl();\n");
-	fprintf(this->out_file, "\tparams.url = config.client_url();\n");
-	fprintf(this->out_file, "\tparams.caller = config.client_caller();\n");
-
 	// TODO: set client task params
 }
 

+ 1 - 1
tools/templates/common/config.json

@@ -6,10 +6,10 @@
 
   "client":
   {
+    "transport_type": "TT_TCP",
     "remote_host": "127.0.0.1",
     "remote_port": 8080,
     "is_ssl" : false,
-    "redirect_max": 2,
     "retry_max": 1,
     "user_name": "root",
 	"password": "",

+ 16 - 1
tools/templates/config/config_full.cc

@@ -245,12 +245,27 @@ void RPCConfig::load_server()
 
 void RPCConfig::load_client()
 {
+    if (this->data["client"].has("transport_type"))
+    {
+        std::string type = this->data["client"]["transport_type"].get<std::string>();
+        if (type == "TT_SCTP")
+            this->c_transport_type = TT_SCTP;
+        else if (type == "TT_UDP")
+            this->c_transport_type = TT_UDP;
+    }
+
     if (this->data["client"].has("remote_host"))
         this->c_host = this->data["client"]["remote_host"].get<std::string>();
 
     if (this->data["client"].has("remote_port"))
         this->c_port = this->data["client"]["remote_port"];
 
+    if (this->data["client"].has("is_ssl"))
+        this->c_is_ssl = this->data["client"]["is_ssl"];
+
+    if (this->data["client"].has("callee_timeout"))
+        this->c_callee_timeout = this->data["client"]["callee_timeout"];
+
     if (this->data["client"].has("redirect_max"))
         this->c_redirect_max = this->data["client"]["redirect_max"];
 
@@ -405,7 +420,7 @@ void RPCConfig::load_trace()
                 report_interval = it["report_interval_ms"];
 
             auto *filter = new RPCTraceOpenTelemetry(url,
-													 OTLP_TRACES_PATH,
+                                                     OTLP_TRACES_PATH,
                                                      redirect_max,
                                                      retry_max,
                                                      spans_per_second,

+ 7 - 3
tools/templates/config/config_full.h

@@ -43,13 +43,15 @@ public:
     unsigned short server_port() const { return this->s_port; }
     const char *server_cert_file() const { return this->s_cert_file.c_str(); }
     const char *server_file_key() const { return this->s_file_key.c_str(); }
-    unsigned short client_port() const { return this->c_port; }
+    enum TransportType client_transport_type() const { return this->c_transport_type; }
     const char *client_host() const { return this->c_host.c_str(); }
+    unsigned short client_port() const { return this->c_port; }
     bool client_is_ssl() const { return this->c_is_ssl; }
     const char *client_url() const { return this->c_url.c_str(); }
+    int client_callee_timeout() const { return this->c_callee_timeout; }
+    const char *client_caller() const { return this->c_caller.c_str(); }
     int redirect_max() const { return this->c_redirect_max; }
     int retry_max() const { return this->c_retry_max; }
-    const char *client_caller() const { return this->c_caller.c_str(); }
     const char *client_user_name() const { return this->c_user_name.c_str(); }
     const char *client_password() const { return this->c_password.c_str(); }
     const char *get_root_path() const { return this->root_path.c_str(); }
@@ -72,13 +74,15 @@ private:
     unsigned short s_port;
     std::string s_cert_file;
     std::string s_file_key;
+    enum TransportType c_transport_type;
     std::string c_host;
     unsigned short c_port;
     bool c_is_ssl;
     std::string c_url;
+    int c_callee_timeout;
+    std::string c_caller;
     int c_redirect_max;
     int c_retry_max;
-    std::string c_caller;
     std::string c_user_name;
     std::string c_password;
     std::string root_path;

+ 0 - 1
tools/templates/proxy/client_rpc.conf

@@ -4,7 +4,6 @@
     "remote_host": "127.0.0.1",
     "remote_port": 1411,
     "is_ssl" : false,
-    "redirect_max": 2,
     "retry_max": 1,
     "callee" : "rpc_client"
   }

+ 0 - 1
tools/templates/rpc/client.conf

@@ -4,7 +4,6 @@
     "remote_host": "127.0.0.1",
     "remote_port": 1412,
     "is_ssl" : false,
-    "redirect_max": 2,
     "retry_max": 1,
     "callee" : "rpc_client"
   }