Ver Fonte

202 enhance rule matching code logical (#212)

* enhance code logic.

* make expr_rules to local variable.

* update version.
Yang Shuang há 1 ano atrás
pai
commit
fc64dc5f9f

+ 1 - 1
.github/workflows/release.yml

@@ -6,7 +6,7 @@ on:
 env:
   CC: gcc-4.9
   CXX: g++-4.9    
-  ver: dtc-v2.1.1
+  ver: dtc-v2.1.2
 
 jobs:
   build:

+ 3 - 8
src/agent/da.c

@@ -477,14 +477,6 @@ static int da_pre_run(struct instance *dai) {
 
 	log_info("DTC AGENT init.");
 
-#if 0
-	if(re_load_table_key(dtckey) < 0)
-	{
-		log_error("load dtc define error.");
-		return -1;
-	}
-#endif
-
 	if (daemonize) {
 		status = da_daemonize(1);
 		if (status != 0) {
@@ -534,6 +526,9 @@ static int da_pre_run(struct instance *dai) {
 	}
 
 	da_print_run(dai);
+
+	re_load_all_rules();
+	
 	return 0;
 }
 

+ 1 - 23
src/agent/da_msg.c

@@ -214,7 +214,6 @@ static struct msg *_msg_get() {
 	m->sending = 0;
 
 	m->pkt_nr = 0;
-	m->mid = 0;
 	m->ismysql = 0;
 
 	return m;
@@ -551,29 +550,8 @@ struct mbuf *msg_insert_mem_bulk(struct msg *msg,struct mbuf *mbuf,uint8_t *pos,
 uint32_t msg_backend_idx(struct msg *msg, uint8_t *key, uint32_t keylen) {
 	struct conn *conn = msg->owner;
 	struct server_pool *pool = conn->owner;
-	uint32_t i ;
-	struct server_pool *temp_pool = NULL;
-	struct context *ctx = pool->ctx;
-	log_debug("msg backend idx entry");
-	for(i = 0 ; i < array_n(&(ctx->pool)) ; i ++){
-		struct string tmp1, tmp2;
-		temp_pool = (struct server_pool *)array_get(&(ctx->pool), i);
-		string_copy(&tmp2, temp_pool->name.data, temp_pool->name.len);
-		string_upper(&tmp2);		
-		//if(string_compare(&tmp2, &msg->table_name) == 0)
-		if(msg->mid == temp_pool->mid)
-			break;
-		else
-			temp_pool = NULL;			
-	}
 
-	if(temp_pool)
-		return server_pool_idx(temp_pool, key, keylen);
-	else
-	{
-		temp_pool = (struct server_pool *)array_get(&(ctx->pool), 0);
-		return server_pool_idx(temp_pool, NULL, 0);
-	}
+	return server_pool_idx(pool, key, keylen);
 }
 
 static int msg_send_chain(struct context *ctx, struct conn *conn,

+ 0 - 2
src/agent/da_msg.h

@@ -153,9 +153,7 @@ struct msg {
 	enum enum_server_command command; /* mysql request command type */
 	enum enum_agent_admin admin;
 	uint8_t layer;
-	int mid;
 	int ismysql;
-	struct string table_name;
 	union COM_DATA data;
 
 	int err; /* errno on error? */

+ 1 - 28
src/agent/da_request.c

@@ -496,35 +496,8 @@ static void req_forward(struct context *ctx, struct conn *c_conn,
 	}
 	else
 	{
-		for(i = 0 ; i < array_n(&(ctx->pool)) ; i ++){
-			temp_pool = (struct server_pool *)array_get(&(ctx->pool), i);
-			if(msg->mid == 0)
-				break;
-			if(msg->mid == temp_pool->mid)
-				break;
-			else
-				temp_pool = NULL;
-		}
-		if(temp_pool == NULL){
-			log_debug("s_conn null");
-			//client connection is still exist,no swallow
-			msg->done = 1;
-			msg->error = 1;
-			msg->err = MSG_REQ_FORWARD_ERR;
-			if (msg->frag_owner != NULL) {
-				msg->frag_owner->nfrag_done++;
-			}
-			if (req_done(c_conn, msg)) {
-				rsp_forward(ctx, c_conn, msg);
-			}
-			stats_pool_incr(ctx, pool, forward_error);
-			log_error("msg :%" PRIu64 " from c %d ,get s_conn fail!",
-				msg->id, c_conn->fd);
-			return;
-		}
-
 		s_conn =
-			server_pool_conn(ctx, temp_pool, msg);
+			server_pool_conn(ctx, (struct server_pool *) c_conn->owner, msg);
 	}
 
 	if (s_conn == NULL) {

+ 1 - 1
src/agent/da_server.c

@@ -110,7 +110,7 @@ static uint32_t server_pool_hash(struct server_pool *pool, uint8_t *key,
 uint32_t server_pool_idx(struct server_pool *pool, uint8_t *key,
 		uint32_t keylen) {
 	ASSERT(array_n(&pool->server) != 0);
-	//ASSERT(key != NULL && keylen != 0);
+	ASSERT(key != NULL && keylen != 0);
 	uint32_t hash, idx;
 	
 	hash = server_pool_hash(pool, key, keylen);

+ 7 - 95
src/agent/my/my_parse.c

@@ -632,76 +632,16 @@ bool check_cmd_insert(struct string *str)
 		return false;
 }
 
-int get_mid_by_dbname(const char* sessiondb, const char* sql, struct msg* r)
-{
-	int mid = 0;
-	struct context* ctx = NULL;
-	struct conn *c_conn = NULL;
-	int sql_len = 0;
-	int ret = 0;
-	char cmp_string[300] = {0};
-	struct string req_string; 
-	c_conn = r->owner;
-	ctx = conn_to_ctx(c_conn);
-
-	ret = get_table_with_db(sessiondb, sql, &cmp_string);
-	if(ret >= 0)
-	{
-		struct array *pool = &(ctx->pool);
-		int i;
-
-		string_copy(&req_string, cmp_string, strlen(cmp_string));
-		string_upper(&req_string);
-
-		for (i = 0; i < array_n(pool); i++) {
-			struct server_pool *p = (struct server_pool *)array_get(pool, i);
-			struct string xmlname; 
-			if(string_empty(&p->name))
-				continue;
-
-			string_duplicate(&xmlname, &(p->name));
-			string_upper(&xmlname);
-
-			log_info("xml name: %s, cmp string: %s", xmlname.data, req_string.data);
-			if(da_strncmp(xmlname.data, req_string.data, req_string.len) == 0)
-			{
-				mid = p->mid;
-			}
-
-			string_deinit(&xmlname);
-		}
-
-		string_deinit(&req_string);
-	}
-
-	log_info("mid result: %d", mid);
-	return mid;
-}
-
-void get_tablename(struct msg* r, uint8_t* sql, int sql_len)
-{
-	char tablename[260] = {0};
-	if(sql == NULL || sql_len <= 0)
-		return ;
-
-	int ret = sql_parse_table(sql, &tablename);
-	if(ret > 0)
-	{
-		string_copy(&r->table_name, tablename, strlen(tablename));
-	}
-	log_debug("tablename: %s", tablename);
-}
-
 int my_get_route_key(uint8_t *sql, int sql_len, int *start_offset,
-		     int *end_offset, const char* dbname, struct msg* r)
+		     int *end_offset, const char* dbsession, struct msg* r)
 {
 	int i = 0;
 	struct string str, ostr;
 	int ret = 0;
 	int layer = 0;
 	string_init(&str);
-	string_copy(&str, sql, sql_len);
 	string_init(&ostr);
+	string_copy(&str, sql, sql_len);
 	string_copy(&ostr, sql, sql_len);
 
 	if (string_empty(&str))
@@ -711,44 +651,16 @@ int my_get_route_key(uint8_t *sql, int sql_len, int *start_offset,
 		return -9;
 
 	log_debug("sql: %s", str.data);
-	if(dbname && strlen(dbname))
+	if(dbsession && strlen(dbsession))
 	{
-		log_debug("dbname len:%d, dbname: %s", strlen(dbname), dbname);
-	}
-
-	int mid = get_mid_by_dbname(dbname, str.data, r);
-	char conf_path[260] = {0};
-	memset(conf_path, 0, 260);
-	if(mid != 0)
-	{
-		sprintf(conf_path, "../conf/dtc-conf-%d.yaml", mid);
-		r->mid = mid;
-	}
-
-	get_tablename(r, str.data, str.len);
-	if(r->table_name.len > 0)
-		log_debug("table name: %s", r->table_name.data);
-
-	char strkey[260] = {0};
-	memset(strkey, 0, 260);
-	if(strlen(conf_path) > 0)
-	{
-		if(rule_get_key(conf_path, strkey) <= 0)
-		{
-			ret = -5;
-			goto done;
-		}
-		else
-		{
-			log_debug("strkey: %s", strkey);
-		}
+		log_debug("dbsession len:%d, dbsession: %s", strlen(dbsession), dbsession);
 	}
 
-	r->keytype = rule_get_key_type(conf_path);
-	log_debug("strkey type: %d", r->keytype);
+	char strkey[1024] = {0};
+	memset(strkey, 0, 1024);
 
 	//agent sql route, rule engine
-	layer = rule_sql_match(str.data, ostr.data, dbname, strlen(conf_path) > 0 ? conf_path : NULL);
+	layer = rule_sql_match(str.data, ostr.data, dbsession, &strkey, &r->keytype);
 	log_debug("rule layer: %d", layer);
 
 	if(layer != 1)

+ 1 - 1
src/agent/my/my_parse.h

@@ -35,6 +35,6 @@ int my_do_command(struct context *ctx, struct conn *c_conn, struct msg *msg);
 int my_fragment(struct msg *r, uint32_t ncontinuum, struct msg_tqh *frag_msgq);
 
 int my_get_route_key(uint8_t *sql, int sql_len, int *start_offset,
-		     int *end_offset, const char* dbname, struct msg* r);
+		     int *end_offset, const char* dbsession, struct msg* r);
 
 #endif /* _MY_PARSE_H_ */

+ 0 - 42
src/rule/main.cc

@@ -11,8 +11,6 @@
 
 using namespace std;
 
-extern vector<vector<hsql::Expr*> > expr_rules;
-
 int main(int argc, char* argv[])
 {
     printf("hello dtc, ./bin KEY SQL\n");
@@ -31,45 +29,5 @@ int main(int argc, char* argv[])
     
     cout<<"parsing success."<<sql_ast.isValid()<<endl;
 
-#if 0
-    if(re_load_table_key(szkey) < 0)
-        return -1;
-    key = szkey;
-
-    if(key.length() == 0)
-        return -1;
-
-    cout<<"key: "<<key<<endl;
-    cout<<"sql: "<<sql<<endl;
-
-    init_log4cplus();
-    
-    int ret = re_load_rule();
-    if(ret != 0)
-    {
-        log4cplus_error("load rule error:%d", ret);
-        return 0;
-    }
-
-    hsql::SQLParserResult sql_ast;
-    if(re_parse_sql(sql, &sql_ast) != 0)
-        return -1;
-
-    ret = re_match_sql(&sql_ast, expr_rules, &sql_ast);
-    if(ret == 0)
-    {
-        if(re_is_cache_sql(&sql_ast, key))
-        {
-            printf("RULE MATCH : L1 - cache data\n");
-        }
-        else
-        {
-            printf("RULE MATCH : L2 - hot data\n");
-        }
-    }
-    else {
-        printf("RULE MATCH : L3 - full data\n");
-    }
-#endif
     return 0;
 }

+ 55 - 73
src/rule/re_load.cc

@@ -3,21 +3,19 @@
 #include "log.h"
 #include <string>
 #include <iostream>
+#include <fcntl.h>
 #include "re_comm.h"
 
-std::string conf_file = "../conf/dtc.yaml";
-
 using namespace hsql;
-hsql::SQLParserResult rule_ast;
 
-vector<vector<hsql::Expr*> > expr_rules;
+std::map<std::string, std::string> g_map_dtc_yaml;
 
 // load rule from dtc.yaml
-std::string do_get_rule()
+std::string do_get_rule(std::string buf)
 {
     YAML::Node config;
     try {
-        config = YAML::LoadFile(conf_file);
+        config = YAML::Load(buf);
 	} catch (const YAML::Exception &e) {
 		log4cplus_error("config file error:%s\n", e.what());
 		return "";
@@ -52,14 +50,14 @@ int get_rule_condition_num(hsql::Expr* rule)
 }
 
 // parse rule txt to AST.
-int do_parse_rule(std::string rules)
+int do_parse_rule(std::string rules, hsql::SQLParserResult* rule_ast)
 {
     std::string sql = "select * from rules where ";
     sql += rules;
     sql += ";";
     log4cplus_debug("rule sql: %s", sql.c_str());
-    bool r = hsql::SQLParser::parse(sql, &rule_ast);
-    if (r && rule_ast.isValid() && rule_ast.size() > 0)
+    bool r = hsql::SQLParser::parse(sql, rule_ast);
+    if (r && rule_ast->isValid() && rule_ast->size() > 0)
     {
         return 0; 
     }
@@ -76,46 +74,46 @@ int traverse_sub_ast(hsql::Expr* where, vector<hsql::Expr*>* v_expr)
     }
     else
     {
+        log4cplus_debug("type: %d, %d, %s", where->type, where->opType, where->expr->name);
         v_expr->push_back(where);
     }
 }
 
-int traverse_ast(hsql::Expr* where)
+int traverse_ast(hsql::Expr* where, vector<vector<hsql::Expr*> >* expr_rules)
 {
     if(where->isType(kExprOperator) &&  where->opType == kOpOr)
     {
-        traverse_ast(where->expr);
-        traverse_ast(where->expr2);
+        traverse_ast(where->expr, expr_rules);
+        traverse_ast(where->expr2, expr_rules);
     }
     else
     {
         vector<hsql::Expr*> v_expr;
         traverse_sub_ast(where, &v_expr);
 
-        expr_rules.push_back(v_expr);
+        expr_rules->push_back(v_expr);
     }
 
     return 0;
 }
 
 //legitimacy check.
-int do_check_rule()
+int do_check_rule(hsql::SQLParserResult* rule_ast, vector<vector<hsql::Expr*> >* expr_rules)
 {
     hsql::Expr *where = NULL;
-
-    if(rule_ast.size() != 1)
+    if(rule_ast->size() != 1)
         return -1;
 
-    if(rule_ast.getStatement(0)->type() != hsql::kStmtSelect)
+    if(rule_ast->getStatement(0)->type() != hsql::kStmtSelect)
         return -2;
 
-    hsql::SelectStatement* stmt = (SelectStatement*)rule_ast.getStatement(0);
+    hsql::SelectStatement* stmt = (SelectStatement*)rule_ast->getStatement(0);
 
     where = stmt->whereClause;
     if(!where)
         return -3;
 
-    traverse_ast(where);
+    traverse_ast(where, expr_rules);
 
     return 0;
 }
@@ -132,25 +130,58 @@ int do_split_rules()
     return 0;
 }
 
-int re_load_rule()
+std::string load_dtc_yaml_buffer(int mid)
+{
+    char path[260];
+    int i_length = 0;
+    char* file = NULL;
+    
+    sprintf(path, "../conf/dtc-conf-%d.yaml", mid);
+
+    int fd = -1;
+
+	if ((fd = open(path, O_RDONLY)) < 0) 
+    {
+		log4cplus_error("open config file error");
+		return "";
+	}
+
+	printf("open file:%s\n", path);
+	lseek(fd, 0L, SEEK_SET);
+	i_length = lseek(fd, 0L, SEEK_END);
+	lseek(fd, 0L, SEEK_SET);
+	// Attention: memory init here ,need release outside
+	file = (char *)malloc(i_length + 1);
+	int readlen = read(fd, file, i_length);
+	if (readlen < 0 || readlen == 0)
+		return "";
+	file[i_length] = '\0';
+	close(fd);
+	i_length++; // add finish flag length
+    std::string res = file;
+    delete file;
+    return res;
+}
+
+int re_load_rule(std::string buf, hsql::SQLParserResult* rule_ast, vector<vector<hsql::Expr*> >* expr_rules)
 {
     log4cplus_debug("load rule start...");
 
-    if(rule_ast.isValid() && rule_ast.size() > 0)
+    if(rule_ast->isValid() && rule_ast->size() > 0)
         return 0;
 
-    std::string rules = do_get_rule();
+    std::string rules = do_get_rule(buf);
     if(rules.length() <= 0)
         return -1;
 
-    int ret = do_parse_rule(rules);
+    int ret = do_parse_rule(rules, rule_ast);
     if(ret != 0)
     {
         log4cplus_error("match rules parsed failed, %d", ret);
         return -2;
     }
 
-    ret = do_check_rule();
+    ret = do_check_rule(rule_ast, expr_rules);
     if(ret != 0)
     {
         log4cplus_error("match rules check failed, %d", ret);
@@ -160,52 +191,3 @@ int re_load_rule()
     log4cplus_debug("load rule end.");
     return 0;
 }
-
-extern "C" int re_load_table_key(char* key)
-{
-    YAML::Node config;
-    try {
-        config = YAML::LoadFile(conf_file);
-	} catch (const YAML::Exception &e) {
-		log4cplus_error("config file error:%s\n", e.what());
-		return -1;
-	}
-
-    YAML::Node node = config["primary"]["cache"]["field"][0]["name"];
-    if(node)
-    {
-        if(node.as<string>().length() >= 50)
-        {
-            return -1;
-        }
-        strcpy(key, node.as<string>().c_str());
-        return 0;
-    }
-
-    return -1;
-}
-
-std::string re_load_table_name()
-{
-    YAML::Node config;
-    try {
-        config = YAML::LoadFile(conf_file);
-	} catch (const YAML::Exception &e) {
-		log4cplus_error("config file error:%s\n", e.what());
-		return "";
-	}
-
-    YAML::Node node = config["primary"]["table"];
-    if(node)
-    {
-        if(node.as<string>().length() >= 50)
-        {
-            return "";
-        }
-        std::string res = node.as<string>();
-        transform(res.begin(),res.end(),res.begin(),::toupper);
-        return res;
-    }
-
-    return "";
-}

+ 5 - 2
src/rule/re_load.h

@@ -1,7 +1,10 @@
 #include "../libs/hsql/include/SQLParser.h"
 #include "../libs/hsql/include/util/sqlhelper.h"
 #include <string>
+#include <vector>
+
+using namespace std;
 
 int get_rule_condition_num(hsql::Expr* rule);
-int re_load_rule();
-std::string re_load_table_name();
+int re_load_rule(std::string buf, hsql::SQLParserResult* rule_ast, vector<vector<hsql::Expr*> >* expr_rules);
+std::string load_dtc_yaml_buffer(int mid);

+ 122 - 65
src/rule/rule.cc

@@ -10,22 +10,24 @@
 #include "log.h"
 #include "mxml.h"
 #include "yaml-cpp/yaml.h"
+#include "mxml.h"
 
 #define SPECIFIC_L1_SCHEMA "L1"
 #define SPECIFIC_L2_SCHEMA "L2"
 #define SPECIFIC_L3_SCHEMA "L3"
 
+#define AGENT_XML_FILE "../conf/agent.xml"
+
 using namespace std;
 using namespace hsql;
 
-extern vector<vector<hsql::Expr*> > expr_rules;
-extern std::string conf_file;
+extern std::map<std::string, std::string> g_map_dtc_yaml;
 
-std::string get_key_info(std::string conf)
+std::string get_key_info(std::string buf)
 {
     YAML::Node config;
     try {
-        config = YAML::LoadFile(conf);
+        config = YAML::Load(buf);
 	} catch (const YAML::Exception &e) {
 		log4cplus_error("config file error:%s\n", e.what());
 		return "";
@@ -169,13 +171,12 @@ extern "C" int get_statement_value(char* str, int len, const char* strkey, int*
         return -1;
 }
 
-extern "C" int get_table_with_db(const char* sessiondb, const char* sql, char* result)
+std::string get_table_with_db(const char* sessiondb, const char* sql)
 {
-	if(result == NULL)
-		return -1;
+	char result[300];
     hsql::SQLParserResult sql_ast;
     if(re_parse_sql(sql, &sql_ast) != 0)
-        return -2;
+        return "";
 	memset(result, 0, 300);
 
     // Get db name
@@ -191,7 +192,7 @@ extern "C" int get_table_with_db(const char* sessiondb, const char* sql, char* r
     else
     {
         log4cplus_error("no database selected.");
-        return -3;
+        return "";
     }
 
     //Append symbol.
@@ -205,25 +206,24 @@ extern "C" int get_table_with_db(const char* sessiondb, const char* sql, char* r
     }
     else
     {
-        return -4;
+        return "";
     }
 
-    return 0;
+    std::string strres = result;
+    transform(strres.begin(),strres.end(),strres.begin(),::toupper);
+    return strres;
 }
 
-extern "C" int rule_get_key_type(const char* conf)
+int rule_get_key_type(std::string buf)
 {
     YAML::Node config;
-    if(conf == NULL)
-        return -1;
-    
-    if(strlen(conf) <= 0)
+    if(buf.length() == 0)
         return -1;
 
     try {
-        config = YAML::LoadFile(conf);
+        config = YAML::Load(buf);
 	} catch (const YAML::Exception &e) {
-		log4cplus_error("config file(%s) error:%s\n", conf, e.what());
+		log4cplus_error("config buf load error:%s\n", e.what());
 		return -1;
 	}
 
@@ -409,28 +409,82 @@ bool exist_sql_db(hsql::SQLParserResult* ast)
     return false;
 }
 
-extern "C" int rule_sql_match(const char* szsql, const char* osql, const char* dbname, const char* conf)
+bool is_dtc_instance(std::string key)
+{
+    if(key.length() > 0)
+        return true;
+    else
+        return false;
+}
+
+extern "C" int re_load_all_rules()
+{
+    init_log4cplus();
+    FILE *fp = fopen(AGENT_XML_FILE, "r");
+    mxml_node_t *poolnode = NULL;
+
+    if (fp == NULL) {
+        log4cplus_error("conf: failed to open configuration '%s': %s", AGENT_XML_FILE, strerror(errno));
+        return false;
+    }
+    mxml_node_t* tree = mxmlLoadFile(NULL, fp, MXML_TEXT_CALLBACK);
+    if (tree == NULL) {
+        log4cplus_error("mxmlLoadFile error, file: %s", AGENT_XML_FILE);
+        return false;
+    }
+    fclose(fp);
+
+    for (poolnode = mxmlFindElement(tree, tree, "MODULE", NULL, NULL, MXML_DESCEND); 
+        poolnode != NULL;
+		poolnode = mxmlFindElement(poolnode, tree, "MODULE", NULL, NULL, MXML_DESCEND)) 
+    {
+        char* Mid = (char *) mxmlElementGetAttr(poolnode, "Mid");
+        if (Mid == NULL) {
+            log4cplus_error("get Mid from conf '%s' error", AGENT_XML_FILE);
+            mxmlDelete(tree);
+            return false;
+        }
+        int imid = atoi(Mid);
+
+        char* Name = (char *) mxmlElementGetAttr(poolnode, "Name");
+        if (Name == NULL) {
+            log4cplus_error("get Name from conf '%s' error", AGENT_XML_FILE);
+            mxmlDelete(tree);
+            return false;
+        }
+
+        std::string buf = load_dtc_yaml_buffer(imid);
+        if(buf.length() > 0)
+        {
+            log4cplus_debug("push %s into map.", Name);
+            std::string strname = Name;
+            transform(strname.begin(),strname.end(),strname.begin(),::toupper);
+            g_map_dtc_yaml[strname] = buf;
+        }
+        else
+        {
+            log4cplus_error("get dtc: %d yaml buffer error.", imid);
+            return -2;
+        }
+
+    }
+
+    mxmlDelete(tree);
+
+    return 0;
+}
+
+extern "C" int rule_sql_match(const char* szsql, const char* osql, const char* dbsession, char* out_dtckey, int* out_keytype)
 {
     if(!szsql)
         return -1;
         
-    std::string key = "";
+    std::string dtc_key = "";
     std::string sql = szsql;
-    bool flag = false;
 
     init_log4cplus();
 
-    if(conf)
-    {
-        conf_file = std::string(conf);
-        flag = true;
-        log4cplus_debug("flag is true, conf: %s", conf);
-        key = get_key_info(conf_file);
-        if(key.length() == 0)
-            return -1;
-    }
-
-    log4cplus_debug("key len: %d, key: %s, sql len: %d, sql: %s, dbname len: %d, dbname: %s", key.length(), key.c_str(), key.length(), osql, strlen(dbname), std::string(dbname).c_str());
+    log4cplus_debug("input sql: %s", osql);
 
     if(sql.find("WITHOUT@@") != sql.npos)
     {
@@ -449,7 +503,7 @@ extern "C" int rule_sql_match(const char* szsql, const char* osql, const char* d
     if(is_show_db_with_ast(&sql_ast))
     {
         log4cplus_debug("layered: L3, SHOW statment.");
-        return 3;
+        return 2;
     }
 
     if(is_set_with_ast(&sql_ast))
@@ -460,7 +514,7 @@ extern "C" int rule_sql_match(const char* szsql, const char* osql, const char* d
 
     if(is_show_table_with_ast(&sql_ast))
     {
-        if(exist_session_db(dbname))
+        if(exist_session_db(dbsession))
         {
             log4cplus_debug("layered: L2, session db.");
             return 2;
@@ -474,18 +528,43 @@ extern "C" int rule_sql_match(const char* szsql, const char* osql, const char* d
 
     if(is_show_create_table_with_ast(&sql_ast))
     {
-        log4cplus_debug("layered: L2, show create table.");
-        return 2;
+        if(exist_session_db(dbsession))
+        {
+            log4cplus_debug("layered: L2, show create table.");
+            return 2;
+        }
+        else
+        {
+            log4cplus_debug("layered: error, no session db.");
+            return -6;
+        }
     }
 
-    log4cplus_debug("flag: %d %d %d", flag, exist_session_db(dbname), exist_sql_db(&sql_ast));
-    if((exist_session_db(dbname) || (exist_sql_db(&sql_ast))) && flag == false)
+    std::string db_dot_name = get_table_with_db(dbsession, szsql);
+    if(db_dot_name.length() > 0 && g_map_dtc_yaml.count(db_dot_name) > 0)
+    {
+        dtc_key = get_key_info(g_map_dtc_yaml[db_dot_name]);
+        if(dtc_key.length() == 0)
+        {
+            log4cplus_error("get dtc_key from yaml:%s failed.", db_dot_name.c_str());
+            return -1;
+        }
+        strcpy(out_dtckey, dtc_key.c_str());
+        *out_keytype = rule_get_key_type(g_map_dtc_yaml[db_dot_name]);
+    }
+        log4cplus_debug("dtc key len: %d, key: %s, dbname len: %d, dbname: %s", dtc_key.length(), dtc_key.c_str(), strlen(dbsession), std::string(dbsession).c_str());
+
+    log4cplus_debug("Is dtc instance: %d %d %d", is_dtc_instance(dtc_key), exist_session_db(dbsession), exist_sql_db(&sql_ast));
+    if((exist_session_db(dbsession) || (exist_sql_db(&sql_ast))) && !is_dtc_instance(dtc_key))
     {
         log4cplus_debug("layered: L2, db session & single table");
         return 2;
     }
 
-    int ret = re_load_rule();
+    vector<vector<hsql::Expr*> > expr_rules;
+    expr_rules.clear();
+    hsql::SQLParserResult rule_ast;
+    int ret = re_load_rule(g_map_dtc_yaml[db_dot_name], &rule_ast, &expr_rules);
     if(ret != 0)
     {
         log4cplus_error("load rule error:%d", ret);
@@ -542,37 +621,15 @@ extern "C" int rule_sql_match(const char* szsql, const char* osql, const char* d
         log4cplus_debug("temsql: %s", tempsql.c_str());
         ast = &ast2;
     }
-
-    int ext = is_ext_table(&sql_ast, dbname);
-    if(ext == -1)
-    {
-        log4cplus_debug("layered: L2, ext table.");
-        return 2;
-    }
-    else if(ext == -2)
-    {
-        log4cplus_debug("layered: error.");
-        return -2;
-    }
-        
+       
     ret = re_match_sql(&sql_ast, expr_rules, ast);  //rule match
     if(ret == 0 || is_update_delete_type(&sql_ast))
     {
-        if(re_is_cache_sql(&sql_ast, key))  //if exist dtc key.
+        if(re_is_cache_sql(&sql_ast, dtc_key))  //if exist dtc key.
         {
             //L1: DTC cache.
-            std::string tab_name = get_table_name(&sql_ast);
-            std::string conf_tab_name = re_load_table_name();
-            if(tab_name == conf_tab_name)
-            {
-                log4cplus_debug("layered: L1.");
-                return 1;
-            }
-            else
-            {
-                log4cplus_error("layered: L3, table name dismatch: %s, %s", tab_name.c_str(), conf_tab_name.c_str());
-                return 3;
-            }
+            log4cplus_debug("layered: L1.");
+            return 1;
         }
         else
         {

+ 2 - 4
src/rule/rule.h

@@ -3,13 +3,11 @@
 extern "C"{
 #endif
 
-    int rule_sql_match(const char* szsql, const char* osql, const char* dbname, const char* conf);
-    int re_load_table_key(char* key);
+    int rule_sql_match(const char* szsql, const char* osql, const char* dbsession, char* out_dtckey, int* out_keytype);
     int sql_parse_table(const char* szsql, char* out);
-    int rule_get_key_type(const char* conf);
     int rule_get_key(const char* conf, char* out);
     int get_table_with_db(const char* sessiondb, const char* sql, char* result);    
-
+    int re_load_all_rules();
 #ifdef __cplusplus    
 }
 #endif