223 lines
5.9 KiB
Go
223 lines
5.9 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
|
|
"monica-go-completion-backend/internal/completion"
|
|
)
|
|
|
|
type fakeCompletionService struct {
|
|
resp completion.Response
|
|
err error
|
|
}
|
|
|
|
func (f *fakeCompletionService) Complete(_ context.Context, _ completion.Request) (completion.Response, error) {
|
|
if f.err != nil {
|
|
return completion.Response{}, f.err
|
|
}
|
|
return f.resp, nil
|
|
}
|
|
|
|
type fakeLSPStatusProvider struct {
|
|
status map[string]any
|
|
}
|
|
|
|
func (f *fakeLSPStatusProvider) LspServiceStatus() map[string]any {
|
|
return f.status
|
|
}
|
|
|
|
// 验证 HTTP 补全接口的成功路径。
|
|
func TestRegisterRoutesCompletionSuccess(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
RegisterRoutes(r, &fakeCompletionService{
|
|
resp: completion.Response{Items: []completion.Item{{Label: "Println"}}, IsIncomplete: true},
|
|
})
|
|
|
|
body, _ := json.Marshal(map[string]any{
|
|
"uri": "file:///main.go",
|
|
"text": "package main",
|
|
"line": 0,
|
|
"character": 0,
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var got struct {
|
|
Items []completion.Item `json:"items"`
|
|
IsIncomplete bool `json:"isIncomplete"`
|
|
}
|
|
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
|
t.Fatalf("failed to decode response: %v", err)
|
|
}
|
|
if len(got.Items) != 1 || got.Items[0].Label != "Println" || !got.IsIncomplete {
|
|
t.Fatalf("unexpected response body: %s", w.Body.String())
|
|
}
|
|
}
|
|
|
|
// 验证非法 JSON 会返回 400。
|
|
func TestRegisterRoutesCompletionBadJSON(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
RegisterRoutes(r, &fakeCompletionService{})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader([]byte("not-json")))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
// 验证业务校验错误会映射为 400。
|
|
func TestRegisterRoutesCompletionValidationError(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
rv := &fakeCompletionService{err: completion.ErrInvalidRequest}
|
|
RegisterRoutes(r, rv)
|
|
|
|
body, _ := json.Marshal(map[string]any{
|
|
"uri": "",
|
|
"text": "",
|
|
"line": -1,
|
|
"character": 0,
|
|
})
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
// 验证未知内部错误会映射为 500。
|
|
func TestRegisterRoutesCompletionServerError(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
RegisterRoutes(r, &fakeCompletionService{err: errors.New("boom")})
|
|
|
|
body, _ := json.Marshal(map[string]any{
|
|
"uri": "file:///main.go",
|
|
"text": "package main",
|
|
"line": 0,
|
|
"character": 0,
|
|
})
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Fatalf("expected status 500, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
// 验证 WebSocket 补全协议的基础成功流程。
|
|
func TestRegisterRoutesCompletionWebSocketSuccess(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
RegisterRoutes(r, &fakeCompletionService{
|
|
resp: completion.Response{Items: []completion.Item{{Label: "Println"}}, IsIncomplete: false},
|
|
})
|
|
|
|
srv := httptest.NewServer(r)
|
|
defer srv.Close()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/completions/go"
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("websocket dial failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
req := map[string]any{
|
|
"id": "1",
|
|
"uri": "file:///main.go",
|
|
"text": "package main",
|
|
"line": 0,
|
|
"character": 0,
|
|
}
|
|
if err := conn.WriteJSON(req); err != nil {
|
|
t.Fatalf("WriteJSON() failed: %v", err)
|
|
}
|
|
|
|
var resp struct {
|
|
ID string `json:"id"`
|
|
Items []completion.Item `json:"items"`
|
|
IsIncomplete bool `json:"isIncomplete"`
|
|
Error string `json:"error"`
|
|
}
|
|
if err := conn.ReadJSON(&resp); err != nil {
|
|
t.Fatalf("ReadJSON() failed: %v", err)
|
|
}
|
|
|
|
if resp.Error != "" {
|
|
t.Fatalf("unexpected websocket error: %s", resp.Error)
|
|
}
|
|
if resp.ID != "1" {
|
|
t.Fatalf("unexpected id: %s", resp.ID)
|
|
}
|
|
if len(resp.Items) != 1 || resp.Items[0].Label != "Println" {
|
|
t.Fatalf("unexpected items: %+v", resp.Items)
|
|
}
|
|
}
|
|
|
|
// 验证 /health/lsp-status 会返回语言探测状态快照。
|
|
func TestRegisterRoutesLspStatus(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
RegisterRoutes(r, &fakeCompletionService{}, RouteOptions{
|
|
LSPStatusProvider: &fakeLSPStatusProvider{
|
|
status: map[string]any{
|
|
"go": map[string]any{
|
|
"online": true,
|
|
},
|
|
},
|
|
},
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/health/lsp-status", nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var got map[string]any
|
|
if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
|
|
t.Fatalf("failed to decode response: %v", err)
|
|
}
|
|
if got["status"] != "ok" {
|
|
t.Fatalf("expected status ok, got %#v", got["status"])
|
|
}
|
|
languages, ok := got["languages"].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("languages field type mismatch: %#v", got["languages"])
|
|
}
|
|
if _, ok := languages["go"]; !ok {
|
|
t.Fatalf("expected go language status, got %#v", languages)
|
|
}
|
|
}
|