package cn.gtmap.hlw.core.util.thread;

import cn.gtmap.hlw.core.util.mdc.ThreadMdcUtil;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;

/**
 * @Author admin
 * @Date 2023/5/10 16:06
 * @Version v1.0
 * @Description 自定义线程池：实现链路id赋值
 **/
@Slf4j
@Component
public class ThreadPoolMdcExecutor {

    private static Logger logger = LoggerFactory.getLogger(ThreadPoolMdcExecutor.class);

    private static Integer CORE_POOL_SSIZE = 10;

    private static Integer MAX_POOL_SSIZE = 50;

    private static Integer QUEUE_CAPACITY = 200;

    private static String THREAD_NAME_PREFIX = "myExecutor--";

    private static Integer KEEP_ALIVE_SECONDS = 60;

    private static Integer AWAIT_TERMINATION_SECONDS = 60;


    public static ThreadPoolTaskExecutor taskExecutor() {
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        //设置线程池参数信息
        taskExecutor.setCorePoolSize(CORE_POOL_SSIZE);
        taskExecutor.setMaxPoolSize(MAX_POOL_SSIZE);
        taskExecutor.setQueueCapacity(QUEUE_CAPACITY);
        taskExecutor.setKeepAliveSeconds(KEEP_ALIVE_SECONDS);
        taskExecutor.setThreadNamePrefix(THREAD_NAME_PREFIX);
        taskExecutor.setWaitForTasksToCompleteOnShutdown(true);
        //线程池中任务等待时间，超过等待时间直接销毁
        taskExecutor.setAwaitTerminationSeconds(AWAIT_TERMINATION_SECONDS);
        //修改拒绝策略为使用当前线程执行 ：CallerRunsPolicy
        taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        //初始化线程池
        taskExecutor.initialize();
        return taskExecutor;
    }

    public static void execute(Runnable task) {
        taskExecutor().execute(ThreadMdcUtil.wrap(task, MDC.getCopyOfContextMap()));
    }

    public static void execute(Runnable task, boolean wait) {
        ThreadPoolTaskExecutor executor = (ThreadPoolTaskExecutor) taskExecutor();
        executor.execute(ThreadMdcUtil.wrap(task, MDC.getCopyOfContextMap()));
        if (wait) {
            shutDownThread(executor);
        }
    }

    /**
     * 等待线程池中所有线程结束后关闭
     */
    public static void shutDownThread(ThreadPoolTaskExecutor taskExecutor) {
        //线程池
        ThreadPoolExecutor threadPoolExecutor = taskExecutor.getThreadPoolExecutor();
        //lst 队列任务
        BlockingQueue<Runnable> queue = threadPoolExecutor.getQueue();
        while (true) {
            int count = threadPoolExecutor.getActiveCount();
            int queueSize = queue.size();
            if (count == 0 && queueSize == 0) {
                taskExecutor.destroy();
                break;
            } else {
                log.info("线程池尚在工作中，当前触发{}个线程，队列中存在{}个任务排队", count, queueSize);
                try {
                    Thread.sleep(500L);
                } catch (Exception e) {
                    log.error("shutDownThread", e);
                }
            }
        }
    }

    public static <T> List<T> invokeAllTasks(List<Callable<T>> tasks) throws InterruptedException {
        ThreadPoolTaskExecutor executor = taskExecutor();
        List<Future<T>> futures = executor.getThreadPoolExecutor().invokeAll(tasks);
        List<T> results = new ArrayList<>();
        for (Future<T> future : futures) {
            try {
                T t = future.get();
                if (t != null) {
                    results.add(future.get());
                }
            } catch (ExecutionException e) {
                logger.error("Task execution failed: {}", e.getMessage(), e);
                if (e.getCause() instanceof NullPointerException) {
                    logger.error("NullPointerException in task", e.getCause());
                }
            }
        }
        shutDownThread(executor);
        return results;
    }

//    public static <T> List<T> invokeAllTasks(List<Callable<T>> tasks) throws InterruptedException {
//        ThreadPoolTaskExecutor executor = taskExecutor();
//        List<Future<T>> futures = executor.getThreadPoolExecutor().invokeAll(tasks);
//        List<T> results = new ArrayList<>();
//        for (Future<T> future : futures) {
//            try {
//                results.add(future.get());
//            } catch (ExecutionException e) {
//                logger.error("Task execution failed", e);
//            }
//        }
//        shutDownThread(executor);
//        return results;
//    }
}
