Security Guide - zhoudm1743/go-util GitHub Wiki
本指南专注于 Go-Util 的安全使用,涵盖输入验证、数据保护、认证授权、安全配置等关键安全要素。
// 多层安全验证示例
func SecureUserRegistration(input UserRegistrationInput) (*User, error) {
// 第一层:输入验证
if err := validateUserInput(input); err != nil {
return nil, fmt.Errorf("输入验证失败: %w", err)
}
// 第二层:业务规则验证
if err := validateBusinessRules(input); err != nil {
return nil, fmt.Errorf("业务规则验证失败: %w", err)
}
// 第三层:安全策略检查
if err := checkSecurityPolicy(input); err != nil {
return nil, fmt.Errorf("安全策略检查失败: %w", err)
}
// 第四层:数据清理和转义
cleanedInput := sanitizeUserInput(input)
return createUser(cleanedInput)
}
func validateUserInput(input UserRegistrationInput) error {
// 使用 Go-Util 进行输入验证
if util.Str(input.Email).IsBlank() || !util.Str(input.Email).IsEmail() {
return errors.New("邮箱格式无效")
}
if util.Str(input.Password).Len() < 8 {
return errors.New("密码长度不能少于8位")
}
// 密码强度检查
if !util.Str(input.Password).MatchRegex(`^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[@$!%*?&])[A-Za-z\d@$!%*?&]`) {
return errors.New("密码必须包含大小写字母、数字和特殊字符")
}
return nil
}
// 基于角色的访问控制
type SecurityContext struct {
UserID string
Roles []string
Permissions []string
SessionID string
IPAddress string
ExpiresAt *util.XTime
}
func (sc *SecurityContext) HasPermission(resource string, action string) bool {
requiredPermission := fmt.Sprintf("%s:%s", resource, action)
return util.ArraysFromSlice(sc.Permissions).Contains(requiredPermission)
}
func (sc *SecurityContext) HasRole(role string) bool {
return util.ArraysFromSlice(sc.Roles).Contains(role)
}
func (sc *SecurityContext) IsExpired() bool {
return util.Now().After(sc.ExpiresAt)
}
// 安全装饰器
func RequirePermission(permission string) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := getSecurityContext(r)
if ctx == nil || ctx.IsExpired() {
http.Error(w, "未授权", http.StatusUnauthorized)
return
}
if !ctx.HasPermission(permission, r.Method) {
http.Error(w, "权限不足", http.StatusForbidden)
return
}
next(w, r)
}
}
}
// 安全的字符串处理工具集
type SecureStringProcessor struct {
maxLength int
allowedChars string
forbiddenWords []string
}
func NewSecureStringProcessor() *SecureStringProcessor {
return &SecureStringProcessor{
maxLength: 1000,
allowedChars: `^[a-zA-Z0-9\s\-_.@]+$`,
forbiddenWords: []string{"script", "javascript", "vbscript", "onload", "onerror"},
}
}
func (ssp *SecureStringProcessor) ProcessUserInput(input string) (string, error) {
// 步骤1:长度检查
if util.Str(input).Len() > ssp.maxLength {
return "", fmt.Errorf("输入长度超过限制(%d)", ssp.maxLength)
}
// 步骤2:空值检查
if util.Str(input).IsBlank() {
return "", errors.New("输入不能为空")
}
// 步骤3:基础清理
cleaned := util.Str(input).
Trim(). // 去除首尾空白
ReplaceRegex(`\s+`, " "). // 规范化空格
ReplaceRegex(`[^\x20-\x7E]`, ""). // 移除非ASCII字符
String()
// 步骤4:恶意内容检测
if err := ssp.detectMaliciousContent(cleaned); err != nil {
return "", err
}
// 步骤5:字符白名单验证
if !util.Str(cleaned).MatchRegex(ssp.allowedChars) {
return "", errors.New("输入包含非法字符")
}
// 步骤6:HTML转义
escaped := util.Str(cleaned).HTMLEscape().String()
return escaped, nil
}
func (ssp *SecureStringProcessor) detectMaliciousContent(input string) error {
lowerInput := util.Str(input).Lower().String()
// 检查禁用词汇
for _, word := range ssp.forbiddenWords {
if util.Str(lowerInput).Contains(word) {
return fmt.Errorf("输入包含禁用内容: %s", word)
}
}
// XSS 检测
xssPatterns := []string{
`<script.*?>.*?</script>`,
`javascript:`,
`on\w+\s*=`,
`<iframe.*?>`,
`<object.*?>`,
`<embed.*?>`,
}
for _, pattern := range xssPatterns {
if util.Str(lowerInput).MatchRegex(pattern) {
return errors.New("检测到潜在的XSS攻击")
}
}
// SQL注入检测
sqlPatterns := []string{
`(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER)\b)`,
`(\b(UNION|OR|AND)\b.*\b(SELECT|INSERT|UPDATE|DELETE)\b)`,
`('|\"|;|--|\*|\/\*)`,
}
for _, pattern := range sqlPatterns {
if util.Str(lowerInput).MatchRegex(pattern) {
return errors.New("检测到潜在的SQL注入攻击")
}
}
return nil
}
// 安全的文件上传处理
type SecureFileUploader struct {
allowedTypes []string
maxFileSize int64
allowedExts []string
uploadPath string
virusScanner VirusScanner
}
func NewSecureFileUploader() *SecureFileUploader {
return &SecureFileUploader{
allowedTypes: []string{"image/jpeg", "image/png", "image/gif", "application/pdf"},
maxFileSize: 10 * 1024 * 1024, // 10MB
allowedExts: []string{".jpg", ".jpeg", ".png", ".gif", ".pdf"},
uploadPath: "/secure/uploads",
}
}
func (sfu *SecureFileUploader) ProcessUpload(fileHeader *multipart.FileHeader) (*UploadResult, error) {
// 步骤1:文件大小检查
if fileHeader.Size > sfu.maxFileSize {
return nil, fmt.Errorf("文件大小超过限制(%d bytes)", sfu.maxFileSize)
}
// 步骤2:文件扩展名检查
ext := util.Str(fileHeader.Filename).
ReplaceRegex(`.*\.`, ".").
Lower().
String()
if !util.ArraysFromSlice(sfu.allowedExts).Contains(ext) {
return nil, fmt.Errorf("不支持的文件类型: %s", ext)
}
// 步骤3:MIME类型检查
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("打开文件失败: %w", err)
}
defer file.Close()
buffer := make([]byte, 512)
if _, err := file.Read(buffer); err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
mimeType := http.DetectContentType(buffer)
if !util.ArraysFromSlice(sfu.allowedTypes).Contains(mimeType) {
return nil, fmt.Errorf("无效的MIME类型: %s", mimeType)
}
// 步骤4:文件名安全化
safeFilename := sfu.sanitizeFilename(fileHeader.Filename)
// 步骤5:病毒扫描
if sfu.virusScanner != nil {
if err := sfu.virusScanner.Scan(file); err != nil {
return nil, fmt.Errorf("病毒扫描失败: %w", err)
}
}
// 步骤6:保存文件
finalPath := filepath.Join(sfu.uploadPath, safeFilename)
if err := sfu.saveFile(file, finalPath); err != nil {
return nil, fmt.Errorf("保存文件失败: %w", err)
}
return &UploadResult{
Filename: safeFilename,
Path: finalPath,
Size: fileHeader.Size,
MimeType: mimeType,
UploadedAt: util.Now(),
}, nil
}
func (sfu *SecureFileUploader) sanitizeFilename(filename string) string {
// 移除路径分隔符和危险字符
safe := util.Str(filename).
ReplaceRegex(`[/\\:*?"<>|]`, "").
ReplaceRegex(`\.\.+`, ".").
Trim().
String()
// 添加时间戳避免重名
timestamp := util.Now().FormatCustom("20060102_150405")
name := util.Str(safe).ReplaceRegex(`\.[^.]*$`, "").String()
ext := util.Str(safe).ReplaceRegex(`.*\.`, ".").String()
return fmt.Sprintf("%s_%s%s", name, timestamp, ext)
}
// 安全的密码处理器
type PasswordSecurity struct {
minLength int
requireUppercase bool
requireLowercase bool
requireDigits bool
requireSymbols bool
maxAge time.Duration
history int
}
func NewPasswordSecurity() *PasswordSecurity {
return &PasswordSecurity{
minLength: 12,
requireUppercase: true,
requireLowercase: true,
requireDigits: true,
requireSymbols: true,
maxAge: 90 * 24 * time.Hour, // 90天
history: 5, // 记住最近5个密码
}
}
func (ps *PasswordSecurity) ValidatePassword(password string) error {
pwdStr := util.Str(password)
// 长度检查
if pwdStr.Len() < ps.minLength {
return fmt.Errorf("密码长度不能少于%d位", ps.minLength)
}
// 复杂度检查
var patterns []string
var requirements []string
if ps.requireLowercase {
patterns = append(patterns, `[a-z]`)
requirements = append(requirements, "小写字母")
}
if ps.requireUppercase {
patterns = append(patterns, `[A-Z]`)
requirements = append(requirements, "大写字母")
}
if ps.requireDigits {
patterns = append(patterns, `\d`)
requirements = append(requirements, "数字")
}
if ps.requireSymbols {
patterns = append(patterns, `[!@#$%^&*(),.?":{}|<>]`)
requirements = append(requirements, "特殊字符")
}
for i, pattern := range patterns {
if !pwdStr.MatchRegex(pattern) {
return fmt.Errorf("密码必须包含%s", requirements[i])
}
}
// 常见密码检查
if err := ps.checkCommonPasswords(password); err != nil {
return err
}
// 重复字符检查
if ps.hasRepeatingCharacters(password) {
return errors.New("密码不能包含连续重复的字符")
}
return nil
}
func (ps *PasswordSecurity) checkCommonPasswords(password string) error {
commonPasswords := []string{
"password", "123456", "password123", "admin", "qwerty",
"12345678", "123456789", "password1", "abc123", "Password1",
}
lowerPwd := util.Str(password).Lower().String()
for _, common := range commonPasswords {
if lowerPwd == common || util.Str(lowerPwd).Contains(common) {
return errors.New("密码过于简单,请使用更复杂的密码")
}
}
return nil
}
func (ps *PasswordSecurity) hasRepeatingCharacters(password string) bool {
// 检查连续重复字符(如aaa, 111)
return util.Str(password).MatchRegex(`(.)\1{2,}`)
}
// 安全的密码哈希
func (ps *PasswordSecurity) HashPassword(password string) (string, error) {
// 使用 bcrypt 进行密码哈希
salt := ps.generateSalt()
// 组合密码和盐
combined := password + salt
// 计算哈希
hashedBytes := sha256.Sum256([]byte(combined))
hashedPassword := util.Str(fmt.Sprintf("%x", hashedBytes)).
Base64Encode().
String()
// 返回 salt:hash 格式
return fmt.Sprintf("%s:%s", salt, hashedPassword), nil
}
func (ps *PasswordSecurity) VerifyPassword(password, storedHash string) bool {
parts := util.Str(storedHash).Split(":")
if len(parts) != 2 {
return false
}
salt := parts[0]
expectedHash := parts[1]
// 重新计算哈希
combined := password + salt
hashedBytes := sha256.Sum256([]byte(combined))
actualHash := util.Str(fmt.Sprintf("%x", hashedBytes)).
Base64Encode().
String()
return actualHash == expectedHash
}
func (ps *PasswordSecurity) generateSalt() string {
return util.RandomString(16)
}
// 安全的会话管理器
type SecureSessionManager struct {
sessions *util.SafeMap[string, *SecureSession]
config *SessionConfig
cleaner *SessionCleaner
}
type SecureSession struct {
ID string
UserID string
CreatedAt *util.XTime
LastAccess *util.XTime
ExpiresAt *util.XTime
IPAddress string
UserAgent string
Data *util.SafeMap[string, interface{}]
IsActive bool
LoginAttempts int
}
type SessionConfig struct {
MaxAge time.Duration
IdleTimeout time.Duration
MaxSessions int
SecureCookies bool
HttpOnlyCookies bool
SameSiteStrict bool
RegenerateID bool
}
func NewSecureSessionManager(config *SessionConfig) *SecureSessionManager {
ssm := &SecureSessionManager{
sessions: util.NewSafeMap[string, *SecureSession](),
config: config,
cleaner: NewSessionCleaner(),
}
// 启动清理协程
go ssm.startCleanupRoutine()
return ssm
}
func (ssm *SecureSessionManager) CreateSession(userID, ipAddress, userAgent string) (*SecureSession, error) {
// 检查用户的现有会话数
userSessions := ssm.getUserSessions(userID)
if len(userSessions) >= ssm.config.MaxSessions {
// 清理最旧的会话
oldestSession := ssm.findOldestSession(userSessions)
ssm.DestroySession(oldestSession.ID)
}
session := &SecureSession{
ID: ssm.generateSecureSessionID(),
UserID: userID,
CreatedAt: util.Now(),
LastAccess: util.Now(),
ExpiresAt: util.Now().Add(ssm.config.MaxAge),
IPAddress: ipAddress,
UserAgent: userAgent,
Data: util.NewSafeMap[string, interface{}](),
IsActive: true,
}
ssm.sessions.Set(session.ID, session)
return session, nil
}
func (ssm *SecureSessionManager) ValidateSession(sessionID, ipAddress, userAgent string) (*SecureSession, error) {
session, exists := ssm.sessions.SafeGet(sessionID)
if !exists {
return nil, errors.New("会话不存在")
}
// 检查会话是否过期
if util.Now().After(session.ExpiresAt) {
ssm.DestroySession(sessionID)
return nil, errors.New("会话已过期")
}
// 检查会话是否空闲超时
if util.Now().DiffTime(session.LastAccess) > ssm.config.IdleTimeout {
ssm.DestroySession(sessionID)
return nil, errors.New("会话空闲超时")
}
// IP地址验证(可选)
if session.IPAddress != ipAddress {
ssm.DestroySession(sessionID)
return nil, errors.New("IP地址不匹配,会话已终止")
}
// User-Agent验证(检测会话劫持)
if session.UserAgent != userAgent {
ssm.DestroySession(sessionID)
return nil, errors.New("User-Agent不匹配,疑似会话劫持")
}
// 更新最后访问时间
session.LastAccess = util.Now()
// 可选:定期重新生成会话ID
if ssm.config.RegenerateID && util.Now().DiffTime(session.CreatedAt) > time.Hour {
newSessionID := ssm.generateSecureSessionID()
ssm.sessions.Delete(sessionID)
session.ID = newSessionID
ssm.sessions.Set(newSessionID, session)
}
return session, nil
}
func (ssm *SecureSessionManager) generateSecureSessionID() string {
// 生成强随机会话ID
timestamp := util.Now().TimestampNano()
random := util.RandomString(32)
combined := fmt.Sprintf("%d_%s", timestamp, random)
return util.Str(combined).SHA256().String()
}
func (ssm *SecureSessionManager) getUserSessions(userID string) []*SecureSession {
var userSessions []*SecureSession
ssm.sessions.ForEach(func(id string, session *SecureSession) bool {
if session.UserID == userID && session.IsActive {
userSessions = append(userSessions, session)
}
return true
})
return userSessions
}
func (ssm *SecureSessionManager) startCleanupRoutine() {
ticker := time.NewTicker(time.Minute * 5)
defer ticker.Stop()
for range ticker.C {
ssm.cleanupExpiredSessions()
}
}
func (ssm *SecureSessionManager) cleanupExpiredSessions() {
now := util.Now()
expiredSessions := util.NewArray[string]()
ssm.sessions.ForEach(func(id string, session *SecureSession) bool {
if now.After(session.ExpiresAt) ||
now.DiffTime(session.LastAccess) > ssm.config.IdleTimeout {
expiredSessions.Append(id)
}
return true
})
expiredSessions.ForEach(func(i int, sessionID string) bool {
ssm.DestroySession(sessionID)
return true
})
}
// 安全事件检测器
type SecurityEventDetector struct {
rules []SecurityRule
alertManager AlertManager
logger SecurityLogger
metrics SecurityMetrics
}
type SecurityRule interface {
Evaluate(event SecurityEvent) (bool, string)
Severity() SecuritySeverity
Name() string
}
type SecurityEvent struct {
ID string
Timestamp *util.XTime
EventType string
UserID string
IPAddress string
UserAgent string
Resource string
Action string
Status string
Details map[string]interface{}
Risk RiskLevel
}
// 登录异常检测规则
type LoginAnomalyRule struct {
maxAttempts int
timeWindow time.Duration
geoRestriction bool
}
func (lar *LoginAnomalyRule) Evaluate(event SecurityEvent) (bool, string) {
if event.EventType != "login_attempt" {
return false, ""
}
// 检查短时间内的登录尝试次数
if attempts, ok := event.Details["recent_attempts"].(int); ok {
if attempts > lar.maxAttempts {
return true, fmt.Sprintf("用户在%v内尝试登录%d次", lar.timeWindow, attempts)
}
}
// 检查地理位置异常
if lar.geoRestriction {
if country, ok := event.Details["country"].(string); ok {
allowedCountries := []string{"CN", "US", "GB"}
if !util.ArraysFromSlice(allowedCountries).Contains(country) {
return true, fmt.Sprintf("来自非允许地区的登录尝试: %s", country)
}
}
}
return false, ""
}
// 数据访问异常检测规则
type DataAccessAnomalyRule struct {
normalHours [2]int // 正常工作时间
maxDataVolume int64 // 最大数据访问量
}
func (daar *DataAccessAnomalyRule) Evaluate(event SecurityEvent) (bool, string) {
if event.EventType != "data_access" {
return false, ""
}
// 检查访问时间
hour := event.Timestamp.Hour()
if hour < daar.normalHours[0] || hour > daar.normalHours[1] {
return true, fmt.Sprintf("非工作时间数据访问: %d点", hour)
}
// 检查数据访问量
if volume, ok := event.Details["data_volume"].(int64); ok {
if volume > daar.maxDataVolume {
return true, fmt.Sprintf("数据访问量异常: %d bytes", volume)
}
}
return false, ""
}
func (sed *SecurityEventDetector) ProcessEvent(event SecurityEvent) {
// 评估所有安全规则
for _, rule := range sed.rules {
if triggered, message := rule.Evaluate(event); triggered {
alert := SecurityAlert{
ID: generateAlertID(),
Timestamp: util.Now(),
RuleName: rule.Name(),
Severity: rule.Severity(),
Event: event,
Message: message,
}
// 发送告警
sed.alertManager.SendAlert(alert)
// 记录安全日志
sed.logger.LogSecurityEvent(alert)
// 更新指标
sed.metrics.RecordSecurityAlert(rule.Name(), rule.Severity())
// 根据严重程度采取行动
sed.takeAction(alert)
}
}
}
func (sed *SecurityEventDetector) takeAction(alert SecurityAlert) {
switch alert.Severity {
case SeverityCritical:
// 严重威胁:立即阻断
sed.blockUser(alert.Event.UserID, "安全威胁检测")
sed.blockIP(alert.Event.IPAddress, "恶意行为检测")
case SeverityHigh:
// 高风险:增加监控
sed.increaseMonitoring(alert.Event.UserID)
case SeverityMedium:
// 中等风险:记录和通知
sed.notifySecurityTeam(alert)
case SeverityLow:
// 低风险:仅记录
// 无需额外行动
}
}
// 数据脱敏处理器
type DataMaskingProcessor struct {
rules map[string]MaskingRule
config *MaskingConfig
}
type MaskingRule interface {
Mask(value string) string
ShouldMask(fieldName string, context string) bool
}
// 邮箱脱敏规则
type EmailMaskingRule struct{}
func (emr *EmailMaskingRule) Mask(email string) string {
if !util.Str(email).IsEmail() {
return "***"
}
parts := util.Str(email).Split("@")
if len(parts) != 2 {
return "***"
}
username := parts[0]
domain := parts[1]
if util.Str(username).Len() <= 2 {
return fmt.Sprintf("***@%s", domain)
}
maskedUsername := util.Str(username).
Left(2).
Concat("***").
String()
return fmt.Sprintf("%s@%s", maskedUsername, domain)
}
func (emr *EmailMaskingRule) ShouldMask(fieldName string, context string) bool {
emailFields := []string{"email", "mail", "email_address"}
fieldLower := util.Str(fieldName).Lower().String()
return util.ArraysFromSlice(emailFields).Contains(fieldLower)
}
// 手机号脱敏规则
type PhoneMaskingRule struct{}
func (pmr *PhoneMaskingRule) Mask(phone string) string {
cleaned := util.Str(phone).ReplaceRegex(`[^\d]`, "").String()
if util.Str(cleaned).Len() < 7 {
return "***"
}
if util.Str(cleaned).Len() == 11 { // 中国手机号
return util.Str(cleaned).
Left(3).
Concat("****").
Concat(util.Str(cleaned).Right(4).String()).
String()
}
// 通用处理
length := util.Str(cleaned).Len()
if length <= 4 {
return "***"
}
return util.Str(cleaned).
Left(2).
Concat("***").
Concat(util.Str(cleaned).Right(2).String()).
String()
}
// 身份证号脱敏规则
type IDCardMaskingRule struct{}
func (idmr *IDCardMaskingRule) Mask(idCard string) string {
cleaned := util.Str(idCard).ReplaceRegex(`[^\dxX]`, "").String()
if util.Str(cleaned).Len() != 18 && util.Str(cleaned).Len() != 15 {
return "***"
}
return util.Str(cleaned).
Left(6).
Concat("********").
Concat(util.Str(cleaned).Right(4).String()).
String()
}
// 批量数据脱敏
func (dmp *DataMaskingProcessor) MaskData(data map[string]interface{}, context string) map[string]interface{} {
return util.MapFromNative(data).
MapValues(func(value interface{}) interface{} {
if str, ok := value.(string); ok {
return dmp.maskStringValue(str, context)
}
return value
}).
ToMap()
}
func (dmp *DataMaskingProcessor) maskStringValue(value, context string) string {
for fieldPattern, rule := range dmp.rules {
if rule.ShouldMask(fieldPattern, context) {
return rule.Mask(value)
}
}
return value
}
// 敏感数据检测
func (dmp *DataMaskingProcessor) DetectSensitiveData(text string) []SensitiveDataMatch {
var matches []SensitiveDataMatch
// 检测邮箱
emailPattern := `[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`
emailMatches := util.Str(text).FindAllRegex(emailPattern)
for _, match := range emailMatches {
matches = append(matches, SensitiveDataMatch{
Type: "email",
Value: match,
Start: util.Str(text).Index(match),
End: util.Str(text).Index(match) + util.Str(match).Len(),
})
}
// 检测手机号
phonePattern := `1[3-9]\d{9}`
phoneMatches := util.Str(text).FindAllRegex(phonePattern)
for _, match := range phoneMatches {
matches = append(matches, SensitiveDataMatch{
Type: "phone",
Value: match,
Start: util.Str(text).Index(match),
End: util.Str(text).Index(match) + util.Str(match).Len(),
})
}
// 检测身份证号
idPattern := `\b\d{17}[\dxX]\b`
idMatches := util.Str(text).FindAllRegex(idPattern)
for _, match := range idMatches {
matches = append(matches, SensitiveDataMatch{
Type: "id_card",
Value: match,
Start: util.Str(text).Index(match),
End: util.Str(text).Index(match) + util.Str(match).Len(),
})
}
return matches
}
// API 安全防护中间件
type APISecurityMiddleware struct {
rateLimiter *RateLimiter
validator *RequestValidator
authenticator *Authenticator
authorizer *Authorizer
logger SecurityLogger
}
func (asm *APISecurityMiddleware) SecureAPIHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := util.Now()
// 1. 请求限流检查
if !asm.rateLimiter.AllowRequest(r) {
asm.logger.LogSecurityEvent("rate_limit_exceeded", r, nil)
http.Error(w, "请求过于频繁", http.StatusTooManyRequests)
return
}
// 2. 请求验证
if err := asm.validator.ValidateRequest(r); err != nil {
asm.logger.LogSecurityEvent("invalid_request", r, err)
http.Error(w, "请求验证失败", http.StatusBadRequest)
return
}
// 3. 身份认证
user, err := asm.authenticator.Authenticate(r)
if err != nil {
asm.logger.LogSecurityEvent("authentication_failed", r, err)
http.Error(w, "认证失败", http.StatusUnauthorized)
return
}
// 4. 权限授权
if !asm.authorizer.Authorize(user, r) {
asm.logger.LogSecurityEvent("authorization_failed", r, user)
http.Error(w, "权限不足", http.StatusForbidden)
return
}
// 5. 添加安全头
asm.addSecurityHeaders(w)
// 6. 执行业务逻辑
next(w, r)
// 7. 记录访问日志
duration := util.Now().DiffTime(startTime)
asm.logger.LogAPIAccess(user, r, duration)
}
}
func (asm *APISecurityMiddleware) addSecurityHeaders(w http.ResponseWriter) {
// 防止 XSS 攻击
w.Header().Set("X-XSS-Protection", "1; mode=block")
// 防止内容类型嗅探
w.Header().Set("X-Content-Type-Options", "nosniff")
// 防止页面被嵌入到 iframe 中
w.Header().Set("X-Frame-Options", "DENY")
// 强制 HTTPS
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
// 内容安全策略
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'")
// 引用者策略
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// 权限策略
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
}
// 请求验证器
type RequestValidator struct {
maxBodySize int64
allowedMethods []string
requiredHeaders []string
}
func (rv *RequestValidator) ValidateRequest(r *http.Request) error {
// 检查 HTTP 方法
if !util.ArraysFromSlice(rv.allowedMethods).Contains(r.Method) {
return fmt.Errorf("不支持的HTTP方法: %s", r.Method)
}
// 检查请求体大小
if r.ContentLength > rv.maxBodySize {
return fmt.Errorf("请求体过大: %d bytes", r.ContentLength)
}
// 检查必需的请求头
for _, header := range rv.requiredHeaders {
if r.Header.Get(header) == "" {
return fmt.Errorf("缺少必需的请求头: %s", header)
}
}
// 检查 Content-Type
if r.Method == "POST" || r.Method == "PUT" {
contentType := r.Header.Get("Content-Type")
if contentType == "" {
return errors.New("缺少 Content-Type 请求头")
}
// 验证 Content-Type 格式
if !util.Str(contentType).MatchRegex(`^application/(json|x-www-form-urlencoded)|multipart/form-data`) {
return fmt.Errorf("不支持的 Content-Type: %s", contentType)
}
}
// 检查请求URL
if err := rv.validateURL(r.URL); err != nil {
return fmt.Errorf("URL验证失败: %w", err)
}
return nil
}
func (rv *RequestValidator) validateURL(url *url.URL) error {
// 检查路径注入
if util.Str(url.Path).Contains("..") {
return errors.New("检测到路径遍历攻击")
}
// 检查查询参数
for key, values := range url.Query() {
for _, value := range values {
if err := rv.validateQueryParameter(key, value); err != nil {
return err
}
}
}
return nil
}
func (rv *RequestValidator) validateQueryParameter(key, value string) error {
// 检查参数名
if !util.Str(key).MatchRegex(`^[a-zA-Z0-9_-]+$`) {
return fmt.Errorf("无效的查询参数名: %s", key)
}
// 检查参数值长度
if util.Str(value).Len() > 1000 {
return fmt.Errorf("查询参数值过长: %s", key)
}
// 检查潜在的注入攻击
sqlInjectionPatterns := []string{
`(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER)\b)`,
`(\b(UNION|OR|AND)\b.*\b(SELECT|INSERT|UPDATE|DELETE)\b)`,
`('|\"|;|--|\*|\/\*)`,
}
lowerValue := util.Str(value).Lower().String()
for _, pattern := range sqlInjectionPatterns {
if util.Str(lowerValue).MatchRegex(pattern) {
return fmt.Errorf("检测到潜在的SQL注入攻击在参数 %s", key)
}
}
return nil
}
// 安全指标收集器
type SecurityMetricsCollector struct {
metrics *util.SafeMap[string, *SecurityMetric]
alerts []SecurityAlert
reports []SecurityReport
config *SecurityMetricsConfig
}
type SecurityMetric struct {
Name string
Value float64
Timestamp *util.XTime
Tags map[string]string
Threshold float64
IsAnomalous bool
}
func (smc *SecurityMetricsCollector) RecordSecurityEvent(eventType string, severity string, details map[string]interface{}) {
metricName := fmt.Sprintf("security_events_%s", eventType)
// 递增事件计数
smc.incrementMetric(metricName, map[string]string{
"severity": severity,
"type": eventType,
})
// 检查异常阈值
if metric, exists := smc.metrics.SafeGet(metricName); exists {
if metric.Value > metric.Threshold {
smc.generateAlert(metricName, metric)
}
}
}
func (smc *SecurityMetricsCollector) RecordLoginMetrics(userID string, success bool, ipAddress string) {
tags := map[string]string{
"user_id": userID,
"ip_address": ipAddress,
"success": fmt.Sprintf("%t", success),
}
if success {
smc.incrementMetric("successful_logins", tags)
} else {
smc.incrementMetric("failed_logins", tags)
// 检查失败登录频率
smc.checkFailedLoginThreshold(userID, ipAddress)
}
}
func (smc *SecurityMetricsCollector) checkFailedLoginThreshold(userID, ipAddress string) {
// 获取过去5分钟的失败登录次数
window := 5 * time.Minute
failedCount := smc.getFailedLoginCount(userID, ipAddress, window)
if failedCount > 5 {
alert := SecurityAlert{
ID: generateAlertID(),
Timestamp: util.Now(),
Type: "brute_force_attack",
Severity: SeverityHigh,
Message: fmt.Sprintf("用户 %s 从 %s 在5分钟内失败登录%d次", userID, ipAddress, failedCount),
Details: map[string]interface{}{
"user_id": userID,
"ip_address": ipAddress,
"failed_count": failedCount,
"time_window": window.String(),
},
}
smc.alerts = append(smc.alerts, alert)
}
}
// 生成安全报告
func (smc *SecurityMetricsCollector) GenerateSecurityReport(period string) *SecurityReport {
endTime := util.Now()
var startTime *util.XTime
switch period {
case "daily":
startTime = endTime.SubDays(1)
case "weekly":
startTime = endTime.SubDays(7)
case "monthly":
startTime = endTime.SubDays(30)
default:
startTime = endTime.SubDays(1)
}
report := &SecurityReport{
ID: generateReportID(),
Period: period,
StartTime: startTime,
EndTime: endTime,
GeneratedAt: util.Now(),
}
// 收集期间内的安全指标
report.Summary = smc.generateSecuritySummary(startTime, endTime)
report.ThreatAnalysis = smc.analyzeThreatPatterns(startTime, endTime)
report.Recommendations = smc.generateRecommendations(report.Summary)
return report
}
func (smc *SecurityMetricsCollector) generateSecuritySummary(startTime, endTime *util.XTime) SecuritySummary {
// 获取时间范围内的所有指标
var totalEvents int
var criticalEvents int
var blockedIPs []string
var topAttackTypes []string
smc.metrics.ForEach(func(name string, metric *SecurityMetric) bool {
if metric.Timestamp.Between(startTime, endTime) {
totalEvents++
if severity, ok := metric.Tags["severity"]; ok && severity == "critical" {
criticalEvents++
}
if attackType, ok := metric.Tags["type"]; ok {
topAttackTypes = append(topAttackTypes, attackType)
}
}
return true
})
// 统计最频繁的攻击类型
attackFrequency := util.NewMap[string, int]()
for _, attackType := range topAttackTypes {
current := attackFrequency.GetOr(attackType, 0)
attackFrequency.Set(attackType, current+1)
}
topAttacks := attackFrequency.
ToPairs().
SortWith(func(p1, p2 util.Pair[string, int]) bool {
return p1.Value > p2.Value
}).
Take(5).
Map(func(p util.Pair[string, int]) string {
return p.Key
}).
ToSlice()
return SecuritySummary{
TotalSecurityEvents: totalEvents,
CriticalEvents: criticalEvents,
BlockedIPs: len(blockedIPs),
TopAttackTypes: topAttacks,
SecurityScore: smc.calculateSecurityScore(totalEvents, criticalEvents),
}
}
如果您在安全实施中遇到问题:
- 🔍 查看FAQ - 常见安全问题解答
- 🐛 报告问题 - Bug反馈
- 💡 功能建议 - 新功能讨论
- 📧 邮件支持 - [email protected]
- 🔒 安全专线 - 提供专业的安全咨询服务
🔐 安全是一个持续的过程,Go-Util 为您提供坚实的安全基础!