Files
golang/golang-learning/09-advanced/03-context.go
2025-08-24 13:01:09 +08:00

766 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
03-context.go - Go 语言 Context 包详解
学习目标:
1. 理解 Context 的概念和作用
2. 掌握 Context 的创建和使用方法
3. 学会在并发程序中传递取消信号
4. 了解 Context 的超时和截止时间机制
5. 掌握 Context 的最佳实践
知识点:
- Context 接口和实现
- WithCancel, WithTimeout, WithDeadline
- WithValue 传递请求范围的数据
- Context 在 HTTP 服务中的应用
- Context 的传播和继承
- 避免 Context 的常见陷阱
*/
package main
import (
"context"
"fmt"
"net/http"
"sync"
"time"
)
func main() {
fmt.Println("=== Go 语言 Context 包详解 ===\n")
// 演示 Context 的基本概念
demonstrateContextBasics()
// 演示取消 Context
demonstrateCancellation()
// 演示超时 Context
demonstrateTimeout()
// 演示截止时间 Context
demonstrateDeadline()
// 演示 Context 传递值
demonstrateWithValue()
// 演示 Context 在并发中的应用
demonstrateConcurrencyWithContext()
// 演示 Context 在 HTTP 中的应用
demonstrateHTTPContext()
// 演示 Context 的最佳实践
demonstrateContextBestPractices()
}
// demonstrateContextBasics 演示 Context 的基本概念
func demonstrateContextBasics() {
fmt.Println("1. Context 的基本概念:")
// Context 的概念
fmt.Printf(" Context 的概念:\n")
fmt.Printf(" - Context 是 Go 语言中用于传递请求范围数据的标准方式\n")
fmt.Printf(" - 提供取消信号、超时控制和请求范围值传递\n")
fmt.Printf(" - 在 goroutine 之间传递取消信号和截止时间\n")
fmt.Printf(" - 避免 goroutine 泄漏和资源浪费\n")
// Context 接口
fmt.Printf(" Context 接口:\n")
fmt.Printf(" type Context interface {\n")
fmt.Printf(" Deadline() (deadline time.Time, ok bool)\n")
fmt.Printf(" Done() <-chan struct{}\n")
fmt.Printf(" Err() error\n")
fmt.Printf(" Value(key interface{}) interface{}\n")
fmt.Printf(" }\n")
// 创建根 Context
fmt.Printf(" 创建根 Context:\n")
// Background Context
bgCtx := context.Background()
fmt.Printf(" Background Context: %v\n", bgCtx)
fmt.Printf(" - 通常用作根 Context\n")
fmt.Printf(" - 永远不会被取消,没有值,没有截止时间\n")
// TODO Context
todoCtx := context.TODO()
fmt.Printf(" TODO Context: %v\n", todoCtx)
fmt.Printf(" - 当不确定使用哪个 Context 时使用\n")
fmt.Printf(" - 通常在重构时作为占位符\n")
// Context 的方法
fmt.Printf(" Context 的方法:\n")
deadline, ok := bgCtx.Deadline()
fmt.Printf(" Deadline(): %v, %t (是否有截止时间)\n", deadline, ok)
fmt.Printf(" Done(): %v (取消通道)\n", bgCtx.Done())
fmt.Printf(" Err(): %v (错误信息)\n", bgCtx.Err())
fmt.Printf(" Value(key): %v (获取值)\n", bgCtx.Value("key"))
fmt.Println()
}
// demonstrateCancellation 演示取消 Context
func demonstrateCancellation() {
fmt.Println("2. 取消 Context:")
// WithCancel 的使用
fmt.Printf(" WithCancel 的使用:\n")
fmt.Printf(" ctx, cancel := context.WithCancel(context.Background())\n")
fmt.Printf(" defer cancel() // 确保释放资源\n")
// 创建可取消的 Context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动一个 goroutine 监听取消信号
fmt.Printf(" 启动监听取消信号的 goroutine:\n")
go func() {
select {
case <-ctx.Done():
fmt.Printf(" goroutine 收到取消信号: %v\n", ctx.Err())
case <-time.After(5 * time.Second):
fmt.Printf(" goroutine 超时退出\n")
}
}()
// 模拟一些工作
time.Sleep(100 * time.Millisecond)
fmt.Printf(" 执行取消操作...\n")
cancel() // 发送取消信号
// 等待 goroutine 处理取消信号
time.Sleep(100 * time.Millisecond)
// 检查 Context 状态
fmt.Printf(" Context 状态:\n")
fmt.Printf(" Done(): %v\n", ctx.Done() != nil)
fmt.Printf(" Err(): %v\n", ctx.Err())
// 演示取消传播
fmt.Printf(" 取消传播:\n")
parentCtx, parentCancel := context.WithCancel(context.Background())
childCtx, childCancel := context.WithCancel(parentCtx)
defer parentCancel()
defer childCancel()
// 取消父 Context
parentCancel()
// 检查子 Context 是否也被取消
select {
case <-childCtx.Done():
fmt.Printf(" 子 Context 也被取消了: %v\n", childCtx.Err())
case <-time.After(100 * time.Millisecond):
fmt.Printf(" 子 Context 没有被取消\n")
}
fmt.Println()
}
// demonstrateTimeout 演示超时 Context
func demonstrateTimeout() {
fmt.Println("3. 超时 Context:")
// WithTimeout 的使用
fmt.Printf(" WithTimeout 的使用:\n")
fmt.Printf(" ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)\n")
fmt.Printf(" defer cancel()\n")
// 创建超时 Context
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
// 模拟长时间运行的操作
fmt.Printf(" 模拟长时间操作:\n")
start := time.Now()
select {
case <-time.After(500 * time.Millisecond):
fmt.Printf(" 操作完成,耗时: %v\n", time.Since(start))
case <-ctx.Done():
fmt.Printf(" 操作被超时取消,耗时: %v错误: %v\n", time.Since(start), ctx.Err())
}
// 演示超时处理函数
fmt.Printf(" 超时处理函数示例:\n")
result, err := doWorkWithTimeout(300 * time.Millisecond)
if err != nil {
fmt.Printf(" 工作超时: %v\n", err)
} else {
fmt.Printf(" 工作完成: %s\n", result)
}
result, err = doWorkWithTimeout(100 * time.Millisecond)
if err != nil {
fmt.Printf(" 工作超时: %v\n", err)
} else {
fmt.Printf(" 工作完成: %s\n", result)
}
// 演示 HTTP 请求超时
fmt.Printf(" HTTP 请求超时示例:\n")
err = makeHTTPRequestWithTimeout("https://httpbin.org/delay/1", 500*time.Millisecond)
if err != nil {
fmt.Printf(" HTTP 请求失败: %v\n", err)
} else {
fmt.Printf(" HTTP 请求成功\n")
}
fmt.Println()
}
// demonstrateDeadline 演示截止时间 Context
func demonstrateDeadline() {
fmt.Println("4. 截止时间 Context:")
// WithDeadline 的使用
fmt.Printf(" WithDeadline 的使用:\n")
deadline := time.Now().Add(200 * time.Millisecond)
fmt.Printf(" deadline := time.Now().Add(200 * time.Millisecond)\n")
fmt.Printf(" ctx, cancel := context.WithDeadline(context.Background(), deadline)\n")
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
// 检查截止时间
ctxDeadline, ok := ctx.Deadline()
fmt.Printf(" Context 截止时间:\n")
fmt.Printf(" 截止时间: %v\n", ctxDeadline)
fmt.Printf(" 有截止时间: %t\n", ok)
fmt.Printf(" 距离截止时间: %v\n", time.Until(ctxDeadline))
// 等待截止时间到达
fmt.Printf(" 等待截止时间到达:\n")
start := time.Now()
<-ctx.Done()
fmt.Printf(" Context 在 %v 后被取消,错误: %v\n", time.Since(start), ctx.Err())
// 演示截止时间检查
fmt.Printf(" 截止时间检查示例:\n")
checkDeadline(context.Background())
deadlineCtx, deadlineCancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer deadlineCancel()
checkDeadline(deadlineCtx)
fmt.Println()
}
// demonstrateWithValue 演示 Context 传递值
func demonstrateWithValue() {
fmt.Println("5. Context 传递值:")
// WithValue 的使用
fmt.Printf(" WithValue 的使用:\n")
fmt.Printf(" ctx := context.WithValue(context.Background(), \\\"userID\\\", 12345)\n")
// 创建带值的 Context
ctx := context.WithValue(context.Background(), "userID", 12345)
ctx = context.WithValue(ctx, "requestID", "req-abc-123")
ctx = context.WithValue(ctx, "traceID", "trace-xyz-789")
// 获取值
fmt.Printf(" 获取 Context 中的值:\n")
if userID := ctx.Value("userID"); userID != nil {
fmt.Printf(" 用户ID: %v\n", userID)
}
if requestID := ctx.Value("requestID"); requestID != nil {
fmt.Printf(" 请求ID: %v\n", requestID)
}
if traceID := ctx.Value("traceID"); traceID != nil {
fmt.Printf(" 追踪ID: %v\n", traceID)
}
// 值不存在的情况
if sessionID := ctx.Value("sessionID"); sessionID != nil {
fmt.Printf(" 会话ID: %v\n", sessionID)
} else {
fmt.Printf(" 会话ID: 不存在\n")
}
// 演示类型安全的键
fmt.Printf(" 类型安全的键:\n")
type contextKey string
const (
userIDKey contextKey = "userID"
requestIDKey contextKey = "requestID"
)
safeCtx := context.WithValue(context.Background(), userIDKey, 67890)
safeCtx = context.WithValue(safeCtx, requestIDKey, "req-def-456")
// 使用类型安全的方式获取值
if userID := safeCtx.Value(userIDKey); userID != nil {
fmt.Printf(" 安全获取用户ID: %v\n", userID)
}
// 演示值的传播
fmt.Printf(" 值的传播:\n")
processRequest(ctx)
// 演示值的最佳实践
fmt.Printf(" 值的最佳实践:\n")
fmt.Printf(" 1. 只存储请求范围的数据\n")
fmt.Printf(" 2. 使用类型安全的键\n")
fmt.Printf(" 3. 不要存储可选参数\n")
fmt.Printf(" 4. 键应该是不可导出的\n")
fmt.Println()
}
// demonstrateConcurrencyWithContext 演示 Context 在并发中的应用
func demonstrateConcurrencyWithContext() {
fmt.Println("6. Context 在并发中的应用:")
// 演示工作池模式
fmt.Printf(" 工作池模式:\n")
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// 创建工作通道
jobs := make(chan int, 10)
results := make(chan string, 10)
// 启动工作者
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go worker(ctx, i, jobs, results, &wg)
}
// 发送工作
go func() {
for i := 1; i <= 8; i++ {
select {
case jobs <- i:
fmt.Printf(" 发送工作 %d\n", i)
case <-ctx.Done():
fmt.Printf(" 停止发送工作: %v\n", ctx.Err())
close(jobs)
return
}
time.Sleep(100 * time.Millisecond)
}
close(jobs)
}()
// 收集结果
go func() {
wg.Wait()
close(results)
}()
// 打印结果
fmt.Printf(" 工作结果:\n")
for result := range results {
fmt.Printf(" %s\n", result)
}
// 演示扇出扇入模式
fmt.Printf(" 扇出扇入模式:\n")
fanOutFanIn()
fmt.Println()
}
// demonstrateHTTPContext 演示 Context 在 HTTP 中的应用
func demonstrateHTTPContext() {
fmt.Println("7. Context 在 HTTP 中的应用:")
// HTTP 请求中的 Context
fmt.Printf(" HTTP 请求中的 Context:\n")
fmt.Printf(" - 每个 HTTP 请求都有一个关联的 Context\n")
fmt.Printf(" - 可以通过 r.Context() 获取\n")
fmt.Printf(" - 请求取消时 Context 也会被取消\n")
// 模拟 HTTP 处理器
fmt.Printf(" HTTP 处理器示例:\n")
// 创建模拟请求
req, _ := http.NewRequest("GET", "/api/data", nil)
// 添加超时
ctx, cancel := context.WithTimeout(req.Context(), 1*time.Second)
defer cancel()
req = req.WithContext(ctx)
// 模拟处理请求
fmt.Printf(" 处理请求...\n")
handleRequest(req)
// 演示中间件模式
fmt.Printf(" 中间件模式:\n")
fmt.Printf(" func middleware(next http.Handler) http.Handler {\n")
fmt.Printf(" return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n")
fmt.Printf(" ctx := context.WithValue(r.Context(), \\\"requestID\\\", generateID())\n")
fmt.Printf(" next.ServeHTTP(w, r.WithContext(ctx))\n")
fmt.Printf(" })\n")
fmt.Printf(" }\n")
// 演示数据库查询超时
fmt.Printf(" 数据库查询超时:\n")
queryCtx, queryCancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer queryCancel()
err := queryDatabase(queryCtx, "SELECT * FROM users WHERE id = ?", 123)
if err != nil {
fmt.Printf(" 数据库查询失败: %v\n", err)
} else {
fmt.Printf(" 数据库查询成功\n")
}
fmt.Println()
}
// demonstrateContextBestPractices 演示 Context 的最佳实践
func demonstrateContextBestPractices() {
fmt.Println("8. Context 的最佳实践:")
// 最佳实践列表
fmt.Printf(" Context 的最佳实践:\n")
fmt.Printf(" 1. 不要将 Context 存储在结构体中\n")
fmt.Printf(" 2. Context 应该作为函数的第一个参数\n")
fmt.Printf(" 3. 不要传递 nil Context使用 context.TODO()\n")
fmt.Printf(" 4. Context.Value 只用于请求范围的数据\n")
fmt.Printf(" 5. 使用 defer cancel() 确保资源释放\n")
fmt.Printf(" 6. 不要忽略 Context 的取消信号\n")
fmt.Printf(" 7. 使用类型安全的键\n")
fmt.Printf(" 8. 避免在 Context 中存储可选参数\n")
// 正确的函数签名
fmt.Printf(" 正确的函数签名:\n")
fmt.Printf(" func DoSomething(ctx context.Context, arg string) error\n")
fmt.Printf(" func (s *Service) Process(ctx context.Context, data []byte) (*Result, error)\n")
// 错误的用法
fmt.Printf(" 避免的错误用法:\n")
fmt.Printf(" // 错误:不要在结构体中存储 Context\n")
fmt.Printf(" type Server struct {\n")
fmt.Printf(" ctx context.Context // 错误\n")
fmt.Printf(" }\n")
fmt.Printf(" \n")
fmt.Printf(" // 错误:不要传递 nil Context\n")
fmt.Printf(" DoSomething(nil, \\\"data\\\") // 错误\n")
// 正确的用法
fmt.Printf(" 正确的用法示例:\n")
ctx := context.Background()
// 正确的取消处理
processWithCancel(ctx)
// 正确的超时处理
processWithTimeout(ctx)
// 正确的值传递
processWithValue(ctx)
// Context 的性能考虑
fmt.Printf(" 性能考虑:\n")
fmt.Printf(" 1. Context.Value 的查找是 O(n) 的\n")
fmt.Printf(" 2. 避免在热路径中频繁创建 Context\n")
fmt.Printf(" 3. 合理使用 Context 的层次结构\n")
fmt.Printf(" 4. 及时调用 cancel 函数释放资源\n")
// 常见陷阱
fmt.Printf(" 常见陷阱:\n")
fmt.Printf(" 1. 忘记调用 cancel 函数导致资源泄漏\n")
fmt.Printf(" 2. 在循环中创建过多的 Context\n")
fmt.Printf(" 3. 将 Context 用作可选参数的容器\n")
fmt.Printf(" 4. 不检查 Context 的取消信号\n")
fmt.Println()
}
// ========== 辅助函数 ==========
// doWorkWithTimeout 带超时的工作函数
func doWorkWithTimeout(timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 模拟工作
select {
case <-time.After(200 * time.Millisecond):
return "工作完成", nil
case <-ctx.Done():
return "", ctx.Err()
}
}
// makeHTTPRequestWithTimeout 带超时的 HTTP 请求
func makeHTTPRequestWithTimeout(url string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
// 模拟 HTTP 请求
select {
case <-time.After(300 * time.Millisecond):
return nil // 请求成功
case <-ctx.Done():
return ctx.Err()
}
}
// checkDeadline 检查截止时间
func checkDeadline(ctx context.Context) {
if deadline, ok := ctx.Deadline(); ok {
fmt.Printf(" 有截止时间: %v剩余时间: %v\n", deadline, time.Until(deadline))
} else {
fmt.Printf(" 没有截止时间\n")
}
}
// processRequest 处理请求
func processRequest(ctx context.Context) {
if userID := ctx.Value("userID"); userID != nil {
fmt.Printf(" 处理用户 %v 的请求\n", userID)
}
if requestID := ctx.Value("requestID"); requestID != nil {
fmt.Printf(" 请求ID: %v\n", requestID)
}
}
// worker 工作者函数
func worker(ctx context.Context, id int, jobs <-chan int, results chan<- string, wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case job, ok := <-jobs:
if !ok {
fmt.Printf(" 工作者 %d 退出(通道关闭)\n", id)
return
}
// 模拟工作
select {
case <-time.After(200 * time.Millisecond):
results <- fmt.Sprintf("工作者 %d 完成工作 %d", id, job)
case <-ctx.Done():
fmt.Printf(" 工作者 %d 被取消: %v\n", id, ctx.Err())
return
}
case <-ctx.Done():
fmt.Printf(" 工作者 %d 被取消: %v\n", id, ctx.Err())
return
}
}
}
// fanOutFanIn 扇出扇入模式
func fanOutFanIn() {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
// 输入通道
input := make(chan int, 5)
// 启动多个处理器(扇出)
c1 := process(ctx, input)
c2 := process(ctx, input)
c3 := process(ctx, input)
// 合并结果(扇入)
output := merge(ctx, c1, c2, c3)
// 发送数据
go func() {
defer close(input)
for i := 1; i <= 6; i++ {
select {
case input <- i:
case <-ctx.Done():
return
}
}
}()
// 收集结果
for result := range output {
fmt.Printf(" 扇出扇入结果: %s\n", result)
}
}
// process 处理函数
func process(ctx context.Context, input <-chan int) <-chan string {
output := make(chan string)
go func() {
defer close(output)
for {
select {
case n, ok := <-input:
if !ok {
return
}
select {
case output <- fmt.Sprintf("处理 %d", n*n):
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return output
}
// merge 合并多个通道
func merge(ctx context.Context, channels ...<-chan string) <-chan string {
var wg sync.WaitGroup
output := make(chan string)
// 为每个输入通道启动一个 goroutine
multiplex := func(c <-chan string) {
defer wg.Done()
for {
select {
case s, ok := <-c:
if !ok {
return
}
select {
case output <- s:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
wg.Add(len(channels))
for _, c := range channels {
go multiplex(c)
}
// 等待所有 goroutine 完成
go func() {
wg.Wait()
close(output)
}()
return output
}
// handleRequest 处理 HTTP 请求
func handleRequest(req *http.Request) {
ctx := req.Context()
// 模拟处理
select {
case <-time.After(500 * time.Millisecond):
fmt.Printf(" 请求处理完成\n")
case <-ctx.Done():
fmt.Printf(" 请求被取消: %v\n", ctx.Err())
}
}
// queryDatabase 查询数据库
func queryDatabase(ctx context.Context, query string, args ...interface{}) error {
// 模拟数据库查询
select {
case <-time.After(300 * time.Millisecond):
return nil // 查询成功
case <-ctx.Done():
return ctx.Err()
}
}
// processWithCancel 带取消的处理
func processWithCancel(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// 模拟处理
go func() {
time.Sleep(100 * time.Millisecond)
cancel() // 取消操作
}()
select {
case <-time.After(200 * time.Millisecond):
fmt.Printf(" 处理完成\n")
case <-ctx.Done():
fmt.Printf(" 处理被取消\n")
}
}
// processWithTimeout 带超时的处理
func processWithTimeout(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
defer cancel()
select {
case <-time.After(100 * time.Millisecond):
fmt.Printf(" 处理完成\n")
case <-ctx.Done():
fmt.Printf(" 处理超时\n")
}
}
// processWithValue 带值传递的处理
func processWithValue(ctx context.Context) {
ctx = context.WithValue(ctx, "operation", "process")
if op := ctx.Value("operation"); op != nil {
fmt.Printf(" 执行操作: %v\n", op)
}
}
/*
运行这个程序:
go run 03-context.go
Context 的核心概念:
1. Context 是 Go 语言中处理请求范围数据的标准方式
2. 提供取消信号、超时控制和值传递功能
3. 在 goroutine 之间传递取消信号和截止时间
4. 避免 goroutine 泄漏和资源浪费
Context 的主要类型:
1. context.Background(): 根 Context通常用作起点
2. context.TODO(): 当不确定使用哪个 Context 时使用
3. context.WithCancel(): 创建可取消的 Context
4. context.WithTimeout(): 创建带超时的 Context
5. context.WithDeadline(): 创建带截止时间的 Context
6. context.WithValue(): 创建带值的 Context
Context 的最佳实践:
1. 不要将 Context 存储在结构体中
2. Context 应该作为函数的第一个参数
3. 不要传递 nil Context使用 context.TODO()
4. Context.Value 只用于请求范围的数据
5. 使用 defer cancel() 确保资源释放
6. 不要忽略 Context 的取消信号
7. 使用类型安全的键
8. 避免在 Context 中存储可选参数
Context 的应用场景:
1. HTTP 请求处理
2. 数据库查询超时
3. 并发任务协调
4. 微服务调用链
5. 长时间运行的任务控制
注意事项:
1. Context 是并发安全的
2. Context 的取消会传播到所有子 Context
3. Context.Value 的查找是 O(n) 的
4. 及时调用 cancel 函数释放资源
5. 不要在 Context 中存储可选参数
常见错误:
1. 忘记调用 cancel 函数
2. 将 Context 存储在结构体中
3. 传递 nil Context
4. 不检查 Context 的取消信号
5. 在 Context 中存储过多数据
*/