feat: enhance completion service with session management and language support
- Introduced session management using Redis for tracking active sessions. - Added session claiming and releasing functionality in the completion manager. - Enhanced HTTP and WebSocket completion endpoints to support multiple languages. - Implemented request timeout and maximum body size configurations for API routes. - Updated client-side code to handle session IDs and language parameters in completion requests. - Improved error handling for unsupported languages and session conflicts. - Added tests for the completion manager to ensure proper session handling and cleanup.
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -14,30 +15,100 @@ type CompletionService interface {
|
||||
Complete(ctx context.Context, req completion.Request) (completion.Response, error)
|
||||
}
|
||||
|
||||
func RegisterRoutes(router *gin.Engine, service CompletionService) {
|
||||
type SessionStatsProvider interface {
|
||||
ActiveSessions() map[string]int
|
||||
}
|
||||
|
||||
type RouteOptions struct {
|
||||
RequestTimeout time.Duration
|
||||
MaxBodyBytes int64
|
||||
}
|
||||
|
||||
func RegisterRoutes(router *gin.Engine, service CompletionService, options ...RouteOptions) {
|
||||
opts := RouteOptions{
|
||||
RequestTimeout: 10 * time.Second,
|
||||
MaxBodyBytes: 2 << 20, // 2MB
|
||||
}
|
||||
if len(options) > 0 {
|
||||
if options[0].RequestTimeout > 0 {
|
||||
opts.RequestTimeout = options[0].RequestTimeout
|
||||
}
|
||||
if options[0].MaxBodyBytes > 0 {
|
||||
opts.MaxBodyBytes = options[0].MaxBodyBytes
|
||||
}
|
||||
}
|
||||
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "service": "lsp-gateway"})
|
||||
})
|
||||
|
||||
registerWSRoutes(router, service)
|
||||
router.GET("/health/live", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "alive"})
|
||||
})
|
||||
|
||||
router.POST("/api/v1/completions/go", func(c *gin.Context) {
|
||||
router.GET("/health/ready", func(c *gin.Context) {
|
||||
sessions := map[string]int{}
|
||||
if provider, ok := service.(SessionStatsProvider); ok {
|
||||
sessions = provider.ActiveSessions()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ready",
|
||||
"sessions": sessions,
|
||||
})
|
||||
})
|
||||
|
||||
registerWSRoutes(router, service, opts)
|
||||
|
||||
handleCompletion := func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, opts.MaxBodyBytes)
|
||||
var req completion.Request
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON payload"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := service.Complete(c.Request.Context(), req)
|
||||
routeLang := c.Param("language")
|
||||
if req.Language == "" {
|
||||
req.Language = routeLang
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), opts.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, completion.ErrInvalidRequest) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, completion.ErrUnsupportedLanguage) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, completion.ErrTooManySessions) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var ownedErr *completion.ErrSessionOwnedByOtherInstance
|
||||
if errors.As(err, &ownedErr) {
|
||||
if ownedErr.OwnerEndpoint != "" {
|
||||
c.Header("X-LSP-Route-To", ownedErr.OwnerEndpoint)
|
||||
}
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": err.Error(),
|
||||
"routeTo": ownedErr.OwnerEndpoint,
|
||||
"ownerId": ownedErr.OwnerID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "completion failed"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
router.POST("/api/v1/completions/:language", func(c *gin.Context) {
|
||||
handleCompletion(c)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -24,6 +23,8 @@ var wsUpgrader = websocket.Upgrader{
|
||||
|
||||
type wsCompletionRequest struct {
|
||||
ID string `json:"id"`
|
||||
Language string `json:"language,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
URI string `json:"uri"`
|
||||
Text string `json:"text"`
|
||||
Line int `json:"line"`
|
||||
@@ -34,17 +35,35 @@ type wsCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Items []completion.Item `json:"items,omitempty"`
|
||||
IsIncomplete bool `json:"isIncomplete,omitempty"`
|
||||
RouteTo string `json:"routeTo,omitempty"`
|
||||
OwnerID string `json:"ownerId,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func registerWSRoutes(router *gin.Engine, service CompletionService) {
|
||||
router.GET("/ws/completions/go", func(c *gin.Context) {
|
||||
type wsRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params"`
|
||||
}
|
||||
|
||||
type wsRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func registerWSRoutes(router *gin.Engine, service CompletionService, opts RouteOptions) {
|
||||
handler := func(c *gin.Context, defaultLanguage string) {
|
||||
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadLimit(opts.MaxBodyBytes)
|
||||
|
||||
var writeMu sync.Mutex
|
||||
|
||||
for {
|
||||
@@ -52,45 +71,173 @@ func registerWSRoutes(router *gin.Engine, service CompletionService) {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var req wsCompletionRequest
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: "",
|
||||
Error: "invalid JSON payload",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
go func(r wsCompletionRequest) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, completion.Request{
|
||||
URI: r.URI,
|
||||
Text: r.Text,
|
||||
Line: r.Line,
|
||||
Character: r.Character,
|
||||
})
|
||||
if err != nil {
|
||||
msg := "completion failed"
|
||||
if errors.Is(err, completion.ErrInvalidRequest) {
|
||||
msg = err.Error()
|
||||
}
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: r.ID,
|
||||
Error: msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: r.ID,
|
||||
Items: resp.Items,
|
||||
IsIncomplete: resp.IsIncomplete,
|
||||
})
|
||||
}(req)
|
||||
handleWSMessage(conn, &writeMu, service, payload, defaultLanguage, opts)
|
||||
}
|
||||
}
|
||||
|
||||
router.GET("/ws/completions", func(c *gin.Context) {
|
||||
handler(c, "")
|
||||
})
|
||||
router.GET("/ws/completions/:language", func(c *gin.Context) {
|
||||
handler(c, c.Param("language"))
|
||||
})
|
||||
}
|
||||
|
||||
func handleWSMessage(
|
||||
conn *websocket.Conn,
|
||||
writeMu *sync.Mutex,
|
||||
service CompletionService,
|
||||
payload []byte,
|
||||
defaultLanguage string,
|
||||
opts RouteOptions,
|
||||
) {
|
||||
if tryHandleRPCMessage(conn, writeMu, service, payload, defaultLanguage, opts) {
|
||||
return
|
||||
}
|
||||
|
||||
var req wsCompletionRequest
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
sendWSResponse(conn, writeMu, wsCompletionResponse{
|
||||
ID: "",
|
||||
Error: "invalid JSON payload",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Language == "" {
|
||||
req.Language = defaultLanguage
|
||||
}
|
||||
|
||||
processWSCompletion(conn, writeMu, service, req, opts)
|
||||
}
|
||||
|
||||
func tryHandleRPCMessage(
|
||||
conn *websocket.Conn,
|
||||
writeMu *sync.Mutex,
|
||||
service CompletionService,
|
||||
payload []byte,
|
||||
defaultLanguage string,
|
||||
opts RouteOptions,
|
||||
) bool {
|
||||
var rpcReq wsRPCRequest
|
||||
if err := json.Unmarshal(payload, &rpcReq); err != nil {
|
||||
return false
|
||||
}
|
||||
if rpcReq.JSONRPC != "2.0" || rpcReq.Method == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if rpcReq.Method != "completion/complete" && rpcReq.Method != "completion.complete" {
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Error: map[string]any{
|
||||
"code": -32601,
|
||||
"message": "method not found",
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
var req wsCompletionRequest
|
||||
if err := json.Unmarshal(rpcReq.Params, &req); err != nil {
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Error: map[string]any{
|
||||
"code": -32602,
|
||||
"message": "invalid params",
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
if req.Language == "" {
|
||||
req.Language = defaultLanguage
|
||||
}
|
||||
if req.ID == "" {
|
||||
req.ID = string(rpcReq.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, completion.Request{
|
||||
Language: req.Language,
|
||||
SessionID: req.SessionID,
|
||||
URI: req.URI,
|
||||
Text: req.Text,
|
||||
Line: req.Line,
|
||||
Character: req.Character,
|
||||
})
|
||||
if err != nil {
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Error: map[string]any{
|
||||
"code": -32000,
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: rpcReq.ID,
|
||||
Result: resp,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func processWSCompletion(
|
||||
conn *websocket.Conn,
|
||||
writeMu *sync.Mutex,
|
||||
service CompletionService,
|
||||
req wsCompletionRequest,
|
||||
opts RouteOptions,
|
||||
) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, completion.Request{
|
||||
Language: req.Language,
|
||||
SessionID: req.SessionID,
|
||||
URI: req.URI,
|
||||
Text: req.Text,
|
||||
Line: req.Line,
|
||||
Character: req.Character,
|
||||
})
|
||||
if err != nil {
|
||||
msg := "completion failed"
|
||||
routeTo := ""
|
||||
ownerID := ""
|
||||
switch {
|
||||
case errors.Is(err, completion.ErrInvalidRequest):
|
||||
msg = err.Error()
|
||||
case errors.Is(err, completion.ErrUnsupportedLanguage):
|
||||
msg = err.Error()
|
||||
case errors.Is(err, completion.ErrTooManySessions):
|
||||
msg = err.Error()
|
||||
default:
|
||||
var ownedErr *completion.ErrSessionOwnedByOtherInstance
|
||||
if errors.As(err, &ownedErr) {
|
||||
msg = err.Error()
|
||||
routeTo = ownedErr.OwnerEndpoint
|
||||
ownerID = ownedErr.OwnerID
|
||||
}
|
||||
}
|
||||
sendWSResponse(conn, writeMu, wsCompletionResponse{
|
||||
ID: req.ID,
|
||||
RouteTo: routeTo,
|
||||
OwnerID: ownerID,
|
||||
Error: msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sendWSResponse(conn, writeMu, wsCompletionResponse{
|
||||
ID: req.ID,
|
||||
Items: resp.Items,
|
||||
IsIncomplete: resp.IsIncomplete,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -99,3 +246,9 @@ func sendWSResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsCompletion
|
||||
defer writeMu.Unlock()
|
||||
_ = conn.WriteJSON(resp)
|
||||
}
|
||||
|
||||
func sendWSRPCResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsRPCResponse) {
|
||||
writeMu.Lock()
|
||||
defer writeMu.Unlock()
|
||||
_ = conn.WriteJSON(resp)
|
||||
}
|
||||
|
||||
222
backend/internal/cluster/redis_registry.go
Normal file
222
backend/internal/cluster/redis_registry.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type RedisRegistryConfig struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
KeyPrefix string
|
||||
InstanceID string
|
||||
InstanceEndpoint string
|
||||
SessionTTL time.Duration
|
||||
InstanceTTL time.Duration
|
||||
HeartbeatInterval time.Duration
|
||||
}
|
||||
|
||||
type RedisRegistry struct {
|
||||
client *redis.Client
|
||||
|
||||
keyPrefix string
|
||||
instanceID string
|
||||
instanceEndpoint string
|
||||
sessionTTL time.Duration
|
||||
instanceTTL time.Duration
|
||||
heartbeatInterval time.Duration
|
||||
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
var claimSessionScript = redis.NewScript(`
|
||||
local sessionKey = KEYS[1]
|
||||
local owner = ARGV[1]
|
||||
local now = ARGV[2]
|
||||
local ttl = tonumber(ARGV[3])
|
||||
|
||||
local existing = redis.call('HGET', sessionKey, 'owner')
|
||||
if (not existing) or existing == owner then
|
||||
redis.call('HSET', sessionKey, 'owner', owner, 'updatedAt', now)
|
||||
redis.call('PEXPIRE', sessionKey, ttl)
|
||||
return owner
|
||||
end
|
||||
|
||||
redis.call('PEXPIRE', sessionKey, ttl)
|
||||
return existing
|
||||
`)
|
||||
|
||||
var releaseSessionScript = redis.NewScript(`
|
||||
local sessionKey = KEYS[1]
|
||||
local owner = ARGV[1]
|
||||
local existing = redis.call('HGET', sessionKey, 'owner')
|
||||
if existing == owner then
|
||||
redis.call('DEL', sessionKey)
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
func NewRedisRegistry(ctx context.Context, cfg RedisRegistryConfig) (*RedisRegistry, error) {
|
||||
if strings.TrimSpace(cfg.Addr) == "" {
|
||||
return nil, errors.New("redis addr is required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.KeyPrefix) == "" {
|
||||
cfg.KeyPrefix = "lsp-gateway"
|
||||
}
|
||||
if cfg.SessionTTL <= 0 {
|
||||
cfg.SessionTTL = 20 * time.Minute
|
||||
}
|
||||
if cfg.InstanceTTL <= 0 {
|
||||
cfg.InstanceTTL = 30 * time.Second
|
||||
}
|
||||
if cfg.HeartbeatInterval <= 0 {
|
||||
cfg.HeartbeatInterval = 10 * time.Second
|
||||
}
|
||||
if cfg.InstanceID == "" {
|
||||
return nil, errors.New("instance id is required")
|
||||
}
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("redis ping failed: %w", err)
|
||||
}
|
||||
|
||||
registry := &RedisRegistry{
|
||||
client: client,
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
instanceID: cfg.InstanceID,
|
||||
instanceEndpoint: cfg.InstanceEndpoint,
|
||||
sessionTTL: cfg.SessionTTL,
|
||||
instanceTTL: cfg.InstanceTTL,
|
||||
heartbeatInterval: cfg.HeartbeatInterval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
if err := registry.refreshInstance(ctx); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
}
|
||||
go registry.heartbeatLoop()
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) ClaimSession(
|
||||
ctx context.Context,
|
||||
language string,
|
||||
sessionID string,
|
||||
) (ownerID string, ownerEndpoint string, err error) {
|
||||
key := r.sessionKey(language, sessionID)
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
raw, err := claimSessionScript.Run(
|
||||
ctx,
|
||||
r.client,
|
||||
[]string{key},
|
||||
r.instanceID,
|
||||
now,
|
||||
r.sessionTTL.Milliseconds(),
|
||||
).Result()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("claim session failed: %w", err)
|
||||
}
|
||||
|
||||
owner := fmt.Sprint(raw)
|
||||
if owner == "" {
|
||||
return "", "", errors.New("claim session returned empty owner")
|
||||
}
|
||||
if owner == r.instanceID {
|
||||
return owner, r.instanceEndpoint, nil
|
||||
}
|
||||
|
||||
endpoint, err := r.resolveInstanceEndpoint(ctx, owner)
|
||||
if err != nil {
|
||||
return owner, "", err
|
||||
}
|
||||
return owner, endpoint, nil
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) ReleaseSession(ctx context.Context, language, sessionID string) error {
|
||||
key := r.sessionKey(language, sessionID)
|
||||
if _, err := releaseSessionScript.Run(ctx, r.client, []string{key}, r.instanceID).Result(); err != nil {
|
||||
return fmt.Errorf("release session failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) Close() error {
|
||||
r.stopOnce.Do(func() {
|
||||
close(r.stopCh)
|
||||
})
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) heartbeatLoop() {
|
||||
ticker := time.NewTicker(r.heartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_ = r.refreshInstance(ctx)
|
||||
cancel()
|
||||
case <-r.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) refreshInstance(ctx context.Context) error {
|
||||
key := r.instanceKey(r.instanceID)
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
if err := r.client.HSet(ctx, key, map[string]any{
|
||||
"endpoint": r.instanceEndpoint,
|
||||
"updatedAt": now,
|
||||
}).Err(); err != nil {
|
||||
return fmt.Errorf("refresh instance metadata failed: %w", err)
|
||||
}
|
||||
if err := r.client.Expire(ctx, key, r.instanceTTL).Err(); err != nil {
|
||||
return fmt.Errorf("refresh instance ttl failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) resolveInstanceEndpoint(ctx context.Context, ownerID string) (string, error) {
|
||||
key := r.instanceKey(ownerID)
|
||||
endpoint, err := r.client.HGet(ctx, key, "endpoint").Result()
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve instance endpoint failed: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(endpoint), nil
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) sessionKey(language, sessionID string) string {
|
||||
return fmt.Sprintf("%s:sessions:%s:%s", r.keyPrefix, normalizePart(language), normalizePart(sessionID))
|
||||
}
|
||||
|
||||
func (r *RedisRegistry) instanceKey(instanceID string) string {
|
||||
return fmt.Sprintf("%s:instances:%s", r.keyPrefix, normalizePart(instanceID))
|
||||
}
|
||||
|
||||
func normalizePart(value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return "default"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
314
backend/internal/completion/manager.go
Normal file
314
backend/internal/completion/manager.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package completion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
var ErrTooManySessions = errors.New("too many active lsp sessions")
|
||||
|
||||
type RuntimeClient interface {
|
||||
Client
|
||||
Close() error
|
||||
}
|
||||
|
||||
type SessionRegistry interface {
|
||||
ClaimSession(ctx context.Context, language, sessionID string) (ownerID string, ownerEndpoint string, err error)
|
||||
ReleaseSession(ctx context.Context, language, sessionID string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type LanguageServerSpec struct {
|
||||
Language string
|
||||
LanguageID string
|
||||
Command string
|
||||
Args []string
|
||||
}
|
||||
|
||||
type ClientFactory func(ctx context.Context, spec LanguageServerSpec, workspaceDir string) (RuntimeClient, error)
|
||||
|
||||
type ManagerConfig struct {
|
||||
WorkspaceDir string
|
||||
MaxSessions int
|
||||
SessionTTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
InstanceID string
|
||||
Registry SessionRegistry
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
|
||||
config ManagerConfig
|
||||
specByLang map[string]LanguageServerSpec
|
||||
sessions map[string]*managedSession
|
||||
newClient ClientFactory
|
||||
stopCh chan struct{}
|
||||
stoppedOnce sync.Once
|
||||
}
|
||||
|
||||
type managedSession struct {
|
||||
key string
|
||||
sessionID string
|
||||
language string
|
||||
service *Service
|
||||
client RuntimeClient
|
||||
lastUsed time.Time
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
type ErrSessionOwnedByOtherInstance struct {
|
||||
OwnerID string
|
||||
OwnerEndpoint string
|
||||
}
|
||||
|
||||
func (e *ErrSessionOwnedByOtherInstance) Error() string {
|
||||
if e.OwnerEndpoint != "" {
|
||||
return fmt.Sprintf("session owned by another instance: %s (%s)", e.OwnerID, e.OwnerEndpoint)
|
||||
}
|
||||
return fmt.Sprintf("session owned by another instance: %s", e.OwnerID)
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig, specs []LanguageServerSpec, factory ClientFactory) *Manager {
|
||||
if config.MaxSessions <= 0 {
|
||||
config.MaxSessions = 256
|
||||
}
|
||||
if config.SessionTTL <= 0 {
|
||||
config.SessionTTL = 20 * time.Minute
|
||||
}
|
||||
if config.CleanupInterval <= 0 {
|
||||
config.CleanupInterval = 2 * time.Minute
|
||||
}
|
||||
if strings.TrimSpace(config.InstanceID) == "" {
|
||||
config.InstanceID = "instance-local"
|
||||
}
|
||||
|
||||
specByLang := make(map[string]LanguageServerSpec)
|
||||
for _, spec := range specs {
|
||||
key := normalizeLanguage(spec.Language)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
specByLang[key] = spec
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
config: config,
|
||||
specByLang: specByLang,
|
||||
sessions: make(map[string]*managedSession),
|
||||
newClient: factory,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go m.cleanupLoop()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Manager) Complete(ctx context.Context, req Request) (Response, error) {
|
||||
language := normalizeLanguage(req.Language)
|
||||
if language == "" {
|
||||
return Response{}, ErrInvalidRequest
|
||||
}
|
||||
|
||||
spec, ok := m.specByLang[language]
|
||||
if !ok {
|
||||
return Response{}, fmt.Errorf("%w: %s", ErrUnsupportedLanguage, req.Language)
|
||||
}
|
||||
|
||||
sessionKey := buildSessionKey(language, req.SessionID)
|
||||
sessionID := normalizeSessionID(req.SessionID)
|
||||
|
||||
if m.config.Registry != nil {
|
||||
ownerID, ownerEndpoint, err := m.config.Registry.ClaimSession(ctx, language, sessionID)
|
||||
if err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
if ownerID != m.config.InstanceID {
|
||||
return Response{}, &ErrSessionOwnedByOtherInstance{
|
||||
OwnerID: ownerID,
|
||||
OwnerEndpoint: ownerEndpoint,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session, err := m.getOrCreateSession(ctx, sessionKey, sessionID, spec)
|
||||
if err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
|
||||
resp, err := session.service.Complete(ctx, req)
|
||||
if err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
if current, ok := m.sessions[sessionKey]; ok {
|
||||
current.lastUsed = time.Now()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *Manager) ActiveSessions() map[string]int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
out := make(map[string]int)
|
||||
for _, session := range m.sessions {
|
||||
out[session.language]++
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) Close() error {
|
||||
m.stoppedOnce.Do(func() {
|
||||
close(m.stopCh)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for key, session := range m.sessions {
|
||||
if m.config.Registry != nil {
|
||||
_ = m.config.Registry.ReleaseSession(context.Background(), session.language, session.sessionID)
|
||||
}
|
||||
_ = session.client.Close()
|
||||
delete(m.sessions, key)
|
||||
}
|
||||
if m.config.Registry != nil {
|
||||
_ = m.config.Registry.Close()
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupLoop() {
|
||||
ticker := time.NewTicker(m.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanupIdleSessions()
|
||||
case <-m.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupIdleSessions() {
|
||||
cutoff := time.Now().Add(-m.config.SessionTTL)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for key, session := range m.sessions {
|
||||
if session.lastUsed.After(cutoff) {
|
||||
continue
|
||||
}
|
||||
if m.config.Registry != nil {
|
||||
_ = m.config.Registry.ReleaseSession(context.Background(), session.language, session.sessionID)
|
||||
}
|
||||
_ = session.client.Close()
|
||||
delete(m.sessions, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getOrCreateSession(
|
||||
ctx context.Context,
|
||||
sessionKey string,
|
||||
sessionID string,
|
||||
spec LanguageServerSpec,
|
||||
) (*managedSession, error) {
|
||||
m.mu.Lock()
|
||||
if existing, ok := m.sessions[sessionKey]; ok {
|
||||
existing.lastUsed = time.Now()
|
||||
m.mu.Unlock()
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
if len(m.sessions) >= m.config.MaxSessions {
|
||||
if !m.evictLeastRecentlyUsedLocked() {
|
||||
m.mu.Unlock()
|
||||
return nil, ErrTooManySessions
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
client, err := m.newClient(ctx, spec, m.config.WorkspaceDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
newSession := &managedSession{
|
||||
key: sessionKey,
|
||||
sessionID: sessionID,
|
||||
language: normalizeLanguage(spec.Language),
|
||||
service: NewService(client),
|
||||
client: client,
|
||||
lastUsed: now,
|
||||
createdAt: now,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if existing, ok := m.sessions[sessionKey]; ok {
|
||||
_ = client.Close()
|
||||
existing.lastUsed = now
|
||||
return existing, nil
|
||||
}
|
||||
m.sessions[sessionKey] = newSession
|
||||
return newSession, nil
|
||||
}
|
||||
|
||||
func (m *Manager) evictLeastRecentlyUsedLocked() bool {
|
||||
if len(m.sessions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(m.sessions))
|
||||
for key := range m.sessions {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
slices.SortFunc(keys, func(a, b string) int {
|
||||
as := m.sessions[a]
|
||||
bs := m.sessions[b]
|
||||
if as.lastUsed.Before(bs.lastUsed) {
|
||||
return -1
|
||||
}
|
||||
if as.lastUsed.After(bs.lastUsed) {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(as.key, bs.key)
|
||||
})
|
||||
|
||||
victimKey := keys[0]
|
||||
victim := m.sessions[victimKey]
|
||||
if m.config.Registry != nil {
|
||||
_ = m.config.Registry.ReleaseSession(context.Background(), victim.language, victim.sessionID)
|
||||
}
|
||||
_ = victim.client.Close()
|
||||
delete(m.sessions, victimKey)
|
||||
return true
|
||||
}
|
||||
|
||||
func buildSessionKey(language, sessionID string) string {
|
||||
return language + ":" + normalizeSessionID(sessionID)
|
||||
}
|
||||
|
||||
func normalizeSessionID(sessionID string) string {
|
||||
sid := strings.TrimSpace(sessionID)
|
||||
if sid == "" {
|
||||
return "default"
|
||||
}
|
||||
return sid
|
||||
}
|
||||
|
||||
func normalizeLanguage(language string) string {
|
||||
return strings.ToLower(strings.TrimSpace(language))
|
||||
}
|
||||
125
backend/internal/completion/manager_test.go
Normal file
125
backend/internal/completion/manager_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package completion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeRuntimeClient struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (f *fakeRuntimeClient) DidOpen(_ context.Context, _ string, _ string, _ int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeRuntimeClient) DidChange(_ context.Context, _ string, _ string, _ int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeRuntimeClient) Completion(_ context.Context, _ string, _ int, _ int) (Response, error) {
|
||||
return Response{
|
||||
Items: []Item{{Label: "ok"}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeRuntimeClient) Close() error {
|
||||
f.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManagerCompleteSelectsLanguageAndSession(t *testing.T) {
|
||||
createCount := 0
|
||||
factory := func(_ context.Context, spec LanguageServerSpec, _ string) (RuntimeClient, error) {
|
||||
createCount++
|
||||
if spec.Language != "go" {
|
||||
return nil, fmt.Errorf("unexpected language: %s", spec.Language)
|
||||
}
|
||||
return &fakeRuntimeClient{}, nil
|
||||
}
|
||||
|
||||
m := NewManager(ManagerConfig{WorkspaceDir: "."}, []LanguageServerSpec{
|
||||
{Language: "go", LanguageID: "go", Command: "gopls"},
|
||||
}, factory)
|
||||
defer m.Close()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := m.Complete(context.Background(), Request{
|
||||
Language: "go",
|
||||
SessionID: "s1",
|
||||
URI: "file:///main.go",
|
||||
Text: "package main",
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Complete() error = %v", err)
|
||||
}
|
||||
if len(resp.Items) != 1 || resp.Items[0].Label != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
if createCount != 1 {
|
||||
t.Fatalf("expected one client creation, got %d", createCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCompleteUnsupportedLanguage(t *testing.T) {
|
||||
m := NewManager(ManagerConfig{}, []LanguageServerSpec{
|
||||
{Language: "go", LanguageID: "go", Command: "gopls"},
|
||||
}, func(_ context.Context, _ LanguageServerSpec, _ string) (RuntimeClient, error) {
|
||||
return &fakeRuntimeClient{}, nil
|
||||
})
|
||||
defer m.Close()
|
||||
|
||||
_, err := m.Complete(context.Background(), Request{
|
||||
Language: "python",
|
||||
URI: "file:///main.py",
|
||||
Text: "print('hi')",
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
})
|
||||
if !errors.Is(err, ErrUnsupportedLanguage) {
|
||||
t.Fatalf("expected ErrUnsupportedLanguage, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCleanupIdleSession(t *testing.T) {
|
||||
client := &fakeRuntimeClient{}
|
||||
m := NewManager(ManagerConfig{
|
||||
WorkspaceDir: ".",
|
||||
SessionTTL: 30 * time.Millisecond,
|
||||
CleanupInterval: 10 * time.Millisecond,
|
||||
}, []LanguageServerSpec{
|
||||
{Language: "go", LanguageID: "go", Command: "gopls"},
|
||||
}, func(_ context.Context, _ LanguageServerSpec, _ string) (RuntimeClient, error) {
|
||||
return client, nil
|
||||
})
|
||||
defer m.Close()
|
||||
|
||||
_, err := m.Complete(context.Background(), Request{
|
||||
Language: "go",
|
||||
SessionID: "s2",
|
||||
URI: "file:///main.go",
|
||||
Text: "package main",
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Complete() error = %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(90 * time.Millisecond)
|
||||
|
||||
sessions := m.ActiveSessions()
|
||||
if sessions["go"] != 0 {
|
||||
t.Fatalf("expected session cleanup, got %+v", sessions)
|
||||
}
|
||||
if !client.closed {
|
||||
t.Fatal("expected client to be closed by cleanup")
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
var ErrInvalidRequest = errors.New("invalid completion request")
|
||||
|
||||
type Request struct {
|
||||
Language string `json:"language,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
URI string `json:"uri"`
|
||||
Text string `json:"text"`
|
||||
Line int `json:"line"`
|
||||
|
||||
@@ -33,6 +33,8 @@ type Client struct {
|
||||
nextID atomic.Int64
|
||||
|
||||
workspaceDir string
|
||||
languageID string
|
||||
clientName string
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pending map[string]chan rpcResponse
|
||||
@@ -74,19 +76,33 @@ type lspCompletionList struct {
|
||||
Items []lspCompletionItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, error) {
|
||||
if goplsPath == "" {
|
||||
goplsPath = "gopls"
|
||||
type Config struct {
|
||||
Command string
|
||||
Args []string
|
||||
RootPath string
|
||||
LanguageID string
|
||||
ClientName string
|
||||
}
|
||||
|
||||
func NewClient(parent context.Context, cfg Config) (*Client, error) {
|
||||
if cfg.Command == "" {
|
||||
cfg.Command = "gopls"
|
||||
}
|
||||
if rootPath == "" {
|
||||
if cfg.RootPath == "" {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get working directory: %w", err)
|
||||
}
|
||||
rootPath = cwd
|
||||
cfg.RootPath = cwd
|
||||
}
|
||||
if cfg.LanguageID == "" {
|
||||
cfg.LanguageID = "go"
|
||||
}
|
||||
if cfg.ClientName == "" {
|
||||
cfg.ClientName = "monica-lsp-gateway"
|
||||
}
|
||||
|
||||
cmd := exec.Command(goplsPath)
|
||||
cmd := exec.Command(cfg.Command, cfg.Args...)
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdin pipe: %w", err)
|
||||
@@ -98,13 +114,15 @@ func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, err
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start gopls: %w", err)
|
||||
return nil, fmt.Errorf("start language server %q: %w", cfg.Command, err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
workspaceDir: rootPath,
|
||||
workspaceDir: cfg.RootPath,
|
||||
languageID: cfg.LanguageID,
|
||||
clientName: cfg.ClientName,
|
||||
pending: make(map[string]chan rpcResponse),
|
||||
exitCh: make(chan error, 1),
|
||||
}
|
||||
@@ -117,7 +135,7 @@ func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, err
|
||||
initCtx, cancel := context.WithTimeout(parent, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.initialize(initCtx, rootPath); err != nil {
|
||||
if err := client.initialize(initCtx, cfg.RootPath); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
}
|
||||
@@ -150,7 +168,7 @@ func (c *Client) initialize(ctx context.Context, rootPath string) error {
|
||||
},
|
||||
},
|
||||
"clientInfo": map[string]string{
|
||||
"name": "monica-go-completion-backend",
|
||||
"name": c.clientName,
|
||||
"version": "0.1.0",
|
||||
},
|
||||
}
|
||||
@@ -174,7 +192,7 @@ func (c *Client) DidOpen(ctx context.Context, uri, text string, version int) err
|
||||
params := map[string]any{
|
||||
"textDocument": map[string]any{
|
||||
"uri": normalizedURI,
|
||||
"languageId": "go",
|
||||
"languageId": c.languageID,
|
||||
"version": version,
|
||||
"text": text,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user