616 lines
15 KiB
Go
616 lines
15 KiB
Go
package lsp
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/url"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"monica-go-completion-backend/internal/completion"
|
||
)
|
||
|
||
var errClientClosed = errors.New("lsp client closed")
|
||
|
||
// Client 封装与 Language Server 的 JSON-RPC/LSP 通信。
|
||
type Client struct {
|
||
cmd *exec.Cmd
|
||
stdin io.WriteCloser
|
||
|
||
writeMu sync.Mutex
|
||
|
||
nextID atomic.Int64
|
||
|
||
workspaceDir string
|
||
languageID string
|
||
clientName string
|
||
|
||
pendingMu sync.Mutex
|
||
pending map[string]chan rpcResponse
|
||
|
||
exitCh chan error
|
||
closeOnce sync.Once
|
||
closed atomic.Bool
|
||
}
|
||
|
||
type rpcResponse struct {
|
||
result json.RawMessage
|
||
err error
|
||
}
|
||
|
||
// rpcError 对应 JSON-RPC 错误对象。
|
||
type rpcError struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Data json.RawMessage `json:"data,omitempty"`
|
||
}
|
||
|
||
// incomingEnvelope 表示从服务端读到的响应/通知外层结构。
|
||
type incomingEnvelope struct {
|
||
ID *json.RawMessage `json:"id,omitempty"`
|
||
Result json.RawMessage `json:"result,omitempty"`
|
||
Error *rpcError `json:"error,omitempty"`
|
||
}
|
||
|
||
type lspCompletionItem struct {
|
||
Label string `json:"label"`
|
||
Kind int `json:"kind,omitempty"`
|
||
Detail string `json:"detail,omitempty"`
|
||
InsertText string `json:"insertText,omitempty"`
|
||
SortText string `json:"sortText,omitempty"`
|
||
FilterText string `json:"filterText,omitempty"`
|
||
Documentation json.RawMessage `json:"documentation,omitempty"`
|
||
}
|
||
|
||
type lspCompletionList struct {
|
||
IsIncomplete bool `json:"isIncomplete"`
|
||
Items []lspCompletionItem `json:"items"`
|
||
}
|
||
|
||
// Config 定义 LSP 子进程启动参数及客户端标识。
|
||
type Config struct {
|
||
Command string // LSP 可执行命令。
|
||
Args []string // 启动参数(通常包含 --stdio)。
|
||
RootPath string // 工作区根路径。
|
||
LanguageID string // didOpen/didChange 使用的 languageId。
|
||
ClientName string // initialize.clientInfo.name。
|
||
}
|
||
|
||
// NewClient 启动语言服务器进程并完成 initialize 握手。
|
||
func NewClient(parent context.Context, cfg Config) (*Client, error) {
|
||
if cfg.Command == "" {
|
||
cfg.Command = "gopls"
|
||
}
|
||
if cfg.RootPath == "" {
|
||
cwd, err := os.Getwd()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get working directory: %w", err)
|
||
}
|
||
cfg.RootPath = cwd
|
||
}
|
||
if cfg.LanguageID == "" {
|
||
cfg.LanguageID = "go"
|
||
}
|
||
if cfg.ClientName == "" {
|
||
cfg.ClientName = "monica-lsp-gateway"
|
||
}
|
||
|
||
cmd := exec.Command(filepath.FromSlash(cfg.Command), cfg.Args...)
|
||
stdin, err := cmd.StdinPipe()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create stdin pipe: %w", err)
|
||
}
|
||
stdout, err := cmd.StdoutPipe()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||
}
|
||
cmd.Stderr = os.Stderr
|
||
|
||
if err := cmd.Start(); err != nil {
|
||
return nil, fmt.Errorf("start language server %q: %w", cfg.Command, err)
|
||
}
|
||
|
||
client := &Client{
|
||
cmd: cmd,
|
||
stdin: stdin,
|
||
workspaceDir: cfg.RootPath,
|
||
languageID: cfg.LanguageID,
|
||
clientName: cfg.ClientName,
|
||
pending: make(map[string]chan rpcResponse),
|
||
exitCh: make(chan error, 1),
|
||
}
|
||
|
||
go func() {
|
||
client.exitCh <- cmd.Wait()
|
||
}()
|
||
// 独立协程持续读取 stdout 并分发响应。
|
||
go client.readLoop(stdout)
|
||
|
||
initCtx, cancel := context.WithTimeout(parent, 30*time.Second)
|
||
defer cancel()
|
||
|
||
if err := client.initialize(initCtx, cfg.RootPath); err != nil {
|
||
_ = client.Close()
|
||
return nil, err
|
||
}
|
||
|
||
return client, nil
|
||
}
|
||
|
||
// initialize 完成 LSP initialize/initialized 流程。
|
||
func (c *Client) initialize(ctx context.Context, rootPath string) error {
|
||
rootURI, err := pathToURI(rootPath)
|
||
if err != nil {
|
||
return fmt.Errorf("build root uri: %w", err)
|
||
}
|
||
|
||
params := map[string]any{
|
||
"processId": os.Getpid(),
|
||
"rootUri": rootURI,
|
||
"capabilities": map[string]any{
|
||
"textDocument": map[string]any{
|
||
"completion": map[string]any{
|
||
"completionItem": map[string]any{
|
||
"snippetSupport": true,
|
||
},
|
||
},
|
||
},
|
||
},
|
||
"workspaceFolders": []map[string]string{
|
||
{
|
||
"uri": rootURI,
|
||
"name": filepath.Base(rootPath),
|
||
},
|
||
},
|
||
"clientInfo": map[string]string{
|
||
"name": c.clientName,
|
||
"version": "0.1.0",
|
||
},
|
||
}
|
||
|
||
var initResult json.RawMessage
|
||
if err := c.request(ctx, "initialize", params, &initResult); err != nil {
|
||
return fmt.Errorf("initialize request failed: %w", err)
|
||
}
|
||
if err := c.notify("initialized", map[string]any{}); err != nil {
|
||
return fmt.Errorf("initialized notification failed: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DidOpen 发送 textDocument/didOpen,告知服务端首次打开文档。
|
||
func (c *Client) DidOpen(ctx context.Context, uri, text string, version int) error {
|
||
normalizedURI, err := c.normalizeURI(uri)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
params := map[string]any{
|
||
"textDocument": map[string]any{
|
||
"uri": normalizedURI,
|
||
"languageId": c.languageID,
|
||
"version": version,
|
||
"text": text,
|
||
},
|
||
}
|
||
return c.notifyWithContext(ctx, "textDocument/didOpen", params)
|
||
}
|
||
|
||
// DidChange 发送 textDocument/didChange,推送文档全文与版本号。
|
||
func (c *Client) DidChange(ctx context.Context, uri, text string, version int) error {
|
||
normalizedURI, err := c.normalizeURI(uri)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
params := map[string]any{
|
||
"textDocument": map[string]any{
|
||
"uri": normalizedURI,
|
||
"version": version,
|
||
},
|
||
"contentChanges": []map[string]string{
|
||
{
|
||
"text": text,
|
||
},
|
||
},
|
||
}
|
||
return c.notifyWithContext(ctx, "textDocument/didChange", params)
|
||
}
|
||
|
||
// Completion 调用 textDocument/completion,并兼容两种返回形态。
|
||
func (c *Client) Completion(ctx context.Context, uri string, line, character int) (completion.Response, error) {
|
||
normalizedURI, err := c.normalizeURI(uri)
|
||
if err != nil {
|
||
return completion.Response{}, err
|
||
}
|
||
|
||
params := map[string]any{
|
||
"textDocument": map[string]string{
|
||
"uri": normalizedURI,
|
||
},
|
||
"position": map[string]int{
|
||
"line": line,
|
||
"character": character,
|
||
},
|
||
}
|
||
|
||
var raw json.RawMessage
|
||
if err := c.request(ctx, "textDocument/completion", params, &raw); err != nil {
|
||
return completion.Response{}, err
|
||
}
|
||
|
||
body := bytes.TrimSpace(raw)
|
||
if len(body) == 0 || bytes.Equal(body, []byte("null")) {
|
||
return completion.Response{}, nil
|
||
}
|
||
|
||
if body[0] == '[' {
|
||
// 部分 LSP 服务端直接返回 CompletionItem[]。
|
||
var items []lspCompletionItem
|
||
if err := json.Unmarshal(body, &items); err != nil {
|
||
return completion.Response{}, fmt.Errorf("decode completion items: %w", err)
|
||
}
|
||
return completion.Response{Items: mapCompletionItems(items)}, nil
|
||
}
|
||
|
||
var list lspCompletionList
|
||
if err := json.Unmarshal(body, &list); err != nil {
|
||
return completion.Response{}, fmt.Errorf("decode completion list: %w", err)
|
||
}
|
||
return completion.Response{
|
||
Items: mapCompletionItems(list.Items),
|
||
IsIncomplete: list.IsIncomplete,
|
||
}, nil
|
||
}
|
||
|
||
var windowsDrivePattern = regexp.MustCompile(`^[A-Za-z]:`)
|
||
|
||
// languageExtensions 定义 languageId 对应的规范文件扩展名。
|
||
var languageExtensions = map[string]string{
|
||
"java": ".java",
|
||
"go": ".go",
|
||
"javascript": ".js",
|
||
"typescript": ".ts",
|
||
"python": ".py",
|
||
"c": ".c",
|
||
"cpp": ".cpp",
|
||
"rust": ".rs",
|
||
}
|
||
|
||
// normalizeURI 将相对 file URI 重写为工作区绝对路径 URI,
|
||
// 并将不匹配 languageId 的扩展名(如 .txt)替换为语言对应扩展名。
|
||
func (c *Client) normalizeURI(rawURI string) (string, error) {
|
||
if rawURI == "" {
|
||
return "", errors.New("empty uri")
|
||
}
|
||
|
||
u, err := url.Parse(rawURI)
|
||
if err != nil {
|
||
return "", fmt.Errorf("invalid uri: %w", err)
|
||
}
|
||
if u.Scheme != "file" {
|
||
return rawURI, nil
|
||
}
|
||
|
||
path := u.Path
|
||
if path == "" {
|
||
return rawURI, nil
|
||
}
|
||
|
||
var localPath string
|
||
trimmed := strings.TrimPrefix(path, "/")
|
||
if windowsDrivePattern.MatchString(trimmed) {
|
||
localPath = filepath.FromSlash(trimmed)
|
||
} else {
|
||
rel := strings.TrimLeft(trimmed, "/\\")
|
||
localPath = filepath.Join(c.workspaceDir, filepath.FromSlash(rel))
|
||
}
|
||
|
||
// 若文件扩展名与 languageId 不匹配(如 .txt),替换为语言对应扩展名。
|
||
if expectedExt, ok := languageExtensions[strings.ToLower(c.languageID)]; ok {
|
||
currentExt := strings.ToLower(filepath.Ext(localPath))
|
||
if currentExt != expectedExt {
|
||
localPath = strings.TrimSuffix(localPath, filepath.Ext(localPath)) + expectedExt
|
||
}
|
||
}
|
||
|
||
return pathToURI(localPath)
|
||
}
|
||
|
||
// Close 优雅关闭 LSP 进程并失败通知所有未完成请求。
|
||
func (c *Client) Close() error {
|
||
c.closeOnce.Do(func() {
|
||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
|
||
_ = c.request(shutdownCtx, "shutdown", map[string]any{}, nil)
|
||
_ = c.notify("exit", map[string]any{})
|
||
_ = c.stdin.Close()
|
||
|
||
select {
|
||
case <-c.exitCh:
|
||
case <-time.After(2 * time.Second):
|
||
if c.cmd.Process != nil {
|
||
_ = c.cmd.Process.Kill()
|
||
}
|
||
<-c.exitCh
|
||
}
|
||
|
||
c.closed.Store(true)
|
||
c.failPending(errClientClosed)
|
||
})
|
||
return nil
|
||
}
|
||
|
||
// request 发送 JSON-RPC 请求并阻塞等待对应响应。
|
||
func (c *Client) request(ctx context.Context, method string, params any, out any) error {
|
||
if c.closed.Load() {
|
||
return errClientClosed
|
||
}
|
||
|
||
id := c.nextID.Add(1)
|
||
key := strconv.FormatInt(id, 10)
|
||
wait := make(chan rpcResponse, 1)
|
||
|
||
c.pendingMu.Lock()
|
||
c.pending[key] = wait
|
||
c.pendingMu.Unlock()
|
||
|
||
msg := map[string]any{
|
||
"jsonrpc": "2.0",
|
||
"id": id,
|
||
"method": method,
|
||
"params": params,
|
||
}
|
||
if err := c.writeMessage(msg); err != nil {
|
||
c.removePending(key)
|
||
return err
|
||
}
|
||
|
||
select {
|
||
case <-ctx.Done():
|
||
c.removePending(key)
|
||
return ctx.Err()
|
||
case resp := <-wait:
|
||
if resp.err != nil {
|
||
return resp.err
|
||
}
|
||
if out != nil && len(resp.result) > 0 && !bytes.Equal(bytes.TrimSpace(resp.result), []byte("null")) {
|
||
if err := json.Unmarshal(resp.result, out); err != nil {
|
||
return fmt.Errorf("decode response for %s: %w", method, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func (c *Client) notify(method string, params any) error {
|
||
return c.notifyWithContext(context.Background(), method, params)
|
||
}
|
||
|
||
// notifyWithContext 发送 JSON-RPC 通知(无响应)。
|
||
func (c *Client) notifyWithContext(ctx context.Context, method string, params any) error {
|
||
if c.closed.Load() {
|
||
return errClientClosed
|
||
}
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
default:
|
||
}
|
||
|
||
msg := map[string]any{
|
||
"jsonrpc": "2.0",
|
||
"method": method,
|
||
"params": params,
|
||
}
|
||
return c.writeMessage(msg)
|
||
}
|
||
|
||
// writeMessage 按 LSP framing 写入消息头和消息体。
|
||
func (c *Client) writeMessage(msg any) error {
|
||
payload, err := json.Marshal(msg)
|
||
if err != nil {
|
||
return fmt.Errorf("marshal json-rpc message: %w", err)
|
||
}
|
||
|
||
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(payload))
|
||
|
||
c.writeMu.Lock()
|
||
defer c.writeMu.Unlock()
|
||
|
||
if _, err := io.WriteString(c.stdin, header); err != nil {
|
||
return fmt.Errorf("write lsp header: %w", err)
|
||
}
|
||
if _, err := c.stdin.Write(payload); err != nil {
|
||
return fmt.Errorf("write lsp payload: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// readLoop 持续读取 LSP 消息并交给 handleIncoming 分发。
|
||
func (c *Client) readLoop(stdout io.Reader) {
|
||
reader := bufio.NewReader(stdout)
|
||
for {
|
||
body, err := readMessage(reader)
|
||
if err != nil {
|
||
c.failPending(fmt.Errorf("read lsp message: %w", err))
|
||
return
|
||
}
|
||
c.handleIncoming(body)
|
||
}
|
||
}
|
||
|
||
// readMessage 按 Content-Length 协议边界读取单条 JSON-RPC 消息。
|
||
func readMessage(reader *bufio.Reader) ([]byte, error) {
|
||
contentLength := 0
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
line = strings.TrimRight(line, "\r\n")
|
||
if line == "" {
|
||
break
|
||
}
|
||
|
||
parts := strings.SplitN(line, ":", 2)
|
||
if len(parts) != 2 {
|
||
continue
|
||
}
|
||
if strings.EqualFold(strings.TrimSpace(parts[0]), "Content-Length") {
|
||
n, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("invalid content-length %q: %w", parts[1], err)
|
||
}
|
||
contentLength = n
|
||
}
|
||
}
|
||
|
||
if contentLength <= 0 {
|
||
return nil, errors.New("missing content-length")
|
||
}
|
||
|
||
body := make([]byte, contentLength)
|
||
if _, err := io.ReadFull(reader, body); err != nil {
|
||
return nil, err
|
||
}
|
||
return body, nil
|
||
}
|
||
|
||
// handleIncoming 将响应按 id 投递给对应等待中的请求通道。
|
||
func (c *Client) handleIncoming(body []byte) {
|
||
var envelope incomingEnvelope
|
||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||
return
|
||
}
|
||
|
||
if envelope.ID == nil {
|
||
return
|
||
}
|
||
|
||
key := normalizeID(*envelope.ID)
|
||
if key == "" {
|
||
return
|
||
}
|
||
|
||
c.pendingMu.Lock()
|
||
wait := c.pending[key]
|
||
delete(c.pending, key)
|
||
c.pendingMu.Unlock()
|
||
|
||
if wait == nil {
|
||
return
|
||
}
|
||
|
||
if envelope.Error != nil {
|
||
wait <- rpcResponse{
|
||
err: fmt.Errorf("lsp error code=%d message=%s", envelope.Error.Code, envelope.Error.Message),
|
||
}
|
||
return
|
||
}
|
||
|
||
wait <- rpcResponse{result: envelope.Result}
|
||
}
|
||
|
||
// normalizeID 统一处理数字/字符串两种 JSON-RPC id。
|
||
func normalizeID(raw json.RawMessage) string {
|
||
raw = bytes.TrimSpace(raw)
|
||
if len(raw) == 0 {
|
||
return ""
|
||
}
|
||
|
||
if raw[0] == '"' {
|
||
var s string
|
||
if err := json.Unmarshal(raw, &s); err == nil {
|
||
return s
|
||
}
|
||
}
|
||
return string(raw)
|
||
}
|
||
|
||
func (c *Client) removePending(key string) {
|
||
c.pendingMu.Lock()
|
||
delete(c.pending, key)
|
||
c.pendingMu.Unlock()
|
||
}
|
||
|
||
// failPending 在连接异常时让所有挂起请求立即失败返回。
|
||
func (c *Client) failPending(err error) {
|
||
c.pendingMu.Lock()
|
||
defer c.pendingMu.Unlock()
|
||
|
||
for key, wait := range c.pending {
|
||
delete(c.pending, key)
|
||
wait <- rpcResponse{err: err}
|
||
}
|
||
}
|
||
|
||
// mapCompletionItems 将 LSP 项结构映射为网关统一输出结构。
|
||
func mapCompletionItems(items []lspCompletionItem) []completion.Item {
|
||
out := make([]completion.Item, 0, len(items))
|
||
for _, it := range items {
|
||
out = append(out, completion.Item{
|
||
Label: it.Label,
|
||
Kind: it.Kind,
|
||
Detail: it.Detail,
|
||
Documentation: decodeDocumentation(it.Documentation),
|
||
InsertText: it.InsertText,
|
||
SortText: it.SortText,
|
||
FilterText: it.FilterText,
|
||
})
|
||
}
|
||
return out
|
||
}
|
||
|
||
// decodeDocumentation 兼容 string 和 MarkupContent 两种文档字段。
|
||
func decodeDocumentation(raw json.RawMessage) string {
|
||
raw = bytes.TrimSpace(raw)
|
||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||
return ""
|
||
}
|
||
|
||
if raw[0] == '"' {
|
||
var s string
|
||
if err := json.Unmarshal(raw, &s); err == nil {
|
||
return s
|
||
}
|
||
return ""
|
||
}
|
||
|
||
var markup struct {
|
||
Value string `json:"value"`
|
||
}
|
||
if err := json.Unmarshal(raw, &markup); err == nil && markup.Value != "" {
|
||
return markup.Value
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// pathToURI 将本地路径转换为标准 file URI。
|
||
func pathToURI(path string) (string, error) {
|
||
abs, err := filepath.Abs(path)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
slashed := filepath.ToSlash(abs)
|
||
if !strings.HasPrefix(slashed, "/") {
|
||
slashed = "/" + slashed
|
||
}
|
||
u := url.URL{Scheme: "file", Path: slashed}
|
||
return u.String(), nil
|
||
}
|