Browse Source

word segment algorithm

shzhulin3 2 years ago
parent
commit
7943c89b58

+ 39 - 0
src/comm/segment/bmm_segment.cc

@@ -0,0 +1,39 @@
+#include "bmm_segment.h"
+
+BmmSegment::BmmSegment()
+{
+}
+
+BmmSegment::~BmmSegment()
+{
+}
+
+void BmmSegment::ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& bmm_list){
+    int maxlen = MAX_WORD_LEN;
+    int len_phrase = phrase.length();
+    int i = len_phrase, j = 0;
+
+    while (i > 0) {
+        int start = i - maxlen;
+        if (start < 0)
+            start = 0;
+        iutf8string phrase_sub = phrase.utf8substr(start, i-start);
+        for (j = 0; j < phrase_sub.length(); j++) {
+            if (j == phrase_sub.length() - 1)
+                break;
+            iutf8string key = phrase_sub.utf8substr(j, phrase_sub.length()-j);
+            if (wordValid(key.stlstring(), appid) == true) {
+                vector<string>::iterator iter = bmm_list.begin();
+                bmm_list.insert(iter, key.stlstring());
+                i -= key.length() - 1;
+                break;
+            }
+        }
+        if (j == phrase_sub.length() - 1) {
+            vector<string>::iterator iter = bmm_list.begin();
+            bmm_list.insert(iter, "" + phrase_sub[j]);
+        }
+        i -= 1;
+    }
+    return;
+}

+ 35 - 0
src/comm/segment/bmm_segment.h

@@ -0,0 +1,35 @@
+/*
+ * =====================================================================================
+ *
+ *       Filename:  bmm_segment.h
+ *
+ *    Description:  bmm_segment class definition.
+ *
+ *        Version:  1.0
+ *        Created:  08/06/2021
+ *       Revision:  none
+ *       Compiler:  gcc
+ *
+ *         Author:  zhulin, shzhulin3@jd.com
+ *        Company:  JD.com, Inc.
+ *
+ * =====================================================================================
+ */
+
+#ifndef __BMM_SEGMENT_H__
+#define __BMM_SEGMENT_H__
+#include "segment.h"
+
+class BmmSegment: public Segment
+{
+private:
+    /* data */
+public:
+    BmmSegment();
+    ~BmmSegment();
+    virtual void ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& vec);
+};
+
+
+
+#endif

+ 59 - 0
src/comm/segment/custom_segment.cc

@@ -0,0 +1,59 @@
+#include "custom_segment.h"
+#include <dlfcn.h>
+#include "../log.h"
+
+static const char *cache_configfile = "../conf/cache.conf";
+
+CustomSegment::CustomSegment(){
+    cache_config_ = new CConfig();
+}
+
+CustomSegment::~CustomSegment(){
+    if(NULL != cache_config_){
+        delete cache_config_;
+    }
+}
+
+bool CustomSegment::Init(string word_path, string train_path){
+    Segment::Init(word_path, train_path);
+    if (cache_config_->ParseConfig(cache_configfile, "search_cache")) {
+        log_error("no cache config or config file [%s] is error", cache_configfile);
+        return false;
+    }
+    //读取配置文件
+    const char* so = cache_config_->GetStrVal("search_cache", "WordSplitSo");
+    if (so == NULL) {
+        log_error("has no so.");
+        return false;
+    }
+    char* fun = (char* )cache_config_->GetStrVal("search_cache", "WordSplitFunction");
+    if (fun == NULL) {
+        log_error("has no function.");
+        return false;
+    }
+    void* dll = dlopen(so, RTLD_NOW|RTLD_GLOBAL);
+    if(dll == (void*)NULL){
+        log_error("dlopen(%s) error: %s", so, dlerror());
+        return false;
+    }
+    word_split_func_ = (split_interface)dlsym(dll, fun);
+    if(word_split_func_ == NULL){
+        log_error("word-split plugin function[%s] not found in [%s]", fun, so);
+        return false;
+    }
+    return true;
+}
+
+void CustomSegment::ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& vec){
+    char res[100] = {'\0'};
+    word_split_func_(phrase.stlstring().c_str(), res, 100);
+    string tmp = "";
+    for(int i = 0; i < strlen(res); i++){
+        if(res[i] != ' '){
+            tmp += res[i];
+        } else {
+            vec.push_back(tmp);
+            tmp = "";
+        }
+    }
+}

+ 40 - 0
src/comm/segment/custom_segment.h

@@ -0,0 +1,40 @@
+/*
+ * =====================================================================================
+ *
+ *       Filename:  custom_segment.h
+ *
+ *    Description:  custom_segment class definition.
+ *
+ *        Version:  1.0
+ *        Created:  15/06/2021
+ *       Revision:  none
+ *       Compiler:  gcc
+ *
+ *         Author:  zhulin, shzhulin3@jd.com
+ *        Company:  JD.com, Inc.
+ *
+ * =====================================================================================
+ */
+
+#ifndef __CUSTOM_SEGMENT_H__
+#define __CUSTOM_SEGMENT_H__
+#include "segment.h"
+#include "../config.h"
+
+typedef void (*split_interface)(const char* str, char* res, int len);
+
+class CustomSegment: public Segment
+{
+public:
+    CustomSegment();
+    ~CustomSegment();
+    virtual bool Init(string word_path, string train_path);
+    virtual void ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& vec);
+private:
+    CConfig* cache_config_;
+    split_interface word_split_func_;
+};
+
+
+
+#endif

+ 102 - 0
src/comm/segment/dag_segment.cc

@@ -0,0 +1,102 @@
+#include "dag_segment.h"
+#include <math.h>
+#include <float.h>
+
+DagSegment::DagSegment()
+{
+}
+
+DagSegment::~DagSegment()
+{
+}
+
+void DagSegment::ConcreteSplit(iutf8string& sentence, uint32_t appid, vector<string>& vec){
+    map<uint32_t, vector<uint32_t> > dag_map;
+    getDag(sentence, appid, dag_map);
+    map<uint32_t, RouteValue> route;
+    calc(sentence, dag_map, route, appid);
+    iutf8string utf8_str(sentence.stlstring());
+    uint32_t N = utf8_str.length();
+    uint32_t i = 0;
+    string buf = "";
+    while (i < N) {
+        uint32_t j = route[i].idx + 1;
+        string l_word = utf8_str.substr(i, j - i);
+        if (isAllAlphaOrDigit(l_word)) {
+            buf += l_word;
+            i = j;
+        }
+        else {
+            if (!buf.empty()) {
+                vec.push_back(buf);
+                buf = "";
+            }
+            vec.push_back(l_word);
+            i = j;
+        }
+    }
+    if (!buf.empty()) {
+        vec.push_back(buf);
+        buf = "";
+    }
+
+    return;
+}
+
+void DagSegment::getDag(iutf8string& utf8_str, uint32_t appid, map<uint32_t, vector<uint32_t> >& map_dag) {
+    uint32_t N = utf8_str.length();
+    for (uint32_t k = 0; k < N; k++) {
+        uint32_t i = k;
+        vector<uint32_t> tmplist;
+        string frag = utf8_str[k];
+        while (i < N) {
+            if (wordValid(frag, appid) == true) {
+                tmplist.push_back(i);
+            }
+            i++;
+            frag = utf8_str.substr(k, i + 1 - k);
+        }
+        if (tmplist.empty()) {
+            tmplist.push_back(k);
+        }
+        map_dag[k] = tmplist;
+    }
+    return;
+}
+
+void DagSegment::calc(iutf8string& utf8_str, const map<uint32_t, vector<uint32_t> >& map_dag, map<uint32_t, RouteValue>& route, uint32_t appid) {
+    uint32_t N = utf8_str.length();
+    RouteValue route_N;
+    route[N] = route_N;
+    double logtotal = log(TOTAL);
+    for (int i = N - 1; i > -1; i--) {
+        vector<uint32_t> vec = map_dag.at(i);
+        double max_route = -DBL_MAX;
+        uint32_t max_idx = 0;
+        for (size_t t = 0; t < vec.size(); t++) {
+            string word = utf8_str.substr(i, vec[t] + 1 - i);
+            WordInfo word_info;
+            uint32_t word_freq = 1;
+            if (word_dict_.find(word) != word_dict_.end()) {
+                map<uint32_t, WordInfo> wordInfo = word_dict_[word];
+                if (wordInfo.find(0) != wordInfo.end()) {
+                    word_info = wordInfo[0];
+                    word_freq = word_info.word_freq;
+                }
+                if (wordInfo.find(appid) != wordInfo.end()) {
+                    word_info = wordInfo[appid];
+                    word_freq = word_info.word_freq;
+                }
+            }
+            double route_value = log(word_freq) - logtotal + route[vec[t] + 1].max_route;
+            if (route_value > max_route) {
+                max_route = route_value;
+                max_idx = vec[t];
+            }
+        }
+        RouteValue route_value;
+        route_value.max_route = max_route;
+        route_value.idx = max_idx;
+        route[i] = route_value;
+    }
+}

+ 45 - 0
src/comm/segment/dag_segment.h

@@ -0,0 +1,45 @@
+/*
+ * =====================================================================================
+ *
+ *       Filename:  dag_segment.h
+ *
+ *    Description:  dag_segment class definition.
+ *
+ *        Version:  1.0
+ *        Created:  08/06/2021
+ *       Revision:  none
+ *       Compiler:  gcc
+ *
+ *         Author:  zhulin, shzhulin3@jd.com
+ *        Company:  JD.com, Inc.
+ *
+ * =====================================================================================
+ */
+
+#ifndef __DAG_SEGMENT_H__
+#define __DAG_SEGMENT_H__
+#include "segment.h"
+
+struct RouteValue {
+    double max_route;
+    uint32_t idx;
+    RouteValue() {
+        max_route = 0;
+        idx = 0;
+    }
+};
+
+class DagSegment: public Segment
+{
+public:
+    DagSegment();
+    ~DagSegment();
+    virtual void ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& vec);
+private:
+    void getDag(iutf8string& sentence, uint32_t appid, map<uint32_t, vector<uint32_t> >& dag_map);
+    void calc(iutf8string& sentence, const map<uint32_t, vector<uint32_t> >& dag_map, map<uint32_t, RouteValue>& route, uint32_t appid);
+};
+
+
+
+#endif

+ 37 - 0
src/comm/segment/fmm_segment.cc

@@ -0,0 +1,37 @@
+#include "fmm_segment.h"
+
+FmmSegment::FmmSegment()
+{
+}
+
+FmmSegment::~FmmSegment()
+{
+}
+
+void FmmSegment::ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& fmm_list){
+    int maxlen = MAX_WORD_LEN;
+    int len_phrase = phrase.length();
+    int i = 0, j = 0;
+
+    while (i < len_phrase) {
+        int end = i + maxlen;
+        if (end >= len_phrase)
+            end = len_phrase;
+        iutf8string phrase_sub = phrase.utf8substr(i, end - i);
+        for (j = phrase_sub.length(); j >= 0; j--) {
+            if (j == 1)
+                break;
+            iutf8string key = phrase_sub.utf8substr(0, j);
+            if (wordValid(key.stlstring(), appid) == true) {
+                fmm_list.push_back(key.stlstring());
+                i += key.length() - 1;
+                break;
+            }
+        }
+        if (j == 1) {
+            fmm_list.push_back(phrase_sub[0]);
+        }
+        i += 1;
+    }
+    return;
+}

+ 35 - 0
src/comm/segment/fmm_segment.h

@@ -0,0 +1,35 @@
+/*
+ * =====================================================================================
+ *
+ *       Filename:  fmm_segment.h
+ *
+ *    Description:  fmm_segment class definition.
+ *
+ *        Version:  1.0
+ *        Created:  08/06/2021
+ *       Revision:  none
+ *       Compiler:  gcc
+ *
+ *         Author:  zhulin, shzhulin3@jd.com
+ *        Company:  JD.com, Inc.
+ *
+ * =====================================================================================
+ */
+
+#ifndef __FMM_SEGMENT_H__
+#define __FMM_SEGMENT_H__
+#include "segment.h"
+
+class FmmSegment: public Segment
+{
+private:
+    /* data */
+public:
+    FmmSegment();
+    ~FmmSegment();
+    virtual void ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& vec);
+};
+
+
+
+#endif

+ 159 - 0
src/comm/segment/hmm_manager.cc

@@ -0,0 +1,159 @@
+#include "hmm_manager.h"
+
+#include <fstream>
+#include "../log.h"
+#include "../utf8_str.h"
+
+HmmManager::HmmManager(){
+    train_corpus_ = new TrainCorpus();
+}
+
+HmmManager::~HmmManager(){
+    if(NULL != train_corpus_){
+        delete train_corpus_;
+    }
+}
+
+bool HmmManager::Init(string train_path, const set<string>& punct_set) {
+    string str;
+    ifstream train_infile;
+    train_infile.open(train_path.c_str());
+    if (train_infile.is_open() == false) {
+        log_error("open file error: %s.\n", train_path.c_str());
+        return false;
+    }
+    string beg_tag = "<BEG>";
+    string end_tag = "<END>";
+    while (getline(train_infile, str))
+    {
+        vector<string> str_vec = splitEx(str, " ");
+        vector<string> line_list;
+        vector<string>::iterator iter = str_vec.begin();
+        for (; iter != str_vec.end(); iter++) {
+            if (punct_set.find(*iter) == punct_set.end() && *iter != "") {
+                line_list.push_back(*iter);
+            }
+        }
+        train_cnt_ += line_list.size();
+        for (int i = -1; i < (int)line_list.size(); i++) {
+            string word1;
+            string word2;
+            if (i == -1) {
+                word1 = beg_tag;
+                word2 = line_list[i + 1];
+            }
+            else if (i == (int)line_list.size() - 1) {
+                word1 = line_list[i];
+                word2 = end_tag;
+            }
+            else {
+                word1 = line_list[i];
+                word2 = line_list[i + 1];
+            }
+            if (next_dict_.find(word1) == next_dict_.end()) {
+                map<string, int> dict;
+                next_dict_[word1] = dict;
+            }
+            if (next_dict_[word1].find(word2) == next_dict_[word1].end()) {
+                next_dict_[word1][word2] = 1;
+            }
+            else {
+                next_dict_[word1][word2] += 1;
+            }
+        }
+    }
+    train_infile.close();
+
+    bool ret = train_corpus_->Init(train_path);
+    if (ret == false) {
+        log_error("train_corpus init error.");
+        return ret;
+    }
+    log_info("total training words length is: %u, next_dict count: %d.", train_cnt_, (int)next_dict_.size());
+    
+    return true;
+}
+
+void HmmManager::HmmSplit(string str, vector<string>& vec){
+    vector<char> pos_list = viterbi(str);
+    string result;
+    iutf8string utf8_str(str);
+    for (size_t i = 0; i < pos_list.size(); i++) {
+        result += utf8_str[i];
+        if (pos_list[i] == 'E') {
+            std::size_t found = result.find_last_of(" ");
+            string new_word = result.substr(found + 1);
+        }
+        if (pos_list[i] == 'E' || pos_list[i] == 'S') {
+            result += ' ';
+        }
+    }
+    if (result[result.size()-1] == ' ') {
+        result = result.substr(0, result.size() - 1);
+    }
+
+    vec = splitEx(result, " ");
+}
+
+vector<char> HmmManager::viterbi(string sentence) {
+    iutf8string utf8_str(sentence);
+    vector<map<char, double> > viterbi_vec;
+    map<char, vector<char> > path;
+    char states[4] = { 'B','M','E','S' };
+    map<char, double> prob_map;
+    for (size_t i = 0; i < sizeof(states); i++) {
+        char y = states[i];
+        double emit_value = train_corpus_->MinEmit();
+        if (train_corpus_->emit_dict[y].find(utf8_str[0]) != train_corpus_->emit_dict[y].end()) {
+            emit_value = train_corpus_->emit_dict[y].at(utf8_str[0]);
+        }
+        prob_map[y] = train_corpus_->start_dict[y] * emit_value;  // 在位置0,以y状态为末尾的状态序列的最大概率
+        path[y].push_back(y);
+    }
+    viterbi_vec.push_back(prob_map);
+    for (int j = 1; j < utf8_str.length(); j++) {
+        map<char, vector<char> > new_path;
+        prob_map.clear();
+        for (size_t k = 0; k < sizeof(states); k++) {
+            char y = states[k];
+            double max_prob = 0.0;
+            char state = ' ';
+            for (size_t m = 0; m < sizeof(states); m++) {
+                char y0 = states[m];  // 从y0 -> y状态的递归
+                //cout << j << " " << y0 << " " << y << " " << V[j - 1][y0] << " " << train_corpus.trans_dict[y0][y] << " " << train_corpus.emit_dict[y].at(utf8_str[j]) << endl;
+                double emit_value = train_corpus_->MinEmit();
+                if (train_corpus_->emit_dict[y].find(utf8_str[j]) != train_corpus_->emit_dict[y].end()) {
+                    emit_value = train_corpus_->emit_dict[y].at(utf8_str[j]);
+                }
+                double prob = viterbi_vec[j - 1][y0] * train_corpus_->trans_dict[y0][y] * emit_value;
+                if (prob > max_prob) {
+                    max_prob = prob;
+                    state = y0;
+                }
+            }
+            prob_map[y] = max_prob;
+            new_path[y] = path[state];
+            new_path[y].push_back(y);
+        }
+        viterbi_vec.push_back(prob_map);
+        path = new_path;
+    }
+    double max_prob = 0.0;
+    char state = ' ';
+    for (size_t i = 0; i < sizeof(states); i++) {
+        char y = states[i];
+        if (viterbi_vec[utf8_str.length() - 1][y] > max_prob) {
+            max_prob = viterbi_vec[utf8_str.length() - 1][y];
+            state = y;
+        }
+    }
+    return path[state];
+}
+
+uint32_t HmmManager::TrainCnt(){
+    return train_cnt_;
+}
+
+map<string, map<string, int> >& HmmManager::NextDict(){
+    return next_dict_;
+}

+ 46 - 0
src/comm/segment/hmm_manager.h

@@ -0,0 +1,46 @@
+/*
+ * =====================================================================================
+ *
+ *       Filename:  hmm_manager.h
+ *
+ *    Description:  hmm manager class definition.
+ *
+ *        Version:  1.0
+ *        Created:  09/06/2021
+ *       Revision:  none
+ *       Compiler:  gcc
+ *
+ *         Author:  zhulin, shzhulin3@jd.com
+ *        Company:  JD.com, Inc.
+ *
+ * =====================================================================================
+ */
+
+#ifndef __HMM_MANAGER_H__
+#define __HMM_MANAGER_H__
+
+#include <string>
+#include <set>
+#include <vector>
+#include <map>
+#include "../trainCorpus.h"
+using namespace std;
+
+class HmmManager{
+public:
+    HmmManager();
+    ~HmmManager();
+    bool Init(string train_path, const set<string>& punct_set);
+    void HmmSplit(string str, vector<string>& vec);
+    map<string, map<string, int> >& NextDict();
+    uint32_t TrainCnt();
+private:
+    vector<char> viterbi(string sentence);
+private:
+    uint32_t train_cnt_;
+    TrainCorpus* train_corpus_;
+    map<string, map<string, int> > next_dict_;
+};
+
+
+#endif

+ 220 - 0
src/comm/segment/ngram_segment.cc

@@ -0,0 +1,220 @@
+#include "ngram_segment.h"
+#include <math.h>
+
+NgramSegment::NgramSegment()
+{
+}
+
+NgramSegment::~NgramSegment()
+{
+}
+
+void NgramSegment::ConcreteSplit(iutf8string& utf8_str, uint32_t appid, vector<string>& parse_list){
+    vector<string> parse_list1;
+    vector<string> parse_list2;
+    fmm(utf8_str, appid, parse_list1);
+    bmm(utf8_str, appid, parse_list2);
+    parse_list1.insert(parse_list1.begin(), "<BEG>");
+    parse_list1.push_back("<END>");
+    parse_list2.insert(parse_list2.begin(), "<BEG>");
+    parse_list2.push_back("<END>");
+    
+    // CalList1和CalList2分别记录两个句子词序列不同的部分
+    vector<string> cal_list1;
+    vector<string> cal_list2;
+    // pos1和pos2记录两个句子的当前字的位置,cur1和cur2记录两个句子的第几个词
+    uint32_t pos1 = 0;
+    uint32_t pos2 = 0;
+    uint32_t cur1 = 0;
+    uint32_t cur2 = 0;
+    while (1) {
+        if (cur1 == parse_list1.size() && cur2 == parse_list2.size()) {
+            break;
+        }
+        // 如果当前位置一样
+        if (pos1 == pos2) {
+            // 当前位置一样,并且词也一样
+            if (parse_list1[cur1].size() == parse_list2[cur2].size()) {
+                pos1 += parse_list1[cur1].size();
+                pos2 += parse_list2[cur2].size();
+                // 说明此时得到两个不同的词序列,根据bigram选择概率大的
+                // 注意算不同的时候要考虑加上前面一个词和后面一个词,拼接的时候再去掉即可
+                if (cal_list1.size() > 0) {
+                    cal_list1.insert(cal_list1.begin(), parse_list[parse_list.size() - 1]);
+                    cal_list2.insert(cal_list2.begin(), parse_list[parse_list.size() - 1]);
+                    if (cur1 < parse_list1.size()-1) {
+                        cal_list1.push_back(parse_list1[cur1]);
+                        cal_list2.push_back(parse_list2[cur2]);
+                    }
+                    double p1 = calSegProbability(cal_list1);
+                    double p2 = calSegProbability(cal_list2);
+
+                    vector<string> cal_list = (p1 > p2) ? cal_list1 : cal_list2;
+                    cal_list.erase(cal_list.begin());
+                    if (cur1 < parse_list1.size() - 1) {
+                        cal_list.pop_back();
+                    }
+                    parse_list.insert(parse_list.end(), cal_list.begin(), cal_list.end());
+                    cal_list1.clear();
+                    cal_list2.clear();
+                }
+                parse_list.push_back(parse_list1[cur1]);
+                cur1++;
+                cur2++;
+            }
+            // pos相同,len(ParseList1[cur1])不同,向后滑动,不同的添加到list中
+            else if (parse_list1[cur1].size() > parse_list2[cur2].size()) {
+                cal_list2.push_back(parse_list2[cur2]);
+                pos2 += parse_list2[cur2].size();
+                cur2++;
+            }
+            else {
+                cal_list1.push_back(parse_list1[cur1]);
+                pos1 += parse_list1[cur1].size();
+                cur1++;
+            }
+        }
+        else { 
+            // pos不同,而结束的位置相同,两个同时向后滑动
+            if (pos1 + parse_list1[cur1].size() == pos2 + parse_list2[cur2].size()) {
+                cal_list1.push_back(parse_list1[cur1]);
+                cal_list2.push_back(parse_list2[cur2]);
+                pos1 += parse_list1[cur1].size();
+                pos2 += parse_list2[cur2].size();
+                cur1++;
+                cur2++;
+            }
+            else if (pos1 + parse_list1[cur1].size() > pos2 + parse_list2[cur2].size()) {
+                cal_list2.push_back(parse_list2[cur2]);
+                pos2 += parse_list2[cur2].size();
+                cur2++;
+            }
+            else {
+                cal_list1.push_back(parse_list1[cur1]);
+                pos1 += parse_list1[cur1].size();
+                cur1++;
+            }
+        }
+    }
+    parse_list.erase(parse_list.begin());
+    parse_list.pop_back();
+    return;
+}
+
+void NgramSegment::fmm(iutf8string& phrase, uint32_t appid, vector<string>& fmm_list) {
+    int maxlen = MAX_WORD_LEN;
+    int len_phrase = phrase.length();
+    int i = 0, j = 0;
+
+    while (i < len_phrase) {
+        int end = i + maxlen;
+        if (end >= len_phrase)
+            end = len_phrase;
+        iutf8string phrase_sub = phrase.utf8substr(i, end - i);
+        for (j = phrase_sub.length(); j >= 0; j--) {
+            if (j == 1)
+                break;
+            iutf8string key = phrase_sub.utf8substr(0, j);
+            if (wordValid(key.stlstring(), appid) == true) {
+                fmm_list.push_back(key.stlstring());
+                i += key.length() - 1;
+                break;
+            }
+        }
+        if (j == 1) {
+            fmm_list.push_back(phrase_sub[0]);
+        }
+        i += 1;
+    }
+    return;
+}
+
+void NgramSegment::bmm(iutf8string& phrase, uint32_t appid, vector<string>& bmm_list) {
+    int maxlen = MAX_WORD_LEN;
+    int len_phrase = phrase.length();
+    int i = len_phrase, j = 0;
+
+    while (i > 0) {
+        int start = i - maxlen;
+        if (start < 0)
+            start = 0;
+        iutf8string phrase_sub = phrase.utf8substr(start, i-start);
+        for (j = 0; j < phrase_sub.length(); j++) {
+            if (j == phrase_sub.length() - 1)
+                break;
+            iutf8string key = phrase_sub.utf8substr(j, phrase_sub.length()-j);
+            if (wordValid(key.stlstring(), appid) == true) {
+                vector<string>::iterator iter = bmm_list.begin();
+                bmm_list.insert(iter, key.stlstring());
+                i -= key.length() - 1;
+                break;
+            }
+        }
+        if (j == phrase_sub.length() - 1) {
+            vector<string>::iterator iter = bmm_list.begin();
+            bmm_list.insert(iter, "" + phrase_sub[j]);
+        }
+        i -= 1;
+    }
+    return;
+}
+
+double NgramSegment::calSegProbability(const vector<string>& vec) {
+    double p = 0;
+    string word1;
+    string word2;
+    // 由于概率很小,对连乘做了取对数处理转化为加法
+    for (int pos = 0; pos < (int)vec.size(); pos++) {
+        if (pos != (int)vec.size() - 1) {
+            // 乘以后面词的条件概率
+            word1 = vec[pos];
+            word2 = vec[pos + 1];
+            if (hmm_manager_->NextDict().find(word1) == hmm_manager_->NextDict().end()) {
+                // 加1平滑
+                p += log(1.0 / hmm_manager_->TrainCnt());
+            }
+            else {
+                double numerator = 1.0;
+                uint32_t denominator = hmm_manager_->TrainCnt();
+                map<string, int>::iterator iter = hmm_manager_->NextDict()[word1].begin();
+                for (; iter != hmm_manager_->NextDict()[word1].end(); iter++) {
+                    if (iter->first == word2) {
+                        numerator += iter->second;
+                    }
+                    denominator += iter->second;
+                }
+                p += log(numerator / denominator);
+            }
+        }
+        // 乘以第一个词的概率
+        if ((pos == 0 && vec[pos] != "<BEG>") || (pos == 1 && vec[0] == "<BEG>")) {
+            uint32_t word_freq = 0;
+            WordInfo word_info;
+            if (getWordInfo(vec[pos], 0, word_info)) {
+                word_freq = word_info.word_freq;
+                p += log(word_freq + 1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
+            }
+            else {
+                p += log(1.0 / hmm_manager_->NextDict().size() + hmm_manager_->TrainCnt());
+            }
+        }
+    }
+
+    return p;
+}
+
+bool NgramSegment::getWordInfo(string word, uint32_t appid, WordInfo& word_info) {
+   if (word_dict_.find(word) != word_dict_.end()) {
+       map<uint32_t, WordInfo> wordInfo = word_dict_[word];
+       if (wordInfo.find(0) != wordInfo.end()) {
+           word_info = wordInfo[0];
+           return true;
+       }
+       if (wordInfo.find(appid) != wordInfo.end()) {
+           word_info = wordInfo[appid];
+           return true;
+       }
+   }
+
+   return false;
+}

+ 38 - 0
src/comm/segment/ngram_segment.h

@@ -0,0 +1,38 @@
+/*
+ * =====================================================================================
+ *
+ *       Filename:  ngram_segment.h
+ *
+ *    Description:  ngram_segment class definition.
+ *
+ *        Version:  1.0
+ *        Created:  08/06/2021
+ *       Revision:  none
+ *       Compiler:  gcc
+ *
+ *         Author:  zhulin, shzhulin3@jd.com
+ *        Company:  JD.com, Inc.
+ *
+ * =====================================================================================
+ */
+
+#ifndef __NGRAM_SEGMENT_H__
+#define __NGRAM_SEGMENT_H__
+#include "segment.h"
+
+class NgramSegment: public Segment
+{
+public:
+    NgramSegment();
+    ~NgramSegment();
+    virtual void ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& vec);
+private:
+    void fmm(iutf8string& phrase, uint32_t appid, vector<string>& vec);
+    void bmm(iutf8string& phrase, uint32_t appid, vector<string>& vec);
+    double calSegProbability(const vector<string>& vec);
+    bool getWordInfo(string word, uint32_t appid, WordInfo& word_info);
+};
+
+
+
+#endif

+ 255 - 0
src/comm/segment/segment.cc

@@ -0,0 +1,255 @@
+#include "segment.h"
+
+#include <fstream>
+#include <stdlib.h>
+
+#include "../log.h"
+
+#define ALPHA_DIGIT "01234567891234567890\
+abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSUVWXYZ"
+
+Segment::Segment(){
+    hmm_manager_ = new HmmManager();
+}
+
+Segment::~Segment(){
+    if(NULL != hmm_manager_){
+        delete hmm_manager_;
+    }
+}
+
+bool Segment::isAlphaOrDigit(string str) {
+    if (alpha_set_.find(str) != alpha_set_.end()){
+        return true;
+    }
+    return false;
+}
+
+bool Segment::wordValid(string word, uint32_t appid) {
+    if(punct_set_.find(word) != punct_set_.end()){
+        return false;
+    }
+    if (word_dict_.find(word) != word_dict_.end()) {
+        map<uint32_t, WordInfo> wordInfo = word_dict_[word];
+        if (wordInfo.find(0) != wordInfo.end() || wordInfo.find(appid) != wordInfo.end()) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+bool Segment::Init(string word_path, string train_path){
+    string en_punct = ",.!?/'\"<>\\:;\n";
+    string punct = ",。!?、;:“”‘’()《》 ";
+    punct = punct.append(en_punct);
+    iutf8string utf8_punct(punct);
+    for (int i = 0; i < utf8_punct.length(); i++) {
+        punct_set_.insert(utf8_punct[i]);
+    }
+
+    string alphadigit = ALPHA_DIGIT;
+    iutf8string utf8_alpha(alphadigit);
+    for (int i = 0; i < utf8_alpha.length(); i++) {
+        alpha_set_.insert(utf8_alpha[i]);
+    }
+
+    string str;
+    ifstream word_infile;
+    word_infile.open(word_path.c_str());
+    if (word_infile.is_open() == false) {
+        log_error("open file error: %s.\n", word_path.c_str());
+        return false;
+    }
+
+    uint32_t word_id = 0;
+    uint32_t appid = 0;
+    string word;
+    uint32_t word_freq = 0;
+    while (getline(word_infile, str))
+    {
+        vector<string> str_vec = splitEx(str, "\t");
+        word_id = atoi(str_vec[0].c_str());
+        word = str_vec[1];
+        appid = atoi(str_vec[2].c_str());
+        word_freq = atoi(str_vec[3].c_str());
+        WordInfo word_info;
+        word_info.appid = appid;
+        word_info.word_freq = word_freq;
+        word_info.word_id = word_id;
+        word_dict_[word][appid] = word_info;
+    }
+    log_info("word_dict count: %d", (int)word_dict_.size());
+
+    bool ret = hmm_manager_->Init(train_path, punct_set_);
+    if(false == ret){
+        log_error("hmm_manager_ init error: %s.\n", train_path.c_str());
+        return false;
+    }
+    return true;
+}
+
+void Segment::Split(iutf8string& phrase, uint32_t appid, vector<string>& new_res_all, bool hmm_flag){
+    vector<string> sen_list;
+    set<string> special_set;  // 记录英文和数字字符串
+    string tmp_words = "";
+    bool flag = false; // 记录是否有英文或者数字的flag
+    for (int i = 0; i < phrase.length(); i++) {  // 对句子进行预处理:去掉标点、提取出英文或数字,只对连续的汉字进行中文分词
+        if (isAlphaOrDigit(phrase[i])) {
+            if (tmp_words != "" && flag == false) {
+                sen_list.push_back(tmp_words);
+                tmp_words = "";
+            }
+            flag = true;
+            tmp_words += phrase[i];
+        }
+        else if(punct_set_.find(phrase[i]) != punct_set_.end()){
+            if (tmp_words != "") {
+                sen_list.push_back(tmp_words);
+                sen_list.push_back(phrase[i]);
+                if (flag == true) {
+                    special_set.insert(tmp_words);
+                    flag = false;
+                }
+                tmp_words = "";
+            }
+        }
+        else {
+            if (flag == true) {
+                sen_list.push_back(tmp_words);
+                special_set.insert(tmp_words);
+                flag = false;
+                tmp_words = phrase[i];
+            }
+            else {
+                tmp_words += phrase[i];
+            }
+        }
+    }
+    if (tmp_words != "") {
+        sen_list.push_back(tmp_words);
+        if (flag == true) {
+            special_set.insert(tmp_words);
+        }
+    }
+    tmp_words = "";
+    vector<string> res_all;
+    for (int i = 0; i < (int)sen_list.size(); i++) {
+        // special_set中保存了连续的字母数字串,不需要进行分词
+        if (special_set.find(sen_list[i]) == special_set.end() && punct_set_.find(sen_list[i]) == punct_set_.end()) {
+            iutf8string utf8_str(sen_list[i]);
+            vector<string> parse_list;
+            ConcreteSplit(utf8_str, appid, parse_list);
+            res_all.insert(res_all.end(), parse_list.begin(), parse_list.end());
+        }else { // 英文或数字需要放入到res_all,标点符号不需要
+            if(punct_set_.find(sen_list[i]) == punct_set_.end()){
+                res_all.push_back(sen_list[i]);
+            }
+        }
+    }
+
+    if (hmm_flag == false) {
+        new_res_all.assign(res_all.begin(), res_all.end());
+    } else {
+        // 使用HMM发现新词
+        dealByHmmMgr(appid, res_all, new_res_all);
+    }
+
+    return;
+}
+
+void Segment::dealByHmmMgr(uint32_t appid, const vector<string>& res_all, vector<string>& new_res_all){
+    string buf = "";
+    for (size_t i = 0; i < res_all.size(); i++) {
+        iutf8string utf8_str(res_all[i]);
+        if (utf8_str.length() == 1 && punct_set_.find(res_all[i]) == punct_set_.end() && res_all[i].length() > 1) { // 确保res_all[i]是汉字
+            buf += res_all[i];
+        }
+        else {
+            if (buf.length() > 0) {
+                iutf8string utf8_buf(buf);
+                if (utf8_buf.length() == 1) {
+                    new_res_all.push_back(buf);
+                }
+                else if (wordValid(buf, appid) == false) { // 连续的单字组合起来,使用HMM算法进行分词
+                    vector<string> vec;
+                    hmm_manager_->HmmSplit(buf, vec);
+                    new_res_all.insert(new_res_all.end(), vec.begin(), vec.end());
+                }
+                else { // 是否有这种情况
+                    new_res_all.push_back(buf);
+                }
+            }
+            buf = "";
+            new_res_all.push_back(res_all[i]);
+        }
+    }
+
+    if (buf.length() > 0) {
+        iutf8string utf8_buf(buf);
+        if (utf8_buf.length() == 1) {
+            new_res_all.push_back(buf);
+        }
+        else if (wordValid(buf, appid) == false) { // 连续的单字组合起来,使用HMM算法进行分词
+            vector<string> vec;
+            hmm_manager_->HmmSplit(buf, vec);
+            new_res_all.insert(new_res_all.end(), vec.begin(), vec.end());
+        }
+        else { // 是否有这种情况
+            new_res_all.push_back(buf);
+        }
+        buf = "";
+    }
+}
+
+void Segment::CutForSearch(iutf8string& phrase, uint32_t appid, vector<vector<string> >& search_res_all) {
+    // 搜索引擎模式
+    vector<string> new_res_all;
+    Split(phrase, appid, new_res_all);
+    for (size_t i = 0; i < new_res_all.size(); i++) {
+        vector<string> vec;
+        iutf8string utf8_str(new_res_all[i]);
+        if (utf8_str.length() > 2 && isAllAlphaOrDigit(new_res_all[i]) == false) {
+            for (int j = 0; j < utf8_str.length() - 1; j++) {
+                string tmp_str = utf8_str.substr(j, 2);
+                if (wordValid(tmp_str, appid) == true) {
+                    vec.push_back(tmp_str);
+                }
+            }
+        }
+        if (utf8_str.length() > 3 && isAllAlphaOrDigit(new_res_all[i]) == false) {
+            for (int j = 0; j < utf8_str.length() - 2; j++) {
+                string tmp_str = utf8_str.substr(j, 3);
+                if (wordValid(tmp_str, appid) == true) {
+                    vec.push_back(tmp_str);
+                }
+            }
+        }
+        vec.push_back(new_res_all[i]);
+        search_res_all.push_back(vec);
+    }
+
+    return;
+}
+
+bool Segment::isAllAlphaOrDigit(string str) {
+    bool flag = true;
+    size_t i = 0;
+    for (; i < str.size(); i++) {
+        if (!isupper(str[i]) && !islower(str[i]) && !isdigit(str[i])) {
+            flag = false;
+            break;
+        }
+    }
+    return flag;
+}
+
+void Segment::CutNgram(iutf8string& phrase, vector<string>& search_res, uint32_t n) {
+    uint32_t N = (n > (uint32_t)phrase.length()) ? (uint32_t)phrase.length() : n;
+    for (size_t i = 1; i <= N; i++) {
+        for (size_t j = 0; j < (size_t)phrase.length() - i + 1; j++) {
+            string tmp_str = phrase.substr(j, i);
+            search_res.push_back(tmp_str);
+        }
+    }
+}

+ 61 - 0
src/comm/segment/segment.h

@@ -0,0 +1,61 @@
+/*
+ * =====================================================================================
+ *
+ *       Filename:  segment.h
+ *
+ *    Description:  segment class definition.
+ *
+ *        Version:  1.0
+ *        Created:  08/06/2021
+ *       Revision:  none
+ *       Compiler:  gcc
+ *
+ *         Author:  zhulin, shzhulin3@jd.com
+ *        Company:  JD.com, Inc.
+ *
+ * =====================================================================================
+ */
+
+#ifndef __SEGMENT_H__
+#define __SEGMENT_H__
+
+#include "hmm_manager.h"
+#include "../utf8_str.h"
+#define MAX_WORD_LEN 8
+#define TOTAL 8000000
+
+struct WordInfo {
+    WordInfo() {
+        word_id = 0;
+        word_freq = 0;
+        appid = 0;
+    }
+    uint32_t word_id;
+    uint32_t word_freq;
+    uint32_t appid;
+};
+
+class Segment{
+public:
+    Segment();
+    virtual ~Segment();
+    virtual bool Init(string word_path, string train_path);
+    void CutForSearch(iutf8string& phrase, uint32_t appid, vector<vector<string> >& search_res_all);
+    void CutNgram(iutf8string& phrase, vector<string>& search_res, uint32_t n);
+    void Split(iutf8string& phrase, uint32_t appid, vector<string>& vec, bool hmm_flag = false);
+    virtual void ConcreteSplit(iutf8string& phrase, uint32_t appid, vector<string>& vec) = 0;
+
+protected:
+    bool isAllAlphaOrDigit(string str);
+    bool isAlphaOrDigit(string str);
+    bool wordValid(string word, uint32_t appid);
+    void dealByHmmMgr(uint32_t appid, const vector<string>& old_vec, vector<string>& new_vec);
+
+protected:
+    map<string, map<uint32_t, WordInfo> > word_dict_;
+    set<string> punct_set_;
+    set<string> alpha_set_;
+    HmmManager* hmm_manager_;
+};
+
+#endif

+ 22 - 0
src/comm/utf8_str.cc

@@ -13,6 +13,23 @@ iutf8string::iutf8string(const char* str)
 	refresh();
 }
 
+iutf8string::iutf8string(const iutf8string& str){
+	data = str.stlstring();
+	refresh();
+}
+
+iutf8string& iutf8string::operator=(const iutf8string& str){
+	if(this != &str){
+		if(offerset != NULL){
+			delete[] offerset;
+			offerset = NULL;
+		}
+		data = str.stlstring();
+		refresh();
+	}
+	return *this;
+}
+
 iutf8string::~iutf8string()
 {
 	delete[] offerset;
@@ -23,6 +40,11 @@ string iutf8string::stlstring()
 	return data;
 }
 
+string iutf8string::stlstring() const
+{
+	return data;
+}
+
 const char* iutf8string::c_str()
 {
 	return data.c_str();

+ 3 - 0
src/comm/utf8_str.h

@@ -29,6 +29,8 @@ public:
 	iutf8string(const std::string& str);
 	iutf8string(const char* str);
 	~iutf8string();
+	iutf8string(const iutf8string& str);
+	iutf8string& operator=(const iutf8string& str);
 
 public:
 	int length();
@@ -36,6 +38,7 @@ public:
 	iutf8string operator + (iutf8string& str);
 	std::string operator [](int index);
 	std::string stlstring();
+	std::string stlstring() const;
 	const char* c_str();
 	iutf8string utf8substr(int u8start_index, int u8_length);
 	std::string substr(int u8start_index, int u8_length);