Go 学习笔记(二十八)Go 实现工作池的两种方式

本文原创地址:博客园骏马金龙Go 基础系列:Go 实现工作池的两种方式

worker pool 简介

worker pool 其实就是线程池 thread pool。对于 go 来说,直接使用的是 goroutine 而非线程,不过这里仍然以线程来解释线程池。

在线程池模型中, 有 2 个队列一个池子:任务队列、已完成任务队列和线程池 。其中已完成任务队列可能存在也可能不存在,依据实际需求而定。

只要有任务进来,就会放进任务队列中。只要线程执行完了一个任务,就将任务放进已完成任务队列,有时候还会将任务的处理结果也放进已完成队列中。

worker pool 中包含了一堆的线程 (worker,对 go 而言每个 worker 就是一个 goroutine),这些线程嗷嗷待哺,等待着为它们分配任务,或者自己去任务队列中取任务。取得任务后更新任务队列,然后执行任务,并将执行完成的任务放进已完成队列。

下图来自 wiki:

1.png

在 Go 中有两种方式可以实现工作池:传统的互斥锁、channel。

传统互斥锁机制的工作池

假设 Go 中的任务的定义形式为:

type Task struct {
	...
}

每次有任务进来时,都将任务放在任务队列中。

使用传统的互斥锁方式实现,任务队列的定义结构大概如下:

type Queue struct{
	M     sync.Mutex
	Tasks []Task
}

然后在执行任务的函数中加上 Lock()和 Unlock()。例如:

func Worker(queue *Queue) {
	for {
		// Lock()和Unlock()之间的是critical section
		queue.M.Lock()
		// 取出任务
		task := queue.Tasks[0]
		// 更新任务队列
		queue.Tasks = queue.Tasks[1:]
		queue.M.Unlock()
		// 在此goroutine中执行任务
		process(task)
	}
}

假如在线程池中激活了 100 个 goroutine 来执行 Worker()。Lock() 和 Unlock()保证了在同一时间点只能有一个 goroutine 取得任务并随之更新任务列表,取任务和更新任务队列都是 critical section 中的代码,它们是具有原子性。然后这个 goroutine 可以执行自己取得的任务。于此同时,其它 goroutine 可以争夺互斥锁,只要争抢到互斥锁,就可以取得任务并更新任务列表。当某个 goroutine 执行完 process(task),它将因为 for 循环再次参与互斥锁的争抢。

上面只是给出了一点主要的代码段,要实现完整的线程池,还有很多额外的代码。

通过互斥锁,上面的一切操作都是线程安全的。但问题在于加锁 / 解锁的机制比较重量级,当 worker(即 goroutine) 的数量足够多,锁机制的实现将出现瓶颈。

通过 buffered channel 实现工作池

在 Go 中,也能用 buffered channel 实现工作池。

示例代码很长,所以这里先拆分解释每一部分,最后给出完整的代码段。

在下面的示例中,每个 worker 的工作都是计算每个数值的位数相加之和。例如给定一个数值 234,worker 则计算2+3+4=9。这里交给 worker 的数值是随机生成的 [0,999) 范围内的数值。

这个示例有几个核心功能需要先解释,也是通过 channel 实现线程池的一般功能:

  • 创建一个 task buffered channel,并通过 allocate() 函数将生成的任务存放到 task buffered channel 中
  • 创建一个 goroutine pool,每个 goroutine 监听 task buffered channel,并从中取出任务
  • goroutine 执行任务后,将结果写入到 result buffered channel 中
  • 从 result buffered channel 中取出计算结果并输出

首先,创建 Task 和 Result 两个结构,并创建它们的通道:

type Task struct {
	ID      int
	randnum int
}

type Result struct {
	task    Task
	result  int
}

var tasks = make(chan Task, 10)
var results = make(chan Result, 10)

这里,每个 Task 都有自己的 ID,以及该任务将要被 worker 计算的随机数。每个 Result 都包含了 worker 的计算结果 result 以及这个结果对应的 task,这样从 Result 中就可以取出任务信息以及计算结果。

另外,两个通道都是 buffered channel,容量都是 10。每个 worker 都会监听 tasks 通道,并取出其中的任务进行计算,然后将计算结果和任务自身放进 results 通道中。

然后是计算位数之和的函数 process(),它将作为 worker 的工作任务之一。

func process(num int) int {
	sum := 0
	for num != 0 {
		digit := num % 10
		sum += digit
		num /= 10
	}
	time.Sleep(2 * time.Second)
	return sum
}

这个计算过程其实很简单,但随后还睡眠了 2 秒,用来假装执行一个计算任务是需要一点时间的。

然后是 worker(),它监听 tasks 通道并取出任务进行计算,并将结果放进 results 通道。

func worker(wg *WaitGroup){
	defer wg.Done()
	for task := range tasks {
		result := Result{task, process(task.randnum)}
		results <- result
	}
}

上面的代码很容易理解,只要 tasks channel 不关闭,就会一直监听该 channel。需要注意的是,该函数使用指针类型的*WaitGroup作为参数,不能直接使用值类型的WaitGroup作为参数,否则会使得每个 worker 都有一个自己的 WaitGroup。

然后是创建工作池的函数 createWorkerPool(),它有一个数值参数,表示要创建多少个 worker。

func createWorkerPool(numOfWorkers int) {
	var wg sync.WaitGroup
	for i := 0; i < numOfWorkers; i++ {
		wg.Add(1)
		go worker(&wg)
	}
	wg.Wait()
	close(results)
}

创建工作池时,首先创建一个 WaitGroup 的值 wg,这个 wg 被工作池中的所有 goroutine 共享,每创建一个 goroutine 都 wg.Add(1)。创建完所有的 goroutine 后等待所有的 groutine 都执行完它们的任务,只要有一个任务还没有执行完,这个函数就会被 Wait() 阻塞。当所有任务都执行完成后,关闭 results 通道,因为没有结果再需要向该通道写了。

当然,这里是否需要关闭 results 通道,是由稍后的 range 迭代这个通道决定的,不关闭这个通道会一直阻塞 range,最终导致死锁。

工作池部分已经完成了。现在需要使用 allocate() 函数分配任务:生成一大堆的随机数,然后将 Task 放进 tasks 通道。该函数有一个代表创建任务数量的数值参数:

func allocate(numOfTasks int) {
	for i := 0; i < numOfTasks; i++ {
		randnum := rand.Intn(999)
		task := Task{i, randnum}
		tasks <- task
	}
	close(tasks)
}

注意,最后需要关闭 tasks 通道,因为所有任务都分配完之后,没有任务再需要分配。当然,这里之所以需要关闭 tasks 通道,是因为 worker() 中使用了 range 迭代 tasks 通道,如果不关闭这个通道,worker 将在取完所有任务后一直阻塞,最终导致死锁。

再接着的是取出 results 通道中的结果进行输出,函数名为 getResult():

func getResult(done chan bool) {
	for result := range results {
		fmt.Printf("Task id %d, randnum %d , sum %d\n", result.task.id, result.task.randnum, result.result)
	}
	done <- true
}

getResult()中使用了一个 done 参数,这个参数是一个信号通道,用来表示 results 中的所有结果都取出来并处理完成了,这个通道不一定要用 bool 类型,任何类型皆可,它不用来传数据,仅用来返回可读,所以上面直接 close(done) 的效果也一样。通过下面的 main() 函数,就能理解 done 信号通道的作用。

最后还差 main() 函数:

func main() {
	// 记录起始终止时间,用来测试完成所有任务耗费时长
	startTime := time.Now()
	
	numOfWorkers := 20
	numOfTasks := 100
	// 创建任务到任务队列中
	go allocate(numOfTasks)
	// 创建工作池
	go createWorkerPool(numOfWorkers)
	// 取得结果
	var done = make(chan bool)
	go getResult(done)

	// 如果results中还有数据,将阻塞在此
	// 直到发送了信号给done通道
	<- done
	endTime := time.Now()
	diff := endTime.Sub(startTime)
	fmt.Println("total time taken ", diff.Seconds(), "seconds")
}

上面分配了 20 个 worker,这 20 个 worker 总共需要处理的任务数量为 100。但注意,无论是 tasks 还是 results 通道,容量都是 10,意味着任务队列最长只能是 10 个任务。

下面是完整的代码段:

package main

import (
	"fmt"
	"math/rand"
	"sync"
	"time"
)

type Task struct {
	id      int
	randnum int
}
type Result struct {
	task   Task
	result int
}

var tasks = make(chan Task, 10)
var results = make(chan Result, 10)

func process(num int) int {
	sum := 0
	for num != 0 {
		digit := num % 10
		sum += digit
		num /= 10
	}
	time.Sleep(2 * time.Second)
	return sum
}
func worker(wg *sync.WaitGroup) {
	defer wg.Done()
	for task := range tasks {
		result := Result{task, process(task.randnum)}
		results <- result
	}
}
func createWorkerPool(numOfWorkers int) {
	var wg sync.WaitGroup
	for i := 0; i < numOfWorkers; i++ {
		wg.Add(1)
		go worker(&wg)
	}
	wg.Wait()
	close(results)
}
func allocate(numOfTasks int) {
	for i := 0; i < numOfTasks; i++ {
		randnum := rand.Intn(999)
		task := Task{i, randnum}
		tasks <- task
	}
	close(tasks)
}
func getResult(done chan bool) {
	for result := range results {
		fmt.Printf("Task id %d, randnum %d , sum %d\n", result.task.id, result.task.randnum, result.result)
	}
	done <- true
}
func main() {
	startTime := time.Now()
	numOfWorkers := 20
	numOfTasks := 100

	var done = make(chan bool)
	go getResult(done)
	go allocate(numOfTasks)
	go createWorkerPool(numOfWorkers)
	// 必须在allocate()和getResult()之后创建工作池
	<-done
	endTime := time.Now()
	diff := endTime.Sub(startTime)
	fmt.Println("total time taken ", diff.Seconds(), "seconds")
}

执行结果:

Task id 19, randnum 914 , sum 14
Task id 9, randnum 150 , sum 6
Task id 15, randnum 215 , sum 8
............
Task id 97, randnum 315 , sum 9
Task id 99, randnum 641 , sum 11
total time taken  10.0174705 seconds

总共花费 10 秒。

可以试着将任务数量、worker 数量修改修改,看看它们的性能比例情况。例如,将 worker 数量设置为 99,将需要 4 秒,将 worker 数量设置为 10,将需要 20 秒。
上一篇 Go 学习笔记(二十七)互斥锁 Mutex 和读写锁 RWMutex 用法详述
Go 学习笔记(目录)
下一篇 Go 学习笔记(二十九)惰性数值生成器