feat: improve tui and vault handling

This commit is contained in:
mirivlad 2026-05-28 02:25:18 +08:00
parent 73c60b9f93
commit e1d709396b
21 changed files with 2292 additions and 224 deletions

View File

@ -3,9 +3,11 @@ package cmd
import ( import (
"fmt" "fmt"
"strings" "strings"
"syscall"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
"golang.org/x/term"
) )
var addFlags struct { var addFlags struct {
@ -19,7 +21,6 @@ var addFlags struct {
displayName string displayName string
notes string notes string
tags string tags string
password string
} }
var addCmd = &cobra.Command{ var addCmd = &cobra.Command{
@ -59,11 +60,21 @@ func addNonInteractive(alias string) error {
server.DisplayName = alias server.DisplayName = alias
} }
// Handle password auth - store in vault // Handle password/passphrase auth — request interactively, never via argv
if server.AuthMethod == model.AuthPassword { if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
password := addFlags.password secretType := "password"
if password == "" { if server.AuthMethod == model.AuthKeyPassphrase {
return fmt.Errorf("password auth requires --password flag or interactive mode") secretType = "passphrase"
}
fmt.Printf("Enter %s (will be stored in vault, input hidden): ", secretType)
password, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
return fmt.Errorf("read %s: %w", secretType, err)
}
if len(password) == 0 {
return fmt.Errorf("%s cannot be empty", secretType)
} }
v := getOrCreateVault() v := getOrCreateVault()
@ -72,29 +83,14 @@ func addNonInteractive(alias string) error {
} }
vaultKey := fmt.Sprintf("server:%s:ssh_password", alias) vaultKey := fmt.Sprintf("server:%s:ssh_password", alias)
if err := v.Put(vaultKey, "ssh_password", []byte(password)); err != nil { vaultType := "ssh_password"
return fmt.Errorf("store password in vault: %w", err)
}
if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err)
}
}
// Handle key+passphrase - store passphrase in vault
if server.AuthMethod == model.AuthKeyPassphrase { if server.AuthMethod == model.AuthKeyPassphrase {
passphrase := addFlags.password vaultKey = fmt.Sprintf("server:%s:key_passphrase", alias)
if passphrase == "" { vaultType = "key_passphrase"
return fmt.Errorf("key+passphrase auth requires --password flag for the passphrase")
} }
v := getOrCreateVault() if err := v.Put(vaultKey, vaultType, password); err != nil {
if !v.IsUnlocked() { return fmt.Errorf("store %s in vault: %w", secretType, err)
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
}
vaultKey := fmt.Sprintf("server:%s:key_passphrase", alias)
if err := v.Put(vaultKey, "key_passphrase", []byte(passphrase)); err != nil {
return fmt.Errorf("store passphrase in vault: %w", err)
} }
if err := v.Save(); err != nil { if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err) return fmt.Errorf("save vault: %w", err)
@ -132,5 +128,4 @@ func init() {
addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name") addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name")
addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes") addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes")
addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags") addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags")
addCmd.Flags().StringVar(&addFlags.password, "password", "", "SSH password or key passphrase (stored in vault)")
} }

View File

@ -27,6 +27,15 @@ var deleteCmd = &cobra.Command{
return fmt.Errorf("delete server: %w", err) return fmt.Errorf("delete server: %w", err)
} }
// Clean up vault secrets for this server
v := getOrCreateVault()
if v.IsUnlocked() {
cleanupServerSecrets(v, alias)
if err := v.Save(); err != nil {
return fmt.Errorf("save vault after cleanup: %w", err)
}
}
fmt.Println("Deleted.") fmt.Println("Deleted.")
return nil return nil
}, },

View File

@ -2,9 +2,11 @@ package cmd
import ( import (
"fmt" "fmt"
"syscall"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
"github.com/spf13/cobra"
"golang.org/x/term"
) )
var editCmd = &cobra.Command{ var editCmd = &cobra.Command{
@ -18,6 +20,8 @@ var editCmd = &cobra.Command{
return fmt.Errorf("server not found: %s", alias) return fmt.Errorf("server not found: %s", alias)
} }
oldAuthMethod := server.AuthMethod
if parsedHost != "" { if parsedHost != "" {
server.Host = parsedHost server.Host = parsedHost
} }
@ -46,6 +50,41 @@ var editCmd = &cobra.Command{
server.Notes = parsedNotes server.Notes = parsedNotes
} }
if parsedAuth != "" && oldAuthMethod != server.AuthMethod {
v := getOrCreateVault()
if v.IsUnlocked() {
var secret string
if server.AuthMethod == model.AuthPassword {
fmt.Print("Enter new password (stored in vault, input hidden): ")
pw, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
return fmt.Errorf("read password: %w", err)
}
if len(pw) > 0 {
secret = string(pw)
}
} else if server.AuthMethod == model.AuthKeyPassphrase {
fmt.Print("Enter key passphrase (stored in vault, input hidden): ")
pw, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
return fmt.Errorf("read passphrase: %w", err)
}
if len(pw) > 0 {
secret = string(pw)
}
}
if err := syncServerSecrets(v, alias, server, secret); err != nil {
return fmt.Errorf("sync vault secrets: %w", err)
}
if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err)
}
}
}
if err := appDB.UpdateServer(server); err != nil { if err := appDB.UpdateServer(server); err != nil {
return fmt.Errorf("update server: %w", err) return fmt.Errorf("update server: %w", err)
} }

View File

@ -74,8 +74,13 @@ var runCmd = &cobra.Command{
return fmt.Errorf("server not found: %s", alias) return fmt.Errorf("server not found: %s", alias)
} }
// Build ssh args with the command // For password auth, use PTY-wrapper with command
sshArgs := buildSSHArgs(server) if server.AuthMethod == model.AuthPassword {
return runWithPassword(server, command)
}
// For key/agent auth — direct execution
sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command) sshArgs = append(sshArgs, command)
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...) sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
@ -91,17 +96,21 @@ var runCmd = &cobra.Command{
}, },
} }
func buildSSHArgs(server *model.Server) []string { // runWithPassword runs a command on a server with password auth via PTY-wrapper.
var args []string func runWithPassword(server *model.Server, command string) error {
args = append(args, "-p", fmt.Sprintf("%d", server.Port)) v := getOrCreateVault()
if server.IdentityFile != "" { if !v.IsUnlocked() {
args = append(args, "-i", server.IdentityFile) return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
} }
if server.ProxyJump != "" {
args = append(args, "-J", server.ProxyJump) vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias)
password, err := v.Get(vaultKey)
if err != nil {
return fmt.Errorf("get password from vault: %w", err)
} }
args = append(args, "-o", "StrictHostKeyChecking=accept-new")
target := fmt.Sprintf("%s@%s", server.User, server.Host) sshArgs := ssh.BuildSSHArgs(server)
args = append(args, target) sshArgs = append(sshArgs, command)
return args
return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(password))
} }

85
cmd/secrets.go Normal file
View File

@ -0,0 +1,85 @@
package cmd
import (
"fmt"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/ssh"
"github.com/mirivlad/sshkeeper/internal/vault"
)
const (
secretSSHPassword = "ssh_password"
secretKeyPassphrase = "key_passphrase"
secretSudoPassword = "sudo_password"
)
var serverSecretTypes = []string{
secretSSHPassword,
secretKeyPassphrase,
secretSudoPassword,
}
func serverSecretID(alias, secretType string) string {
return fmt.Sprintf("server:%s:%s", alias, secretType)
}
func cleanupServerSecrets(v *vault.Vault, alias string) {
for _, secretType := range serverSecretTypes {
v.Delete(serverSecretID(alias, secretType))
}
}
func syncServerSecrets(v *vault.Vault, oldAlias string, server *model.Server, secret string) error {
if oldAlias == "" {
oldAlias = server.Alias
}
if oldAlias != server.Alias {
for _, secretType := range serverSecretTypes {
oldID := serverSecretID(oldAlias, secretType)
data, err := v.Get(oldID)
if err == nil {
if err := v.Put(serverSecretID(server.Alias, secretType), secretType, data); err != nil {
return err
}
}
v.Delete(oldID)
}
}
switch server.AuthMethod {
case model.AuthPassword:
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
if secret != "" {
return v.Put(serverSecretID(server.Alias, secretSSHPassword), secretSSHPassword, []byte(secret))
}
case model.AuthKeyPassphrase:
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
if secret != "" {
return v.Put(serverSecretID(server.Alias, secretKeyPassphrase), secretKeyPassphrase, []byte(secret))
}
default:
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
}
return nil
}
func deleteVaultSecrets(v *vault.Vault, alias string, secretType string) error {
if secretType != "" {
v.Delete(serverSecretID(alias, secretType))
return nil
}
cleanupServerSecrets(v, alias)
return nil
}
func formTestVaultFunc(getVault ssh.VaultFunc, server *model.Server, formSecret string) ssh.VaultFunc {
return func(serverAlias string, secretType string) (string, error) {
if (secretType == secretSSHPassword || secretType == secretKeyPassphrase) && formSecret != "" {
return formSecret, nil
}
return getVault(serverAlias, secretType)
}
}

186
cmd/secrets_test.go Normal file
View File

@ -0,0 +1,186 @@
package cmd
import (
"fmt"
"path/filepath"
"testing"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/vault"
)
func newUnlockedTestVault(t *testing.T) *vault.Vault {
t.Helper()
path := filepath.Join(t.TempDir(), "vault.bin")
if err := vault.Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := vault.New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
return v
}
func mustPutSecret(t *testing.T, v *vault.Vault, alias, secretType, value string) {
t.Helper()
if err := v.Put(serverSecretID(alias, secretType), secretType, []byte(value)); err != nil {
t.Fatalf("put %s: %v", secretType, err)
}
}
func mustGetSecret(t *testing.T, v *vault.Vault, alias, secretType string) string {
t.Helper()
data, err := v.Get(serverSecretID(alias, secretType))
if err != nil {
t.Fatalf("get %s: %v", secretType, err)
}
return string(data)
}
func assertSecretMissing(t *testing.T, v *vault.Vault, alias, secretType string) {
t.Helper()
if data, err := v.Get(serverSecretID(alias, secretType)); err == nil {
t.Fatalf("expected %s for %s to be missing, got %q", secretType, alias, string(data))
}
}
func TestCleanupServerSecretsRemovesAuthAndSudoSecrets(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
mustPutSecret(t, v, "prod", "sudo_password", "sudo")
cleanupServerSecrets(v, "prod")
assertSecretMissing(t, v, "prod", "ssh_password")
assertSecretMissing(t, v, "prod", "key_passphrase")
assertSecretMissing(t, v, "prod", "sudo_password")
}
func TestSyncServerSecretsDeletesAuthSecretsForKeyAuth(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
server := &model.Server{Alias: "prod", AuthMethod: model.AuthKey}
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
t.Fatalf("sync secrets: %v", err)
}
assertSecretMissing(t, v, "prod", "ssh_password")
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestSyncServerSecretsKeepsExistingPasswordWhenEditPasswordIsBlank(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "old-password")
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
t.Fatalf("sync secrets: %v", err)
}
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "old-password" {
t.Fatalf("expected password to remain, got %q", got)
}
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestSyncServerSecretsRenamesSecretsBeforeApplyingAuthCleanup(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "old", "ssh_password", "password")
mustPutSecret(t, v, "old", "key_passphrase", "passphrase")
mustPutSecret(t, v, "old", "sudo_password", "sudo")
server := &model.Server{Alias: "new", AuthMethod: model.AuthKeyPassphrase}
if err := syncServerSecrets(v, "old", server, ""); err != nil {
t.Fatalf("sync secrets: %v", err)
}
assertSecretMissing(t, v, "old", "ssh_password")
assertSecretMissing(t, v, "old", "key_passphrase")
assertSecretMissing(t, v, "old", "sudo_password")
assertSecretMissing(t, v, "new", "ssh_password")
if got := mustGetSecret(t, v, "new", "key_passphrase"); got != "passphrase" {
t.Fatalf("expected key passphrase to move, got %q", got)
}
if got := mustGetSecret(t, v, "new", "sudo_password"); got != "sudo" {
t.Fatalf("expected sudo password to move, got %q", got)
}
}
func TestSyncServerSecretsStoresNewSecretForSelectedAuthMethod(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
if err := syncServerSecrets(v, "prod", server, "new-password"); err != nil {
t.Fatalf("sync secrets: %v", err)
}
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "new-password" {
t.Fatalf("expected new password, got %q", got)
}
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestDeleteVaultSecretsForAliasAndType(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
if err := deleteVaultSecrets(v, "prod", "ssh_password"); err != nil {
t.Fatalf("delete vault secret: %v", err)
}
assertSecretMissing(t, v, "prod", "ssh_password")
if got := mustGetSecret(t, v, "prod", "key_passphrase"); got != "passphrase" {
t.Fatalf("expected passphrase to remain, got %q", got)
}
}
func TestDeleteVaultSecretsForAlias(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
if err := deleteVaultSecrets(v, "prod", ""); err != nil {
t.Fatalf("delete vault secrets: %v", err)
}
assertSecretMissing(t, v, "prod", "ssh_password")
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestFormTestVaultFuncUsesSavedSecretWhenFormSecretBlank(t *testing.T) {
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
if serverAlias != "prod" || secretType != "ssh_password" {
return "", fmt.Errorf("unexpected secret lookup %s %s", serverAlias, secretType)
}
return "saved-password", nil
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "")
got, err := vaultFunc("prod", "ssh_password")
if err != nil {
t.Fatalf("get password: %v", err)
}
if got != "saved-password" {
t.Fatalf("expected saved password, got %q", got)
}
}
func TestFormTestVaultFuncUsesFormSecretWhenProvided(t *testing.T) {
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
return "", fmt.Errorf("saved secret should not be used")
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "typed-password")
got, err := vaultFunc("prod", "ssh_password")
if err != nil {
t.Fatalf("get password: %v", err)
}
if got != "typed-password" {
t.Fatalf("expected typed password, got %q", got)
}
}

View File

@ -2,8 +2,11 @@ package cmd
import ( import (
"fmt" "fmt"
"os"
"os/exec"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/ssh"
) )
var templateCmd = &cobra.Command{ var templateCmd = &cobra.Command{
@ -70,13 +73,15 @@ var runTemplateCmd = &cobra.Command{
alias := args[0] alias := args[0]
templateName := args[1] templateName := args[1]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
templates, err := appDB.GetCommandTemplates(alias) templates, err := appDB.GetCommandTemplates(alias)
if err != nil { if err != nil {
return fmt.Errorf("list templates: %w", err) return fmt.Errorf("list templates: %w", err)
} }
if len(templates) == 0 {
return fmt.Errorf("server not found or no templates: %s", alias)
}
var command string var command string
for _, t := range templates { for _, t := range templates {
@ -91,7 +96,21 @@ var runTemplateCmd = &cobra.Command{
} }
fmt.Printf("Running '%s' on %s...\n", command, alias) fmt.Printf("Running '%s' on %s...\n", command, alias)
return nil
// Build ssh args with the command
sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command)
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
if err := sshCmd.Start(); err != nil {
return fmt.Errorf("start ssh: %w", err)
}
return sshCmd.Wait()
}, },
} }

View File

@ -36,41 +36,43 @@ func runTUI() error {
return appDB.SearchServers(query) return appDB.SearchServers(query)
} }
tui.DeleteServer = func(alias string) error { tui.DeleteServer = func(alias string) error {
return appDB.DeleteServer(alias) if err := appDB.DeleteServer(alias); err != nil {
return err
}
v := getOrCreateVault()
if v.IsUnlocked() {
cleanupServerSecrets(v, alias)
if err := v.Save(); err != nil {
return fmt.Errorf("save vault after cleanup: %w", err)
}
}
return nil
} }
tui.TestConnection = func(server *model.Server) (bool, string) { tui.TestConnection = func(server *model.Server) (bool, string) {
return ssh.Test(cfg, server, vaultFunc) return ssh.Test(cfg, server, vaultFunc)
} }
tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) { tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) {
directVaultFunc := func(sa string, st string) (string, error) { return ssh.Test(cfg, server, formTestVaultFunc(vaultFunc, server, password))
if st == "ssh_password" || st == "key_passphrase" {
return password, nil
} }
return vaultFunc(sa, st) tui.SaveServer = func(server *model.Server, password string, oldAlias string) error {
}
return ssh.Test(cfg, server, directVaultFunc)
}
tui.SaveServer = func(server *model.Server, password string) error {
if password != "" {
v := getOrCreateVault() v := getOrCreateVault()
vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias) if v.IsUnlocked() {
secretType := "ssh_password" if err := syncServerSecrets(v, oldAlias, server, password); err != nil {
if server.AuthMethod == model.AuthKeyPassphrase { return fmt.Errorf("sync vault secrets: %w", err)
vaultKey = fmt.Sprintf("server:%s:key_passphrase", server.Alias)
secretType = "key_passphrase"
}
if err := v.Put(vaultKey, secretType, []byte(password)); err != nil {
return fmt.Errorf("store secret: %w", err)
} }
if err := v.Save(); err != nil { if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err) return fmt.Errorf("save vault: %w", err)
} }
} }
existing, _ := appDB.GetServer(server.Alias) lookupAlias := server.Alias
if oldAlias != "" {
lookupAlias = oldAlias
}
existing, _ := appDB.GetServer(lookupAlias)
if existing != nil { if existing != nil {
server.ID = existing.ID server.ID = existing.ID
return appDB.UpdateServer(server) return appDB.UpdateServerByAlias(existing.Alias, server)
} }
return appDB.CreateServer(server) return appDB.CreateServer(server)
} }
@ -84,6 +86,16 @@ func runTUI() error {
tui.DeleteGroup = func(name string) error { tui.DeleteGroup = func(name string) error {
return appDB.DeleteGroup(name) return appDB.DeleteGroup(name)
} }
tui.UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
return appDB.UpdateTestResult(alias, status, testErr)
}
tui.HasSecret = func(alias string, secretType string) bool {
v := getOrCreateVault()
if !v.IsUnlocked() {
return false
}
return v.HasSecret(serverSecretID(alias, secretType))
}
// Run TUI in a loop — if user requests connect, handle it and restart TUI // Run TUI in a loop — if user requests connect, handle it and restart TUI
for { for {
@ -132,5 +144,3 @@ func runTUI() error {
return nil return nil
} }
} }

View File

@ -3,11 +3,12 @@ package cmd
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"syscall" "syscall"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/config"
"github.com/mirivlad/sshkeeper/internal/vault" "github.com/mirivlad/sshkeeper/internal/vault"
"github.com/spf13/cobra"
"golang.org/x/term" "golang.org/x/term"
) )
@ -160,9 +161,75 @@ var vaultChangePasswordCmd = &cobra.Command{
}, },
} }
var vaultListCmd = &cobra.Command{
Use: "list",
Short: "List stored secret metadata",
RunE: func(cmd *cobra.Command, args []string) error {
v := getOrCreateVault()
if !v.IsUnlocked() {
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
}
output, err := formatVaultSecretsList(v)
if err != nil {
return err
}
fmt.Print(output)
return nil
},
}
var vaultDeleteCmd = &cobra.Command{
Use: "delete <alias> [type]",
Short: "Delete stored secrets for a server",
Args: cobra.RangeArgs(1, 2),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
secretType := ""
if len(args) == 2 {
secretType = args[1]
}
v := getOrCreateVault()
if !v.IsUnlocked() {
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
}
if err := deleteVaultSecrets(v, alias, secretType); err != nil {
return err
}
if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err)
}
if secretType == "" {
fmt.Printf("Deleted secrets for %s.\n", alias)
} else {
fmt.Printf("Deleted %s for %s.\n", secretType, alias)
}
return nil
},
}
func formatVaultSecretsList(v *vault.Vault) (string, error) {
metas, err := v.ListSecrets()
if err != nil {
return "", err
}
if len(metas) == 0 {
return "No secrets stored.\n", nil
}
var b strings.Builder
fmt.Fprintf(&b, "%-24s %-18s\n", "ALIAS", "TYPE")
for _, meta := range metas {
fmt.Fprintf(&b, "%-24s %-18s\n", meta.Alias, meta.Type)
}
return b.String(), nil
}
func init() { func init() {
vaultCmd.AddCommand(vaultUnlockCmd) vaultCmd.AddCommand(vaultUnlockCmd)
vaultCmd.AddCommand(vaultLockCmd) vaultCmd.AddCommand(vaultLockCmd)
vaultCmd.AddCommand(vaultStatusCmd) vaultCmd.AddCommand(vaultStatusCmd)
vaultCmd.AddCommand(vaultChangePasswordCmd) vaultCmd.AddCommand(vaultChangePasswordCmd)
vaultCmd.AddCommand(vaultListCmd)
vaultCmd.AddCommand(vaultDeleteCmd)
} }

40
cmd/vault_test.go Normal file
View File

@ -0,0 +1,40 @@
package cmd
import (
"strings"
"testing"
)
func TestFormatVaultSecretsListDoesNotExposeSecretValues(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "super-secret")
mustPutSecret(t, v, "stage", "key_passphrase", "also-secret")
output, err := formatVaultSecretsList(v)
if err != nil {
t.Fatalf("format vault secrets list: %v", err)
}
for _, want := range []string{"prod", "ssh_password", "stage", "key_passphrase"} {
if !strings.Contains(output, want) {
t.Fatalf("expected output to contain %q\noutput:\n%s", want, output)
}
}
for _, secretValue := range []string{"super-secret", "also-secret"} {
if strings.Contains(output, secretValue) {
t.Fatalf("expected output not to expose secret value %q\noutput:\n%s", secretValue, output)
}
}
}
func TestFormatVaultSecretsListHandlesEmptyVault(t *testing.T) {
v := newUnlockedTestVault(t)
output, err := formatVaultSecretsList(v)
if err != nil {
t.Fatalf("format empty vault secrets list: %v", err)
}
if !strings.Contains(output, "No secrets stored.") {
t.Fatalf("expected empty output message, got:\n%s", output)
}
}

View File

@ -0,0 +1,295 @@
# Server List Scalability Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Keep the server list usable when the saved server count is larger than the terminal height.
**Architecture:** The list screen should render a bounded table viewport instead of rendering every server row. Selection remains owned by the existing `bubbles/list.Model`, while `viewServerList` derives a visible row range around the selected item and keeps the selected-server detail panel and footer visible.
**Tech Stack:** Go, Bubble Tea, Bubbles list, Lip Gloss, existing TUI tests in `internal/tui/app_test.go`.
---
### Task 1: Add Regression Coverage For Long Server Lists
**Files:**
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Add a test for a constrained terminal height**
Add a test that creates more servers than can fit on screen, sets a small terminal size, renders the list, and verifies the selected details and footer are still visible.
```go
func TestServerListViewKeepsDetailsVisibleWithManyServers(t *testing.T) {
servers := make([]*model.Server, 45)
for i := range servers {
servers[i] = &model.Server{
Alias: fmt.Sprintf("server-%02d", i+1),
DisplayName: fmt.Sprintf("Server %02d", i+1),
Host: fmt.Sprintf("host-%02d.example.org", i+1),
Port: 22,
User: "mirivlad",
AuthMethod: model.AuthKey,
LastTestStatus: model.TestUnknown,
}
}
m := New(servers)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
model := updated.(*tuiModel)
view := model.View()
if !strings.Contains(view, "Server 01") {
t.Fatalf("expected first selected server to be visible:\n%s", view)
}
if !strings.Contains(view, "Selected") {
t.Fatalf("expected selected server details to remain visible:\n%s", view)
}
if !strings.Contains(view, "Enter connect") {
t.Fatalf("expected footer to remain visible:\n%s", view)
}
if count := strings.Count(view, "server-"); count >= len(servers) {
t.Fatalf("expected bounded row rendering, rendered %d server aliases", count)
}
}
```
- [ ] **Step 2: Run the focused test and confirm it fails**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
```
Expected: FAIL because the current table renders every server and can push the detail panel/footer below the visible terminal area.
### Task 2: Compute A Visible Row Window
**Files:**
- Modify: `internal/tui/app.go`
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Add focused tests for visible range calculation**
Add tests for a helper that computes the inclusive start and exclusive end indexes for rendered rows.
```go
func TestVisibleServerRangeKeepsSelectionInsideWindow(t *testing.T) {
tests := []struct {
name string
total int
selected int
available int
wantStart int
wantEnd int
}{
{name: "first page", total: 40, selected: 0, available: 10, wantStart: 0, wantEnd: 10},
{name: "middle page", total: 40, selected: 20, available: 10, wantStart: 11, wantEnd: 21},
{name: "last page", total: 40, selected: 39, available: 10, wantStart: 30, wantEnd: 40},
{name: "all fit", total: 5, selected: 3, available: 10, wantStart: 0, wantEnd: 5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
start, end := visibleServerRange(tt.total, tt.selected, tt.available)
if start != tt.wantStart || end != tt.wantEnd {
t.Fatalf("visibleServerRange() = %d, %d; want %d, %d", start, end, tt.wantStart, tt.wantEnd)
}
})
}
}
```
- [ ] **Step 2: Implement `visibleServerRange`**
Add a small helper near `selectedServer`.
```go
func visibleServerRange(total, selected, available int) (int, int) {
if total <= 0 || available <= 0 {
return 0, 0
}
if available >= total {
return 0, total
}
if selected < 0 {
selected = 0
}
if selected >= total {
selected = total - 1
}
start := selected - available + 1
if start < 0 {
start = 0
}
end := start + available
if end > total {
end = total
start = end - available
}
return start, end
}
```
- [ ] **Step 3: Run helper tests**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestVisibleServerRangeKeepsSelectionInsideWindow -count=1
```
Expected: PASS.
### Task 3: Render Only Rows That Fit
**Files:**
- Modify: `internal/tui/app.go`
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Reserve terminal space for fixed UI blocks**
Add a helper that decides how many server rows may be rendered while keeping the selected details and footer visible.
```go
func (m *tuiModel) visibleServerRows() int {
if m.height <= 0 {
return len(m.servers)
}
const fixedRows = 16
rows := m.height - fixedRows
if rows < 3 {
return 3
}
return rows
}
```
- [ ] **Step 2: Use the visible range in `viewServerList`**
In `viewServerList`, replace the loop over all servers with a bounded loop:
```go
selectedIndex := m.list.Index()
start, end := visibleServerRange(len(m.servers), selectedIndex, m.visibleServerRows())
for _, server := range m.servers[start:end] {
// existing row rendering body stays unchanged
}
```
Then render a compact range hint when rows are hidden:
```go
if len(m.servers) > end-start {
b.WriteString(helpStyle.Render(fmt.Sprintf(" Showing %d-%d of %d", start+1, end, len(m.servers))))
b.WriteString("\n")
}
```
- [ ] **Step 3: Run long-list regression test**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
```
Expected: PASS.
### Task 4: Verify Navigation Still Works
**Files:**
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Add a test for moving selection beyond the first window**
Use the existing `m.list.Update` path by sending `tea.KeyDown` messages and confirm the rendered window follows the selected server.
```go
func TestServerListViewScrollsWithSelection(t *testing.T) {
servers := make([]*model.Server, 45)
for i := range servers {
servers[i] = &model.Server{
Alias: fmt.Sprintf("server-%02d", i+1),
DisplayName: fmt.Sprintf("Server %02d", i+1),
Host: fmt.Sprintf("host-%02d.example.org", i+1),
Port: 22,
User: "mirivlad",
AuthMethod: model.AuthKey,
}
}
m := New(servers)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
model := updated.(*tuiModel)
for i := 0; i < 20; i++ {
updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyDown})
model = updated.(*tuiModel)
}
view := model.View()
if !strings.Contains(view, "Server 21") {
t.Fatalf("expected selected server to be visible after navigation:\n%s", view)
}
if !strings.Contains(view, "Showing") {
t.Fatalf("expected range hint for long server list:\n%s", view)
}
}
```
- [ ] **Step 2: Run TUI tests**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -count=1
```
Expected: PASS.
### Task 5: Final Verification And Build
**Files:**
- No source edits expected.
- [ ] **Step 1: Run the full test suite**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./...
```
Expected: all packages pass.
- [ ] **Step 2: Rebuild the project binary**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go build -o bin/sshkeeper .
```
Expected: exit code 0 and updated `bin/sshkeeper`.
- [ ] **Step 3: Commit the implementation**
Run:
```bash
git add internal/tui/app.go internal/tui/app_test.go bin/sshkeeper
git commit -m "fix: keep server list usable with many servers"
```
Expected: commit succeeds.
---
## Self-Review
- Spec coverage: the plan covers the known failure mode for 40+ servers, keeps selected details visible, keeps the footer visible, and preserves existing `bubbles/list` navigation.
- Placeholder scan: no `TBD`, `TODO`, or open-ended implementation placeholders remain.
- Type consistency: helper names and files match the current TUI code shape.

2
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/creack/pty v1.1.24 github.com/creack/pty v1.1.24
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.52.0
golang.org/x/sys v0.45.0
golang.org/x/term v0.43.0 golang.org/x/term v0.43.0
modernc.org/sqlite v1.50.1 modernc.org/sqlite v1.50.1
) )
@ -41,7 +42,6 @@ require (
github.com/sahilm/fuzzy v0.1.1 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/spf13/pflag v1.0.9 // indirect github.com/spf13/pflag v1.0.9 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.37.0 // indirect
modernc.org/libc v1.72.3 // indirect modernc.org/libc v1.72.3 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

View File

@ -30,6 +30,17 @@ func (db *DB) UpdateServer(s *model.Server) error {
return err return err
} }
func (db *DB) UpdateServerByAlias(oldAlias string, s *model.Server) error {
_, err := db.conn.Exec(`
UPDATE servers SET
alias=?, display_name=?, host=?, port=?, user=?, auth_method=?,
identity_file=?, proxy_jump=?, group_name=?, notes=?, updated_at=CURRENT_TIMESTAMP
WHERE alias=?`,
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, oldAlias)
return err
}
func (db *DB) DeleteServer(alias string) error { func (db *DB) DeleteServer(alias string) error {
_, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias) _, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias)
return err return err

View File

@ -0,0 +1,47 @@
package db
import (
"testing"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestUpdateServerByAliasCanRenameAlias(t *testing.T) {
db, err := Open(t.TempDir())
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
server := &model.Server{
Alias: "old",
Host: "old.example",
Port: 22,
User: "root",
AuthMethod: model.AuthPassword,
}
if err := db.CreateServer(server); err != nil {
t.Fatalf("create server: %v", err)
}
server.Alias = "new"
server.Host = "new.example"
server.AuthMethod = model.AuthKey
if err := db.UpdateServerByAlias("old", server); err != nil {
t.Fatalf("update server by old alias: %v", err)
}
if _, err := db.GetServer("old"); err == nil {
t.Fatal("expected old alias to be gone")
}
got, err := db.GetServer("new")
if err != nil {
t.Fatalf("get new alias: %v", err)
}
if got.ID != server.ID {
t.Fatalf("expected ID to stay %d, got %d", server.ID, got.ID)
}
if got.Host != "new.example" || got.AuthMethod != model.AuthKey {
t.Fatalf("unexpected updated server: %#v", got)
}
}

View File

@ -5,7 +5,6 @@ import (
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"time"
"github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/config"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
@ -14,7 +13,7 @@ import (
type VaultFunc func(serverAlias string, secretType string) (string, error) type VaultFunc func(serverAlias string, secretType string) (string, error)
func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error { func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error {
args := buildArgs(server) args := BuildSSHArgs(server)
switch server.AuthMethod { switch server.AuthMethod {
case model.AuthPassword: case model.AuthPassword:
@ -25,8 +24,7 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
return ConnectWithPassword(cfg.SSH.Binary, args, password) return ConnectWithPassword(cfg.SSH.Binary, args, password)
case model.AuthKeyPassphrase: case model.AuthKeyPassphrase:
// For key+passphrase, we need to handle the passphrase // For key+passphrase, let ssh-agent handle it or prompt normally
// For now, let ssh-agent handle it or prompt normally
// TODO: use ssh-agent or similar // TODO: use ssh-agent or similar
fallthrough fallthrough
@ -46,13 +44,11 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
} }
func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) { func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) {
args := buildArgs(server) args := BuildSSHArgs(server)
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec)) args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec))
switch server.AuthMethod { switch server.AuthMethod {
case model.AuthPassword: case model.AuthPassword:
// For password auth, we can't use BatchMode
// Use a short timeout and try to connect
args = append(args, "-o", "NumberOfPasswordPrompts=1") args = append(args, "-o", "NumberOfPasswordPrompts=1")
password, err := getVault(server.Alias, "ssh_password") password, err := getVault(server.Alias, "ssh_password")
if err != nil { if err != nil {
@ -81,40 +77,29 @@ func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, s
} }
} }
// testWithPassword tests SSH connection with password auth via PTY-wrapper.
// It connects, sends the password, runs the test command, and checks the output.
func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) { func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) {
// For password test, we use PTY approach with a short timeout
// This is a simplified version - in production, use ConnectWithPassword
// with a test command
args = append(args, cfg.SSH.TestCommand) args = append(args, cfg.SSH.TestCommand)
cmd := exec.Command(cfg.SSH.Binary, args...) ok, output := connectWithPasswordAndRead(cfg.SSH.Binary, args, password, cfg.SSH.ConnectTimeoutSec)
cmd.Stdin = nil if !ok {
cmd.Stdout = nil return false, output
cmd.Stderr = nil
// Use a timeout
done := make(chan error, 1)
if err := cmd.Start(); err != nil {
return false, err.Error()
} }
go func() { result := strings.TrimSpace(output)
done <- cmd.Wait() if result == "SSHKEEPER_OK" {
}()
select {
case err := <-done:
if err != nil {
return false, err.Error()
}
return true, "" return true, ""
case <-time.After(time.Duration(cfg.SSH.ConnectTimeoutSec) * time.Second):
cmd.Process.Kill()
return false, "connection timeout"
} }
// The output might have the test command echo before the result
if strings.Contains(result, "SSHKEEPER_OK") {
return true, ""
}
return false, result
} }
func buildArgs(server *model.Server) []string { // BuildSSHArgs builds the SSH command arguments for a server profile.
func BuildSSHArgs(server *model.Server) []string {
var args []string var args []string
args = append(args, "-p", fmt.Sprintf("%d", server.Port)) args = append(args, "-p", fmt.Sprintf("%d", server.Port))
@ -127,8 +112,6 @@ func buildArgs(server *model.Server) []string {
args = append(args, "-J", server.ProxyJump) args = append(args, "-J", server.ProxyJump)
} }
// Disable strict host key checking for first connection
// In production, this should be configurable
args = append(args, "-o", "StrictHostKeyChecking=accept-new") args = append(args, "-o", "StrictHostKeyChecking=accept-new")
target := fmt.Sprintf("%s@%s", server.User, server.Host) target := fmt.Sprintf("%s@%s", server.User, server.Host)

View File

@ -7,19 +7,19 @@ import (
"os/exec" "os/exec"
"regexp" "regexp"
"strings" "strings"
"sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
"github.com/creack/pty" "github.com/creack/pty"
"golang.org/x/sys/unix"
"golang.org/x/term" "golang.org/x/term"
) )
var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`) var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`)
// ConnectWithPassword runs SSH through a PTY, detects the password prompt,
// sends the password, and then bridges the user terminal to the SSH session.
func ConnectWithPassword(sshBinary string, args []string, password string) error { func ConnectWithPassword(sshBinary string, args []string, password string) error {
// Start SSH with PTY
cmd := exec.Command(sshBinary, args...) cmd := exec.Command(sshBinary, args...)
cmd.Env = os.Environ() cmd.Env = os.Environ()
cmd.SysProcAttr = &syscall.SysProcAttr{ cmd.SysProcAttr = &syscall.SysProcAttr{
@ -33,18 +33,14 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
} }
defer ptmx.Close() defer ptmx.Close()
// Save terminal state and set to raw
oldState, err := term.MakeRaw(int(os.Stdin.Fd())) oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil { if err != nil {
return fmt.Errorf("set raw terminal: %w", err) return fmt.Errorf("set raw terminal: %w", err)
} }
defer term.Restore(int(os.Stdin.Fd()), oldState)
// Channel to signal when password has been sent passwordSent := new(atomic.Bool)
passwordSent := make(chan bool, 1)
done := make(chan error, 1) done := make(chan error, 1)
// Read from PTY, detect password prompt
go func() { go func() {
buf := make([]byte, 4096) buf := make([]byte, 4096)
var accumulated strings.Builder var accumulated strings.Builder
@ -54,22 +50,17 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
if n > 0 { if n > 0 {
data := buf[:n] data := buf[:n]
accumulated.Write(data) accumulated.Write(data)
// Write to stdout
os.Stdout.Write(data) os.Stdout.Write(data)
// Check for password prompt if !passwordSent.Load() {
if !<-passwordSent {
text := accumulated.String() text := accumulated.String()
if passwordPromptRe.MatchString(text) { if passwordPromptRe.MatchString(text) {
passwordSent <- true passwordSent.Store(true)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
ptmx.Write([]byte(password + "\r")) ptmx.Write([]byte(password + "\r"))
continue
} }
} }
// Reset accumulated buffer periodically to avoid unbounded growth
if accumulated.Len() > 8192 { if accumulated.Len() > 8192 {
s := accumulated.String() s := accumulated.String()
accumulated.Reset() accumulated.Reset()
@ -87,14 +78,142 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
} }
}() }()
// Copy stdin to PTY stopInput, waitInput, restoreInput, err := forwardInputToPTY(ptmx)
go func() { if err != nil {
io.Copy(ptmx, os.Stdin) term.Restore(int(os.Stdin.Fd()), oldState)
}() return err
}
// Wait for command completion
err = cmd.Wait() err = cmd.Wait()
passwordSent <- false // signal to stop close(stopInput)
waitInput()
if restoreErr := restoreInput(); err == nil && restoreErr != nil {
err = restoreErr
}
if restoreErr := term.Restore(int(os.Stdin.Fd()), oldState); err == nil && restoreErr != nil {
err = restoreErr
}
return err return err
} }
func forwardInputToPTY(ptmx *os.File) (chan struct{}, func(), func() error, error) {
stdinFD := int(os.Stdin.Fd())
flags, err := unix.FcntlInt(uintptr(stdinFD), unix.F_GETFL, 0)
if err != nil {
return nil, nil, nil, fmt.Errorf("get stdin flags: %w", err)
}
if err := unix.SetNonblock(stdinFD, true); err != nil {
return nil, nil, nil, fmt.Errorf("set stdin nonblocking: %w", err)
}
stop := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 4096)
for {
select {
case <-stop:
return
default:
}
n, readErr := unix.Read(stdinFD, buf)
if n > 0 {
if _, writeErr := ptmx.Write(buf[:n]); writeErr != nil {
return
}
}
if readErr == nil {
continue
}
if readErr == unix.EAGAIN || readErr == unix.EWOULDBLOCK {
select {
case <-stop:
return
case <-time.After(10 * time.Millisecond):
continue
}
}
return
}
}()
restore := func() error {
_, err := unix.FcntlInt(uintptr(stdinFD), unix.F_SETFL, flags)
if err != nil {
return fmt.Errorf("restore stdin flags: %w", err)
}
return nil
}
return stop, wg.Wait, restore, nil
}
// connectWithPasswordAndRead runs SSH through a PTY, sends the password,
// collects all output, and returns it. Used for non-interactive testing.
// Returns (true, output) on success, (false, error) on failure.
func connectWithPasswordAndRead(sshBinary string, args []string, password string, timeoutSec int) (bool, string) {
cmd := exec.Command(sshBinary, args...)
cmd.Env = os.Environ()
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
Setctty: true,
}
ptmx, err := pty.Start(cmd)
if err != nil {
return false, fmt.Sprintf("start ssh with pty: %v", err)
}
defer ptmx.Close()
passwordSent := false
done := make(chan string, 1)
// Read from PTY, detect password prompt, collect output
go func() {
buf := make([]byte, 4096)
var accumulated strings.Builder
for {
n, err := ptmx.Read(buf)
if n > 0 {
data := buf[:n]
accumulated.Write(data)
// Check for password prompt
if !passwordSent {
text := accumulated.String()
if passwordPromptRe.MatchString(text) {
passwordSent = true
time.Sleep(100 * time.Millisecond)
ptmx.Write([]byte(password + "\r"))
continue
}
}
// Reset accumulated buffer periodically
if accumulated.Len() > 16384 {
s := accumulated.String()
accumulated.Reset()
accumulated.WriteString(s[len(s)-4096:])
}
}
if err != nil {
done <- accumulated.String()
return
}
}
}()
// Wait for command completion or timeout
select {
case output := <-done:
return true, output
case <-time.After(time.Duration(timeoutSec) * time.Second):
cmd.Process.Kill()
return false, "connection timeout"
}
}

63
internal/ssh/pty_test.go Normal file
View File

@ -0,0 +1,63 @@
package ssh
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/creack/pty"
)
func TestConnectWithPasswordDoesNotConsumeInputAfterReturn(t *testing.T) {
script := filepath.Join(t.TempDir(), "fake-ssh")
if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf 'password: '\nsleep 0.1\nexit 0\n"), 0o700); err != nil {
t.Fatalf("write fake ssh: %v", err)
}
master, slave, err := pty.Open()
if err != nil {
t.Fatalf("open stdin pty: %v", err)
}
defer master.Close()
defer slave.Close()
oldStdin := os.Stdin
oldStdout := os.Stdout
stdoutSink, err := os.CreateTemp(t.TempDir(), "stdout")
if err != nil {
t.Fatalf("create stdout sink: %v", err)
}
defer stdoutSink.Close()
os.Stdin = slave
os.Stdout = stdoutSink
defer func() {
os.Stdin = oldStdin
os.Stdout = oldStdout
}()
if err := ConnectWithPassword(script, nil, "secret"); err != nil {
t.Fatalf("connect with password: %v", err)
}
if _, err := master.Write([]byte("\n")); err != nil {
t.Fatalf("write next enter: %v", err)
}
time.Sleep(100 * time.Millisecond)
readDone := make(chan []byte, 1)
go func() {
buf := make([]byte, 1)
n, _ := slave.Read(buf)
readDone <- buf[:n]
}()
select {
case got := <-readDone:
if len(got) != 1 || got[0] != '\n' {
t.Fatalf("expected to read newline, got %q", got)
}
case <-time.After(300 * time.Millisecond):
t.Fatal("expected next Enter to remain readable after ConnectWithPassword returned")
}
}

View File

@ -27,6 +27,9 @@ var (
Bold(true) Bold(true)
normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")) normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15"))
selectedRowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")).Background(lipgloss.Color("4"))
listHeaderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("14")).Bold(true)
sectionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")).Bold(true).MarginTop(1)
testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true) testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true)
testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true) testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true)
@ -68,8 +71,12 @@ type serverItem struct {
} }
func (i serverItem) Title() string { return i.server.Alias } func (i serverItem) Title() string { return i.server.Alias }
func (i serverItem) Description() string { return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod) } func (i serverItem) Description() string {
func (i serverItem) FilterValue() string { return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User } return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod)
}
func (i serverItem) FilterValue() string {
return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User
}
// --- External callbacks --- // --- External callbacks ---
@ -80,10 +87,12 @@ var (
TestConnection func(server *model.Server) (bool, string) TestConnection func(server *model.Server) (bool, string)
// TestConnectionWithPassword tests with explicit password (for form test before save) // TestConnectionWithPassword tests with explicit password (for form test before save)
TestConnectionWithPassword func(server *model.Server, password string) (bool, string) TestConnectionWithPassword func(server *model.Server, password string) (bool, string)
SaveServer func(server *model.Server, password string) error SaveServer func(server *model.Server, password string, oldAlias string) error
GetGroups func() ([]string, error) // Returns existing group names UpdateTestResult func(alias string, status model.TestStatus, testErr string) error
RenameGroup func(oldName, newName string) error // Rename group for all servers HasSecret func(alias string, secretType string) bool
DeleteGroup func(name string) error // Remove group from all servers GetGroups func() ([]string, error)
RenameGroup func(oldName, newName string) error
DeleteGroup func(name string) error
) )
// --- Screen type --- // --- Screen type ---
@ -196,8 +205,22 @@ func (m *tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} }
m.form.testResultTime = time.Now() m.form.testResultTime = time.Now()
m.form.err = nil m.form.err = nil
}
return m, nil return m, nil
}
// Update test status in DB and reload list
if item, ok := m.list.SelectedItem().(serverItem); ok && UpdateTestResult != nil {
status := model.TestUnknown
if msg.ok {
status = model.TestOK
} else if msg.err != "" {
status = model.TestFailed
}
UpdateTestResult(item.server.Alias, status, msg.err)
}
return m, func() tea.Msg {
servers, err := ListServers()
return serversLoadedMsg{servers: servers, err: err}
}
case saveDoneMsg: case saveDoneMsg:
if m.form != nil { if m.form != nil {
@ -323,11 +346,23 @@ func (m *tuiModel) updateSearch(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
if msg.Type == tea.KeyEsc { if msg.Type == tea.KeyEsc {
if m.form != nil && (m.form.showGroupList || m.form.showAuthList) {
updated, cmd := m.form.Update(msg)
if fm, ok := updated.(*formModel); ok {
m.form = fm
}
return m, cmd
}
m.screen = screenList m.screen = screenList
m.form = nil m.form = nil
m.err = nil m.err = nil
m.success = "" m.success = ""
return m, nil // Reload server list after form close
return m, func() tea.Msg {
servers, err := ListServers()
return serversLoadedMsg{servers: servers, err: err}
}
} }
updated, cmd := m.form.Update(msg) updated, cmd := m.form.Update(msg)
@ -342,9 +377,7 @@ func (m *tuiModel) View() string {
switch m.screen { switch m.screen {
case screenList: case screenList:
b.WriteString(m.list.View()) b.WriteString(m.viewServerList())
b.WriteString("\n")
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
case screenSearch: case screenSearch:
b.WriteString("Search: " + m.searchInput.View() + "\n") b.WriteString("Search: " + m.searchInput.View() + "\n")
@ -366,13 +399,148 @@ func (m *tuiModel) View() string {
return b.String() return b.String()
} }
func (m *tuiModel) viewServerList() string {
var b strings.Builder
selectedAlias := ""
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
selectedAlias = item.server.Alias
}
b.WriteString(titleStyle.Render(fmt.Sprintf("sshkeeper %d servers", len(m.servers))))
b.WriteString("\n")
b.WriteString(helpStyle.Render(fmt.Sprintf("Vault unlocked | %s", testSummary(m.servers))))
b.WriteString("\n\n")
b.WriteString(listHeaderStyle.Render(fmt.Sprintf(" %-20s %-20s %-34s %-12s %-10s %s", "NAME", "ALIAS", "TARGET", "AUTH", "GROUP", "STATUS")))
b.WriteString("\n")
if len(m.servers) == 0 {
b.WriteString(helpStyle.Render(" No servers yet. Press Ctrl+A to add one."))
b.WriteString("\n")
} else {
for _, server := range m.servers {
marker := " "
rowStyle := normalStyle
if server.Alias == selectedAlias {
marker = ">"
rowStyle = selectedRowStyle
}
name := server.DisplayName
if name == "" {
name = server.Alias
}
target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port)
group := server.GroupName
if group == "" {
group = "-"
}
row := fmt.Sprintf("%s %-20s %-20s %-34s %-12s %-10s %s",
marker,
truncate(name, 20),
truncate(server.Alias, 20),
truncate(target, 34),
authLabel(server.AuthMethod),
truncate(group, 10),
testStatusLabel(server),
)
b.WriteString(rowStyle.Render(row))
b.WriteString("\n")
}
}
b.WriteString("\n")
if selectedAlias != "" {
if selected := m.selectedServer(); selected != nil {
b.WriteString(m.viewSelectedServer(selected))
b.WriteString("\n")
}
}
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
return b.String()
}
func (m *tuiModel) selectedServer() *model.Server {
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
return item.server
}
return nil
}
func (m *tuiModel) viewSelectedServer(server *model.Server) string {
displayName := server.DisplayName
if displayName == "" {
displayName = "-"
}
group := server.GroupName
if group == "" {
group = "-"
}
var b strings.Builder
b.WriteString(sectionStyle.Render("Selected"))
b.WriteString("\n")
b.WriteString(fmt.Sprintf(" Alias: %s\n", server.Alias))
b.WriteString(fmt.Sprintf(" Display Name: %s\n", displayName))
b.WriteString(fmt.Sprintf(" Host: %s\n", server.Host))
b.WriteString(fmt.Sprintf(" Port: %d\n", server.Port))
b.WriteString(fmt.Sprintf(" User: %s\n", server.User))
b.WriteString(fmt.Sprintf(" Auth: %s\n", authLabel(server.AuthMethod)))
b.WriteString(fmt.Sprintf(" Group: %s\n", group))
b.WriteString(fmt.Sprintf(" Status: %s\n", testStatusLabel(server)))
return b.String()
}
func testSummary(servers []*model.Server) string {
okCount := 0
failedCount := 0
for _, server := range servers {
switch server.LastTestStatus {
case model.TestOK:
okCount++
case model.TestFailed:
failedCount++
}
}
return fmt.Sprintf("%d OK | %d FAIL", okCount, failedCount)
}
func authLabel(auth model.AuthMethod) string {
switch auth {
case model.AuthPassword:
return "password"
case model.AuthKey:
return "key"
case model.AuthKeyPassphrase:
return "key+phrase"
case model.AuthAgent:
return "agent"
default:
return string(auth)
}
}
func testStatusLabel(server *model.Server) string {
switch server.LastTestStatus {
case model.TestOK:
return "OK"
case model.TestFailed:
if server.LastTestError != "" {
return "FAIL"
}
return "FAIL"
default:
return "?"
}
}
// --- Form model --- // --- Form model ---
type formModel struct { type formModel struct {
edit bool edit bool
server *model.Server server *model.Server
inputs []textinput.Model inputs []textinput.Model
labels []string
password textinput.Model password textinput.Model
passwordLabel string
focusIdx int focusIdx int
testResult string testResult string
testOK bool testOK bool
@ -385,10 +553,22 @@ type formModel struct {
spinner spinner.Model spinner spinner.Model
width int width int
height int height int
groups string // cached groups list (comma-separated for display) groups []string // existing group names
showGroups bool // show group dropdown groupList list.Model // dropdown list for groups
showGroupList bool // whether group dropdown is visible
authList list.Model
showAuthList bool
} }
// groupItem implements list.Item for the group dropdown
type groupItem struct {
name string
}
func (i groupItem) Title() string { return i.name }
func (i groupItem) Description() string { return "" }
func (i groupItem) FilterValue() string { return i.name }
func newFormModel(w, h int) *formModel { func newFormModel(w, h int) *formModel {
inputs := make([]textinput.Model, 10) inputs := make([]textinput.Model, 10)
labels := []string{ labels := []string{
@ -405,12 +585,12 @@ func newFormModel(w, h int) *formModel {
} }
for i, label := range labels { for i, label := range labels {
inputs[i] = textinput.New() inputs[i] = textinput.New()
inputs[i].Placeholder = label inputs[i].Placeholder = placeholderForLabel(label)
inputs[i].CharLimit = 128 inputs[i].CharLimit = 128
} }
pw := textinput.New() pw := textinput.New()
pw.Placeholder = "Password / Passphrase (stored in vault)" pw.Placeholder = "optional"
pw.CharLimit = 256 pw.CharLimit = 256
pw.EchoMode = textinput.EchoPassword pw.EchoMode = textinput.EchoPassword
@ -422,23 +602,74 @@ func newFormModel(w, h int) *formModel {
fm := &formModel{ fm := &formModel{
inputs: inputs, inputs: inputs,
labels: labels,
password: pw, password: pw,
passwordLabel: "Password / Passphrase",
focusIdx: 0, focusIdx: 0,
spinner: s, spinner: s,
width: w, width: w,
height: h, height: h,
} }
fm.authList = newStringList([]string{
string(model.AuthPassword),
string(model.AuthKey),
string(model.AuthKeyPassphrase),
string(model.AuthAgent),
}, "Select auth method", 34, 16)
// Load existing groups // Load existing groups
if GetGroups != nil { if GetGroups != nil {
if groups, err := GetGroups(); err == nil && len(groups) > 0 { if groups, err := GetGroups(); err == nil && len(groups) > 0 {
fm.groups = strings.Join(groups, ", ") fm.groups = groups
fm.groupList = newStringList(groups, "Select group", 30, 8)
} }
} }
fm.updateFocus()
return fm return fm
} }
func placeholderForLabel(label string) string {
switch label {
case "Alias":
return "mail.kp"
case "Display Name":
return "Production mail"
case "Host":
return "mail.example.org"
case "Port":
return "22"
case "User":
return "root"
case "Auth Method (password/key/key_passphrase/agent)":
return "key"
case "Identity File":
return "~/.ssh/id_ed25519"
case "ProxyJump":
return "optional"
case "Group (type new or pick from list)":
return "KP"
case "Notes":
return "optional"
default:
return label
}
}
func newStringList(values []string, title string, width, height int) list.Model {
items := make([]list.Item, len(values))
for i, value := range values {
items[i] = groupItem{name: value}
}
l := list.New(items, list.NewDefaultDelegate(), width, height)
l.SetShowStatusBar(false)
l.SetShowHelp(false)
l.SetShowPagination(false)
l.Title = title
l.Styles.Title = titleStyle
return l
}
func newEditFormModel(s *model.Server, w, h int) *formModel { func newEditFormModel(s *model.Server, w, h int) *formModel {
fm := newFormModel(w, h) fm := newFormModel(w, h)
fm.edit = true fm.edit = true
@ -453,6 +684,21 @@ func newEditFormModel(s *model.Server, w, h int) *formModel {
fm.inputs[7].SetValue(s.ProxyJump) fm.inputs[7].SetValue(s.ProxyJump)
fm.inputs[8].SetValue(s.GroupName) fm.inputs[8].SetValue(s.GroupName)
fm.inputs[9].SetValue(s.Notes) fm.inputs[9].SetValue(s.Notes)
if HasSecret != nil {
switch s.AuthMethod {
case model.AuthPassword:
if HasSecret(s.Alias, "ssh_password") {
fm.passwordLabel = "Password (secret saved; leave blank to keep)"
fm.password.Placeholder = ""
}
case model.AuthKeyPassphrase:
if HasSecret(s.Alias, "key_passphrase") {
fm.passwordLabel = "Key passphrase (secret saved; leave blank to keep)"
fm.password.Placeholder = ""
}
}
}
fm.updateFocus()
return fm return fm
} }
@ -498,6 +744,48 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return fm, cmd return fm, cmd
} }
// Handle group dropdown
if fm.showGroupList {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyEsc:
fm.showGroupList = false
return fm, nil
case tea.KeyEnter:
if item, ok := fm.groupList.SelectedItem().(groupItem); ok {
fm.inputs[8].SetValue(item.name)
}
fm.showGroupList = false
return fm, nil
}
}
// Pass other keys to the list
var cmd tea.Cmd
fm.groupList, cmd = fm.groupList.Update(msg)
return fm, cmd
}
if fm.showAuthList {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyEsc:
fm.showAuthList = false
return fm, nil
case tea.KeyEnter:
if item, ok := fm.authList.SelectedItem().(groupItem); ok {
fm.inputs[5].SetValue(item.name)
}
fm.showAuthList = false
return fm, nil
}
}
var cmd tea.Cmd
fm.authList, cmd = fm.authList.Update(msg)
return fm, cmd
}
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.KeyMsg: case tea.KeyMsg:
switch msg.Type { switch msg.Type {
@ -519,6 +807,17 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
fm.updateFocus() fm.updateFocus()
return fm, nil return fm, nil
case tea.KeyRunes:
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 5 {
fm.showAuthList = true
return fm, nil
}
// '/' on Group field opens group dropdown
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 8 && len(fm.groups) > 0 {
fm.showGroupList = true
return fm, nil
}
case tea.KeyEnter: case tea.KeyEnter:
switch { switch {
case fm.focusIdx == len(fm.inputs)+1: case fm.focusIdx == len(fm.inputs)+1:
@ -576,20 +875,36 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (fm *formModel) updateFocus() { func (fm *formModel) updateFocus() {
for i := range fm.inputs { for i := range fm.inputs {
fm.inputs[i].Blur() fm.inputs[i].Blur()
fm.inputs[i].Prompt = blurredStyle.Render(fm.inputs[i].Placeholder + ": ") fm.inputs[i].Prompt = blurredStyle.Render(fm.labelAt(i) + ": ")
} }
fm.password.Blur() fm.password.Blur()
fm.password.Prompt = blurredStyle.Render(fm.password.Placeholder + ": ") fm.password.Prompt = blurredStyle.Render(fm.passwordLabel + ": ")
if fm.focusIdx < len(fm.inputs) { if fm.focusIdx < len(fm.inputs) {
fm.inputs[fm.focusIdx].Focus() fm.inputs[fm.focusIdx].Focus()
fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.inputs[fm.focusIdx].Placeholder + "> ") fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.labelAt(fm.focusIdx) + "> ")
} else if fm.focusIdx == len(fm.inputs) { } else if fm.focusIdx == len(fm.inputs) {
fm.password.Focus() fm.password.Focus()
fm.password.Prompt = focusedStyle.Render(fm.password.Placeholder + "> ") fm.password.Prompt = focusedStyle.Render(fm.passwordLabel + "> ")
} }
} }
func (fm *formModel) labelAt(index int) string {
if index >= 0 && index < len(fm.labels) {
if index == 5 {
return "Auth Method (/ pick)"
}
if index == 8 {
if len(fm.groups) > 0 {
return "Group (/ pick)"
}
return "Group"
}
return fm.labels[index]
}
return ""
}
func (fm *formModel) runTest() tea.Cmd { func (fm *formModel) runTest() tea.Cmd {
fm.testing = true fm.testing = true
fm.testResult = "" fm.testResult = ""
@ -635,7 +950,11 @@ func (fm *formModel) runSave() tea.Cmd {
if s.Host == "" { if s.Host == "" {
return saveDoneMsg{err: fmt.Errorf("host is required")} return saveDoneMsg{err: fmt.Errorf("host is required")}
} }
err := SaveServer(s, pw) oldAlias := ""
if fm.edit && fm.server != nil {
oldAlias = fm.server.Alias
}
err := SaveServer(s, pw, oldAlias)
return saveDoneMsg{err: err} return saveDoneMsg{err: err}
}, },
) )
@ -672,13 +991,72 @@ func (fm *formModel) View() string {
b.WriteString(titleStyle.Render(title)) b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n") b.WriteString("\n\n")
for i := range fm.inputs { // Calculate visible range based on terminal height
// Reserve lines for: title (2) + password (1) + buttons (3) + help (1) + padding (2) = ~9
reserved := 9
available := fm.height - reserved
if available < 4 {
available = 4
}
numInputs := len(fm.inputs)
startIdx := 0
endIdx := numInputs
// Scroll: keep focused field visible
if numInputs > available {
focusInput := fm.focusIdx
if focusInput >= numInputs {
focusInput = numInputs - 1
}
// Try to show `available` fields centered on focus
startIdx = focusInput - available/2
if startIdx < 0 {
startIdx = 0
}
endIdx = startIdx + available
if endIdx > numInputs {
endIdx = numInputs
startIdx = endIdx - available
if startIdx < 0 {
startIdx = 0
}
}
}
// Show scroll indicator if needed
if startIdx > 0 {
b.WriteString(helpStyle.Render(" ↑ more fields above\n"))
}
for i := startIdx; i < endIdx; i++ {
if section := formSectionTitle(i); section != "" {
b.WriteString(sectionStyle.Render(section))
b.WriteString("\n")
}
if i == 5 {
fm.inputs[i].Placeholder = "password/key/key_passphrase/agent"
}
// Show group hint inline in placeholder for Group field
if i == 8 && len(fm.groups) > 0 && !fm.showGroupList {
fm.inputs[i].Placeholder = truncate(strings.Join(fm.groups, ", "), 25)
}
b.WriteString(fm.inputs[i].View()) b.WriteString(fm.inputs[i].View())
b.WriteString("\n") b.WriteString("\n")
// Show existing groups hint under Group field if i == 5 && fm.showAuthList {
if i == 8 && fm.groups != "" { b.WriteString("\n" + renderDropdown(fm.authList) + "\n")
b.WriteString(helpStyle.Render(" Groups: " + fm.groups + "\n")) b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
return b.String()
} }
if i == 8 && fm.showGroupList {
b.WriteString("\n" + renderDropdown(fm.groupList) + "\n")
b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
return b.String()
}
}
if endIdx < numInputs {
b.WriteString(helpStyle.Render(fmt.Sprintf(" ↓ more fields below (%d-%d of %d)\n", startIdx+1, endIdx, numInputs)))
} }
b.WriteString(fm.password.View()) b.WriteString(fm.password.View())
@ -723,8 +1101,53 @@ func (fm *formModel) View() string {
saveBtn = normalStyle.Render(saveBtn) saveBtn = normalStyle.Render(saveBtn)
} }
b.WriteString("\n" + testBtn + " " + saveBtn + "\n\n") b.WriteString("\n" + sectionStyle.Render("Actions") + "\n")
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | Enter select | Esc back")) b.WriteString(testBtn + " " + saveBtn + "\n\n")
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | / pick list | Enter select | Esc back"))
return b.String() return b.String()
} }
func renderDropdown(l list.Model) string {
var b strings.Builder
b.WriteString(sectionStyle.Render(l.Title))
b.WriteString("\n")
for i, item := range l.Items() {
group, ok := item.(groupItem)
if !ok {
continue
}
prefix := " "
style := normalStyle
if i == l.Index() {
prefix = "> "
style = selectedRowStyle
}
b.WriteString(style.Render(prefix + group.name))
b.WriteString("\n")
}
return strings.TrimRight(b.String(), "\n")
}
func formSectionTitle(index int) string {
switch index {
case 0:
return "Identity"
case 2:
return "Connection"
case 5:
return "Authentication"
case 8:
return "Metadata"
default:
return ""
}
}
// truncate limits a string to maxLen, adding "..." if truncated
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen-3] + "..."
}

322
internal/tui/app_test.go Normal file
View File

@ -0,0 +1,322 @@
package tui
import (
"strings"
"testing"
"time"
"github.com/charmbracelet/bubbles/list"
tea "github.com/charmbracelet/bubbletea"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestServerListViewUsesDashboardLayout(t *testing.T) {
now := time.Date(2026, 5, 28, 1, 50, 0, 0, time.UTC)
m := New([]*model.Server{
{
Alias: "mail.kp",
DisplayName: "Mail",
Host: "mail.example.org",
Port: 222,
User: "mirivlad",
AuthMethod: model.AuthPassword,
GroupName: "KP",
LastTestStatus: model.TestOK,
LastTestAt: &now,
},
{
Alias: "mirv.top",
Host: "mirv.top",
Port: 22,
User: "root",
AuthMethod: model.AuthKey,
LastTestStatus: model.TestUnknown,
},
})
m.width = 100
m.height = 30
m.list.SetSize(100, 24)
view := m.View()
for _, want := range []string{
"sshkeeper",
"2 servers",
"Vault",
"NAME",
"TARGET",
"AUTH",
"GROUP",
"STATUS",
"Mail",
"mail.kp",
"mirivlad@mail.example.org:222",
"KP",
"OK",
"Selected",
"Host: mail.example.org",
"Alias: mail.kp",
"Display Name: Mail",
"Port: 222",
"Enter connect",
} {
if !strings.Contains(view, want) {
t.Fatalf("expected list view to contain %q\nview:\n%s", want, view)
}
}
if strings.Contains(view, "Profiles managed locally") {
t.Fatalf("expected compact status header instead of README text\nview:\n%s", view)
}
}
func TestEscClosesGroupListBeforeLeavingForm(t *testing.T) {
oldGetGroups := GetGroups
GetGroups = func() ([]string, error) {
return []string{"prod", "stage"}, nil
}
defer func() { GetGroups = oldGetGroups }()
m := &tuiModel{
screen: screenForm,
form: newFormModel(80, 24),
}
m.form.focusIdx = 8
m.form.updateFocus()
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
m = updated.(*tuiModel)
if !m.form.showGroupList {
t.Fatal("expected / on the group field to open the group list")
}
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
m = updated.(*tuiModel)
if m.screen != screenForm {
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
}
if m.form == nil {
t.Fatal("expected form to remain open")
}
if m.form.showGroupList {
t.Fatal("expected Esc to close only the group list")
}
}
func TestEscClosesAuthMethodListBeforeLeavingForm(t *testing.T) {
m := &tuiModel{
screen: screenForm,
form: newFormModel(80, 24),
}
m.form.focusIdx = 5
m.form.updateFocus()
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
m = updated.(*tuiModel)
if !m.form.showAuthList {
t.Fatal("expected / on the auth method field to open the auth method list")
}
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
m = updated.(*tuiModel)
if m.screen != screenForm {
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
}
if m.form == nil {
t.Fatal("expected form to remain open")
}
if m.form.showAuthList {
t.Fatal("expected Esc to close only the auth method list")
}
}
func TestAuthMethodListSelectsValue(t *testing.T) {
fm := newFormModel(80, 24)
fm.focusIdx = 5
fm.updateFocus()
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
fm = updated.(*formModel)
if !fm.showAuthList {
t.Fatal("expected / on the auth method field to open the auth method list")
}
fm.authList.Select(2)
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm = updated.(*formModel)
if fm.showAuthList {
t.Fatal("expected Enter to close auth method list")
}
if got := fm.inputs[5].Value(); got != string(model.AuthKeyPassphrase) {
t.Fatalf("expected auth method %q, got %q", model.AuthKeyPassphrase, got)
}
}
func TestAuthMethodListViewShowsAllOptions(t *testing.T) {
fm := newFormModel(80, 12)
fm.focusIdx = 5
fm.updateFocus()
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
fm = updated.(*formModel)
view := fm.View()
authPos := strings.Index(view, "Auth Method")
listPos := strings.Index(view, "Select auth method")
if authPos < 0 || listPos < 0 {
t.Fatalf("expected auth field and auth list title in view\nview:\n%s", view)
}
if listPos < authPos {
t.Fatalf("expected auth method list after auth field\nview:\n%s", view)
}
if between := view[authPos:listPos]; strings.Contains(between, "Identity File") {
t.Fatalf("expected auth method list to render directly under auth field\nview:\n%s", view)
}
if strings.Contains(view, "│") {
t.Fatalf("expected compact auth method dropdown without default list border\nview:\n%s", view)
}
for _, method := range []model.AuthMethod{
model.AuthPassword,
model.AuthKey,
model.AuthKeyPassphrase,
model.AuthAgent,
} {
if !strings.Contains(view, string(method)) {
t.Fatalf("expected auth method list view to contain %q\nview:\n%s", method, view)
}
}
}
func TestGroupListViewRendersDirectlyUnderGroupField(t *testing.T) {
oldGetGroups := GetGroups
GetGroups = func() ([]string, error) {
return []string{"KP", "MY"}, nil
}
defer func() { GetGroups = oldGetGroups }()
fm := newFormModel(80, 24)
fm.focusIdx = 8
fm.updateFocus()
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
fm = updated.(*formModel)
view := fm.View()
groupPos := strings.Index(view, "Group")
listPos := strings.Index(view, "Select group")
if groupPos < 0 || listPos < 0 {
t.Fatalf("expected group field and group list title in view\nview:\n%s", view)
}
if listPos < groupPos {
t.Fatalf("expected group list after group field\nview:\n%s", view)
}
if between := view[groupPos:listPos]; strings.Contains(between, "Password") {
t.Fatalf("expected group dropdown to render before password field\nview:\n%s", view)
}
if strings.Contains(view, "│") {
t.Fatalf("expected compact group dropdown without default list border\nview:\n%s", view)
}
}
func TestSelectableFieldHintsAreVisible(t *testing.T) {
oldGetGroups := GetGroups
GetGroups = func() ([]string, error) {
return []string{"KP"}, nil
}
defer func() { GetGroups = oldGetGroups }()
fm := newFormModel(80, 24)
view := fm.View()
for _, want := range []string{"Auth Method (/ pick)", "Group (/ pick)", "/ pick list"} {
if !strings.Contains(view, want) {
t.Fatalf("expected form view to contain selectable-field hint %q\nview:\n%s", want, view)
}
}
}
func TestEditFormShowsSavedSecretMarker(t *testing.T) {
oldHasSecret := HasSecret
HasSecret = func(alias string, secretType string) bool {
return alias == "prod" && secretType == "ssh_password"
}
defer func() { HasSecret = oldHasSecret }()
fm := newEditFormModel(&model.Server{
Alias: "prod",
Host: "example.org",
Port: 22,
User: "root",
AuthMethod: model.AuthPassword,
}, 80, 24)
view := fm.View()
if !strings.Contains(view, "secret saved") {
t.Fatalf("expected edit form to show saved secret marker\nview:\n%s", view)
}
if !strings.Contains(view, "leave blank to keep") {
t.Fatalf("expected edit form to explain blank password keeps saved secret\nview:\n%s", view)
}
if strings.Count(view, "secret saved") != 1 {
t.Fatalf("expected saved secret marker to appear once\nview:\n%s", view)
}
}
func TestFormViewUsesSectionsAndStableLabels(t *testing.T) {
fm := newFormModel(100, 30)
view := fm.View()
for _, want := range []string{
"Identity",
"Connection",
"Authentication",
"Metadata",
"Actions",
"Alias",
"Display Name",
"Auth Method",
"Password / Passphrase",
} {
if !strings.Contains(view, want) {
t.Fatalf("expected form view to contain %q\nview:\n%s", want, view)
}
}
}
func TestFormTestResultDoesNotUpdateSelectedListServer(t *testing.T) {
oldUpdateTestResult := UpdateTestResult
oldListServers := ListServers
defer func() {
UpdateTestResult = oldUpdateTestResult
ListServers = oldListServers
}()
updateCalled := false
UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
updateCalled = true
return nil
}
ListServers = func() ([]*model.Server, error) {
t.Fatal("form test result should not reload server list")
return nil, nil
}
selected := &model.Server{Alias: "selected", Host: "example.org", Port: 22, User: "root"}
m := New([]*model.Server{selected})
m.screen = screenForm
m.form = newFormModel(80, 24)
m.list = list.New([]list.Item{serverItem{server: selected}}, list.NewDefaultDelegate(), 80, 20)
updated, cmd := m.Update(testDoneMsg{ok: true})
m = updated.(*tuiModel)
if cmd != nil {
t.Fatal("form test result should not return a reload command")
}
if updateCalled {
t.Fatal("form test result should not update the selected list server")
}
if m.form == nil || m.form.testResult != "Connection OK." {
t.Fatal("expected form to keep its test result")
}
}

View File

@ -8,6 +8,8 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"sort"
"strings"
"sync" "sync"
"time" "time"
@ -20,6 +22,8 @@ const (
saltLen = 32 saltLen = 32
nonceLen = 24 nonceLen = 24
keyLen = 32 keyLen = 32
verifierID = "__sshkeeper_vault_verifier__"
verifierPlaintext = "sshkeeper-vault-verifier-v1"
) )
type KDFMeta struct { type KDFMeta struct {
@ -40,6 +44,7 @@ type Record struct {
type VaultFile struct { type VaultFile struct {
Version int `json:"version"` Version int `json:"version"`
KDF KDFMeta `json:"kdf"` KDF KDFMeta `json:"kdf"`
Verifier *Record `json:"verifier,omitempty"`
Records []Record `json:"records"` Records []Record `json:"records"`
} }
@ -47,14 +52,25 @@ type Vault struct {
mu sync.Mutex mu sync.Mutex
path string path string
masterKey []byte masterKey []byte
records map[string][]byte // id -> plaintext records map[string]secretRecord
modified bool modified bool
} }
type secretRecord struct {
secretType string
plaintext []byte
}
type SecretMeta struct {
ID string
Alias string
Type string
}
func New(path string) *Vault { func New(path string) *Vault {
return &Vault{ return &Vault{
path: path, path: path,
records: make(map[string][]byte), records: make(map[string]secretRecord),
} }
} }
@ -76,20 +92,25 @@ func Create(path string, masterPassword string) error {
kdf := KDFMeta{ kdf := KDFMeta{
Name: "argon2id", Name: "argon2id",
MemoryKiB: 4096, MemoryKiB: 65536,
Iterations: 2, Iterations: 3,
Parallelism: 1, Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt), Salt: base64.StdEncoding.EncodeToString(salt),
} }
fmt.Print("Deriving key...") fmt.Print("Deriving key...")
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB)*1024, uint8(kdf.Parallelism), keyLen) key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB), uint8(kdf.Parallelism), keyLen)
verifier, err := newVerifierRecord(key)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
// Verify key is valid by doing a test encrypt/decrypt
vf := VaultFile{ vf := VaultFile{
Version: currentVersion, Version: currentVersion,
KDF: kdf, KDF: kdf,
Verifier: &verifier,
Records: []Record{}, Records: []Record{},
} }
@ -143,24 +164,32 @@ func (v *Vault) Unlock(masterPassword string) error {
return fmt.Errorf("decode salt: %w", err) return fmt.Errorf("decode salt: %w", err)
} }
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen) key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
// Try to decrypt first record to verify password if vf.Verifier != nil {
if len(vf.Records) > 0 { if err := verifyRecord(key, *vf.Verifier); err != nil {
return fmt.Errorf("invalid master password")
}
} else if len(vf.Records) > 0 {
if _, err := decryptRecord(key, vf.Records[0]); err != nil { if _, err := decryptRecord(key, vf.Records[0]); err != nil {
return fmt.Errorf("invalid master password") return fmt.Errorf("invalid master password")
} }
} else {
return fmt.Errorf("vault cannot verify master password; recreate empty vault")
} }
v.masterKey = key v.masterKey = key
v.records = make(map[string][]byte) v.records = make(map[string]secretRecord)
for _, rec := range vf.Records { for _, rec := range vf.Records {
plaintext, err := decryptRecord(key, rec) plaintext, err := decryptRecord(key, rec)
if err != nil { if err != nil {
return fmt.Errorf("decrypt record %s: %w", rec.ID, err) return fmt.Errorf("decrypt record %s: %w", rec.ID, err)
} }
v.records[rec.ID] = plaintext v.records[rec.ID] = secretRecord{
secretType: inferSecretType(rec.ID, rec.Type),
plaintext: plaintext,
}
} }
fmt.Println(" done.") fmt.Println(" done.")
@ -178,7 +207,7 @@ func (v *Vault) Lock() {
} }
} }
v.masterKey = nil v.masterKey = nil
v.records = make(map[string][]byte) v.records = make(map[string]secretRecord)
} }
// IsUnlocked returns whether the vault is currently unlocked // IsUnlocked returns whether the vault is currently unlocked
@ -197,7 +226,9 @@ func (v *Vault) Put(id string, secretType string, plaintext []byte) error {
return fmt.Errorf("vault is locked") return fmt.Errorf("vault is locked")
} }
v.records[id] = plaintext data := make([]byte, len(plaintext))
copy(data, plaintext)
v.records[id] = secretRecord{secretType: secretType, plaintext: data}
v.modified = true v.modified = true
return nil return nil
} }
@ -211,16 +242,55 @@ func (v *Vault) Get(id string) ([]byte, error) {
return nil, fmt.Errorf("vault is locked") return nil, fmt.Errorf("vault is locked")
} }
data, ok := v.records[id] record, ok := v.records[id]
if !ok { if !ok {
return nil, fmt.Errorf("secret not found: %s", id) return nil, fmt.Errorf("secret not found: %s", id)
} }
result := make([]byte, len(data)) result := make([]byte, len(record.plaintext))
copy(result, data) copy(result, record.plaintext)
return result, nil return result, nil
} }
func (v *Vault) HasSecret(id string) bool {
v.mu.Lock()
defer v.mu.Unlock()
_, ok := v.records[id]
return ok
}
func (v *Vault) ListSecrets() ([]SecretMeta, error) {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey == nil {
return nil, fmt.Errorf("vault is locked")
}
metas := make([]SecretMeta, 0, len(v.records))
for id, record := range v.records {
alias, secretType, ok := parseServerSecretID(id)
if !ok {
continue
}
if record.secretType != "" {
secretType = record.secretType
}
metas = append(metas, SecretMeta{
ID: id,
Alias: alias,
Type: secretType,
})
}
sort.Slice(metas, func(i, j int) bool {
if metas[i].Alias == metas[j].Alias {
return metas[i].Type < metas[j].Type
}
return metas[i].Alias < metas[j].Alias
})
return metas, nil
}
// Delete removes a secret // Delete removes a secret
func (v *Vault) Delete(id string) { func (v *Vault) Delete(id string) {
v.mu.Lock() v.mu.Lock()
@ -245,8 +315,8 @@ func (v *Vault) Save() error {
kdf := KDFMeta{ kdf := KDFMeta{
Name: "argon2id", Name: "argon2id",
MemoryKiB: 4096, MemoryKiB: 65536,
Iterations: 2, Iterations: 3,
Parallelism: 1, Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt), Salt: base64.StdEncoding.EncodeToString(salt),
} }
@ -254,17 +324,23 @@ func (v *Vault) Save() error {
fmt.Print("Deriving key...") fmt.Print("Deriving key...")
var records []Record var records []Record
for id, plaintext := range v.records { for id, record := range v.records {
rec, err := encryptRecord(v.masterKey, id, plaintext) rec, err := encryptRecordWithType(v.masterKey, id, record.secretType, record.plaintext)
if err != nil { if err != nil {
return fmt.Errorf("encrypt record %s: %w", id, err) return fmt.Errorf("encrypt record %s: %w", id, err)
} }
records = append(records, rec) records = append(records, rec)
} }
verifier, err := newVerifierRecord(v.masterKey)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
vf := VaultFile{ vf := VaultFile{
Version: currentVersion, Version: currentVersion,
KDF: kdf, KDF: kdf,
Verifier: &verifier,
Records: records, Records: records,
} }
@ -309,12 +385,12 @@ func (v *Vault) ChangePassword(newPassword string) error {
return fmt.Errorf("generate salt: %w", err) return fmt.Errorf("generate salt: %w", err)
} }
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 8192*1024, 1, keyLen) newKey := argon2.IDKey([]byte(newPassword), salt, 3, 65536, 1, keyLen)
kdf := KDFMeta{ kdf := KDFMeta{
Name: "argon2id", Name: "argon2id",
MemoryKiB: 4096, MemoryKiB: 65536,
Iterations: 2, Iterations: 3,
Parallelism: 1, Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt), Salt: base64.StdEncoding.EncodeToString(salt),
} }
@ -322,17 +398,23 @@ func (v *Vault) ChangePassword(newPassword string) error {
fmt.Print("Deriving key...") fmt.Print("Deriving key...")
var records []Record var records []Record
for id, plaintext := range v.records { for id, record := range v.records {
rec, err := encryptRecord(newKey, id, plaintext) rec, err := encryptRecordWithType(newKey, id, record.secretType, record.plaintext)
if err != nil { if err != nil {
return fmt.Errorf("encrypt record: %w", err) return fmt.Errorf("encrypt record: %w", err)
} }
records = append(records, rec) records = append(records, rec)
} }
verifier, err := newVerifierRecord(newKey)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
vf := VaultFile{ vf := VaultFile{
Version: currentVersion, Version: currentVersion,
KDF: kdf, KDF: kdf,
Verifier: &verifier,
Records: records, Records: records,
} }
@ -382,6 +464,10 @@ func (v *Vault) getSalt() string {
} }
func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) { func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
return encryptRecordWithType(key, id, "", plaintext)
}
func encryptRecordWithType(key []byte, id string, secretType string, plaintext []byte) (Record, error) {
aead, err := chacha20poly1305.NewX(key) aead, err := chacha20poly1305.NewX(key)
if err != nil { if err != nil {
return Record{}, err return Record{}, err
@ -396,11 +482,31 @@ func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
return Record{ return Record{
ID: id, ID: id,
Type: secretType,
Nonce: base64.StdEncoding.EncodeToString(nonce), Nonce: base64.StdEncoding.EncodeToString(nonce),
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext), Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
}, nil }, nil
} }
func inferSecretType(id string, recordType string) string {
if recordType != "" {
return recordType
}
_, secretType, ok := parseServerSecretID(id)
if !ok {
return ""
}
return secretType
}
func parseServerSecretID(id string) (string, string, bool) {
parts := strings.Split(id, ":")
if len(parts) != 3 || parts[0] != "server" || parts[1] == "" || parts[2] == "" {
return "", "", false
}
return parts[1], parts[2], true
}
func decryptRecord(key []byte, rec Record) ([]byte, error) { func decryptRecord(key []byte, rec Record) ([]byte, error) {
aead, err := chacha20poly1305.NewX(key) aead, err := chacha20poly1305.NewX(key)
if err != nil { if err != nil {
@ -425,6 +531,26 @@ func decryptRecord(key []byte, rec Record) ([]byte, error) {
return plaintext, nil return plaintext, nil
} }
func newVerifierRecord(key []byte) (Record, error) {
rec, err := encryptRecord(key, verifierID, []byte(verifierPlaintext))
if err != nil {
return Record{}, err
}
rec.Type = "verifier"
return rec, nil
}
func verifyRecord(key []byte, rec Record) error {
plaintext, err := decryptRecord(key, rec)
if err != nil {
return err
}
if !SecureCompare(string(plaintext), verifierPlaintext) {
return fmt.Errorf("invalid verifier")
}
return nil
}
// VerifyPassword checks if a master password is correct without unlocking // VerifyPassword checks if a master password is correct without unlocking
func VerifyPassword(path string, masterPassword string) (bool, error) { func VerifyPassword(path string, masterPassword string) (bool, error) {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
@ -442,18 +568,20 @@ func VerifyPassword(path string, masterPassword string) (bool, error) {
return false, err return false, err
} }
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen) key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
defer func() { defer func() {
for i := range key { for i := range key {
key[i] = 0 key[i] = 0
} }
}() }()
if len(vf.Records) == 0 { if vf.Verifier != nil {
// Empty vault, try a test encryption return verifyRecord(key, *vf.Verifier) == nil, nil
return true, nil
} }
if len(vf.Records) == 0 {
return false, nil
}
_, err = decryptRecord(key, vf.Records[0]) _, err = decryptRecord(key, vf.Records[0])
return err == nil, nil return err == nil, nil
} }

View File

@ -0,0 +1,218 @@
package vault
import (
"encoding/base64"
"encoding/json"
"os"
"path/filepath"
"testing"
"golang.org/x/crypto/argon2"
)
func TestNewEmptyVaultRejectsWrongPassword(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "correct horse"); err != nil {
t.Fatalf("create vault: %v", err)
}
ok, err := VerifyPassword(path, "wrong horse")
if err != nil {
t.Fatalf("verify wrong password: %v", err)
}
if ok {
t.Fatal("expected wrong password to be rejected for a new empty vault")
}
v := New(path)
if err := v.Unlock("wrong horse"); err == nil {
t.Fatal("expected unlock with wrong password to fail for a new empty vault")
}
}
func TestNewEmptyVaultAcceptsCorrectPassword(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "correct horse"); err != nil {
t.Fatalf("create vault: %v", err)
}
ok, err := VerifyPassword(path, "correct horse")
if err != nil {
t.Fatalf("verify correct password: %v", err)
}
if !ok {
t.Fatal("expected correct password to be accepted for a new empty vault")
}
v := New(path)
if err := v.Unlock("correct horse"); err != nil {
t.Fatalf("unlock with correct password: %v", err)
}
}
func TestLegacyVaultWithRecordsStillVerifiesByFirstRecord(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
salt := []byte("12345678901234567890123456789012")
key := argon2.IDKey([]byte("correct horse"), salt, 3, 65536, 1, keyLen)
rec, err := encryptRecord(key, "server:test:ssh_password", []byte("secret"))
if err != nil {
t.Fatalf("encrypt legacy record: %v", err)
}
data, err := json.Marshal(VaultFile{
Version: currentVersion,
KDF: KDFMeta{
Name: "argon2id",
MemoryKiB: 65536,
Iterations: 3,
Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt),
},
Records: []Record{rec},
})
if err != nil {
t.Fatalf("marshal legacy vault: %v", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("write legacy vault: %v", err)
}
ok, err := VerifyPassword(path, "correct horse")
if err != nil {
t.Fatalf("verify legacy vault: %v", err)
}
if !ok {
t.Fatal("expected legacy vault with records to accept correct password")
}
ok, err = VerifyPassword(path, "wrong horse")
if err != nil {
t.Fatalf("verify legacy vault with wrong password: %v", err)
}
if ok {
t.Fatal("expected legacy vault with records to reject wrong password")
}
}
func TestLegacyEmptyVaultWithoutVerifierCannotUnlock(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
salt := []byte("12345678901234567890123456789012")
data, err := json.Marshal(VaultFile{
Version: currentVersion,
KDF: KDFMeta{
Name: "argon2id",
MemoryKiB: 65536,
Iterations: 3,
Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt),
},
Records: []Record{},
})
if err != nil {
t.Fatalf("marshal legacy empty vault: %v", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("write legacy empty vault: %v", err)
}
ok, err := VerifyPassword(path, "any password")
if err != nil {
t.Fatalf("verify legacy empty vault: %v", err)
}
if ok {
t.Fatal("expected legacy empty vault without verifier to be unverifiable")
}
v := New(path)
if err := v.Unlock("any password"); err == nil {
t.Fatal("expected legacy empty vault without verifier to reject unlock")
}
}
func TestListSecretsReturnsMetadataWithoutPlaintext(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
t.Fatalf("put password: %v", err)
}
if err := v.Put("server:prod:key_passphrase", "key_passphrase", []byte("secret-passphrase")); err != nil {
t.Fatalf("put passphrase: %v", err)
}
metas, err := v.ListSecrets()
if err != nil {
t.Fatalf("list secrets: %v", err)
}
if len(metas) != 2 {
t.Fatalf("expected 2 secret metadata records, got %d: %#v", len(metas), metas)
}
if metas[0].Alias != "prod" || metas[0].Type != "key_passphrase" {
t.Fatalf("unexpected first metadata record: %#v", metas[0])
}
if metas[1].Alias != "prod" || metas[1].Type != "ssh_password" {
t.Fatalf("unexpected second metadata record: %#v", metas[1])
}
}
func TestListSecretsPreservesTypesAfterSaveAndUnlock(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
t.Fatalf("put password: %v", err)
}
if err := v.Save(); err != nil {
t.Fatalf("save vault: %v", err)
}
reopened := New(path)
if err := reopened.Unlock("master"); err != nil {
t.Fatalf("unlock reopened vault: %v", err)
}
metas, err := reopened.ListSecrets()
if err != nil {
t.Fatalf("list reopened secrets: %v", err)
}
if len(metas) != 1 {
t.Fatalf("expected 1 secret metadata record, got %d: %#v", len(metas), metas)
}
if metas[0].ID != "server:prod:ssh_password" || metas[0].Alias != "prod" || metas[0].Type != "ssh_password" {
t.Fatalf("unexpected metadata after reopen: %#v", metas[0])
}
}
func TestHasSecretReportsPresenceWithoutReturningValue(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
t.Fatalf("put password: %v", err)
}
if !v.HasSecret("server:prod:ssh_password") {
t.Fatal("expected saved password to be reported present")
}
if v.HasSecret("server:prod:key_passphrase") {
t.Fatal("expected missing passphrase to be reported absent")
}
}