From 4de5a74a551b19010d75acfb57093103e912b425 Mon Sep 17 00:00:00 2001 From: mirivlad Date: Sat, 20 Jun 2026 03:20:25 +0800 Subject: [PATCH] fix: sanitize sync error messages, detect non-sync servers, add health check in TestAuth --- internal/core/sync/client.go | 53 ++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/internal/core/sync/client.go b/internal/core/sync/client.go index 91f8435..a59c292 100644 --- a/internal/core/sync/client.go +++ b/internal/core/sync/client.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "path/filepath" + "strings" "time" ) @@ -124,12 +125,35 @@ func (c *Client) RevokeCurrent() error { // TestAuth checks credentials without creating a device. func (c *Client) TestAuth(serverURL, username, password string) error { - body := map[string]string{"username": username, "password": password} + // First, check if this is a Verstak Sync server + healthURL := strings.TrimSuffix(serverURL, "/") + "/api/v1/health" + req, err := http.NewRequest("GET", healthURL, nil) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + resp, err := c.HTTP.Do(req) + if err != nil { + return fmt.Errorf("connection failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("not a Verstak Sync server (HTTP %d)", resp.StatusCode) + } + + data, _ := io.ReadAll(resp.Body) + body := string(data) + if !strings.Contains(body, "status") && !strings.Contains(body, "ok") { + return fmt.Errorf("not a Verstak Sync server (unexpected response)") + } + + // Now test actual auth + authBody := map[string]string{"username": username, "password": password} savedURL := c.ServerURL savedKey := c.APIKey c.ServerURL = serverURL c.APIKey = "" - err := c.post("/api/auth/test", body, nil) + err = c.post("/api/auth/test", authBody, nil) c.ServerURL = savedURL c.APIKey = savedKey return err @@ -273,6 +297,13 @@ func (c *Client) DownloadBlob(sha256, destPath string) error { return err } +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + func (c *Client) bearerToken() string { if c.DeviceToken != "" { return c.DeviceToken @@ -301,8 +332,7 @@ func (c *Client) post(path string, body, result interface{}) error { defer resp.Body.Close() if resp.StatusCode >= 400 { - data, _ := io.ReadAll(resp.Body) - return fmt.Errorf("server %d: %s", resp.StatusCode, string(data)) + return c.readErrorBody(resp, resp.StatusCode) } if result != nil { @@ -325,8 +355,7 @@ func (c *Client) get(path string, result interface{}) error { defer resp.Body.Close() if resp.StatusCode >= 400 { - data, _ := io.ReadAll(resp.Body) - return fmt.Errorf("server %d: %s", resp.StatusCode, string(data)) + return c.readErrorBody(resp, resp.StatusCode) } if result != nil { @@ -334,3 +363,15 @@ func (c *Client) get(path string, result interface{}) error { } return nil } + +func (c *Client) readErrorBody(resp *http.Response, statusCode int) error { + buf := make([]byte, 4096) + n, _ := io.ReadFull(resp.Body, buf) + body := string(buf[:minInt(n, 500)]) + + lower := strings.ToLower(body) + if strings.Contains(lower, "