/* 03-context.go - Go 语言 Context 包详解 学习目标: 1. 理解 Context 的概念和作用 2. 掌握 Context 的创建和使用方法 3. 学会在并发程序中传递取消信号 4. 了解 Context 的超时和截止时间机制 5. 掌握 Context 的最佳实践 知识点: - Context 接口和实现 - WithCancel, WithTimeout, WithDeadline - WithValue 传递请求范围的数据 - Context 在 HTTP 服务中的应用 - Context 的传播和继承 - 避免 Context 的常见陷阱 */ package main import ( "context" "fmt" "net/http" "sync" "time" ) func main() { fmt.Println("=== Go 语言 Context 包详解 ===\n") // 演示 Context 的基本概念 demonstrateContextBasics() // 演示取消 Context demonstrateCancellation() // 演示超时 Context demonstrateTimeout() // 演示截止时间 Context demonstrateDeadline() // 演示 Context 传递值 demonstrateWithValue() // 演示 Context 在并发中的应用 demonstrateConcurrencyWithContext() // 演示 Context 在 HTTP 中的应用 demonstrateHTTPContext() // 演示 Context 的最佳实践 demonstrateContextBestPractices() } // demonstrateContextBasics 演示 Context 的基本概念 func demonstrateContextBasics() { fmt.Println("1. Context 的基本概念:") // Context 的概念 fmt.Printf(" Context 的概念:\n") fmt.Printf(" - Context 是 Go 语言中用于传递请求范围数据的标准方式\n") fmt.Printf(" - 提供取消信号、超时控制和请求范围值传递\n") fmt.Printf(" - 在 goroutine 之间传递取消信号和截止时间\n") fmt.Printf(" - 避免 goroutine 泄漏和资源浪费\n") // Context 接口 fmt.Printf(" Context 接口:\n") fmt.Printf(" type Context interface {\n") fmt.Printf(" Deadline() (deadline time.Time, ok bool)\n") fmt.Printf(" Done() <-chan struct{}\n") fmt.Printf(" Err() error\n") fmt.Printf(" Value(key interface{}) interface{}\n") fmt.Printf(" }\n") // 创建根 Context fmt.Printf(" 创建根 Context:\n") // Background Context bgCtx := context.Background() fmt.Printf(" Background Context: %v\n", bgCtx) fmt.Printf(" - 通常用作根 Context\n") fmt.Printf(" - 永远不会被取消,没有值,没有截止时间\n") // TODO Context todoCtx := context.TODO() fmt.Printf(" TODO Context: %v\n", todoCtx) fmt.Printf(" - 当不确定使用哪个 Context 时使用\n") fmt.Printf(" - 通常在重构时作为占位符\n") // Context 的方法 fmt.Printf(" Context 的方法:\n") deadline, ok := bgCtx.Deadline() fmt.Printf(" Deadline(): %v, %t (是否有截止时间)\n", deadline, ok) fmt.Printf(" Done(): %v (取消通道)\n", bgCtx.Done()) fmt.Printf(" Err(): %v (错误信息)\n", bgCtx.Err()) fmt.Printf(" Value(key): %v (获取值)\n", bgCtx.Value("key")) fmt.Println() } // demonstrateCancellation 演示取消 Context func demonstrateCancellation() { fmt.Println("2. 取消 Context:") // WithCancel 的使用 fmt.Printf(" WithCancel 的使用:\n") fmt.Printf(" ctx, cancel := context.WithCancel(context.Background())\n") fmt.Printf(" defer cancel() // 确保释放资源\n") // 创建可取消的 Context ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 启动一个 goroutine 监听取消信号 fmt.Printf(" 启动监听取消信号的 goroutine:\n") go func() { select { case <-ctx.Done(): fmt.Printf(" goroutine 收到取消信号: %v\n", ctx.Err()) case <-time.After(5 * time.Second): fmt.Printf(" goroutine 超时退出\n") } }() // 模拟一些工作 time.Sleep(100 * time.Millisecond) fmt.Printf(" 执行取消操作...\n") cancel() // 发送取消信号 // 等待 goroutine 处理取消信号 time.Sleep(100 * time.Millisecond) // 检查 Context 状态 fmt.Printf(" Context 状态:\n") fmt.Printf(" Done(): %v\n", ctx.Done() != nil) fmt.Printf(" Err(): %v\n", ctx.Err()) // 演示取消传播 fmt.Printf(" 取消传播:\n") parentCtx, parentCancel := context.WithCancel(context.Background()) childCtx, childCancel := context.WithCancel(parentCtx) defer parentCancel() defer childCancel() // 取消父 Context parentCancel() // 检查子 Context 是否也被取消 select { case <-childCtx.Done(): fmt.Printf(" 子 Context 也被取消了: %v\n", childCtx.Err()) case <-time.After(100 * time.Millisecond): fmt.Printf(" 子 Context 没有被取消\n") } fmt.Println() } // demonstrateTimeout 演示超时 Context func demonstrateTimeout() { fmt.Println("3. 超时 Context:") // WithTimeout 的使用 fmt.Printf(" WithTimeout 的使用:\n") fmt.Printf(" ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)\n") fmt.Printf(" defer cancel()\n") // 创建超时 Context ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() // 模拟长时间运行的操作 fmt.Printf(" 模拟长时间操作:\n") start := time.Now() select { case <-time.After(500 * time.Millisecond): fmt.Printf(" 操作完成,耗时: %v\n", time.Since(start)) case <-ctx.Done(): fmt.Printf(" 操作被超时取消,耗时: %v,错误: %v\n", time.Since(start), ctx.Err()) } // 演示超时处理函数 fmt.Printf(" 超时处理函数示例:\n") result, err := doWorkWithTimeout(300 * time.Millisecond) if err != nil { fmt.Printf(" 工作超时: %v\n", err) } else { fmt.Printf(" 工作完成: %s\n", result) } result, err = doWorkWithTimeout(100 * time.Millisecond) if err != nil { fmt.Printf(" 工作超时: %v\n", err) } else { fmt.Printf(" 工作完成: %s\n", result) } // 演示 HTTP 请求超时 fmt.Printf(" HTTP 请求超时示例:\n") err = makeHTTPRequestWithTimeout("https://httpbin.org/delay/1", 500*time.Millisecond) if err != nil { fmt.Printf(" HTTP 请求失败: %v\n", err) } else { fmt.Printf(" HTTP 请求成功\n") } fmt.Println() } // demonstrateDeadline 演示截止时间 Context func demonstrateDeadline() { fmt.Println("4. 截止时间 Context:") // WithDeadline 的使用 fmt.Printf(" WithDeadline 的使用:\n") deadline := time.Now().Add(200 * time.Millisecond) fmt.Printf(" deadline := time.Now().Add(200 * time.Millisecond)\n") fmt.Printf(" ctx, cancel := context.WithDeadline(context.Background(), deadline)\n") ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() // 检查截止时间 ctxDeadline, ok := ctx.Deadline() fmt.Printf(" Context 截止时间:\n") fmt.Printf(" 截止时间: %v\n", ctxDeadline) fmt.Printf(" 有截止时间: %t\n", ok) fmt.Printf(" 距离截止时间: %v\n", time.Until(ctxDeadline)) // 等待截止时间到达 fmt.Printf(" 等待截止时间到达:\n") start := time.Now() <-ctx.Done() fmt.Printf(" Context 在 %v 后被取消,错误: %v\n", time.Since(start), ctx.Err()) // 演示截止时间检查 fmt.Printf(" 截止时间检查示例:\n") checkDeadline(context.Background()) deadlineCtx, deadlineCancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) defer deadlineCancel() checkDeadline(deadlineCtx) fmt.Println() } // demonstrateWithValue 演示 Context 传递值 func demonstrateWithValue() { fmt.Println("5. Context 传递值:") // WithValue 的使用 fmt.Printf(" WithValue 的使用:\n") fmt.Printf(" ctx := context.WithValue(context.Background(), \\\"userID\\\", 12345)\n") // 创建带值的 Context ctx := context.WithValue(context.Background(), "userID", 12345) ctx = context.WithValue(ctx, "requestID", "req-abc-123") ctx = context.WithValue(ctx, "traceID", "trace-xyz-789") // 获取值 fmt.Printf(" 获取 Context 中的值:\n") if userID := ctx.Value("userID"); userID != nil { fmt.Printf(" 用户ID: %v\n", userID) } if requestID := ctx.Value("requestID"); requestID != nil { fmt.Printf(" 请求ID: %v\n", requestID) } if traceID := ctx.Value("traceID"); traceID != nil { fmt.Printf(" 追踪ID: %v\n", traceID) } // 值不存在的情况 if sessionID := ctx.Value("sessionID"); sessionID != nil { fmt.Printf(" 会话ID: %v\n", sessionID) } else { fmt.Printf(" 会话ID: 不存在\n") } // 演示类型安全的键 fmt.Printf(" 类型安全的键:\n") type contextKey string const ( userIDKey contextKey = "userID" requestIDKey contextKey = "requestID" ) safeCtx := context.WithValue(context.Background(), userIDKey, 67890) safeCtx = context.WithValue(safeCtx, requestIDKey, "req-def-456") // 使用类型安全的方式获取值 if userID := safeCtx.Value(userIDKey); userID != nil { fmt.Printf(" 安全获取用户ID: %v\n", userID) } // 演示值的传播 fmt.Printf(" 值的传播:\n") processRequest(ctx) // 演示值的最佳实践 fmt.Printf(" 值的最佳实践:\n") fmt.Printf(" 1. 只存储请求范围的数据\n") fmt.Printf(" 2. 使用类型安全的键\n") fmt.Printf(" 3. 不要存储可选参数\n") fmt.Printf(" 4. 键应该是不可导出的\n") fmt.Println() } // demonstrateConcurrencyWithContext 演示 Context 在并发中的应用 func demonstrateConcurrencyWithContext() { fmt.Println("6. Context 在并发中的应用:") // 演示工作池模式 fmt.Printf(" 工作池模式:\n") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() // 创建工作通道 jobs := make(chan int, 10) results := make(chan string, 10) // 启动工作者 var wg sync.WaitGroup for i := 0; i < 3; i++ { wg.Add(1) go worker(ctx, i, jobs, results, &wg) } // 发送工作 go func() { for i := 1; i <= 8; i++ { select { case jobs <- i: fmt.Printf(" 发送工作 %d\n", i) case <-ctx.Done(): fmt.Printf(" 停止发送工作: %v\n", ctx.Err()) close(jobs) return } time.Sleep(100 * time.Millisecond) } close(jobs) }() // 收集结果 go func() { wg.Wait() close(results) }() // 打印结果 fmt.Printf(" 工作结果:\n") for result := range results { fmt.Printf(" %s\n", result) } // 演示扇出扇入模式 fmt.Printf(" 扇出扇入模式:\n") fanOutFanIn() fmt.Println() } // demonstrateHTTPContext 演示 Context 在 HTTP 中的应用 func demonstrateHTTPContext() { fmt.Println("7. Context 在 HTTP 中的应用:") // HTTP 请求中的 Context fmt.Printf(" HTTP 请求中的 Context:\n") fmt.Printf(" - 每个 HTTP 请求都有一个关联的 Context\n") fmt.Printf(" - 可以通过 r.Context() 获取\n") fmt.Printf(" - 请求取消时 Context 也会被取消\n") // 模拟 HTTP 处理器 fmt.Printf(" HTTP 处理器示例:\n") // 创建模拟请求 req, _ := http.NewRequest("GET", "/api/data", nil) // 添加超时 ctx, cancel := context.WithTimeout(req.Context(), 1*time.Second) defer cancel() req = req.WithContext(ctx) // 模拟处理请求 fmt.Printf(" 处理请求...\n") handleRequest(req) // 演示中间件模式 fmt.Printf(" 中间件模式:\n") fmt.Printf(" func middleware(next http.Handler) http.Handler {\n") fmt.Printf(" return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n") fmt.Printf(" ctx := context.WithValue(r.Context(), \\\"requestID\\\", generateID())\n") fmt.Printf(" next.ServeHTTP(w, r.WithContext(ctx))\n") fmt.Printf(" })\n") fmt.Printf(" }\n") // 演示数据库查询超时 fmt.Printf(" 数据库查询超时:\n") queryCtx, queryCancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer queryCancel() err := queryDatabase(queryCtx, "SELECT * FROM users WHERE id = ?", 123) if err != nil { fmt.Printf(" 数据库查询失败: %v\n", err) } else { fmt.Printf(" 数据库查询成功\n") } fmt.Println() } // demonstrateContextBestPractices 演示 Context 的最佳实践 func demonstrateContextBestPractices() { fmt.Println("8. Context 的最佳实践:") // 最佳实践列表 fmt.Printf(" Context 的最佳实践:\n") fmt.Printf(" 1. 不要将 Context 存储在结构体中\n") fmt.Printf(" 2. Context 应该作为函数的第一个参数\n") fmt.Printf(" 3. 不要传递 nil Context,使用 context.TODO()\n") fmt.Printf(" 4. Context.Value 只用于请求范围的数据\n") fmt.Printf(" 5. 使用 defer cancel() 确保资源释放\n") fmt.Printf(" 6. 不要忽略 Context 的取消信号\n") fmt.Printf(" 7. 使用类型安全的键\n") fmt.Printf(" 8. 避免在 Context 中存储可选参数\n") // 正确的函数签名 fmt.Printf(" 正确的函数签名:\n") fmt.Printf(" func DoSomething(ctx context.Context, arg string) error\n") fmt.Printf(" func (s *Service) Process(ctx context.Context, data []byte) (*Result, error)\n") // 错误的用法 fmt.Printf(" 避免的错误用法:\n") fmt.Printf(" // 错误:不要在结构体中存储 Context\n") fmt.Printf(" type Server struct {\n") fmt.Printf(" ctx context.Context // 错误\n") fmt.Printf(" }\n") fmt.Printf(" \n") fmt.Printf(" // 错误:不要传递 nil Context\n") fmt.Printf(" DoSomething(nil, \\\"data\\\") // 错误\n") // 正确的用法 fmt.Printf(" 正确的用法示例:\n") ctx := context.Background() // 正确的取消处理 processWithCancel(ctx) // 正确的超时处理 processWithTimeout(ctx) // 正确的值传递 processWithValue(ctx) // Context 的性能考虑 fmt.Printf(" 性能考虑:\n") fmt.Printf(" 1. Context.Value 的查找是 O(n) 的\n") fmt.Printf(" 2. 避免在热路径中频繁创建 Context\n") fmt.Printf(" 3. 合理使用 Context 的层次结构\n") fmt.Printf(" 4. 及时调用 cancel 函数释放资源\n") // 常见陷阱 fmt.Printf(" 常见陷阱:\n") fmt.Printf(" 1. 忘记调用 cancel 函数导致资源泄漏\n") fmt.Printf(" 2. 在循环中创建过多的 Context\n") fmt.Printf(" 3. 将 Context 用作可选参数的容器\n") fmt.Printf(" 4. 不检查 Context 的取消信号\n") fmt.Println() } // ========== 辅助函数 ========== // doWorkWithTimeout 带超时的工作函数 func doWorkWithTimeout(timeout time.Duration) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // 模拟工作 select { case <-time.After(200 * time.Millisecond): return "工作完成", nil case <-ctx.Done(): return "", ctx.Err() } } // makeHTTPRequestWithTimeout 带超时的 HTTP 请求 func makeHTTPRequestWithTimeout(url string, timeout time.Duration) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return err } // 模拟 HTTP 请求 select { case <-time.After(300 * time.Millisecond): return nil // 请求成功 case <-ctx.Done(): return ctx.Err() } } // checkDeadline 检查截止时间 func checkDeadline(ctx context.Context) { if deadline, ok := ctx.Deadline(); ok { fmt.Printf(" 有截止时间: %v,剩余时间: %v\n", deadline, time.Until(deadline)) } else { fmt.Printf(" 没有截止时间\n") } } // processRequest 处理请求 func processRequest(ctx context.Context) { if userID := ctx.Value("userID"); userID != nil { fmt.Printf(" 处理用户 %v 的请求\n", userID) } if requestID := ctx.Value("requestID"); requestID != nil { fmt.Printf(" 请求ID: %v\n", requestID) } } // worker 工作者函数 func worker(ctx context.Context, id int, jobs <-chan int, results chan<- string, wg *sync.WaitGroup) { defer wg.Done() for { select { case job, ok := <-jobs: if !ok { fmt.Printf(" 工作者 %d 退出(通道关闭)\n", id) return } // 模拟工作 select { case <-time.After(200 * time.Millisecond): results <- fmt.Sprintf("工作者 %d 完成工作 %d", id, job) case <-ctx.Done(): fmt.Printf(" 工作者 %d 被取消: %v\n", id, ctx.Err()) return } case <-ctx.Done(): fmt.Printf(" 工作者 %d 被取消: %v\n", id, ctx.Err()) return } } } // fanOutFanIn 扇出扇入模式 func fanOutFanIn() { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // 输入通道 input := make(chan int, 5) // 启动多个处理器(扇出) c1 := process(ctx, input) c2 := process(ctx, input) c3 := process(ctx, input) // 合并结果(扇入) output := merge(ctx, c1, c2, c3) // 发送数据 go func() { defer close(input) for i := 1; i <= 6; i++ { select { case input <- i: case <-ctx.Done(): return } } }() // 收集结果 for result := range output { fmt.Printf(" 扇出扇入结果: %s\n", result) } } // process 处理函数 func process(ctx context.Context, input <-chan int) <-chan string { output := make(chan string) go func() { defer close(output) for { select { case n, ok := <-input: if !ok { return } select { case output <- fmt.Sprintf("处理 %d", n*n): case <-ctx.Done(): return } case <-ctx.Done(): return } } }() return output } // merge 合并多个通道 func merge(ctx context.Context, channels ...<-chan string) <-chan string { var wg sync.WaitGroup output := make(chan string) // 为每个输入通道启动一个 goroutine multiplex := func(c <-chan string) { defer wg.Done() for { select { case s, ok := <-c: if !ok { return } select { case output <- s: case <-ctx.Done(): return } case <-ctx.Done(): return } } } wg.Add(len(channels)) for _, c := range channels { go multiplex(c) } // 等待所有 goroutine 完成 go func() { wg.Wait() close(output) }() return output } // handleRequest 处理 HTTP 请求 func handleRequest(req *http.Request) { ctx := req.Context() // 模拟处理 select { case <-time.After(500 * time.Millisecond): fmt.Printf(" 请求处理完成\n") case <-ctx.Done(): fmt.Printf(" 请求被取消: %v\n", ctx.Err()) } } // queryDatabase 查询数据库 func queryDatabase(ctx context.Context, query string, args ...interface{}) error { // 模拟数据库查询 select { case <-time.After(300 * time.Millisecond): return nil // 查询成功 case <-ctx.Done(): return ctx.Err() } } // processWithCancel 带取消的处理 func processWithCancel(ctx context.Context) { ctx, cancel := context.WithCancel(ctx) defer cancel() // 模拟处理 go func() { time.Sleep(100 * time.Millisecond) cancel() // 取消操作 }() select { case <-time.After(200 * time.Millisecond): fmt.Printf(" 处理完成\n") case <-ctx.Done(): fmt.Printf(" 处理被取消\n") } } // processWithTimeout 带超时的处理 func processWithTimeout(ctx context.Context) { ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) defer cancel() select { case <-time.After(100 * time.Millisecond): fmt.Printf(" 处理完成\n") case <-ctx.Done(): fmt.Printf(" 处理超时\n") } } // processWithValue 带值传递的处理 func processWithValue(ctx context.Context) { ctx = context.WithValue(ctx, "operation", "process") if op := ctx.Value("operation"); op != nil { fmt.Printf(" 执行操作: %v\n", op) } } /* 运行这个程序: go run 03-context.go Context 的核心概念: 1. Context 是 Go 语言中处理请求范围数据的标准方式 2. 提供取消信号、超时控制和值传递功能 3. 在 goroutine 之间传递取消信号和截止时间 4. 避免 goroutine 泄漏和资源浪费 Context 的主要类型: 1. context.Background(): 根 Context,通常用作起点 2. context.TODO(): 当不确定使用哪个 Context 时使用 3. context.WithCancel(): 创建可取消的 Context 4. context.WithTimeout(): 创建带超时的 Context 5. context.WithDeadline(): 创建带截止时间的 Context 6. context.WithValue(): 创建带值的 Context Context 的最佳实践: 1. 不要将 Context 存储在结构体中 2. Context 应该作为函数的第一个参数 3. 不要传递 nil Context,使用 context.TODO() 4. Context.Value 只用于请求范围的数据 5. 使用 defer cancel() 确保资源释放 6. 不要忽略 Context 的取消信号 7. 使用类型安全的键 8. 避免在 Context 中存储可选参数 Context 的应用场景: 1. HTTP 请求处理 2. 数据库查询超时 3. 并发任务协调 4. 微服务调用链 5. 长时间运行的任务控制 注意事项: 1. Context 是并发安全的 2. Context 的取消会传播到所有子 Context 3. Context.Value 的查找是 O(n) 的 4. 及时调用 cancel 函数释放资源 5. 不要在 Context 中存储可选参数 常见错误: 1. 忘记调用 cancel 函数 2. 将 Context 存储在结构体中 3. 传递 nil Context 4. 不检查 Context 的取消信号 5. 在 Context 中存储过多数据 */