package api import ( "context" "encoding/json" "errors" "net/http" "sync" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "monica-go-completion-backend/internal/completion" ) var wsUpgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(_ *http.Request) bool { return true }, } type wsCompletionRequest struct { ID string `json:"id"` Language string `json:"language,omitempty"` SessionID string `json:"sessionId,omitempty"` URI string `json:"uri"` Text string `json:"text"` Line int `json:"line"` Character int `json:"character"` } type wsCompletionResponse struct { ID string `json:"id"` Items []completion.Item `json:"items,omitempty"` IsIncomplete bool `json:"isIncomplete,omitempty"` RouteTo string `json:"routeTo,omitempty"` OwnerID string `json:"ownerId,omitempty"` Error string `json:"error,omitempty"` } type wsRPCRequest struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id"` Method string `json:"method"` Params json.RawMessage `json:"params"` } type wsRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id"` Result any `json:"result,omitempty"` Error any `json:"error,omitempty"` } func registerWSRoutes(router *gin.Engine, service CompletionService, opts RouteOptions) { handler := func(c *gin.Context, defaultLanguage string) { conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { return } defer conn.Close() conn.SetReadLimit(opts.MaxBodyBytes) var writeMu sync.Mutex for { _, payload, err := conn.ReadMessage() if err != nil { break } handleWSMessage(conn, &writeMu, service, payload, defaultLanguage, opts) } } router.GET("/ws/completions", func(c *gin.Context) { handler(c, "") }) router.GET("/ws/completions/:language", func(c *gin.Context) { handler(c, c.Param("language")) }) } func handleWSMessage( conn *websocket.Conn, writeMu *sync.Mutex, service CompletionService, payload []byte, defaultLanguage string, opts RouteOptions, ) { if tryHandleRPCMessage(conn, writeMu, service, payload, defaultLanguage, opts) { return } var req wsCompletionRequest if err := json.Unmarshal(payload, &req); err != nil { sendWSResponse(conn, writeMu, wsCompletionResponse{ ID: "", Error: "invalid JSON payload", }) return } if req.Language == "" { req.Language = defaultLanguage } processWSCompletion(conn, writeMu, service, req, opts) } func tryHandleRPCMessage( conn *websocket.Conn, writeMu *sync.Mutex, service CompletionService, payload []byte, defaultLanguage string, opts RouteOptions, ) bool { var rpcReq wsRPCRequest if err := json.Unmarshal(payload, &rpcReq); err != nil { return false } if rpcReq.JSONRPC != "2.0" || rpcReq.Method == "" { return false } if rpcReq.Method != "completion/complete" && rpcReq.Method != "completion.complete" { sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Error: map[string]any{ "code": -32601, "message": "method not found", }, }) return true } var req wsCompletionRequest if err := json.Unmarshal(rpcReq.Params, &req); err != nil { sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Error: map[string]any{ "code": -32602, "message": "invalid params", }, }) return true } if req.Language == "" { req.Language = defaultLanguage } if req.ID == "" { req.ID = string(rpcReq.ID) } ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout) defer cancel() resp, err := service.Complete(ctx, completion.Request{ Language: req.Language, SessionID: req.SessionID, URI: req.URI, Text: req.Text, Line: req.Line, Character: req.Character, }) if err != nil { sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Error: map[string]any{ "code": -32000, "message": err.Error(), }, }) return true } sendWSRPCResponse(conn, writeMu, wsRPCResponse{ JSONRPC: "2.0", ID: rpcReq.ID, Result: resp, }) return true } func processWSCompletion( conn *websocket.Conn, writeMu *sync.Mutex, service CompletionService, req wsCompletionRequest, opts RouteOptions, ) { ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout) defer cancel() resp, err := service.Complete(ctx, completion.Request{ Language: req.Language, SessionID: req.SessionID, URI: req.URI, Text: req.Text, Line: req.Line, Character: req.Character, }) if err != nil { msg := "completion failed" routeTo := "" ownerID := "" switch { case errors.Is(err, completion.ErrInvalidRequest): msg = err.Error() case errors.Is(err, completion.ErrUnsupportedLanguage): msg = err.Error() case errors.Is(err, completion.ErrTooManySessions): msg = err.Error() default: var ownedErr *completion.ErrSessionOwnedByOtherInstance if errors.As(err, &ownedErr) { msg = err.Error() routeTo = ownedErr.OwnerEndpoint ownerID = ownedErr.OwnerID } } sendWSResponse(conn, writeMu, wsCompletionResponse{ ID: req.ID, RouteTo: routeTo, OwnerID: ownerID, Error: msg, }) return } sendWSResponse(conn, writeMu, wsCompletionResponse{ ID: req.ID, Items: resp.Items, IsIncomplete: resp.IsIncomplete, }) } func sendWSResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsCompletionResponse) { writeMu.Lock() defer writeMu.Unlock() _ = conn.WriteJSON(resp) } func sendWSRPCResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsRPCResponse) { writeMu.Lock() defer writeMu.Unlock() _ = conn.WriteJSON(resp) }