diff --git a/cmd/add.go b/cmd/add.go index df6b52f..59d4771 100644 --- a/cmd/add.go +++ b/cmd/add.go @@ -3,9 +3,11 @@ package cmd import ( "fmt" "strings" + "syscall" "github.com/spf13/cobra" "github.com/mirivlad/sshkeeper/internal/model" + "golang.org/x/term" ) var addFlags struct { @@ -19,7 +21,6 @@ var addFlags struct { displayName string notes string tags string - password string } var addCmd = &cobra.Command{ @@ -59,11 +60,21 @@ func addNonInteractive(alias string) error { server.DisplayName = alias } - // Handle password auth - store in vault - if server.AuthMethod == model.AuthPassword { - password := addFlags.password - if password == "" { - return fmt.Errorf("password auth requires --password flag or interactive mode") + // Handle password/passphrase auth — request interactively, never via argv + if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase { + secretType := "password" + if server.AuthMethod == model.AuthKeyPassphrase { + secretType = "passphrase" + } + + fmt.Printf("Enter %s (will be stored in vault, input hidden): ", secretType) + password, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() + if err != nil { + return fmt.Errorf("read %s: %w", secretType, err) + } + if len(password) == 0 { + return fmt.Errorf("%s cannot be empty", secretType) } v := getOrCreateVault() @@ -72,29 +83,14 @@ func addNonInteractive(alias string) error { } vaultKey := fmt.Sprintf("server:%s:ssh_password", alias) - if err := v.Put(vaultKey, "ssh_password", []byte(password)); err != nil { - return fmt.Errorf("store password in vault: %w", err) - } - if err := v.Save(); err != nil { - return fmt.Errorf("save vault: %w", err) - } - } - - // Handle key+passphrase - store passphrase in vault - if server.AuthMethod == model.AuthKeyPassphrase { - passphrase := addFlags.password - if passphrase == "" { - return fmt.Errorf("key+passphrase auth requires --password flag for the passphrase") + vaultType := "ssh_password" + if server.AuthMethod == model.AuthKeyPassphrase { + vaultKey = fmt.Sprintf("server:%s:key_passphrase", alias) + vaultType = "key_passphrase" } - v := getOrCreateVault() - if !v.IsUnlocked() { - return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") - } - - vaultKey := fmt.Sprintf("server:%s:key_passphrase", alias) - if err := v.Put(vaultKey, "key_passphrase", []byte(passphrase)); err != nil { - return fmt.Errorf("store passphrase in vault: %w", err) + if err := v.Put(vaultKey, vaultType, password); err != nil { + return fmt.Errorf("store %s in vault: %w", secretType, err) } if err := v.Save(); err != nil { return fmt.Errorf("save vault: %w", err) @@ -132,5 +128,4 @@ func init() { addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name") addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes") addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags") - addCmd.Flags().StringVar(&addFlags.password, "password", "", "SSH password or key passphrase (stored in vault)") } diff --git a/cmd/delete.go b/cmd/delete.go index cc8c737..a32b6cd 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -27,6 +27,15 @@ var deleteCmd = &cobra.Command{ return fmt.Errorf("delete server: %w", err) } + // Clean up vault secrets for this server + v := getOrCreateVault() + if v.IsUnlocked() { + cleanupServerSecrets(v, alias) + if err := v.Save(); err != nil { + return fmt.Errorf("save vault after cleanup: %w", err) + } + } + fmt.Println("Deleted.") return nil }, diff --git a/cmd/edit.go b/cmd/edit.go index 9480407..a24e287 100644 --- a/cmd/edit.go +++ b/cmd/edit.go @@ -2,9 +2,11 @@ package cmd import ( "fmt" + "syscall" - "github.com/spf13/cobra" "github.com/mirivlad/sshkeeper/internal/model" + "github.com/spf13/cobra" + "golang.org/x/term" ) var editCmd = &cobra.Command{ @@ -18,6 +20,8 @@ var editCmd = &cobra.Command{ return fmt.Errorf("server not found: %s", alias) } + oldAuthMethod := server.AuthMethod + if parsedHost != "" { server.Host = parsedHost } @@ -46,6 +50,41 @@ var editCmd = &cobra.Command{ server.Notes = parsedNotes } + if parsedAuth != "" && oldAuthMethod != server.AuthMethod { + v := getOrCreateVault() + if v.IsUnlocked() { + var secret string + if server.AuthMethod == model.AuthPassword { + fmt.Print("Enter new password (stored in vault, input hidden): ") + pw, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() + if err != nil { + return fmt.Errorf("read password: %w", err) + } + if len(pw) > 0 { + secret = string(pw) + } + } else if server.AuthMethod == model.AuthKeyPassphrase { + fmt.Print("Enter key passphrase (stored in vault, input hidden): ") + pw, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() + if err != nil { + return fmt.Errorf("read passphrase: %w", err) + } + if len(pw) > 0 { + secret = string(pw) + } + } + + if err := syncServerSecrets(v, alias, server, secret); err != nil { + return fmt.Errorf("sync vault secrets: %w", err) + } + if err := v.Save(); err != nil { + return fmt.Errorf("save vault: %w", err) + } + } + } + if err := appDB.UpdateServer(server); err != nil { return fmt.Errorf("update server: %w", err) } @@ -56,15 +95,15 @@ var editCmd = &cobra.Command{ } var ( - parsedHost string - parsedPort int - parsedUser string - parsedAuth string - parsedIdentity string - parsedProxyJump string - parsedGroup string + parsedHost string + parsedPort int + parsedUser string + parsedAuth string + parsedIdentity string + parsedProxyJump string + parsedGroup string parsedDisplayName string - parsedNotes string + parsedNotes string ) func init() { diff --git a/cmd/extra.go b/cmd/extra.go index 7f54d0c..78a3d7e 100644 --- a/cmd/extra.go +++ b/cmd/extra.go @@ -74,8 +74,13 @@ var runCmd = &cobra.Command{ return fmt.Errorf("server not found: %s", alias) } - // Build ssh args with the command - sshArgs := buildSSHArgs(server) + // For password auth, use PTY-wrapper with command + if server.AuthMethod == model.AuthPassword { + return runWithPassword(server, command) + } + + // For key/agent auth — direct execution + sshArgs := ssh.BuildSSHArgs(server) sshArgs = append(sshArgs, command) sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...) @@ -91,17 +96,21 @@ var runCmd = &cobra.Command{ }, } -func buildSSHArgs(server *model.Server) []string { - var args []string - args = append(args, "-p", fmt.Sprintf("%d", server.Port)) - if server.IdentityFile != "" { - args = append(args, "-i", server.IdentityFile) +// runWithPassword runs a command on a server with password auth via PTY-wrapper. +func runWithPassword(server *model.Server, command string) error { + v := getOrCreateVault() + if !v.IsUnlocked() { + return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first") } - if server.ProxyJump != "" { - args = append(args, "-J", server.ProxyJump) + + vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias) + password, err := v.Get(vaultKey) + if err != nil { + return fmt.Errorf("get password from vault: %w", err) } - args = append(args, "-o", "StrictHostKeyChecking=accept-new") - target := fmt.Sprintf("%s@%s", server.User, server.Host) - args = append(args, target) - return args + + sshArgs := ssh.BuildSSHArgs(server) + sshArgs = append(sshArgs, command) + + return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(password)) } diff --git a/cmd/secrets.go b/cmd/secrets.go new file mode 100644 index 0000000..18352f5 --- /dev/null +++ b/cmd/secrets.go @@ -0,0 +1,85 @@ +package cmd + +import ( + "fmt" + + "github.com/mirivlad/sshkeeper/internal/model" + "github.com/mirivlad/sshkeeper/internal/ssh" + "github.com/mirivlad/sshkeeper/internal/vault" +) + +const ( + secretSSHPassword = "ssh_password" + secretKeyPassphrase = "key_passphrase" + secretSudoPassword = "sudo_password" +) + +var serverSecretTypes = []string{ + secretSSHPassword, + secretKeyPassphrase, + secretSudoPassword, +} + +func serverSecretID(alias, secretType string) string { + return fmt.Sprintf("server:%s:%s", alias, secretType) +} + +func cleanupServerSecrets(v *vault.Vault, alias string) { + for _, secretType := range serverSecretTypes { + v.Delete(serverSecretID(alias, secretType)) + } +} + +func syncServerSecrets(v *vault.Vault, oldAlias string, server *model.Server, secret string) error { + if oldAlias == "" { + oldAlias = server.Alias + } + if oldAlias != server.Alias { + for _, secretType := range serverSecretTypes { + oldID := serverSecretID(oldAlias, secretType) + data, err := v.Get(oldID) + if err == nil { + if err := v.Put(serverSecretID(server.Alias, secretType), secretType, data); err != nil { + return err + } + } + v.Delete(oldID) + } + } + + switch server.AuthMethod { + case model.AuthPassword: + v.Delete(serverSecretID(server.Alias, secretKeyPassphrase)) + if secret != "" { + return v.Put(serverSecretID(server.Alias, secretSSHPassword), secretSSHPassword, []byte(secret)) + } + case model.AuthKeyPassphrase: + v.Delete(serverSecretID(server.Alias, secretSSHPassword)) + if secret != "" { + return v.Put(serverSecretID(server.Alias, secretKeyPassphrase), secretKeyPassphrase, []byte(secret)) + } + default: + v.Delete(serverSecretID(server.Alias, secretSSHPassword)) + v.Delete(serverSecretID(server.Alias, secretKeyPassphrase)) + } + + return nil +} + +func deleteVaultSecrets(v *vault.Vault, alias string, secretType string) error { + if secretType != "" { + v.Delete(serverSecretID(alias, secretType)) + return nil + } + cleanupServerSecrets(v, alias) + return nil +} + +func formTestVaultFunc(getVault ssh.VaultFunc, server *model.Server, formSecret string) ssh.VaultFunc { + return func(serverAlias string, secretType string) (string, error) { + if (secretType == secretSSHPassword || secretType == secretKeyPassphrase) && formSecret != "" { + return formSecret, nil + } + return getVault(serverAlias, secretType) + } +} diff --git a/cmd/secrets_test.go b/cmd/secrets_test.go new file mode 100644 index 0000000..39a839c --- /dev/null +++ b/cmd/secrets_test.go @@ -0,0 +1,186 @@ +package cmd + +import ( + "fmt" + "path/filepath" + "testing" + + "github.com/mirivlad/sshkeeper/internal/model" + "github.com/mirivlad/sshkeeper/internal/vault" +) + +func newUnlockedTestVault(t *testing.T) *vault.Vault { + t.Helper() + path := filepath.Join(t.TempDir(), "vault.bin") + if err := vault.Create(path, "master"); err != nil { + t.Fatalf("create vault: %v", err) + } + v := vault.New(path) + if err := v.Unlock("master"); err != nil { + t.Fatalf("unlock vault: %v", err) + } + return v +} + +func mustPutSecret(t *testing.T, v *vault.Vault, alias, secretType, value string) { + t.Helper() + if err := v.Put(serverSecretID(alias, secretType), secretType, []byte(value)); err != nil { + t.Fatalf("put %s: %v", secretType, err) + } +} + +func mustGetSecret(t *testing.T, v *vault.Vault, alias, secretType string) string { + t.Helper() + data, err := v.Get(serverSecretID(alias, secretType)) + if err != nil { + t.Fatalf("get %s: %v", secretType, err) + } + return string(data) +} + +func assertSecretMissing(t *testing.T, v *vault.Vault, alias, secretType string) { + t.Helper() + if data, err := v.Get(serverSecretID(alias, secretType)); err == nil { + t.Fatalf("expected %s for %s to be missing, got %q", secretType, alias, string(data)) + } +} + +func TestCleanupServerSecretsRemovesAuthAndSudoSecrets(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "prod", "ssh_password", "password") + mustPutSecret(t, v, "prod", "key_passphrase", "passphrase") + mustPutSecret(t, v, "prod", "sudo_password", "sudo") + + cleanupServerSecrets(v, "prod") + + assertSecretMissing(t, v, "prod", "ssh_password") + assertSecretMissing(t, v, "prod", "key_passphrase") + assertSecretMissing(t, v, "prod", "sudo_password") +} + +func TestSyncServerSecretsDeletesAuthSecretsForKeyAuth(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "prod", "ssh_password", "password") + mustPutSecret(t, v, "prod", "key_passphrase", "passphrase") + + server := &model.Server{Alias: "prod", AuthMethod: model.AuthKey} + if err := syncServerSecrets(v, "prod", server, ""); err != nil { + t.Fatalf("sync secrets: %v", err) + } + + assertSecretMissing(t, v, "prod", "ssh_password") + assertSecretMissing(t, v, "prod", "key_passphrase") +} + +func TestSyncServerSecretsKeepsExistingPasswordWhenEditPasswordIsBlank(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "prod", "ssh_password", "old-password") + mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase") + + server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword} + if err := syncServerSecrets(v, "prod", server, ""); err != nil { + t.Fatalf("sync secrets: %v", err) + } + + if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "old-password" { + t.Fatalf("expected password to remain, got %q", got) + } + assertSecretMissing(t, v, "prod", "key_passphrase") +} + +func TestSyncServerSecretsRenamesSecretsBeforeApplyingAuthCleanup(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "old", "ssh_password", "password") + mustPutSecret(t, v, "old", "key_passphrase", "passphrase") + mustPutSecret(t, v, "old", "sudo_password", "sudo") + + server := &model.Server{Alias: "new", AuthMethod: model.AuthKeyPassphrase} + if err := syncServerSecrets(v, "old", server, ""); err != nil { + t.Fatalf("sync secrets: %v", err) + } + + assertSecretMissing(t, v, "old", "ssh_password") + assertSecretMissing(t, v, "old", "key_passphrase") + assertSecretMissing(t, v, "old", "sudo_password") + assertSecretMissing(t, v, "new", "ssh_password") + if got := mustGetSecret(t, v, "new", "key_passphrase"); got != "passphrase" { + t.Fatalf("expected key passphrase to move, got %q", got) + } + if got := mustGetSecret(t, v, "new", "sudo_password"); got != "sudo" { + t.Fatalf("expected sudo password to move, got %q", got) + } +} + +func TestSyncServerSecretsStoresNewSecretForSelectedAuthMethod(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase") + + server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword} + if err := syncServerSecrets(v, "prod", server, "new-password"); err != nil { + t.Fatalf("sync secrets: %v", err) + } + + if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "new-password" { + t.Fatalf("expected new password, got %q", got) + } + assertSecretMissing(t, v, "prod", "key_passphrase") +} + +func TestDeleteVaultSecretsForAliasAndType(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "prod", "ssh_password", "password") + mustPutSecret(t, v, "prod", "key_passphrase", "passphrase") + + if err := deleteVaultSecrets(v, "prod", "ssh_password"); err != nil { + t.Fatalf("delete vault secret: %v", err) + } + + assertSecretMissing(t, v, "prod", "ssh_password") + if got := mustGetSecret(t, v, "prod", "key_passphrase"); got != "passphrase" { + t.Fatalf("expected passphrase to remain, got %q", got) + } +} + +func TestDeleteVaultSecretsForAlias(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "prod", "ssh_password", "password") + mustPutSecret(t, v, "prod", "key_passphrase", "passphrase") + + if err := deleteVaultSecrets(v, "prod", ""); err != nil { + t.Fatalf("delete vault secrets: %v", err) + } + + assertSecretMissing(t, v, "prod", "ssh_password") + assertSecretMissing(t, v, "prod", "key_passphrase") +} + +func TestFormTestVaultFuncUsesSavedSecretWhenFormSecretBlank(t *testing.T) { + vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) { + if serverAlias != "prod" || secretType != "ssh_password" { + return "", fmt.Errorf("unexpected secret lookup %s %s", serverAlias, secretType) + } + return "saved-password", nil + }, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "") + + got, err := vaultFunc("prod", "ssh_password") + if err != nil { + t.Fatalf("get password: %v", err) + } + if got != "saved-password" { + t.Fatalf("expected saved password, got %q", got) + } +} + +func TestFormTestVaultFuncUsesFormSecretWhenProvided(t *testing.T) { + vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) { + return "", fmt.Errorf("saved secret should not be used") + }, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "typed-password") + + got, err := vaultFunc("prod", "ssh_password") + if err != nil { + t.Fatalf("get password: %v", err) + } + if got != "typed-password" { + t.Fatalf("expected typed password, got %q", got) + } +} diff --git a/cmd/template.go b/cmd/template.go index db74f93..3d4a01d 100644 --- a/cmd/template.go +++ b/cmd/template.go @@ -2,8 +2,11 @@ package cmd import ( "fmt" + "os" + "os/exec" "github.com/spf13/cobra" + "github.com/mirivlad/sshkeeper/internal/ssh" ) var templateCmd = &cobra.Command{ @@ -70,13 +73,15 @@ var runTemplateCmd = &cobra.Command{ alias := args[0] templateName := args[1] + server, err := appDB.GetServer(alias) + if err != nil { + return fmt.Errorf("server not found: %s", alias) + } + templates, err := appDB.GetCommandTemplates(alias) if err != nil { return fmt.Errorf("list templates: %w", err) } - if len(templates) == 0 { - return fmt.Errorf("server not found or no templates: %s", alias) - } var command string for _, t := range templates { @@ -91,7 +96,21 @@ var runTemplateCmd = &cobra.Command{ } fmt.Printf("Running '%s' on %s...\n", command, alias) - return nil + + // Build ssh args with the command + sshArgs := ssh.BuildSSHArgs(server) + sshArgs = append(sshArgs, command) + + sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...) + sshCmd.Stdin = os.Stdin + sshCmd.Stdout = os.Stdout + sshCmd.Stderr = os.Stderr + + if err := sshCmd.Start(); err != nil { + return fmt.Errorf("start ssh: %w", err) + } + + return sshCmd.Wait() }, } diff --git a/cmd/tui.go b/cmd/tui.go index 5abc250..5322182 100644 --- a/cmd/tui.go +++ b/cmd/tui.go @@ -36,41 +36,43 @@ func runTUI() error { return appDB.SearchServers(query) } tui.DeleteServer = func(alias string) error { - return appDB.DeleteServer(alias) + if err := appDB.DeleteServer(alias); err != nil { + return err + } + v := getOrCreateVault() + if v.IsUnlocked() { + cleanupServerSecrets(v, alias) + if err := v.Save(); err != nil { + return fmt.Errorf("save vault after cleanup: %w", err) + } + } + return nil } tui.TestConnection = func(server *model.Server) (bool, string) { return ssh.Test(cfg, server, vaultFunc) } tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) { - directVaultFunc := func(sa string, st string) (string, error) { - if st == "ssh_password" || st == "key_passphrase" { - return password, nil - } - return vaultFunc(sa, st) - } - return ssh.Test(cfg, server, directVaultFunc) + return ssh.Test(cfg, server, formTestVaultFunc(vaultFunc, server, password)) } - tui.SaveServer = func(server *model.Server, password string) error { - if password != "" { - v := getOrCreateVault() - vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias) - secretType := "ssh_password" - if server.AuthMethod == model.AuthKeyPassphrase { - vaultKey = fmt.Sprintf("server:%s:key_passphrase", server.Alias) - secretType = "key_passphrase" - } - if err := v.Put(vaultKey, secretType, []byte(password)); err != nil { - return fmt.Errorf("store secret: %w", err) + tui.SaveServer = func(server *model.Server, password string, oldAlias string) error { + v := getOrCreateVault() + if v.IsUnlocked() { + if err := syncServerSecrets(v, oldAlias, server, password); err != nil { + return fmt.Errorf("sync vault secrets: %w", err) } if err := v.Save(); err != nil { return fmt.Errorf("save vault: %w", err) } } - existing, _ := appDB.GetServer(server.Alias) + lookupAlias := server.Alias + if oldAlias != "" { + lookupAlias = oldAlias + } + existing, _ := appDB.GetServer(lookupAlias) if existing != nil { server.ID = existing.ID - return appDB.UpdateServer(server) + return appDB.UpdateServerByAlias(existing.Alias, server) } return appDB.CreateServer(server) } @@ -84,6 +86,16 @@ func runTUI() error { tui.DeleteGroup = func(name string) error { return appDB.DeleteGroup(name) } + tui.UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error { + return appDB.UpdateTestResult(alias, status, testErr) + } + tui.HasSecret = func(alias string, secretType string) bool { + v := getOrCreateVault() + if !v.IsUnlocked() { + return false + } + return v.HasSecret(serverSecretID(alias, secretType)) + } // Run TUI in a loop — if user requests connect, handle it and restart TUI for { @@ -132,5 +144,3 @@ func runTUI() error { return nil } } - - diff --git a/cmd/vault.go b/cmd/vault.go index fe3f627..c85ab83 100644 --- a/cmd/vault.go +++ b/cmd/vault.go @@ -3,11 +3,12 @@ package cmd import ( "fmt" "os" + "strings" "syscall" - "github.com/spf13/cobra" "github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/vault" + "github.com/spf13/cobra" "golang.org/x/term" ) @@ -160,9 +161,75 @@ var vaultChangePasswordCmd = &cobra.Command{ }, } +var vaultListCmd = &cobra.Command{ + Use: "list", + Short: "List stored secret metadata", + RunE: func(cmd *cobra.Command, args []string) error { + v := getOrCreateVault() + if !v.IsUnlocked() { + return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'") + } + output, err := formatVaultSecretsList(v) + if err != nil { + return err + } + fmt.Print(output) + return nil + }, +} + +var vaultDeleteCmd = &cobra.Command{ + Use: "delete [type]", + Short: "Delete stored secrets for a server", + Args: cobra.RangeArgs(1, 2), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + secretType := "" + if len(args) == 2 { + secretType = args[1] + } + + v := getOrCreateVault() + if !v.IsUnlocked() { + return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'") + } + if err := deleteVaultSecrets(v, alias, secretType); err != nil { + return err + } + if err := v.Save(); err != nil { + return fmt.Errorf("save vault: %w", err) + } + if secretType == "" { + fmt.Printf("Deleted secrets for %s.\n", alias) + } else { + fmt.Printf("Deleted %s for %s.\n", secretType, alias) + } + return nil + }, +} + +func formatVaultSecretsList(v *vault.Vault) (string, error) { + metas, err := v.ListSecrets() + if err != nil { + return "", err + } + if len(metas) == 0 { + return "No secrets stored.\n", nil + } + + var b strings.Builder + fmt.Fprintf(&b, "%-24s %-18s\n", "ALIAS", "TYPE") + for _, meta := range metas { + fmt.Fprintf(&b, "%-24s %-18s\n", meta.Alias, meta.Type) + } + return b.String(), nil +} + func init() { vaultCmd.AddCommand(vaultUnlockCmd) vaultCmd.AddCommand(vaultLockCmd) vaultCmd.AddCommand(vaultStatusCmd) vaultCmd.AddCommand(vaultChangePasswordCmd) + vaultCmd.AddCommand(vaultListCmd) + vaultCmd.AddCommand(vaultDeleteCmd) } diff --git a/cmd/vault_test.go b/cmd/vault_test.go new file mode 100644 index 0000000..b84829f --- /dev/null +++ b/cmd/vault_test.go @@ -0,0 +1,40 @@ +package cmd + +import ( + "strings" + "testing" +) + +func TestFormatVaultSecretsListDoesNotExposeSecretValues(t *testing.T) { + v := newUnlockedTestVault(t) + mustPutSecret(t, v, "prod", "ssh_password", "super-secret") + mustPutSecret(t, v, "stage", "key_passphrase", "also-secret") + + output, err := formatVaultSecretsList(v) + if err != nil { + t.Fatalf("format vault secrets list: %v", err) + } + + for _, want := range []string{"prod", "ssh_password", "stage", "key_passphrase"} { + if !strings.Contains(output, want) { + t.Fatalf("expected output to contain %q\noutput:\n%s", want, output) + } + } + for _, secretValue := range []string{"super-secret", "also-secret"} { + if strings.Contains(output, secretValue) { + t.Fatalf("expected output not to expose secret value %q\noutput:\n%s", secretValue, output) + } + } +} + +func TestFormatVaultSecretsListHandlesEmptyVault(t *testing.T) { + v := newUnlockedTestVault(t) + + output, err := formatVaultSecretsList(v) + if err != nil { + t.Fatalf("format empty vault secrets list: %v", err) + } + if !strings.Contains(output, "No secrets stored.") { + t.Fatalf("expected empty output message, got:\n%s", output) + } +} diff --git a/docs/superpowers/plans/2026-05-28-server-list-scalability.md b/docs/superpowers/plans/2026-05-28-server-list-scalability.md new file mode 100644 index 0000000..fb3d98b --- /dev/null +++ b/docs/superpowers/plans/2026-05-28-server-list-scalability.md @@ -0,0 +1,295 @@ +# Server List Scalability Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Keep the server list usable when the saved server count is larger than the terminal height. + +**Architecture:** The list screen should render a bounded table viewport instead of rendering every server row. Selection remains owned by the existing `bubbles/list.Model`, while `viewServerList` derives a visible row range around the selected item and keeps the selected-server detail panel and footer visible. + +**Tech Stack:** Go, Bubble Tea, Bubbles list, Lip Gloss, existing TUI tests in `internal/tui/app_test.go`. + +--- + +### Task 1: Add Regression Coverage For Long Server Lists + +**Files:** +- Modify: `internal/tui/app_test.go` + +- [ ] **Step 1: Add a test for a constrained terminal height** + +Add a test that creates more servers than can fit on screen, sets a small terminal size, renders the list, and verifies the selected details and footer are still visible. + +```go +func TestServerListViewKeepsDetailsVisibleWithManyServers(t *testing.T) { + servers := make([]*model.Server, 45) + for i := range servers { + servers[i] = &model.Server{ + Alias: fmt.Sprintf("server-%02d", i+1), + DisplayName: fmt.Sprintf("Server %02d", i+1), + Host: fmt.Sprintf("host-%02d.example.org", i+1), + Port: 22, + User: "mirivlad", + AuthMethod: model.AuthKey, + LastTestStatus: model.TestUnknown, + } + } + + m := New(servers) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18}) + model := updated.(*tuiModel) + + view := model.View() + if !strings.Contains(view, "Server 01") { + t.Fatalf("expected first selected server to be visible:\n%s", view) + } + if !strings.Contains(view, "Selected") { + t.Fatalf("expected selected server details to remain visible:\n%s", view) + } + if !strings.Contains(view, "Enter connect") { + t.Fatalf("expected footer to remain visible:\n%s", view) + } + if count := strings.Count(view, "server-"); count >= len(servers) { + t.Fatalf("expected bounded row rendering, rendered %d server aliases", count) + } +} +``` + +- [ ] **Step 2: Run the focused test and confirm it fails** + +Run: + +```bash +env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1 +``` + +Expected: FAIL because the current table renders every server and can push the detail panel/footer below the visible terminal area. + +### Task 2: Compute A Visible Row Window + +**Files:** +- Modify: `internal/tui/app.go` +- Modify: `internal/tui/app_test.go` + +- [ ] **Step 1: Add focused tests for visible range calculation** + +Add tests for a helper that computes the inclusive start and exclusive end indexes for rendered rows. + +```go +func TestVisibleServerRangeKeepsSelectionInsideWindow(t *testing.T) { + tests := []struct { + name string + total int + selected int + available int + wantStart int + wantEnd int + }{ + {name: "first page", total: 40, selected: 0, available: 10, wantStart: 0, wantEnd: 10}, + {name: "middle page", total: 40, selected: 20, available: 10, wantStart: 11, wantEnd: 21}, + {name: "last page", total: 40, selected: 39, available: 10, wantStart: 30, wantEnd: 40}, + {name: "all fit", total: 5, selected: 3, available: 10, wantStart: 0, wantEnd: 5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start, end := visibleServerRange(tt.total, tt.selected, tt.available) + if start != tt.wantStart || end != tt.wantEnd { + t.Fatalf("visibleServerRange() = %d, %d; want %d, %d", start, end, tt.wantStart, tt.wantEnd) + } + }) + } +} +``` + +- [ ] **Step 2: Implement `visibleServerRange`** + +Add a small helper near `selectedServer`. + +```go +func visibleServerRange(total, selected, available int) (int, int) { + if total <= 0 || available <= 0 { + return 0, 0 + } + if available >= total { + return 0, total + } + if selected < 0 { + selected = 0 + } + if selected >= total { + selected = total - 1 + } + + start := selected - available + 1 + if start < 0 { + start = 0 + } + end := start + available + if end > total { + end = total + start = end - available + } + return start, end +} +``` + +- [ ] **Step 3: Run helper tests** + +Run: + +```bash +env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestVisibleServerRangeKeepsSelectionInsideWindow -count=1 +``` + +Expected: PASS. + +### Task 3: Render Only Rows That Fit + +**Files:** +- Modify: `internal/tui/app.go` +- Modify: `internal/tui/app_test.go` + +- [ ] **Step 1: Reserve terminal space for fixed UI blocks** + +Add a helper that decides how many server rows may be rendered while keeping the selected details and footer visible. + +```go +func (m *tuiModel) visibleServerRows() int { + if m.height <= 0 { + return len(m.servers) + } + + const fixedRows = 16 + rows := m.height - fixedRows + if rows < 3 { + return 3 + } + return rows +} +``` + +- [ ] **Step 2: Use the visible range in `viewServerList`** + +In `viewServerList`, replace the loop over all servers with a bounded loop: + +```go +selectedIndex := m.list.Index() +start, end := visibleServerRange(len(m.servers), selectedIndex, m.visibleServerRows()) +for _, server := range m.servers[start:end] { + // existing row rendering body stays unchanged +} +``` + +Then render a compact range hint when rows are hidden: + +```go +if len(m.servers) > end-start { + b.WriteString(helpStyle.Render(fmt.Sprintf(" Showing %d-%d of %d", start+1, end, len(m.servers)))) + b.WriteString("\n") +} +``` + +- [ ] **Step 3: Run long-list regression test** + +Run: + +```bash +env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1 +``` + +Expected: PASS. + +### Task 4: Verify Navigation Still Works + +**Files:** +- Modify: `internal/tui/app_test.go` + +- [ ] **Step 1: Add a test for moving selection beyond the first window** + +Use the existing `m.list.Update` path by sending `tea.KeyDown` messages and confirm the rendered window follows the selected server. + +```go +func TestServerListViewScrollsWithSelection(t *testing.T) { + servers := make([]*model.Server, 45) + for i := range servers { + servers[i] = &model.Server{ + Alias: fmt.Sprintf("server-%02d", i+1), + DisplayName: fmt.Sprintf("Server %02d", i+1), + Host: fmt.Sprintf("host-%02d.example.org", i+1), + Port: 22, + User: "mirivlad", + AuthMethod: model.AuthKey, + } + } + + m := New(servers) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18}) + model := updated.(*tuiModel) + for i := 0; i < 20; i++ { + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyDown}) + model = updated.(*tuiModel) + } + + view := model.View() + if !strings.Contains(view, "Server 21") { + t.Fatalf("expected selected server to be visible after navigation:\n%s", view) + } + if !strings.Contains(view, "Showing") { + t.Fatalf("expected range hint for long server list:\n%s", view) + } +} +``` + +- [ ] **Step 2: Run TUI tests** + +Run: + +```bash +env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -count=1 +``` + +Expected: PASS. + +### Task 5: Final Verification And Build + +**Files:** +- No source edits expected. + +- [ ] **Step 1: Run the full test suite** + +Run: + +```bash +env GOCACHE=/tmp/sshkeeper-go-cache go test ./... +``` + +Expected: all packages pass. + +- [ ] **Step 2: Rebuild the project binary** + +Run: + +```bash +env GOCACHE=/tmp/sshkeeper-go-cache go build -o bin/sshkeeper . +``` + +Expected: exit code 0 and updated `bin/sshkeeper`. + +- [ ] **Step 3: Commit the implementation** + +Run: + +```bash +git add internal/tui/app.go internal/tui/app_test.go bin/sshkeeper +git commit -m "fix: keep server list usable with many servers" +``` + +Expected: commit succeeds. + +--- + +## Self-Review + +- Spec coverage: the plan covers the known failure mode for 40+ servers, keeps selected details visible, keeps the footer visible, and preserves existing `bubbles/list` navigation. +- Placeholder scan: no `TBD`, `TODO`, or open-ended implementation placeholders remain. +- Type consistency: helper names and files match the current TUI code shape. diff --git a/go.mod b/go.mod index a1c15d2..cf00dc6 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/creack/pty v1.1.24 github.com/spf13/cobra v1.10.2 golang.org/x/crypto v0.52.0 + golang.org/x/sys v0.45.0 golang.org/x/term v0.43.0 modernc.org/sqlite v1.50.1 ) @@ -41,7 +42,6 @@ require ( github.com/sahilm/fuzzy v0.1.1 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - golang.org/x/sys v0.45.0 // indirect golang.org/x/text v0.37.0 // indirect modernc.org/libc v1.72.3 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/internal/db/servers.go b/internal/db/servers.go index 83c43a0..21add61 100644 --- a/internal/db/servers.go +++ b/internal/db/servers.go @@ -30,6 +30,17 @@ func (db *DB) UpdateServer(s *model.Server) error { return err } +func (db *DB) UpdateServerByAlias(oldAlias string, s *model.Server) error { + _, err := db.conn.Exec(` + UPDATE servers SET + alias=?, display_name=?, host=?, port=?, user=?, auth_method=?, + identity_file=?, proxy_jump=?, group_name=?, notes=?, updated_at=CURRENT_TIMESTAMP + WHERE alias=?`, + s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod, + s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, oldAlias) + return err +} + func (db *DB) DeleteServer(alias string) error { _, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias) return err diff --git a/internal/db/servers_test.go b/internal/db/servers_test.go new file mode 100644 index 0000000..2d2c297 --- /dev/null +++ b/internal/db/servers_test.go @@ -0,0 +1,47 @@ +package db + +import ( + "testing" + + "github.com/mirivlad/sshkeeper/internal/model" +) + +func TestUpdateServerByAliasCanRenameAlias(t *testing.T) { + db, err := Open(t.TempDir()) + if err != nil { + t.Fatalf("open db: %v", err) + } + defer db.Close() + + server := &model.Server{ + Alias: "old", + Host: "old.example", + Port: 22, + User: "root", + AuthMethod: model.AuthPassword, + } + if err := db.CreateServer(server); err != nil { + t.Fatalf("create server: %v", err) + } + + server.Alias = "new" + server.Host = "new.example" + server.AuthMethod = model.AuthKey + if err := db.UpdateServerByAlias("old", server); err != nil { + t.Fatalf("update server by old alias: %v", err) + } + + if _, err := db.GetServer("old"); err == nil { + t.Fatal("expected old alias to be gone") + } + got, err := db.GetServer("new") + if err != nil { + t.Fatalf("get new alias: %v", err) + } + if got.ID != server.ID { + t.Fatalf("expected ID to stay %d, got %d", server.ID, got.ID) + } + if got.Host != "new.example" || got.AuthMethod != model.AuthKey { + t.Fatalf("unexpected updated server: %#v", got) + } +} diff --git a/internal/ssh/command.go b/internal/ssh/command.go index 6837ea7..7aa088f 100644 --- a/internal/ssh/command.go +++ b/internal/ssh/command.go @@ -5,7 +5,6 @@ import ( "os" "os/exec" "strings" - "time" "github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/model" @@ -14,7 +13,7 @@ import ( type VaultFunc func(serverAlias string, secretType string) (string, error) func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error { - args := buildArgs(server) + args := BuildSSHArgs(server) switch server.AuthMethod { case model.AuthPassword: @@ -25,8 +24,7 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error return ConnectWithPassword(cfg.SSH.Binary, args, password) case model.AuthKeyPassphrase: - // For key+passphrase, we need to handle the passphrase - // For now, let ssh-agent handle it or prompt normally + // For key+passphrase, let ssh-agent handle it or prompt normally // TODO: use ssh-agent or similar fallthrough @@ -46,13 +44,11 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error } func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) { - args := buildArgs(server) + args := BuildSSHArgs(server) args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec)) switch server.AuthMethod { case model.AuthPassword: - // For password auth, we can't use BatchMode - // Use a short timeout and try to connect args = append(args, "-o", "NumberOfPasswordPrompts=1") password, err := getVault(server.Alias, "ssh_password") if err != nil { @@ -81,40 +77,29 @@ func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, s } } +// testWithPassword tests SSH connection with password auth via PTY-wrapper. +// It connects, sends the password, runs the test command, and checks the output. func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) { - // For password test, we use PTY approach with a short timeout - // This is a simplified version - in production, use ConnectWithPassword - // with a test command args = append(args, cfg.SSH.TestCommand) - cmd := exec.Command(cfg.SSH.Binary, args...) - cmd.Stdin = nil - cmd.Stdout = nil - cmd.Stderr = nil - - // Use a timeout - done := make(chan error, 1) - if err := cmd.Start(); err != nil { - return false, err.Error() + ok, output := connectWithPasswordAndRead(cfg.SSH.Binary, args, password, cfg.SSH.ConnectTimeoutSec) + if !ok { + return false, output } - go func() { - done <- cmd.Wait() - }() - - select { - case err := <-done: - if err != nil { - return false, err.Error() - } + result := strings.TrimSpace(output) + if result == "SSHKEEPER_OK" { return true, "" - case <-time.After(time.Duration(cfg.SSH.ConnectTimeoutSec) * time.Second): - cmd.Process.Kill() - return false, "connection timeout" } + // The output might have the test command echo before the result + if strings.Contains(result, "SSHKEEPER_OK") { + return true, "" + } + return false, result } -func buildArgs(server *model.Server) []string { +// BuildSSHArgs builds the SSH command arguments for a server profile. +func BuildSSHArgs(server *model.Server) []string { var args []string args = append(args, "-p", fmt.Sprintf("%d", server.Port)) @@ -127,8 +112,6 @@ func buildArgs(server *model.Server) []string { args = append(args, "-J", server.ProxyJump) } - // Disable strict host key checking for first connection - // In production, this should be configurable args = append(args, "-o", "StrictHostKeyChecking=accept-new") target := fmt.Sprintf("%s@%s", server.User, server.Host) diff --git a/internal/ssh/pty.go b/internal/ssh/pty.go index 7e27474..4961405 100644 --- a/internal/ssh/pty.go +++ b/internal/ssh/pty.go @@ -7,19 +7,19 @@ import ( "os/exec" "regexp" "strings" + "sync" + "sync/atomic" "syscall" "time" "github.com/creack/pty" + "golang.org/x/sys/unix" "golang.org/x/term" ) var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`) -// ConnectWithPassword runs SSH through a PTY, detects the password prompt, -// sends the password, and then bridges the user terminal to the SSH session. func ConnectWithPassword(sshBinary string, args []string, password string) error { - // Start SSH with PTY cmd := exec.Command(sshBinary, args...) cmd.Env = os.Environ() cmd.SysProcAttr = &syscall.SysProcAttr{ @@ -33,18 +33,14 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error } defer ptmx.Close() - // Save terminal state and set to raw oldState, err := term.MakeRaw(int(os.Stdin.Fd())) if err != nil { return fmt.Errorf("set raw terminal: %w", err) } - defer term.Restore(int(os.Stdin.Fd()), oldState) - // Channel to signal when password has been sent - passwordSent := make(chan bool, 1) + passwordSent := new(atomic.Bool) done := make(chan error, 1) - // Read from PTY, detect password prompt go func() { buf := make([]byte, 4096) var accumulated strings.Builder @@ -54,22 +50,17 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error if n > 0 { data := buf[:n] accumulated.Write(data) - - // Write to stdout os.Stdout.Write(data) - // Check for password prompt - if !<-passwordSent { + if !passwordSent.Load() { text := accumulated.String() if passwordPromptRe.MatchString(text) { - passwordSent <- true + passwordSent.Store(true) time.Sleep(100 * time.Millisecond) ptmx.Write([]byte(password + "\r")) - continue } } - // Reset accumulated buffer periodically to avoid unbounded growth if accumulated.Len() > 8192 { s := accumulated.String() accumulated.Reset() @@ -87,14 +78,142 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error } }() - // Copy stdin to PTY - go func() { - io.Copy(ptmx, os.Stdin) - }() + stopInput, waitInput, restoreInput, err := forwardInputToPTY(ptmx) + if err != nil { + term.Restore(int(os.Stdin.Fd()), oldState) + return err + } - // Wait for command completion err = cmd.Wait() - passwordSent <- false // signal to stop + close(stopInput) + waitInput() + if restoreErr := restoreInput(); err == nil && restoreErr != nil { + err = restoreErr + } + if restoreErr := term.Restore(int(os.Stdin.Fd()), oldState); err == nil && restoreErr != nil { + err = restoreErr + } return err } + +func forwardInputToPTY(ptmx *os.File) (chan struct{}, func(), func() error, error) { + stdinFD := int(os.Stdin.Fd()) + flags, err := unix.FcntlInt(uintptr(stdinFD), unix.F_GETFL, 0) + if err != nil { + return nil, nil, nil, fmt.Errorf("get stdin flags: %w", err) + } + if err := unix.SetNonblock(stdinFD, true); err != nil { + return nil, nil, nil, fmt.Errorf("set stdin nonblocking: %w", err) + } + + stop := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 4096) + for { + select { + case <-stop: + return + default: + } + + n, readErr := unix.Read(stdinFD, buf) + if n > 0 { + if _, writeErr := ptmx.Write(buf[:n]); writeErr != nil { + return + } + } + if readErr == nil { + continue + } + if readErr == unix.EAGAIN || readErr == unix.EWOULDBLOCK { + select { + case <-stop: + return + case <-time.After(10 * time.Millisecond): + continue + } + } + return + } + }() + + restore := func() error { + _, err := unix.FcntlInt(uintptr(stdinFD), unix.F_SETFL, flags) + if err != nil { + return fmt.Errorf("restore stdin flags: %w", err) + } + return nil + } + + return stop, wg.Wait, restore, nil +} + +// connectWithPasswordAndRead runs SSH through a PTY, sends the password, +// collects all output, and returns it. Used for non-interactive testing. +// Returns (true, output) on success, (false, error) on failure. +func connectWithPasswordAndRead(sshBinary string, args []string, password string, timeoutSec int) (bool, string) { + cmd := exec.Command(sshBinary, args...) + cmd.Env = os.Environ() + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + Setctty: true, + } + + ptmx, err := pty.Start(cmd) + if err != nil { + return false, fmt.Sprintf("start ssh with pty: %v", err) + } + defer ptmx.Close() + + passwordSent := false + done := make(chan string, 1) + + // Read from PTY, detect password prompt, collect output + go func() { + buf := make([]byte, 4096) + var accumulated strings.Builder + + for { + n, err := ptmx.Read(buf) + if n > 0 { + data := buf[:n] + accumulated.Write(data) + + // Check for password prompt + if !passwordSent { + text := accumulated.String() + if passwordPromptRe.MatchString(text) { + passwordSent = true + time.Sleep(100 * time.Millisecond) + ptmx.Write([]byte(password + "\r")) + continue + } + } + + // Reset accumulated buffer periodically + if accumulated.Len() > 16384 { + s := accumulated.String() + accumulated.Reset() + accumulated.WriteString(s[len(s)-4096:]) + } + } + if err != nil { + done <- accumulated.String() + return + } + } + }() + + // Wait for command completion or timeout + select { + case output := <-done: + return true, output + case <-time.After(time.Duration(timeoutSec) * time.Second): + cmd.Process.Kill() + return false, "connection timeout" + } +} diff --git a/internal/ssh/pty_test.go b/internal/ssh/pty_test.go new file mode 100644 index 0000000..9d0271f --- /dev/null +++ b/internal/ssh/pty_test.go @@ -0,0 +1,63 @@ +package ssh + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/creack/pty" +) + +func TestConnectWithPasswordDoesNotConsumeInputAfterReturn(t *testing.T) { + script := filepath.Join(t.TempDir(), "fake-ssh") + if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf 'password: '\nsleep 0.1\nexit 0\n"), 0o700); err != nil { + t.Fatalf("write fake ssh: %v", err) + } + + master, slave, err := pty.Open() + if err != nil { + t.Fatalf("open stdin pty: %v", err) + } + defer master.Close() + defer slave.Close() + + oldStdin := os.Stdin + oldStdout := os.Stdout + stdoutSink, err := os.CreateTemp(t.TempDir(), "stdout") + if err != nil { + t.Fatalf("create stdout sink: %v", err) + } + defer stdoutSink.Close() + os.Stdin = slave + os.Stdout = stdoutSink + defer func() { + os.Stdin = oldStdin + os.Stdout = oldStdout + }() + + if err := ConnectWithPassword(script, nil, "secret"); err != nil { + t.Fatalf("connect with password: %v", err) + } + + if _, err := master.Write([]byte("\n")); err != nil { + t.Fatalf("write next enter: %v", err) + } + time.Sleep(100 * time.Millisecond) + + readDone := make(chan []byte, 1) + go func() { + buf := make([]byte, 1) + n, _ := slave.Read(buf) + readDone <- buf[:n] + }() + + select { + case got := <-readDone: + if len(got) != 1 || got[0] != '\n' { + t.Fatalf("expected to read newline, got %q", got) + } + case <-time.After(300 * time.Millisecond): + t.Fatal("expected next Enter to remain readable after ConnectWithPassword returned") + } +} diff --git a/internal/tui/app.go b/internal/tui/app.go index 14a57f4..ff5b5c7 100644 --- a/internal/tui/app.go +++ b/internal/tui/app.go @@ -26,7 +26,10 @@ var ( Background(lipgloss.Color("4")). Bold(true) - normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")) + normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")) + selectedRowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")).Background(lipgloss.Color("4")) + listHeaderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("14")).Bold(true) + sectionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")).Bold(true).MarginTop(1) testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true) testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true) @@ -67,9 +70,13 @@ type serverItem struct { server *model.Server } -func (i serverItem) Title() string { return i.server.Alias } -func (i serverItem) Description() string { return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod) } -func (i serverItem) FilterValue() string { return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User } +func (i serverItem) Title() string { return i.server.Alias } +func (i serverItem) Description() string { + return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod) +} +func (i serverItem) FilterValue() string { + return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User +} // --- External callbacks --- @@ -80,10 +87,12 @@ var ( TestConnection func(server *model.Server) (bool, string) // TestConnectionWithPassword tests with explicit password (for form test before save) TestConnectionWithPassword func(server *model.Server, password string) (bool, string) - SaveServer func(server *model.Server, password string) error - GetGroups func() ([]string, error) // Returns existing group names - RenameGroup func(oldName, newName string) error // Rename group for all servers - DeleteGroup func(name string) error // Remove group from all servers + SaveServer func(server *model.Server, password string, oldAlias string) error + UpdateTestResult func(alias string, status model.TestStatus, testErr string) error + HasSecret func(alias string, secretType string) bool + GetGroups func() ([]string, error) + RenameGroup func(oldName, newName string) error + DeleteGroup func(name string) error ) // --- Screen type --- @@ -99,8 +108,8 @@ const ( // --- Result type — returned from TUI to caller --- type TUIResult struct { - Server *model.Server - Action string // "connect" + Server *model.Server + Action string // "connect" } // --- Main TUI model --- @@ -196,8 +205,22 @@ func (m *tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } m.form.testResultTime = time.Now() m.form.err = nil + return m, nil + } + // Update test status in DB and reload list + if item, ok := m.list.SelectedItem().(serverItem); ok && UpdateTestResult != nil { + status := model.TestUnknown + if msg.ok { + status = model.TestOK + } else if msg.err != "" { + status = model.TestFailed + } + UpdateTestResult(item.server.Alias, status, msg.err) + } + return m, func() tea.Msg { + servers, err := ListServers() + return serversLoadedMsg{servers: servers, err: err} } - return m, nil case saveDoneMsg: if m.form != nil { @@ -323,11 +346,23 @@ func (m *tuiModel) updateSearch(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) { if msg.Type == tea.KeyEsc { + if m.form != nil && (m.form.showGroupList || m.form.showAuthList) { + updated, cmd := m.form.Update(msg) + if fm, ok := updated.(*formModel); ok { + m.form = fm + } + return m, cmd + } + m.screen = screenList m.form = nil m.err = nil m.success = "" - return m, nil + // Reload server list after form close + return m, func() tea.Msg { + servers, err := ListServers() + return serversLoadedMsg{servers: servers, err: err} + } } updated, cmd := m.form.Update(msg) @@ -342,9 +377,7 @@ func (m *tuiModel) View() string { switch m.screen { case screenList: - b.WriteString(m.list.View()) - b.WriteString("\n") - b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit")) + b.WriteString(m.viewServerList()) case screenSearch: b.WriteString("Search: " + m.searchInput.View() + "\n") @@ -366,13 +399,148 @@ func (m *tuiModel) View() string { return b.String() } +func (m *tuiModel) viewServerList() string { + var b strings.Builder + selectedAlias := "" + if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil { + selectedAlias = item.server.Alias + } + + b.WriteString(titleStyle.Render(fmt.Sprintf("sshkeeper %d servers", len(m.servers)))) + b.WriteString("\n") + b.WriteString(helpStyle.Render(fmt.Sprintf("Vault unlocked | %s", testSummary(m.servers)))) + b.WriteString("\n\n") + b.WriteString(listHeaderStyle.Render(fmt.Sprintf(" %-20s %-20s %-34s %-12s %-10s %s", "NAME", "ALIAS", "TARGET", "AUTH", "GROUP", "STATUS"))) + b.WriteString("\n") + + if len(m.servers) == 0 { + b.WriteString(helpStyle.Render(" No servers yet. Press Ctrl+A to add one.")) + b.WriteString("\n") + } else { + for _, server := range m.servers { + marker := " " + rowStyle := normalStyle + if server.Alias == selectedAlias { + marker = ">" + rowStyle = selectedRowStyle + } + name := server.DisplayName + if name == "" { + name = server.Alias + } + target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port) + group := server.GroupName + if group == "" { + group = "-" + } + row := fmt.Sprintf("%s %-20s %-20s %-34s %-12s %-10s %s", + marker, + truncate(name, 20), + truncate(server.Alias, 20), + truncate(target, 34), + authLabel(server.AuthMethod), + truncate(group, 10), + testStatusLabel(server), + ) + b.WriteString(rowStyle.Render(row)) + b.WriteString("\n") + } + } + + b.WriteString("\n") + if selectedAlias != "" { + if selected := m.selectedServer(); selected != nil { + b.WriteString(m.viewSelectedServer(selected)) + b.WriteString("\n") + } + } + b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit")) + return b.String() +} + +func (m *tuiModel) selectedServer() *model.Server { + if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil { + return item.server + } + return nil +} + +func (m *tuiModel) viewSelectedServer(server *model.Server) string { + displayName := server.DisplayName + if displayName == "" { + displayName = "-" + } + group := server.GroupName + if group == "" { + group = "-" + } + + var b strings.Builder + b.WriteString(sectionStyle.Render("Selected")) + b.WriteString("\n") + b.WriteString(fmt.Sprintf(" Alias: %s\n", server.Alias)) + b.WriteString(fmt.Sprintf(" Display Name: %s\n", displayName)) + b.WriteString(fmt.Sprintf(" Host: %s\n", server.Host)) + b.WriteString(fmt.Sprintf(" Port: %d\n", server.Port)) + b.WriteString(fmt.Sprintf(" User: %s\n", server.User)) + b.WriteString(fmt.Sprintf(" Auth: %s\n", authLabel(server.AuthMethod))) + b.WriteString(fmt.Sprintf(" Group: %s\n", group)) + b.WriteString(fmt.Sprintf(" Status: %s\n", testStatusLabel(server))) + return b.String() +} + +func testSummary(servers []*model.Server) string { + okCount := 0 + failedCount := 0 + for _, server := range servers { + switch server.LastTestStatus { + case model.TestOK: + okCount++ + case model.TestFailed: + failedCount++ + } + } + return fmt.Sprintf("%d OK | %d FAIL", okCount, failedCount) +} + +func authLabel(auth model.AuthMethod) string { + switch auth { + case model.AuthPassword: + return "password" + case model.AuthKey: + return "key" + case model.AuthKeyPassphrase: + return "key+phrase" + case model.AuthAgent: + return "agent" + default: + return string(auth) + } +} + +func testStatusLabel(server *model.Server) string { + switch server.LastTestStatus { + case model.TestOK: + return "OK" + case model.TestFailed: + if server.LastTestError != "" { + return "FAIL" + } + return "FAIL" + default: + return "?" + } +} + // --- Form model --- type formModel struct { edit bool server *model.Server inputs []textinput.Model + labels []string password textinput.Model + passwordLabel string focusIdx int testResult string testOK bool @@ -385,10 +553,22 @@ type formModel struct { spinner spinner.Model width int height int - groups string // cached groups list (comma-separated for display) - showGroups bool // show group dropdown + groups []string // existing group names + groupList list.Model // dropdown list for groups + showGroupList bool // whether group dropdown is visible + authList list.Model + showAuthList bool } +// groupItem implements list.Item for the group dropdown +type groupItem struct { + name string +} + +func (i groupItem) Title() string { return i.name } +func (i groupItem) Description() string { return "" } +func (i groupItem) FilterValue() string { return i.name } + func newFormModel(w, h int) *formModel { inputs := make([]textinput.Model, 10) labels := []string{ @@ -405,12 +585,12 @@ func newFormModel(w, h int) *formModel { } for i, label := range labels { inputs[i] = textinput.New() - inputs[i].Placeholder = label + inputs[i].Placeholder = placeholderForLabel(label) inputs[i].CharLimit = 128 } pw := textinput.New() - pw.Placeholder = "Password / Passphrase (stored in vault)" + pw.Placeholder = "optional" pw.CharLimit = 256 pw.EchoMode = textinput.EchoPassword @@ -421,24 +601,75 @@ func newFormModel(w, h int) *formModel { inputs[0].Focus() fm := &formModel{ - inputs: inputs, - password: pw, - focusIdx: 0, - spinner: s, - width: w, - height: h, + inputs: inputs, + labels: labels, + password: pw, + passwordLabel: "Password / Passphrase", + focusIdx: 0, + spinner: s, + width: w, + height: h, } + fm.authList = newStringList([]string{ + string(model.AuthPassword), + string(model.AuthKey), + string(model.AuthKeyPassphrase), + string(model.AuthAgent), + }, "Select auth method", 34, 16) // Load existing groups if GetGroups != nil { if groups, err := GetGroups(); err == nil && len(groups) > 0 { - fm.groups = strings.Join(groups, ", ") + fm.groups = groups + fm.groupList = newStringList(groups, "Select group", 30, 8) } } + fm.updateFocus() return fm } +func placeholderForLabel(label string) string { + switch label { + case "Alias": + return "mail.kp" + case "Display Name": + return "Production mail" + case "Host": + return "mail.example.org" + case "Port": + return "22" + case "User": + return "root" + case "Auth Method (password/key/key_passphrase/agent)": + return "key" + case "Identity File": + return "~/.ssh/id_ed25519" + case "ProxyJump": + return "optional" + case "Group (type new or pick from list)": + return "KP" + case "Notes": + return "optional" + default: + return label + } +} + +func newStringList(values []string, title string, width, height int) list.Model { + items := make([]list.Item, len(values)) + for i, value := range values { + items[i] = groupItem{name: value} + } + l := list.New(items, list.NewDefaultDelegate(), width, height) + l.SetShowStatusBar(false) + l.SetShowHelp(false) + l.SetShowPagination(false) + l.Title = title + l.Styles.Title = titleStyle + return l +} + func newEditFormModel(s *model.Server, w, h int) *formModel { fm := newFormModel(w, h) fm.edit = true @@ -453,6 +684,21 @@ func newEditFormModel(s *model.Server, w, h int) *formModel { fm.inputs[7].SetValue(s.ProxyJump) fm.inputs[8].SetValue(s.GroupName) fm.inputs[9].SetValue(s.Notes) + if HasSecret != nil { + switch s.AuthMethod { + case model.AuthPassword: + if HasSecret(s.Alias, "ssh_password") { + fm.passwordLabel = "Password (secret saved; leave blank to keep)" + fm.password.Placeholder = "" + } + case model.AuthKeyPassphrase: + if HasSecret(s.Alias, "key_passphrase") { + fm.passwordLabel = "Key passphrase (secret saved; leave blank to keep)" + fm.password.Placeholder = "" + } + } + } + fm.updateFocus() return fm } @@ -498,6 +744,48 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return fm, cmd } + // Handle group dropdown + if fm.showGroupList { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.Type { + case tea.KeyEsc: + fm.showGroupList = false + return fm, nil + case tea.KeyEnter: + if item, ok := fm.groupList.SelectedItem().(groupItem); ok { + fm.inputs[8].SetValue(item.name) + } + fm.showGroupList = false + return fm, nil + } + } + // Pass other keys to the list + var cmd tea.Cmd + fm.groupList, cmd = fm.groupList.Update(msg) + return fm, cmd + } + + if fm.showAuthList { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.Type { + case tea.KeyEsc: + fm.showAuthList = false + return fm, nil + case tea.KeyEnter: + if item, ok := fm.authList.SelectedItem().(groupItem); ok { + fm.inputs[5].SetValue(item.name) + } + fm.showAuthList = false + return fm, nil + } + } + var cmd tea.Cmd + fm.authList, cmd = fm.authList.Update(msg) + return fm, cmd + } + switch msg := msg.(type) { case tea.KeyMsg: switch msg.Type { @@ -519,6 +807,17 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { fm.updateFocus() return fm, nil + case tea.KeyRunes: + if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 5 { + fm.showAuthList = true + return fm, nil + } + // '/' on Group field opens group dropdown + if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 8 && len(fm.groups) > 0 { + fm.showGroupList = true + return fm, nil + } + case tea.KeyEnter: switch { case fm.focusIdx == len(fm.inputs)+1: @@ -576,20 +875,36 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (fm *formModel) updateFocus() { for i := range fm.inputs { fm.inputs[i].Blur() - fm.inputs[i].Prompt = blurredStyle.Render(fm.inputs[i].Placeholder + ": ") + fm.inputs[i].Prompt = blurredStyle.Render(fm.labelAt(i) + ": ") } fm.password.Blur() - fm.password.Prompt = blurredStyle.Render(fm.password.Placeholder + ": ") + fm.password.Prompt = blurredStyle.Render(fm.passwordLabel + ": ") if fm.focusIdx < len(fm.inputs) { fm.inputs[fm.focusIdx].Focus() - fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.inputs[fm.focusIdx].Placeholder + "> ") + fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.labelAt(fm.focusIdx) + "> ") } else if fm.focusIdx == len(fm.inputs) { fm.password.Focus() - fm.password.Prompt = focusedStyle.Render(fm.password.Placeholder + "> ") + fm.password.Prompt = focusedStyle.Render(fm.passwordLabel + "> ") } } +func (fm *formModel) labelAt(index int) string { + if index >= 0 && index < len(fm.labels) { + if index == 5 { + return "Auth Method (/ pick)" + } + if index == 8 { + if len(fm.groups) > 0 { + return "Group (/ pick)" + } + return "Group" + } + return fm.labels[index] + } + return "" +} + func (fm *formModel) runTest() tea.Cmd { fm.testing = true fm.testResult = "" @@ -635,7 +950,11 @@ func (fm *formModel) runSave() tea.Cmd { if s.Host == "" { return saveDoneMsg{err: fmt.Errorf("host is required")} } - err := SaveServer(s, pw) + oldAlias := "" + if fm.edit && fm.server != nil { + oldAlias = fm.server.Alias + } + err := SaveServer(s, pw, oldAlias) return saveDoneMsg{err: err} }, ) @@ -672,13 +991,72 @@ func (fm *formModel) View() string { b.WriteString(titleStyle.Render(title)) b.WriteString("\n\n") - for i := range fm.inputs { + // Calculate visible range based on terminal height + // Reserve lines for: title (2) + password (1) + buttons (3) + help (1) + padding (2) = ~9 + reserved := 9 + available := fm.height - reserved + if available < 4 { + available = 4 + } + + numInputs := len(fm.inputs) + startIdx := 0 + endIdx := numInputs + + // Scroll: keep focused field visible + if numInputs > available { + focusInput := fm.focusIdx + if focusInput >= numInputs { + focusInput = numInputs - 1 + } + // Try to show `available` fields centered on focus + startIdx = focusInput - available/2 + if startIdx < 0 { + startIdx = 0 + } + endIdx = startIdx + available + if endIdx > numInputs { + endIdx = numInputs + startIdx = endIdx - available + if startIdx < 0 { + startIdx = 0 + } + } + } + + // Show scroll indicator if needed + if startIdx > 0 { + b.WriteString(helpStyle.Render(" ↑ more fields above\n")) + } + + for i := startIdx; i < endIdx; i++ { + if section := formSectionTitle(i); section != "" { + b.WriteString(sectionStyle.Render(section)) + b.WriteString("\n") + } + if i == 5 { + fm.inputs[i].Placeholder = "password/key/key_passphrase/agent" + } + // Show group hint inline in placeholder for Group field + if i == 8 && len(fm.groups) > 0 && !fm.showGroupList { + fm.inputs[i].Placeholder = truncate(strings.Join(fm.groups, ", "), 25) + } b.WriteString(fm.inputs[i].View()) b.WriteString("\n") - // Show existing groups hint under Group field - if i == 8 && fm.groups != "" { - b.WriteString(helpStyle.Render(" Groups: " + fm.groups + "\n")) + if i == 5 && fm.showAuthList { + b.WriteString("\n" + renderDropdown(fm.authList) + "\n") + b.WriteString(helpStyle.Render("Enter select | Esc cancel")) + return b.String() } + if i == 8 && fm.showGroupList { + b.WriteString("\n" + renderDropdown(fm.groupList) + "\n") + b.WriteString(helpStyle.Render("Enter select | Esc cancel")) + return b.String() + } + } + + if endIdx < numInputs { + b.WriteString(helpStyle.Render(fmt.Sprintf(" ↓ more fields below (%d-%d of %d)\n", startIdx+1, endIdx, numInputs))) } b.WriteString(fm.password.View()) @@ -723,8 +1101,53 @@ func (fm *formModel) View() string { saveBtn = normalStyle.Render(saveBtn) } - b.WriteString("\n" + testBtn + " " + saveBtn + "\n\n") - b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | Enter select | Esc back")) + b.WriteString("\n" + sectionStyle.Render("Actions") + "\n") + b.WriteString(testBtn + " " + saveBtn + "\n\n") + b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | / pick list | Enter select | Esc back")) return b.String() } + +func renderDropdown(l list.Model) string { + var b strings.Builder + b.WriteString(sectionStyle.Render(l.Title)) + b.WriteString("\n") + for i, item := range l.Items() { + group, ok := item.(groupItem) + if !ok { + continue + } + prefix := " " + style := normalStyle + if i == l.Index() { + prefix = "> " + style = selectedRowStyle + } + b.WriteString(style.Render(prefix + group.name)) + b.WriteString("\n") + } + return strings.TrimRight(b.String(), "\n") +} + +func formSectionTitle(index int) string { + switch index { + case 0: + return "Identity" + case 2: + return "Connection" + case 5: + return "Authentication" + case 8: + return "Metadata" + default: + return "" + } +} + +// truncate limits a string to maxLen, adding "..." if truncated +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} diff --git a/internal/tui/app_test.go b/internal/tui/app_test.go new file mode 100644 index 0000000..8aece7e --- /dev/null +++ b/internal/tui/app_test.go @@ -0,0 +1,322 @@ +package tui + +import ( + "strings" + "testing" + "time" + + "github.com/charmbracelet/bubbles/list" + tea "github.com/charmbracelet/bubbletea" + "github.com/mirivlad/sshkeeper/internal/model" +) + +func TestServerListViewUsesDashboardLayout(t *testing.T) { + now := time.Date(2026, 5, 28, 1, 50, 0, 0, time.UTC) + m := New([]*model.Server{ + { + Alias: "mail.kp", + DisplayName: "Mail", + Host: "mail.example.org", + Port: 222, + User: "mirivlad", + AuthMethod: model.AuthPassword, + GroupName: "KP", + LastTestStatus: model.TestOK, + LastTestAt: &now, + }, + { + Alias: "mirv.top", + Host: "mirv.top", + Port: 22, + User: "root", + AuthMethod: model.AuthKey, + LastTestStatus: model.TestUnknown, + }, + }) + m.width = 100 + m.height = 30 + m.list.SetSize(100, 24) + + view := m.View() + for _, want := range []string{ + "sshkeeper", + "2 servers", + "Vault", + "NAME", + "TARGET", + "AUTH", + "GROUP", + "STATUS", + "Mail", + "mail.kp", + "mirivlad@mail.example.org:222", + "KP", + "OK", + "Selected", + "Host: mail.example.org", + "Alias: mail.kp", + "Display Name: Mail", + "Port: 222", + "Enter connect", + } { + if !strings.Contains(view, want) { + t.Fatalf("expected list view to contain %q\nview:\n%s", want, view) + } + } + if strings.Contains(view, "Profiles managed locally") { + t.Fatalf("expected compact status header instead of README text\nview:\n%s", view) + } +} + +func TestEscClosesGroupListBeforeLeavingForm(t *testing.T) { + oldGetGroups := GetGroups + GetGroups = func() ([]string, error) { + return []string{"prod", "stage"}, nil + } + defer func() { GetGroups = oldGetGroups }() + + m := &tuiModel{ + screen: screenForm, + form: newFormModel(80, 24), + } + m.form.focusIdx = 8 + m.form.updateFocus() + + updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}}) + m = updated.(*tuiModel) + if !m.form.showGroupList { + t.Fatal("expected / on the group field to open the group list") + } + + updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc}) + m = updated.(*tuiModel) + + if m.screen != screenForm { + t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen) + } + if m.form == nil { + t.Fatal("expected form to remain open") + } + if m.form.showGroupList { + t.Fatal("expected Esc to close only the group list") + } +} + +func TestEscClosesAuthMethodListBeforeLeavingForm(t *testing.T) { + m := &tuiModel{ + screen: screenForm, + form: newFormModel(80, 24), + } + m.form.focusIdx = 5 + m.form.updateFocus() + + updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}}) + m = updated.(*tuiModel) + if !m.form.showAuthList { + t.Fatal("expected / on the auth method field to open the auth method list") + } + + updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc}) + m = updated.(*tuiModel) + + if m.screen != screenForm { + t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen) + } + if m.form == nil { + t.Fatal("expected form to remain open") + } + if m.form.showAuthList { + t.Fatal("expected Esc to close only the auth method list") + } +} + +func TestAuthMethodListSelectsValue(t *testing.T) { + fm := newFormModel(80, 24) + fm.focusIdx = 5 + fm.updateFocus() + + updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}}) + fm = updated.(*formModel) + if !fm.showAuthList { + t.Fatal("expected / on the auth method field to open the auth method list") + } + + fm.authList.Select(2) + updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyEnter}) + fm = updated.(*formModel) + + if fm.showAuthList { + t.Fatal("expected Enter to close auth method list") + } + if got := fm.inputs[5].Value(); got != string(model.AuthKeyPassphrase) { + t.Fatalf("expected auth method %q, got %q", model.AuthKeyPassphrase, got) + } +} + +func TestAuthMethodListViewShowsAllOptions(t *testing.T) { + fm := newFormModel(80, 12) + fm.focusIdx = 5 + fm.updateFocus() + + updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}}) + fm = updated.(*formModel) + + view := fm.View() + authPos := strings.Index(view, "Auth Method") + listPos := strings.Index(view, "Select auth method") + if authPos < 0 || listPos < 0 { + t.Fatalf("expected auth field and auth list title in view\nview:\n%s", view) + } + if listPos < authPos { + t.Fatalf("expected auth method list after auth field\nview:\n%s", view) + } + if between := view[authPos:listPos]; strings.Contains(between, "Identity File") { + t.Fatalf("expected auth method list to render directly under auth field\nview:\n%s", view) + } + if strings.Contains(view, "│") { + t.Fatalf("expected compact auth method dropdown without default list border\nview:\n%s", view) + } + for _, method := range []model.AuthMethod{ + model.AuthPassword, + model.AuthKey, + model.AuthKeyPassphrase, + model.AuthAgent, + } { + if !strings.Contains(view, string(method)) { + t.Fatalf("expected auth method list view to contain %q\nview:\n%s", method, view) + } + } +} + +func TestGroupListViewRendersDirectlyUnderGroupField(t *testing.T) { + oldGetGroups := GetGroups + GetGroups = func() ([]string, error) { + return []string{"KP", "MY"}, nil + } + defer func() { GetGroups = oldGetGroups }() + + fm := newFormModel(80, 24) + fm.focusIdx = 8 + fm.updateFocus() + + updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}}) + fm = updated.(*formModel) + + view := fm.View() + groupPos := strings.Index(view, "Group") + listPos := strings.Index(view, "Select group") + if groupPos < 0 || listPos < 0 { + t.Fatalf("expected group field and group list title in view\nview:\n%s", view) + } + if listPos < groupPos { + t.Fatalf("expected group list after group field\nview:\n%s", view) + } + if between := view[groupPos:listPos]; strings.Contains(between, "Password") { + t.Fatalf("expected group dropdown to render before password field\nview:\n%s", view) + } + if strings.Contains(view, "│") { + t.Fatalf("expected compact group dropdown without default list border\nview:\n%s", view) + } +} + +func TestSelectableFieldHintsAreVisible(t *testing.T) { + oldGetGroups := GetGroups + GetGroups = func() ([]string, error) { + return []string{"KP"}, nil + } + defer func() { GetGroups = oldGetGroups }() + + fm := newFormModel(80, 24) + view := fm.View() + for _, want := range []string{"Auth Method (/ pick)", "Group (/ pick)", "/ pick list"} { + if !strings.Contains(view, want) { + t.Fatalf("expected form view to contain selectable-field hint %q\nview:\n%s", want, view) + } + } +} + +func TestEditFormShowsSavedSecretMarker(t *testing.T) { + oldHasSecret := HasSecret + HasSecret = func(alias string, secretType string) bool { + return alias == "prod" && secretType == "ssh_password" + } + defer func() { HasSecret = oldHasSecret }() + + fm := newEditFormModel(&model.Server{ + Alias: "prod", + Host: "example.org", + Port: 22, + User: "root", + AuthMethod: model.AuthPassword, + }, 80, 24) + + view := fm.View() + if !strings.Contains(view, "secret saved") { + t.Fatalf("expected edit form to show saved secret marker\nview:\n%s", view) + } + if !strings.Contains(view, "leave blank to keep") { + t.Fatalf("expected edit form to explain blank password keeps saved secret\nview:\n%s", view) + } + if strings.Count(view, "secret saved") != 1 { + t.Fatalf("expected saved secret marker to appear once\nview:\n%s", view) + } +} + +func TestFormViewUsesSectionsAndStableLabels(t *testing.T) { + fm := newFormModel(100, 30) + view := fm.View() + + for _, want := range []string{ + "Identity", + "Connection", + "Authentication", + "Metadata", + "Actions", + "Alias", + "Display Name", + "Auth Method", + "Password / Passphrase", + } { + if !strings.Contains(view, want) { + t.Fatalf("expected form view to contain %q\nview:\n%s", want, view) + } + } +} + +func TestFormTestResultDoesNotUpdateSelectedListServer(t *testing.T) { + oldUpdateTestResult := UpdateTestResult + oldListServers := ListServers + defer func() { + UpdateTestResult = oldUpdateTestResult + ListServers = oldListServers + }() + + updateCalled := false + UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error { + updateCalled = true + return nil + } + ListServers = func() ([]*model.Server, error) { + t.Fatal("form test result should not reload server list") + return nil, nil + } + + selected := &model.Server{Alias: "selected", Host: "example.org", Port: 22, User: "root"} + m := New([]*model.Server{selected}) + m.screen = screenForm + m.form = newFormModel(80, 24) + m.list = list.New([]list.Item{serverItem{server: selected}}, list.NewDefaultDelegate(), 80, 20) + + updated, cmd := m.Update(testDoneMsg{ok: true}) + m = updated.(*tuiModel) + + if cmd != nil { + t.Fatal("form test result should not return a reload command") + } + if updateCalled { + t.Fatal("form test result should not update the selected list server") + } + if m.form == nil || m.form.testResult != "Connection OK." { + t.Fatal("expected form to keep its test result") + } +} diff --git a/internal/vault/vault.go b/internal/vault/vault.go index f756cb7..90ff870 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -8,6 +8,8 @@ import ( "fmt" "io" "os" + "sort" + "strings" "sync" "time" @@ -16,18 +18,20 @@ import ( ) const ( - currentVersion = 1 - saltLen = 32 - nonceLen = 24 - keyLen = 32 + currentVersion = 1 + saltLen = 32 + nonceLen = 24 + keyLen = 32 + verifierID = "__sshkeeper_vault_verifier__" + verifierPlaintext = "sshkeeper-vault-verifier-v1" ) type KDFMeta struct { - Name string `json:"name"` - MemoryKiB int `json:"memory_kib"` - Iterations int `json:"iterations"` - Parallelism int `json:"parallelism"` - Salt string `json:"salt"` + Name string `json:"name"` + MemoryKiB int `json:"memory_kib"` + Iterations int `json:"iterations"` + Parallelism int `json:"parallelism"` + Salt string `json:"salt"` } type Record struct { @@ -38,23 +42,35 @@ type Record struct { } type VaultFile struct { - Version int `json:"version"` - KDF KDFMeta `json:"kdf"` - Records []Record `json:"records"` + Version int `json:"version"` + KDF KDFMeta `json:"kdf"` + Verifier *Record `json:"verifier,omitempty"` + Records []Record `json:"records"` } type Vault struct { - mu sync.Mutex - path string + mu sync.Mutex + path string masterKey []byte - records map[string][]byte // id -> plaintext - modified bool + records map[string]secretRecord + modified bool +} + +type secretRecord struct { + secretType string + plaintext []byte +} + +type SecretMeta struct { + ID string + Alias string + Type string } func New(path string) *Vault { return &Vault{ path: path, - records: make(map[string][]byte), + records: make(map[string]secretRecord), } } @@ -76,21 +92,26 @@ func Create(path string, masterPassword string) error { kdf := KDFMeta{ Name: "argon2id", - MemoryKiB: 4096, - Iterations: 2, + MemoryKiB: 65536, + Iterations: 3, Parallelism: 1, Salt: base64.StdEncoding.EncodeToString(salt), } fmt.Print("Deriving key...") - key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB)*1024, uint8(kdf.Parallelism), keyLen) + key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB), uint8(kdf.Parallelism), keyLen) + + verifier, err := newVerifierRecord(key) + if err != nil { + return fmt.Errorf("create verifier: %w", err) + } - // Verify key is valid by doing a test encrypt/decrypt vf := VaultFile{ - Version: currentVersion, - KDF: kdf, - Records: []Record{}, + Version: currentVersion, + KDF: kdf, + Verifier: &verifier, + Records: []Record{}, } data, err := json.Marshal(vf) @@ -143,24 +164,32 @@ func (v *Vault) Unlock(masterPassword string) error { return fmt.Errorf("decode salt: %w", err) } - key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen) + key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen) - // Try to decrypt first record to verify password - if len(vf.Records) > 0 { + if vf.Verifier != nil { + if err := verifyRecord(key, *vf.Verifier); err != nil { + return fmt.Errorf("invalid master password") + } + } else if len(vf.Records) > 0 { if _, err := decryptRecord(key, vf.Records[0]); err != nil { return fmt.Errorf("invalid master password") } + } else { + return fmt.Errorf("vault cannot verify master password; recreate empty vault") } v.masterKey = key - v.records = make(map[string][]byte) + v.records = make(map[string]secretRecord) for _, rec := range vf.Records { plaintext, err := decryptRecord(key, rec) if err != nil { return fmt.Errorf("decrypt record %s: %w", rec.ID, err) } - v.records[rec.ID] = plaintext + v.records[rec.ID] = secretRecord{ + secretType: inferSecretType(rec.ID, rec.Type), + plaintext: plaintext, + } } fmt.Println(" done.") @@ -178,7 +207,7 @@ func (v *Vault) Lock() { } } v.masterKey = nil - v.records = make(map[string][]byte) + v.records = make(map[string]secretRecord) } // IsUnlocked returns whether the vault is currently unlocked @@ -197,7 +226,9 @@ func (v *Vault) Put(id string, secretType string, plaintext []byte) error { return fmt.Errorf("vault is locked") } - v.records[id] = plaintext + data := make([]byte, len(plaintext)) + copy(data, plaintext) + v.records[id] = secretRecord{secretType: secretType, plaintext: data} v.modified = true return nil } @@ -211,16 +242,55 @@ func (v *Vault) Get(id string) ([]byte, error) { return nil, fmt.Errorf("vault is locked") } - data, ok := v.records[id] + record, ok := v.records[id] if !ok { return nil, fmt.Errorf("secret not found: %s", id) } - result := make([]byte, len(data)) - copy(result, data) + result := make([]byte, len(record.plaintext)) + copy(result, record.plaintext) return result, nil } +func (v *Vault) HasSecret(id string) bool { + v.mu.Lock() + defer v.mu.Unlock() + _, ok := v.records[id] + return ok +} + +func (v *Vault) ListSecrets() ([]SecretMeta, error) { + v.mu.Lock() + defer v.mu.Unlock() + + if v.masterKey == nil { + return nil, fmt.Errorf("vault is locked") + } + + metas := make([]SecretMeta, 0, len(v.records)) + for id, record := range v.records { + alias, secretType, ok := parseServerSecretID(id) + if !ok { + continue + } + if record.secretType != "" { + secretType = record.secretType + } + metas = append(metas, SecretMeta{ + ID: id, + Alias: alias, + Type: secretType, + }) + } + sort.Slice(metas, func(i, j int) bool { + if metas[i].Alias == metas[j].Alias { + return metas[i].Type < metas[j].Type + } + return metas[i].Alias < metas[j].Alias + }) + return metas, nil +} + // Delete removes a secret func (v *Vault) Delete(id string) { v.mu.Lock() @@ -245,8 +315,8 @@ func (v *Vault) Save() error { kdf := KDFMeta{ Name: "argon2id", - MemoryKiB: 4096, - Iterations: 2, + MemoryKiB: 65536, + Iterations: 3, Parallelism: 1, Salt: base64.StdEncoding.EncodeToString(salt), } @@ -254,18 +324,24 @@ func (v *Vault) Save() error { fmt.Print("Deriving key...") var records []Record - for id, plaintext := range v.records { - rec, err := encryptRecord(v.masterKey, id, plaintext) + for id, record := range v.records { + rec, err := encryptRecordWithType(v.masterKey, id, record.secretType, record.plaintext) if err != nil { return fmt.Errorf("encrypt record %s: %w", id, err) } records = append(records, rec) } + verifier, err := newVerifierRecord(v.masterKey) + if err != nil { + return fmt.Errorf("create verifier: %w", err) + } + vf := VaultFile{ - Version: currentVersion, - KDF: kdf, - Records: records, + Version: currentVersion, + KDF: kdf, + Verifier: &verifier, + Records: records, } data, err := json.Marshal(vf) @@ -309,12 +385,12 @@ func (v *Vault) ChangePassword(newPassword string) error { return fmt.Errorf("generate salt: %w", err) } - newKey := argon2.IDKey([]byte(newPassword), salt, 3, 8192*1024, 1, keyLen) + newKey := argon2.IDKey([]byte(newPassword), salt, 3, 65536, 1, keyLen) kdf := KDFMeta{ Name: "argon2id", - MemoryKiB: 4096, - Iterations: 2, + MemoryKiB: 65536, + Iterations: 3, Parallelism: 1, Salt: base64.StdEncoding.EncodeToString(salt), } @@ -322,18 +398,24 @@ func (v *Vault) ChangePassword(newPassword string) error { fmt.Print("Deriving key...") var records []Record - for id, plaintext := range v.records { - rec, err := encryptRecord(newKey, id, plaintext) + for id, record := range v.records { + rec, err := encryptRecordWithType(newKey, id, record.secretType, record.plaintext) if err != nil { return fmt.Errorf("encrypt record: %w", err) } records = append(records, rec) } + verifier, err := newVerifierRecord(newKey) + if err != nil { + return fmt.Errorf("create verifier: %w", err) + } + vf := VaultFile{ - Version: currentVersion, - KDF: kdf, - Records: records, + Version: currentVersion, + KDF: kdf, + Verifier: &verifier, + Records: records, } data, err := json.Marshal(vf) @@ -382,6 +464,10 @@ func (v *Vault) getSalt() string { } func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) { + return encryptRecordWithType(key, id, "", plaintext) +} + +func encryptRecordWithType(key []byte, id string, secretType string, plaintext []byte) (Record, error) { aead, err := chacha20poly1305.NewX(key) if err != nil { return Record{}, err @@ -396,11 +482,31 @@ func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) { return Record{ ID: id, + Type: secretType, Nonce: base64.StdEncoding.EncodeToString(nonce), Ciphertext: base64.StdEncoding.EncodeToString(ciphertext), }, nil } +func inferSecretType(id string, recordType string) string { + if recordType != "" { + return recordType + } + _, secretType, ok := parseServerSecretID(id) + if !ok { + return "" + } + return secretType +} + +func parseServerSecretID(id string) (string, string, bool) { + parts := strings.Split(id, ":") + if len(parts) != 3 || parts[0] != "server" || parts[1] == "" || parts[2] == "" { + return "", "", false + } + return parts[1], parts[2], true +} + func decryptRecord(key []byte, rec Record) ([]byte, error) { aead, err := chacha20poly1305.NewX(key) if err != nil { @@ -425,6 +531,26 @@ func decryptRecord(key []byte, rec Record) ([]byte, error) { return plaintext, nil } +func newVerifierRecord(key []byte) (Record, error) { + rec, err := encryptRecord(key, verifierID, []byte(verifierPlaintext)) + if err != nil { + return Record{}, err + } + rec.Type = "verifier" + return rec, nil +} + +func verifyRecord(key []byte, rec Record) error { + plaintext, err := decryptRecord(key, rec) + if err != nil { + return err + } + if !SecureCompare(string(plaintext), verifierPlaintext) { + return fmt.Errorf("invalid verifier") + } + return nil +} + // VerifyPassword checks if a master password is correct without unlocking func VerifyPassword(path string, masterPassword string) (bool, error) { data, err := os.ReadFile(path) @@ -442,18 +568,20 @@ func VerifyPassword(path string, masterPassword string) (bool, error) { return false, err } - key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen) + key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen) defer func() { for i := range key { key[i] = 0 } }() - if len(vf.Records) == 0 { - // Empty vault, try a test encryption - return true, nil + if vf.Verifier != nil { + return verifyRecord(key, *vf.Verifier) == nil, nil } + if len(vf.Records) == 0 { + return false, nil + } _, err = decryptRecord(key, vf.Records[0]) return err == nil, nil } diff --git a/internal/vault/vault_test.go b/internal/vault/vault_test.go new file mode 100644 index 0000000..346863d --- /dev/null +++ b/internal/vault/vault_test.go @@ -0,0 +1,218 @@ +package vault + +import ( + "encoding/base64" + "encoding/json" + "os" + "path/filepath" + "testing" + + "golang.org/x/crypto/argon2" +) + +func TestNewEmptyVaultRejectsWrongPassword(t *testing.T) { + path := filepath.Join(t.TempDir(), "vault.bin") + + if err := Create(path, "correct horse"); err != nil { + t.Fatalf("create vault: %v", err) + } + + ok, err := VerifyPassword(path, "wrong horse") + if err != nil { + t.Fatalf("verify wrong password: %v", err) + } + if ok { + t.Fatal("expected wrong password to be rejected for a new empty vault") + } + + v := New(path) + if err := v.Unlock("wrong horse"); err == nil { + t.Fatal("expected unlock with wrong password to fail for a new empty vault") + } +} + +func TestNewEmptyVaultAcceptsCorrectPassword(t *testing.T) { + path := filepath.Join(t.TempDir(), "vault.bin") + + if err := Create(path, "correct horse"); err != nil { + t.Fatalf("create vault: %v", err) + } + + ok, err := VerifyPassword(path, "correct horse") + if err != nil { + t.Fatalf("verify correct password: %v", err) + } + if !ok { + t.Fatal("expected correct password to be accepted for a new empty vault") + } + + v := New(path) + if err := v.Unlock("correct horse"); err != nil { + t.Fatalf("unlock with correct password: %v", err) + } +} + +func TestLegacyVaultWithRecordsStillVerifiesByFirstRecord(t *testing.T) { + path := filepath.Join(t.TempDir(), "vault.bin") + salt := []byte("12345678901234567890123456789012") + key := argon2.IDKey([]byte("correct horse"), salt, 3, 65536, 1, keyLen) + + rec, err := encryptRecord(key, "server:test:ssh_password", []byte("secret")) + if err != nil { + t.Fatalf("encrypt legacy record: %v", err) + } + + data, err := json.Marshal(VaultFile{ + Version: currentVersion, + KDF: KDFMeta{ + Name: "argon2id", + MemoryKiB: 65536, + Iterations: 3, + Parallelism: 1, + Salt: base64.StdEncoding.EncodeToString(salt), + }, + Records: []Record{rec}, + }) + if err != nil { + t.Fatalf("marshal legacy vault: %v", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("write legacy vault: %v", err) + } + + ok, err := VerifyPassword(path, "correct horse") + if err != nil { + t.Fatalf("verify legacy vault: %v", err) + } + if !ok { + t.Fatal("expected legacy vault with records to accept correct password") + } + + ok, err = VerifyPassword(path, "wrong horse") + if err != nil { + t.Fatalf("verify legacy vault with wrong password: %v", err) + } + if ok { + t.Fatal("expected legacy vault with records to reject wrong password") + } +} + +func TestLegacyEmptyVaultWithoutVerifierCannotUnlock(t *testing.T) { + path := filepath.Join(t.TempDir(), "vault.bin") + salt := []byte("12345678901234567890123456789012") + + data, err := json.Marshal(VaultFile{ + Version: currentVersion, + KDF: KDFMeta{ + Name: "argon2id", + MemoryKiB: 65536, + Iterations: 3, + Parallelism: 1, + Salt: base64.StdEncoding.EncodeToString(salt), + }, + Records: []Record{}, + }) + if err != nil { + t.Fatalf("marshal legacy empty vault: %v", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("write legacy empty vault: %v", err) + } + + ok, err := VerifyPassword(path, "any password") + if err != nil { + t.Fatalf("verify legacy empty vault: %v", err) + } + if ok { + t.Fatal("expected legacy empty vault without verifier to be unverifiable") + } + + v := New(path) + if err := v.Unlock("any password"); err == nil { + t.Fatal("expected legacy empty vault without verifier to reject unlock") + } +} + +func TestListSecretsReturnsMetadataWithoutPlaintext(t *testing.T) { + path := filepath.Join(t.TempDir(), "vault.bin") + if err := Create(path, "master"); err != nil { + t.Fatalf("create vault: %v", err) + } + v := New(path) + if err := v.Unlock("master"); err != nil { + t.Fatalf("unlock vault: %v", err) + } + if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil { + t.Fatalf("put password: %v", err) + } + if err := v.Put("server:prod:key_passphrase", "key_passphrase", []byte("secret-passphrase")); err != nil { + t.Fatalf("put passphrase: %v", err) + } + + metas, err := v.ListSecrets() + if err != nil { + t.Fatalf("list secrets: %v", err) + } + if len(metas) != 2 { + t.Fatalf("expected 2 secret metadata records, got %d: %#v", len(metas), metas) + } + if metas[0].Alias != "prod" || metas[0].Type != "key_passphrase" { + t.Fatalf("unexpected first metadata record: %#v", metas[0]) + } + if metas[1].Alias != "prod" || metas[1].Type != "ssh_password" { + t.Fatalf("unexpected second metadata record: %#v", metas[1]) + } +} + +func TestListSecretsPreservesTypesAfterSaveAndUnlock(t *testing.T) { + path := filepath.Join(t.TempDir(), "vault.bin") + if err := Create(path, "master"); err != nil { + t.Fatalf("create vault: %v", err) + } + v := New(path) + if err := v.Unlock("master"); err != nil { + t.Fatalf("unlock vault: %v", err) + } + if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil { + t.Fatalf("put password: %v", err) + } + if err := v.Save(); err != nil { + t.Fatalf("save vault: %v", err) + } + + reopened := New(path) + if err := reopened.Unlock("master"); err != nil { + t.Fatalf("unlock reopened vault: %v", err) + } + metas, err := reopened.ListSecrets() + if err != nil { + t.Fatalf("list reopened secrets: %v", err) + } + if len(metas) != 1 { + t.Fatalf("expected 1 secret metadata record, got %d: %#v", len(metas), metas) + } + if metas[0].ID != "server:prod:ssh_password" || metas[0].Alias != "prod" || metas[0].Type != "ssh_password" { + t.Fatalf("unexpected metadata after reopen: %#v", metas[0]) + } +} + +func TestHasSecretReportsPresenceWithoutReturningValue(t *testing.T) { + path := filepath.Join(t.TempDir(), "vault.bin") + if err := Create(path, "master"); err != nil { + t.Fatalf("create vault: %v", err) + } + v := New(path) + if err := v.Unlock("master"); err != nil { + t.Fatalf("unlock vault: %v", err) + } + if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil { + t.Fatalf("put password: %v", err) + } + + if !v.HasSecret("server:prod:ssh_password") { + t.Fatal("expected saved password to be reported present") + } + if v.HasSecret("server:prod:key_passphrase") { + t.Fatal("expected missing passphrase to be reported absent") + } +}