瀏覽代碼

implement weighted randomized load balancer

serverglen 3 年之前
父節點
當前提交
9c1b4532b9

+ 4 - 0
docs/cn/client.md

@@ -218,6 +218,10 @@ int main() {
 
 随机从列表中选择一台服务器,无需其他设置。和round robin类似,这个算法的前提也是服务器都是类似的。
 
+### wr
+
+即weighted random, 根据服务器列表配置的权重值来选择服务器,服务器被选到的机会正比于其权重值。
+
 ### la
 
 locality-aware,优先选择延时低的下游,直到其延时高于其他机器,无需其他设置。实现原理请查看[Locality-aware load balancing](lalb.md)。

+ 4 - 0
docs/en/client.md

@@ -220,6 +220,10 @@ which is weighted round robin. Choose the next server according to the configure
 
 Randomly choose one server from the list, no other settings. Similarly with round robin, the algorithm assumes that servers to access are similar.
 
+### wr
+
+which is weighted random. Choose the next server according to the configured weight. The chances a server is selected is consistent with its weight.
+
 ### la
 
 which is locality-aware. Perfer servers with lower latencies, until the latency is higher than others, no other settings. Check out [Locality-aware load balancing](lalb.md) for more details.

+ 3 - 0
src/brpc/global.cpp

@@ -43,6 +43,7 @@
 #include "brpc/policy/round_robin_load_balancer.h"
 #include "brpc/policy/weighted_round_robin_load_balancer.h"
 #include "brpc/policy/randomized_load_balancer.h"
+#include "brpc/policy/weighted_randomized_load_balancer.h"
 #include "brpc/policy/locality_aware_load_balancer.h"
 #include "brpc/policy/consistent_hashing_load_balancer.h"
 #include "brpc/policy/hasher.h"
@@ -137,6 +138,7 @@ struct GlobalExtensions {
     RoundRobinLoadBalancer rr_lb;
     WeightedRoundRobinLoadBalancer wrr_lb;
     RandomizedLoadBalancer randomized_lb;
+    WeightedRandomizedLoadBalancer wr_lb;
     LocalityAwareLoadBalancer la_lb;
     ConsistentHashingLoadBalancer ch_mh_lb;
     ConsistentHashingLoadBalancer ch_md5_lb;
@@ -359,6 +361,7 @@ static void GlobalInitializeOrDieImpl() {
     LoadBalancerExtension()->RegisterOrDie("rr", &g_ext->rr_lb);
     LoadBalancerExtension()->RegisterOrDie("wrr", &g_ext->wrr_lb);
     LoadBalancerExtension()->RegisterOrDie("random", &g_ext->randomized_lb);
+    LoadBalancerExtension()->RegisterOrDie("wr", &g_ext->wr_lb);
     LoadBalancerExtension()->RegisterOrDie("la", &g_ext->la_lb);
     LoadBalancerExtension()->RegisterOrDie("c_murmurhash", &g_ext->ch_mh_lb);
     LoadBalancerExtension()->RegisterOrDie("c_md5", &g_ext->ch_md5_lb);

+ 169 - 0
src/brpc/policy/weighted_randomized_load_balancer.cpp

@@ -0,0 +1,169 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+
+#include <algorithm>
+
+#include "butil/fast_rand.h"
+#include "brpc/socket.h"
+#include "brpc/policy/weighted_randomized_load_balancer.h"
+#include "butil/strings/string_number_conversions.h"
+
+namespace brpc {
+namespace policy {
+
+static bool server_compare(const WeightedRandomizedLoadBalancer::Server& lhs, const WeightedRandomizedLoadBalancer::Server& rhs) {
+        return (lhs.current_weight_sum < rhs.current_weight_sum);
+}
+
+bool WeightedRandomizedLoadBalancer::Add(Servers& bg, const ServerId& id) {
+    if (bg.server_list.capacity() < 128) {
+        bg.server_list.reserve(128);
+    }
+    uint32_t weight = 0;
+    if (butil::StringToUint(id.tag, &weight) &&
+        weight > 0) {
+        bool insert_server =
+                 bg.server_map.emplace(id.id, bg.server_list.size()).second;
+        if (insert_server) {
+            uint64_t current_weight_sum = bg.weight_sum + weight;
+            bg.server_list.emplace_back(id.id, weight, current_weight_sum);
+            bg.weight_sum = current_weight_sum;
+            return true;
+        }
+    } else {
+        LOG(ERROR) << "Invalid weight is set: " << id.tag;
+    }
+    return false;
+}
+
+bool WeightedRandomizedLoadBalancer::Remove(Servers& bg, const ServerId& id) {
+    typedef std::map<SocketId, size_t>::iterator MapIter_t;
+    MapIter_t iter = bg.server_map.find(id.id);
+    if (iter != bg.server_map.end()) {
+        size_t index = iter->second;
+        Server remove_server = bg.server_list[index];
+        uint32_t weight_diff = bg.server_list.back().weight - remove_server.weight;
+        bg.weight_sum -= remove_server.weight;
+        bg.server_list[index] = bg.server_list.back();
+        bg.server_list[index].current_weight_sum = remove_server.current_weight_sum + weight_diff;
+        bg.server_map[bg.server_list[index].id] = index;
+        bg.server_list.pop_back();
+        bg.server_map.erase(iter);
+        size_t n = bg.server_list.size();
+        for (index++; index < n; index++) {
+            bg.server_list[index].current_weight_sum += weight_diff;
+        }
+        return true;
+    }
+    return false;
+}
+
+size_t WeightedRandomizedLoadBalancer::BatchAdd(
+    Servers& bg, const std::vector<ServerId>& servers) {
+    size_t count = 0;
+    for (size_t i = 0; i < servers.size(); ++i) {
+        count += !!Add(bg, servers[i]);
+    }
+    return count;
+}
+
+size_t WeightedRandomizedLoadBalancer::BatchRemove(
+    Servers& bg, const std::vector<ServerId>& servers) {
+    size_t count = 0;
+    for (size_t i = 0; i < servers.size(); ++i) {
+        count += !!Remove(bg, servers[i]);
+    }
+    return count;
+}
+
+bool WeightedRandomizedLoadBalancer::AddServer(const ServerId& id) {
+    return _db_servers.Modify(Add, id);
+}
+
+bool WeightedRandomizedLoadBalancer::RemoveServer(const ServerId& id) {
+    return _db_servers.Modify(Remove, id);
+}
+
+size_t WeightedRandomizedLoadBalancer::AddServersInBatch(
+    const std::vector<ServerId>& servers) {
+    return _db_servers.Modify(BatchAdd, servers);
+}
+
+size_t WeightedRandomizedLoadBalancer::RemoveServersInBatch(
+    const std::vector<ServerId>& servers) {
+    return _db_servers.Modify(BatchRemove, servers);
+}
+
+int WeightedRandomizedLoadBalancer::SelectServer(const SelectIn& in, SelectOut* out) {
+    butil::DoublyBufferedData<Servers>::ScopedPtr s;
+    if (_db_servers.Read(&s) != 0) {
+        return ENOMEM;
+    }
+    size_t n = s->server_list.size();
+    if (n == 0) {
+        return ENODATA;
+    }
+    uint64_t weight_sum = s->weight_sum;
+    for (size_t i = 0; i < n; ++i) {
+        uint64_t random_weight = butil::fast_rand_less_than(weight_sum);
+        const Server random_server(0, 0, random_weight);
+        const auto& server = std::lower_bound(s->server_list.begin(), s->server_list.end(), random_server, server_compare);
+        const SocketId id = server->id;
+        if (((i + 1) == n  // always take last chance
+             || !ExcludedServers::IsExcluded(in.excluded, id))
+            && Socket::Address(id, out->ptr) == 0
+            && (*out->ptr)->IsAvailable()) {
+            // We found an available server
+            return 0;
+        }
+    }
+    // After we traversed the whole server list, there is still no
+    // available server
+    return EHOSTDOWN;
+}
+
+LoadBalancer* WeightedRandomizedLoadBalancer::New(
+    const butil::StringPiece&) const {
+    return new (std::nothrow) WeightedRandomizedLoadBalancer;
+}
+
+void WeightedRandomizedLoadBalancer::Destroy() {
+    delete this;
+}
+
+void WeightedRandomizedLoadBalancer::Describe(
+    std::ostream &os, const DescribeOptions& options) {
+    if (!options.verbose) {
+        os << "wr";
+        return;
+    }
+    os << "WeightedRandomized{";
+    butil::DoublyBufferedData<Servers>::ScopedPtr s;
+    if (_db_servers.Read(&s) != 0) {
+        os << "fail to read _db_servers";
+    } else {
+        os << "n=" << s->server_list.size() << ':';
+        for (const auto& server : s->server_list) {
+            os << ' ' << server.id << '(' << server.weight << ')';
+        }
+    }
+    os << '}';
+}
+
+}  // namespace policy
+} // namespace brpc

+ 70 - 0
src/brpc/policy/weighted_randomized_load_balancer.h

@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+
+#ifndef BRPC_POLICY_WEIGHTED_RANDOMIZED_LOAD_BALANCER_H
+#define BRPC_POLICY_WEIGHTED_RANDOMIZED_LOAD_BALANCER_H
+
+#include <map>
+#include <vector>
+#include "butil/containers/doubly_buffered_data.h"
+#include "brpc/load_balancer.h"
+
+namespace brpc {
+namespace policy {
+
+// This LoadBalancer selects server as the assigned weight.
+// Weight is got from tag of ServerId.
+class WeightedRandomizedLoadBalancer : public LoadBalancer {
+public:
+    bool AddServer(const ServerId& id);
+    bool RemoveServer(const ServerId& id);
+    size_t AddServersInBatch(const std::vector<ServerId>& servers);
+    size_t RemoveServersInBatch(const std::vector<ServerId>& servers);
+    int SelectServer(const SelectIn& in, SelectOut* out);
+    LoadBalancer* New(const butil::StringPiece&) const;
+    void Destroy();
+    void Describe(std::ostream& os, const DescribeOptions&);
+
+    struct Server {
+        Server(SocketId s_id = 0, uint32_t s_w = 0, uint64_t s_c_w_s = 0): id(s_id), weight(s_w), current_weight_sum(s_c_w_s) {}
+        SocketId id;
+        uint32_t weight;
+        uint64_t current_weight_sum;
+    };
+    struct Servers {
+        // The value is configured weight and weight_sum for each server.
+        std::vector<Server> server_list;
+        // The value is the index of the server in "server_list".
+        std::map<SocketId, size_t> server_map;
+        uint64_t weight_sum;
+        Servers() : weight_sum(0) {}
+    };
+
+private:
+    static bool Add(Servers& bg, const ServerId& id);
+    static bool Remove(Servers& bg, const ServerId& id);
+    static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
+    static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
+
+    butil::DoublyBufferedData<Servers> _db_servers;
+};
+
+}  // namespace policy
+} // namespace brpc
+
+#endif  // BRPC_POLICY_WEIGHTED_RANDOMIZED_LOAD_BALANCER_H

+ 78 - 0
test/brpc_load_balancer_unittest.cpp

@@ -34,6 +34,7 @@
 #include "brpc/excluded_servers.h" 
 #include "brpc/policy/weighted_round_robin_load_balancer.h"
 #include "brpc/policy/round_robin_load_balancer.h"
+#include "brpc/policy/weighted_randomized_load_balancer.h"
 #include "brpc/policy/randomized_load_balancer.h"
 #include "brpc/policy/locality_aware_load_balancer.h"
 #include "brpc/policy/consistent_hashing_load_balancer.h"
@@ -726,6 +727,83 @@ TEST_F(LoadBalancerTest, weighted_round_robin_no_valid_server) {
     brpc::ExcludedServers::Destroy(exclude);
 }
 
+TEST_F(LoadBalancerTest, weighted_randomized) {
+    const char* servers[] = {
+        "10.92.115.19:8831",
+        "10.42.108.25:8832",
+        "10.36.150.31:8833",
+        "10.36.150.32:8899",
+        "10.92.149.48:8834",
+        "10.42.122.201:8835",
+        "10.42.122.202:8836"
+    };
+    std::string weight[] = {"3", "2", "5", "10", "1ab", "-1", "0"};
+    std::map<butil::EndPoint, int> configed_weight;
+    uint64_t configed_weight_sum = 0;
+    brpc::policy::WeightedRandomizedLoadBalancer wrlb;
+    size_t valid_weight_num = 4;
+
+    // Add server to selected list. The server with invalid weight will be skipped.
+    for (size_t i = 0;  i < ARRAY_SIZE(servers); ++i) {
+        const char *addr = servers[i];
+        butil::EndPoint dummy;
+        ASSERT_EQ(0, str2endpoint(addr, &dummy));
+        brpc::ServerId id(8888);
+        brpc::SocketOptions options;
+        options.remote_side = dummy;
+        options.user = new SaveRecycle;
+        ASSERT_EQ(0, brpc::Socket::Create(options, &id.id));
+        id.tag = weight[i];
+        if (i < valid_weight_num) {
+            int weight_num = 0;
+            ASSERT_TRUE(butil::StringToInt(weight[i], &weight_num));
+            configed_weight[dummy] = weight_num;
+            configed_weight_sum += weight_num;
+            EXPECT_TRUE(wrlb.AddServer(id));
+        } else {
+            EXPECT_FALSE(wrlb.AddServer(id));
+        }
+    }
+
+    // Select the best server according to weight configured.
+    // There are 4 valid servers with weight 3, 2, 5 and 10 respectively.
+    // We run SelectServer for multiple times. The result number of each server seleted should be
+    // weight randomized with weight configured.
+    std::map<butil::EndPoint, size_t> select_result;
+    brpc::SocketUniquePtr ptr;
+    brpc::LoadBalancer::SelectIn in = { 0, false, false, 0u, NULL };
+    brpc::LoadBalancer::SelectOut out(&ptr);
+    int run_times = configed_weight_sum * 10;
+    std::vector<butil::EndPoint> select_servers;
+    for (int i = 0; i < run_times; ++i) {
+        EXPECT_EQ(0, wrlb.SelectServer(in, &out));
+        select_servers.emplace_back(ptr->remote_side());
+        ++select_result[ptr->remote_side()];
+    }
+
+    for (const auto& server : select_servers) {
+        std::cout << "weight randomized=" << server << ", ";
+    }
+    std::cout << std::endl;
+
+    // Check whether selected result is weight with expected.
+    EXPECT_EQ(valid_weight_num, select_result.size());
+    std::cout << "configed_weight_sum=" << configed_weight_sum << " run_times=" << run_times << std::endl;
+    for (const auto& result : select_result) {
+        double actual_rate = result.second * 1.0 / run_times;
+        double expect_rate = configed_weight[result.first] * 1.0 / configed_weight_sum;
+        std::cout << result.first << " weight=" << configed_weight[result.first]
+            << " select_times=" << result.second
+            << " actual_rate=" << actual_rate << " expect_rate=" << expect_rate
+            << " expect_rate/2=" << expect_rate/2 << " expect_rate*2=" << expect_rate*2
+            << std::endl;
+        // actual_rate >= expect_rate / 2
+        ASSERT_GE(actual_rate, expect_rate / 2);
+        // actual_rate <= expect_rate * 2
+        ASSERT_LE(actual_rate, expect_rate * 2);
+    }
+}
+
 TEST_F(LoadBalancerTest, health_check_no_valid_server) {
     const char* servers[] = { 
             "10.92.115.19:8832",