diff --git a/README.md b/README.md index 16bd8ac..7396c8c 100644 --- a/README.md +++ b/README.md @@ -44,11 +44,14 @@ sshkeeper init ## Common CLI Commands ```bash -# Add profiles +# Add profiles with flags sshkeeper add web --host 10.0.0.10 --user deploy --auth key sshkeeper add prod --host 10.0.0.20 --user root --auth password sshkeeper add bastion --host bastion.example.org --user admin --auth key_passphrase --identity-file ~/.ssh/id_rsa +# Or use the interactive CLI prompt +sshkeeper add + # Inspect profiles sshkeeper list sshkeeper show web @@ -71,7 +74,10 @@ sshkeeper ssh-config install-include Commands that only read profile metadata, such as `list`, `show`, `search`, `config path`, `group list`, and `export`, do not require the master password. -Commands that need secrets ask for the master password in that process. +Commands that need secrets ask for the master password in that process. Adding +`key` or `agent` profiles does not require unlocking the vault; adding +`password` or `key_passphrase` profiles asks for the master password before +storing the secret. ## TUI diff --git a/cmd/add.go b/cmd/add.go index 59d4771..9e0d397 100644 --- a/cmd/add.go +++ b/cmd/add.go @@ -1,12 +1,16 @@ package cmd import ( + "bufio" "fmt" + "io" + "os" + "strconv" "strings" "syscall" - "github.com/spf13/cobra" "github.com/mirivlad/sshkeeper/internal/model" + "github.com/spf13/cobra" "golang.org/x/term" ) @@ -32,10 +36,18 @@ var addCmd = &cobra.Command{ if len(args) == 1 && addFlags.host != "" { return addNonInteractive(args[0]) } - return fmt.Errorf("interactive add not yet implemented, use: sshkeeper add --host --user --auth ") + return addInteractive() }, } +func addInteractive() error { + server, err := promptServerForAdd(os.Stdin, os.Stdout) + if err != nil { + return err + } + return saveServerWithOptionalSecret(server) +} + func addNonInteractive(alias string) error { server := &model.Server{ Alias: alias, @@ -60,6 +72,10 @@ func addNonInteractive(alias string) error { server.DisplayName = alias } + return saveServerWithOptionalSecret(server) +} + +func saveServerWithOptionalSecret(server *model.Server) error { // Handle password/passphrase auth — request interactively, never via argv if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase { secretType := "password" @@ -78,14 +94,14 @@ func addNonInteractive(alias string) error { } v := getOrCreateVault() - if !v.IsUnlocked() { - return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") + if err := unlockVaultForCommand(v); err != nil { + return err } - vaultKey := fmt.Sprintf("server:%s:ssh_password", alias) + vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias) vaultType := "ssh_password" if server.AuthMethod == model.AuthKeyPassphrase { - vaultKey = fmt.Sprintf("server:%s:key_passphrase", alias) + vaultKey = fmt.Sprintf("server:%s:key_passphrase", server.Alias) vaultType = "key_passphrase" } @@ -117,6 +133,111 @@ func addNonInteractive(alias string) error { return nil } +func promptServerForAdd(in io.Reader, out io.Writer) (*model.Server, error) { + reader := bufio.NewReader(in) + + alias, err := promptRequired(reader, out, "Alias") + if err != nil { + return nil, err + } + displayName, err := promptOptional(reader, out, "Display name", alias) + if err != nil { + return nil, err + } + host, err := promptRequired(reader, out, "Host") + if err != nil { + return nil, err + } + portText, err := promptOptional(reader, out, "Port", "22") + if err != nil { + return nil, err + } + port, err := strconv.Atoi(portText) + if err != nil || port <= 0 { + return nil, fmt.Errorf("invalid port: %s", portText) + } + user, err := promptOptional(reader, out, "User", "root") + if err != nil { + return nil, err + } + authText, err := promptOptional(reader, out, "Auth method (password/key/key_passphrase/agent)", string(model.AuthKey)) + if err != nil { + return nil, err + } + authMethod := model.AuthMethod(authText) + if !isSupportedAuthMethod(authMethod) { + return nil, fmt.Errorf("unsupported auth method: %s", authText) + } + identityFile, err := promptOptional(reader, out, "Identity file", "") + if err != nil { + return nil, err + } + proxyJump, err := promptOptional(reader, out, "ProxyJump", "") + if err != nil { + return nil, err + } + groupName, err := promptOptional(reader, out, "Group", "") + if err != nil { + return nil, err + } + notes, err := promptOptional(reader, out, "Notes", "") + if err != nil { + return nil, err + } + + return &model.Server{ + Alias: alias, + DisplayName: displayName, + Host: host, + Port: port, + User: user, + AuthMethod: authMethod, + IdentityFile: identityFile, + ProxyJump: proxyJump, + GroupName: groupName, + Notes: notes, + }, nil +} + +func promptRequired(reader *bufio.Reader, out io.Writer, label string) (string, error) { + for { + value, err := promptOptional(reader, out, label, "") + if err != nil { + return "", err + } + if value != "" { + return value, nil + } + fmt.Fprintf(out, "%s is required.\n", label) + } +} + +func promptOptional(reader *bufio.Reader, out io.Writer, label string, defaultValue string) (string, error) { + if defaultValue == "" { + fmt.Fprintf(out, "%s: ", label) + } else { + fmt.Fprintf(out, "%s [%s]: ", label, defaultValue) + } + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + return "", err + } + value := strings.TrimSpace(line) + if value == "" { + return defaultValue, nil + } + return value, nil +} + +func isSupportedAuthMethod(method model.AuthMethod) bool { + switch method { + case model.AuthPassword, model.AuthKey, model.AuthKeyPassphrase, model.AuthAgent: + return true + default: + return false + } +} + func init() { addCmd.Flags().StringVar(&addFlags.host, "host", "", "Server hostname or IP") addCmd.Flags().IntVar(&addFlags.port, "port", 22, "SSH port") diff --git a/cmd/add_test.go b/cmd/add_test.go new file mode 100644 index 0000000..3c17417 --- /dev/null +++ b/cmd/add_test.go @@ -0,0 +1,82 @@ +package cmd + +import ( + "bytes" + "strings" + "testing" + + "github.com/mirivlad/sshkeeper/internal/model" +) + +func TestPromptServerForAddCollectsInteractiveFields(t *testing.T) { + input := strings.NewReader(strings.Join([]string{ + "prod", + "Production", + "prod.example.org", + "2222", + "deploy", + "key", + "~/.ssh/id_prod", + "bastion", + "prod", + "critical host", + "", + }, "\n")) + var output bytes.Buffer + + server, err := promptServerForAdd(input, &output) + if err != nil { + t.Fatalf("prompt server: %v", err) + } + + if server.Alias != "prod" || + server.DisplayName != "Production" || + server.Host != "prod.example.org" || + server.Port != 2222 || + server.User != "deploy" || + server.AuthMethod != model.AuthKey || + server.IdentityFile != "~/.ssh/id_prod" || + server.ProxyJump != "bastion" || + server.GroupName != "prod" || + server.Notes != "critical host" { + t.Fatalf("unexpected server: %#v", server) + } + if strings.Contains(output.String(), "not yet implemented") { + t.Fatalf("interactive add should not report unimplemented:\n%s", output.String()) + } +} + +func TestPromptServerForAddAppliesDefaults(t *testing.T) { + input := strings.NewReader(strings.Join([]string{ + "prod", + "", + "prod.example.org", + "", + "", + "", + "", + "", + "", + "", + "", + }, "\n")) + var output bytes.Buffer + + server, err := promptServerForAdd(input, &output) + if err != nil { + t.Fatalf("prompt server: %v", err) + } + + if server.DisplayName != "prod" { + t.Fatalf("display name default = %q; want alias", server.DisplayName) + } + if server.Port != 22 { + t.Fatalf("port default = %d; want 22", server.Port) + } + if server.User != "root" { + t.Fatalf("user default = %q; want root", server.User) + } + if server.AuthMethod != model.AuthKey { + t.Fatalf("auth default = %q; want key", server.AuthMethod) + } +} diff --git a/cmd/connect.go b/cmd/connect.go index 2c44f45..ce7a606 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -3,9 +3,9 @@ package cmd import ( "fmt" - "github.com/spf13/cobra" "github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/ssh" + "github.com/spf13/cobra" ) var connectCmd = &cobra.Command{ @@ -23,7 +23,7 @@ var connectCmd = &cobra.Command{ v := getOrCreateVault() vaultFunc := func(serverAlias string, secretType string) (string, error) { if !v.IsUnlocked() { - return "", fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") + return "", fmt.Errorf("%s", vaultLockedProcessMessage()) } key := fmt.Sprintf("server:%s:%s", serverAlias, secretType) data, err := v.Get(key) @@ -64,7 +64,7 @@ var testCmd = &cobra.Command{ v := getOrCreateVault() vaultFunc := func(serverAlias string, secretType string) (string, error) { if !v.IsUnlocked() { - return "", fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") + return "", fmt.Errorf("%s", vaultLockedProcessMessage()) } key := fmt.Sprintf("server:%s:%s", serverAlias, secretType) data, err := v.Get(key) diff --git a/cmd/extra.go b/cmd/extra.go index 78a3d7e..d78917b 100644 --- a/cmd/extra.go +++ b/cmd/extra.go @@ -6,9 +6,9 @@ import ( "os/exec" "strings" - "github.com/spf13/cobra" "github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/ssh" + "github.com/spf13/cobra" ) var importCmd = &cobra.Command{ @@ -74,9 +74,8 @@ var runCmd = &cobra.Command{ return fmt.Errorf("server not found: %s", alias) } - // For password auth, use PTY-wrapper with command - if server.AuthMethod == model.AuthPassword { - return runWithPassword(server, command) + if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase { + return runWithSecret(server, command) } // For key/agent auth — direct execution @@ -96,21 +95,25 @@ var runCmd = &cobra.Command{ }, } -// runWithPassword runs a command on a server with password auth via PTY-wrapper. -func runWithPassword(server *model.Server, command string) error { +// runWithSecret runs a command on a server through the PTY prompt handler. +func runWithSecret(server *model.Server, command string) error { v := getOrCreateVault() if !v.IsUnlocked() { - return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") + return fmt.Errorf("%s", vaultLockedProcessMessage()) } - vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias) - password, err := v.Get(vaultKey) + secretType := "ssh_password" + if server.AuthMethod == model.AuthKeyPassphrase { + secretType = "key_passphrase" + } + vaultKey := fmt.Sprintf("server:%s:%s", server.Alias, secretType) + secret, err := v.Get(vaultKey) if err != nil { - return fmt.Errorf("get password from vault: %w", err) + return fmt.Errorf("get %s from vault: %w", secretType, err) } sshArgs := ssh.BuildSSHArgs(server) sshArgs = append(sshArgs, command) - return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(password)) + return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(secret)) } diff --git a/cmd/init.go b/cmd/init.go index 96ce17f..c452a70 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -4,9 +4,9 @@ import ( "fmt" "os" - "github.com/spf13/cobra" "github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/db" + "github.com/spf13/cobra" ) var initCmd = &cobra.Command{ @@ -47,7 +47,7 @@ var initCmd = &cobra.Command{ fmt.Printf("Created database: %s/sshkeeper.db\n", cfg.DataDir) fmt.Printf("Created vault: %s/vault.bin\n", cfg.DataDir) fmt.Println() - fmt.Println("Next step: run 'sshkeeper vault unlock' to set master password.") + fmt.Println("Next step: run 'sshkeeper' or any command that needs secrets to create the vault master password.") return nil }, } diff --git a/cmd/root.go b/cmd/root.go index e8a125f..4762031 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -122,7 +122,7 @@ func initApp() { vaultInstance = v fmt.Println() - fmt.Println("Vault created and unlocked. You're ready to go!") + fmt.Println("Vault created and unlocked for this command. You're ready to go!") fmt.Println() break } @@ -148,7 +148,7 @@ func initApp() { fmt.Printf("Invalid password. %d attempts remaining.\n", remaining) continue } - fmt.Fprintf(os.Stderr, "Too many failed attempts. Run 'sshkeeper vault unlock' to try again.\n") + fmt.Fprintf(os.Stderr, "Too many failed attempts. Start the command again to retry.\n") os.Exit(1) } @@ -172,7 +172,7 @@ func commandRequiresStartupVaultUnlock(args []string) bool { } switch args[0] { - case "connect", "c", "run", "run-template", "test", "add", "edit", "delete": + case "connect", "c", "run", "run-template", "test", "edit", "delete": return true default: return false diff --git a/cmd/root_test.go b/cmd/root_test.go index bfe3044..7f58db7 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -11,7 +11,9 @@ func TestCommandRequiresStartupVaultUnlock(t *testing.T) { {name: "root tui", args: nil, want: true}, {name: "connect", args: []string{"connect", "prod"}, want: true}, {name: "short connect alias", args: []string{"c", "prod"}, want: true}, - {name: "add can store secrets", args: []string{"add", "prod"}, want: true}, + {name: "add unlocks only if selected auth needs a secret", args: []string{"add", "prod"}, want: false}, + {name: "run can use secrets", args: []string{"run", "prod", "uptime"}, want: true}, + {name: "test can use secrets", args: []string{"test", "prod"}, want: true}, {name: "vault handles its own unlock", args: []string{"vault", "list"}, want: false}, {name: "list only reads database", args: []string{"list"}, want: false}, {name: "show only reads database", args: []string{"show", "prod"}, want: false}, diff --git a/cmd/vault.go b/cmd/vault.go index 51f3048..bf6e9fe 100644 --- a/cmd/vault.go +++ b/cmd/vault.go @@ -220,6 +220,10 @@ func unlockVaultForCommand(v *vault.Vault) error { return nil } +func vaultLockedProcessMessage() string { + return "vault is locked in this process; enter the master password when this command prompts for it" +} + func formatVaultStatus(unlocked bool, exists bool) string { if !exists { return "Vault: not found"