sshkeeper/internal/db/servers.go

409 lines
11 KiB
Go

package db
import (
"database/sql"
"sort"
"strings"
"time"
"github.com/mirivlad/sshkeeper/internal/model"
)
func (db *DB) CreateServer(s *model.Server) error {
result, err := db.conn.Exec(`
INSERT INTO servers (alias, display_name, host, port, user, auth_method, identity_file, proxy_jump, group_name, notes, startup_command)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod, s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, s.StartupCommand)
if err != nil {
return err
}
s.ID, _ = result.LastInsertId()
return nil
}
func (db *DB) UpdateServer(s *model.Server) error {
_, err := db.conn.Exec(`
UPDATE servers SET
display_name=?, host=?, port=?, user=?, auth_method=?,
identity_file=?, proxy_jump=?, group_name=?, notes=?, startup_command=?, updated_at=CURRENT_TIMESTAMP
WHERE alias=?`,
s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, s.StartupCommand, s.Alias)
return err
}
func (db *DB) UpdateServerByAlias(oldAlias string, s *model.Server) error {
_, err := db.conn.Exec(`
UPDATE servers SET
alias=?, display_name=?, host=?, port=?, user=?, auth_method=?,
identity_file=?, proxy_jump=?, group_name=?, notes=?, startup_command=?, updated_at=CURRENT_TIMESTAMP
WHERE alias=?`,
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, s.StartupCommand, oldAlias)
return err
}
func (db *DB) DeleteServer(alias string) error {
_, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias)
return err
}
func (db *DB) GetServer(alias string) (*model.Server, error) {
var s model.Server
var lastConnected, lastTest sql.NullTime
err := db.conn.QueryRow(`
SELECT id, alias, display_name, host, port, user, auth_method,
identity_file, proxy_jump, group_name, notes, startup_command,
created_at, updated_at, last_connected_at,
last_test_at, last_test_status, last_test_error
FROM servers WHERE alias=?`, alias).Scan(
&s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod,
&s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.StartupCommand,
&s.CreatedAt, &s.UpdatedAt, &lastConnected,
&lastTest, &s.LastTestStatus, &s.LastTestError)
if err != nil {
return nil, err
}
if lastConnected.Valid {
s.LastConnectedAt = &lastConnected.Time
}
if lastTest.Valid {
s.LastTestAt = &lastTest.Time
}
tags, err := db.GetServerTags(s.ID)
if err != nil {
return nil, err
}
s.Tags = tags
return &s, nil
}
func (db *DB) ListServers() ([]*model.Server, error) {
rows, err := db.conn.Query(`
SELECT id, alias, display_name, host, port, user, auth_method,
identity_file, proxy_jump, group_name, notes, startup_command,
created_at, updated_at, last_connected_at,
last_test_at, last_test_status, last_test_error
FROM servers ORDER BY alias`)
if err != nil {
return nil, err
}
defer rows.Close()
var servers []*model.Server
for rows.Next() {
var s model.Server
var lastConnected, lastTest sql.NullTime
err := rows.Scan(
&s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod,
&s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.StartupCommand,
&s.CreatedAt, &s.UpdatedAt, &lastConnected,
&lastTest, &s.LastTestStatus, &s.LastTestError)
if err != nil {
return nil, err
}
if lastConnected.Valid {
s.LastConnectedAt = &lastConnected.Time
}
if lastTest.Valid {
s.LastTestAt = &lastTest.Time
}
tags, err := db.GetServerTags(s.ID)
if err != nil {
return nil, err
}
s.Tags = tags
servers = append(servers, &s)
}
return servers, rows.Err()
}
func (db *DB) SearchServers(query string) ([]*model.Server, error) {
pattern := "%" + query + "%"
rows, err := db.conn.Query(`
SELECT id, alias, display_name, host, port, user, auth_method,
identity_file, proxy_jump, group_name, notes, startup_command,
created_at, updated_at, last_connected_at,
last_test_at, last_test_status, last_test_error
FROM servers
WHERE alias LIKE ? OR display_name LIKE ? OR host LIKE ? OR user LIKE ? OR group_name LIKE ?
ORDER BY alias`, pattern, pattern, pattern, pattern, pattern)
if err != nil {
return nil, err
}
defer rows.Close()
var servers []*model.Server
for rows.Next() {
var s model.Server
var lastConnected, lastTest sql.NullTime
err := rows.Scan(
&s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod,
&s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.StartupCommand,
&s.CreatedAt, &s.UpdatedAt, &lastConnected,
&lastTest, &s.LastTestStatus, &s.LastTestError)
if err != nil {
return nil, err
}
if lastConnected.Valid {
s.LastConnectedAt = &lastConnected.Time
}
if lastTest.Valid {
s.LastTestAt = &lastTest.Time
}
tags, err := db.GetServerTags(s.ID)
if err != nil {
return nil, err
}
s.Tags = tags
servers = append(servers, &s)
}
return servers, rows.Err()
}
func (db *DB) UpdateTestResult(alias string, status model.TestStatus, testErr string) error {
_, err := db.conn.Exec(`
UPDATE servers SET last_test_at=CURRENT_TIMESTAMP, last_test_status=?, last_test_error=?
WHERE alias=?`, status, testErr, alias)
return err
}
func (db *DB) UpdateLastConnected(alias string) error {
_, err := db.conn.Exec("UPDATE servers SET last_connected_at=CURRENT_TIMESTAMP WHERE alias=?", alias)
return err
}
// Tag methods
func (db *DB) AddTagToServer(serverID int64, tagName string) error {
tagName = strings.TrimSpace(tagName)
if tagName == "" {
return nil
}
var tagID int64
err := db.conn.QueryRow("SELECT id FROM tags WHERE name=?", tagName).Scan(&tagID)
if err == sql.ErrNoRows {
result, err := db.conn.Exec("INSERT INTO tags (name) VALUES (?)", tagName)
if err != nil {
return err
}
tagID, _ = result.LastInsertId()
} else if err != nil {
return err
}
_, err = db.conn.Exec("INSERT OR IGNORE INTO server_tags (server_id, tag_id) VALUES (?, ?)", serverID, tagID)
return err
}
func (db *DB) SetServerTags(serverID int64, tagNames []string) error {
if _, err := db.conn.Exec("DELETE FROM server_tags WHERE server_id=?", serverID); err != nil {
return err
}
for _, tagName := range uniqueCleanStrings(tagNames) {
if err := db.AddTagToServer(serverID, tagName); err != nil {
return err
}
}
return nil
}
func (db *DB) ListTags() ([]string, error) {
rows, err := db.conn.Query("SELECT name FROM tags ORDER BY name")
if err != nil {
return nil, err
}
defer rows.Close()
var tags []string
for rows.Next() {
var tag string
if err := rows.Scan(&tag); err != nil {
return nil, err
}
tags = append(tags, tag)
}
return tags, rows.Err()
}
func (db *DB) RenameTag(oldName, newName string) error {
oldName = strings.TrimSpace(oldName)
newName = strings.TrimSpace(newName)
if oldName == "" || newName == "" {
return nil
}
_, err := db.conn.Exec("UPDATE tags SET name=? WHERE name=?", newName, oldName)
return err
}
func (db *DB) DeleteTag(name string) error {
name = strings.TrimSpace(name)
if name == "" {
return nil
}
_, err := db.conn.Exec("DELETE FROM tags WHERE name=?", name)
return err
}
func (db *DB) GetServerTags(serverID int64) ([]string, error) {
rows, err := db.conn.Query(`
SELECT t.name FROM tags t
JOIN server_tags st ON st.tag_id = t.id
WHERE st.server_id = ?
ORDER BY t.name`, serverID)
if err != nil {
return nil, err
}
defer rows.Close()
var tags []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
tags = append(tags, name)
}
return tags, rows.Err()
}
// Forward methods
func (db *DB) AddForward(serverID int64, fwdType model.ForwardType, localAddr string, localPort int, remoteAddr string, remotePort int) error {
_, err := db.conn.Exec(`
INSERT INTO forwards (server_id, type, local_addr, local_port, remote_addr, remote_port)
VALUES (?, ?, ?, ?, ?, ?)`,
serverID, fwdType, localAddr, localPort, remoteAddr, remotePort)
return err
}
func (db *DB) GetForwards(serverID int64) ([]*model.Forward, error) {
rows, err := db.conn.Query(`
SELECT id, server_id, type, local_addr, local_port, remote_addr, remote_port
FROM forwards WHERE server_id=?`, serverID)
if err != nil {
return nil, err
}
defer rows.Close()
var forwards []*model.Forward
for rows.Next() {
var f model.Forward
if err := rows.Scan(&f.ID, &f.ServerID, &f.Type, &f.LocalAddr, &f.LocalPort, &f.RemoteAddr, &f.RemotePort); err != nil {
return nil, err
}
forwards = append(forwards, &f)
}
return forwards, rows.Err()
}
// Ensure time import is used
var _ time.Time
func (db *DB) CreateCommandTemplate(t *model.CommandTemplate) error {
result, err := db.conn.Exec(
"INSERT INTO global_command_templates (name, command, description) VALUES (?, ?, ?)",
t.Name, t.Command, t.Description)
if err != nil {
return err
}
t.ID, _ = result.LastInsertId()
return err
}
func (db *DB) GetCommandTemplate(name string) (*model.CommandTemplate, error) {
var t model.CommandTemplate
err := db.conn.QueryRow(`
SELECT id, name, command, description
FROM global_command_templates WHERE name=?`, name).Scan(&t.ID, &t.Name, &t.Command, &t.Description)
if err != nil {
return nil, err
}
return &t, nil
}
func (db *DB) ListCommandTemplates() ([]*model.CommandTemplate, error) {
rows, err := db.conn.Query(`
SELECT id, name, command, description
FROM global_command_templates
ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
var templates []*model.CommandTemplate
for rows.Next() {
var t model.CommandTemplate
if err := rows.Scan(&t.ID, &t.Name, &t.Command, &t.Description); err != nil {
return nil, err
}
templates = append(templates, &t)
}
return templates, rows.Err()
}
func (db *DB) UpdateCommandTemplate(oldName string, t *model.CommandTemplate) error {
_, err := db.conn.Exec(`
UPDATE global_command_templates
SET name=?, command=?, description=?, updated_at=CURRENT_TIMESTAMP
WHERE name=?`, t.Name, t.Command, t.Description, oldName)
return err
}
func (db *DB) DeleteCommandTemplate(name string) error {
_, err := db.conn.Exec("DELETE FROM global_command_templates WHERE name=?", name)
return err
}
func uniqueCleanStrings(values []string) []string {
seen := map[string]bool{}
var result []string
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" || seen[value] {
continue
}
seen[value] = true
result = append(result, value)
}
sort.Strings(result)
return result
}
// GetGroups returns all unique group names with server count
func (db *DB) GetGroups() ([]string, error) {
rows, err := db.conn.Query(`
SELECT group_name FROM servers
WHERE group_name != ''
GROUP BY group_name
ORDER BY group_name`)
if err != nil {
return nil, err
}
defer rows.Close()
var groups []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
groups = append(groups, name)
}
return groups, rows.Err()
}
// RenameGroup renames a group for all servers in it
func (db *DB) RenameGroup(oldName, newName string) error {
_, err := db.conn.Exec(
"UPDATE servers SET group_name = ?, updated_at = CURRENT_TIMESTAMP WHERE group_name = ?",
newName, oldName)
return err
}
// DeleteGroup removes group assignment from all servers
func (db *DB) DeleteGroup(name string) error {
_, err := db.conn.Exec(
"UPDATE servers SET group_name = '', updated_at = CURRENT_TIMESTAMP WHERE group_name = ?",
name)
return err
}