Add TUI tag and template management

This commit is contained in:
mirivlad 2026-05-28 17:36:50 +08:00
parent 0bd4463819
commit c2b0e57f3a
17 changed files with 2129 additions and 195 deletions

View File

@ -65,7 +65,13 @@ sshkeeper run web "uptime"
# Groups and templates # Groups and templates
sshkeeper group list sshkeeper group list
sshkeeper template list web sshkeeper template list
sshkeeper template add uptime "uptime"
sshkeeper run-template web uptime
# Tags and startup command
sshkeeper add web --host 10.0.0.10 --user deploy --auth key --tags prod,web --startup-command "tmux attach -t ops"
sshkeeper edit web --tags prod,web --startup-command "tmux attach -t ops"
# OpenSSH config # OpenSSH config
sshkeeper ssh-config generate sshkeeper ssh-config generate
@ -86,13 +92,22 @@ Running `sshkeeper` without arguments opens the TUI.
| Key | Action | | Key | Action |
| --- | --- | | --- | --- |
| Enter | Connect to selected server | | Enter | Connect to selected server |
| Ctrl+R | Pick and run a command template on the selected servers |
| Insert | Select or unselect a server, then move to the next row |
| Ctrl+A | Add server | | Ctrl+A | Add server |
| Ctrl+E | Edit server | | Ctrl+E | Edit server |
| Ctrl+D | Delete server | | Ctrl+D | Delete server |
| Ctrl+T | Test connection | | Ctrl+T | Test connection |
| Ctrl+F | Search | | Ctrl+F | Search |
| Ctrl+G | Manage tags |
| Ctrl+P | Manage global command templates |
| Ctrl+Q / Ctrl+C | Quit | | Ctrl+Q / Ctrl+C | Quit |
Templates are global entities and can run on any server. Foreground template
runs leave the TUI, show the SSH session in the terminal, and then return to the
TUI. Background runs execute the command and show per-server output in a result
screen.
In add/edit forms: In add/edit forms:
| Key | Action | | Key | Action |

View File

@ -24,6 +24,7 @@ var addFlags struct {
groupName string groupName string
displayName string displayName string
notes string notes string
startup string
tags string tags string
} }
@ -50,16 +51,17 @@ func addInteractive() error {
func addNonInteractive(alias string) error { func addNonInteractive(alias string) error {
server := &model.Server{ server := &model.Server{
Alias: alias, Alias: alias,
DisplayName: addFlags.displayName, DisplayName: addFlags.displayName,
Host: addFlags.host, Host: addFlags.host,
Port: addFlags.port, Port: addFlags.port,
User: addFlags.user, User: addFlags.user,
AuthMethod: model.AuthMethod(addFlags.authMethod), AuthMethod: model.AuthMethod(addFlags.authMethod),
IdentityFile: addFlags.identityFile, IdentityFile: addFlags.identityFile,
ProxyJump: addFlags.proxyJump, ProxyJump: addFlags.proxyJump,
GroupName: addFlags.groupName, GroupName: addFlags.groupName,
Notes: addFlags.notes, Notes: addFlags.notes,
StartupCommand: addFlags.startup,
} }
if server.Port == 0 { if server.Port == 0 {
@ -118,14 +120,11 @@ func saveServerWithOptionalSecret(server *model.Server) error {
} }
if addFlags.tags != "" { if addFlags.tags != "" {
tagList := strings.Split(addFlags.tags, ",") server.Tags = strings.Split(addFlags.tags, ",")
for _, t := range tagList { }
t = strings.TrimSpace(t) if len(server.Tags) > 0 {
if t != "" { if err := appDB.SetServerTags(server.ID, server.Tags); err != nil {
if err := appDB.AddTagToServer(server.ID, t); err != nil { return fmt.Errorf("set tags: %w", err)
return fmt.Errorf("add tag %s: %w", t, err)
}
}
} }
} }
@ -184,18 +183,28 @@ func promptServerForAdd(in io.Reader, out io.Writer) (*model.Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
startupCommand, err := promptOptional(reader, out, "Startup command", "")
if err != nil {
return nil, err
}
tagsText, err := promptOptional(reader, out, "Tags (comma-separated)", "")
if err != nil {
return nil, err
}
return &model.Server{ return &model.Server{
Alias: alias, Alias: alias,
DisplayName: displayName, DisplayName: displayName,
Host: host, Host: host,
Port: port, Port: port,
User: user, User: user,
AuthMethod: authMethod, AuthMethod: authMethod,
IdentityFile: identityFile, IdentityFile: identityFile,
ProxyJump: proxyJump, ProxyJump: proxyJump,
GroupName: groupName, GroupName: groupName,
Notes: notes, Notes: notes,
StartupCommand: startupCommand,
Tags: strings.Split(tagsText, ","),
}, nil }, nil
} }
@ -248,5 +257,6 @@ func init() {
addCmd.Flags().StringVar(&addFlags.groupName, "group", "", "Server group") addCmd.Flags().StringVar(&addFlags.groupName, "group", "", "Server group")
addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name") addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name")
addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes") addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes")
addCmd.Flags().StringVar(&addFlags.startup, "startup-command", "", "Command to run after connecting")
addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags") addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags")
} }

View File

@ -20,6 +20,8 @@ func TestPromptServerForAddCollectsInteractiveFields(t *testing.T) {
"bastion", "bastion",
"prod", "prod",
"critical host", "critical host",
"tmux attach -t prod",
"prod,web",
"", "",
}, "\n")) }, "\n"))
var output bytes.Buffer var output bytes.Buffer
@ -38,9 +40,13 @@ func TestPromptServerForAddCollectsInteractiveFields(t *testing.T) {
server.IdentityFile != "~/.ssh/id_prod" || server.IdentityFile != "~/.ssh/id_prod" ||
server.ProxyJump != "bastion" || server.ProxyJump != "bastion" ||
server.GroupName != "prod" || server.GroupName != "prod" ||
server.Notes != "critical host" { server.Notes != "critical host" ||
server.StartupCommand != "tmux attach -t prod" {
t.Fatalf("unexpected server: %#v", server) t.Fatalf("unexpected server: %#v", server)
} }
if strings.Join(server.Tags, ",") != "prod,web" {
t.Fatalf("tags = %#v", server.Tags)
}
if strings.Contains(output.String(), "not yet implemented") { if strings.Contains(output.String(), "not yet implemented") {
t.Fatalf("interactive add should not report unimplemented:\n%s", output.String()) t.Fatalf("interactive add should not report unimplemented:\n%s", output.String())
} }
@ -59,6 +65,8 @@ func TestPromptServerForAddAppliesDefaults(t *testing.T) {
"", "",
"", "",
"", "",
"",
"",
}, "\n")) }, "\n"))
var output bytes.Buffer var output bytes.Buffer

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"fmt" "fmt"
"strings"
"syscall" "syscall"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
@ -49,6 +50,13 @@ var editCmd = &cobra.Command{
if parsedNotes != "" { if parsedNotes != "" {
server.Notes = parsedNotes server.Notes = parsedNotes
} }
if parsedStartup != "" {
server.StartupCommand = parsedStartup
}
tagsChanged := cmd.Flags().Changed("tags")
if tagsChanged {
server.Tags = strings.Split(parsedTags, ",")
}
if parsedAuth != "" && oldAuthMethod != server.AuthMethod { if parsedAuth != "" && oldAuthMethod != server.AuthMethod {
v := getOrCreateVault() v := getOrCreateVault()
@ -88,6 +96,11 @@ var editCmd = &cobra.Command{
if err := appDB.UpdateServer(server); err != nil { if err := appDB.UpdateServer(server); err != nil {
return fmt.Errorf("update server: %w", err) return fmt.Errorf("update server: %w", err)
} }
if tagsChanged {
if err := appDB.SetServerTags(server.ID, server.Tags); err != nil {
return fmt.Errorf("set tags: %w", err)
}
}
fmt.Println("Saved.") fmt.Println("Saved.")
return nil return nil
@ -104,6 +117,8 @@ var (
parsedGroup string parsedGroup string
parsedDisplayName string parsedDisplayName string
parsedNotes string parsedNotes string
parsedStartup string
parsedTags string
) )
func init() { func init() {
@ -116,4 +131,6 @@ func init() {
editCmd.Flags().StringVar(&parsedGroup, "group", "", "Server group") editCmd.Flags().StringVar(&parsedGroup, "group", "", "Server group")
editCmd.Flags().StringVar(&parsedDisplayName, "display-name", "", "Display name") editCmd.Flags().StringVar(&parsedDisplayName, "display-name", "", "Display name")
editCmd.Flags().StringVar(&parsedNotes, "notes", "", "Notes") editCmd.Flags().StringVar(&parsedNotes, "notes", "", "Notes")
editCmd.Flags().StringVar(&parsedStartup, "startup-command", "", "Command to run after connecting")
editCmd.Flags().StringVar(&parsedTags, "tags", "", "Comma-separated tags")
} }

View File

@ -2,8 +2,6 @@ package cmd
import ( import (
"fmt" "fmt"
"os"
"os/exec"
"strings" "strings"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
@ -74,46 +72,23 @@ var runCmd = &cobra.Command{
return fmt.Errorf("server not found: %s", alias) return fmt.Errorf("server not found: %s", alias)
} }
if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase { return runCommandOnServer(server, command)
return runWithSecret(server, command)
}
// For key/agent auth — direct execution
sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command)
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
if err := sshCmd.Start(); err != nil {
return fmt.Errorf("start ssh: %w", err)
}
return sshCmd.Wait()
}, },
} }
// runWithSecret runs a command on a server through the PTY prompt handler. func runCommandOnServer(server *model.Server, command string) error {
func runWithSecret(server *model.Server, command string) error { return ssh.RunCommand(cfg, server, commandVaultFunc, command)
}
func commandVaultFunc(serverAlias string, secretType string) (string, error) {
v := getOrCreateVault() v := getOrCreateVault()
if !v.IsUnlocked() { if !v.IsUnlocked() {
return fmt.Errorf("%s", vaultLockedProcessMessage()) return "", fmt.Errorf("%s", vaultLockedProcessMessage())
} }
vaultKey := fmt.Sprintf("server:%s:%s", serverAlias, secretType)
secretType := "ssh_password" data, err := v.Get(vaultKey)
if server.AuthMethod == model.AuthKeyPassphrase {
secretType = "key_passphrase"
}
vaultKey := fmt.Sprintf("server:%s:%s", server.Alias, secretType)
secret, err := v.Get(vaultKey)
if err != nil { if err != nil {
return fmt.Errorf("get %s from vault: %w", secretType, err) return "", err
} }
return string(data), nil
sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command)
return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(secret))
} }

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"fmt" "fmt"
"strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -32,6 +33,12 @@ var showCmd = &cobra.Command{
if server.GroupName != "" { if server.GroupName != "" {
fmt.Printf("Group: %s\n", server.GroupName) fmt.Printf("Group: %s\n", server.GroupName)
} }
if len(server.Tags) > 0 {
fmt.Printf("Tags: %s\n", strings.Join(server.Tags, ", "))
}
if server.StartupCommand != "" {
fmt.Printf("Startup Cmd: %s\n", server.StartupCommand)
}
if server.Notes != "" { if server.Notes != "" {
fmt.Printf("Notes: %s\n", server.Notes) fmt.Printf("Notes: %s\n", server.Notes)
} }

View File

@ -2,30 +2,22 @@ package cmd
import ( import (
"fmt" "fmt"
"os"
"os/exec"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/ssh"
) )
var templateCmd = &cobra.Command{ var templateCmd = &cobra.Command{
Use: "template", Use: "template",
Short: "Command template management", Short: "Global command template management",
} }
var templateListCmd = &cobra.Command{ var templateListCmd = &cobra.Command{
Use: "list <alias>", Use: "list",
Short: "List command templates for a server", Short: "List global command templates",
Args: cobra.ExactArgs(1), Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0] templates, err := appDB.ListCommandTemplates()
_, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
templates, err := appDB.GetCommandTemplates(alias)
if err != nil { if err != nil {
return fmt.Errorf("list templates: %w", err) return fmt.Errorf("list templates: %w", err)
} }
@ -36,27 +28,19 @@ var templateListCmd = &cobra.Command{
} }
for _, t := range templates { for _, t := range templates {
fmt.Printf(" %-15s %s\n", t.Name, t.Command) fmt.Printf(" %-20s %s\n", t.Name, t.Command)
} }
return nil return nil
}, },
} }
var templateAddCmd = &cobra.Command{ var templateAddCmd = &cobra.Command{
Use: "add <alias> <name> <command>", Use: "add <name> <command>",
Short: "Add a command template", Short: "Add a global command template",
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0] t := &model.CommandTemplate{Name: args[0], Command: args[1]}
name := args[1] if err := appDB.CreateCommandTemplate(t); err != nil {
command := args[2]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
if err := appDB.AddCommandTemplate(server.ID, name, command); err != nil {
return fmt.Errorf("add template: %w", err) return fmt.Errorf("add template: %w", err)
} }
@ -65,9 +49,38 @@ var templateAddCmd = &cobra.Command{
}, },
} }
var templateEditCmd = &cobra.Command{
Use: "edit <old-name> <name> <command>",
Short: "Edit a global command template",
Args: cobra.ExactArgs(3),
RunE: func(cmd *cobra.Command, args []string) error {
t := &model.CommandTemplate{Name: args[1], Command: args[2]}
if err := appDB.UpdateCommandTemplate(args[0], t); err != nil {
return fmt.Errorf("edit template: %w", err)
}
fmt.Println("Template saved.")
return nil
},
}
var templateDeleteCmd = &cobra.Command{
Use: "delete <name>",
Short: "Delete a global command template",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
if err := appDB.DeleteCommandTemplate(args[0]); err != nil {
return fmt.Errorf("delete template: %w", err)
}
fmt.Println("Template deleted.")
return nil
},
}
var runTemplateCmd = &cobra.Command{ var runTemplateCmd = &cobra.Command{
Use: "run-template <alias> <template>", Use: "run-template <alias> <template>",
Short: "Run a command template on a server", Short: "Run a global command template on a server",
Args: cobra.ExactArgs(2), Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0] alias := args[0]
@ -78,43 +91,19 @@ var runTemplateCmd = &cobra.Command{
return fmt.Errorf("server not found: %s", alias) return fmt.Errorf("server not found: %s", alias)
} }
templates, err := appDB.GetCommandTemplates(alias) template, err := appDB.GetCommandTemplate(templateName)
if err != nil { if err != nil {
return fmt.Errorf("list templates: %w", err)
}
var command string
for _, t := range templates {
if t.Name == templateName {
command = t.Command
break
}
}
if command == "" {
return fmt.Errorf("template not found: %s", templateName) return fmt.Errorf("template not found: %s", templateName)
} }
fmt.Printf("Running '%s' on %s...\n", command, alias) fmt.Printf("Running '%s' on %s...\n", template.Command, alias)
return runCommandOnServer(server, template.Command)
// Build ssh args with the command
sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command)
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
if err := sshCmd.Start(); err != nil {
return fmt.Errorf("start ssh: %w", err)
}
return sshCmd.Wait()
}, },
} }
func init() { func init() {
templateCmd.AddCommand(templateListCmd) templateCmd.AddCommand(templateListCmd)
templateCmd.AddCommand(templateAddCmd) templateCmd.AddCommand(templateAddCmd)
templateCmd.AddCommand(templateEditCmd)
templateCmd.AddCommand(templateDeleteCmd)
} }

View File

@ -72,9 +72,15 @@ func runTUI() error {
existing, _ := appDB.GetServer(lookupAlias) existing, _ := appDB.GetServer(lookupAlias)
if existing != nil { if existing != nil {
server.ID = existing.ID server.ID = existing.ID
return appDB.UpdateServerByAlias(existing.Alias, server) if err := appDB.UpdateServerByAlias(existing.Alias, server); err != nil {
return err
}
return appDB.SetServerTags(existing.ID, server.Tags)
} }
return appDB.CreateServer(server) if err := appDB.CreateServer(server); err != nil {
return err
}
return appDB.SetServerTags(server.ID, server.Tags)
} }
tui.GetGroups = func() ([]string, error) { tui.GetGroups = func() ([]string, error) {
@ -86,6 +92,38 @@ func runTUI() error {
tui.DeleteGroup = func(name string) error { tui.DeleteGroup = func(name string) error {
return appDB.DeleteGroup(name) return appDB.DeleteGroup(name)
} }
tui.ListTags = func() ([]string, error) {
return appDB.ListTags()
}
tui.RenameTag = func(oldName, newName string) error {
return appDB.RenameTag(oldName, newName)
}
tui.DeleteTag = func(name string) error {
return appDB.DeleteTag(name)
}
tui.SetServerTags = func(server *model.Server, tags []string) error {
server.Tags = tags
return appDB.SetServerTags(server.ID, tags)
}
tui.ListCommandTemplates = func() ([]*model.CommandTemplate, error) {
return appDB.ListCommandTemplates()
}
tui.SaveCommandTemplate = func(oldName string, template *model.CommandTemplate) error {
if oldName == "" {
return appDB.CreateCommandTemplate(template)
}
return appDB.UpdateCommandTemplate(oldName, template)
}
tui.DeleteCommandTemplate = func(name string) error {
return appDB.DeleteCommandTemplate(name)
}
tui.RunTemplateBackground = func(server *model.Server, command string) (string, error) {
fresh, err := appDB.GetServer(server.Alias)
if err != nil {
return "", err
}
return ssh.RunCommandOutput(cfg, fresh, vaultFunc, command)
}
tui.UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error { tui.UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
return appDB.UpdateTestResult(alias, status, testErr) return appDB.UpdateTestResult(alias, status, testErr)
} }
@ -139,6 +177,27 @@ func runTUI() error {
servers, _ = appDB.ListServers() servers, _ = appDB.ListServers()
continue continue
} }
if result != nil && result.Action == "run_template_foreground" && len(result.Servers) > 0 {
for _, server := range result.Servers {
fresh, err := appDB.GetServer(server.Alias)
if err != nil {
fmt.Fprintf(os.Stderr, "Server not found: %s\n", server.Alias)
continue
}
fmt.Printf("Running template %q on %s...\n", result.TemplateName, fresh.Alias)
if err := ssh.RunCommand(cfg, fresh, vaultFunc, result.Command); err != nil {
fmt.Fprintf(os.Stderr, "Command error on %s: %v\n", fresh.Alias, err)
}
appDB.UpdateLastConnected(fresh.Alias)
}
fmt.Println("\n[Press Enter to return to sshkeeper]")
buf := make([]byte, 1)
os.Stdin.Read(buf)
servers, _ = appDB.ListServers()
continue
}
// Normal quit (q or Esc) // Normal quit (q or Esc)
return nil return nil

View File

@ -0,0 +1,51 @@
# TUI Tags And Global Templates Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add full TUI support for tags and global command templates, including multi-server template execution.
**Architecture:** Move command templates to global entities while preserving legacy server-scoped rows as importable data. Add server `startup_command`, richer tag CRUD helpers, and TUI screens for tag/template management. Template execution uses existing OpenSSH command construction, with foreground execution returning control to the terminal and background execution collecting per-server results in a TUI results screen.
**Tech Stack:** Go, Bubble Tea, Bubbles list/textinput, SQLite, existing Cobra CLI and TUI package.
---
### Task 1: Data Model And CLI
- [x] Add `StartupCommand` to `model.Server`.
- [x] Add DB schema migration helpers for `servers.startup_command` and global `global_command_templates`.
- [x] Add DB CRUD for global templates and tag management.
- [x] Update CLI `template` commands to manage global templates.
- [x] Update `run-template` to run a global template on a server.
- [x] Add focused DB and CLI tests.
### Task 2: TUI Tags
- [x] Add tags to server form as comma-separated input.
- [x] Persist tags on add/edit.
- [x] Show tags in selected-server details.
- [x] Add a tag management screen with list, add, rename, delete, assign/remove selected servers.
- [x] Add tests for tag rendering and callbacks.
### Task 3: TUI Templates And Selection
- [x] Add multi-selection state toggled by Insert.
- [x] Show selected markers and selected count in list footer.
- [x] Add global template picker opened by Shift+Enter.
- [x] Add template manager screen with list/add/edit/delete.
- [x] Add startup command field to server form.
- [x] Add tests for selection and template screen state.
### Task 4: Template Execution
- [x] Add foreground template execution result path that exits TUI and lets caller run SSH.
- [x] Add background execution callback and results screen.
- [x] Run background template on selected servers, collecting stdout/stderr/status.
- [x] Add tests for run mode selection and result rendering.
### Task 5: Verification
- [x] Run `env GOCACHE=/tmp/sshkeeper-go-cache go test ./...`.
- [x] Run `env GOCACHE=/tmp/sshkeeper-go-cache go build -o bin/sshkeeper .`.
- [ ] Smoke-test CLI template CRUD on temporary XDG paths.
- [ ] Commit final implementation.

View File

@ -34,12 +34,71 @@ func Open(dataDir string) (*DB, error) {
if err := db.migrate(); err != nil { if err := db.migrate(); err != nil {
return nil, fmt.Errorf("migrate: %w", err) return nil, fmt.Errorf("migrate: %w", err)
} }
if err := db.ensureSchema(); err != nil {
return nil, fmt.Errorf("ensure schema: %w", err)
}
os.Chmod(dbPath, 0600) os.Chmod(dbPath, 0600)
return db, nil return db, nil
} }
func (db *DB) ensureSchema() error {
hasStartupCommand, err := db.hasColumn("servers", "startup_command")
if err != nil {
return err
}
if !hasStartupCommand {
if _, err := db.conn.Exec("ALTER TABLE servers ADD COLUMN startup_command TEXT NOT NULL DEFAULT ''"); err != nil {
return fmt.Errorf("add startup_command: %w", err)
}
}
_, err = db.conn.Exec(`
CREATE TABLE IF NOT EXISTS global_command_templates (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
command TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil {
return fmt.Errorf("create global templates: %w", err)
}
_, err = db.conn.Exec(`
INSERT OR IGNORE INTO global_command_templates (name, command)
SELECT name, command FROM command_templates`)
if err != nil {
return fmt.Errorf("copy legacy templates: %w", err)
}
return nil
}
func (db *DB) hasColumn(tableName, columnName string) (bool, error) {
rows, err := db.conn.Query("PRAGMA table_info(" + tableName + ")")
if err != nil {
return false, err
}
defer rows.Close()
for rows.Next() {
var cid int
var name, typ string
var notNull int
var defaultValue sql.NullString
var pk int
if err := rows.Scan(&cid, &name, &typ, &notNull, &defaultValue, &pk); err != nil {
return false, err
}
if name == columnName {
return true, nil
}
}
return false, rows.Err()
}
func (db *DB) Close() error { func (db *DB) Close() error {
return db.conn.Close() return db.conn.Close()
} }

View File

@ -2,6 +2,8 @@ package db
import ( import (
"database/sql" "database/sql"
"sort"
"strings"
"time" "time"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
@ -9,9 +11,9 @@ import (
func (db *DB) CreateServer(s *model.Server) error { func (db *DB) CreateServer(s *model.Server) error {
result, err := db.conn.Exec(` result, err := db.conn.Exec(`
INSERT INTO servers (alias, display_name, host, port, user, auth_method, identity_file, proxy_jump, group_name, notes) INSERT INTO servers (alias, display_name, host, port, user, auth_method, identity_file, proxy_jump, group_name, notes, startup_command)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod, s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes) s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod, s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, s.StartupCommand)
if err != nil { if err != nil {
return err return err
} }
@ -23,10 +25,10 @@ func (db *DB) UpdateServer(s *model.Server) error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
UPDATE servers SET UPDATE servers SET
display_name=?, host=?, port=?, user=?, auth_method=?, display_name=?, host=?, port=?, user=?, auth_method=?,
identity_file=?, proxy_jump=?, group_name=?, notes=?, updated_at=CURRENT_TIMESTAMP identity_file=?, proxy_jump=?, group_name=?, notes=?, startup_command=?, updated_at=CURRENT_TIMESTAMP
WHERE alias=?`, WHERE alias=?`,
s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, s.Alias) s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, s.StartupCommand, s.Alias)
return err return err
} }
@ -34,10 +36,10 @@ func (db *DB) UpdateServerByAlias(oldAlias string, s *model.Server) error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
UPDATE servers SET UPDATE servers SET
alias=?, display_name=?, host=?, port=?, user=?, auth_method=?, alias=?, display_name=?, host=?, port=?, user=?, auth_method=?,
identity_file=?, proxy_jump=?, group_name=?, notes=?, updated_at=CURRENT_TIMESTAMP identity_file=?, proxy_jump=?, group_name=?, notes=?, startup_command=?, updated_at=CURRENT_TIMESTAMP
WHERE alias=?`, WHERE alias=?`,
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod, s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, oldAlias) s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, s.StartupCommand, oldAlias)
return err return err
} }
@ -51,12 +53,12 @@ func (db *DB) GetServer(alias string) (*model.Server, error) {
var lastConnected, lastTest sql.NullTime var lastConnected, lastTest sql.NullTime
err := db.conn.QueryRow(` err := db.conn.QueryRow(`
SELECT id, alias, display_name, host, port, user, auth_method, SELECT id, alias, display_name, host, port, user, auth_method,
identity_file, proxy_jump, group_name, notes, identity_file, proxy_jump, group_name, notes, startup_command,
created_at, updated_at, last_connected_at, created_at, updated_at, last_connected_at,
last_test_at, last_test_status, last_test_error last_test_at, last_test_status, last_test_error
FROM servers WHERE alias=?`, alias).Scan( FROM servers WHERE alias=?`, alias).Scan(
&s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod, &s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod,
&s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.StartupCommand,
&s.CreatedAt, &s.UpdatedAt, &lastConnected, &s.CreatedAt, &s.UpdatedAt, &lastConnected,
&lastTest, &s.LastTestStatus, &s.LastTestError) &lastTest, &s.LastTestStatus, &s.LastTestError)
if err != nil { if err != nil {
@ -68,13 +70,18 @@ func (db *DB) GetServer(alias string) (*model.Server, error) {
if lastTest.Valid { if lastTest.Valid {
s.LastTestAt = &lastTest.Time s.LastTestAt = &lastTest.Time
} }
tags, err := db.GetServerTags(s.ID)
if err != nil {
return nil, err
}
s.Tags = tags
return &s, nil return &s, nil
} }
func (db *DB) ListServers() ([]*model.Server, error) { func (db *DB) ListServers() ([]*model.Server, error) {
rows, err := db.conn.Query(` rows, err := db.conn.Query(`
SELECT id, alias, display_name, host, port, user, auth_method, SELECT id, alias, display_name, host, port, user, auth_method,
identity_file, proxy_jump, group_name, notes, identity_file, proxy_jump, group_name, notes, startup_command,
created_at, updated_at, last_connected_at, created_at, updated_at, last_connected_at,
last_test_at, last_test_status, last_test_error last_test_at, last_test_status, last_test_error
FROM servers ORDER BY alias`) FROM servers ORDER BY alias`)
@ -89,7 +96,7 @@ func (db *DB) ListServers() ([]*model.Server, error) {
var lastConnected, lastTest sql.NullTime var lastConnected, lastTest sql.NullTime
err := rows.Scan( err := rows.Scan(
&s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod, &s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod,
&s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.StartupCommand,
&s.CreatedAt, &s.UpdatedAt, &lastConnected, &s.CreatedAt, &s.UpdatedAt, &lastConnected,
&lastTest, &s.LastTestStatus, &s.LastTestError) &lastTest, &s.LastTestStatus, &s.LastTestError)
if err != nil { if err != nil {
@ -101,6 +108,11 @@ func (db *DB) ListServers() ([]*model.Server, error) {
if lastTest.Valid { if lastTest.Valid {
s.LastTestAt = &lastTest.Time s.LastTestAt = &lastTest.Time
} }
tags, err := db.GetServerTags(s.ID)
if err != nil {
return nil, err
}
s.Tags = tags
servers = append(servers, &s) servers = append(servers, &s)
} }
return servers, rows.Err() return servers, rows.Err()
@ -110,7 +122,7 @@ func (db *DB) SearchServers(query string) ([]*model.Server, error) {
pattern := "%" + query + "%" pattern := "%" + query + "%"
rows, err := db.conn.Query(` rows, err := db.conn.Query(`
SELECT id, alias, display_name, host, port, user, auth_method, SELECT id, alias, display_name, host, port, user, auth_method,
identity_file, proxy_jump, group_name, notes, identity_file, proxy_jump, group_name, notes, startup_command,
created_at, updated_at, last_connected_at, created_at, updated_at, last_connected_at,
last_test_at, last_test_status, last_test_error last_test_at, last_test_status, last_test_error
FROM servers FROM servers
@ -127,7 +139,7 @@ func (db *DB) SearchServers(query string) ([]*model.Server, error) {
var lastConnected, lastTest sql.NullTime var lastConnected, lastTest sql.NullTime
err := rows.Scan( err := rows.Scan(
&s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod, &s.ID, &s.Alias, &s.DisplayName, &s.Host, &s.Port, &s.User, &s.AuthMethod,
&s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.IdentityFile, &s.ProxyJump, &s.GroupName, &s.Notes, &s.StartupCommand,
&s.CreatedAt, &s.UpdatedAt, &lastConnected, &s.CreatedAt, &s.UpdatedAt, &lastConnected,
&lastTest, &s.LastTestStatus, &s.LastTestError) &lastTest, &s.LastTestStatus, &s.LastTestError)
if err != nil { if err != nil {
@ -139,6 +151,11 @@ func (db *DB) SearchServers(query string) ([]*model.Server, error) {
if lastTest.Valid { if lastTest.Valid {
s.LastTestAt = &lastTest.Time s.LastTestAt = &lastTest.Time
} }
tags, err := db.GetServerTags(s.ID)
if err != nil {
return nil, err
}
s.Tags = tags
servers = append(servers, &s) servers = append(servers, &s)
} }
return servers, rows.Err() return servers, rows.Err()
@ -158,6 +175,10 @@ func (db *DB) UpdateLastConnected(alias string) error {
// Tag methods // Tag methods
func (db *DB) AddTagToServer(serverID int64, tagName string) error { func (db *DB) AddTagToServer(serverID int64, tagName string) error {
tagName = strings.TrimSpace(tagName)
if tagName == "" {
return nil
}
var tagID int64 var tagID int64
err := db.conn.QueryRow("SELECT id FROM tags WHERE name=?", tagName).Scan(&tagID) err := db.conn.QueryRow("SELECT id FROM tags WHERE name=?", tagName).Scan(&tagID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -173,6 +194,55 @@ func (db *DB) AddTagToServer(serverID int64, tagName string) error {
return err return err
} }
func (db *DB) SetServerTags(serverID int64, tagNames []string) error {
if _, err := db.conn.Exec("DELETE FROM server_tags WHERE server_id=?", serverID); err != nil {
return err
}
for _, tagName := range uniqueCleanStrings(tagNames) {
if err := db.AddTagToServer(serverID, tagName); err != nil {
return err
}
}
return nil
}
func (db *DB) ListTags() ([]string, error) {
rows, err := db.conn.Query("SELECT name FROM tags ORDER BY name")
if err != nil {
return nil, err
}
defer rows.Close()
var tags []string
for rows.Next() {
var tag string
if err := rows.Scan(&tag); err != nil {
return nil, err
}
tags = append(tags, tag)
}
return tags, rows.Err()
}
func (db *DB) RenameTag(oldName, newName string) error {
oldName = strings.TrimSpace(oldName)
newName = strings.TrimSpace(newName)
if oldName == "" || newName == "" {
return nil
}
_, err := db.conn.Exec("UPDATE tags SET name=? WHERE name=?", newName, oldName)
return err
}
func (db *DB) DeleteTag(name string) error {
name = strings.TrimSpace(name)
if name == "" {
return nil
}
_, err := db.conn.Exec("DELETE FROM tags WHERE name=?", name)
return err
}
func (db *DB) GetServerTags(serverID int64) ([]string, error) { func (db *DB) GetServerTags(serverID int64) ([]string, error) {
rows, err := db.conn.Query(` rows, err := db.conn.Query(`
SELECT t.name FROM tags t SELECT t.name FROM tags t
@ -227,21 +297,33 @@ func (db *DB) GetForwards(serverID int64) ([]*model.Forward, error) {
// Ensure time import is used // Ensure time import is used
var _ time.Time var _ time.Time
// Command template methods func (db *DB) CreateCommandTemplate(t *model.CommandTemplate) error {
func (db *DB) AddCommandTemplate(serverID int64, name, command string) error { result, err := db.conn.Exec(
_, err := db.conn.Exec( "INSERT INTO global_command_templates (name, command, description) VALUES (?, ?, ?)",
"INSERT INTO command_templates (server_id, name, command) VALUES (?, ?, ?)", t.Name, t.Command, t.Description)
serverID, name, command) if err != nil {
return err
}
t.ID, _ = result.LastInsertId()
return err return err
} }
func (db *DB) GetCommandTemplates(serverAlias string) ([]*model.CommandTemplate, error) { func (db *DB) GetCommandTemplate(name string) (*model.CommandTemplate, error) {
var t model.CommandTemplate
err := db.conn.QueryRow(`
SELECT id, name, command, description
FROM global_command_templates WHERE name=?`, name).Scan(&t.ID, &t.Name, &t.Command, &t.Description)
if err != nil {
return nil, err
}
return &t, nil
}
func (db *DB) ListCommandTemplates() ([]*model.CommandTemplate, error) {
rows, err := db.conn.Query(` rows, err := db.conn.Query(`
SELECT ct.id, ct.server_id, ct.name, ct.command SELECT id, name, command, description
FROM command_templates ct FROM global_command_templates
JOIN servers s ON s.id = ct.server_id ORDER BY name`)
WHERE s.alias = ?
ORDER BY ct.name`, serverAlias)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -250,7 +332,7 @@ func (db *DB) GetCommandTemplates(serverAlias string) ([]*model.CommandTemplate,
var templates []*model.CommandTemplate var templates []*model.CommandTemplate
for rows.Next() { for rows.Next() {
var t model.CommandTemplate var t model.CommandTemplate
if err := rows.Scan(&t.ID, &t.ServerID, &t.Name, &t.Command); err != nil { if err := rows.Scan(&t.ID, &t.Name, &t.Command, &t.Description); err != nil {
return nil, err return nil, err
} }
templates = append(templates, &t) templates = append(templates, &t)
@ -258,6 +340,34 @@ func (db *DB) GetCommandTemplates(serverAlias string) ([]*model.CommandTemplate,
return templates, rows.Err() return templates, rows.Err()
} }
func (db *DB) UpdateCommandTemplate(oldName string, t *model.CommandTemplate) error {
_, err := db.conn.Exec(`
UPDATE global_command_templates
SET name=?, command=?, description=?, updated_at=CURRENT_TIMESTAMP
WHERE name=?`, t.Name, t.Command, t.Description, oldName)
return err
}
func (db *DB) DeleteCommandTemplate(name string) error {
_, err := db.conn.Exec("DELETE FROM global_command_templates WHERE name=?", name)
return err
}
func uniqueCleanStrings(values []string) []string {
seen := map[string]bool{}
var result []string
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" || seen[value] {
continue
}
seen[value] = true
result = append(result, value)
}
sort.Strings(result)
return result
}
// GetGroups returns all unique group names with server count // GetGroups returns all unique group names with server count
func (db *DB) GetGroups() ([]string, error) { func (db *DB) GetGroups() ([]string, error) {
rows, err := db.conn.Query(` rows, err := db.conn.Query(`

View File

@ -45,3 +45,190 @@ func TestUpdateServerByAliasCanRenameAlias(t *testing.T) {
t.Fatalf("unexpected updated server: %#v", got) t.Fatalf("unexpected updated server: %#v", got)
} }
} }
func TestServerPersistsStartupCommandAndTags(t *testing.T) {
db, err := Open(t.TempDir())
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
server := &model.Server{
Alias: "prod",
DisplayName: "Production",
Host: "prod.example",
Port: 2222,
User: "deploy",
AuthMethod: model.AuthKey,
StartupCommand: "tmux attach -t ops",
Tags: []string{"prod", "db"},
}
if err := db.CreateServer(server); err != nil {
t.Fatalf("create server: %v", err)
}
if err := db.SetServerTags(server.ID, server.Tags); err != nil {
t.Fatalf("set tags: %v", err)
}
got, err := db.GetServer("prod")
if err != nil {
t.Fatalf("get server: %v", err)
}
if got.StartupCommand != "tmux attach -t ops" {
t.Fatalf("startup command = %q", got.StartupCommand)
}
if len(got.Tags) != 2 || got.Tags[0] != "db" || got.Tags[1] != "prod" {
t.Fatalf("unexpected tags: %#v", got.Tags)
}
got.StartupCommand = "uptime"
got.Tags = []string{"web"}
if err := db.UpdateServerByAlias("prod", got); err != nil {
t.Fatalf("update server: %v", err)
}
if err := db.SetServerTags(got.ID, got.Tags); err != nil {
t.Fatalf("replace tags: %v", err)
}
reopened, err := db.GetServer("prod")
if err != nil {
t.Fatalf("get updated server: %v", err)
}
if reopened.StartupCommand != "uptime" {
t.Fatalf("updated startup command = %q", reopened.StartupCommand)
}
if len(reopened.Tags) != 1 || reopened.Tags[0] != "web" {
t.Fatalf("updated tags: %#v", reopened.Tags)
}
}
func TestGlobalCommandTemplateCRUD(t *testing.T) {
db, err := Open(t.TempDir())
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
tpl := &model.CommandTemplate{Name: "uptime", Command: "uptime"}
if err := db.CreateCommandTemplate(tpl); err != nil {
t.Fatalf("create template: %v", err)
}
if tpl.ID == 0 {
t.Fatal("expected template ID")
}
got, err := db.GetCommandTemplate("uptime")
if err != nil {
t.Fatalf("get template: %v", err)
}
if got.Command != "uptime" {
t.Fatalf("template command = %q", got.Command)
}
got.Name = "load"
got.Command = "cat /proc/loadavg"
if err := db.UpdateCommandTemplate("uptime", got); err != nil {
t.Fatalf("update template: %v", err)
}
if _, err := db.GetCommandTemplate("uptime"); err == nil {
t.Fatal("expected old template name to be gone")
}
updated, err := db.GetCommandTemplate("load")
if err != nil {
t.Fatalf("get renamed template: %v", err)
}
if updated.Command != "cat /proc/loadavg" {
t.Fatalf("updated command = %q", updated.Command)
}
templates, err := db.ListCommandTemplates()
if err != nil {
t.Fatalf("list templates: %v", err)
}
if len(templates) != 1 || templates[0].Name != "load" {
t.Fatalf("unexpected templates: %#v", templates)
}
if err := db.DeleteCommandTemplate("load"); err != nil {
t.Fatalf("delete template: %v", err)
}
if templates, err := db.ListCommandTemplates(); err != nil || len(templates) != 0 {
t.Fatalf("expected no templates, got %#v err %v", templates, err)
}
}
func TestLegacyCommandTemplatesAreCopiedToGlobalTemplates(t *testing.T) {
db, err := Open(t.TempDir())
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
server := &model.Server{Alias: "prod", Host: "prod.example", Port: 22, User: "root", AuthMethod: model.AuthKey}
if err := db.CreateServer(server); err != nil {
t.Fatalf("create server: %v", err)
}
if _, err := db.conn.Exec(
"INSERT INTO command_templates (server_id, name, command) VALUES (?, ?, ?)",
server.ID, "legacy-uptime", "uptime",
); err != nil {
t.Fatalf("insert legacy template: %v", err)
}
if err := db.ensureSchema(); err != nil {
t.Fatalf("ensure schema: %v", err)
}
got, err := db.GetCommandTemplate("legacy-uptime")
if err != nil {
t.Fatalf("get copied template: %v", err)
}
if got.Command != "uptime" {
t.Fatalf("copied command = %q", got.Command)
}
}
func TestTagManagementCRUD(t *testing.T) {
db, err := Open(t.TempDir())
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
server := &model.Server{Alias: "prod", Host: "prod.example", Port: 22, User: "root", AuthMethod: model.AuthKey}
if err := db.CreateServer(server); err != nil {
t.Fatalf("create server: %v", err)
}
if err := db.SetServerTags(server.ID, []string{"prod", "web"}); err != nil {
t.Fatalf("set tags: %v", err)
}
tags, err := db.ListTags()
if err != nil {
t.Fatalf("list tags: %v", err)
}
if len(tags) != 2 || tags[0] != "prod" || tags[1] != "web" {
t.Fatalf("unexpected tags: %#v", tags)
}
if err := db.RenameTag("web", "frontend"); err != nil {
t.Fatalf("rename tag: %v", err)
}
got, err := db.GetServer("prod")
if err != nil {
t.Fatalf("get server: %v", err)
}
if len(got.Tags) != 2 || got.Tags[0] != "frontend" || got.Tags[1] != "prod" {
t.Fatalf("renamed tags: %#v", got.Tags)
}
if err := db.DeleteTag("prod"); err != nil {
t.Fatalf("delete tag: %v", err)
}
got, err = db.GetServer("prod")
if err != nil {
t.Fatalf("get after delete: %v", err)
}
if len(got.Tags) != 1 || got.Tags[0] != "frontend" {
t.Fatalf("tags after delete: %#v", got.Tags)
}
}

View File

@ -31,6 +31,7 @@ type Server struct {
ProxyJump string `json:"proxy_jump"` ProxyJump string `json:"proxy_jump"`
GroupName string `json:"group_name"` GroupName string `json:"group_name"`
Notes string `json:"notes"` Notes string `json:"notes"`
StartupCommand string `json:"startup_command"`
Tags []string `json:"tags"` Tags []string `json:"tags"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
@ -50,10 +51,10 @@ const (
) )
type Secret struct { type Secret struct {
ID string `json:"id"` ID string `json:"id"`
Type SecretType `json:"type"` Type SecretType `json:"type"`
Nonce []byte `json:"nonce"` Nonce []byte `json:"nonce"`
Data []byte `json:"data"` Data []byte `json:"data"`
} }
type ForwardType string type ForwardType string
@ -80,8 +81,9 @@ type Tag struct {
} }
type CommandTemplate struct { type CommandTemplate struct {
ID int64 `json:"id"` ID int64 `json:"id"`
ServerID int64 `json:"server_id"` ServerID int64 `json:"server_id"`
Name string `json:"name"` Name string `json:"name"`
Command string `json:"command"` Command string `json:"command"`
Description string `json:"description"`
} }

View File

@ -14,6 +14,9 @@ type VaultFunc func(serverAlias string, secretType string) (string, error)
func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error { func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error {
args := BuildSSHArgs(server) args := BuildSSHArgs(server)
if strings.TrimSpace(server.StartupCommand) != "" {
args = append(args, server.StartupCommand)
}
switch server.AuthMethod { switch server.AuthMethod {
case model.AuthPassword: case model.AuthPassword:
@ -45,6 +48,73 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
} }
} }
func RunCommand(cfg *config.Config, server *model.Server, getVault VaultFunc, command string) error {
args := BuildSSHArgs(server)
args = append(args, command)
switch server.AuthMethod {
case model.AuthPassword:
password, err := getVault(server.Alias, "ssh_password")
if err != nil {
return fmt.Errorf("get password from vault: %w", err)
}
return ConnectWithPassword(cfg.SSH.Binary, args, password)
case model.AuthKeyPassphrase:
passphrase, err := getVault(server.Alias, "key_passphrase")
if err != nil {
return fmt.Errorf("get key passphrase from vault: %w", err)
}
return ConnectWithPassword(cfg.SSH.Binary, args, passphrase)
default:
cmd := exec.Command(cfg.SSH.Binary, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return fmt.Errorf("start ssh: %w", err)
}
return cmd.Wait()
}
}
func RunCommandOutput(cfg *config.Config, server *model.Server, getVault VaultFunc, command string) (string, error) {
args := BuildSSHArgs(server)
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec))
switch server.AuthMethod {
case model.AuthPassword:
args = append(args, "-o", "NumberOfPasswordPrompts=1", command)
password, err := getVault(server.Alias, "ssh_password")
if err != nil {
return "", fmt.Errorf("get password from vault: %w", err)
}
ok, output := connectWithPasswordAndRead(cfg.SSH.Binary, args, password, cfg.SSH.ConnectTimeoutSec)
if !ok {
return output, fmt.Errorf("ssh command failed")
}
return output, nil
case model.AuthKeyPassphrase:
args = append(args, "-o", "NumberOfPasswordPrompts=1", command)
passphrase, err := getVault(server.Alias, "key_passphrase")
if err != nil {
return "", fmt.Errorf("get key passphrase from vault: %w", err)
}
ok, output := connectWithPasswordAndRead(cfg.SSH.Binary, args, passphrase, cfg.SSH.ConnectTimeoutSec)
if !ok {
return output, fmt.Errorf("ssh command failed")
}
return output, nil
default:
args = append(args, "-o", "BatchMode=yes", command)
cmd := exec.Command(cfg.SSH.Binary, args...)
output, err := cmd.CombinedOutput()
if err != nil {
return string(output), err
}
return string(output), nil
}
}
func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) { func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) {
args := BuildSSHArgs(server) args := BuildSSHArgs(server)
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec)) args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec))

View File

@ -81,3 +81,36 @@ func TestKeyPassphraseTestReportsVaultError(t *testing.T) {
t.Fatalf("expected vault error, got %q", errText) t.Fatalf("expected vault error, got %q", errText)
} }
} }
func TestConnectRunsStartupCommand(t *testing.T) {
dir := t.TempDir()
argsFile := filepath.Join(dir, "args")
script := filepath.Join(dir, "fake-ssh")
if err := os.WriteFile(script, []byte(fmt.Sprintf(`#!/bin/sh
printf '%%s\n' "$@" > %q
`, argsFile)), 0o700); err != nil {
t.Fatalf("write fake ssh: %v", err)
}
cfg := &config.Config{SSH: config.SSHConfig{Binary: script}}
server := &model.Server{
Alias: "prod",
Host: "example.org",
Port: 22,
User: "root",
AuthMethod: model.AuthKey,
StartupCommand: "tmux attach -t ops",
}
if err := Connect(cfg, server, nil); err != nil {
t.Fatalf("connect: %v", err)
}
data, err := os.ReadFile(argsFile)
if err != nil {
t.Fatalf("read args: %v", err)
}
if !strings.Contains(string(data), "tmux attach -t ops") {
t.Fatalf("expected startup command in ssh args, got:\n%s", data)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@ import (
"github.com/charmbracelet/bubbles/list" "github.com/charmbracelet/bubbles/list"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
) )
@ -58,7 +59,8 @@ func TestServerListViewUsesDashboardLayout(t *testing.T) {
"Alias: mail.kp", "Alias: mail.kp",
"Display Name: Mail", "Display Name: Mail",
"Port: 222", "Port: 222",
"Enter connect", "Enter",
"connect",
} { } {
if !strings.Contains(view, want) { if !strings.Contains(view, want) {
t.Fatalf("expected list view to contain %q\nview:\n%s", want, view) t.Fatalf("expected list view to contain %q\nview:\n%s", want, view)
@ -94,7 +96,7 @@ func TestServerListViewKeepsDetailsVisibleWithManyServers(t *testing.T) {
if !strings.Contains(view, "Selected") { if !strings.Contains(view, "Selected") {
t.Fatalf("expected selected server details to remain visible:\n%s", view) t.Fatalf("expected selected server details to remain visible:\n%s", view)
} }
if !strings.Contains(view, "Enter connect") { if !strings.Contains(view, "Enter") || !strings.Contains(view, "connect") {
t.Fatalf("expected footer to remain visible:\n%s", view) t.Fatalf("expected footer to remain visible:\n%s", view)
} }
if count := strings.Count(view, "server-"); count >= len(servers) { if count := strings.Count(view, "server-"); count >= len(servers) {
@ -102,6 +104,84 @@ func TestServerListViewKeepsDetailsVisibleWithManyServers(t *testing.T) {
} }
} }
func TestServerListHelpWrapsOnNarrowTerminal(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
m.width = 72
m.height = 24
view := m.View()
for _, line := range strings.Split(view, "\n") {
if strings.Contains(line, "Enter") && strings.Contains(line, "connect") && lipgloss.Width(line) > 72 {
t.Fatalf("expected help line to be bounded, got width %d: %q\nview:\n%s", lipgloss.Width(line), line, view)
}
}
for _, want := range []string{"Ctrl+R", "run tpl", "Ctrl+P", "tpl mgr"} {
if !strings.Contains(view, want) {
t.Fatalf("expected help to contain %q\nview:\n%s", want, view)
}
}
if strings.Contains(view, "Shift+Enter") {
t.Fatalf("expected help to omit Shift+Enter\nview:\n%s", view)
}
}
func TestServerListHelpWrapsSelectionAndResultHints(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
m.width = 90
lines := wrapHelpItems(m.listHelpItems(2, true), m.width-2)
if len(lines) < 2 {
t.Fatalf("expected wrapped help, got %#v", lines)
}
for _, line := range lines {
plain := plainHelpLine(line)
if len(plain) > m.width-2 {
t.Fatalf("help line too long: len=%d line=%q", len(plain), plain)
}
}
var plainLines []string
for _, line := range lines {
plainLines = append(plainLines, plainHelpLine(line))
}
joined := strings.Join(plainLines, "\n")
for _, want := range []string{"Ins: select (2 selected)", "Esc: clear result", "Ctrl+P: tpl mgr", "Ctrl+Q: quit"} {
if !strings.Contains(joined, want) {
t.Fatalf("expected wrapped help to contain %q\nlines:%#v", want, lines)
}
}
}
func TestServerListFooterUsesColonFormatAndColoredHotkeys(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
{Alias: "two", Host: "two.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
m.width = 90
m.height = 30
m.selected["one"] = true
view := m.View()
for _, want := range []string{"Ins", ": select (1 selected)", "Ctrl+A", ": add"} {
if !strings.Contains(view, want) {
t.Fatalf("expected footer to contain %q\nview:\n%s", want, view)
}
}
if hotkeyStyle.GetForeground() == nil {
t.Fatal("expected hotkey style to define a foreground color")
}
lines := strings.Split(view, "\n")
if got := len(lines); got != m.height {
t.Fatalf("expected footer to be pinned to bottom with %d lines, got %d\nview:\n%s", m.height, got, view)
}
if !strings.Contains(lines[len(lines)-1], "Ctrl+Q") {
t.Fatalf("expected final footer line at terminal bottom, got %q\nview:\n%s", lines[len(lines)-1], view)
}
}
func TestVisibleServerRangeKeepsSelectionInsideWindow(t *testing.T) { func TestVisibleServerRangeKeepsSelectionInsideWindow(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -317,7 +397,7 @@ func TestSelectableFieldHintsAreVisible(t *testing.T) {
fm := newFormModel(80, 24) fm := newFormModel(80, 24)
view := fm.View() view := fm.View()
for _, want := range []string{"Auth Method (/ pick)", "Group (/ pick)", "/ pick list"} { for _, want := range []string{"Auth Method (/ pick)", "Group (/ pick)", "pick list"} {
if !strings.Contains(view, want) { if !strings.Contains(view, want) {
t.Fatalf("expected form view to contain selectable-field hint %q\nview:\n%s", want, view) t.Fatalf("expected form view to contain selectable-field hint %q\nview:\n%s", want, view)
} }
@ -409,3 +489,209 @@ func TestFormTestResultDoesNotUpdateSelectedListServer(t *testing.T) {
t.Fatal("expected form to keep its test result") t.Fatal("expected form to keep its test result")
} }
} }
func TestServerFormBuildsStartupCommandAndTags(t *testing.T) {
fm := newFormModel(100, 30)
fm.inputs[0].SetValue("prod")
fm.inputs[2].SetValue("prod.example.org")
fm.inputs[3].SetValue("22")
fm.inputs[4].SetValue("root")
fm.inputs[5].SetValue(string(model.AuthKey))
fm.inputs[10].SetValue("tmux attach -t ops")
fm.inputs[11].SetValue("prod, web, prod")
server := fm.buildServer()
if server.StartupCommand != "tmux attach -t ops" {
t.Fatalf("startup command = %q", server.StartupCommand)
}
if got := strings.Join(server.Tags, ","); got != "prod,web" {
t.Fatalf("tags = %q", got)
}
}
func TestInsertTogglesServerSelection(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
{Alias: "two", Host: "two.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
updated, _ := m.updateList(tea.KeyMsg{Type: tea.KeyInsert})
model := updated.(*tuiModel)
if !model.selected["one"] {
t.Fatal("expected Insert to select current server")
}
if model.list.Index() != 1 {
t.Fatalf("expected Insert to advance to next server, index = %d", model.list.Index())
}
updated, _ = model.updateList(tea.KeyMsg{Type: tea.KeyInsert})
model = updated.(*tuiModel)
if !model.selected["two"] {
t.Fatal("expected second Insert to select next server")
}
}
func TestCtrlROpensTemplatePicker(t *testing.T) {
oldListCommandTemplates := ListCommandTemplates
defer func() { ListCommandTemplates = oldListCommandTemplates }()
ListCommandTemplates = func() ([]*model.CommandTemplate, error) {
return []*model.CommandTemplate{{Name: "uptime", Command: "uptime"}}, nil
}
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
updated, cmd := m.updateList(tea.KeyMsg{Type: tea.KeyCtrlR})
model := updated.(*tuiModel)
if model.screen != screenTemplatePicker {
t.Fatalf("expected Ctrl+R to open template picker, screen = %v", model.screen)
}
if cmd == nil {
t.Fatal("expected Ctrl+R to load templates")
}
}
func TestShiftEnterDoesNotOpenTemplatePicker(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
updated, _ := m.updateList(tea.KeyMsg{Type: tea.KeyEnter})
model := updated.(*tuiModel)
if model.screen == screenTemplatePicker {
t.Fatal("Shift+Enter fallback should not be wired; Ctrl+R is the template shortcut")
}
}
func TestTemplatePickerForegroundUsesSelectedServers(t *testing.T) {
servers := []*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
{Alias: "two", Host: "two.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
}
m := New(servers)
m.selected["one"] = true
m.selected["two"] = true
m.pendingTemplate = &model.CommandTemplate{Name: "uptime", Command: "uptime"}
m.screen = screenTemplateMode
updated, cmd := m.updateTemplateMode(tea.KeyMsg{Type: tea.KeyEnter})
model := updated.(*tuiModel)
if cmd == nil {
t.Fatal("expected foreground template mode to return a command")
}
msg := cmd()
req, ok := msg.(templateRunRequestMsg)
if !ok {
t.Fatalf("expected templateRunRequestMsg, got %T", msg)
}
if len(req.servers) != 2 || req.command != "uptime" || req.templateName != "uptime" {
t.Fatalf("unexpected template request: %#v", req)
}
model.Update(req)
if model.Result() == nil || model.Result().Action != "run_template_foreground" {
t.Fatalf("expected foreground result, got %#v", model.Result())
}
}
func TestTemplateModeUsesControlShortcuts(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
m.pendingTemplate = &model.CommandTemplate{Name: "uptime", Command: "uptime"}
m.screen = screenTemplateMode
_, cmd := m.updateTemplateMode(tea.KeyMsg{Type: tea.KeyCtrlF})
if cmd == nil {
t.Fatal("expected Ctrl+F to start foreground run")
}
called := false
oldRunTemplateBackground := RunTemplateBackground
RunTemplateBackground = func(server *model.Server, command string) (string, error) {
called = true
return "ok", nil
}
defer func() { RunTemplateBackground = oldRunTemplateBackground }()
_, cmd = m.updateTemplateMode(tea.KeyMsg{Type: tea.KeyCtrlB})
if cmd == nil {
t.Fatal("expected Ctrl+B to start background run")
}
cmd()
if !called {
t.Fatal("expected Ctrl+B command to call background runner")
}
view := m.viewTemplateMode()
for _, want := range []string{"Ctrl+F (Enter)", "Foreground", "Ctrl+B", "Background"} {
if !strings.Contains(view, want) {
t.Fatalf("expected view to contain %q\nview:\n%s", want, view)
}
}
}
func TestBackgroundRunReturnsToListAndShowsResultPanel(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
m.screen = screenTemplateMode
updated, _ := m.Update(backgroundRunDoneMsg{results: []templateRunResult{
{Alias: "one", Output: "Distributor ID:\tDebian\nRelease:\t11"},
}})
model := updated.(*tuiModel)
if model.screen != screenList {
t.Fatalf("expected background result to return to server list, got screen %v", model.screen)
}
view := model.View()
for _, want := range []string{"sshkeeper", "Last Background Run", "one", "OK", "Distributor ID:", "Release:"} {
if !strings.Contains(view, want) {
t.Fatalf("expected view to contain %q\nview:\n%s", want, view)
}
}
if strings.Contains(view, "Background Results") {
t.Fatalf("expected inline list panel, not standalone background screen\nview:\n%s", view)
}
}
func TestBackgroundRunShowsOutputForSelectedServerInMultiRun(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
{Alias: "two", Host: "two.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
m.bgResults = []templateRunResult{
{Alias: "one", Output: "one output"},
{Alias: "two", Output: "two output"},
}
view := m.View()
if !strings.Contains(view, "one output") || strings.Contains(view, "two output") {
t.Fatalf("expected selected server output only\nview:\n%s", view)
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyDown})
model := updated.(*tuiModel)
view = model.View()
if !strings.Contains(view, "two output") || strings.Contains(view, "one output") {
t.Fatalf("expected output to follow selected server\nview:\n%s", view)
}
}
func TestBackgroundOutputLinesArePaddedAndTabsExpanded(t *testing.T) {
m := New([]*model.Server{
{Alias: "one", Host: "one.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey},
})
m.width = 48
m.bgResults = []templateRunResult{{Alias: "one", Output: "Distributor ID:\tDebian"}}
view := m.View()
if strings.Contains(view, "\t") {
t.Fatalf("expected tabs to be expanded\nview:\n%s", view)
}
for _, line := range strings.Split(view, "\n") {
if strings.Contains(line, "Distributor ID:") && len(line) < 48 {
t.Fatalf("expected output line to be padded to clear stale chars, len=%d line=%q", len(line), line)
}
}
}