Singleflight源码解读

为什么使用Singleflight?

使用Singleflight的目的:解决缓存击穿

什么是缓存击穿?

什么是缓存击穿? 缓存击穿简单来说,就是热点Key失效后,如果突然出现大量请求就会直达数据库,造成数据库负载升高

可能想的一个解决办法是,只允许一个请求访问数据库,并把数据回填到缓存中,后续请求访问缓存即可

而singleflight基本就是这个思路,多个并发请求到来,只有第一个协程执行任务,在将结果复用给其他协程

可以看看代码

1package main
2
3import (
4    "context"
5    "fmt"
6    "sync"
7    "time"
8
9    "golang.org/x/sync/singleflight"
10)
11
12var cache = sync.Map{}
13var g singleflight.Group
14
15func getFromCache(key string) (string, bool) {
16    if val, ok := cache.Load(key); ok {
17        return val.(string), true
18    }
19    return "", false
20}
21
22func setToCache(key, value string) {
23    cache.Store(key, value)
24}
25
26func getFromDB(key string) string {
27    fmt.Println("Querying database for:", key)
28    time.Sleep(100 * time.Millisecond) // 模拟数据库延迟
29    return "value_of_" + key
30}
31
32func GetValue(ctx context.Context, key string) (string, error) {
33    if val, ok := getFromCache(key); ok {
34        return val, nil
35    }
36
37    // 进入 singleflight 防止击穿
38    val, err, _ := g.Do(key, func() (interface{}, error) {
39        value := getFromDB(key)
40        setToCache(key, value)
41        return value, nil
42    })
43    if err != nil {
44        return "", err
45    }
46    return val.(string), nil
47}

但是,这次,出于我本人的好奇,我想看看singleflight具体是如何实现的,而且其源码也不长,总共200多行

源码解读

Group 和 call 结构体

首先,我们先看看Group 这个类型,他拥有Do这个方法

1// Group represents a class of work and forms a namespace in
2// which units of work can be executed with duplicate suppression.
3// 翻译了下:
4// 组代表一类工作,并形成一个命名空间,其中的工作单元可以在重复抑制的情况下执行。
5type Group struct {
6        mu sync.Mutex       // 由于map不能支持并发读写,所以需要互斥锁来保护m
7        m  map[string]*call // 延迟初始化
8}

结构上还是很简单的,大体上就是使用一个map来保存不同key的请求,然后使用sync.Mutex来保护map,防止并发读写

接下来我们看看call及其子类型

1// call is an in-flight or completed singleflight.Do call
2// 翻译了一下: 大致意思是说这个call代表了一个正在执行或者已经完成的请求
3type call struct {
4        wg sync.WaitGroup
5
6        // These fields are written once before the WaitGroup is done
7        // and are only read after the WaitGroup is done.
8        // 这个意思是,下面的val和err只能在"WaitGroup" done 之前才可以也只能被写入一次
9        // 并且只能在"WaitGroup" done 之后被读取
10        val interface{}
11        err error
12
13        // These fields are read and written with the singleflight
14        // mutex held before the WaitGroup is done, and are read but
15        // not written after the WaitGroup is done.
16        //意思就是:
17        // 这些字段在 WaitGroup 完成之前使用 singleflight 互斥锁进行读写,
18        // 在 WaitGroup 完成之后,这些字段会被读取,但不会被写入。
19        dups  int  //这个是记录这个 key 被分享了多少次    
20        chans []chan<- Result // 执行DoChan会被用到
21}
22
23// Result holds the results of Do, so they can be passed
24// on a channel.
25// 意思是
26// 这个Result 保存了Do的结果,因此它们可以通过通道传递
27type Result struct {
28        Val    interface{}
29        Err    error
30        Shared bool
31}

暂时,我们只是了解了 这些类型有哪些字段,但字段的用途可能还有些不了解,别急,我们接着看

Do

我们看看最主要的Do方法,具体的解释,我写在源码中

1func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
2        // 给group上锁,因为要开始读取g.m
3        g.mu.Lock()
4        // 之前说过这个m是延迟初始化的,所以我们需要判断下是否需要初始化
5        if g.m == nil {
6                g.m = make(map[string]*call)
7        }
8        // 检查这个m中是否有key的这个请求
9        if c, ok := g.m[key]; ok {
10                // 如果有,c.dup++
11                // 这个 dup应该是call被共享的次数
12                c.dups++
13                // 解锁
14                g.mu.Unlock()
15                // 等待call结束
16                c.wg.Wait()
17                   
18                // 走到这一步说明,请求完成,处理结果
19                if e, ok := c.err.(*panicError); ok {
20                        panic(e)
21                } else if c.err == errGoexit {
22                // runtime.Goexit 是 Go 标准库 runtime 包中的一个函数,
23                // 用来 立即终止当前 goroutine 的执行,但 不会终止整个程序或其他 goroutines
24                        runtime.Goexit()
25                }
26                // 返回结果
27                return c.val, c.err, true
28        }
29        // 这里说明,这个是第一个请求的gorountine
30        // 创建一个新的call
31        c := new(call)
32        // wg add 1
33        c.wg.Add(1)
34        // 给map中对应的key赋值
35        g.m[key] = c
36        // 写结束,解锁
37        g.mu.Unlock()
38        
39        // 执行任务
40        g.doCall(c, key, fn)
41        return c.val, c.err, c.dups > 0
42}
43
44
45func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
46        // 标志是否正常返回
47        normalReturn := false
48        // 捕捉到了panic
49        recovered := false
50
51        // use double-defer to distinguish panic from runtime.Goexit,
52        // more details see https://golang.org/cl/134395
53        defer func() {
54                // 如果没有正常返回,并且没有成功捕捉到panic就设置错误errGoexit,使得其他请求退出
55                if !normalReturn && !recovered {
56                        c.err = errGoexit
57                }
58                
59                // 任务结束,需要删除对应的map中对应的key
60                // 由于是对map的写操作,需要加锁
61                g.mu.Lock()
62                defer g.mu.Unlock()
63                
64                // 此次任务完成,wg done
65                c.wg.Done()
66                // 检查对应key的值是否还是c
67                // 如果是,就删除对应的key
68                if g.m[key] == c {
69                        delete(g.m, key)
70                }
71
72                if e, ok := c.err.(*panicError); ok {
73                        // 为了防止等待的通道被永远阻塞,
74                        // 需要确保此panic无法恢复。
75                        if len(c.chans) > 0 {
76                                go panic(e)
77                                select {} // Keep this goroutine around so that it will appear in the crash dump.
78                        } else {
79                                panic(e)
80                        }
81                } else if c.err == errGoexit {
82                        // Already in the process of goexit, no need to call again
83                } else {
84                        // 正常返回
85                        for _, ch := range c.chans {
86                                ch <- Result{c.val, c.err, c.dups > 0}
87                        }
88                }
89        }()
90
91        func() {
92                defer func() {
93                        if !normalReturn {
94                                // 理想情况下,我们应该等到确定了以下情况后再进行堆栈跟踪:
95                                // 这是一个恐慌还是一个 Runtime.GoExit。
96                                // 不幸的是,我们区分这两者的唯一方法是查看
97                                // 恢复是否阻止了 Goroutine 的终止,而
98                                // 当我们知道这一点时,与恐慌相关的堆栈跟踪部分已经被丢弃了。
99                                if r := recover(); r != nil {
100                                        c.err = newPanicError(r)
101                                }
102                        }
103                }()
104                // 这里真正执行具体的任务
105                // 并标记正常返回
106                // 这里有一种特殊情况
107                // 就是fn()中有调用runtime.Goexit()
108                // 就会直接截断gorountine
109                // 但是注册的defer函数还是会被执行
110                c.val, c.err = fn()
111                normalReturn = true
112        }()
113
114        if !normalReturn {
115                recovered = true
116        }
117}

DoChan

其实还有个DoChan方法,但和Do方法不同的是,他返回的是一个channel,当第一个gorountine执行完后,就会往每一个正在等待的gorountine的channel中放入值

1// DoChan is like Do but returns a channel that will receive the
2// results when they are ready.
3//
4// The returned channel will not be closed.
5func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
6        ch := make(chan Result, 1)
7        g.mu.Lock()
8        if g.m == nil {
9                g.m = make(map[string]*call)
10        }
11        if c, ok := g.m[key]; ok {
12                c.dups++
13                c.chans = append(c.chans, ch)
14                g.mu.Unlock()
15                return ch
16        }
17        c := &call{chans: []chan<- Result{ch}}
18        c.wg.Add(1)
19        g.m[key] = c
20        g.mu.Unlock()
21
22        go g.doCall(c, key, fn)
23
24        return ch
25}