diff --git a/cmd/verstak-server/server.go b/cmd/verstak-server/server.go index 26c4088..3c61cb6 100644 --- a/cmd/verstak-server/server.go +++ b/cmd/verstak-server/server.go @@ -3,14 +3,18 @@ package main import ( "crypto/rand" "crypto/sha256" + "crypto/tls" "database/sql" "encoding/hex" "encoding/json" "fmt" "io" + "net" "net/http" + "net/smtp" "os" "path/filepath" + "regexp" "strings" "sync" "time" @@ -20,6 +24,8 @@ import ( _ "github.com/mattn/go-sqlite3" ) +var passwordRE = regexp.MustCompile(`^[A-Za-z0-9]+$`) + // ============================================================ // Config // ============================================================ @@ -139,6 +145,45 @@ func (ts *tokenStore) Check(tok string) bool { return true } +// userTokenStore embeds tokenStore but also tracks the user_id per token. +type userTokenStore struct { + mu sync.Mutex + tokens map[string]userTokenEntry +} + +type userTokenEntry struct { + UserID string + ExpiresAt time.Time +} + +func newUserTokenStore() *userTokenStore { + return &userTokenStore{tokens: make(map[string]userTokenEntry)} +} + +func (uts *userTokenStore) Create(userID string) string { + uts.mu.Lock() + defer uts.mu.Unlock() + b := make([]byte, 16) + rand.Read(b) + tok := hex.EncodeToString(b) + uts.tokens[tok] = userTokenEntry{UserID: userID, ExpiresAt: time.Now().Add(24 * time.Hour)} + return tok +} + +func (uts *userTokenStore) Check(tok string) (string, bool) { + uts.mu.Lock() + defer uts.mu.Unlock() + entry, ok := uts.tokens[tok] + if !ok { + return "", false + } + if time.Now().After(entry.ExpiresAt) { + delete(uts.tokens, tok) + return "", false + } + return entry.UserID, true +} + // ============================================================ // Server DB schema // ============================================================ @@ -175,6 +220,34 @@ CREATE TABLE IF NOT EXISTS server_blobs ( size INTEGER NOT NULL, created_at TEXT NOT NULL ); + +CREATE TABLE IF NOT EXISTS server_smtp_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS server_users ( + id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + confirmed INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS server_email_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + purpose TEXT NOT NULL, + expires_at TEXT NOT NULL, + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS server_user_devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + PRIMARY KEY (user_id, device_id) +); ` // ============================================================ @@ -182,11 +255,12 @@ CREATE TABLE IF NOT EXISTS server_blobs ( // ============================================================ type Server struct { - db *sql.DB - cfg *Config - tokens *tokenStore - blobsDir string - mux *http.ServeMux + db *sql.DB + cfg *Config + tokens *tokenStore + userTokens *userTokenStore + blobsDir string + mux *http.ServeMux } func NewServer(dbPath, dataDir string, cfg *Config) (*Server, error) { @@ -215,10 +289,11 @@ func NewServer(dbPath, dataDir string, cfg *Config) (*Server, error) { } s := &Server{ - db: db, - cfg: cfg, - tokens: newTokenStore(), - blobsDir: blobsDir, + db: db, + cfg: cfg, + tokens: newTokenStore(), + userTokens: newUserTokenStore(), + blobsDir: blobsDir, } s.mux = s.routes() return s, nil @@ -243,6 +318,12 @@ func (s *Server) routes() *http.ServeMux { mux.HandleFunc("/api/v1/sync/push", s.handleSyncPush) mux.HandleFunc("/api/v1/sync/pull", s.handleSyncPull) mux.HandleFunc("/api/v1/blobs/", s.handleBlobs) + mux.HandleFunc("/api/v1/auth/register", s.handleRegister) + mux.HandleFunc("/api/v1/auth/confirm", s.handleConfirm) + mux.HandleFunc("/api/v1/auth/login", s.handleUserLogin) + mux.HandleFunc("/api/v1/auth/forgot", s.handleForgot) + mux.HandleFunc("/api/v1/auth/reset", s.handleReset) + mux.HandleFunc("/api/v1/user/devices", s.handleUserDevices) mux.HandleFunc("/admin/login", s.handleAdminLogin) mux.HandleFunc("/admin/dashboard", s.handleAdminDashboard) mux.HandleFunc("/admin/", s.handleAdminAPI) @@ -293,6 +374,116 @@ func (s *Server) requireAdmin(w http.ResponseWriter, r *http.Request) bool { return true } +// ============================================================ +// SMTP Config +// ============================================================ + +func (s *Server) smtpGet(key string) string { + var val string + s.db.QueryRow("SELECT value FROM server_smtp_config WHERE key=?", key).Scan(&val) + return val +} + +func (s *Server) smtpSet(key, val string) error { + _, err := s.db.Exec("INSERT OR REPLACE INTO server_smtp_config (key, value) VALUES (?, ?)", key, val) + return err +} + +func (s *Server) smtpSend(to, subject, body string) error { + host := s.smtpGet("smtp_host") + port := s.smtpGet("smtp_port") + user := s.smtpGet("smtp_user") + pass := s.smtpGet("smtp_pass") + from := s.smtpGet("smtp_from") + if host == "" || port == "" || from == "" { + return fmt.Errorf("SMTP not configured") + } + addr := net.JoinHostPort(host, port) + msg := []byte("From: " + from + "\r\n" + + "To: " + to + "\r\n" + + "Subject: " + subject + "\r\n" + + "MIME-Version: 1.0\r\n" + + "Content-Type: text/plain; charset=UTF-8\r\n" + + "\r\n" + body + "\r\n") + if user != "" { + auth := smtp.PlainAuth("", user, pass, host) + if port == "465" { + tlsCfg := &tls.Config{ServerName: host} + conn, err := tls.Dial("tcp", addr, tlsCfg) + if err != nil { + return err + } + cl, err := smtp.NewClient(conn, host) + if err != nil { + return err + } + defer cl.Close() + if err := cl.Auth(auth); err != nil { + return err + } + if err := cl.Mail(from); err != nil { + return err + } + if err := cl.Rcpt(to); err != nil { + return err + } + w, err := cl.Data() + if err != nil { + return err + } + _, err = w.Write(msg) + if err != nil { + return err + } + return w.Close() + } + return smtp.SendMail(addr, auth, from, []string{to}, msg) + } + return smtp.SendMail(addr, nil, from, []string{to}, msg) +} + +// ============================================================ +// User helpers +// ============================================================ + +func validatePassword(password string) string { + if len(password) < 8 { + return "Password must be at least 8 characters" + } + if !passwordRE.MatchString(password) { + return "Password must contain only Latin letters and digits" + } + hasLetter := false + hasDigit := false + for _, ch := range password { + if ch >= 'A' && ch <= 'Z' || ch >= 'a' && ch <= 'z' { + hasLetter = true + } + if ch >= '0' && ch <= '9' { + hasDigit = true + } + } + if !hasLetter || !hasDigit { + return "Password must contain both letters and digits" + } + return "" +} + +func (s *Server) requireUser(w http.ResponseWriter, r *http.Request) (string, bool) { + key := r.Header.Get("Authorization") + key = strings.TrimPrefix(key, "Bearer ") + if key == "" { + jsonErr(w, 401, "authorization required") + return "", false + } + userID, ok := s.userTokens.Check(key) + if !ok { + jsonErr(w, 401, "invalid or expired token") + return "", false + } + return userID, true +} + // ============================================================ // Handlers // ============================================================ @@ -333,35 +524,313 @@ func (s *Server) handleDeviceRegister(w http.ResponseWriter, r *http.Request) { return } if req.Username == "" || req.Password == "" { - jsonErr(w, 401, "admin username and password required") + jsonErr(w, 401, "username and password required") return } - if !s.cfg.CheckAdmin(req.Username, req.Password) { - jsonErr(w, 401, "invalid admin credentials") + + // Look up user by username or email. + var userID, hash string + var confirmed int + err := s.db.QueryRow("SELECT id, password_hash, confirmed FROM server_users WHERE username=? OR email=?", + req.Username, strings.ToLower(req.Username)).Scan(&userID, &hash, &confirmed) + if err != nil { + jsonErr(w, 401, "invalid credentials") + return + } + if confirmed == 0 { + jsonErr(w, 403, "email not confirmed") + return + } + if bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)) != nil { + jsonErr(w, 401, "invalid credentials") return } b := make([]byte, 20) rand.Read(b) apiKey := hex.EncodeToString(b) - + deviceID := apiKey[:12] now := time.Now().UTC().Format(time.RFC3339) - _, err := s.db.Exec( + _, err = s.db.Exec( "INSERT INTO server_devices (id, name, api_key, last_seen, created_at) VALUES (?, ?, ?, ?, ?)", - apiKey[:12], req.Name, apiKey, now, now, + deviceID, req.Name, apiKey, now, now, ) if err != nil { jsonErr(w, 500, err.Error()) return } + // Link device to user. + s.db.Exec("INSERT OR IGNORE INTO server_user_devices (user_id, device_id) VALUES (?, ?)", userID, deviceID) jsonOK(w, map[string]interface{}{ - "device_id": apiKey[:12], + "device_id": deviceID, "api_key": apiKey, }) } +// ============================================================ +// Auth / User handlers +// ============================================================ + +func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Username string `json:"username"` + Email string `json:"email"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Username == "" || req.Email == "" || req.Password == "" { + jsonErr(w, 400, "username, email and password required") + return + } + if err := validatePassword(req.Password); err != "" { + jsonErr(w, 400, err) + return + } + if !strings.Contains(req.Email, "@") || !strings.Contains(req.Email, ".") { + jsonErr(w, 400, "invalid email") + return + } + hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) + if err != nil { + jsonErr(w, 500, "internal error") + return + } + now := time.Now().UTC().Format(time.RFC3339) + id := make([]byte, 12) + rand.Read(id) + userID := hex.EncodeToString(id) + _, err = s.db.Exec( + "INSERT INTO server_users (id, username, email, password_hash, confirmed, created_at) VALUES (?, ?, ?, ?, 0, ?)", + userID, req.Username, strings.ToLower(req.Email), string(hash), now, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + jsonErr(w, 409, "username or email already taken") + return + } + jsonErr(w, 500, err.Error()) + return + } + // Confirmation token. + tok := make([]byte, 24) + rand.Read(tok) + tokenStr := hex.EncodeToString(tok) + exp := time.Now().Add(48 * time.Hour).UTC().Format(time.RFC3339) + s.db.Exec("INSERT INTO server_email_tokens (token, user_id, purpose, expires_at, created_at) VALUES (?, ?, 'confirm', ?, ?)", + tokenStr, userID, exp, now) + // Try to send email. + host := s.smtpGet("smtp_host") + if host != "" { + confirmURL := fmt.Sprintf("%s/confirm?token=%s", s.smtpGet("server_url"), tokenStr) + if confirmURL == "" { + confirmURL = fmt.Sprintf("/api/v1/auth/confirm?token=%s", tokenStr) + } + body := fmt.Sprintf("Welcome to Verstak Sync!\n\nPlease confirm your email by clicking:\n%s\n\nIf you did not register, ignore this message.", confirmURL) + s.smtpSend(req.Email, "Confirm your Verstak Sync account", body) + } + jsonOK(w, map[string]string{"status": "confirmation_sent"}) +} + +func (s *Server) handleConfirm(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + jsonErr(w, 405, "GET required") + return + } + tokenStr := r.URL.Query().Get("token") + if tokenStr == "" { + jsonErr(w, 400, "token required") + return + } + var userID, expiresAt string + err := s.db.QueryRow("SELECT user_id, expires_at FROM server_email_tokens WHERE token=? AND purpose='confirm'", + tokenStr).Scan(&userID, &expiresAt) + if err != nil { + jsonErr(w, 400, "invalid or expired token") + return + } + exp, err := time.Parse(time.RFC3339, expiresAt) + if err != nil || time.Now().After(exp) { + jsonErr(w, 400, "token expired") + return + } + s.db.Exec("UPDATE server_users SET confirmed=1 WHERE id=?", userID) + s.db.Exec("DELETE FROM server_email_tokens WHERE token=?", tokenStr) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write([]byte("

Email confirmed

You can now log in.

")) +} + +func (s *Server) handleUserLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Username == "" || req.Password == "" { + jsonErr(w, 400, "username and password required") + return + } + var userID, hash string + var confirmed int + err := s.db.QueryRow("SELECT id, password_hash, confirmed FROM server_users WHERE username=? OR email=?", + req.Username, strings.ToLower(req.Username)).Scan(&userID, &hash, &confirmed) + if err != nil { + jsonErr(w, 401, "invalid credentials") + return + } + if confirmed == 0 { + jsonErr(w, 403, "email not confirmed") + return + } + if bcrypt.CompareHashAndPassword([]byte(hash), []byte(req.Password)) != nil { + jsonErr(w, 401, "invalid credentials") + return + } + tok := s.userTokens.Create(userID) + jsonOK(w, map[string]string{"token": tok, "user_id": userID}) +} + +func (s *Server) handleForgot(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Email string `json:"email"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Email == "" { + jsonErr(w, 400, "email required") + return + } + var userID string + err := s.db.QueryRow("SELECT id FROM server_users WHERE email=?", strings.ToLower(req.Email)).Scan(&userID) + if err != nil { + jsonOK(w, map[string]string{"status": "if email exists, reset link sent"}) + return + } + tok := make([]byte, 24) + rand.Read(tok) + tokenStr := hex.EncodeToString(tok) + exp := time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339) + now := time.Now().UTC().Format(time.RFC3339) + s.db.Exec("INSERT INTO server_email_tokens (token, user_id, purpose, expires_at, created_at) VALUES (?, ?, 'reset', ?, ?)", + tokenStr, userID, exp, now) + host := s.smtpGet("smtp_host") + if host != "" { + resetURL := fmt.Sprintf("%s/reset?token=%s", s.smtpGet("server_url"), tokenStr) + if resetURL == "" { + resetURL = fmt.Sprintf("/api/v1/auth/reset?token=%s", tokenStr) + } + body := fmt.Sprintf("Reset your Verstak Sync password:\n\n%s\n\nThis link expires in 1 hour.", resetURL) + s.smtpSend(req.Email, "Verstak Sync password reset", body) + } + jsonOK(w, map[string]string{"status": "if email exists, reset link sent"}) +} + +func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + jsonErr(w, 405, "POST required") + return + } + var req struct { + Token string `json:"token"` + NewPassword string `json:"new_password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + jsonErr(w, 400, "invalid JSON") + return + } + if req.Token == "" || req.NewPassword == "" { + jsonErr(w, 400, "token and new_password required") + return + } + if err := validatePassword(req.NewPassword); err != "" { + jsonErr(w, 400, err) + return + } + var userID, expiresAt string + err := s.db.QueryRow("SELECT user_id, expires_at FROM server_email_tokens WHERE token=? AND purpose='reset'", + req.Token).Scan(&userID, &expiresAt) + if err != nil { + jsonErr(w, 400, "invalid or expired token") + return + } + exp, err := time.Parse(time.RFC3339, expiresAt) + if err != nil || time.Now().After(exp) { + jsonErr(w, 400, "token expired") + return + } + hash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) + if err != nil { + jsonErr(w, 500, "internal error") + return + } + s.db.Exec("UPDATE server_users SET password_hash=? WHERE id=?", string(hash), userID) + s.db.Exec("DELETE FROM server_email_tokens WHERE token=?", req.Token) + jsonOK(w, map[string]string{"status": "password reset"}) +} + +func (s *Server) handleUserDevices(w http.ResponseWriter, r *http.Request) { + userID, ok := s.requireUser(w, r) + if !ok { + return + } + if r.Method != "GET" { + jsonErr(w, 405, "GET required") + return + } + rows, err := s.db.Query(` + SELECT d.id, d.name, d.last_seen, d.created_at + FROM server_devices d + JOIN server_user_devices ud ON ud.device_id = d.id + WHERE ud.user_id = ? + ORDER BY d.created_at`, userID) + if err != nil { + jsonErr(w, 500, err.Error()) + return + } + defer rows.Close() + type deviceDTO struct { + ID string `json:"id"` + Name string `json:"name"` + LastSeen string `json:"last_seen"` + CreatedAt string `json:"created_at"` + } + var devices []deviceDTO + for rows.Next() { + var d deviceDTO + var lastSeen sql.NullString + if err := rows.Scan(&d.ID, &d.Name, &lastSeen, &d.CreatedAt); err != nil { + continue + } + d.LastSeen = lastSeen.String + devices = append(devices, d) + } + if devices == nil { + devices = []deviceDTO{} + } + jsonOK(w, map[string]interface{}{"devices": devices}) +} + func (s *Server) handleSyncPush(w http.ResponseWriter, r *http.Request) { if !s.requireAPIKey(w, r) { return @@ -591,32 +1060,43 @@ func (s *Server) handleAdminDashboard(w http.ResponseWriter, r *http.Request) { s.db.QueryRow("SELECT COUNT(*) FROM server_devices").Scan(&deviceCount) s.db.QueryRow("SELECT COUNT(*) FROM server_ops").Scan(&opsCount) + // Load SMTP config for display. + smtpHost := s.smtpGet("smtp_host") + smtpPort := s.smtpGet("smtp_port") + smtpUser := s.smtpGet("smtp_user") + smtpFrom := s.smtpGet("smtp_from") + srvURL := s.smtpGet("server_url") + html := fmt.Sprintf(` Verstak Sync — Admin

Verstak Sync Server

@@ -624,6 +1104,9 @@ form{margin-top:8px}
Устройств: %d
Операций: %d
+ +
+

API-ключи

-

Новый ключ

-
- +

Новый ключ

+ +
-

Health check

-`, deviceCount, opsCount) +
+ +
+

SMTP (для писем)

+
+
+
+
+
+
+
+
+
+
+
+
+
+ +

Health check

+`, deviceCount, opsCount, smtpHost, smtpPort, smtpUser, smtpFrom, srvURL) w.Write([]byte(html)) } @@ -706,8 +1207,22 @@ func (s *Server) handleAdminAPI(w http.ResponseWriter, r *http.Request) { jsonErr(w, 500, err.Error()) return } + s.db.Exec("DELETE FROM server_user_devices WHERE device_id=?", id) jsonOK(w, map[string]string{"status": "deleted"}) + case path == "/api/smtp" && r.Method == "POST": + if err := r.ParseForm(); err != nil { + jsonErr(w, 400, "bad form") + return + } + for _, key := range []string{"smtp_host", "smtp_port", "smtp_user", "smtp_pass", "smtp_from", "server_url"} { + val := r.FormValue(key) + if val != "" { + s.smtpSet(key, val) + } + } + http.Redirect(w, r, "/admin/dashboard", http.StatusFound) + default: jsonErr(w, 404, "not found") }