feat: add interactive server creation

This commit is contained in:
mirivlad 2026-05-28 13:51:37 +08:00
parent 4b996032a9
commit 0bd4463819
9 changed files with 246 additions and 28 deletions

View File

@ -44,11 +44,14 @@ sshkeeper init
## Common CLI Commands ## Common CLI Commands
```bash ```bash
# Add profiles # Add profiles with flags
sshkeeper add web --host 10.0.0.10 --user deploy --auth key sshkeeper add web --host 10.0.0.10 --user deploy --auth key
sshkeeper add prod --host 10.0.0.20 --user root --auth password sshkeeper add prod --host 10.0.0.20 --user root --auth password
sshkeeper add bastion --host bastion.example.org --user admin --auth key_passphrase --identity-file ~/.ssh/id_rsa sshkeeper add bastion --host bastion.example.org --user admin --auth key_passphrase --identity-file ~/.ssh/id_rsa
# Or use the interactive CLI prompt
sshkeeper add
# Inspect profiles # Inspect profiles
sshkeeper list sshkeeper list
sshkeeper show web sshkeeper show web
@ -71,7 +74,10 @@ sshkeeper ssh-config install-include
Commands that only read profile metadata, such as `list`, `show`, `search`, Commands that only read profile metadata, such as `list`, `show`, `search`,
`config path`, `group list`, and `export`, do not require the master password. `config path`, `group list`, and `export`, do not require the master password.
Commands that need secrets ask for the master password in that process. Commands that need secrets ask for the master password in that process. Adding
`key` or `agent` profiles does not require unlocking the vault; adding
`password` or `key_passphrase` profiles asks for the master password before
storing the secret.
## TUI ## TUI

View File

@ -1,12 +1,16 @@
package cmd package cmd
import ( import (
"bufio"
"fmt" "fmt"
"io"
"os"
"strconv"
"strings" "strings"
"syscall" "syscall"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
"github.com/spf13/cobra"
"golang.org/x/term" "golang.org/x/term"
) )
@ -32,10 +36,18 @@ var addCmd = &cobra.Command{
if len(args) == 1 && addFlags.host != "" { if len(args) == 1 && addFlags.host != "" {
return addNonInteractive(args[0]) return addNonInteractive(args[0])
} }
return fmt.Errorf("interactive add not yet implemented, use: sshkeeper add <alias> --host <host> --user <user> --auth <method>") return addInteractive()
}, },
} }
func addInteractive() error {
server, err := promptServerForAdd(os.Stdin, os.Stdout)
if err != nil {
return err
}
return saveServerWithOptionalSecret(server)
}
func addNonInteractive(alias string) error { func addNonInteractive(alias string) error {
server := &model.Server{ server := &model.Server{
Alias: alias, Alias: alias,
@ -60,6 +72,10 @@ func addNonInteractive(alias string) error {
server.DisplayName = alias server.DisplayName = alias
} }
return saveServerWithOptionalSecret(server)
}
func saveServerWithOptionalSecret(server *model.Server) error {
// Handle password/passphrase auth — request interactively, never via argv // Handle password/passphrase auth — request interactively, never via argv
if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase { if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
secretType := "password" secretType := "password"
@ -78,14 +94,14 @@ func addNonInteractive(alias string) error {
} }
v := getOrCreateVault() v := getOrCreateVault()
if !v.IsUnlocked() { if err := unlockVaultForCommand(v); err != nil {
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") return err
} }
vaultKey := fmt.Sprintf("server:%s:ssh_password", alias) vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias)
vaultType := "ssh_password" vaultType := "ssh_password"
if server.AuthMethod == model.AuthKeyPassphrase { if server.AuthMethod == model.AuthKeyPassphrase {
vaultKey = fmt.Sprintf("server:%s:key_passphrase", alias) vaultKey = fmt.Sprintf("server:%s:key_passphrase", server.Alias)
vaultType = "key_passphrase" vaultType = "key_passphrase"
} }
@ -117,6 +133,111 @@ func addNonInteractive(alias string) error {
return nil return nil
} }
func promptServerForAdd(in io.Reader, out io.Writer) (*model.Server, error) {
reader := bufio.NewReader(in)
alias, err := promptRequired(reader, out, "Alias")
if err != nil {
return nil, err
}
displayName, err := promptOptional(reader, out, "Display name", alias)
if err != nil {
return nil, err
}
host, err := promptRequired(reader, out, "Host")
if err != nil {
return nil, err
}
portText, err := promptOptional(reader, out, "Port", "22")
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portText)
if err != nil || port <= 0 {
return nil, fmt.Errorf("invalid port: %s", portText)
}
user, err := promptOptional(reader, out, "User", "root")
if err != nil {
return nil, err
}
authText, err := promptOptional(reader, out, "Auth method (password/key/key_passphrase/agent)", string(model.AuthKey))
if err != nil {
return nil, err
}
authMethod := model.AuthMethod(authText)
if !isSupportedAuthMethod(authMethod) {
return nil, fmt.Errorf("unsupported auth method: %s", authText)
}
identityFile, err := promptOptional(reader, out, "Identity file", "")
if err != nil {
return nil, err
}
proxyJump, err := promptOptional(reader, out, "ProxyJump", "")
if err != nil {
return nil, err
}
groupName, err := promptOptional(reader, out, "Group", "")
if err != nil {
return nil, err
}
notes, err := promptOptional(reader, out, "Notes", "")
if err != nil {
return nil, err
}
return &model.Server{
Alias: alias,
DisplayName: displayName,
Host: host,
Port: port,
User: user,
AuthMethod: authMethod,
IdentityFile: identityFile,
ProxyJump: proxyJump,
GroupName: groupName,
Notes: notes,
}, nil
}
func promptRequired(reader *bufio.Reader, out io.Writer, label string) (string, error) {
for {
value, err := promptOptional(reader, out, label, "")
if err != nil {
return "", err
}
if value != "" {
return value, nil
}
fmt.Fprintf(out, "%s is required.\n", label)
}
}
func promptOptional(reader *bufio.Reader, out io.Writer, label string, defaultValue string) (string, error) {
if defaultValue == "" {
fmt.Fprintf(out, "%s: ", label)
} else {
fmt.Fprintf(out, "%s [%s]: ", label, defaultValue)
}
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
return "", err
}
value := strings.TrimSpace(line)
if value == "" {
return defaultValue, nil
}
return value, nil
}
func isSupportedAuthMethod(method model.AuthMethod) bool {
switch method {
case model.AuthPassword, model.AuthKey, model.AuthKeyPassphrase, model.AuthAgent:
return true
default:
return false
}
}
func init() { func init() {
addCmd.Flags().StringVar(&addFlags.host, "host", "", "Server hostname or IP") addCmd.Flags().StringVar(&addFlags.host, "host", "", "Server hostname or IP")
addCmd.Flags().IntVar(&addFlags.port, "port", 22, "SSH port") addCmd.Flags().IntVar(&addFlags.port, "port", 22, "SSH port")

82
cmd/add_test.go Normal file
View File

@ -0,0 +1,82 @@
package cmd
import (
"bytes"
"strings"
"testing"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestPromptServerForAddCollectsInteractiveFields(t *testing.T) {
input := strings.NewReader(strings.Join([]string{
"prod",
"Production",
"prod.example.org",
"2222",
"deploy",
"key",
"~/.ssh/id_prod",
"bastion",
"prod",
"critical host",
"",
}, "\n"))
var output bytes.Buffer
server, err := promptServerForAdd(input, &output)
if err != nil {
t.Fatalf("prompt server: %v", err)
}
if server.Alias != "prod" ||
server.DisplayName != "Production" ||
server.Host != "prod.example.org" ||
server.Port != 2222 ||
server.User != "deploy" ||
server.AuthMethod != model.AuthKey ||
server.IdentityFile != "~/.ssh/id_prod" ||
server.ProxyJump != "bastion" ||
server.GroupName != "prod" ||
server.Notes != "critical host" {
t.Fatalf("unexpected server: %#v", server)
}
if strings.Contains(output.String(), "not yet implemented") {
t.Fatalf("interactive add should not report unimplemented:\n%s", output.String())
}
}
func TestPromptServerForAddAppliesDefaults(t *testing.T) {
input := strings.NewReader(strings.Join([]string{
"prod",
"",
"prod.example.org",
"",
"",
"",
"",
"",
"",
"",
"",
}, "\n"))
var output bytes.Buffer
server, err := promptServerForAdd(input, &output)
if err != nil {
t.Fatalf("prompt server: %v", err)
}
if server.DisplayName != "prod" {
t.Fatalf("display name default = %q; want alias", server.DisplayName)
}
if server.Port != 22 {
t.Fatalf("port default = %d; want 22", server.Port)
}
if server.User != "root" {
t.Fatalf("user default = %q; want root", server.User)
}
if server.AuthMethod != model.AuthKey {
t.Fatalf("auth default = %q; want key", server.AuthMethod)
}
}

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/ssh" "github.com/mirivlad/sshkeeper/internal/ssh"
"github.com/spf13/cobra"
) )
var connectCmd = &cobra.Command{ var connectCmd = &cobra.Command{
@ -23,7 +23,7 @@ var connectCmd = &cobra.Command{
v := getOrCreateVault() v := getOrCreateVault()
vaultFunc := func(serverAlias string, secretType string) (string, error) { vaultFunc := func(serverAlias string, secretType string) (string, error) {
if !v.IsUnlocked() { if !v.IsUnlocked() {
return "", fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") return "", fmt.Errorf("%s", vaultLockedProcessMessage())
} }
key := fmt.Sprintf("server:%s:%s", serverAlias, secretType) key := fmt.Sprintf("server:%s:%s", serverAlias, secretType)
data, err := v.Get(key) data, err := v.Get(key)
@ -64,7 +64,7 @@ var testCmd = &cobra.Command{
v := getOrCreateVault() v := getOrCreateVault()
vaultFunc := func(serverAlias string, secretType string) (string, error) { vaultFunc := func(serverAlias string, secretType string) (string, error) {
if !v.IsUnlocked() { if !v.IsUnlocked() {
return "", fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") return "", fmt.Errorf("%s", vaultLockedProcessMessage())
} }
key := fmt.Sprintf("server:%s:%s", serverAlias, secretType) key := fmt.Sprintf("server:%s:%s", serverAlias, secretType)
data, err := v.Get(key) data, err := v.Get(key)

View File

@ -6,9 +6,9 @@ import (
"os/exec" "os/exec"
"strings" "strings"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/ssh" "github.com/mirivlad/sshkeeper/internal/ssh"
"github.com/spf13/cobra"
) )
var importCmd = &cobra.Command{ var importCmd = &cobra.Command{
@ -74,9 +74,8 @@ var runCmd = &cobra.Command{
return fmt.Errorf("server not found: %s", alias) return fmt.Errorf("server not found: %s", alias)
} }
// For password auth, use PTY-wrapper with command if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
if server.AuthMethod == model.AuthPassword { return runWithSecret(server, command)
return runWithPassword(server, command)
} }
// For key/agent auth — direct execution // For key/agent auth — direct execution
@ -96,21 +95,25 @@ var runCmd = &cobra.Command{
}, },
} }
// runWithPassword runs a command on a server with password auth via PTY-wrapper. // runWithSecret runs a command on a server through the PTY prompt handler.
func runWithPassword(server *model.Server, command string) error { func runWithSecret(server *model.Server, command string) error {
v := getOrCreateVault() v := getOrCreateVault()
if !v.IsUnlocked() { if !v.IsUnlocked() {
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") return fmt.Errorf("%s", vaultLockedProcessMessage())
} }
vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias) secretType := "ssh_password"
password, err := v.Get(vaultKey) if server.AuthMethod == model.AuthKeyPassphrase {
secretType = "key_passphrase"
}
vaultKey := fmt.Sprintf("server:%s:%s", server.Alias, secretType)
secret, err := v.Get(vaultKey)
if err != nil { if err != nil {
return fmt.Errorf("get password from vault: %w", err) return fmt.Errorf("get %s from vault: %w", secretType, err)
} }
sshArgs := ssh.BuildSSHArgs(server) sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command) sshArgs = append(sshArgs, command)
return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(password)) return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(secret))
} }

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/config"
"github.com/mirivlad/sshkeeper/internal/db" "github.com/mirivlad/sshkeeper/internal/db"
"github.com/spf13/cobra"
) )
var initCmd = &cobra.Command{ var initCmd = &cobra.Command{
@ -47,7 +47,7 @@ var initCmd = &cobra.Command{
fmt.Printf("Created database: %s/sshkeeper.db\n", cfg.DataDir) fmt.Printf("Created database: %s/sshkeeper.db\n", cfg.DataDir)
fmt.Printf("Created vault: %s/vault.bin\n", cfg.DataDir) fmt.Printf("Created vault: %s/vault.bin\n", cfg.DataDir)
fmt.Println() fmt.Println()
fmt.Println("Next step: run 'sshkeeper vault unlock' to set master password.") fmt.Println("Next step: run 'sshkeeper' or any command that needs secrets to create the vault master password.")
return nil return nil
}, },
} }

View File

@ -122,7 +122,7 @@ func initApp() {
vaultInstance = v vaultInstance = v
fmt.Println() fmt.Println()
fmt.Println("Vault created and unlocked. You're ready to go!") fmt.Println("Vault created and unlocked for this command. You're ready to go!")
fmt.Println() fmt.Println()
break break
} }
@ -148,7 +148,7 @@ func initApp() {
fmt.Printf("Invalid password. %d attempts remaining.\n", remaining) fmt.Printf("Invalid password. %d attempts remaining.\n", remaining)
continue continue
} }
fmt.Fprintf(os.Stderr, "Too many failed attempts. Run 'sshkeeper vault unlock' to try again.\n") fmt.Fprintf(os.Stderr, "Too many failed attempts. Start the command again to retry.\n")
os.Exit(1) os.Exit(1)
} }
@ -172,7 +172,7 @@ func commandRequiresStartupVaultUnlock(args []string) bool {
} }
switch args[0] { switch args[0] {
case "connect", "c", "run", "run-template", "test", "add", "edit", "delete": case "connect", "c", "run", "run-template", "test", "edit", "delete":
return true return true
default: default:
return false return false

View File

@ -11,7 +11,9 @@ func TestCommandRequiresStartupVaultUnlock(t *testing.T) {
{name: "root tui", args: nil, want: true}, {name: "root tui", args: nil, want: true},
{name: "connect", args: []string{"connect", "prod"}, want: true}, {name: "connect", args: []string{"connect", "prod"}, want: true},
{name: "short connect alias", args: []string{"c", "prod"}, want: true}, {name: "short connect alias", args: []string{"c", "prod"}, want: true},
{name: "add can store secrets", args: []string{"add", "prod"}, want: true}, {name: "add unlocks only if selected auth needs a secret", args: []string{"add", "prod"}, want: false},
{name: "run can use secrets", args: []string{"run", "prod", "uptime"}, want: true},
{name: "test can use secrets", args: []string{"test", "prod"}, want: true},
{name: "vault handles its own unlock", args: []string{"vault", "list"}, want: false}, {name: "vault handles its own unlock", args: []string{"vault", "list"}, want: false},
{name: "list only reads database", args: []string{"list"}, want: false}, {name: "list only reads database", args: []string{"list"}, want: false},
{name: "show only reads database", args: []string{"show", "prod"}, want: false}, {name: "show only reads database", args: []string{"show", "prod"}, want: false},

View File

@ -220,6 +220,10 @@ func unlockVaultForCommand(v *vault.Vault) error {
return nil return nil
} }
func vaultLockedProcessMessage() string {
return "vault is locked in this process; enter the master password when this command prompts for it"
}
func formatVaultStatus(unlocked bool, exists bool) string { func formatVaultStatus(unlocked bool, exists bool) string {
if !exists { if !exists {
return "Vault: not found" return "Vault: not found"