package api import ( "context" "encoding/json" "errors" "net/http" "sync" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "go.uber.org/zap" "monica-go-completion-backend/internal/completion" ) var wsUpgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(_ *http.Request) bool { return true }, } // wsCompletionRequest 是普通 WS 消息的补全请求格式。 type wsCompletionRequest struct { ID string `json:"id"` Language string `json:"language,omitempty"` SessionID string `json:"sessionId,omitempty"` URI string `json:"uri"` Text string `json:"text"` Line int `json:"line"` Character int `json:"character"` } // wsCompletionResponse 是普通 WS 消息的补全响应格式。 type wsCompletionResponse struct { ID string `json:"id"` Items []completion.Item `json:"items,omitempty"` IsIncomplete bool `json:"isIncomplete,omitempty"` RouteTo string `json:"routeTo,omitempty"` OwnerID string `json:"ownerId,omitempty"` Error string `json:"error,omitempty"` } // wsRPCRequest/wsRPCResponse 用于兼容 JSON-RPC 2.0 客户端。 type wsRPCRequest struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id"` Method string `json:"method"` Params json.RawMessage `json:"params"` } type wsRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id"` Result any `json:"result,omitempty"` Error any `json:"error,omitempty"` } // registerWSRoutes 注册 WebSocket 补全入口(含可选语言路由)。 func registerWSRoutes(router *gin.Engine, service CompletionService, opts RouteOptions) { handler := func(c *gin.Context, defaultLanguage string) { conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { zap.L().Warn( "ws upgrade failed", zap.String("traceId", traceIDFromGin(c)), zap.String("path", c.Request.URL.Path), zap.String("clientIP", c.ClientIP()), zap.Error(err), ) return } defer conn.Close() traceID := traceIDFromGin(c) zap.L().Info( "ws connection opened", zap.String("traceId", traceID), zap.String("path", c.Request.URL.Path), zap.String("language", defaultLanguage), zap.String("clientIP", c.ClientIP()), ) defer func() { zap.L().Info( "ws connection closed", zap.String("traceId", traceID), zap.String("path", c.Request.URL.Path), zap.String("language", defaultLanguage), zap.String("clientIP", c.ClientIP()), ) }() conn.SetReadLimit(opts.MaxBodyBytes) var writeMu sync.Mutex for { // 单连接串行读取消息,写操作通过 writeMu 保证并发安全。 _, payload, err := conn.ReadMessage() if err != nil { zap.L().Info( "ws read loop ended", zap.String("traceId", traceID), zap.String("path", c.Request.URL.Path), zap.Error(err), ) break } handleWSMessage(conn, &writeMu, service, payload, defaultLanguage, traceID, c.ClientIP(), opts) } } router.GET("/ws/completions", func(c *gin.Context) { handler(c, "") }) router.GET("/ws/completions/:language", func(c *gin.Context) { handler(c, c.Param("language")) }) } // handleWSMessage 先尝试按 JSON-RPC 处理;失败后回退到普通 JSON 协议。 func handleWSMessage( conn *websocket.Conn, writeMu *sync.Mutex, service CompletionService, payload []byte, defaultLanguage string, traceID string, clientIP string, opts RouteOptions, ) { if tryHandleRPCMessage(conn, writeMu, service, payload, defaultLanguage, traceID, clientIP, opts) { return } var req wsCompletionRequest if err := json.Unmarshal(payload, &req); err != nil { zap.L().Warn( "ws invalid json payload", zap.String("traceId", traceID), zap.String("clientIP", clientIP), zap.Error(err), ) sendWSResponse(conn, writeMu, wsCompletionResponse{ ID: "", Error: "invalid JSON payload", }) return } if req.Language == "" { req.Language = defaultLanguage } processWSCompletion(conn, writeMu, service, req, traceID, clientIP, opts) } // tryHandleRPCMessage 处理 JSON-RPC 2.0 请求,返回 true 表示消息已消费。 func tryHandleRPCMessage( conn *websocket.Conn, writeMu *sync.Mutex, service CompletionService, payload []byte, defaultLanguage string, traceID string, clientIP string, opts RouteOptions, ) bool { var rpcReq wsRPCRequest if err := json.Unmarshal(payload, &rpcReq); err != nil { return false } if rpcReq.JSONRPC != "2.0" || rpcReq.Method == "" { return false } if rpcReq.Method != "completion/complete" && rpcReq.Method != "completion.complete" { // 非补全方法按 JSON-RPC 规范返回 method not found。 sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Error: map[string]any{ "code": -32601, "message": "method not found", }, }) return true } var req wsCompletionRequest if err := json.Unmarshal(rpcReq.Params, &req); err != nil { zap.L().Warn( "ws rpc invalid params", zap.String("traceId", traceID), zap.String("clientIP", clientIP), zap.String("method", rpcReq.Method), zap.Error(err), ) // 参数反序列化失败按 invalid params 处理。 sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Error: map[string]any{ "code": -32602, "message": "invalid params", }, }) return true } if req.Language == "" { req.Language = defaultLanguage } if req.ID == "" { // 兼容未在 params 提供业务 ID 的客户端。 req.ID = string(rpcReq.ID) } ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout) defer cancel() resp, err := service.Complete(ctx, completion.Request{ Language: req.Language, SessionID: req.SessionID, URI: req.URI, Text: req.Text, Line: req.Line, Character: req.Character, }) if err != nil { zap.L().Warn( "ws rpc completion failed", zap.String("traceId", traceID), zap.String("clientIP", clientIP), zap.String("method", rpcReq.Method), zap.String("language", req.Language), zap.String("sessionId", req.SessionID), zap.String("uri", req.URI), zap.Error(err), ) sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Error: map[string]any{ "code": -32000, "message": err.Error(), }, }) return true } sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Result: resp, }) return true } // processWSCompletion 处理普通 WS 协议下的补全请求。 func processWSCompletion( conn *websocket.Conn, writeMu *sync.Mutex, service CompletionService, req wsCompletionRequest, traceID string, clientIP string, opts RouteOptions, ) { ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout) defer cancel() resp, err := service.Complete(ctx, completion.Request{ Language: req.Language, SessionID: req.SessionID, URI: req.URI, Text: req.Text, Line: req.Line, Character: req.Character, }) if err != nil { msg := "completion failed" routeTo := "" ownerID := "" switch { case errors.Is(err, completion.ErrInvalidRequest): msg = err.Error() case errors.Is(err, completion.ErrUnsupportedLanguage): msg = err.Error() case errors.Is(err, completion.ErrTooManySessions): msg = err.Error() default: var ownedErr *completion.ErrSessionOwnedByOtherInstance if errors.As(err, &ownedErr) { // 会话在其他实例上时返回路由提示,客户端可重连对应节点。 msg = err.Error() routeTo = ownedErr.OwnerEndpoint ownerID = ownedErr.OwnerID } } zap.L().Warn( "ws completion failed", zap.String("traceId", traceID), zap.String("clientIP", clientIP), zap.String("language", req.Language), zap.String("sessionId", req.SessionID), zap.String("uri", req.URI), zap.String("routeTo", routeTo), zap.String("ownerId", ownerID), zap.Error(err), ) sendWSResponse(conn, writeMu, wsCompletionResponse{ ID: req.ID, RouteTo: routeTo, OwnerID: ownerID, Error: msg, }) return } zap.L().Info( "ws completion success", zap.String("traceId", traceID), zap.String("clientIP", clientIP), zap.String("language", req.Language), zap.String("sessionId", req.SessionID), zap.String("uri", req.URI), zap.Int("items", len(resp.Items)), ) sendWSResponse(conn, writeMu, wsCompletionResponse{ ID: req.ID, Items: resp.Items, IsIncomplete: resp.IsIncomplete, }) } // sendWSResponse 统一串行写回普通 WS 响应。 func sendWSResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsCompletionResponse) { writeMu.Lock() defer writeMu.Unlock() _ = conn.WriteJSON(resp) } // sendWSRPCResponse 统一串行写回 JSON-RPC 响应。 func sendWSRPCResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsRPCResponse) { writeMu.Lock() defer writeMu.Unlock() _ = conn.WriteJSON(resp) }