feat: enhance completion service with session management and language support

- Introduced session management using Redis for tracking active sessions.
- Added session claiming and releasing functionality in the completion manager.
- Enhanced HTTP and WebSocket completion endpoints to support multiple languages.
- Implemented request timeout and maximum body size configurations for API routes.
- Updated client-side code to handle session IDs and language parameters in completion requests.
- Improved error handling for unsupported languages and session conflicts.
- Added tests for the completion manager to ensure proper session handling and cleanup.
This commit is contained in:
2026-02-15 16:22:01 +08:00
parent 23decb8687
commit 57afb90bc0
14 changed files with 1334 additions and 138 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"time"
"github.com/gin-gonic/gin"
@@ -14,30 +15,100 @@ type CompletionService interface {
Complete(ctx context.Context, req completion.Request) (completion.Response, error)
}
func RegisterRoutes(router *gin.Engine, service CompletionService) {
type SessionStatsProvider interface {
ActiveSessions() map[string]int
}
type RouteOptions struct {
RequestTimeout time.Duration
MaxBodyBytes int64
}
func RegisterRoutes(router *gin.Engine, service CompletionService, options ...RouteOptions) {
opts := RouteOptions{
RequestTimeout: 10 * time.Second,
MaxBodyBytes: 2 << 20, // 2MB
}
if len(options) > 0 {
if options[0].RequestTimeout > 0 {
opts.RequestTimeout = options[0].RequestTimeout
}
if options[0].MaxBodyBytes > 0 {
opts.MaxBodyBytes = options[0].MaxBodyBytes
}
}
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
c.JSON(http.StatusOK, gin.H{"status": "ok", "service": "lsp-gateway"})
})
registerWSRoutes(router, service)
router.GET("/health/live", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "alive"})
})
router.POST("/api/v1/completions/go", func(c *gin.Context) {
router.GET("/health/ready", func(c *gin.Context) {
sessions := map[string]int{}
if provider, ok := service.(SessionStatsProvider); ok {
sessions = provider.ActiveSessions()
}
c.JSON(http.StatusOK, gin.H{
"status": "ready",
"sessions": sessions,
})
})
registerWSRoutes(router, service, opts)
handleCompletion := func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, opts.MaxBodyBytes)
var req completion.Request
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON payload"})
return
}
resp, err := service.Complete(c.Request.Context(), req)
routeLang := c.Param("language")
if req.Language == "" {
req.Language = routeLang
}
ctx, cancel := context.WithTimeout(c.Request.Context(), opts.RequestTimeout)
defer cancel()
resp, err := service.Complete(ctx, req)
if err != nil {
if errors.Is(err, completion.ErrInvalidRequest) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if errors.Is(err, completion.ErrUnsupportedLanguage) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if errors.Is(err, completion.ErrTooManySessions) {
c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()})
return
}
var ownedErr *completion.ErrSessionOwnedByOtherInstance
if errors.As(err, &ownedErr) {
if ownedErr.OwnerEndpoint != "" {
c.Header("X-LSP-Route-To", ownedErr.OwnerEndpoint)
}
c.JSON(http.StatusConflict, gin.H{
"error": err.Error(),
"routeTo": ownedErr.OwnerEndpoint,
"ownerId": ownedErr.OwnerID,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "completion failed"})
return
}
c.JSON(http.StatusOK, resp)
}
router.POST("/api/v1/completions/:language", func(c *gin.Context) {
handleCompletion(c)
})
}

View File

@@ -6,7 +6,6 @@ import (
"errors"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -24,6 +23,8 @@ var wsUpgrader = websocket.Upgrader{
type wsCompletionRequest struct {
ID string `json:"id"`
Language string `json:"language,omitempty"`
SessionID string `json:"sessionId,omitempty"`
URI string `json:"uri"`
Text string `json:"text"`
Line int `json:"line"`
@@ -34,17 +35,35 @@ type wsCompletionResponse struct {
ID string `json:"id"`
Items []completion.Item `json:"items,omitempty"`
IsIncomplete bool `json:"isIncomplete,omitempty"`
RouteTo string `json:"routeTo,omitempty"`
OwnerID string `json:"ownerId,omitempty"`
Error string `json:"error,omitempty"`
}
func registerWSRoutes(router *gin.Engine, service CompletionService) {
router.GET("/ws/completions/go", func(c *gin.Context) {
type wsRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
}
type wsRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id"`
Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"`
}
func registerWSRoutes(router *gin.Engine, service CompletionService, opts RouteOptions) {
handler := func(c *gin.Context, defaultLanguage string) {
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
defer conn.Close()
conn.SetReadLimit(opts.MaxBodyBytes)
var writeMu sync.Mutex
for {
@@ -52,45 +71,173 @@ func registerWSRoutes(router *gin.Engine, service CompletionService) {
if err != nil {
break
}
var req wsCompletionRequest
if err := json.Unmarshal(payload, &req); err != nil {
sendWSResponse(conn, &writeMu, wsCompletionResponse{
ID: "",
Error: "invalid JSON payload",
})
continue
}
go func(r wsCompletionRequest) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
resp, err := service.Complete(ctx, completion.Request{
URI: r.URI,
Text: r.Text,
Line: r.Line,
Character: r.Character,
})
if err != nil {
msg := "completion failed"
if errors.Is(err, completion.ErrInvalidRequest) {
msg = err.Error()
}
sendWSResponse(conn, &writeMu, wsCompletionResponse{
ID: r.ID,
Error: msg,
})
return
}
sendWSResponse(conn, &writeMu, wsCompletionResponse{
ID: r.ID,
Items: resp.Items,
IsIncomplete: resp.IsIncomplete,
})
}(req)
handleWSMessage(conn, &writeMu, service, payload, defaultLanguage, opts)
}
}
router.GET("/ws/completions", func(c *gin.Context) {
handler(c, "")
})
router.GET("/ws/completions/:language", func(c *gin.Context) {
handler(c, c.Param("language"))
})
}
func handleWSMessage(
conn *websocket.Conn,
writeMu *sync.Mutex,
service CompletionService,
payload []byte,
defaultLanguage string,
opts RouteOptions,
) {
if tryHandleRPCMessage(conn, writeMu, service, payload, defaultLanguage, opts) {
return
}
var req wsCompletionRequest
if err := json.Unmarshal(payload, &req); err != nil {
sendWSResponse(conn, writeMu, wsCompletionResponse{
ID: "",
Error: "invalid JSON payload",
})
return
}
if req.Language == "" {
req.Language = defaultLanguage
}
processWSCompletion(conn, writeMu, service, req, opts)
}
func tryHandleRPCMessage(
conn *websocket.Conn,
writeMu *sync.Mutex,
service CompletionService,
payload []byte,
defaultLanguage string,
opts RouteOptions,
) bool {
var rpcReq wsRPCRequest
if err := json.Unmarshal(payload, &rpcReq); err != nil {
return false
}
if rpcReq.JSONRPC != "2.0" || rpcReq.Method == "" {
return false
}
if rpcReq.Method != "completion/complete" && rpcReq.Method != "completion.complete" {
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
JSONRPC: "2.0",
ID: rpcReq.ID,
Error: map[string]any{
"code": -32601,
"message": "method not found",
},
})
return true
}
var req wsCompletionRequest
if err := json.Unmarshal(rpcReq.Params, &req); err != nil {
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
JSONRPC: "2.0",
ID: rpcReq.ID,
Error: map[string]any{
"code": -32602,
"message": "invalid params",
},
})
return true
}
if req.Language == "" {
req.Language = defaultLanguage
}
if req.ID == "" {
req.ID = string(rpcReq.ID)
}
ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout)
defer cancel()
resp, err := service.Complete(ctx, completion.Request{
Language: req.Language,
SessionID: req.SessionID,
URI: req.URI,
Text: req.Text,
Line: req.Line,
Character: req.Character,
})
if err != nil {
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
JSONRPC: "2.0",
ID: rpcReq.ID,
Error: map[string]any{
"code": -32000,
"message": err.Error(),
},
})
return true
}
sendWSRPCResponse(conn, writeMu, wsRPCResponse{
JSONRPC: "2.0",
ID: rpcReq.ID,
Result: resp,
})
return true
}
func processWSCompletion(
conn *websocket.Conn,
writeMu *sync.Mutex,
service CompletionService,
req wsCompletionRequest,
opts RouteOptions,
) {
ctx, cancel := context.WithTimeout(context.Background(), opts.RequestTimeout)
defer cancel()
resp, err := service.Complete(ctx, completion.Request{
Language: req.Language,
SessionID: req.SessionID,
URI: req.URI,
Text: req.Text,
Line: req.Line,
Character: req.Character,
})
if err != nil {
msg := "completion failed"
routeTo := ""
ownerID := ""
switch {
case errors.Is(err, completion.ErrInvalidRequest):
msg = err.Error()
case errors.Is(err, completion.ErrUnsupportedLanguage):
msg = err.Error()
case errors.Is(err, completion.ErrTooManySessions):
msg = err.Error()
default:
var ownedErr *completion.ErrSessionOwnedByOtherInstance
if errors.As(err, &ownedErr) {
msg = err.Error()
routeTo = ownedErr.OwnerEndpoint
ownerID = ownedErr.OwnerID
}
}
sendWSResponse(conn, writeMu, wsCompletionResponse{
ID: req.ID,
RouteTo: routeTo,
OwnerID: ownerID,
Error: msg,
})
return
}
sendWSResponse(conn, writeMu, wsCompletionResponse{
ID: req.ID,
Items: resp.Items,
IsIncomplete: resp.IsIncomplete,
})
}
@@ -99,3 +246,9 @@ func sendWSResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsCompletion
defer writeMu.Unlock()
_ = conn.WriteJSON(resp)
}
func sendWSRPCResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsRPCResponse) {
writeMu.Lock()
defer writeMu.Unlock()
_ = conn.WriteJSON(resp)
}

View File

@@ -0,0 +1,222 @@
package cluster
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
type RedisRegistryConfig struct {
Addr string
Password string
DB int
KeyPrefix string
InstanceID string
InstanceEndpoint string
SessionTTL time.Duration
InstanceTTL time.Duration
HeartbeatInterval time.Duration
}
type RedisRegistry struct {
client *redis.Client
keyPrefix string
instanceID string
instanceEndpoint string
sessionTTL time.Duration
instanceTTL time.Duration
heartbeatInterval time.Duration
stopCh chan struct{}
stopOnce sync.Once
}
var claimSessionScript = redis.NewScript(`
local sessionKey = KEYS[1]
local owner = ARGV[1]
local now = ARGV[2]
local ttl = tonumber(ARGV[3])
local existing = redis.call('HGET', sessionKey, 'owner')
if (not existing) or existing == owner then
redis.call('HSET', sessionKey, 'owner', owner, 'updatedAt', now)
redis.call('PEXPIRE', sessionKey, ttl)
return owner
end
redis.call('PEXPIRE', sessionKey, ttl)
return existing
`)
var releaseSessionScript = redis.NewScript(`
local sessionKey = KEYS[1]
local owner = ARGV[1]
local existing = redis.call('HGET', sessionKey, 'owner')
if existing == owner then
redis.call('DEL', sessionKey)
return 1
end
return 0
`)
func NewRedisRegistry(ctx context.Context, cfg RedisRegistryConfig) (*RedisRegistry, error) {
if strings.TrimSpace(cfg.Addr) == "" {
return nil, errors.New("redis addr is required")
}
if strings.TrimSpace(cfg.KeyPrefix) == "" {
cfg.KeyPrefix = "lsp-gateway"
}
if cfg.SessionTTL <= 0 {
cfg.SessionTTL = 20 * time.Minute
}
if cfg.InstanceTTL <= 0 {
cfg.InstanceTTL = 30 * time.Second
}
if cfg.HeartbeatInterval <= 0 {
cfg.HeartbeatInterval = 10 * time.Second
}
if cfg.InstanceID == "" {
return nil, errors.New("instance id is required")
}
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
})
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("redis ping failed: %w", err)
}
registry := &RedisRegistry{
client: client,
keyPrefix: cfg.KeyPrefix,
instanceID: cfg.InstanceID,
instanceEndpoint: cfg.InstanceEndpoint,
sessionTTL: cfg.SessionTTL,
instanceTTL: cfg.InstanceTTL,
heartbeatInterval: cfg.HeartbeatInterval,
stopCh: make(chan struct{}),
}
if err := registry.refreshInstance(ctx); err != nil {
_ = client.Close()
return nil, err
}
go registry.heartbeatLoop()
return registry, nil
}
func (r *RedisRegistry) ClaimSession(
ctx context.Context,
language string,
sessionID string,
) (ownerID string, ownerEndpoint string, err error) {
key := r.sessionKey(language, sessionID)
now := time.Now().UTC().Format(time.RFC3339Nano)
raw, err := claimSessionScript.Run(
ctx,
r.client,
[]string{key},
r.instanceID,
now,
r.sessionTTL.Milliseconds(),
).Result()
if err != nil {
return "", "", fmt.Errorf("claim session failed: %w", err)
}
owner := fmt.Sprint(raw)
if owner == "" {
return "", "", errors.New("claim session returned empty owner")
}
if owner == r.instanceID {
return owner, r.instanceEndpoint, nil
}
endpoint, err := r.resolveInstanceEndpoint(ctx, owner)
if err != nil {
return owner, "", err
}
return owner, endpoint, nil
}
func (r *RedisRegistry) ReleaseSession(ctx context.Context, language, sessionID string) error {
key := r.sessionKey(language, sessionID)
if _, err := releaseSessionScript.Run(ctx, r.client, []string{key}, r.instanceID).Result(); err != nil {
return fmt.Errorf("release session failed: %w", err)
}
return nil
}
func (r *RedisRegistry) Close() error {
r.stopOnce.Do(func() {
close(r.stopCh)
})
return r.client.Close()
}
func (r *RedisRegistry) heartbeatLoop() {
ticker := time.NewTicker(r.heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
_ = r.refreshInstance(ctx)
cancel()
case <-r.stopCh:
return
}
}
}
func (r *RedisRegistry) refreshInstance(ctx context.Context) error {
key := r.instanceKey(r.instanceID)
now := time.Now().UTC().Format(time.RFC3339Nano)
if err := r.client.HSet(ctx, key, map[string]any{
"endpoint": r.instanceEndpoint,
"updatedAt": now,
}).Err(); err != nil {
return fmt.Errorf("refresh instance metadata failed: %w", err)
}
if err := r.client.Expire(ctx, key, r.instanceTTL).Err(); err != nil {
return fmt.Errorf("refresh instance ttl failed: %w", err)
}
return nil
}
func (r *RedisRegistry) resolveInstanceEndpoint(ctx context.Context, ownerID string) (string, error) {
key := r.instanceKey(ownerID)
endpoint, err := r.client.HGet(ctx, key, "endpoint").Result()
if err == redis.Nil {
return "", nil
}
if err != nil {
return "", fmt.Errorf("resolve instance endpoint failed: %w", err)
}
return strings.TrimSpace(endpoint), nil
}
func (r *RedisRegistry) sessionKey(language, sessionID string) string {
return fmt.Sprintf("%s:sessions:%s:%s", r.keyPrefix, normalizePart(language), normalizePart(sessionID))
}
func (r *RedisRegistry) instanceKey(instanceID string) string {
return fmt.Sprintf("%s:instances:%s", r.keyPrefix, normalizePart(instanceID))
}
func normalizePart(value string) string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return "default"
}
return trimmed
}

View File

@@ -0,0 +1,314 @@
package completion
import (
"context"
"errors"
"fmt"
"slices"
"strings"
"sync"
"time"
)
var ErrUnsupportedLanguage = errors.New("unsupported language")
var ErrTooManySessions = errors.New("too many active lsp sessions")
type RuntimeClient interface {
Client
Close() error
}
type SessionRegistry interface {
ClaimSession(ctx context.Context, language, sessionID string) (ownerID string, ownerEndpoint string, err error)
ReleaseSession(ctx context.Context, language, sessionID string) error
Close() error
}
type LanguageServerSpec struct {
Language string
LanguageID string
Command string
Args []string
}
type ClientFactory func(ctx context.Context, spec LanguageServerSpec, workspaceDir string) (RuntimeClient, error)
type ManagerConfig struct {
WorkspaceDir string
MaxSessions int
SessionTTL time.Duration
CleanupInterval time.Duration
InstanceID string
Registry SessionRegistry
}
type Manager struct {
mu sync.Mutex
config ManagerConfig
specByLang map[string]LanguageServerSpec
sessions map[string]*managedSession
newClient ClientFactory
stopCh chan struct{}
stoppedOnce sync.Once
}
type managedSession struct {
key string
sessionID string
language string
service *Service
client RuntimeClient
lastUsed time.Time
createdAt time.Time
}
type ErrSessionOwnedByOtherInstance struct {
OwnerID string
OwnerEndpoint string
}
func (e *ErrSessionOwnedByOtherInstance) Error() string {
if e.OwnerEndpoint != "" {
return fmt.Sprintf("session owned by another instance: %s (%s)", e.OwnerID, e.OwnerEndpoint)
}
return fmt.Sprintf("session owned by another instance: %s", e.OwnerID)
}
func NewManager(config ManagerConfig, specs []LanguageServerSpec, factory ClientFactory) *Manager {
if config.MaxSessions <= 0 {
config.MaxSessions = 256
}
if config.SessionTTL <= 0 {
config.SessionTTL = 20 * time.Minute
}
if config.CleanupInterval <= 0 {
config.CleanupInterval = 2 * time.Minute
}
if strings.TrimSpace(config.InstanceID) == "" {
config.InstanceID = "instance-local"
}
specByLang := make(map[string]LanguageServerSpec)
for _, spec := range specs {
key := normalizeLanguage(spec.Language)
if key == "" {
continue
}
specByLang[key] = spec
}
m := &Manager{
config: config,
specByLang: specByLang,
sessions: make(map[string]*managedSession),
newClient: factory,
stopCh: make(chan struct{}),
}
go m.cleanupLoop()
return m
}
func (m *Manager) Complete(ctx context.Context, req Request) (Response, error) {
language := normalizeLanguage(req.Language)
if language == "" {
return Response{}, ErrInvalidRequest
}
spec, ok := m.specByLang[language]
if !ok {
return Response{}, fmt.Errorf("%w: %s", ErrUnsupportedLanguage, req.Language)
}
sessionKey := buildSessionKey(language, req.SessionID)
sessionID := normalizeSessionID(req.SessionID)
if m.config.Registry != nil {
ownerID, ownerEndpoint, err := m.config.Registry.ClaimSession(ctx, language, sessionID)
if err != nil {
return Response{}, err
}
if ownerID != m.config.InstanceID {
return Response{}, &ErrSessionOwnedByOtherInstance{
OwnerID: ownerID,
OwnerEndpoint: ownerEndpoint,
}
}
}
session, err := m.getOrCreateSession(ctx, sessionKey, sessionID, spec)
if err != nil {
return Response{}, err
}
resp, err := session.service.Complete(ctx, req)
if err != nil {
return Response{}, err
}
m.mu.Lock()
if current, ok := m.sessions[sessionKey]; ok {
current.lastUsed = time.Now()
}
m.mu.Unlock()
return resp, nil
}
func (m *Manager) ActiveSessions() map[string]int {
m.mu.Lock()
defer m.mu.Unlock()
out := make(map[string]int)
for _, session := range m.sessions {
out[session.language]++
}
return out
}
func (m *Manager) Close() error {
m.stoppedOnce.Do(func() {
close(m.stopCh)
m.mu.Lock()
defer m.mu.Unlock()
for key, session := range m.sessions {
if m.config.Registry != nil {
_ = m.config.Registry.ReleaseSession(context.Background(), session.language, session.sessionID)
}
_ = session.client.Close()
delete(m.sessions, key)
}
if m.config.Registry != nil {
_ = m.config.Registry.Close()
}
})
return nil
}
func (m *Manager) cleanupLoop() {
ticker := time.NewTicker(m.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanupIdleSessions()
case <-m.stopCh:
return
}
}
}
func (m *Manager) cleanupIdleSessions() {
cutoff := time.Now().Add(-m.config.SessionTTL)
m.mu.Lock()
defer m.mu.Unlock()
for key, session := range m.sessions {
if session.lastUsed.After(cutoff) {
continue
}
if m.config.Registry != nil {
_ = m.config.Registry.ReleaseSession(context.Background(), session.language, session.sessionID)
}
_ = session.client.Close()
delete(m.sessions, key)
}
}
func (m *Manager) getOrCreateSession(
ctx context.Context,
sessionKey string,
sessionID string,
spec LanguageServerSpec,
) (*managedSession, error) {
m.mu.Lock()
if existing, ok := m.sessions[sessionKey]; ok {
existing.lastUsed = time.Now()
m.mu.Unlock()
return existing, nil
}
if len(m.sessions) >= m.config.MaxSessions {
if !m.evictLeastRecentlyUsedLocked() {
m.mu.Unlock()
return nil, ErrTooManySessions
}
}
m.mu.Unlock()
client, err := m.newClient(ctx, spec, m.config.WorkspaceDir)
if err != nil {
return nil, err
}
now := time.Now()
newSession := &managedSession{
key: sessionKey,
sessionID: sessionID,
language: normalizeLanguage(spec.Language),
service: NewService(client),
client: client,
lastUsed: now,
createdAt: now,
}
m.mu.Lock()
defer m.mu.Unlock()
if existing, ok := m.sessions[sessionKey]; ok {
_ = client.Close()
existing.lastUsed = now
return existing, nil
}
m.sessions[sessionKey] = newSession
return newSession, nil
}
func (m *Manager) evictLeastRecentlyUsedLocked() bool {
if len(m.sessions) == 0 {
return false
}
keys := make([]string, 0, len(m.sessions))
for key := range m.sessions {
keys = append(keys, key)
}
slices.SortFunc(keys, func(a, b string) int {
as := m.sessions[a]
bs := m.sessions[b]
if as.lastUsed.Before(bs.lastUsed) {
return -1
}
if as.lastUsed.After(bs.lastUsed) {
return 1
}
return strings.Compare(as.key, bs.key)
})
victimKey := keys[0]
victim := m.sessions[victimKey]
if m.config.Registry != nil {
_ = m.config.Registry.ReleaseSession(context.Background(), victim.language, victim.sessionID)
}
_ = victim.client.Close()
delete(m.sessions, victimKey)
return true
}
func buildSessionKey(language, sessionID string) string {
return language + ":" + normalizeSessionID(sessionID)
}
func normalizeSessionID(sessionID string) string {
sid := strings.TrimSpace(sessionID)
if sid == "" {
return "default"
}
return sid
}
func normalizeLanguage(language string) string {
return strings.ToLower(strings.TrimSpace(language))
}

View File

@@ -0,0 +1,125 @@
package completion
import (
"context"
"errors"
"fmt"
"testing"
"time"
)
type fakeRuntimeClient struct {
closed bool
}
func (f *fakeRuntimeClient) DidOpen(_ context.Context, _ string, _ string, _ int) error {
return nil
}
func (f *fakeRuntimeClient) DidChange(_ context.Context, _ string, _ string, _ int) error {
return nil
}
func (f *fakeRuntimeClient) Completion(_ context.Context, _ string, _ int, _ int) (Response, error) {
return Response{
Items: []Item{{Label: "ok"}},
}, nil
}
func (f *fakeRuntimeClient) Close() error {
f.closed = true
return nil
}
func TestManagerCompleteSelectsLanguageAndSession(t *testing.T) {
createCount := 0
factory := func(_ context.Context, spec LanguageServerSpec, _ string) (RuntimeClient, error) {
createCount++
if spec.Language != "go" {
return nil, fmt.Errorf("unexpected language: %s", spec.Language)
}
return &fakeRuntimeClient{}, nil
}
m := NewManager(ManagerConfig{WorkspaceDir: "."}, []LanguageServerSpec{
{Language: "go", LanguageID: "go", Command: "gopls"},
}, factory)
defer m.Close()
for i := 0; i < 2; i++ {
resp, err := m.Complete(context.Background(), Request{
Language: "go",
SessionID: "s1",
URI: "file:///main.go",
Text: "package main",
Line: 0,
Character: 0,
})
if err != nil {
t.Fatalf("Complete() error = %v", err)
}
if len(resp.Items) != 1 || resp.Items[0].Label != "ok" {
t.Fatalf("unexpected response: %+v", resp)
}
}
if createCount != 1 {
t.Fatalf("expected one client creation, got %d", createCount)
}
}
func TestManagerCompleteUnsupportedLanguage(t *testing.T) {
m := NewManager(ManagerConfig{}, []LanguageServerSpec{
{Language: "go", LanguageID: "go", Command: "gopls"},
}, func(_ context.Context, _ LanguageServerSpec, _ string) (RuntimeClient, error) {
return &fakeRuntimeClient{}, nil
})
defer m.Close()
_, err := m.Complete(context.Background(), Request{
Language: "python",
URI: "file:///main.py",
Text: "print('hi')",
Line: 0,
Character: 0,
})
if !errors.Is(err, ErrUnsupportedLanguage) {
t.Fatalf("expected ErrUnsupportedLanguage, got %v", err)
}
}
func TestManagerCleanupIdleSession(t *testing.T) {
client := &fakeRuntimeClient{}
m := NewManager(ManagerConfig{
WorkspaceDir: ".",
SessionTTL: 30 * time.Millisecond,
CleanupInterval: 10 * time.Millisecond,
}, []LanguageServerSpec{
{Language: "go", LanguageID: "go", Command: "gopls"},
}, func(_ context.Context, _ LanguageServerSpec, _ string) (RuntimeClient, error) {
return client, nil
})
defer m.Close()
_, err := m.Complete(context.Background(), Request{
Language: "go",
SessionID: "s2",
URI: "file:///main.go",
Text: "package main",
Line: 0,
Character: 0,
})
if err != nil {
t.Fatalf("Complete() error = %v", err)
}
time.Sleep(90 * time.Millisecond)
sessions := m.ActiveSessions()
if sessions["go"] != 0 {
t.Fatalf("expected session cleanup, got %+v", sessions)
}
if !client.closed {
t.Fatal("expected client to be closed by cleanup")
}
}

View File

@@ -10,6 +10,8 @@ import (
var ErrInvalidRequest = errors.New("invalid completion request")
type Request struct {
Language string `json:"language,omitempty"`
SessionID string `json:"sessionId,omitempty"`
URI string `json:"uri"`
Text string `json:"text"`
Line int `json:"line"`

View File

@@ -33,6 +33,8 @@ type Client struct {
nextID atomic.Int64
workspaceDir string
languageID string
clientName string
pendingMu sync.Mutex
pending map[string]chan rpcResponse
@@ -74,19 +76,33 @@ type lspCompletionList struct {
Items []lspCompletionItem `json:"items"`
}
func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, error) {
if goplsPath == "" {
goplsPath = "gopls"
type Config struct {
Command string
Args []string
RootPath string
LanguageID string
ClientName string
}
func NewClient(parent context.Context, cfg Config) (*Client, error) {
if cfg.Command == "" {
cfg.Command = "gopls"
}
if rootPath == "" {
if cfg.RootPath == "" {
cwd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("get working directory: %w", err)
}
rootPath = cwd
cfg.RootPath = cwd
}
if cfg.LanguageID == "" {
cfg.LanguageID = "go"
}
if cfg.ClientName == "" {
cfg.ClientName = "monica-lsp-gateway"
}
cmd := exec.Command(goplsPath)
cmd := exec.Command(cfg.Command, cfg.Args...)
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("create stdin pipe: %w", err)
@@ -98,13 +114,15 @@ func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, err
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start gopls: %w", err)
return nil, fmt.Errorf("start language server %q: %w", cfg.Command, err)
}
client := &Client{
cmd: cmd,
stdin: stdin,
workspaceDir: rootPath,
workspaceDir: cfg.RootPath,
languageID: cfg.LanguageID,
clientName: cfg.ClientName,
pending: make(map[string]chan rpcResponse),
exitCh: make(chan error, 1),
}
@@ -117,7 +135,7 @@ func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, err
initCtx, cancel := context.WithTimeout(parent, 10*time.Second)
defer cancel()
if err := client.initialize(initCtx, rootPath); err != nil {
if err := client.initialize(initCtx, cfg.RootPath); err != nil {
_ = client.Close()
return nil, err
}
@@ -150,7 +168,7 @@ func (c *Client) initialize(ctx context.Context, rootPath string) error {
},
},
"clientInfo": map[string]string{
"name": "monica-go-completion-backend",
"name": c.clientName,
"version": "0.1.0",
},
}
@@ -174,7 +192,7 @@ func (c *Client) DidOpen(ctx context.Context, uri, text string, version int) err
params := map[string]any{
"textDocument": map[string]any{
"uri": normalizedURI,
"languageId": "go",
"languageId": c.languageID,
"version": version,
"text": text,
},