线程池的出现,是因为频繁地创建和销毁线程开销比较大。通过线程池,一个线程不仅仅是处理一个任务就被销毁,而是可以处理多个任务,任务被处理完时才被销毁。下图是Java ExecutorService 类的结构:
Java ExecutorService
协程池的作用是一样的,实现原理当然也一样。一个协程池的数据结构,在逻辑上必须3类数据:
协程池的容量
task pool 任务池
worker pool 协程池
task pool 比较容易实现,通过一个加锁的链表就可以实现。
对于worker pool,由于协程对用户不暴露任何ID和管理的API,启动后无法从外部主动管理,无法进行池化。我们能做的只有记录并限制协程的数量。
最常用的方法是创建一个缓冲区大小为N的channel,创建协程时,就向channel发送一条数据;协程退出时,就消费一条数据;通过len检查协程的数量;
另一种方法是通过一个原子变量进行控制,创建协程时,将原子变量加一;协程退出时,将原子变量减一;通过 Load 检查协程的数量;我们后面也采用这种方式。
除了这3块数据,还有一点比较关键:活跃的协程能够不断从task pool获取新的任务,以达到重复利用的效果,因此协程需要获得task pool的一个指针。
基于上面提到的理念,我们先把数据结构创建出来:
type pool struct {
cap int32 // 容量
workerCount int32 // worker数量
taskHead *task // task pool 头指针
taskTail *task // task pool 尾指针
taskLock sync.Mutex // task lock
}
type worker struct {
p *pool // 指向pool,需要获取task
}
type task struct {
f func() // 要执行的函数
next *task // 指向下一个task
}
在使用层面上,pool 需要对外暴露一个方法,以实现类似于 go func(... 的功能。在我们的定义里,它应该包含两部分逻辑: 1) 将任务添加到task pool, 供worker消费; 2) 按需创建新的 worker。其逻辑可以这样写:
func (p *pool) Go(f func()) {
t := &task{f: f}
p.taskLock.Lock()
if p.taskHead == nil {
p.taskHead = t
p.taskTail = t
} else {
p.taskTail.next = t
p.taskTail = t;
}
p.taskLock.Unlock()
// 创建新worker
if (atomic.LoadInt32(&p.workerCount) < p.cap) {
atomic.AddInt32(&p.workerCount, 1)
w := &worker{pool: p}
w.run() // run 待实现
}
}
注意,这里会校验协程的数量,并保证不会超过指定的容量。
worker的 run方法中,首先创建一个协程,在协程里不断消费task pool里的task:
func (w *worker) run() {
go func() {
for {
var t *task
w.pool.taskLock.Lock()
// 从 task pool 获取 task
if w.pool.taskHead != nil {
t = w.pool.taskHead
w.pool.taskHead = t.next;
}
// 如果没有任何task,则关闭该worker
if t == nil {
atomic.AddInt32(&w.pool.workerCount, -1)
w.pool.taskLock.Unlock()
return
}
w.pool.taskLock.Unlock()
// 执行函数
t.f()
}
}
}
在 for 循环里,有三步操作:
1. 从 task pool 获取 task
2. 如果获取不到task,则worker直接退出,退出前将workerCount计数减一
3. 获取到task以后,执行该task
至此,一个最基本的协程池的核心逻辑都有了。
在实际生产环境的线程池中,在内存上做了很大的优化。因为 task 和 worker对象的创建和销毁非常频繁,会频繁地触发GC,可以通过 sync.Pool 去管理,以减少不必要的分配和销毁。
还有一块涉及到异常处理的优化,比如在启动goroutine时设置 defer function 去捕获异常并打印函数调用栈:
defer func() {
if r := recover(); r != nil {
msg := fmt.Sprintf("GOPOOL: panic %v: %s", r, debug.Stack())
logger.Errorf(msg)
}
}()
到这里,协程池就叨完了,完整的代码参考 bytedance/gopkg: Github下的util/gopool目录。