1
0
Эх сурвалжийг харах

用CommonPool替代compleatable,避免大量异步同时执行

wuweifeng10 3 жил өмнө
parent
commit
3e4e382cbe

+ 49 - 19
src/main/java/com/jd/platform/async/executor/Async.java

@@ -15,22 +15,29 @@ import java.util.stream.Collectors;
  * @version 1.0
  */
 public class Async {
-    public static final ThreadPoolExecutor COMMON_POOL =
+    private static final ThreadPoolExecutor COMMON_POOL =
             new ThreadPoolExecutor(Runtime.getRuntime().availableProcessors() * 2, 1024,
                     15L, TimeUnit.SECONDS,
                     new LinkedBlockingQueue<>(),
                     (ThreadFactory) Thread::new);
 
-    public static boolean beginWork(long timeout, ThreadPoolExecutor pool, List<WorkerWrapper> workerWrappers) throws ExecutionException, InterruptedException {
+    private static ExecutorService executorService;
+
+    /**
+     * 出发点
+     */
+    public static boolean beginWork(long timeout, ExecutorService executorService, List<WorkerWrapper> workerWrappers) throws ExecutionException, InterruptedException {
         if(workerWrappers == null || workerWrappers.size() == 0) {
             return false;
         }
+        //保存线程池变量
+        Async.executorService = executorService;
         //定义一个map,存放所有的wrapper,key为wrapper的唯一id,value是该wrapper,可以从value中获取wrapper的result
         Map<String, WorkerWrapper> forParamUseWrappers = new ConcurrentHashMap<>();
         CompletableFuture[] futures = new CompletableFuture[workerWrappers.size()];
         for (int i = 0; i < workerWrappers.size(); i++) {
             WorkerWrapper wrapper = workerWrappers.get(i);
-            futures[i] = CompletableFuture.runAsync(() -> wrapper.work(pool, timeout, forParamUseWrappers), pool);
+            futures[i] = CompletableFuture.runAsync(() -> wrapper.work(executorService, timeout, forParamUseWrappers), executorService);
         }
         try {
             CompletableFuture.allOf(futures).get(timeout, TimeUnit.MILLISECONDS);
@@ -48,12 +55,12 @@ public class Async {
     /**
      * 如果想自定义线程池,请传pool。不自定义的话,就走默认的COMMON_POOL
      */
-    public static boolean beginWork(long timeout, ThreadPoolExecutor pool, WorkerWrapper... workerWrapper) throws ExecutionException, InterruptedException {
+    public static boolean beginWork(long timeout, ExecutorService executorService, WorkerWrapper... workerWrapper) throws ExecutionException, InterruptedException {
         if(workerWrapper == null || workerWrapper.length == 0) {
             return false;
         }
         List<WorkerWrapper> workerWrappers =  Arrays.stream(workerWrapper).collect(Collectors.toList());
-        return beginWork(timeout, pool, workerWrappers);
+        return beginWork(timeout, executorService, workerWrappers);
     }
 
     /**
@@ -71,19 +78,36 @@ public class Async {
             groupCallback = new DefaultGroupCallback();
         }
         IGroupCallback finalGroupCallback = groupCallback;
-        COMMON_POOL.submit(() -> {
-            try {
-                boolean success = beginWork(timeout, COMMON_POOL, workerWrapper);
-                if (success) {
-                    finalGroupCallback.success(Arrays.asList(workerWrapper));
-                } else {
-                    finalGroupCallback.failure(Arrays.asList(workerWrapper), new TimeoutException());
+        if (executorService != null) {
+            executorService.submit(() -> {
+                try {
+                    boolean success = beginWork(timeout, COMMON_POOL, workerWrapper);
+                    if (success) {
+                        finalGroupCallback.success(Arrays.asList(workerWrapper));
+                    } else {
+                        finalGroupCallback.failure(Arrays.asList(workerWrapper), new TimeoutException());
+                    }
+                } catch (ExecutionException | InterruptedException e) {
+                    e.printStackTrace();
+                    finalGroupCallback.failure(Arrays.asList(workerWrapper), e);
                 }
-            } catch (ExecutionException | InterruptedException e) {
-                e.printStackTrace();
-                finalGroupCallback.failure(Arrays.asList(workerWrapper), e);
-            }
-        });
+            });
+        } else {
+            COMMON_POOL.submit(() -> {
+                try {
+                    boolean success = beginWork(timeout, COMMON_POOL, workerWrapper);
+                    if (success) {
+                        finalGroupCallback.success(Arrays.asList(workerWrapper));
+                    } else {
+                        finalGroupCallback.failure(Arrays.asList(workerWrapper), new TimeoutException());
+                    }
+                } catch (ExecutionException | InterruptedException e) {
+                    e.printStackTrace();
+                    finalGroupCallback.failure(Arrays.asList(workerWrapper), e);
+                }
+            });
+        }
+
     }
 
     /**
@@ -102,9 +126,15 @@ public class Async {
 
     }
 
-
+    /**
+     * 关闭线程池
+     */
     public static void shutDown() {
-        COMMON_POOL.shutdown();
+        if (executorService != null) {
+            executorService.shutdown();
+        } else {
+            COMMON_POOL.shutdown();
+        }
     }
 
     public static String getThreadCount() {

+ 17 - 17
src/main/java/com/jd/platform/async/wrapper/WorkerWrapper.java

@@ -12,7 +12,7 @@ import com.jd.platform.async.worker.WorkResult;
 import java.util.*;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
 
 /**
@@ -96,7 +96,7 @@ public class WorkerWrapper<T, V> {
      * 开始工作
      * fromWrapper代表这次work是由哪个上游wrapper发起的
      */
-    private void work(ThreadPoolExecutor poolExecutor, WorkerWrapper fromWrapper, long remainTime, Map<String, WorkerWrapper> forParamUseWrappers) {
+    private void work(ExecutorService executorService, WorkerWrapper fromWrapper, long remainTime, Map<String, WorkerWrapper> forParamUseWrappers) {
         this.forParamUseWrappers = forParamUseWrappers;
         //将自己放到所有wrapper的集合里去
         forParamUseWrappers.put(id, this);
@@ -104,13 +104,13 @@ public class WorkerWrapper<T, V> {
         //总的已经超时了,就快速失败,进行下一个
         if (remainTime <= 0) {
             fastFail(INIT, null);
-            beginNext(poolExecutor, now, remainTime);
+            beginNext(executorService, now, remainTime);
             return;
         }
         //如果自己已经执行过了。
         //可能有多个依赖,其中的一个依赖已经执行完了,并且自己也已开始执行或执行完毕。当另一个依赖执行完毕,又进来该方法时,就不重复处理了
         if (getState() == FINISH || getState() == ERROR) {
-            beginNext(poolExecutor, now, remainTime);
+            beginNext(executorService, now, remainTime);
             return;
         }
 
@@ -119,7 +119,7 @@ public class WorkerWrapper<T, V> {
             //如果自己的next链上有已经出结果或已经开始执行的任务了,自己就不用继续了
             if (!checkNextWrapperResult()) {
                 fastFail(INIT, new SkippedException());
-                beginNext(poolExecutor, now, remainTime);
+                beginNext(executorService, now, remainTime);
                 return;
             }
         }
@@ -127,7 +127,7 @@ public class WorkerWrapper<T, V> {
         //如果没有任何依赖,说明自己就是第一批要执行的
         if (dependWrappers == null || dependWrappers.size() == 0) {
             fire();
-            beginNext(poolExecutor, now, remainTime);
+            beginNext(executorService, now, remainTime);
             return;
         }
 
@@ -139,17 +139,17 @@ public class WorkerWrapper<T, V> {
         //只有一个依赖
         if (dependWrappers.size() == 1) {
             doDependsOneJob(fromWrapper);
-            beginNext(poolExecutor, now, remainTime);
+            beginNext(executorService, now, remainTime);
         } else {
             //有多个依赖时
-            doDependsJobs(poolExecutor, dependWrappers, fromWrapper, now, remainTime);
+            doDependsJobs(executorService, dependWrappers, fromWrapper, now, remainTime);
         }
 
     }
 
 
-    public void work(ThreadPoolExecutor poolExecutor, long remainTime, Map<String, WorkerWrapper> forParamUseWrappers) {
-        work(poolExecutor, null, remainTime, forParamUseWrappers);
+    public void work(ExecutorService executorService, long remainTime, Map<String, WorkerWrapper> forParamUseWrappers) {
+        work(executorService, null, remainTime, forParamUseWrappers);
     }
 
     /**
@@ -179,21 +179,21 @@ public class WorkerWrapper<T, V> {
     /**
      * 进行下一个任务
      */
-    private void beginNext(ThreadPoolExecutor poolExecutor, long now, long remainTime) {
+    private void beginNext(ExecutorService executorService, long now, long remainTime) {
         //花费的时间
         long costTime = SystemClock.now() - now;
         if (nextWrappers == null) {
             return;
         }
         if (nextWrappers.size() == 1) {
-            nextWrappers.get(0).work(poolExecutor, WorkerWrapper.this, remainTime - costTime, forParamUseWrappers);
+            nextWrappers.get(0).work(executorService, WorkerWrapper.this, remainTime - costTime, forParamUseWrappers);
             return;
         }
         CompletableFuture[] futures = new CompletableFuture[nextWrappers.size()];
         for (int i = 0; i < nextWrappers.size(); i++) {
             int finalI = i;
             futures[i] = CompletableFuture.runAsync(() -> nextWrappers.get(finalI)
-                    .work(poolExecutor, WorkerWrapper.this, remainTime - costTime, forParamUseWrappers), poolExecutor);
+                    .work(executorService, WorkerWrapper.this, remainTime - costTime, forParamUseWrappers), executorService);
         }
         try {
             CompletableFuture.allOf(futures).get();
@@ -215,7 +215,7 @@ public class WorkerWrapper<T, V> {
         }
     }
 
-    private synchronized void doDependsJobs(ThreadPoolExecutor poolExecutor, List<DependWrapper> dependWrappers, WorkerWrapper fromWrapper, long now, long remainTime) {
+    private synchronized void doDependsJobs(ExecutorService executorService, List<DependWrapper> dependWrappers, WorkerWrapper fromWrapper, long now, long remainTime) {
         boolean nowDependIsMust = false;
         //创建必须完成的上游wrapper集合
         Set<DependWrapper> mustWrapper = new HashSet<>();
@@ -235,7 +235,7 @@ public class WorkerWrapper<T, V> {
             } else {
                 fire();
             }
-            beginNext(poolExecutor, now, remainTime);
+            beginNext(executorService, now, remainTime);
             return;
         }
 
@@ -271,7 +271,7 @@ public class WorkerWrapper<T, V> {
         //只要有失败的
         if (hasError) {
             fastFail(INIT, null);
-            beginNext(poolExecutor, now, remainTime);
+            beginNext(executorService, now, remainTime);
             return;
         }
 
@@ -280,7 +280,7 @@ public class WorkerWrapper<T, V> {
         if (!existNoFinish) {
             //上游都finish了,进行自己
             fire();
-            beginNext(poolExecutor, now, remainTime);
+            beginNext(executorService, now, remainTime);
             return;
         }
     }

+ 2 - 1
src/test/java/parallel/TestPar.java

@@ -6,6 +6,7 @@ import com.jd.platform.async.executor.timer.SystemClock;
 import com.jd.platform.async.wrapper.WorkerWrapper;
 
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
 
 /**
  * 并行测试
@@ -863,7 +864,7 @@ public class TestPar {
                 .next(last, false)
                 .build();
 
-        Async.beginWork(6000, wrapperW, wrapperW1);
+        Async.beginWork(6000,Executors.newCachedThreadPool(),  wrapperW, wrapperW1);
         Async.shutDown();
     }
 }