bthread_work_stealing_queue_unittest.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. // Licensed to the Apache Software Foundation (ASF) under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing,
  12. // software distributed under the License is distributed on an
  13. // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. // KIND, either express or implied. See the License for the
  15. // specific language governing permissions and limitations
  16. // under the License.
  17. #include <algorithm> // std::sort
  18. #include <gtest/gtest.h>
  19. #include "butil/time.h"
  20. #include "butil/macros.h"
  21. #include "butil/scoped_lock.h"
  22. #include "bthread/work_stealing_queue.h"
  23. namespace {
  24. typedef size_t value_type;
  25. bool g_stop = false;
  26. const size_t N = 1024*512;
  27. const size_t CAP = 8;
  28. pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
  29. void* steal_thread(void* arg) {
  30. std::vector<value_type> *stolen = new std::vector<value_type>;
  31. stolen->reserve(N);
  32. bthread::WorkStealingQueue<value_type> *q =
  33. (bthread::WorkStealingQueue<value_type>*)arg;
  34. value_type val;
  35. while (!g_stop) {
  36. if (q->steal(&val)) {
  37. stolen->push_back(val);
  38. } else {
  39. asm volatile("pause\n": : :"memory");
  40. }
  41. }
  42. return stolen;
  43. }
  44. void* push_thread(void* arg) {
  45. size_t npushed = 0;
  46. value_type seed = 0;
  47. bthread::WorkStealingQueue<value_type> *q =
  48. (bthread::WorkStealingQueue<value_type>*)arg;
  49. while (true) {
  50. pthread_mutex_lock(&mutex);
  51. const bool pushed = q->push(seed);
  52. pthread_mutex_unlock(&mutex);
  53. if (pushed) {
  54. ++seed;
  55. if (++npushed == N) {
  56. g_stop = true;
  57. break;
  58. }
  59. }
  60. }
  61. return NULL;
  62. }
  63. void* pop_thread(void* arg) {
  64. std::vector<value_type> *popped = new std::vector<value_type>;
  65. popped->reserve(N);
  66. bthread::WorkStealingQueue<value_type> *q =
  67. (bthread::WorkStealingQueue<value_type>*)arg;
  68. while (!g_stop) {
  69. value_type val;
  70. pthread_mutex_lock(&mutex);
  71. const bool res = q->pop(&val);
  72. pthread_mutex_unlock(&mutex);
  73. if (res) {
  74. popped->push_back(val);
  75. }
  76. }
  77. return popped;
  78. }
  79. TEST(WSQTest, sanity) {
  80. bthread::WorkStealingQueue<value_type> q;
  81. ASSERT_EQ(0, q.init(CAP));
  82. pthread_t rth[8];
  83. pthread_t wth, pop_th;
  84. for (size_t i = 0; i < ARRAY_SIZE(rth); ++i) {
  85. ASSERT_EQ(0, pthread_create(&rth[i], NULL, steal_thread, &q));
  86. }
  87. ASSERT_EQ(0, pthread_create(&wth, NULL, push_thread, &q));
  88. ASSERT_EQ(0, pthread_create(&pop_th, NULL, pop_thread, &q));
  89. std::vector<value_type> values;
  90. values.reserve(N);
  91. size_t nstolen = 0, npopped = 0;
  92. for (size_t i = 0; i < ARRAY_SIZE(rth); ++i) {
  93. std::vector<value_type>* res = NULL;
  94. pthread_join(rth[i], (void**)&res);
  95. for (size_t j = 0; j < res->size(); ++j, ++nstolen) {
  96. values.push_back((*res)[j]);
  97. }
  98. }
  99. pthread_join(wth, NULL);
  100. std::vector<value_type>* res = NULL;
  101. pthread_join(pop_th, (void**)&res);
  102. for (size_t j = 0; j < res->size(); ++j, ++npopped) {
  103. values.push_back((*res)[j]);
  104. }
  105. value_type val;
  106. while (q.pop(&val)) {
  107. values.push_back(val);
  108. }
  109. std::sort(values.begin(), values.end());
  110. values.resize(std::unique(values.begin(), values.end()) - values.begin());
  111. ASSERT_EQ(N, values.size());
  112. for (size_t i = 0; i < N; ++i) {
  113. ASSERT_EQ(i, values[i]);
  114. }
  115. std::cout << "stolen=" << nstolen
  116. << " popped=" << npopped
  117. << " left=" << (N - nstolen - npopped) << std::endl;
  118. }
  119. } // namespace