package api import ( "bytes" "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "monica-go-completion-backend/internal/completion" ) type fakeCompletionService struct { resp completion.Response err error } func (f *fakeCompletionService) Complete(_ context.Context, _ completion.Request) (completion.Response, error) { if f.err != nil { return completion.Response{}, f.err } return f.resp, nil } // 验证 HTTP 补全接口的成功路径。 func TestRegisterRoutesCompletionSuccess(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() RegisterRoutes(r, &fakeCompletionService{ resp: completion.Response{Items: []completion.Item{{Label: "Println"}}, IsIncomplete: true}, }) body, _ := json.Marshal(map[string]any{ "uri": "file:///main.go", "text": "package main", "line": 0, "character": 0, }) req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected status 200, got %d", w.Code) } var got struct { Items []completion.Item `json:"items"` IsIncomplete bool `json:"isIncomplete"` } if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { t.Fatalf("failed to decode response: %v", err) } if len(got.Items) != 1 || got.Items[0].Label != "Println" || !got.IsIncomplete { t.Fatalf("unexpected response body: %s", w.Body.String()) } } // 验证非法 JSON 会返回 400。 func TestRegisterRoutesCompletionBadJSON(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() RegisterRoutes(r, &fakeCompletionService{}) req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader([]byte("not-json"))) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400, got %d", w.Code) } } // 验证业务校验错误会映射为 400。 func TestRegisterRoutesCompletionValidationError(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() rv := &fakeCompletionService{err: completion.ErrInvalidRequest} RegisterRoutes(r, rv) body, _ := json.Marshal(map[string]any{ "uri": "", "text": "", "line": -1, "character": 0, }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400, got %d", w.Code) } } // 验证未知内部错误会映射为 500。 func TestRegisterRoutesCompletionServerError(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() RegisterRoutes(r, &fakeCompletionService{err: errors.New("boom")}) body, _ := json.Marshal(map[string]any{ "uri": "file:///main.go", "text": "package main", "line": 0, "character": 0, }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/completions/go", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Fatalf("expected status 500, got %d", w.Code) } } // 验证 WebSocket 补全协议的基础成功流程。 func TestRegisterRoutesCompletionWebSocketSuccess(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() RegisterRoutes(r, &fakeCompletionService{ resp: completion.Response{Items: []completion.Item{{Label: "Println"}}, IsIncomplete: false}, }) srv := httptest.NewServer(r) defer srv.Close() wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws/completions/go" conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("websocket dial failed: %v", err) } defer conn.Close() req := map[string]any{ "id": "1", "uri": "file:///main.go", "text": "package main", "line": 0, "character": 0, } if err := conn.WriteJSON(req); err != nil { t.Fatalf("WriteJSON() failed: %v", err) } var resp struct { ID string `json:"id"` Items []completion.Item `json:"items"` IsIncomplete bool `json:"isIncomplete"` Error string `json:"error"` } if err := conn.ReadJSON(&resp); err != nil { t.Fatalf("ReadJSON() failed: %v", err) } if resp.Error != "" { t.Fatalf("unexpected websocket error: %s", resp.Error) } if resp.ID != "1" { t.Fatalf("unexpected id: %s", resp.ID) } if len(resp.Items) != 1 || resp.Items[0].Label != "Println" { t.Fatalf("unexpected items: %+v", resp.Items) } }