148 lines
3.8 KiB
Go
148 lines
3.8 KiB
Go
/*
|
||
middleware.go - 中间件
|
||
实现了各种HTTP中间件功能
|
||
*/
|
||
|
||
package server
|
||
|
||
import (
|
||
"log"
|
||
"net/http"
|
||
"runtime/debug"
|
||
"time"
|
||
)
|
||
|
||
// LoggingMiddleware 日志记录中间件
|
||
func LoggingMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
start := time.Now()
|
||
|
||
// 创建响应记录器来捕获状态码
|
||
recorder := &responseRecorder{
|
||
ResponseWriter: w,
|
||
statusCode: http.StatusOK,
|
||
}
|
||
|
||
// 调用下一个处理器
|
||
next.ServeHTTP(recorder, r)
|
||
|
||
// 记录请求日志
|
||
duration := time.Since(start)
|
||
log.Printf("[%s] %s %s %d %v",
|
||
r.Method,
|
||
r.RequestURI,
|
||
r.RemoteAddr,
|
||
recorder.statusCode,
|
||
duration,
|
||
)
|
||
})
|
||
}
|
||
|
||
// responseRecorder 响应记录器
|
||
type responseRecorder struct {
|
||
http.ResponseWriter
|
||
statusCode int
|
||
}
|
||
|
||
// WriteHeader 记录状态码
|
||
func (rr *responseRecorder) WriteHeader(code int) {
|
||
rr.statusCode = code
|
||
rr.ResponseWriter.WriteHeader(code)
|
||
}
|
||
|
||
// CORSMiddleware CORS中间件
|
||
func CORSMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 设置CORS头
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||
|
||
// 处理预检请求
|
||
if r.Method == "OPTIONS" {
|
||
w.WriteHeader(http.StatusOK)
|
||
return
|
||
}
|
||
|
||
// 调用下一个处理器
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
// RecoveryMiddleware 恢复中间件(处理panic)
|
||
func RecoveryMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
// 记录panic信息
|
||
log.Printf("❌ Panic recovered: %v\n%s", err, debug.Stack())
|
||
|
||
// 返回500错误
|
||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||
}
|
||
}()
|
||
|
||
// 调用下一个处理器
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
// AuthMiddleware 认证中间件(示例)
|
||
func AuthMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 检查Authorization头
|
||
token := r.Header.Get("Authorization")
|
||
if token == "" {
|
||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 这里可以添加实际的token验证逻辑
|
||
// 为了演示,我们简单检查token是否为"Bearer valid-token"
|
||
if token != "Bearer valid-token" {
|
||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
|
||
// 调用下一个处理器
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
// RateLimitMiddleware 限流中间件(简化版)
|
||
func RateLimitMiddleware(next http.Handler) http.Handler {
|
||
// 这里可以实现基于IP的限流逻辑
|
||
// 为了简化,我们只是一个示例框架
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 实际实现中,这里会检查请求频率
|
||
// 如果超过限制,返回429状态码
|
||
|
||
// 调用下一个处理器
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
// ContentTypeMiddleware 内容类型中间件
|
||
func ContentTypeMiddleware(contentType string) func(http.Handler) http.Handler {
|
||
return func(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", contentType)
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
}
|
||
|
||
// SecurityMiddleware 安全头中间件
|
||
func SecurityMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 设置安全相关的HTTP头
|
||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||
w.Header().Set("X-Frame-Options", "DENY")
|
||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||
|
||
// 调用下一个处理器
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|