From 57afb90bc00e45c06cc0e5293f3edf501c065821 Mon Sep 17 00:00:00 2001 From: meowrain Date: Sun, 15 Feb 2026 16:22:01 +0800 Subject: [PATCH] feat: enhance completion service with session management and language support - Introduced session management using Redis for tracking active sessions. - Added session claiming and releasing functionality in the completion manager. - Enhanced HTTP and WebSocket completion endpoints to support multiple languages. - Implemented request timeout and maximum body size configurations for API routes. - Updated client-side code to handle session IDs and language parameters in completion requests. - Improved error handling for unsupported languages and session conflicts. - Added tests for the completion manager to ensure proper session handling and cleanup. --- backend/README.md | 134 +++++---- backend/cmd/server/main.go | 243 +++++++++++++-- backend/go.mod | 3 + backend/go.sum | 10 + backend/internal/api/handler.go | 81 ++++- backend/internal/api/ws_handler.go | 235 ++++++++++++--- backend/internal/cluster/redis_registry.go | 222 ++++++++++++++ backend/internal/completion/manager.go | 314 ++++++++++++++++++++ backend/internal/completion/manager_test.go | 125 ++++++++ backend/internal/completion/service.go | 2 + backend/internal/lsp/client.go | 40 ++- src/api/completion.ts | 41 ++- src/components/MonacoEditor.vue | 18 +- src/types/completion.ts | 4 + 14 files changed, 1334 insertions(+), 138 deletions(-) create mode 100644 backend/internal/cluster/redis_registry.go create mode 100644 backend/internal/completion/manager.go create mode 100644 backend/internal/completion/manager_test.go diff --git a/backend/README.md b/backend/README.md index 96530ee..873df50 100644 --- a/backend/README.md +++ b/backend/README.md @@ -1,6 +1,14 @@ -# Monaco Go Completion Backend +# Monica LSP Gateway -Gin + gopls(JSON-RPC/LSP over stdio) 实现的 Go 代码补全后端。 +面向微服务场景的 LSP 网关(Gin),当前支持: +- Go: `gopls` +- JavaScript / TypeScript: `typescript-language-server --stdio` + +网关职责: +- 对外提供 HTTP + WebSocket 接口 +- 对内通过 JSON-RPC / LSP over stdio 驱动语言服务器 +- 按 `language + sessionId` 维护长生命周期会话 +- 会话空闲自动回收(TTL) ## 运行 @@ -9,80 +17,102 @@ go mod tidy go run ./cmd/server ``` -默认监听 `http://localhost:8080`。 +默认地址:`http://127.0.0.1:8080` -## 环境变量 +## 企业化配置(环境变量) -- `PORT`:HTTP 端口,默认 `8080` -- `GOPLS_PATH`:`gopls` 可执行文件路径,默认 `gopls` -- `WORKSPACE_DIR`:gopls 工作目录,默认当前目录 -- `CORS_ALLOW_ORIGIN`:CORS 允许来源,默认 `*` +- `PORT`:默认 `8080` +- `WORKSPACE_DIR`:LSP 工作目录,默认当前目录 +- `CORS_ALLOW_ORIGIN`:默认 `*` +- `LSP_API_TOKEN`:可选,设置后需在请求头传 `X-API-Key` +- `REQUEST_TIMEOUT`:单请求超时,默认 `10s` +- `MAX_BODY_BYTES`:请求体上限,默认 `2097152`(2MB) +- `SESSION_TTL`:会话空闲回收时间,默认 `20m` +- `SESSION_CLEANUP_INTERVAL`:清理周期,默认 `2m` +- `MAX_SESSIONS`:最大会话数,默认 `256` +- `ENABLE_REDIS_STICKY`:是否启用 Redis 会话外置与粘性路由,默认 `true` +- `REDIS_ADDR`:默认 `10.0.0.10:6379` +- `REDIS_DB`:默认 `1` +- `REDIS_PASSWORD`:默认空 +- `REDIS_KEY_PREFIX`:默认 `lsp-gateway` +- `INSTANCE_ID`:实例唯一标识(建议由部署系统注入) +- `INSTANCE_URL`:实例可回源地址(用于路由提示),默认 `http://127.0.0.1:${PORT}` +- `INSTANCE_TTL`:实例注册 TTL,默认 `30s` +- `INSTANCE_HEARTBEAT_INTERVAL`:实例心跳周期,默认 `10s` -## API +语言服务器命令(可替换为企业内部镜像/封装): +- `GO_LSP_COMMAND`,`GO_LSP_ARGS` +- `JAVASCRIPT_LSP_COMMAND`,`JAVASCRIPT_LSP_ARGS` +- `TYPESCRIPT_LSP_COMMAND`,`TYPESCRIPT_LSP_ARGS` -### 健康检查 +默认 JS/TS 命令: +- `typescript-language-server --stdio` + +## 健康检查 - `GET /health` +- `GET /health/live` +- `GET /health/ready`(含当前会话统计) -### Go 补全 +## HTTP 补全接口 -- `POST /api/v1/completions/go` -- `Content-Type: application/json` +- `POST /api/v1/completions/:language` -### Go 补全(WebSocket) +示例(JavaScript): -- `GET /ws/completions/go` -- 客户端发送: +```json +{ + "sessionId": "editor-1", + "language": "javascript", + "uri": "file:///main.js", + "text": "const a = console.lo", + "line": 0, + "character": 20 +} +``` + +## WebSocket 接口 + +- `GET /ws/completions` +- `GET /ws/completions/:language` + +支持两种消息格式: + +1) 简化消息: ```json { "id": "1", + "sessionId": "editor-1", + "language": "go", "uri": "file:///main.go", - "text": "package main\n\nimport \"fmt\"\n\nfunc main() {\n fmt.Pri\n}", + "text": "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Pri\n}", "line": 5, - "character": 11 + "character": 9 } ``` -- 服务端返回: +2) JSON-RPC 风格: ```json { - "id": "1", - "items": [{ "label": "Println" }], - "isIncomplete": false + "jsonrpc": "2.0", + "id": 1, + "method": "completion/complete", + "params": { + "sessionId": "editor-1", + "language": "typescript", + "uri": "file:///main.ts", + "text": "console.lo", + "line": 0, + "character": 10 + } } ``` -请求体: +## 备注 -```json -{ - "uri": "file:///C:/Users/meowr/Desktop/bishe/monica_editor_with_code_completion/backend/playground.go", - "text": "package main\n\nimport \"fmt\"\n\nfunc main() {\n fmt.Pri\n}", - "line": 5, - "character": 11 -} -``` - -说明: -- `uri` 建议使用绝对 `file://` URI(与 `WORKSPACE_DIR` 在同一工作区最稳定)。 -- 也支持 `file:///main.go` 这类相对根路径 URI,后端会自动映射到 `WORKSPACE_DIR` 下。 - `line` / `character` 为 0-based。 -- 每次请求都要传前端当前全文 `text`。 - -响应体: - -```json -{ - "items": [ - { - "label": "Println", - "kind": 3, - "detail": "func(a ...any) (n int, err error)", - "documentation": "Println formats using the default formats..." - } - ], - "isIncomplete": false -} -``` +- `uri` 支持 `file:///main.go` 这类相对根路径,网关会映射到 `WORKSPACE_DIR` 下。 +- 前端建议稳定传 `sessionId`,避免频繁新建语言服务器进程。 +- 多实例场景下,如果请求落到非会话所属实例,服务会返回 `409` 并携带 `routeTo` 与 `X-LSP-Route-To`。 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 77403ff..1528e99 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -2,45 +2,103 @@ package main import ( "context" + "fmt" "log" "net/http" "os" "os/signal" + "os/user" + "strconv" + "strings" + "sync/atomic" "syscall" "time" "github.com/gin-gonic/gin" "monica-go-completion-backend/internal/api" + "monica-go-completion-backend/internal/cluster" "monica-go-completion-backend/internal/completion" "monica-go-completion-backend/internal/lsp" ) +var requestIDSeed atomic.Int64 + type config struct { - Port string - GoplsPath string - WorkspaceDir string - AllowOrigin string + Port string + WorkspaceDir string + AllowOrigin string + APIToken string + RequestTimeout time.Duration + MaxBodyBytes int64 + SessionTTL time.Duration + CleanupInterval time.Duration + MaxSessions int + InstanceID string + InstanceURL string + EnableRedis bool + RedisAddr string + RedisPassword string + RedisDB int + RedisKeyPrefix string + InstanceTTL time.Duration + Heartbeat time.Duration + Servers []completion.LanguageServerSpec } func main() { cfg := loadConfig() - lspClient, err := lsp.NewClient(context.Background(), cfg.GoplsPath, cfg.WorkspaceDir) - if err != nil { - log.Fatalf("create gopls client failed: %v", err) + var registry completion.SessionRegistry + if cfg.EnableRedis { + var err error + registry, err = cluster.NewRedisRegistry(context.Background(), cluster.RedisRegistryConfig{ + Addr: cfg.RedisAddr, + Password: cfg.RedisPassword, + DB: cfg.RedisDB, + KeyPrefix: cfg.RedisKeyPrefix, + InstanceID: cfg.InstanceID, + InstanceEndpoint: cfg.InstanceURL, + SessionTTL: cfg.SessionTTL, + InstanceTTL: cfg.InstanceTTL, + HeartbeatInterval: cfg.Heartbeat, + }) + if err != nil { + log.Fatalf("create redis registry failed: %v", err) + } } - defer func() { - _ = lspClient.Close() - }() - completionService := completion.NewService(lspClient) + manager := completion.NewManager(completion.ManagerConfig{ + WorkspaceDir: cfg.WorkspaceDir, + MaxSessions: cfg.MaxSessions, + SessionTTL: cfg.SessionTTL, + CleanupInterval: cfg.CleanupInterval, + InstanceID: cfg.InstanceID, + Registry: registry, + }, cfg.Servers, func(ctx context.Context, spec completion.LanguageServerSpec, workspaceDir string) (completion.RuntimeClient, error) { + return lsp.NewClient(ctx, lsp.Config{ + Command: spec.Command, + Args: spec.Args, + RootPath: workspaceDir, + LanguageID: spec.LanguageID, + ClientName: "monica-lsp-gateway", + }) + }) + defer func() { + _ = manager.Close() + }() router := gin.New() router.Use(gin.Logger()) router.Use(gin.Recovery()) + router.Use(requestIDMiddleware()) router.Use(corsMiddleware(cfg.AllowOrigin)) - api.RegisterRoutes(router, completionService) + router.Use(apiTokenMiddleware(cfg.APIToken)) + + api.RegisterRoutes(router, manager, api.RouteOptions{ + RequestTimeout: cfg.RequestTimeout, + MaxBodyBytes: cfg.MaxBodyBytes, + }) server := &http.Server{ Addr: ":" + cfg.Port, @@ -51,7 +109,13 @@ func main() { } go func() { - log.Printf("completion backend listening on :%s", cfg.Port) + log.Printf( + "lsp gateway listening on :%s, workspace=%s, instanceID=%s, redis=%t", + cfg.Port, + cfg.WorkspaceDir, + cfg.InstanceID, + cfg.EnableRedis, + ) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("http server failed: %v", err) } @@ -61,12 +125,15 @@ func main() { signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) <-sigCh - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := server.Shutdown(shutdownCtx); err != nil { log.Printf("http shutdown failed: %v", err) } + if err := manager.Close(); err != nil { + log.Printf("lsp manager close failed: %v", err) + } } func loadConfig() config { @@ -74,27 +141,133 @@ func loadConfig() config { if err != nil { cwd = "." } + return config{ - Port: getenv("PORT", "8080"), - GoplsPath: getenv("GOPLS_PATH", "gopls"), - WorkspaceDir: getenv("WORKSPACE_DIR", cwd), - AllowOrigin: getenv("CORS_ALLOW_ORIGIN", "*"), + Port: getenv("PORT", "8080"), + WorkspaceDir: getenv("WORKSPACE_DIR", cwd), + AllowOrigin: getenv("CORS_ALLOW_ORIGIN", "*"), + APIToken: strings.TrimSpace(os.Getenv("LSP_API_TOKEN")), + RequestTimeout: getenvDuration("REQUEST_TIMEOUT", 10*time.Second), + MaxBodyBytes: getenvInt64("MAX_BODY_BYTES", 2<<20), + SessionTTL: getenvDuration("SESSION_TTL", 20*time.Minute), + CleanupInterval: getenvDuration("SESSION_CLEANUP_INTERVAL", 2*time.Minute), + MaxSessions: getenvInt("MAX_SESSIONS", 256), + InstanceID: getenv("INSTANCE_ID", defaultInstanceID()), + InstanceURL: getenv("INSTANCE_URL", "http://127.0.0.1:"+getenv("PORT", "8080")), + EnableRedis: getenvBool("ENABLE_REDIS_STICKY", true), + RedisAddr: getenv("REDIS_ADDR", "10.0.0.10:6379"), + RedisPassword: getenv("REDIS_PASSWORD", ""), + RedisDB: getenvInt("REDIS_DB", 1), + RedisKeyPrefix: getenv("REDIS_KEY_PREFIX", "lsp-gateway"), + InstanceTTL: getenvDuration("INSTANCE_TTL", 30*time.Second), + Heartbeat: getenvDuration("INSTANCE_HEARTBEAT_INTERVAL", 10*time.Second), + Servers: []completion.LanguageServerSpec{ + { + Language: "go", + LanguageID: "go", + Command: getenv("GO_LSP_COMMAND", "gopls"), + Args: getenvArgs("GO_LSP_ARGS", ""), + }, + { + Language: "javascript", + LanguageID: "javascript", + Command: getenv("JAVASCRIPT_LSP_COMMAND", "typescript-language-server"), + Args: getenvArgs("JAVASCRIPT_LSP_ARGS", "--stdio"), + }, + { + Language: "typescript", + LanguageID: "typescript", + Command: getenv("TYPESCRIPT_LSP_COMMAND", "typescript-language-server"), + Args: getenvArgs("TYPESCRIPT_LSP_ARGS", "--stdio"), + }, + }, } } +func defaultInstanceID() string { + host, err := os.Hostname() + if err != nil || strings.TrimSpace(host) == "" { + host = "unknown-host" + } + userName := "unknown-user" + if u, err := user.Current(); err == nil && strings.TrimSpace(u.Username) != "" { + userName = strings.ReplaceAll(u.Username, "\\", "-") + } + return fmt.Sprintf("%s-%s-%d", host, userName, os.Getpid()) +} + func getenv(key, fallback string) string { - v := os.Getenv(key) + v := strings.TrimSpace(os.Getenv(key)) if v == "" { return fallback } return v } +func getenvArgs(key, fallback string) []string { + value := getenv(key, fallback) + if value == "" { + return nil + } + return strings.Fields(value) +} + +func getenvInt(key string, fallback int) int { + v := strings.TrimSpace(os.Getenv(key)) + if v == "" { + return fallback + } + parsed, err := strconv.Atoi(v) + if err == nil { + return parsed + } + return fallback +} + +func getenvInt64(key string, fallback int64) int64 { + v := strings.TrimSpace(os.Getenv(key)) + if v == "" { + return fallback + } + parsed, err := strconv.ParseInt(v, 10, 64) + if err == nil { + return parsed + } + return fallback +} + +func getenvDuration(key string, fallback time.Duration) time.Duration { + v := strings.TrimSpace(os.Getenv(key)) + if v == "" { + return fallback + } + parsed, err := time.ParseDuration(v) + if err != nil { + return fallback + } + return parsed +} + +func getenvBool(key string, fallback bool) bool { + v := strings.TrimSpace(strings.ToLower(os.Getenv(key))) + if v == "" { + return fallback + } + switch v { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return fallback + } +} + func corsMiddleware(allowOrigin string) gin.HandlerFunc { return func(c *gin.Context) { c.Header("Access-Control-Allow-Origin", allowOrigin) c.Header("Access-Control-Allow-Methods", "GET,POST,OPTIONS") - c.Header("Access-Control-Allow-Headers", "Content-Type,Authorization") + c.Header("Access-Control-Allow-Headers", "Content-Type,Authorization,X-API-Key") if c.Request.Method == http.MethodOptions { c.AbortWithStatus(http.StatusNoContent) @@ -103,3 +276,33 @@ func corsMiddleware(allowOrigin string) gin.HandlerFunc { c.Next() } } + +func apiTokenMiddleware(token string) gin.HandlerFunc { + required := strings.TrimSpace(token) + if required == "" { + return func(c *gin.Context) { + c.Next() + } + } + + return func(c *gin.Context) { + provided := strings.TrimSpace(c.GetHeader("X-API-Key")) + if provided == required { + c.Next() + return + } + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + } +} + +func requestIDMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + rid := strings.TrimSpace(c.GetHeader("X-Request-Id")) + if rid == "" { + rid = fmt.Sprintf("req-%d", requestIDSeed.Add(1)) + } + c.Header("X-Request-Id", rid) + c.Set("requestId", rid) + c.Next() + } +} diff --git a/backend/go.mod b/backend/go.mod index 807c6ff..bbc1dd2 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,12 +5,15 @@ go 1.25 require ( github.com/gin-gonic/gin v1.11.0 github.com/gorilla/websocket v1.5.3 + github.com/redis/go-redis/v9 v9.17.3 ) require ( github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index 16eca5e..7e4d4c0 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,12 +1,20 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= @@ -50,6 +58,8 @@ github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg= github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= +github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4= +github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/backend/internal/api/handler.go b/backend/internal/api/handler.go index 7f79779..39df0f6 100644 --- a/backend/internal/api/handler.go +++ b/backend/internal/api/handler.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "time" "github.com/gin-gonic/gin" @@ -14,30 +15,100 @@ type CompletionService interface { Complete(ctx context.Context, req completion.Request) (completion.Response, error) } -func RegisterRoutes(router *gin.Engine, service CompletionService) { +type SessionStatsProvider interface { + ActiveSessions() map[string]int +} + +type RouteOptions struct { + RequestTimeout time.Duration + MaxBodyBytes int64 +} + +func RegisterRoutes(router *gin.Engine, service CompletionService, options ...RouteOptions) { + opts := RouteOptions{ + RequestTimeout: 10 * time.Second, + MaxBodyBytes: 2 << 20, // 2MB + } + if len(options) > 0 { + if options[0].RequestTimeout > 0 { + opts.RequestTimeout = options[0].RequestTimeout + } + if options[0].MaxBodyBytes > 0 { + opts.MaxBodyBytes = options[0].MaxBodyBytes + } + } + router.GET("/health", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"status": "ok"}) + c.JSON(http.StatusOK, gin.H{"status": "ok", "service": "lsp-gateway"}) }) - registerWSRoutes(router, service) + router.GET("/health/live", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "alive"}) + }) - router.POST("/api/v1/completions/go", func(c *gin.Context) { + router.GET("/health/ready", func(c *gin.Context) { + sessions := map[string]int{} + if provider, ok := service.(SessionStatsProvider); ok { + sessions = provider.ActiveSessions() + } + c.JSON(http.StatusOK, gin.H{ + "status": "ready", + "sessions": sessions, + }) + }) + + registerWSRoutes(router, service, opts) + + handleCompletion := func(c *gin.Context) { + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, opts.MaxBodyBytes) var req completion.Request if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON payload"}) return } - resp, err := service.Complete(c.Request.Context(), req) + routeLang := c.Param("language") + if req.Language == "" { + req.Language = routeLang + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), opts.RequestTimeout) + defer cancel() + + resp, err := service.Complete(ctx, req) if err != nil { if errors.Is(err, completion.ErrInvalidRequest) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + if errors.Is(err, completion.ErrUnsupportedLanguage) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if errors.Is(err, completion.ErrTooManySessions) { + c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()}) + return + } + var ownedErr *completion.ErrSessionOwnedByOtherInstance + if errors.As(err, &ownedErr) { + if ownedErr.OwnerEndpoint != "" { + c.Header("X-LSP-Route-To", ownedErr.OwnerEndpoint) + } + c.JSON(http.StatusConflict, gin.H{ + "error": err.Error(), + "routeTo": ownedErr.OwnerEndpoint, + "ownerId": ownedErr.OwnerID, + }) + return + } c.JSON(http.StatusInternalServerError, gin.H{"error": "completion failed"}) return } c.JSON(http.StatusOK, resp) + } + + router.POST("/api/v1/completions/:language", func(c *gin.Context) { + handleCompletion(c) }) } diff --git a/backend/internal/api/ws_handler.go b/backend/internal/api/ws_handler.go index b3da09a..893a6da 100644 --- a/backend/internal/api/ws_handler.go +++ b/backend/internal/api/ws_handler.go @@ -6,7 +6,6 @@ import ( "errors" "net/http" "sync" - "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -24,6 +23,8 @@ var wsUpgrader = websocket.Upgrader{ type wsCompletionRequest struct { ID string `json:"id"` + Language string `json:"language,omitempty"` + SessionID string `json:"sessionId,omitempty"` URI string `json:"uri"` Text string `json:"text"` Line int `json:"line"` @@ -34,17 +35,35 @@ type wsCompletionResponse struct { ID string `json:"id"` Items []completion.Item `json:"items,omitempty"` IsIncomplete bool `json:"isIncomplete,omitempty"` + RouteTo string `json:"routeTo,omitempty"` + OwnerID string `json:"ownerId,omitempty"` Error string `json:"error,omitempty"` } -func registerWSRoutes(router *gin.Engine, service CompletionService) { - router.GET("/ws/completions/go", func(c *gin.Context) { +type wsRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` +} + +type wsRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Result any `json:"result,omitempty"` + Error any `json:"error,omitempty"` +} + +func registerWSRoutes(router *gin.Engine, service CompletionService, opts RouteOptions) { + handler := func(c *gin.Context, defaultLanguage string) { conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { return } defer conn.Close() + conn.SetReadLimit(opts.MaxBodyBytes) + var writeMu sync.Mutex for { @@ -52,45 +71,173 @@ func registerWSRoutes(router *gin.Engine, service CompletionService) { if err != nil { break } - - var req wsCompletionRequest - if err := json.Unmarshal(payload, &req); err != nil { - sendWSResponse(conn, &writeMu, wsCompletionResponse{ - ID: "", - Error: "invalid JSON payload", - }) - continue - } - - go func(r wsCompletionRequest) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - resp, err := service.Complete(ctx, completion.Request{ - URI: r.URI, - Text: r.Text, - Line: r.Line, - Character: r.Character, - }) - if err != nil { - msg := "completion failed" - if errors.Is(err, completion.ErrInvalidRequest) { - msg = err.Error() - } - sendWSResponse(conn, &writeMu, wsCompletionResponse{ - ID: r.ID, - Error: msg, - }) - return - } - - sendWSResponse(conn, &writeMu, wsCompletionResponse{ - ID: r.ID, - Items: resp.Items, - IsIncomplete: resp.IsIncomplete, - }) - }(req) + handleWSMessage(conn, &writeMu, service, payload, defaultLanguage, opts) } + } + + router.GET("/ws/completions", func(c *gin.Context) { + handler(c, "") + }) + router.GET("/ws/completions/:language", func(c *gin.Context) { + handler(c, c.Param("language")) + }) +} + +func handleWSMessage( + conn *websocket.Conn, + writeMu *sync.Mutex, + service CompletionService, + payload []byte, + defaultLanguage string, + opts RouteOptions, +) { + if tryHandleRPCMessage(conn, writeMu, service, payload, defaultLanguage, opts) { + return + } + + var req wsCompletionRequest + if err := json.Unmarshal(payload, &req); err != nil { + sendWSResponse(conn, writeMu, wsCompletionResponse{ + ID: "", + Error: "invalid JSON payload", + }) + return + } + if req.Language == "" { + req.Language = defaultLanguage + } + + processWSCompletion(conn, writeMu, service, req, opts) +} + +func tryHandleRPCMessage( + conn *websocket.Conn, + writeMu *sync.Mutex, + service CompletionService, + payload []byte, + defaultLanguage string, + opts RouteOptions, +) bool { + var rpcReq wsRPCRequest + if err := json.Unmarshal(payload, &rpcReq); err != nil { + return false + } + if rpcReq.JSONRPC != "2.0" || rpcReq.Method == "" { + return false + } + + if rpcReq.Method != "completion/complete" && rpcReq.Method != "completion.complete" { + sendWSRPCResponse(conn, writeMu, wsRPCResponse{ + JSONRPC: "2.0", + ID: rpcReq.ID, + Error: map[string]any{ + "code": -32601, + "message": "method not found", + }, + }) + return true + } + + var req wsCompletionRequest + if err := json.Unmarshal(rpcReq.Params, &req); err != nil { + sendWSRPCResponse(conn, writeMu, wsRPCResponse{ + JSONRPC: "2.0", + ID: rpcReq.ID, + Error: map[string]any{ + "code": -32602, + "message": "invalid params", + }, + }) + return true + } + if req.Language == "" { + req.Language = defaultLanguage + } + if req.ID == "" { + req.ID = string(rpcReq.ID) + } + + ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout) + defer cancel() + + resp, err := service.Complete(ctx, completion.Request{ + Language: req.Language, + SessionID: req.SessionID, + URI: req.URI, + Text: req.Text, + Line: req.Line, + Character: req.Character, + }) + if err != nil { + sendWSRPCResponse(conn, writeMu, wsRPCResponse{ + JSONRPC: "2.0", + ID: rpcReq.ID, + Error: map[string]any{ + "code": -32000, + "message": err.Error(), + }, + }) + return true + } + + sendWSRPCResponse(conn, writeMu, wsRPCResponse{ + JSONRPC: "2.0", + ID: rpcReq.ID, + Result: resp, + }) + return true +} + +func processWSCompletion( + conn *websocket.Conn, + writeMu *sync.Mutex, + service CompletionService, + req wsCompletionRequest, + opts RouteOptions, +) { + ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout) + defer cancel() + + resp, err := service.Complete(ctx, completion.Request{ + Language: req.Language, + SessionID: req.SessionID, + URI: req.URI, + Text: req.Text, + Line: req.Line, + Character: req.Character, + }) + if err != nil { + msg := "completion failed" + routeTo := "" + ownerID := "" + switch { + case errors.Is(err, completion.ErrInvalidRequest): + msg = err.Error() + case errors.Is(err, completion.ErrUnsupportedLanguage): + msg = err.Error() + case errors.Is(err, completion.ErrTooManySessions): + msg = err.Error() + default: + var ownedErr *completion.ErrSessionOwnedByOtherInstance + if errors.As(err, &ownedErr) { + msg = err.Error() + routeTo = ownedErr.OwnerEndpoint + ownerID = ownedErr.OwnerID + } + } + sendWSResponse(conn, writeMu, wsCompletionResponse{ + ID: req.ID, + RouteTo: routeTo, + OwnerID: ownerID, + Error: msg, + }) + return + } + + sendWSResponse(conn, writeMu, wsCompletionResponse{ + ID: req.ID, + Items: resp.Items, + IsIncomplete: resp.IsIncomplete, }) } @@ -99,3 +246,9 @@ func sendWSResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsCompletion defer writeMu.Unlock() _ = conn.WriteJSON(resp) } + +func sendWSRPCResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsRPCResponse) { + writeMu.Lock() + defer writeMu.Unlock() + _ = conn.WriteJSON(resp) +} diff --git a/backend/internal/cluster/redis_registry.go b/backend/internal/cluster/redis_registry.go new file mode 100644 index 0000000..45d1f19 --- /dev/null +++ b/backend/internal/cluster/redis_registry.go @@ -0,0 +1,222 @@ +package cluster + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +type RedisRegistryConfig struct { + Addr string + Password string + DB int + KeyPrefix string + InstanceID string + InstanceEndpoint string + SessionTTL time.Duration + InstanceTTL time.Duration + HeartbeatInterval time.Duration +} + +type RedisRegistry struct { + client *redis.Client + + keyPrefix string + instanceID string + instanceEndpoint string + sessionTTL time.Duration + instanceTTL time.Duration + heartbeatInterval time.Duration + + stopCh chan struct{} + stopOnce sync.Once +} + +var claimSessionScript = redis.NewScript(` +local sessionKey = KEYS[1] +local owner = ARGV[1] +local now = ARGV[2] +local ttl = tonumber(ARGV[3]) + +local existing = redis.call('HGET', sessionKey, 'owner') +if (not existing) or existing == owner then + redis.call('HSET', sessionKey, 'owner', owner, 'updatedAt', now) + redis.call('PEXPIRE', sessionKey, ttl) + return owner +end + +redis.call('PEXPIRE', sessionKey, ttl) +return existing +`) + +var releaseSessionScript = redis.NewScript(` +local sessionKey = KEYS[1] +local owner = ARGV[1] +local existing = redis.call('HGET', sessionKey, 'owner') +if existing == owner then + redis.call('DEL', sessionKey) + return 1 +end +return 0 +`) + +func NewRedisRegistry(ctx context.Context, cfg RedisRegistryConfig) (*RedisRegistry, error) { + if strings.TrimSpace(cfg.Addr) == "" { + return nil, errors.New("redis addr is required") + } + if strings.TrimSpace(cfg.KeyPrefix) == "" { + cfg.KeyPrefix = "lsp-gateway" + } + if cfg.SessionTTL <= 0 { + cfg.SessionTTL = 20 * time.Minute + } + if cfg.InstanceTTL <= 0 { + cfg.InstanceTTL = 30 * time.Second + } + if cfg.HeartbeatInterval <= 0 { + cfg.HeartbeatInterval = 10 * time.Second + } + if cfg.InstanceID == "" { + return nil, errors.New("instance id is required") + } + + client := redis.NewClient(&redis.Options{ + Addr: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + }) + + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("redis ping failed: %w", err) + } + + registry := &RedisRegistry{ + client: client, + keyPrefix: cfg.KeyPrefix, + instanceID: cfg.InstanceID, + instanceEndpoint: cfg.InstanceEndpoint, + sessionTTL: cfg.SessionTTL, + instanceTTL: cfg.InstanceTTL, + heartbeatInterval: cfg.HeartbeatInterval, + stopCh: make(chan struct{}), + } + if err := registry.refreshInstance(ctx); err != nil { + _ = client.Close() + return nil, err + } + go registry.heartbeatLoop() + return registry, nil +} + +func (r *RedisRegistry) ClaimSession( + ctx context.Context, + language string, + sessionID string, +) (ownerID string, ownerEndpoint string, err error) { + key := r.sessionKey(language, sessionID) + now := time.Now().UTC().Format(time.RFC3339Nano) + raw, err := claimSessionScript.Run( + ctx, + r.client, + []string{key}, + r.instanceID, + now, + r.sessionTTL.Milliseconds(), + ).Result() + if err != nil { + return "", "", fmt.Errorf("claim session failed: %w", err) + } + + owner := fmt.Sprint(raw) + if owner == "" { + return "", "", errors.New("claim session returned empty owner") + } + if owner == r.instanceID { + return owner, r.instanceEndpoint, nil + } + + endpoint, err := r.resolveInstanceEndpoint(ctx, owner) + if err != nil { + return owner, "", err + } + return owner, endpoint, nil +} + +func (r *RedisRegistry) ReleaseSession(ctx context.Context, language, sessionID string) error { + key := r.sessionKey(language, sessionID) + if _, err := releaseSessionScript.Run(ctx, r.client, []string{key}, r.instanceID).Result(); err != nil { + return fmt.Errorf("release session failed: %w", err) + } + return nil +} + +func (r *RedisRegistry) Close() error { + r.stopOnce.Do(func() { + close(r.stopCh) + }) + return r.client.Close() +} + +func (r *RedisRegistry) heartbeatLoop() { + ticker := time.NewTicker(r.heartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + _ = r.refreshInstance(ctx) + cancel() + case <-r.stopCh: + return + } + } +} + +func (r *RedisRegistry) refreshInstance(ctx context.Context) error { + key := r.instanceKey(r.instanceID) + now := time.Now().UTC().Format(time.RFC3339Nano) + if err := r.client.HSet(ctx, key, map[string]any{ + "endpoint": r.instanceEndpoint, + "updatedAt": now, + }).Err(); err != nil { + return fmt.Errorf("refresh instance metadata failed: %w", err) + } + if err := r.client.Expire(ctx, key, r.instanceTTL).Err(); err != nil { + return fmt.Errorf("refresh instance ttl failed: %w", err) + } + return nil +} + +func (r *RedisRegistry) resolveInstanceEndpoint(ctx context.Context, ownerID string) (string, error) { + key := r.instanceKey(ownerID) + endpoint, err := r.client.HGet(ctx, key, "endpoint").Result() + if err == redis.Nil { + return "", nil + } + if err != nil { + return "", fmt.Errorf("resolve instance endpoint failed: %w", err) + } + return strings.TrimSpace(endpoint), nil +} + +func (r *RedisRegistry) sessionKey(language, sessionID string) string { + return fmt.Sprintf("%s:sessions:%s:%s", r.keyPrefix, normalizePart(language), normalizePart(sessionID)) +} + +func (r *RedisRegistry) instanceKey(instanceID string) string { + return fmt.Sprintf("%s:instances:%s", r.keyPrefix, normalizePart(instanceID)) +} + +func normalizePart(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "default" + } + return trimmed +} diff --git a/backend/internal/completion/manager.go b/backend/internal/completion/manager.go new file mode 100644 index 0000000..208837f --- /dev/null +++ b/backend/internal/completion/manager.go @@ -0,0 +1,314 @@ +package completion + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + "sync" + "time" +) + +var ErrUnsupportedLanguage = errors.New("unsupported language") +var ErrTooManySessions = errors.New("too many active lsp sessions") + +type RuntimeClient interface { + Client + Close() error +} + +type SessionRegistry interface { + ClaimSession(ctx context.Context, language, sessionID string) (ownerID string, ownerEndpoint string, err error) + ReleaseSession(ctx context.Context, language, sessionID string) error + Close() error +} + +type LanguageServerSpec struct { + Language string + LanguageID string + Command string + Args []string +} + +type ClientFactory func(ctx context.Context, spec LanguageServerSpec, workspaceDir string) (RuntimeClient, error) + +type ManagerConfig struct { + WorkspaceDir string + MaxSessions int + SessionTTL time.Duration + CleanupInterval time.Duration + InstanceID string + Registry SessionRegistry +} + +type Manager struct { + mu sync.Mutex + + config ManagerConfig + specByLang map[string]LanguageServerSpec + sessions map[string]*managedSession + newClient ClientFactory + stopCh chan struct{} + stoppedOnce sync.Once +} + +type managedSession struct { + key string + sessionID string + language string + service *Service + client RuntimeClient + lastUsed time.Time + createdAt time.Time +} + +type ErrSessionOwnedByOtherInstance struct { + OwnerID string + OwnerEndpoint string +} + +func (e *ErrSessionOwnedByOtherInstance) Error() string { + if e.OwnerEndpoint != "" { + return fmt.Sprintf("session owned by another instance: %s (%s)", e.OwnerID, e.OwnerEndpoint) + } + return fmt.Sprintf("session owned by another instance: %s", e.OwnerID) +} + +func NewManager(config ManagerConfig, specs []LanguageServerSpec, factory ClientFactory) *Manager { + if config.MaxSessions <= 0 { + config.MaxSessions = 256 + } + if config.SessionTTL <= 0 { + config.SessionTTL = 20 * time.Minute + } + if config.CleanupInterval <= 0 { + config.CleanupInterval = 2 * time.Minute + } + if strings.TrimSpace(config.InstanceID) == "" { + config.InstanceID = "instance-local" + } + + specByLang := make(map[string]LanguageServerSpec) + for _, spec := range specs { + key := normalizeLanguage(spec.Language) + if key == "" { + continue + } + specByLang[key] = spec + } + + m := &Manager{ + config: config, + specByLang: specByLang, + sessions: make(map[string]*managedSession), + newClient: factory, + stopCh: make(chan struct{}), + } + go m.cleanupLoop() + return m +} + +func (m *Manager) Complete(ctx context.Context, req Request) (Response, error) { + language := normalizeLanguage(req.Language) + if language == "" { + return Response{}, ErrInvalidRequest + } + + spec, ok := m.specByLang[language] + if !ok { + return Response{}, fmt.Errorf("%w: %s", ErrUnsupportedLanguage, req.Language) + } + + sessionKey := buildSessionKey(language, req.SessionID) + sessionID := normalizeSessionID(req.SessionID) + + if m.config.Registry != nil { + ownerID, ownerEndpoint, err := m.config.Registry.ClaimSession(ctx, language, sessionID) + if err != nil { + return Response{}, err + } + if ownerID != m.config.InstanceID { + return Response{}, &ErrSessionOwnedByOtherInstance{ + OwnerID: ownerID, + OwnerEndpoint: ownerEndpoint, + } + } + } + + session, err := m.getOrCreateSession(ctx, sessionKey, sessionID, spec) + if err != nil { + return Response{}, err + } + + resp, err := session.service.Complete(ctx, req) + if err != nil { + return Response{}, err + } + + m.mu.Lock() + if current, ok := m.sessions[sessionKey]; ok { + current.lastUsed = time.Now() + } + m.mu.Unlock() + return resp, nil +} + +func (m *Manager) ActiveSessions() map[string]int { + m.mu.Lock() + defer m.mu.Unlock() + + out := make(map[string]int) + for _, session := range m.sessions { + out[session.language]++ + } + return out +} + +func (m *Manager) Close() error { + m.stoppedOnce.Do(func() { + close(m.stopCh) + + m.mu.Lock() + defer m.mu.Unlock() + + for key, session := range m.sessions { + if m.config.Registry != nil { + _ = m.config.Registry.ReleaseSession(context.Background(), session.language, session.sessionID) + } + _ = session.client.Close() + delete(m.sessions, key) + } + if m.config.Registry != nil { + _ = m.config.Registry.Close() + } + }) + return nil +} + +func (m *Manager) cleanupLoop() { + ticker := time.NewTicker(m.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + m.cleanupIdleSessions() + case <-m.stopCh: + return + } + } +} + +func (m *Manager) cleanupIdleSessions() { + cutoff := time.Now().Add(-m.config.SessionTTL) + + m.mu.Lock() + defer m.mu.Unlock() + + for key, session := range m.sessions { + if session.lastUsed.After(cutoff) { + continue + } + if m.config.Registry != nil { + _ = m.config.Registry.ReleaseSession(context.Background(), session.language, session.sessionID) + } + _ = session.client.Close() + delete(m.sessions, key) + } +} + +func (m *Manager) getOrCreateSession( + ctx context.Context, + sessionKey string, + sessionID string, + spec LanguageServerSpec, +) (*managedSession, error) { + m.mu.Lock() + if existing, ok := m.sessions[sessionKey]; ok { + existing.lastUsed = time.Now() + m.mu.Unlock() + return existing, nil + } + + if len(m.sessions) >= m.config.MaxSessions { + if !m.evictLeastRecentlyUsedLocked() { + m.mu.Unlock() + return nil, ErrTooManySessions + } + } + m.mu.Unlock() + + client, err := m.newClient(ctx, spec, m.config.WorkspaceDir) + if err != nil { + return nil, err + } + + now := time.Now() + newSession := &managedSession{ + key: sessionKey, + sessionID: sessionID, + language: normalizeLanguage(spec.Language), + service: NewService(client), + client: client, + lastUsed: now, + createdAt: now, + } + + m.mu.Lock() + defer m.mu.Unlock() + if existing, ok := m.sessions[sessionKey]; ok { + _ = client.Close() + existing.lastUsed = now + return existing, nil + } + m.sessions[sessionKey] = newSession + return newSession, nil +} + +func (m *Manager) evictLeastRecentlyUsedLocked() bool { + if len(m.sessions) == 0 { + return false + } + + keys := make([]string, 0, len(m.sessions)) + for key := range m.sessions { + keys = append(keys, key) + } + slices.SortFunc(keys, func(a, b string) int { + as := m.sessions[a] + bs := m.sessions[b] + if as.lastUsed.Before(bs.lastUsed) { + return -1 + } + if as.lastUsed.After(bs.lastUsed) { + return 1 + } + return strings.Compare(as.key, bs.key) + }) + + victimKey := keys[0] + victim := m.sessions[victimKey] + if m.config.Registry != nil { + _ = m.config.Registry.ReleaseSession(context.Background(), victim.language, victim.sessionID) + } + _ = victim.client.Close() + delete(m.sessions, victimKey) + return true +} + +func buildSessionKey(language, sessionID string) string { + return language + ":" + normalizeSessionID(sessionID) +} + +func normalizeSessionID(sessionID string) string { + sid := strings.TrimSpace(sessionID) + if sid == "" { + return "default" + } + return sid +} + +func normalizeLanguage(language string) string { + return strings.ToLower(strings.TrimSpace(language)) +} diff --git a/backend/internal/completion/manager_test.go b/backend/internal/completion/manager_test.go new file mode 100644 index 0000000..9acdc0e --- /dev/null +++ b/backend/internal/completion/manager_test.go @@ -0,0 +1,125 @@ +package completion + +import ( + "context" + "errors" + "fmt" + "testing" + "time" +) + +type fakeRuntimeClient struct { + closed bool +} + +func (f *fakeRuntimeClient) DidOpen(_ context.Context, _ string, _ string, _ int) error { + return nil +} + +func (f *fakeRuntimeClient) DidChange(_ context.Context, _ string, _ string, _ int) error { + return nil +} + +func (f *fakeRuntimeClient) Completion(_ context.Context, _ string, _ int, _ int) (Response, error) { + return Response{ + Items: []Item{{Label: "ok"}}, + }, nil +} + +func (f *fakeRuntimeClient) Close() error { + f.closed = true + return nil +} + +func TestManagerCompleteSelectsLanguageAndSession(t *testing.T) { + createCount := 0 + factory := func(_ context.Context, spec LanguageServerSpec, _ string) (RuntimeClient, error) { + createCount++ + if spec.Language != "go" { + return nil, fmt.Errorf("unexpected language: %s", spec.Language) + } + return &fakeRuntimeClient{}, nil + } + + m := NewManager(ManagerConfig{WorkspaceDir: "."}, []LanguageServerSpec{ + {Language: "go", LanguageID: "go", Command: "gopls"}, + }, factory) + defer m.Close() + + for i := 0; i < 2; i++ { + resp, err := m.Complete(context.Background(), Request{ + Language: "go", + SessionID: "s1", + URI: "file:///main.go", + Text: "package main", + Line: 0, + Character: 0, + }) + if err != nil { + t.Fatalf("Complete() error = %v", err) + } + if len(resp.Items) != 1 || resp.Items[0].Label != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + } + + if createCount != 1 { + t.Fatalf("expected one client creation, got %d", createCount) + } +} + +func TestManagerCompleteUnsupportedLanguage(t *testing.T) { + m := NewManager(ManagerConfig{}, []LanguageServerSpec{ + {Language: "go", LanguageID: "go", Command: "gopls"}, + }, func(_ context.Context, _ LanguageServerSpec, _ string) (RuntimeClient, error) { + return &fakeRuntimeClient{}, nil + }) + defer m.Close() + + _, err := m.Complete(context.Background(), Request{ + Language: "python", + URI: "file:///main.py", + Text: "print('hi')", + Line: 0, + Character: 0, + }) + if !errors.Is(err, ErrUnsupportedLanguage) { + t.Fatalf("expected ErrUnsupportedLanguage, got %v", err) + } +} + +func TestManagerCleanupIdleSession(t *testing.T) { + client := &fakeRuntimeClient{} + m := NewManager(ManagerConfig{ + WorkspaceDir: ".", + SessionTTL: 30 * time.Millisecond, + CleanupInterval: 10 * time.Millisecond, + }, []LanguageServerSpec{ + {Language: "go", LanguageID: "go", Command: "gopls"}, + }, func(_ context.Context, _ LanguageServerSpec, _ string) (RuntimeClient, error) { + return client, nil + }) + defer m.Close() + + _, err := m.Complete(context.Background(), Request{ + Language: "go", + SessionID: "s2", + URI: "file:///main.go", + Text: "package main", + Line: 0, + Character: 0, + }) + if err != nil { + t.Fatalf("Complete() error = %v", err) + } + + time.Sleep(90 * time.Millisecond) + + sessions := m.ActiveSessions() + if sessions["go"] != 0 { + t.Fatalf("expected session cleanup, got %+v", sessions) + } + if !client.closed { + t.Fatal("expected client to be closed by cleanup") + } +} diff --git a/backend/internal/completion/service.go b/backend/internal/completion/service.go index 5184ad4..80456f2 100644 --- a/backend/internal/completion/service.go +++ b/backend/internal/completion/service.go @@ -10,6 +10,8 @@ import ( var ErrInvalidRequest = errors.New("invalid completion request") type Request struct { + Language string `json:"language,omitempty"` + SessionID string `json:"sessionId,omitempty"` URI string `json:"uri"` Text string `json:"text"` Line int `json:"line"` diff --git a/backend/internal/lsp/client.go b/backend/internal/lsp/client.go index 46634b9..2f2eb14 100644 --- a/backend/internal/lsp/client.go +++ b/backend/internal/lsp/client.go @@ -33,6 +33,8 @@ type Client struct { nextID atomic.Int64 workspaceDir string + languageID string + clientName string pendingMu sync.Mutex pending map[string]chan rpcResponse @@ -74,19 +76,33 @@ type lspCompletionList struct { Items []lspCompletionItem `json:"items"` } -func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, error) { - if goplsPath == "" { - goplsPath = "gopls" +type Config struct { + Command string + Args []string + RootPath string + LanguageID string + ClientName string +} + +func NewClient(parent context.Context, cfg Config) (*Client, error) { + if cfg.Command == "" { + cfg.Command = "gopls" } - if rootPath == "" { + if cfg.RootPath == "" { cwd, err := os.Getwd() if err != nil { return nil, fmt.Errorf("get working directory: %w", err) } - rootPath = cwd + cfg.RootPath = cwd + } + if cfg.LanguageID == "" { + cfg.LanguageID = "go" + } + if cfg.ClientName == "" { + cfg.ClientName = "monica-lsp-gateway" } - cmd := exec.Command(goplsPath) + cmd := exec.Command(cfg.Command, cfg.Args...) stdin, err := cmd.StdinPipe() if err != nil { return nil, fmt.Errorf("create stdin pipe: %w", err) @@ -98,13 +114,15 @@ func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, err cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("start gopls: %w", err) + return nil, fmt.Errorf("start language server %q: %w", cfg.Command, err) } client := &Client{ cmd: cmd, stdin: stdin, - workspaceDir: rootPath, + workspaceDir: cfg.RootPath, + languageID: cfg.LanguageID, + clientName: cfg.ClientName, pending: make(map[string]chan rpcResponse), exitCh: make(chan error, 1), } @@ -117,7 +135,7 @@ func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, err initCtx, cancel := context.WithTimeout(parent, 10*time.Second) defer cancel() - if err := client.initialize(initCtx, rootPath); err != nil { + if err := client.initialize(initCtx, cfg.RootPath); err != nil { _ = client.Close() return nil, err } @@ -150,7 +168,7 @@ func (c *Client) initialize(ctx context.Context, rootPath string) error { }, }, "clientInfo": map[string]string{ - "name": "monica-go-completion-backend", + "name": c.clientName, "version": "0.1.0", }, } @@ -174,7 +192,7 @@ func (c *Client) DidOpen(ctx context.Context, uri, text string, version int) err params := map[string]any{ "textDocument": map[string]any{ "uri": normalizedURI, - "languageId": "go", + "languageId": c.languageID, "version": version, "text": text, }, diff --git a/src/api/completion.ts b/src/api/completion.ts index 393276d..c496173 100644 --- a/src/api/completion.ts +++ b/src/api/completion.ts @@ -1,20 +1,23 @@ import type { CompletionRequest, CompletionResponse } from '../types/completion' -const COMPLETION_API_URL = +const COMPLETION_API_BASE_URL = import.meta.env.VITE_COMPLETION_API_URL ?? - 'http://127.0.0.1:8080/api/v1/completions/go' + 'http://127.0.0.1:8080/api/v1/completions' const COMPLETION_WS_URL = import.meta.env.VITE_COMPLETION_WS_URL ?? - 'ws://127.0.0.1:8080/ws/completions/go' + 'ws://127.0.0.1:8080/ws/completions' const WS_TIMEOUT_MS = 1800 interface WSCompletionResponse extends CompletionResponse { id: string error?: string + routeTo?: string + ownerId?: string } interface PendingRequest { + request: CompletionRequest resolve: (value: CompletionResponse) => void timer: number } @@ -41,14 +44,35 @@ export async function fetchCompletions( async function fetchCompletionsByHTTP( request: CompletionRequest, + overrideBaseURL?: string, + rerouteDepth = 0, ): Promise { try { - const response = await fetch(COMPLETION_API_URL, { + const language = encodeURIComponent(request.language) + const baseURL = overrideBaseURL || COMPLETION_API_BASE_URL + const response = await fetch(`${baseURL}/${language}`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(request), }) + if (response.status === 409) { + let rerouteURL = '' + try { + const body = (await response.json()) as { routeTo?: string } + rerouteURL = body.routeTo ?? '' + } catch { + rerouteURL = '' + } + + if (rerouteURL && rerouteDepth < 1) { + const normalizedBase = rerouteURL.endsWith('/') + ? `${rerouteURL}api/v1/completions` + : `${rerouteURL}/api/v1/completions` + return await fetchCompletionsByHTTP(request, normalizedBase, rerouteDepth + 1) + } + } + if (!response.ok) { console.error('Completion API error:', response.status) return { items: [], isIncomplete: false } @@ -73,7 +97,7 @@ async function fetchCompletionsByWS( resolve(await fetchCompletionsByHTTP(request)) }, WS_TIMEOUT_MS) - pending.set(id, { resolve, timer }) + pending.set(id, { request, resolve, timer }) try { socket.send(JSON.stringify({ id, ...request })) @@ -123,6 +147,13 @@ async function getWebSocket(): Promise { pending.delete(payload.id) if (payload.error) { + if (payload.routeTo) { + const routeBase = payload.routeTo.endsWith('/') + ? `${payload.routeTo}api/v1/completions` + : `${payload.routeTo}/api/v1/completions` + void fetchCompletionsByHTTP(entry.request, routeBase).then(entry.resolve) + return + } entry.resolve({ items: [], isIncomplete: false }) return } diff --git a/src/components/MonacoEditor.vue b/src/components/MonacoEditor.vue index 585bdbf..f10e6d3 100644 --- a/src/components/MonacoEditor.vue +++ b/src/components/MonacoEditor.vue @@ -24,6 +24,9 @@ const emit = defineEmits<{ const editorContainer = ref() const editor = shallowRef() let completionDisposable: monaco.IDisposable | null = null +const completionSessionID = `monaco-${Date.now().toString(36)}` + +const lspEnabledLanguages = new Set(['go', 'javascript', 'typescript']) /** Map LSP CompletionItemKind to Monaco CompletionItemKind */ function mapKind(kind?: number): monaco.languages.CompletionItemKind { @@ -82,7 +85,13 @@ function mapKind(kind?: number): monaco.languages.CompletionItemKind { } function getDocumentURI(language: string): string { - const name = `main.${language}` + const extensionByLanguage: Record = { + go: 'go', + javascript: 'js', + typescript: 'ts', + } + const extension = extensionByLanguage[language] ?? language + const name = `main.${extension}` return `file:///${name}` } @@ -91,8 +100,7 @@ function registerCompletionProvider(language: string) { completionDisposable?.dispose() completionDisposable = null - // Currently only Go completion is wired to backend gopls. - if (language !== 'go') { + if (!lspEnabledLanguages.has(language)) { return } @@ -100,7 +108,7 @@ function registerCompletionProvider(language: string) { triggerCharacters: ['.', ':', '('], async provideCompletionItems(model, position) { - if (model.getLanguageId() !== 'go') { + if (!lspEnabledLanguages.has(model.getLanguageId())) { return { suggestions: [] } } @@ -108,6 +116,8 @@ function registerCompletionProvider(language: string) { const word = model.getWordUntilPosition(position) const response = await fetchCompletions({ + language, + sessionId: completionSessionID, uri: getDocumentURI(language), text: code, line: Math.max(position.lineNumber - 1, 0), diff --git a/src/types/completion.ts b/src/types/completion.ts index ad85271..31d8bef 100644 --- a/src/types/completion.ts +++ b/src/types/completion.ts @@ -1,6 +1,10 @@ /** Types for the code completion API */ export interface CompletionRequest { + /** Language key (go/javascript/typescript) */ + language: string + /** Logical client session id */ + sessionId?: string /** Document URI (file://...) */ uri: string /** The full text content of the editor */