ThreadPool.h
#ifndef __THREADPOOL_H__
#define __THREADPOOL_H__
#include <mutex>
#include <unordered_map>
#include <memory>
#include <queue>
#include <condition_variable>
#include <functional>
#include <thread>
#include <future>
const int DEF_THREAD_CNT = std::thread::hardware_concurrency(); // 默认线程数量
const int MAX_THREAD_THRESHOLD = 1024; // 线程数量最大上限
const int DEF_TASK_CNT = 10; // 默认任务数量
const int MAX_TASK_THRESHOLD = 1024; // 任务数量最大上限
const int WAIT_SECONDS = 60; // 线程超时时间
/* 线程池的工作模式 */
enum PoolMode
{
MODE_CACHED = 1, // 动态模式
MODE_FIXED = 2 // 固定模式
};
// 线程
class Thread
{
using FUNC = std::function<void(int)>;
public:
Thread(FUNC func);
~Thread();
void start();
int getId();
private:
FUNC m_func; // 线程执行的操作
static int m_generateId;
int m_threadId;
};
// 线程池
class ThreadPool
{
using Task = std::function<void()>;
public:
ThreadPool();
~ThreadPool();
// 启动线程池
void start(int initThreadCnt = DEF_THREAD_CNT);
// 设置线程池工作模式
void setMode(PoolMode mode);
// 设置任务数量上限
void setTaskThreshold(int taskThreshold = MAX_TASK_THRESHOLD);
// 设置线程数量最大上限
void setThreadThreshold(int threadThreshold = MAX_THREAD_THRESHOLD);
// 提交任务
template <typename Func, typename... Args>
auto submitTask(Func&& func, Args&&... args) -> std::future<decltype(func(std::forward<Args>(args)...))>;
private:
void threadFunc(int threadId); // 线程执行的操作
private:
PoolMode m_mode; // 线程的工作模式
bool m_running; // 线程运行状态
std::unordered_map<int, std::unique_ptr<Thread>> m_threads; // 线程队列
int m_initThreadCnt; // 初始化线程数量
std::atomic_int m_currThreadCnt; // 当前线程数量
std::atomic_int m_idleThreadCnt; // 空闲线程数量
int m_threadMaxThreshold; // 线程数量最大上限
std::queue<Task> m_queueTask; // 任务队列
int m_taskCnt; // 任务数量
int m_taskMaxThreshold; // 任务数量最大上限
std::condition_variable m_exitCondVar; // 退出线程池条件变量
std::condition_variable m_taskNotFull; // 任务队列非满条件变量
std::condition_variable m_taskNotEmpty; // 任务队列非空条件变量
std::mutex m_mutex;
};
/* 提交任务 */
template<typename Func, typename ...Args>
auto ThreadPool::submitTask(Func&& func, Args&& ...args) -> std::future<decltype(func(std::forward<Args>(args)...))>
{
using RType = decltype(func(std::forward<Args>(args)...)); // 返回值类型
auto task = std::make_shared<std::packaged_task<RType(Args...)>>(func);
auto result = task->get_future();
{
std::unique_lock<std::mutex> lock(m_mutex);
// 判断当前队列是否已满
if (!m_taskNotFull.wait_for(lock, std::chrono::seconds(1), [&]{ return m_queueTask.size() < (size_t)m_taskMaxThreshold;}))
{
// 当前任务队列已满
std::cerr << "submit task failed, task queue is full!!!!" << std::endl; // 输出错误
auto task = std::make_shared<std::packaged_task<RType()>>(
[]()->RType { return RType(); });
(*task)();
return task->get_future();
}
m_queueTask.emplace([task, &args...] { // 添加到任务队列中
(*task)(std::forward<Args>(args)...);
});
m_taskCnt++;
m_taskNotEmpty.notify_all(); // 通知等待任务的线程准备获取任务
}
// MODE_CACHED模式,如果任务大于线程数量,则创建线程来执行这些任务
if (m_mode == MODE_CACHED &&
m_currThreadCnt < m_threadMaxThreshold &&
m_taskCnt > m_idleThreadCnt)
{
int needThreadCnt = m_taskCnt - m_idleThreadCnt; // 当前需要创建的线程数量
// 如果需要创建的线程数 + 当前线程数大于线程数量上限,则设置为两和之差
if ((needThreadCnt + m_currThreadCnt) > m_threadMaxThreshold)
needThreadCnt = (needThreadCnt + m_currThreadCnt) - m_threadMaxThreshold;
std::unique_lock<std::mutex> lock(m_mutex); // 获取锁用来创建线程
for (int i = 0; i < needThreadCnt; i++)
{
std::cout << "创建了新线程用来执行任务" << std::endl;
auto threadPtr = new Thread(std::bind(&ThreadPool::threadFunc, this, std::placeholders::_1));
m_threads.emplace(threadPtr->getId(), std::move(threadPtr));
threadPtr->start(); // 线程运行
m_idleThreadCnt++;
m_currThreadCnt++;
}
}
return result;
}
#endif
ThreadPool.cpp
#include "pch.h"
#include "ThreadPool.h"
////////////////////////////////// 线程
int Thread::m_generateId = 0; // 生成线程id标识符
Thread::Thread(FUNC func)
: m_func(func),
m_threadId(m_generateId++)
{}
Thread::~Thread() {}
int Thread::getId()
{
return m_threadId;
}
void Thread::start()
{
std::thread t(m_func, m_threadId);
t.detach();
}
////////////////////////////////// 线程池
ThreadPool::ThreadPool()
: m_running(false),
m_mode(MODE_FIXED),
m_initThreadCnt(DEF_THREAD_CNT),
m_currThreadCnt(DEF_THREAD_CNT),
m_idleThreadCnt(0),
m_threadMaxThreshold(MAX_THREAD_THRESHOLD),
m_taskCnt(0),
m_taskMaxThreshold(MAX_TASK_THRESHOLD)
{
}
ThreadPool::~ThreadPool()
{
std::unique_lock<std::mutex> lock(m_mutex);
m_running = false;
m_taskNotEmpty.notify_all();
m_exitCondVar.wait(lock, [&] { return m_threads.empty(); });
}
/* 创建线程并启动 */
void ThreadPool::start(int initThreadCnt)
{
if (m_running)
return;
m_running = true;
m_initThreadCnt = m_currThreadCnt = initThreadCnt;
// 创建线程
for (int i = 0; i < m_initThreadCnt; i++)
{
auto threadPtr = new Thread(std::bind(&ThreadPool::threadFunc, this, std::placeholders::_1));
m_threads.emplace(threadPtr->getId(), std::move(threadPtr));
m_idleThreadCnt++;
threadPtr->start(); // 启动线程
}
}
/* 设置线程池工作模式 */
void ThreadPool::setMode(PoolMode mode)
{
if (!m_running)
m_mode = mode;
}
/* 设置任务队列数量上限 */
void ThreadPool::setTaskThreshold(int taskThreshold)
{
if (!m_running)
m_taskMaxThreshold = taskThreshold;
}
/* 设置线程数数量上限 */
void ThreadPool::setThreadThreshold(int threadThreshold)
{
if (!m_running)
m_threadMaxThreshold = threadThreshold;
}
/* 线程执行的操作 */
void ThreadPool::threadFunc(int threadId)
{
printf("线程[%d] 开始执行\r\n", threadId);
auto lastTime = std::chrono::high_resolution_clock::now(); // 上一次执行任务的时间
while (m_running)
{
Task task = nullptr; // 存储从任务队列取出的任务
{
std::unique_lock<std::mutex> lock(m_mutex);
if (m_mode == MODE_CACHED) // 如果线程池是动态增长模式
{
// 等待一秒钟,判断是否没有任务要执行
if (std::cv_status::timeout == m_taskNotEmpty.wait_for(lock, std::chrono::seconds(1)))
{
// 没有任务要执行,获取当前时间
auto now = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>(now - lastTime);
// 如果当前线程大于等于 WAIT_SECONDS 时间没有执行任务,则回收线程
if (duration.count() >= WAIT_SECONDS)
{
m_idleThreadCnt--;
m_currThreadCnt--;
m_threads.erase(threadId); // 当前线程退出
return;
}
}
}
// 等待获取任务
m_taskNotEmpty.wait(lock, [&] { return !m_running || !m_queueTask.empty(); });
if (m_running)
{
printf("线程[%d] 获取任务\r\n", threadId);
task = m_queueTask.front();
m_queueTask.pop();
m_idleThreadCnt--;
m_taskCnt--;
m_taskNotFull.notify_all();
if (!m_queueTask.empty())
m_taskNotEmpty.notify_all();
}
else
break;
}
if (task)
{
m_taskNotFull.notify_all();
task();
printf("线程[%d] 执行任务完毕\r\n", threadId);
m_idleThreadCnt++;
lastTime = std::chrono::high_resolution_clock::now();
}
}
std::unique_lock<std::mutex> lock(m_mutex);
m_threads.erase(threadId);
m_exitCondVar.notify_all();
printf("线程[%d] 退出\r\n", threadId);
}
测试代码,如何使用
#include <iostream>
#include <chrnon>
#inlcude "ThreadPool.h"
using namespace std;
int test1(int&& a, int& b)
{
b++;
return a + b;
}
int main()
{
ThreadPool pool; // 创建一个线程池
// MODE_FIXED: 固定线程数量模式
// MODE_CACHED: 动态增长线程数量
pool.setMode(MODE_FIXED);
pool.setTaskThreadshold(100); // 设置最大任务数量上限
// MODE_FIXED模式不需要设置
//pool.setThreadThreshold(1024); // 设置最大线程数量上限
pool.start(20); // 启动线程,并设置当前线程池线程数量
// 提交任务给线程池
int a = 1, b = 2;
// 提交普通函数
future<int> f1 = pool.submitTask(test1, std::move(a), b);
// 提交lambda也可以作为任务
future<int> f2 = pool.submitTask([](int start, int end) {
int sum = 0;
for (; start < end; ++start)
sum += start;
return sum
}, 1, 100);
cout << "test1函数返回值: " << f1.get() << " b= " << b << endl;
cout << "lambda函数返回值: " << f2.get() << endl;
return 0;
}