package ws import ( "encoding/json" "log" "time" "github.com/gorilla/websocket" ) const ( // 写入超时 writeWait = 10 * time.Second // 读取pong超时 pongWait = 60 * time.Second // pong发送后等待下一次ping的间隔 pingPeriod = (pongWait * 9) / 10 // 最大消息大小 maxMessageSize = 65536 ) // Client WebSocket客户端 type Client struct { Hub *Hub Conn *websocket.Conn Send chan []byte UserID string SessionID string } // NewClient 创建WebSocket客户端 func NewClient(hub *Hub, conn *websocket.Conn, userID, sessionID string) *Client { return &Client{ Hub: hub, Conn: conn, Send: make(chan []byte, 256), UserID: userID, SessionID: sessionID, } } // ReadPump 读取协程 —— 从WebSocket连接读取消息 func (c *Client) ReadPump(onMessage func(client *Client, msg ClientMessage)) { defer func() { c.Hub.unregister <- c c.Conn.Close() }() c.Conn.SetReadLimit(maxMessageSize) c.Conn.SetReadDeadline(time.Now().Add(pongWait)) c.Conn.SetPongHandler(func(string) error { c.Conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, rawMessage, err := c.Conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { log.Printf("[WS] 读取错误: user=%s err=%v", c.UserID, err) } break } // 解析消息 var msg ClientMessage if err := json.Unmarshal(rawMessage, &msg); err != nil { log.Printf("[WS] 消息解析失败: user=%s err=%v", c.UserID, err) continue } // 处理ping if msg.Type == "ping" { pongMsg := ServerMessage{ Type: "pong", Timestamp: time.Now().UnixMilli(), } data, _ := json.Marshal(pongMsg) c.Send <- data continue } // 调用消息处理器 if onMessage != nil { onMessage(c, msg) } } } // WritePump 写入协程 —— 向WebSocket连接写入消息 func (c *Client) WritePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.Conn.Close() }() for { select { case message, ok := <-c.Send: c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // Hub关闭了通道 c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil { log.Printf("[WS] 写入错误: user=%s err=%v", c.UserID, err) return } case <-ticker.C: c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // SendMessage 向客户端发送消息 func (c *Client) SendMessage(msg ServerMessage) error { data, err := json.Marshal(msg) if err != nil { return err } select { case c.Send <- data: return nil default: // 通道满:记录警告并返回错误(避免静默丢弃) log.Printf("[WS] 发送通道已满,丢弃消息: type=%s user=%s", msg.Type, c.UserID) return nil } }