feat: enhance completion service with session management and language support
- Introduced session management using Redis for tracking active sessions. - Added session claiming and releasing functionality in the completion manager. - Enhanced HTTP and WebSocket completion endpoints to support multiple languages. - Implemented request timeout and maximum body size configurations for API routes. - Updated client-side code to handle session IDs and language parameters in completion requests. - Improved error handling for unsupported languages and session conflicts. - Added tests for the completion manager to ensure proper session handling and cleanup.
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -24,6 +23,8 @@ var wsUpgrader = websocket.Upgrader{
|
||||
|
||||
type wsCompletionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Language string `json:"language,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
URI string `json:"uri"`
|
||||
Text string `json:"text"`
|
||||
Line int `json:"line"`
|
||||
@@ -34,17 +35,35 @@ type wsCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Items []completion.Item `json:"items,omitempty"`
|
||||
IsIncomplete bool `json:"isIncomplete,omitempty"`
|
||||
RouteTo string `json:"routeTo,omitempty"`
|
||||
OwnerID string `json:"ownerId,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func registerWSRoutes(router *gin.Engine, service CompletionService) {
|
||||
router.GET("/ws/completions/go", func(c *gin.Context) {
|
||||
type wsRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params"`
|
||||
}
|
||||
|
||||
type wsRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func registerWSRoutes(router *gin.Engine, service CompletionService, opts RouteOptions) {
|
||||
handler := func(c *gin.Context, defaultLanguage string) {
|
||||
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadLimit(opts.MaxBodyBytes)
|
||||
|
||||
var writeMu sync.Mutex
|
||||
|
||||
for {
|
||||
@@ -52,45 +71,173 @@ func registerWSRoutes(router *gin.Engine, service CompletionService) {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var req wsCompletionRequest
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: "",
|
||||
Error: "invalid JSON payload",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
go func(r wsCompletionRequest) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, completion.Request{
|
||||
URI: r.URI,
|
||||
Text: r.Text,
|
||||
Line: r.Line,
|
||||
Character: r.Character,
|
||||
})
|
||||
if err != nil {
|
||||
msg := "completion failed"
|
||||
if errors.Is(err, completion.ErrInvalidRequest) {
|
||||
msg = err.Error()
|
||||
}
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: r.ID,
|
||||
Error: msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: r.ID,
|
||||
Items: resp.Items,
|
||||
IsIncomplete: resp.IsIncomplete,
|
||||
})
|
||||
}(req)
|
||||
handleWSMessage(conn, &writeMu, service, payload, defaultLanguage, opts)
|
||||
}
|
||||
}
|
||||
|
||||
router.GET("/ws/completions", func(c *gin.Context) {
|
||||
handler(c, "")
|
||||
})
|
||||
router.GET("/ws/completions/:language", func(c *gin.Context) {
|
||||
handler(c, c.Param("language"))
|
||||
})
|
||||
}
|
||||
|
||||
func handleWSMessage(
|
||||
conn *websocket.Conn,
|
||||
writeMu *sync.Mutex,
|
||||
service CompletionService,
|
||||
payload []byte,
|
||||
defaultLanguage string,
|
||||
opts RouteOptions,
|
||||
) {
|
||||
if tryHandleRPCMessage(conn, writeMu, service, payload, defaultLanguage, opts) {
|
||||
return
|
||||
}
|
||||
|
||||
var req wsCompletionRequest
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
sendWSResponse(conn, writeMu, wsCompletionResponse{
|
||||
ID: "",
|
||||
Error: "invalid JSON payload",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Language == "" {
|
||||
req.Language = defaultLanguage
|
||||
}
|
||||
|
||||
processWSCompletion(conn, writeMu, service, req, opts)
|
||||
}
|
||||
|
||||
func tryHandleRPCMessage(
|
||||
conn *websocket.Conn,
|
||||
writeMu *sync.Mutex,
|
||||
service CompletionService,
|
||||
payload []byte,
|
||||
defaultLanguage string,
|
||||
opts RouteOptions,
|
||||
) bool {
|
||||
var rpcReq wsRPCRequest
|
||||
if err := json.Unmarshal(payload, &rpcReq); err != nil {
|
||||
return false
|
||||
}
|
||||
if rpcReq.JSONRPC != "2.0" || rpcReq.Method == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if rpcReq.Method != "completion/complete" && rpcReq.Method != "completion.complete" {
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Error: map[string]any{
|
||||
"code": -32601,
|
||||
"message": "method not found",
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
var req wsCompletionRequest
|
||||
if err := json.Unmarshal(rpcReq.Params, &req); err != nil {
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Error: map[string]any{
|
||||
"code": -32602,
|
||||
"message": "invalid params",
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
if req.Language == "" {
|
||||
req.Language = defaultLanguage
|
||||
}
|
||||
if req.ID == "" {
|
||||
req.ID = string(rpcReq.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, completion.Request{
|
||||
Language: req.Language,
|
||||
SessionID: req.SessionID,
|
||||
URI: req.URI,
|
||||
Text: req.Text,
|
||||
Line: req.Line,
|
||||
Character: req.Character,
|
||||
})
|
||||
if err != nil {
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Error: map[string]any{
|
||||
"code": -32000,
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Result: resp,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func processWSCompletion(
|
||||
conn *websocket.Conn,
|
||||
writeMu *sync.Mutex,
|
||||
service CompletionService,
|
||||
req wsCompletionRequest,
|
||||
opts RouteOptions,
|
||||
) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, completion.Request{
|
||||
Language: req.Language,
|
||||
SessionID: req.SessionID,
|
||||
URI: req.URI,
|
||||
Text: req.Text,
|
||||
Line: req.Line,
|
||||
Character: req.Character,
|
||||
})
|
||||
if err != nil {
|
||||
msg := "completion failed"
|
||||
routeTo := ""
|
||||
ownerID := ""
|
||||
switch {
|
||||
case errors.Is(err, completion.ErrInvalidRequest):
|
||||
msg = err.Error()
|
||||
case errors.Is(err, completion.ErrUnsupportedLanguage):
|
||||
msg = err.Error()
|
||||
case errors.Is(err, completion.ErrTooManySessions):
|
||||
msg = err.Error()
|
||||
default:
|
||||
var ownedErr *completion.ErrSessionOwnedByOtherInstance
|
||||
if errors.As(err, &ownedErr) {
|
||||
msg = err.Error()
|
||||
routeTo = ownedErr.OwnerEndpoint
|
||||
ownerID = ownedErr.OwnerID
|
||||
}
|
||||
}
|
||||
sendWSResponse(conn, writeMu, wsCompletionResponse{
|
||||
ID: req.ID,
|
||||
RouteTo: routeTo,
|
||||
OwnerID: ownerID,
|
||||
Error: msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sendWSResponse(conn, writeMu, wsCompletionResponse{
|
||||
ID: req.ID,
|
||||
Items: resp.Items,
|
||||
IsIncomplete: resp.IsIncomplete,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -99,3 +246,9 @@ func sendWSResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsCompletion
|
||||
defer writeMu.Unlock()
|
||||
_ = conn.WriteJSON(resp)
|
||||
}
|
||||
|
||||
func sendWSRPCResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsRPCResponse) {
|
||||
writeMu.Lock()
|
||||
defer writeMu.Unlock()
|
||||
_ = conn.WriteJSON(resp)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user