diff --git a/internal/server/handlers_api.go b/internal/server/handlers_api.go new file mode 100644 index 0000000..e43164f --- /dev/null +++ b/internal/server/handlers_api.go @@ -0,0 +1,548 @@ +package server + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "golang.org/x/crypto/bcrypt" +) + +func (s *Server) handleNotFound(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Write([]byte("Verstak Sync Server\n")) + return + } + jsonErr(w, 404, "not found") +} + +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + jsonOK(w, map[string]interface{}{ + "status": "ok", + "version": "verstak-server/v1", + "time": time.Now().UTC().Format(time.RFC3339), + }) +} + +func (s *Server) handleClientPair(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + ip := r.RemoteAddr + if idx := strings.LastIndex(ip, ":"); idx >= 0 { + ip = ip[:idx] + } + if !s.pairLimit.allow(ip) { + s.auditLog("rate_limit_exceeded", "", "", ip, "pair rate limit exceeded") + jsonErr(w, 429, "too many attempts") + return + } + var req struct { + Login string `json:"login"` + Password string `json:"password"` + DeviceName string `json:"device_name"` + ClientVersion string `json:"client_version"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "bad json") + return + } + if req.Login == "" || req.Password == "" { + jsonErr(w, 400, "login and password required") + return + } + if req.DeviceName == "" { + req.DeviceName = "unknown" + } + var userID, hash string + var confirmed, blocked int + err := s.db.QueryRow("SELECT id, password_hash, confirmed, blocked FROM server_users WHERE username=? OR email=?", + req.Login, strings.ToLower(req.Login)).Scan(&userID, &hash, &confirmed, &blocked) + if err != nil { + s.auditLog("device_auth_failed", "", "", ip, "pair: user not found: "+req.Login) + jsonErr(w, 401, "invalid credentials") + return + } + if blocked != 0 { + s.auditLog("device_auth_failed", userID, "", ip, "pair: user blocked") + jsonErr(w, 403, "account blocked") + return + } + if confirmed == 0 { + s.auditLog("device_auth_failed", userID, "", ip, "pair: email not confirmed") + jsonErr(w, 403, "email not confirmed") + return + } + if bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)) != nil { + s.auditLog("device_auth_failed", userID, "", ip, "pair: wrong password") + jsonErr(w, 401, "invalid credentials") + return + } + devID := make([]byte, 12) + rand.Read(devID) + deviceID := "dev_" + hex.EncodeToString(devID) + token, prefix, suffix := genDeviceToken() + tokenHash := sha256Hex(token) + now := time.Now().UTC().Format(time.RFC3339) + apiKey := make([]byte, 20) + rand.Read(apiKey) + _, err = s.db.Exec(`INSERT INTO server_devices + (id, name, api_key, token_hash, token_prefix, token_suffix, user_id, client_version, last_ip, last_seen, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + deviceID, req.DeviceName, hex.EncodeToString(apiKey), tokenHash, prefix, suffix, + userID, req.ClientVersion, ip, now, now) + if err != nil { + jsonErr(w, 500, err.Error()) + return + } + s.db.Exec("INSERT OR IGNORE INTO server_user_devices (user_id, device_id) VALUES (?, ?)", userID, deviceID) + s.db.Exec("UPDATE server_users SET last_seen=? WHERE id=?", now, userID) + s.pairLimit.reset(ip) + s.auditLog("device_paired", userID, deviceID, ip, "device paired: "+req.DeviceName) + jsonOK(w, map[string]interface{}{ + "user_id": userID, + "device_id": deviceID, + "device_token": token, + "server_time": now, + "initial_sync_cursor": 0, + }) +} + +func (s *Server) handleAuthTest(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "bad json") + return + } + if req.Username == "" || req.Password == "" { + jsonErr(w, 400, "username and password required") + return + } + var hash string + var confirmed, blocked int + err := s.db.QueryRow("SELECT password_hash, confirmed, blocked FROM server_users WHERE username=? OR email=?", + req.Username, strings.ToLower(req.Username)).Scan(&hash, &confirmed, &blocked) + if err != nil { + jsonErr(w, 401, "invalid credentials") + return + } + if blocked != 0 { + jsonErr(w, 403, "account blocked") + return + } + if confirmed == 0 { + jsonErr(w, 403, "email not confirmed") + return + } + if bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)) != nil { + jsonErr(w, 401, "invalid credentials") + return + } + jsonOK(w, map[string]string{"status": "ok"}) +} + +func (s *Server) handleClientRevoke(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + tok := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if tok == "" { + jsonErr(w, 401, "token required") + return + } + hash := sha256Hex(tok) + var deviceID, userID string + err := s.db.QueryRow("SELECT id, user_id FROM server_devices WHERE token_hash=?", hash).Scan(&deviceID, &userID) + if err != nil { + jsonErr(w, 401, "invalid token") + return + } + now := time.Now().UTC().Format(time.RFC3339) + s.db.Exec("UPDATE server_devices SET revoked_at=? WHERE id=?", now, deviceID) + s.auditLog("device_revoked", userID, deviceID, r.RemoteAddr, "device revoked by user") + jsonOK(w, map[string]string{"status": "revoked"}) +} + +func (s *Server) handleClientRevokeDevice(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + tok := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if tok == "" { + jsonErr(w, 401, "token required") + return + } + hash := sha256Hex(tok) + var curUserID string + err := s.db.QueryRow("SELECT user_id FROM server_devices WHERE token_hash=?", hash).Scan(&curUserID) + if err != nil || curUserID == "" { + jsonErr(w, 401, "invalid token") + return + } + var req struct { + DeviceID string `json:"device_id"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.DeviceID == "" || req.Password == "" { + jsonErr(w, 400, "device_id and password required") + return + } + var pwHash string + err = s.db.QueryRow("SELECT password_hash FROM server_users WHERE id=?", curUserID).Scan(&pwHash) + if err != nil { + jsonErr(w, 403, "access denied") + return + } + if bcrypt.CompareHashAndPassword([]byte(pwHash), []byte(req.Password)) != nil { + jsonErr(w, 403, "wrong password") + return + } + var devUserID string + err = s.db.QueryRow("SELECT user_id FROM server_devices WHERE id=?", req.DeviceID).Scan(&devUserID) + if err != nil { + jsonErr(w, 404, "device not found") + return + } + if devUserID != curUserID { + jsonErr(w, 403, "device does not belong to you") + return + } + now := time.Now().UTC().Format(time.RFC3339) + s.db.Exec("UPDATE server_devices SET revoked_at=? WHERE id=?", now, req.DeviceID) + s.auditLog("device_revoked", curUserID, req.DeviceID, r.RemoteAddr, "device revoked via API") + jsonOK(w, map[string]string{"status": "revoked"}) +} + +func (s *Server) handleClientMe(w http.ResponseWriter, r *http.Request) { + tok := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if tok == "" { + jsonErr(w, 401, "token required") + return + } + hash := sha256Hex(tok) + var deviceID, userID, name, clientVer, lastSeen, revokedAt, createdAt string + err := s.db.QueryRow(`SELECT d.id, d.user_id, d.name, COALESCE(d.client_version,''), COALESCE(d.last_seen,''), COALESCE(d.revoked_at,''), d.created_at + FROM server_devices d WHERE d.token_hash=?`, hash). + Scan(&deviceID, &userID, &name, &clientVer, &lastSeen, &revokedAt, &createdAt) + if err != nil { + jsonErr(w, 401, "invalid token") + return + } + var username string + s.db.QueryRow("SELECT username FROM server_users WHERE id=?", userID).Scan(&username) + jsonOK(w, map[string]interface{}{ + "device_id": deviceID, + "user_id": userID, + "username": username, + "device_name": name, + "client_version": clientVer, + "last_seen": lastSeen, + "revoked_at": revokedAt, + "created_at": createdAt, + }) +} + +func (s *Server) handleDeviceRegister(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Name string `json:"name"` + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Name == "" { + jsonErr(w, 400, "name required") + return + } + if req.Username == "" || req.Password == "" { + jsonErr(w, 401, "username and password required") + return + } + var userID, hash string + var confirmed, blocked int + err := s.db.QueryRow("SELECT id, password_hash, confirmed, blocked FROM server_users WHERE username=? OR email=?", + req.Username, strings.ToLower(req.Username)).Scan(&userID, &hash, &confirmed, &blocked) + if err != nil { + jsonErr(w, 401, "invalid credentials") + return + } + if blocked != 0 { + jsonErr(w, 403, "account blocked") + return + } + if confirmed == 0 { + jsonErr(w, 403, "email not confirmed") + return + } + if bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)) != nil { + jsonErr(w, 401, "invalid credentials") + return + } + b := make([]byte, 20) + rand.Read(b) + apiKey := hex.EncodeToString(b) + deviceID := apiKey[:12] + now := time.Now().UTC().Format(time.RFC3339) + _, err = s.db.Exec( + "INSERT INTO server_devices (id, name, api_key, last_seen, created_at) VALUES (?, ?, ?, ?, ?)", + deviceID, req.Name, apiKey, now, now, + ) + if err != nil { + jsonErr(w, 500, err.Error()) + return + } + s.db.Exec("INSERT OR IGNORE INTO server_user_devices (user_id, device_id) VALUES (?, ?)", userID, deviceID) + jsonOK(w, map[string]interface{}{ + "device_id": deviceID, + "api_key": apiKey, + }) +} + +func (s *Server) handleSyncPush(w http.ResponseWriter, r *http.Request) { + _, _, ok := s.requireAuth(w, r) + if !ok { + return + } + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + DeviceID string `json:"device_id"` + IdempotencyKey string `json:"idempotency_key"` + Ops []struct { + OpID string `json:"op_id"` + EntityType string `json:"entity_type"` + EntityID string `json:"entity_id"` + OpType string `json:"op_type"` + PayloadJSON string `json:"payload_json"` + ClientSequence int `json:"client_sequence"` + LastSeenServerSeq int `json:"last_seen_server_seq"` + CreatedAt string `json:"created_at"` + } `json:"ops"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON: "+err.Error()) + return + } + if req.IdempotencyKey != "" { + var cachedJSON string + err := s.db.QueryRow("SELECT response_json FROM server_idempotency_keys WHERE idempotency_key=?", req.IdempotencyKey).Scan(&cachedJSON) + if err == nil { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(cachedJSON)) + return + } + } + now := time.Now().UTC().Format(time.RFC3339) + var accepted []string + var conflicts []map[string]interface{} + for _, op := range req.Ops { + if op.OpID == "" || op.EntityType == "" || op.EntityID == "" || op.OpType == "" { + continue + } + if op.LastSeenServerSeq > 0 { + conflictRows, err := s.db.Query(` + SELECT op_id, device_id, op_type, server_sequence FROM server_ops + WHERE entity_type=? AND entity_id=? AND device_id!=? + AND server_sequence > ? AND op_type != 'delete' + ORDER BY server_sequence`, op.EntityType, op.EntityID, req.DeviceID, op.LastSeenServerSeq) + if err == nil { + for conflictRows.Next() { + var cOpID, cDevID, cOpType string + var cSeq int + conflictRows.Scan(&cOpID, &cDevID, &cOpType, &cSeq) + conflicts = append(conflicts, map[string]interface{}{ + "op_id": cOpID, + "device_id": cDevID, + "op_type": cOpType, + "server_sequence": cSeq, + "entity_type": op.EntityType, + "entity_id": op.EntityID, + }) + } + conflictRows.Close() + } + } + res, err := s.db.Exec( + `INSERT OR IGNORE INTO server_ops (op_id, server_sequence, device_id, entity_type, entity_id, op_type, payload_json, idempotency_key, client_sequence, last_seen_server_seq, created_at, pushed_at) + VALUES (?, NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + op.OpID, req.DeviceID, op.EntityType, op.EntityID, op.OpType, op.PayloadJSON, + req.IdempotencyKey, op.ClientSequence, op.LastSeenServerSeq, op.CreatedAt, now, + ) + if err != nil { + continue + } + n, _ := res.RowsAffected() + if n == 0 { + continue + } + seqRes, err := s.db.Exec("INSERT INTO server_revisions (op_id, device_id) VALUES (?, ?)", op.OpID, req.DeviceID) + if err != nil { + continue + } + seq, _ := seqRes.LastInsertId() + s.db.Exec("UPDATE server_ops SET server_sequence=? WHERE op_id=?", seq, op.OpID) + if op.OpType == "delete" { + s.db.Exec(`INSERT OR REPLACE INTO server_tombstones (entity_type, entity_id, op_id, deleted_at) VALUES (?, ?, ?, ?)`, + op.EntityType, op.EntityID, op.OpID, now) + } + accepted = append(accepted, op.OpID) + } + resp := map[string]interface{}{ + "accepted": accepted, + "count": len(accepted), + "conflicts": conflicts, + } + if req.IdempotencyKey != "" { + if respJSON, err := json.Marshal(resp); err == nil { + s.db.Exec("INSERT OR IGNORE INTO server_idempotency_keys (idempotency_key, response_json, created_at) VALUES (?, ?, ?)", + req.IdempotencyKey, string(respJSON), now) + } + } + jsonOK(w, resp) +} + +func (s *Server) handleSyncPull(w http.ResponseWriter, r *http.Request) { + _, _, ok := s.requireAuth(w, r) + if !ok { + return + } + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + SinceSequence int `json:"since_sequence"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + var serverSeq int + s.db.QueryRow("SELECT COALESCE(MAX(server_sequence), 0) FROM server_ops").Scan(&serverSeq) + rows, err := s.db.Query(` + SELECT op_id, server_sequence, device_id, entity_type, entity_id, op_type, payload_json, created_at + FROM server_ops + WHERE server_sequence > ? AND server_sequence IS NOT NULL + ORDER BY server_sequence`, req.SinceSequence) + if err != nil { + jsonErr(w, 500, err.Error()) + return + } + defer rows.Close() + type opDTO struct { + OpID string `json:"op_id"` + ServerSequence int `json:"server_sequence"` + DeviceID string `json:"device_id"` + EntityType string `json:"entity_type"` + EntityID string `json:"entity_id"` + OpType string `json:"op_type"` + PayloadJSON string `json:"payload_json"` + CreatedAt string `json:"created_at"` + } + var ops []opDTO + for rows.Next() { + var o opDTO + if err := rows.Scan(&o.OpID, &o.ServerSequence, &o.DeviceID, &o.EntityType, &o.EntityID, &o.OpType, &o.PayloadJSON, &o.CreatedAt); err != nil { + continue + } + ops = append(ops, o) + } + jsonOK(w, map[string]interface{}{ + "server_sequence": serverSeq, + "ops": ops, + }) +} + +func (s *Server) handleBlobs(w http.ResponseWriter, r *http.Request) { + _, _, ok := s.requireAuth(w, r) + if !ok { + return + } + switch r.Method { + case "POST": + if err := r.ParseMultipartForm(200 << 20); err != nil { + jsonErr(w, 400, "multipart error: "+err.Error()) + return + } + file, header, err := r.FormFile("file") + if err != nil { + jsonErr(w, 400, "file field required") + return + } + defer file.Close() + data, err := io.ReadAll(file) + if err != nil { + jsonErr(w, 500, "read error") + return + } + hash := sha256Hex(string(data)) + blobDir := filepath.Join(s.blobsDir, hash[:2], hash[2:4]) + if err := os.MkdirAll(blobDir, 0750); err != nil { + jsonErr(w, 500, "mkdir error") + return + } + blobPath := filepath.Join(blobDir, hash) + if err := os.WriteFile(blobPath, data, 0640); err != nil { + jsonErr(w, 500, "write error") + return + } + _ = header + now := time.Now().UTC().Format(time.RFC3339) + s.db.Exec("INSERT OR IGNORE INTO server_blobs (sha256, size, created_at) VALUES (?, ?, ?)", + hash, len(data), now) + jsonOK(w, map[string]interface{}{ + "sha256": hash, + "size": len(data), + }) + case "GET": + shaHex := strings.TrimPrefix(r.URL.Path, "/api/v1/blobs/") + if len(shaHex) != 64 { + jsonErr(w, 400, "invalid SHA-256") + return + } + blobPath := filepath.Join(s.blobsDir, shaHex[:2], shaHex[2:4], shaHex) + if _, err := os.Stat(blobPath); os.IsNotExist(err) { + jsonErr(w, 404, "blob not found") + return + } + data, err := os.ReadFile(blobPath) + if err != nil { + jsonErr(w, 500, "read error") + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", "attachment; filename=\""+shaHex+"\"") + w.Write(data) + default: + jsonErr(w, 405, "method not allowed") + } +} diff --git a/internal/server/handlers_auth.go b/internal/server/handlers_auth.go new file mode 100644 index 0000000..d626694 --- /dev/null +++ b/internal/server/handlers_auth.go @@ -0,0 +1,216 @@ +package server + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "log" + "net/http" + "strings" + "time" + + "golang.org/x/crypto/bcrypt" +) + +func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Username string `json:"username"` + Email string `json:"email"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Username == "" || req.Email == "" || req.Password == "" { + jsonErr(w, 400, "username, email and password required") + return + } + if err := validatePassword(req.Password); err != "" { + jsonErr(w, 400, err) + return + } + if !strings.Contains(req.Email, "@") || !strings.Contains(req.Email, ".") { + jsonErr(w, 400, "invalid email") + return + } + hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) + if err != nil { + jsonErr(w, 500, "internal error") + return + } + now := time.Now().UTC().Format(time.RFC3339) + id := make([]byte, 12) + rand.Read(id) + userID := hex.EncodeToString(id) + _, err = s.db.Exec( + "INSERT INTO server_users (id, username, email, password_hash, confirmed, created_at) VALUES (?, ?, ?, ?, 0, ?)", + userID, req.Username, strings.ToLower(req.Email), string(hash), now, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + jsonErr(w, 409, "username or email already taken") + return + } + jsonErr(w, 500, err.Error()) + return + } + tok := make([]byte, 24) + rand.Read(tok) + tokenStr := hex.EncodeToString(tok) + exp := time.Now().Add(48 * time.Hour).UTC().Format(time.RFC3339) + s.db.Exec("INSERT INTO server_email_tokens (token, user_id, purpose, expires_at, created_at) VALUES (?, ?, 'confirm', ?, ?)", + tokenStr, userID, exp, now) + log.Printf("register: confirmation token=%s for user %s", tokenStr, req.Username) + jsonOK(w, map[string]string{"status": "confirmation_sent"}) +} + +func (s *Server) handleConfirm(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + jsonErr(w, 405, "GET required") + return + } + tokenStr := r.URL.Query().Get("token") + if tokenStr == "" { + jsonErr(w, 400, "token required") + return + } + var userID, expiresAt string + err := s.db.QueryRow("SELECT user_id, expires_at FROM server_email_tokens WHERE token=? AND purpose='confirm'", + tokenStr).Scan(&userID, &expiresAt) + if err != nil { + jsonErr(w, 400, "invalid or expired token") + return + } + exp, err := time.Parse(time.RFC3339, expiresAt) + if err != nil || time.Now().After(exp) { + jsonErr(w, 400, "token expired") + return + } + s.db.Exec("UPDATE server_users SET confirmed=1 WHERE id=?", userID) + log.Printf("confirm: user %s confirmed email", userID) + s.db.Exec("DELETE FROM server_email_tokens WHERE token=?", tokenStr) + jsonOK(w, map[string]string{"status": "confirmed"}) +} + +func (s *Server) handleUserLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Username == "" || req.Password == "" { + jsonErr(w, 400, "username and password required") + return + } + var userID, hash string + var confirmed, blocked int + err := s.db.QueryRow("SELECT id, password_hash, confirmed, blocked FROM server_users WHERE username=? OR email=?", + req.Username, strings.ToLower(req.Username)).Scan(&userID, &hash, &confirmed, &blocked) + if err != nil { + jsonErr(w, 401, "invalid credentials") + return + } + if blocked != 0 { + jsonErr(w, 403, "account blocked") + return + } + if confirmed == 0 { + jsonErr(w, 403, "email not confirmed") + return + } + if bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)) != nil { + jsonErr(w, 401, "invalid credentials") + return + } + s.db.Exec("UPDATE server_users SET last_seen=? WHERE id=?", time.Now().UTC().Format(time.RFC3339), userID) + tok := s.userTokens.Create(userID) + jsonOK(w, map[string]string{"token": tok, "user_id": userID}) +} + +func (s *Server) handleForgot(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Email string `json:"email"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Email == "" { + jsonErr(w, 400, "email required") + return + } + var userID string + err := s.db.QueryRow("SELECT id FROM server_users WHERE email=?", strings.ToLower(req.Email)).Scan(&userID) + if err != nil { + jsonOK(w, map[string]string{"status": "if email exists, reset link sent"}) + return + } + tok := make([]byte, 24) + rand.Read(tok) + tokenStr := hex.EncodeToString(tok) + exp := time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339) + now := time.Now().UTC().Format(time.RFC3339) + s.db.Exec("INSERT INTO server_email_tokens (token, user_id, purpose, expires_at, created_at) VALUES (?, ?, 'reset', ?, ?)", + tokenStr, userID, exp, now) + log.Printf("forgot: reset token=%s for user %s", tokenStr, userID) + jsonOK(w, map[string]string{"status": "if email exists, reset link sent"}) +} + +func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Token string `json:"token"` + NewPassword string `json:"new_password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Token == "" || req.NewPassword == "" { + jsonErr(w, 400, "token and new_password required") + return + } + if err := validatePassword(req.NewPassword); err != "" { + jsonErr(w, 400, err) + return + } + var userID, expiresAt string + err := s.db.QueryRow("SELECT user_id, expires_at FROM server_email_tokens WHERE token=? AND purpose='reset'", + req.Token).Scan(&userID, &expiresAt) + if err != nil { + jsonErr(w, 400, "invalid or expired token") + return + } + exp, err := time.Parse(time.RFC3339, expiresAt) + if err != nil || time.Now().After(exp) { + jsonErr(w, 400, "token expired") + return + } + hash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) + if err != nil { + jsonErr(w, 500, "internal error") + return + } + s.db.Exec("UPDATE server_users SET password_hash=? WHERE id=?", string(hash), userID) + s.db.Exec("DELETE FROM server_email_tokens WHERE token=?", req.Token) + jsonOK(w, map[string]string{"status": "password reset"}) +} diff --git a/internal/server/helpers.go b/internal/server/helpers.go new file mode 100644 index 0000000..8fcda22 --- /dev/null +++ b/internal/server/helpers.go @@ -0,0 +1,24 @@ +package server + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "net/http" +) + +func jsonOK(w http.ResponseWriter, v interface{}) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +func jsonErr(w http.ResponseWriter, code int, msg string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(map[string]string{"error": msg}) +} + +func sha256Hex(s string) string { + h := sha256.Sum256([]byte(s)) + return hex.EncodeToString(h[:]) +} diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..4181858 --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,72 @@ +package server + +import ( + "database/sql" + "net/http" + "strings" + "time" +) + +func (s *Server) requireAuth(w http.ResponseWriter, r *http.Request) (deviceID, userID string, ok bool) { + key := r.Header.Get("Authorization") + key = strings.TrimPrefix(key, "Bearer ") + if key == "" { + key = r.URL.Query().Get("api_key") + } + if key == "" { + jsonErr(w, 401, "API key required") + return "", "", false + } + hash := sha256Hex(key) + var deviceIDVal, userIDVal, revokedAt sql.NullString + err := s.db.QueryRow("SELECT id, user_id, revoked_at FROM server_devices WHERE token_hash=?", hash).Scan(&deviceIDVal, &userIDVal, &revokedAt) + if err == nil { + if revokedAt.Valid && revokedAt.String != "" { + jsonErr(w, 401, "device revoked") + return "", "", false + } + if userIDVal.Valid && userIDVal.String != "" { + var blocked int + s.db.QueryRow("SELECT blocked FROM server_users WHERE id=?", userIDVal.String).Scan(&blocked) + if blocked != 0 { + jsonErr(w, 403, "user blocked") + return "", "", false + } + } + s.db.Exec("UPDATE server_devices SET last_seen=? WHERE id=?", time.Now().UTC().Format(time.RFC3339), deviceIDVal.String) + return deviceIDVal.String, userIDVal.String, true + } + var count int + err = s.db.QueryRow("SELECT COUNT(*) FROM server_devices WHERE api_key=?", key).Scan(&count) + if err != nil || count == 0 { + jsonErr(w, 401, "invalid API key") + return "", "", false + } + return "", "", true +} + +func (s *Server) requireAdmin(w http.ResponseWriter, r *http.Request) bool { + cookie, err := r.Cookie("session") + if err != nil || !s.tokens.Check(cookie.Value) { + http.Redirect(w, r, "/admin/login", http.StatusFound) + return false + } + return true +} + +type PasswordError string + +const ( + ErrPasswordTooShort PasswordError = "PASSWORD_TOO_SHORT" + ErrPasswordTooLong PasswordError = "PASSWORD_TOO_LONG" +) + +func validatePassword(password string) string { + if len(password) < 8 { + return string(ErrPasswordTooShort) + } + if len(password) > 256 { + return string(ErrPasswordTooLong) + } + return "" +} diff --git a/internal/server/routes.go b/internal/server/routes.go new file mode 100644 index 0000000..8307a40 --- /dev/null +++ b/internal/server/routes.go @@ -0,0 +1,20 @@ +package server + +func (s *Server) routes() { + s.mux.HandleFunc("/api/v1/health", s.handleHealth) + s.mux.HandleFunc("/api/v1/device/register", s.handleDeviceRegister) + s.mux.HandleFunc("/api/v1/sync/push", s.handleSyncPush) + s.mux.HandleFunc("/api/v1/sync/pull", s.handleSyncPull) + s.mux.HandleFunc("/api/v1/blobs/", s.handleBlobs) + s.mux.HandleFunc("/api/client/pair", s.handleClientPair) + s.mux.HandleFunc("/api/auth/test", s.handleAuthTest) + s.mux.HandleFunc("/api/client/revoke-current", s.handleClientRevoke) + s.mux.HandleFunc("/api/client/me", s.handleClientMe) + s.mux.HandleFunc("/api/client/revoke-device", s.handleClientRevokeDevice) + s.mux.HandleFunc("/api/v1/auth/register", s.handleRegister) + s.mux.HandleFunc("/api/v1/auth/confirm", s.handleConfirm) + s.mux.HandleFunc("/api/v1/auth/login", s.handleUserLogin) + s.mux.HandleFunc("/api/v1/auth/forgot", s.handleForgot) + s.mux.HandleFunc("/api/v1/auth/reset", s.handleReset) + s.mux.HandleFunc("/", s.handleNotFound) +} diff --git a/internal/server/schema.go b/internal/server/schema.go index 73cad8c..7a74f3e 100644 --- a/internal/server/schema.go +++ b/internal/server/schema.go @@ -62,6 +62,26 @@ CREATE TABLE IF NOT EXISTS server_idempotency_keys ( created_at TEXT NOT NULL ); +CREATE TABLE IF NOT EXISTS server_email_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + purpose TEXT NOT NULL, + expires_at TEXT NOT NULL, + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS server_revisions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + op_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS server_blobs ( + sha256 TEXT PRIMARY KEY, + size INTEGER NOT NULL, + created_at TEXT NOT NULL +); + CREATE TABLE IF NOT EXISTS server_audit_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, event_type TEXT NOT NULL, diff --git a/internal/server/server.go b/internal/server/server.go index a4fe072..3ba25e1 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -158,6 +158,10 @@ func NewServer(dbPath, dataDir string, cfg *Config) (*Server, error) { return s, nil } +func (s *Server) SetupRoutes() { + s.routes() +} + func (s *Server) Close() error { return s.db.Close() } diff --git a/internal/server/tokens.go b/internal/server/tokens.go new file mode 100644 index 0000000..b727097 --- /dev/null +++ b/internal/server/tokens.go @@ -0,0 +1,15 @@ +package server + +import ( + "crypto/rand" + "encoding/hex" +) + +func genDeviceToken() (token, prefix, suffix string) { + b := make([]byte, 32) + rand.Read(b) + token = hex.EncodeToString(b) + prefix = token[:8] + suffix = token[len(token)-4:] + return +}