all
This commit is contained in:
43
backend/internal/api/handler.go
Normal file
43
backend/internal/api/handler.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"monica-go-completion-backend/internal/completion"
|
||||
)
|
||||
|
||||
type CompletionService interface {
|
||||
Complete(ctx context.Context, req completion.Request) (completion.Response, error)
|
||||
}
|
||||
|
||||
func RegisterRoutes(router *gin.Engine, service CompletionService) {
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
registerWSRoutes(router, service)
|
||||
|
||||
router.POST("/api/v1/completions/go", func(c *gin.Context) {
|
||||
var req completion.Request
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON payload"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := service.Complete(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
if errors.Is(err, completion.ErrInvalidRequest) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "completion failed"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
})
|
||||
}
|
||||
171
backend/internal/api/handler_test.go
Normal file
171
backend/internal/api/handler_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"monica-go-completion-backend/internal/completion"
|
||||
)
|
||||
|
||||
type fakeCompletionService struct {
|
||||
resp completion.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeCompletionService) Complete(_ context.Context, _ completion.Request) (completion.Response, error) {
|
||||
if f.err != nil {
|
||||
return completion.Response{}, f.err
|
||||
}
|
||||
return f.resp, nil
|
||||
}
|
||||
|
||||
func TestRegisterRoutesCompletionSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
RegisterRoutes(r, &fakeCompletionService{
|
||||
resp: completion.Response{Items: []completion.Item{{Label: "Println"}}, IsIncomplete: true},
|
||||
})
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"uri": "file:///main.go",
|
||||
"text": "package main",
|
||||
"line": 0,
|
||||
"character": 0,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var got struct {
|
||||
Items []completion.Item `json:"items"`
|
||||
IsIncomplete bool `json:"isIncomplete"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if len(got.Items) != 1 || got.Items[0].Label != "Println" || !got.IsIncomplete {
|
||||
t.Fatalf("unexpected response body: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterRoutesCompletionBadJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
RegisterRoutes(r, &fakeCompletionService{})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader([]byte("not-json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterRoutesCompletionValidationError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
rv := &fakeCompletionService{err: completion.ErrInvalidRequest}
|
||||
RegisterRoutes(r, rv)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"uri": "",
|
||||
"text": "",
|
||||
"line": -1,
|
||||
"character": 0,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterRoutesCompletionServerError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
RegisterRoutes(r, &fakeCompletionService{err: errors.New("boom")})
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"uri": "file:///main.go",
|
||||
"text": "package main",
|
||||
"line": 0,
|
||||
"character": 0,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status 500, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterRoutesCompletionWebSocketSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
RegisterRoutes(r, &fakeCompletionService{
|
||||
resp: completion.Response{Items: []completion.Item{{Label: "Println"}}, IsIncomplete: false},
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/completions/go"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("websocket dial failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req := map[string]any{
|
||||
"id": "1",
|
||||
"uri": "file:///main.go",
|
||||
"text": "package main",
|
||||
"line": 0,
|
||||
"character": 0,
|
||||
}
|
||||
if err := conn.WriteJSON(req); err != nil {
|
||||
t.Fatalf("WriteJSON() failed: %v", err)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
ID string `json:"id"`
|
||||
Items []completion.Item `json:"items"`
|
||||
IsIncomplete bool `json:"isIncomplete"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := conn.ReadJSON(&resp); err != nil {
|
||||
t.Fatalf("ReadJSON() failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.Error != "" {
|
||||
t.Fatalf("unexpected websocket error: %s", resp.Error)
|
||||
}
|
||||
if resp.ID != "1" {
|
||||
t.Fatalf("unexpected id: %s", resp.ID)
|
||||
}
|
||||
if len(resp.Items) != 1 || resp.Items[0].Label != "Println" {
|
||||
t.Fatalf("unexpected items: %+v", resp.Items)
|
||||
}
|
||||
}
|
||||
101
backend/internal/api/ws_handler.go
Normal file
101
backend/internal/api/ws_handler.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"monica-go-completion-backend/internal/completion"
|
||||
)
|
||||
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(_ *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type wsCompletionRequest struct {
|
||||
ID string `json:"id"`
|
||||
URI string `json:"uri"`
|
||||
Text string `json:"text"`
|
||||
Line int `json:"line"`
|
||||
Character int `json:"character"`
|
||||
}
|
||||
|
||||
type wsCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Items []completion.Item `json:"items,omitempty"`
|
||||
IsIncomplete bool `json:"isIncomplete,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func registerWSRoutes(router *gin.Engine, service CompletionService) {
|
||||
router.GET("/ws/completions/go", func(c *gin.Context) {
|
||||
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var writeMu sync.Mutex
|
||||
|
||||
for {
|
||||
_, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var req wsCompletionRequest
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: "",
|
||||
Error: "invalid JSON payload",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
go func(r wsCompletionRequest) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := service.Complete(ctx, completion.Request{
|
||||
URI: r.URI,
|
||||
Text: r.Text,
|
||||
Line: r.Line,
|
||||
Character: r.Character,
|
||||
})
|
||||
if err != nil {
|
||||
msg := "completion failed"
|
||||
if errors.Is(err, completion.ErrInvalidRequest) {
|
||||
msg = err.Error()
|
||||
}
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: r.ID,
|
||||
Error: msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sendWSResponse(conn, &writeMu, wsCompletionResponse{
|
||||
ID: r.ID,
|
||||
Items: resp.Items,
|
||||
IsIncomplete: resp.IsIncomplete,
|
||||
})
|
||||
}(req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func sendWSResponse(conn *websocket.Conn, writeMu *sync.Mutex, resp wsCompletionResponse) {
|
||||
writeMu.Lock()
|
||||
defer writeMu.Unlock()
|
||||
_ = conn.WriteJSON(resp)
|
||||
}
|
||||
95
backend/internal/completion/service.go
Normal file
95
backend/internal/completion/service.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package completion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrInvalidRequest = errors.New("invalid completion request")
|
||||
|
||||
type Request struct {
|
||||
URI string `json:"uri"`
|
||||
Text string `json:"text"`
|
||||
Line int `json:"line"`
|
||||
Character int `json:"character"`
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Label string `json:"label"`
|
||||
Kind int `json:"kind,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
Documentation string `json:"documentation,omitempty"`
|
||||
InsertText string `json:"insertText,omitempty"`
|
||||
SortText string `json:"sortText,omitempty"`
|
||||
FilterText string `json:"filterText,omitempty"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Items []Item `json:"items"`
|
||||
IsIncomplete bool `json:"isIncomplete"`
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
DidOpen(ctx context.Context, uri, text string, version int) error
|
||||
DidChange(ctx context.Context, uri, text string, version int) error
|
||||
Completion(ctx context.Context, uri string, line, character int) (Response, error)
|
||||
}
|
||||
|
||||
type documentState struct {
|
||||
version int
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
client Client
|
||||
|
||||
mu sync.Mutex
|
||||
documents map[string]*documentState
|
||||
}
|
||||
|
||||
func NewService(client Client) *Service {
|
||||
return &Service{
|
||||
client: client,
|
||||
documents: make(map[string]*documentState),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Complete(ctx context.Context, req Request) (Response, error) {
|
||||
if err := validateRequest(req); err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
state, opened := s.documents[req.URI]
|
||||
if !opened {
|
||||
if err := s.client.DidOpen(ctx, req.URI, req.Text, 1); err != nil {
|
||||
return Response{}, fmt.Errorf("didOpen failed: %w", err)
|
||||
}
|
||||
s.documents[req.URI] = &documentState{version: 1}
|
||||
} else {
|
||||
nextVersion := state.version + 1
|
||||
if err := s.client.DidChange(ctx, req.URI, req.Text, nextVersion); err != nil {
|
||||
return Response{}, fmt.Errorf("didChange failed: %w", err)
|
||||
}
|
||||
state.version = nextVersion
|
||||
}
|
||||
|
||||
resp, err := s.client.Completion(ctx, req.URI, req.Line, req.Character)
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("completion failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func validateRequest(req Request) error {
|
||||
if req.URI == "" {
|
||||
return ErrInvalidRequest
|
||||
}
|
||||
if req.Line < 0 || req.Character < 0 {
|
||||
return ErrInvalidRequest
|
||||
}
|
||||
return nil
|
||||
}
|
||||
144
backend/internal/completion/service_test.go
Normal file
144
backend/internal/completion/service_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package completion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeLSPClient struct {
|
||||
openCalls []openCall
|
||||
changeCalls []changeCall
|
||||
completionCalls []completionCall
|
||||
completionResp Response
|
||||
openErr error
|
||||
changeErr error
|
||||
completionErr error
|
||||
}
|
||||
|
||||
type openCall struct {
|
||||
uri string
|
||||
text string
|
||||
version int
|
||||
}
|
||||
|
||||
type changeCall struct {
|
||||
uri string
|
||||
text string
|
||||
version int
|
||||
}
|
||||
|
||||
type completionCall struct {
|
||||
uri string
|
||||
line int
|
||||
character int
|
||||
}
|
||||
|
||||
func (f *fakeLSPClient) DidOpen(_ context.Context, uri, text string, version int) error {
|
||||
f.openCalls = append(f.openCalls, openCall{uri: uri, text: text, version: version})
|
||||
return f.openErr
|
||||
}
|
||||
|
||||
func (f *fakeLSPClient) DidChange(_ context.Context, uri, text string, version int) error {
|
||||
f.changeCalls = append(f.changeCalls, changeCall{uri: uri, text: text, version: version})
|
||||
return f.changeErr
|
||||
}
|
||||
|
||||
func (f *fakeLSPClient) Completion(_ context.Context, uri string, line, character int) (Response, error) {
|
||||
f.completionCalls = append(f.completionCalls, completionCall{uri: uri, line: line, character: character})
|
||||
if f.completionErr != nil {
|
||||
return Response{}, f.completionErr
|
||||
}
|
||||
return f.completionResp, nil
|
||||
}
|
||||
|
||||
func TestServiceCompleteFirstRequestSendsDidOpen(t *testing.T) {
|
||||
fake := &fakeLSPClient{
|
||||
completionResp: Response{Items: []Item{{Label: "Println"}}, IsIncomplete: true},
|
||||
}
|
||||
svc := NewService(fake)
|
||||
|
||||
resp, err := svc.Complete(context.Background(), Request{
|
||||
URI: "file:///main.go",
|
||||
Text: "package main\nfunc main() { fmt.Pr }",
|
||||
Line: 1,
|
||||
Character: 23,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Complete() error = %v", err)
|
||||
}
|
||||
if len(fake.openCalls) != 1 {
|
||||
t.Fatalf("expected 1 didOpen call, got %d", len(fake.openCalls))
|
||||
}
|
||||
if fake.openCalls[0].version != 1 {
|
||||
t.Fatalf("expected didOpen version=1, got %d", fake.openCalls[0].version)
|
||||
}
|
||||
if len(fake.changeCalls) != 0 {
|
||||
t.Fatalf("expected 0 didChange calls, got %d", len(fake.changeCalls))
|
||||
}
|
||||
if len(fake.completionCalls) != 1 {
|
||||
t.Fatalf("expected 1 completion call, got %d", len(fake.completionCalls))
|
||||
}
|
||||
if !resp.IsIncomplete || len(resp.Items) != 1 || resp.Items[0].Label != "Println" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceCompleteSecondRequestUsesDidChange(t *testing.T) {
|
||||
fake := &fakeLSPClient{}
|
||||
svc := NewService(fake)
|
||||
|
||||
_, err := svc.Complete(context.Background(), Request{
|
||||
URI: "file:///main.go",
|
||||
Text: "package main\nfunc main() {}",
|
||||
Line: 1,
|
||||
Character: 12,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("first Complete() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = svc.Complete(context.Background(), Request{
|
||||
URI: "file:///main.go",
|
||||
Text: "package main\nfunc main() { fmt.Pr }",
|
||||
Line: 1,
|
||||
Character: 23,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("second Complete() error = %v", err)
|
||||
}
|
||||
|
||||
if len(fake.openCalls) != 1 {
|
||||
t.Fatalf("expected 1 didOpen call, got %d", len(fake.openCalls))
|
||||
}
|
||||
if len(fake.changeCalls) != 1 {
|
||||
t.Fatalf("expected 1 didChange call, got %d", len(fake.changeCalls))
|
||||
}
|
||||
if fake.changeCalls[0].version != 2 {
|
||||
t.Fatalf("expected didChange version=2, got %d", fake.changeCalls[0].version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceCompleteValidatesRequest(t *testing.T) {
|
||||
svc := NewService(&fakeLSPClient{})
|
||||
|
||||
_, err := svc.Complete(context.Background(), Request{URI: "", Text: "", Line: -1, Character: 0})
|
||||
if !errors.Is(err, ErrInvalidRequest) {
|
||||
t.Fatalf("expected ErrInvalidRequest, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceCompleteReturnsClientError(t *testing.T) {
|
||||
fake := &fakeLSPClient{openErr: errors.New("open failed")}
|
||||
svc := NewService(fake)
|
||||
|
||||
_, err := svc.Complete(context.Background(), Request{
|
||||
URI: "file:///main.go",
|
||||
Text: "package main",
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
552
backend/internal/lsp/client.go
Normal file
552
backend/internal/lsp/client.go
Normal file
@@ -0,0 +1,552 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"monica-go-completion-backend/internal/completion"
|
||||
)
|
||||
|
||||
var errClientClosed = errors.New("lsp client closed")
|
||||
|
||||
type Client struct {
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
nextID atomic.Int64
|
||||
|
||||
workspaceDir string
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pending map[string]chan rpcResponse
|
||||
|
||||
exitCh chan error
|
||||
closeOnce sync.Once
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
type rpcResponse struct {
|
||||
result json.RawMessage
|
||||
err error
|
||||
}
|
||||
|
||||
type rpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
type incomingEnvelope struct {
|
||||
ID *json.RawMessage `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *rpcError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type lspCompletionItem struct {
|
||||
Label string `json:"label"`
|
||||
Kind int `json:"kind,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
InsertText string `json:"insertText,omitempty"`
|
||||
SortText string `json:"sortText,omitempty"`
|
||||
FilterText string `json:"filterText,omitempty"`
|
||||
Documentation json.RawMessage `json:"documentation,omitempty"`
|
||||
}
|
||||
|
||||
type lspCompletionList struct {
|
||||
IsIncomplete bool `json:"isIncomplete"`
|
||||
Items []lspCompletionItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewClient(parent context.Context, goplsPath, rootPath string) (*Client, error) {
|
||||
if goplsPath == "" {
|
||||
goplsPath = "gopls"
|
||||
}
|
||||
if rootPath == "" {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get working directory: %w", err)
|
||||
}
|
||||
rootPath = cwd
|
||||
}
|
||||
|
||||
cmd := exec.Command(goplsPath)
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start gopls: %w", err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
workspaceDir: rootPath,
|
||||
pending: make(map[string]chan rpcResponse),
|
||||
exitCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
go func() {
|
||||
client.exitCh <- cmd.Wait()
|
||||
}()
|
||||
go client.readLoop(stdout)
|
||||
|
||||
initCtx, cancel := context.WithTimeout(parent, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.initialize(initCtx, rootPath); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) initialize(ctx context.Context, rootPath string) error {
|
||||
rootURI, err := pathToURI(rootPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build root uri: %w", err)
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"processId": os.Getpid(),
|
||||
"rootUri": rootURI,
|
||||
"capabilities": map[string]any{
|
||||
"textDocument": map[string]any{
|
||||
"completion": map[string]any{
|
||||
"completionItem": map[string]any{
|
||||
"snippetSupport": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"workspaceFolders": []map[string]string{
|
||||
{
|
||||
"uri": rootURI,
|
||||
"name": filepath.Base(rootPath),
|
||||
},
|
||||
},
|
||||
"clientInfo": map[string]string{
|
||||
"name": "monica-go-completion-backend",
|
||||
"version": "0.1.0",
|
||||
},
|
||||
}
|
||||
|
||||
var initResult json.RawMessage
|
||||
if err := c.request(ctx, "initialize", params, &initResult); err != nil {
|
||||
return fmt.Errorf("initialize request failed: %w", err)
|
||||
}
|
||||
if err := c.notify("initialized", map[string]any{}); err != nil {
|
||||
return fmt.Errorf("initialized notification failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) DidOpen(ctx context.Context, uri, text string, version int) error {
|
||||
normalizedURI, err := c.normalizeURI(uri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"textDocument": map[string]any{
|
||||
"uri": normalizedURI,
|
||||
"languageId": "go",
|
||||
"version": version,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
return c.notifyWithContext(ctx, "textDocument/didOpen", params)
|
||||
}
|
||||
|
||||
func (c *Client) DidChange(ctx context.Context, uri, text string, version int) error {
|
||||
normalizedURI, err := c.normalizeURI(uri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"textDocument": map[string]any{
|
||||
"uri": normalizedURI,
|
||||
"version": version,
|
||||
},
|
||||
"contentChanges": []map[string]string{
|
||||
{
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
return c.notifyWithContext(ctx, "textDocument/didChange", params)
|
||||
}
|
||||
|
||||
func (c *Client) Completion(ctx context.Context, uri string, line, character int) (completion.Response, error) {
|
||||
normalizedURI, err := c.normalizeURI(uri)
|
||||
if err != nil {
|
||||
return completion.Response{}, err
|
||||
}
|
||||
|
||||
params := map[string]any{
|
||||
"textDocument": map[string]string{
|
||||
"uri": normalizedURI,
|
||||
},
|
||||
"position": map[string]int{
|
||||
"line": line,
|
||||
"character": character,
|
||||
},
|
||||
}
|
||||
|
||||
var raw json.RawMessage
|
||||
if err := c.request(ctx, "textDocument/completion", params, &raw); err != nil {
|
||||
return completion.Response{}, err
|
||||
}
|
||||
|
||||
body := bytes.TrimSpace(raw)
|
||||
if len(body) == 0 || bytes.Equal(body, []byte("null")) {
|
||||
return completion.Response{}, nil
|
||||
}
|
||||
|
||||
if body[0] == '[' {
|
||||
var items []lspCompletionItem
|
||||
if err := json.Unmarshal(body, &items); err != nil {
|
||||
return completion.Response{}, fmt.Errorf("decode completion items: %w", err)
|
||||
}
|
||||
return completion.Response{Items: mapCompletionItems(items)}, nil
|
||||
}
|
||||
|
||||
var list lspCompletionList
|
||||
if err := json.Unmarshal(body, &list); err != nil {
|
||||
return completion.Response{}, fmt.Errorf("decode completion list: %w", err)
|
||||
}
|
||||
return completion.Response{
|
||||
Items: mapCompletionItems(list.Items),
|
||||
IsIncomplete: list.IsIncomplete,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var windowsDrivePattern = regexp.MustCompile(`^[A-Za-z]:`)
|
||||
|
||||
func (c *Client) normalizeURI(rawURI string) (string, error) {
|
||||
if rawURI == "" {
|
||||
return "", errors.New("empty uri")
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid uri: %w", err)
|
||||
}
|
||||
if u.Scheme != "file" {
|
||||
return rawURI, nil
|
||||
}
|
||||
|
||||
path := u.Path
|
||||
if path == "" {
|
||||
return rawURI, nil
|
||||
}
|
||||
|
||||
var localPath string
|
||||
trimmed := strings.TrimPrefix(path, "/")
|
||||
if windowsDrivePattern.MatchString(trimmed) {
|
||||
localPath = filepath.FromSlash(trimmed)
|
||||
} else {
|
||||
rel := strings.TrimLeft(trimmed, "/\\")
|
||||
localPath = filepath.Join(c.workspaceDir, filepath.FromSlash(rel))
|
||||
}
|
||||
|
||||
return pathToURI(localPath)
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_ = c.request(shutdownCtx, "shutdown", map[string]any{}, nil)
|
||||
_ = c.notify("exit", map[string]any{})
|
||||
_ = c.stdin.Close()
|
||||
|
||||
select {
|
||||
case <-c.exitCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
if c.cmd.Process != nil {
|
||||
_ = c.cmd.Process.Kill()
|
||||
}
|
||||
<-c.exitCh
|
||||
}
|
||||
|
||||
c.closed.Store(true)
|
||||
c.failPending(errClientClosed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) request(ctx context.Context, method string, params any, out any) error {
|
||||
if c.closed.Load() {
|
||||
return errClientClosed
|
||||
}
|
||||
|
||||
id := c.nextID.Add(1)
|
||||
key := strconv.FormatInt(id, 10)
|
||||
wait := make(chan rpcResponse, 1)
|
||||
|
||||
c.pendingMu.Lock()
|
||||
c.pending[key] = wait
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
msg := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
if err := c.writeMessage(msg); err != nil {
|
||||
c.removePending(key)
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.removePending(key)
|
||||
return ctx.Err()
|
||||
case resp := <-wait:
|
||||
if resp.err != nil {
|
||||
return resp.err
|
||||
}
|
||||
if out != nil && len(resp.result) > 0 && !bytes.Equal(bytes.TrimSpace(resp.result), []byte("null")) {
|
||||
if err := json.Unmarshal(resp.result, out); err != nil {
|
||||
return fmt.Errorf("decode response for %s: %w", method, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) notify(method string, params any) error {
|
||||
return c.notifyWithContext(context.Background(), method, params)
|
||||
}
|
||||
|
||||
func (c *Client) notifyWithContext(ctx context.Context, method string, params any) error {
|
||||
if c.closed.Load() {
|
||||
return errClientClosed
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
return c.writeMessage(msg)
|
||||
}
|
||||
|
||||
func (c *Client) writeMessage(msg any) error {
|
||||
payload, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal json-rpc message: %w", err)
|
||||
}
|
||||
|
||||
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(payload))
|
||||
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if _, err := io.WriteString(c.stdin, header); err != nil {
|
||||
return fmt.Errorf("write lsp header: %w", err)
|
||||
}
|
||||
if _, err := c.stdin.Write(payload); err != nil {
|
||||
return fmt.Errorf("write lsp payload: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(stdout io.Reader) {
|
||||
reader := bufio.NewReader(stdout)
|
||||
for {
|
||||
body, err := readMessage(reader)
|
||||
if err != nil {
|
||||
c.failPending(fmt.Errorf("read lsp message: %w", err))
|
||||
return
|
||||
}
|
||||
c.handleIncoming(body)
|
||||
}
|
||||
}
|
||||
|
||||
func readMessage(reader *bufio.Reader) ([]byte, error) {
|
||||
contentLength := 0
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
break
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(parts[0]), "Content-Length") {
|
||||
n, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid content-length %q: %w", parts[1], err)
|
||||
}
|
||||
contentLength = n
|
||||
}
|
||||
}
|
||||
|
||||
if contentLength <= 0 {
|
||||
return nil, errors.New("missing content-length")
|
||||
}
|
||||
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleIncoming(body []byte) {
|
||||
var envelope incomingEnvelope
|
||||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if envelope.ID == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := normalizeID(*envelope.ID)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.pendingMu.Lock()
|
||||
wait := c.pending[key]
|
||||
delete(c.pending, key)
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
if wait == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if envelope.Error != nil {
|
||||
wait <- rpcResponse{
|
||||
err: fmt.Errorf("lsp error code=%d message=%s", envelope.Error.Code, envelope.Error.Message),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
wait <- rpcResponse{result: envelope.Result}
|
||||
}
|
||||
|
||||
func normalizeID(raw json.RawMessage) string {
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if raw[0] == '"' {
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func (c *Client) removePending(key string) {
|
||||
c.pendingMu.Lock()
|
||||
delete(c.pending, key)
|
||||
c.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) failPending(err error) {
|
||||
c.pendingMu.Lock()
|
||||
defer c.pendingMu.Unlock()
|
||||
|
||||
for key, wait := range c.pending {
|
||||
delete(c.pending, key)
|
||||
wait <- rpcResponse{err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func mapCompletionItems(items []lspCompletionItem) []completion.Item {
|
||||
out := make([]completion.Item, 0, len(items))
|
||||
for _, it := range items {
|
||||
out = append(out, completion.Item{
|
||||
Label: it.Label,
|
||||
Kind: it.Kind,
|
||||
Detail: it.Detail,
|
||||
Documentation: decodeDocumentation(it.Documentation),
|
||||
InsertText: it.InsertText,
|
||||
SortText: it.SortText,
|
||||
FilterText: it.FilterText,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeDocumentation(raw json.RawMessage) string {
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if raw[0] == '"' {
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var markup struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &markup); err == nil && markup.Value != "" {
|
||||
return markup.Value
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func pathToURI(path string) (string, error) {
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
slashed := filepath.ToSlash(abs)
|
||||
if !strings.HasPrefix(slashed, "/") {
|
||||
slashed = "/" + slashed
|
||||
}
|
||||
u := url.URL{Scheme: "file", Path: slashed}
|
||||
return u.String(), nil
|
||||
}
|
||||
49
backend/internal/lsp/client_test.go
Normal file
49
backend/internal/lsp/client_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeURIRebasesRelativeFileURI(t *testing.T) {
|
||||
workspace, err := filepath.Abs(filepath.Join("testdata", "ws"))
|
||||
if err != nil {
|
||||
t.Fatalf("filepath.Abs() error = %v", err)
|
||||
}
|
||||
|
||||
client := &Client{workspaceDir: workspace}
|
||||
got, err := client.normalizeURI("file:///main.go")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeURI() error = %v", err)
|
||||
}
|
||||
|
||||
want, err := pathToURI(filepath.Join(workspace, "main.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("pathToURI() error = %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("normalizeURI() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeURIKeepsAbsoluteFileURI(t *testing.T) {
|
||||
workspace, err := filepath.Abs(filepath.Join("testdata", "ws"))
|
||||
if err != nil {
|
||||
t.Fatalf("filepath.Abs() error = %v", err)
|
||||
}
|
||||
|
||||
client := &Client{workspaceDir: workspace}
|
||||
absPath := filepath.Join(workspace, "demo.go")
|
||||
uri, err := pathToURI(absPath)
|
||||
if err != nil {
|
||||
t.Fatalf("pathToURI() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := client.normalizeURI(uri)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeURI() error = %v", err)
|
||||
}
|
||||
if got != uri {
|
||||
t.Fatalf("normalizeURI() = %q, want %q", got, uri)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user