瀏覽代碼

Revert srpc::HttpServerTask.

Xie Han 2 月之前
父節點
當前提交
0faf8d6dc8
共有 2 個文件被更改,包括 151 次插入6 次删除
  1. 139 2
      src/http/http_task.cc
  2. 12 4
      src/http/http_task.h

+ 139 - 2
src/http/http_task.cc

@@ -368,10 +368,32 @@ void HttpServerTask::handle(int state, int error)
 {
 	if (state == WFT_STATE_TOREPLY)
 	{
+		HttpRequest *req = this->get_req();
+
+		// from WFHttpServerTask::handle()
+		req_is_alive_ = req->is_keep_alive();
+		if (req_is_alive_ && req->has_keep_alive_header())
+		{
+			HttpHeaderCursor req_cursor(req);
+			struct HttpMessageHeader header;
+
+			header.name = "Keep-Alive";
+			header.name_len = strlen("Keep-Alive");
+			req_has_keep_alive_header_ = req_cursor.find(&header);
+			if (req_has_keep_alive_header_)
+			{
+				req_keep_alive_.assign((const char *)header.value,
+										header.value_len);
+			}
+		}
+
+		this->state = WFT_STATE_TOREPLY;
+		this->target = this->get_target();
+
 		// fill module data from request to series
 		ModuleSeries *series = new ModuleSeries(this);
 
-		http_get_header_module_data(this->get_req(), this->module_data_);
+		http_get_header_module_data(req, this->module_data_);
 		for (auto *module : this->modules_)
 		{
 			if (module)
@@ -383,14 +405,129 @@ void HttpServerTask::handle(int state, int error)
 	}
 	else if (this->state == WFT_STATE_TOREPLY)
 	{
+		this->state = state;
+		this->error = error;
+		if (error == ETIMEDOUT)
+			this->timeout_reason = TOR_TRANSMIT_TIMEOUT;
+
 		// prepare module_data from series to response
 		for (auto *module : modules_)
 			module->server_task_end(this, this->module_data_);
 
 		http_set_header_module_data(this->module_data_, this->get_resp());
+
+		this->subtask_done();
+	}
+	else
+		delete this;
+}
+
+CommMessageOut *HttpServerTask::message_out()
+{
+	HttpResponse *resp = this->get_resp();
+	struct HttpMessageHeader header;
+
+	if (!resp->get_http_version())
+		resp->set_http_version("HTTP/1.1");
+
+	const char *status_code_str = resp->get_status_code();
+	if (!status_code_str || !resp->get_reason_phrase())
+	{
+		int status_code;
+
+		if (status_code_str)
+			status_code = atoi(status_code_str);
+		else
+			status_code = HttpStatusOK;
+
+		HttpUtil::set_response_status(resp, status_code);
+	}
+
+	if (!resp->is_chunked() && !resp->has_content_length_header())
+	{
+		char buf[32];
+		header.name = "Content-Length";
+		header.name_len = strlen("Content-Length");
+		header.value = buf;
+		header.value_len = sprintf(buf, "%zu", resp->get_output_body_size());
+		resp->add_header(&header);
+	}
+
+	bool is_alive;
+
+	if (resp->has_connection_header())
+		is_alive = resp->is_keep_alive();
+	else
+		is_alive = req_is_alive_;
+
+	if (!is_alive)
+		this->keep_alive_timeo = 0;
+	else
+	{
+		//req---Connection: Keep-Alive
+		//req---Keep-Alive: timeout=5,max=100
+
+		if (req_has_keep_alive_header_)
+		{
+			int flag = 0;
+			std::vector<std::string> params = StringUtil::split(req_keep_alive_, ',');
+
+			for (const auto& kv : params)
+			{
+				std::vector<std::string> arr = StringUtil::split(kv, '=');
+				if (arr.size() < 2)
+					arr.emplace_back("0");
+
+				std::string key = StringUtil::strip(arr[0]);
+				std::string val = StringUtil::strip(arr[1]);
+				if (!(flag & 1) && strcasecmp(key.c_str(), "timeout") == 0)
+				{
+					flag |= 1;
+					// keep_alive_timeo = 5000ms when Keep-Alive: timeout=5
+					this->keep_alive_timeo = 1000 * atoi(val.c_str());
+					if (flag == 3)
+						break;
+				}
+				else if (!(flag & 2) && strcasecmp(key.c_str(), "max") == 0)
+				{
+					flag |= 2;
+					if (this->get_seq() >= atoi(val.c_str()))
+					{
+						this->keep_alive_timeo = 0;
+						break;
+					}
+
+					if (flag == 3)
+						break;
+				}
+			}
+		}
+
+		if ((unsigned int)this->keep_alive_timeo > HTTP_KEEPALIVE_MAX)
+			this->keep_alive_timeo = HTTP_KEEPALIVE_MAX;
+		//if (this->keep_alive_timeo < 0 || this->keep_alive_timeo > HTTP_KEEPALIVE_MAX)
+
+	}
+
+	if (!resp->has_connection_header())
+	{
+		header.name = "Connection";
+		header.name_len = 10;
+		if (this->keep_alive_timeo == 0)
+		{
+			header.value = "close";
+			header.value_len = 5;
+		}
+		else
+		{
+			header.value = "Keep-Alive";
+			header.value_len = 10;
+		}
+
+		resp->add_header(&header);
 	}
 
-	WFHttpServerTask::handle(state, error);
+	return this->WFServerTask::message_out();
 }
 
 } // end namespace srpc

+ 12 - 4
src/http/http_task.h

@@ -23,7 +23,7 @@
 #include <string>
 #include "workflow/HttpUtil.h"
 #include "workflow/WFTaskFactory.h"
-#include "workflow/WFHttpServerTask.h"
+#include "workflow/WFGlobal.h"
 #include "rpc_module.h"
 
 namespace srpc
@@ -84,13 +84,16 @@ private:
 	std::list<RPCModule *> modules_;
 };
 
-class HttpServerTask : public WFHttpServerTask
+class HttpServerTask : public WFServerTask<protocol::HttpRequest,
+										   protocol::HttpResponse>
 {
 public:
 	HttpServerTask(CommService *service,
 				   std::list<RPCModule *>&& modules,
 				   std::function<void (WFHttpTask *)>& process) :
-		WFHttpServerTask(service, process),
+		WFServerTask(service, WFGlobal::get_scheduler(), process),
+		req_is_alive_(false),
+		req_has_keep_alive_header_(false),
 		modules_(std::move(modules))
 	{}
 
@@ -99,7 +102,8 @@ public:
 	bool is_ssl() const { return this->is_ssl_; }
 	unsigned short listen_port() const { return this->listen_port_; }
 
-	class ModuleSeries : public Series
+	class ModuleSeries : public WFServerTask<protocol::HttpRequest,
+											 protocol::HttpResponse>::Series
 	{
 	public:
 		ModuleSeries(WFServerTask<protocol::HttpRequest,
@@ -128,8 +132,12 @@ public:
 
 protected:
 	virtual void handle(int state, int error);
+	virtual CommMessageOut *message_out();
 
 protected:
+	bool req_is_alive_;
+	bool req_has_keep_alive_header_;
+	std::string req_keep_alive_;
 	RPCModuleData module_data_;
 	std::list<RPCModule *> modules_;
 	bool is_ssl_;