From 2e6b451c3d68edfd4907eaf3bd2533ce292fa767 Mon Sep 17 00:00:00 2001 From: mirivlad Date: Sat, 27 Jun 2026 13:39:00 +0800 Subject: [PATCH] Reject revoked legacy sync API keys --- internal/server/middleware.go | 20 +++++++++++--- internal/server/server_test.go | 49 +++++++++++++++++++++++++++++++--- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 4181858..950ca97 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -36,13 +36,25 @@ func (s *Server) requireAuth(w http.ResponseWriter, r *http.Request) (deviceID, s.db.Exec("UPDATE server_devices SET last_seen=? WHERE id=?", time.Now().UTC().Format(time.RFC3339), deviceIDVal.String) return deviceIDVal.String, userIDVal.String, true } - var count int - err = s.db.QueryRow("SELECT COUNT(*) FROM server_devices WHERE api_key=?", key).Scan(&count) - if err != nil || count == 0 { + err = s.db.QueryRow("SELECT id, user_id, revoked_at FROM server_devices WHERE api_key=?", key).Scan(&deviceIDVal, &userIDVal, &revokedAt) + if err != nil { jsonErr(w, 401, "invalid API key") return "", "", false } - return "", "", true + if revokedAt.Valid && revokedAt.String != "" { + jsonErr(w, 401, "device revoked") + return "", "", false + } + if userIDVal.Valid && userIDVal.String != "" { + var blocked int + s.db.QueryRow("SELECT blocked FROM server_users WHERE id=?", userIDVal.String).Scan(&blocked) + if blocked != 0 { + jsonErr(w, 403, "user blocked") + return "", "", false + } + } + s.db.Exec("UPDATE server_devices SET last_seen=? WHERE id=?", time.Now().UTC().Format(time.RFC3339), deviceIDVal.String) + return deviceIDVal.String, userIDVal.String, true } func (s *Server) requireAdmin(w http.ResponseWriter, r *http.Request) bool { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ba06bed..49a40d8 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -132,7 +132,51 @@ func TestSyncPushPullStoresSequencedOps(t *testing.T) { } } +func TestRevokedLegacyAPIKeyCannotPushOrPull(t *testing.T) { + dir := t.TempDir() + s, err := NewServer(filepath.Join(dir, "test.db"), filepath.Join(dir, "data"), &Config{Port: 47732}) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + defer s.Close() + s.SetupRoutes() + + now := time.Now().UTC().Format(time.RFC3339) + if _, err := s.db.Exec( + "INSERT INTO server_devices (id, name, api_key, last_seen, revoked_at, created_at) VALUES (?, ?, ?, ?, ?, ?)", + "device-revoked", "Revoked Device", "revoked-key", now, now, now, + ); err != nil { + t.Fatalf("insert device: %v", err) + } + ts := httptest.NewServer(s.mux) + defer ts.Close() + + pushStatus, pushResp := postJSONStatus(t, ts.URL+"/api/v1/sync/push", "revoked-key", map[string]interface{}{ + "device_id": "device-revoked", + "ops": []map[string]interface{}{}, + }) + if pushStatus != http.StatusUnauthorized || pushResp["error"] != "device revoked" { + t.Fatalf("push status=%d resp=%#v, want 401 device revoked", pushStatus, pushResp) + } + + pullStatus, pullResp := postJSONStatus(t, ts.URL+"/api/v1/sync/pull", "revoked-key", map[string]interface{}{ + "since_sequence": 0, + }) + if pullStatus != http.StatusUnauthorized || pullResp["error"] != "device revoked" { + t.Fatalf("pull status=%d resp=%#v, want 401 device revoked", pullStatus, pullResp) + } +} + func postJSON(t *testing.T, url, token string, body interface{}) map[string]interface{} { + t.Helper() + status, out := postJSONStatus(t, url, token, body) + if status != http.StatusOK { + t.Fatalf("post %s status = %d", url, status) + } + return out +} + +func postJSONStatus(t *testing.T, url, token string, body interface{}) (int, map[string]interface{}) { t.Helper() var b bytes.Buffer if err := json.NewEncoder(&b).Encode(body); err != nil { @@ -149,12 +193,9 @@ func postJSON(t *testing.T, url, token string, body interface{}) map[string]inte t.Fatalf("post %s: %v", url, err) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Fatalf("post %s status = %d", url, resp.StatusCode) - } var out map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { t.Fatalf("decode response: %v", err) } - return out + return resp.StatusCode, out }