feat: improve tui and vault handling

This commit is contained in:
mirivlad 2026-05-28 02:25:18 +08:00
parent 73c60b9f93
commit e1d709396b
21 changed files with 2292 additions and 224 deletions

View File

@ -3,9 +3,11 @@ package cmd
import ( import (
"fmt" "fmt"
"strings" "strings"
"syscall"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
"golang.org/x/term"
) )
var addFlags struct { var addFlags struct {
@ -19,7 +21,6 @@ var addFlags struct {
displayName string displayName string
notes string notes string
tags string tags string
password string
} }
var addCmd = &cobra.Command{ var addCmd = &cobra.Command{
@ -59,11 +60,21 @@ func addNonInteractive(alias string) error {
server.DisplayName = alias server.DisplayName = alias
} }
// Handle password auth - store in vault // Handle password/passphrase auth — request interactively, never via argv
if server.AuthMethod == model.AuthPassword { if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
password := addFlags.password secretType := "password"
if password == "" { if server.AuthMethod == model.AuthKeyPassphrase {
return fmt.Errorf("password auth requires --password flag or interactive mode") secretType = "passphrase"
}
fmt.Printf("Enter %s (will be stored in vault, input hidden): ", secretType)
password, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
return fmt.Errorf("read %s: %w", secretType, err)
}
if len(password) == 0 {
return fmt.Errorf("%s cannot be empty", secretType)
} }
v := getOrCreateVault() v := getOrCreateVault()
@ -72,29 +83,14 @@ func addNonInteractive(alias string) error {
} }
vaultKey := fmt.Sprintf("server:%s:ssh_password", alias) vaultKey := fmt.Sprintf("server:%s:ssh_password", alias)
if err := v.Put(vaultKey, "ssh_password", []byte(password)); err != nil { vaultType := "ssh_password"
return fmt.Errorf("store password in vault: %w", err) if server.AuthMethod == model.AuthKeyPassphrase {
} vaultKey = fmt.Sprintf("server:%s:key_passphrase", alias)
if err := v.Save(); err != nil { vaultType = "key_passphrase"
return fmt.Errorf("save vault: %w", err)
}
}
// Handle key+passphrase - store passphrase in vault
if server.AuthMethod == model.AuthKeyPassphrase {
passphrase := addFlags.password
if passphrase == "" {
return fmt.Errorf("key+passphrase auth requires --password flag for the passphrase")
} }
v := getOrCreateVault() if err := v.Put(vaultKey, vaultType, password); err != nil {
if !v.IsUnlocked() { return fmt.Errorf("store %s in vault: %w", secretType, err)
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
}
vaultKey := fmt.Sprintf("server:%s:key_passphrase", alias)
if err := v.Put(vaultKey, "key_passphrase", []byte(passphrase)); err != nil {
return fmt.Errorf("store passphrase in vault: %w", err)
} }
if err := v.Save(); err != nil { if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err) return fmt.Errorf("save vault: %w", err)
@ -132,5 +128,4 @@ func init() {
addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name") addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name")
addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes") addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes")
addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags") addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags")
addCmd.Flags().StringVar(&addFlags.password, "password", "", "SSH password or key passphrase (stored in vault)")
} }

View File

@ -27,6 +27,15 @@ var deleteCmd = &cobra.Command{
return fmt.Errorf("delete server: %w", err) return fmt.Errorf("delete server: %w", err)
} }
// Clean up vault secrets for this server
v := getOrCreateVault()
if v.IsUnlocked() {
cleanupServerSecrets(v, alias)
if err := v.Save(); err != nil {
return fmt.Errorf("save vault after cleanup: %w", err)
}
}
fmt.Println("Deleted.") fmt.Println("Deleted.")
return nil return nil
}, },

View File

@ -2,9 +2,11 @@ package cmd
import ( import (
"fmt" "fmt"
"syscall"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
"github.com/spf13/cobra"
"golang.org/x/term"
) )
var editCmd = &cobra.Command{ var editCmd = &cobra.Command{
@ -18,6 +20,8 @@ var editCmd = &cobra.Command{
return fmt.Errorf("server not found: %s", alias) return fmt.Errorf("server not found: %s", alias)
} }
oldAuthMethod := server.AuthMethod
if parsedHost != "" { if parsedHost != "" {
server.Host = parsedHost server.Host = parsedHost
} }
@ -46,6 +50,41 @@ var editCmd = &cobra.Command{
server.Notes = parsedNotes server.Notes = parsedNotes
} }
if parsedAuth != "" && oldAuthMethod != server.AuthMethod {
v := getOrCreateVault()
if v.IsUnlocked() {
var secret string
if server.AuthMethod == model.AuthPassword {
fmt.Print("Enter new password (stored in vault, input hidden): ")
pw, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
return fmt.Errorf("read password: %w", err)
}
if len(pw) > 0 {
secret = string(pw)
}
} else if server.AuthMethod == model.AuthKeyPassphrase {
fmt.Print("Enter key passphrase (stored in vault, input hidden): ")
pw, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
return fmt.Errorf("read passphrase: %w", err)
}
if len(pw) > 0 {
secret = string(pw)
}
}
if err := syncServerSecrets(v, alias, server, secret); err != nil {
return fmt.Errorf("sync vault secrets: %w", err)
}
if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err)
}
}
}
if err := appDB.UpdateServer(server); err != nil { if err := appDB.UpdateServer(server); err != nil {
return fmt.Errorf("update server: %w", err) return fmt.Errorf("update server: %w", err)
} }
@ -56,15 +95,15 @@ var editCmd = &cobra.Command{
} }
var ( var (
parsedHost string parsedHost string
parsedPort int parsedPort int
parsedUser string parsedUser string
parsedAuth string parsedAuth string
parsedIdentity string parsedIdentity string
parsedProxyJump string parsedProxyJump string
parsedGroup string parsedGroup string
parsedDisplayName string parsedDisplayName string
parsedNotes string parsedNotes string
) )
func init() { func init() {

View File

@ -74,8 +74,13 @@ var runCmd = &cobra.Command{
return fmt.Errorf("server not found: %s", alias) return fmt.Errorf("server not found: %s", alias)
} }
// Build ssh args with the command // For password auth, use PTY-wrapper with command
sshArgs := buildSSHArgs(server) if server.AuthMethod == model.AuthPassword {
return runWithPassword(server, command)
}
// For key/agent auth — direct execution
sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command) sshArgs = append(sshArgs, command)
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...) sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
@ -91,17 +96,21 @@ var runCmd = &cobra.Command{
}, },
} }
func buildSSHArgs(server *model.Server) []string { // runWithPassword runs a command on a server with password auth via PTY-wrapper.
var args []string func runWithPassword(server *model.Server, command string) error {
args = append(args, "-p", fmt.Sprintf("%d", server.Port)) v := getOrCreateVault()
if server.IdentityFile != "" { if !v.IsUnlocked() {
args = append(args, "-i", server.IdentityFile) return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
} }
if server.ProxyJump != "" {
args = append(args, "-J", server.ProxyJump) vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias)
password, err := v.Get(vaultKey)
if err != nil {
return fmt.Errorf("get password from vault: %w", err)
} }
args = append(args, "-o", "StrictHostKeyChecking=accept-new")
target := fmt.Sprintf("%s@%s", server.User, server.Host) sshArgs := ssh.BuildSSHArgs(server)
args = append(args, target) sshArgs = append(sshArgs, command)
return args
return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(password))
} }

85
cmd/secrets.go Normal file
View File

@ -0,0 +1,85 @@
package cmd
import (
"fmt"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/ssh"
"github.com/mirivlad/sshkeeper/internal/vault"
)
const (
secretSSHPassword = "ssh_password"
secretKeyPassphrase = "key_passphrase"
secretSudoPassword = "sudo_password"
)
var serverSecretTypes = []string{
secretSSHPassword,
secretKeyPassphrase,
secretSudoPassword,
}
func serverSecretID(alias, secretType string) string {
return fmt.Sprintf("server:%s:%s", alias, secretType)
}
func cleanupServerSecrets(v *vault.Vault, alias string) {
for _, secretType := range serverSecretTypes {
v.Delete(serverSecretID(alias, secretType))
}
}
func syncServerSecrets(v *vault.Vault, oldAlias string, server *model.Server, secret string) error {
if oldAlias == "" {
oldAlias = server.Alias
}
if oldAlias != server.Alias {
for _, secretType := range serverSecretTypes {
oldID := serverSecretID(oldAlias, secretType)
data, err := v.Get(oldID)
if err == nil {
if err := v.Put(serverSecretID(server.Alias, secretType), secretType, data); err != nil {
return err
}
}
v.Delete(oldID)
}
}
switch server.AuthMethod {
case model.AuthPassword:
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
if secret != "" {
return v.Put(serverSecretID(server.Alias, secretSSHPassword), secretSSHPassword, []byte(secret))
}
case model.AuthKeyPassphrase:
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
if secret != "" {
return v.Put(serverSecretID(server.Alias, secretKeyPassphrase), secretKeyPassphrase, []byte(secret))
}
default:
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
}
return nil
}
func deleteVaultSecrets(v *vault.Vault, alias string, secretType string) error {
if secretType != "" {
v.Delete(serverSecretID(alias, secretType))
return nil
}
cleanupServerSecrets(v, alias)
return nil
}
func formTestVaultFunc(getVault ssh.VaultFunc, server *model.Server, formSecret string) ssh.VaultFunc {
return func(serverAlias string, secretType string) (string, error) {
if (secretType == secretSSHPassword || secretType == secretKeyPassphrase) && formSecret != "" {
return formSecret, nil
}
return getVault(serverAlias, secretType)
}
}

186
cmd/secrets_test.go Normal file
View File

@ -0,0 +1,186 @@
package cmd
import (
"fmt"
"path/filepath"
"testing"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/vault"
)
func newUnlockedTestVault(t *testing.T) *vault.Vault {
t.Helper()
path := filepath.Join(t.TempDir(), "vault.bin")
if err := vault.Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := vault.New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
return v
}
func mustPutSecret(t *testing.T, v *vault.Vault, alias, secretType, value string) {
t.Helper()
if err := v.Put(serverSecretID(alias, secretType), secretType, []byte(value)); err != nil {
t.Fatalf("put %s: %v", secretType, err)
}
}
func mustGetSecret(t *testing.T, v *vault.Vault, alias, secretType string) string {
t.Helper()
data, err := v.Get(serverSecretID(alias, secretType))
if err != nil {
t.Fatalf("get %s: %v", secretType, err)
}
return string(data)
}
func assertSecretMissing(t *testing.T, v *vault.Vault, alias, secretType string) {
t.Helper()
if data, err := v.Get(serverSecretID(alias, secretType)); err == nil {
t.Fatalf("expected %s for %s to be missing, got %q", secretType, alias, string(data))
}
}
func TestCleanupServerSecretsRemovesAuthAndSudoSecrets(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
mustPutSecret(t, v, "prod", "sudo_password", "sudo")
cleanupServerSecrets(v, "prod")
assertSecretMissing(t, v, "prod", "ssh_password")
assertSecretMissing(t, v, "prod", "key_passphrase")
assertSecretMissing(t, v, "prod", "sudo_password")
}
func TestSyncServerSecretsDeletesAuthSecretsForKeyAuth(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
server := &model.Server{Alias: "prod", AuthMethod: model.AuthKey}
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
t.Fatalf("sync secrets: %v", err)
}
assertSecretMissing(t, v, "prod", "ssh_password")
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestSyncServerSecretsKeepsExistingPasswordWhenEditPasswordIsBlank(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "old-password")
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
t.Fatalf("sync secrets: %v", err)
}
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "old-password" {
t.Fatalf("expected password to remain, got %q", got)
}
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestSyncServerSecretsRenamesSecretsBeforeApplyingAuthCleanup(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "old", "ssh_password", "password")
mustPutSecret(t, v, "old", "key_passphrase", "passphrase")
mustPutSecret(t, v, "old", "sudo_password", "sudo")
server := &model.Server{Alias: "new", AuthMethod: model.AuthKeyPassphrase}
if err := syncServerSecrets(v, "old", server, ""); err != nil {
t.Fatalf("sync secrets: %v", err)
}
assertSecretMissing(t, v, "old", "ssh_password")
assertSecretMissing(t, v, "old", "key_passphrase")
assertSecretMissing(t, v, "old", "sudo_password")
assertSecretMissing(t, v, "new", "ssh_password")
if got := mustGetSecret(t, v, "new", "key_passphrase"); got != "passphrase" {
t.Fatalf("expected key passphrase to move, got %q", got)
}
if got := mustGetSecret(t, v, "new", "sudo_password"); got != "sudo" {
t.Fatalf("expected sudo password to move, got %q", got)
}
}
func TestSyncServerSecretsStoresNewSecretForSelectedAuthMethod(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
if err := syncServerSecrets(v, "prod", server, "new-password"); err != nil {
t.Fatalf("sync secrets: %v", err)
}
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "new-password" {
t.Fatalf("expected new password, got %q", got)
}
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestDeleteVaultSecretsForAliasAndType(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
if err := deleteVaultSecrets(v, "prod", "ssh_password"); err != nil {
t.Fatalf("delete vault secret: %v", err)
}
assertSecretMissing(t, v, "prod", "ssh_password")
if got := mustGetSecret(t, v, "prod", "key_passphrase"); got != "passphrase" {
t.Fatalf("expected passphrase to remain, got %q", got)
}
}
func TestDeleteVaultSecretsForAlias(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "password")
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
if err := deleteVaultSecrets(v, "prod", ""); err != nil {
t.Fatalf("delete vault secrets: %v", err)
}
assertSecretMissing(t, v, "prod", "ssh_password")
assertSecretMissing(t, v, "prod", "key_passphrase")
}
func TestFormTestVaultFuncUsesSavedSecretWhenFormSecretBlank(t *testing.T) {
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
if serverAlias != "prod" || secretType != "ssh_password" {
return "", fmt.Errorf("unexpected secret lookup %s %s", serverAlias, secretType)
}
return "saved-password", nil
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "")
got, err := vaultFunc("prod", "ssh_password")
if err != nil {
t.Fatalf("get password: %v", err)
}
if got != "saved-password" {
t.Fatalf("expected saved password, got %q", got)
}
}
func TestFormTestVaultFuncUsesFormSecretWhenProvided(t *testing.T) {
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
return "", fmt.Errorf("saved secret should not be used")
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "typed-password")
got, err := vaultFunc("prod", "ssh_password")
if err != nil {
t.Fatalf("get password: %v", err)
}
if got != "typed-password" {
t.Fatalf("expected typed password, got %q", got)
}
}

View File

@ -2,8 +2,11 @@ package cmd
import ( import (
"fmt" "fmt"
"os"
"os/exec"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/ssh"
) )
var templateCmd = &cobra.Command{ var templateCmd = &cobra.Command{
@ -70,13 +73,15 @@ var runTemplateCmd = &cobra.Command{
alias := args[0] alias := args[0]
templateName := args[1] templateName := args[1]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
templates, err := appDB.GetCommandTemplates(alias) templates, err := appDB.GetCommandTemplates(alias)
if err != nil { if err != nil {
return fmt.Errorf("list templates: %w", err) return fmt.Errorf("list templates: %w", err)
} }
if len(templates) == 0 {
return fmt.Errorf("server not found or no templates: %s", alias)
}
var command string var command string
for _, t := range templates { for _, t := range templates {
@ -91,7 +96,21 @@ var runTemplateCmd = &cobra.Command{
} }
fmt.Printf("Running '%s' on %s...\n", command, alias) fmt.Printf("Running '%s' on %s...\n", command, alias)
return nil
// Build ssh args with the command
sshArgs := ssh.BuildSSHArgs(server)
sshArgs = append(sshArgs, command)
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
if err := sshCmd.Start(); err != nil {
return fmt.Errorf("start ssh: %w", err)
}
return sshCmd.Wait()
}, },
} }

View File

@ -36,41 +36,43 @@ func runTUI() error {
return appDB.SearchServers(query) return appDB.SearchServers(query)
} }
tui.DeleteServer = func(alias string) error { tui.DeleteServer = func(alias string) error {
return appDB.DeleteServer(alias) if err := appDB.DeleteServer(alias); err != nil {
return err
}
v := getOrCreateVault()
if v.IsUnlocked() {
cleanupServerSecrets(v, alias)
if err := v.Save(); err != nil {
return fmt.Errorf("save vault after cleanup: %w", err)
}
}
return nil
} }
tui.TestConnection = func(server *model.Server) (bool, string) { tui.TestConnection = func(server *model.Server) (bool, string) {
return ssh.Test(cfg, server, vaultFunc) return ssh.Test(cfg, server, vaultFunc)
} }
tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) { tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) {
directVaultFunc := func(sa string, st string) (string, error) { return ssh.Test(cfg, server, formTestVaultFunc(vaultFunc, server, password))
if st == "ssh_password" || st == "key_passphrase" {
return password, nil
}
return vaultFunc(sa, st)
}
return ssh.Test(cfg, server, directVaultFunc)
} }
tui.SaveServer = func(server *model.Server, password string) error { tui.SaveServer = func(server *model.Server, password string, oldAlias string) error {
if password != "" { v := getOrCreateVault()
v := getOrCreateVault() if v.IsUnlocked() {
vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias) if err := syncServerSecrets(v, oldAlias, server, password); err != nil {
secretType := "ssh_password" return fmt.Errorf("sync vault secrets: %w", err)
if server.AuthMethod == model.AuthKeyPassphrase {
vaultKey = fmt.Sprintf("server:%s:key_passphrase", server.Alias)
secretType = "key_passphrase"
}
if err := v.Put(vaultKey, secretType, []byte(password)); err != nil {
return fmt.Errorf("store secret: %w", err)
} }
if err := v.Save(); err != nil { if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err) return fmt.Errorf("save vault: %w", err)
} }
} }
existing, _ := appDB.GetServer(server.Alias) lookupAlias := server.Alias
if oldAlias != "" {
lookupAlias = oldAlias
}
existing, _ := appDB.GetServer(lookupAlias)
if existing != nil { if existing != nil {
server.ID = existing.ID server.ID = existing.ID
return appDB.UpdateServer(server) return appDB.UpdateServerByAlias(existing.Alias, server)
} }
return appDB.CreateServer(server) return appDB.CreateServer(server)
} }
@ -84,6 +86,16 @@ func runTUI() error {
tui.DeleteGroup = func(name string) error { tui.DeleteGroup = func(name string) error {
return appDB.DeleteGroup(name) return appDB.DeleteGroup(name)
} }
tui.UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
return appDB.UpdateTestResult(alias, status, testErr)
}
tui.HasSecret = func(alias string, secretType string) bool {
v := getOrCreateVault()
if !v.IsUnlocked() {
return false
}
return v.HasSecret(serverSecretID(alias, secretType))
}
// Run TUI in a loop — if user requests connect, handle it and restart TUI // Run TUI in a loop — if user requests connect, handle it and restart TUI
for { for {
@ -132,5 +144,3 @@ func runTUI() error {
return nil return nil
} }
} }

View File

@ -3,11 +3,12 @@ package cmd
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"syscall" "syscall"
"github.com/spf13/cobra"
"github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/config"
"github.com/mirivlad/sshkeeper/internal/vault" "github.com/mirivlad/sshkeeper/internal/vault"
"github.com/spf13/cobra"
"golang.org/x/term" "golang.org/x/term"
) )
@ -160,9 +161,75 @@ var vaultChangePasswordCmd = &cobra.Command{
}, },
} }
var vaultListCmd = &cobra.Command{
Use: "list",
Short: "List stored secret metadata",
RunE: func(cmd *cobra.Command, args []string) error {
v := getOrCreateVault()
if !v.IsUnlocked() {
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
}
output, err := formatVaultSecretsList(v)
if err != nil {
return err
}
fmt.Print(output)
return nil
},
}
var vaultDeleteCmd = &cobra.Command{
Use: "delete <alias> [type]",
Short: "Delete stored secrets for a server",
Args: cobra.RangeArgs(1, 2),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
secretType := ""
if len(args) == 2 {
secretType = args[1]
}
v := getOrCreateVault()
if !v.IsUnlocked() {
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
}
if err := deleteVaultSecrets(v, alias, secretType); err != nil {
return err
}
if err := v.Save(); err != nil {
return fmt.Errorf("save vault: %w", err)
}
if secretType == "" {
fmt.Printf("Deleted secrets for %s.\n", alias)
} else {
fmt.Printf("Deleted %s for %s.\n", secretType, alias)
}
return nil
},
}
func formatVaultSecretsList(v *vault.Vault) (string, error) {
metas, err := v.ListSecrets()
if err != nil {
return "", err
}
if len(metas) == 0 {
return "No secrets stored.\n", nil
}
var b strings.Builder
fmt.Fprintf(&b, "%-24s %-18s\n", "ALIAS", "TYPE")
for _, meta := range metas {
fmt.Fprintf(&b, "%-24s %-18s\n", meta.Alias, meta.Type)
}
return b.String(), nil
}
func init() { func init() {
vaultCmd.AddCommand(vaultUnlockCmd) vaultCmd.AddCommand(vaultUnlockCmd)
vaultCmd.AddCommand(vaultLockCmd) vaultCmd.AddCommand(vaultLockCmd)
vaultCmd.AddCommand(vaultStatusCmd) vaultCmd.AddCommand(vaultStatusCmd)
vaultCmd.AddCommand(vaultChangePasswordCmd) vaultCmd.AddCommand(vaultChangePasswordCmd)
vaultCmd.AddCommand(vaultListCmd)
vaultCmd.AddCommand(vaultDeleteCmd)
} }

40
cmd/vault_test.go Normal file
View File

@ -0,0 +1,40 @@
package cmd
import (
"strings"
"testing"
)
func TestFormatVaultSecretsListDoesNotExposeSecretValues(t *testing.T) {
v := newUnlockedTestVault(t)
mustPutSecret(t, v, "prod", "ssh_password", "super-secret")
mustPutSecret(t, v, "stage", "key_passphrase", "also-secret")
output, err := formatVaultSecretsList(v)
if err != nil {
t.Fatalf("format vault secrets list: %v", err)
}
for _, want := range []string{"prod", "ssh_password", "stage", "key_passphrase"} {
if !strings.Contains(output, want) {
t.Fatalf("expected output to contain %q\noutput:\n%s", want, output)
}
}
for _, secretValue := range []string{"super-secret", "also-secret"} {
if strings.Contains(output, secretValue) {
t.Fatalf("expected output not to expose secret value %q\noutput:\n%s", secretValue, output)
}
}
}
func TestFormatVaultSecretsListHandlesEmptyVault(t *testing.T) {
v := newUnlockedTestVault(t)
output, err := formatVaultSecretsList(v)
if err != nil {
t.Fatalf("format empty vault secrets list: %v", err)
}
if !strings.Contains(output, "No secrets stored.") {
t.Fatalf("expected empty output message, got:\n%s", output)
}
}

View File

@ -0,0 +1,295 @@
# Server List Scalability Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Keep the server list usable when the saved server count is larger than the terminal height.
**Architecture:** The list screen should render a bounded table viewport instead of rendering every server row. Selection remains owned by the existing `bubbles/list.Model`, while `viewServerList` derives a visible row range around the selected item and keeps the selected-server detail panel and footer visible.
**Tech Stack:** Go, Bubble Tea, Bubbles list, Lip Gloss, existing TUI tests in `internal/tui/app_test.go`.
---
### Task 1: Add Regression Coverage For Long Server Lists
**Files:**
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Add a test for a constrained terminal height**
Add a test that creates more servers than can fit on screen, sets a small terminal size, renders the list, and verifies the selected details and footer are still visible.
```go
func TestServerListViewKeepsDetailsVisibleWithManyServers(t *testing.T) {
servers := make([]*model.Server, 45)
for i := range servers {
servers[i] = &model.Server{
Alias: fmt.Sprintf("server-%02d", i+1),
DisplayName: fmt.Sprintf("Server %02d", i+1),
Host: fmt.Sprintf("host-%02d.example.org", i+1),
Port: 22,
User: "mirivlad",
AuthMethod: model.AuthKey,
LastTestStatus: model.TestUnknown,
}
}
m := New(servers)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
model := updated.(*tuiModel)
view := model.View()
if !strings.Contains(view, "Server 01") {
t.Fatalf("expected first selected server to be visible:\n%s", view)
}
if !strings.Contains(view, "Selected") {
t.Fatalf("expected selected server details to remain visible:\n%s", view)
}
if !strings.Contains(view, "Enter connect") {
t.Fatalf("expected footer to remain visible:\n%s", view)
}
if count := strings.Count(view, "server-"); count >= len(servers) {
t.Fatalf("expected bounded row rendering, rendered %d server aliases", count)
}
}
```
- [ ] **Step 2: Run the focused test and confirm it fails**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
```
Expected: FAIL because the current table renders every server and can push the detail panel/footer below the visible terminal area.
### Task 2: Compute A Visible Row Window
**Files:**
- Modify: `internal/tui/app.go`
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Add focused tests for visible range calculation**
Add tests for a helper that computes the inclusive start and exclusive end indexes for rendered rows.
```go
func TestVisibleServerRangeKeepsSelectionInsideWindow(t *testing.T) {
tests := []struct {
name string
total int
selected int
available int
wantStart int
wantEnd int
}{
{name: "first page", total: 40, selected: 0, available: 10, wantStart: 0, wantEnd: 10},
{name: "middle page", total: 40, selected: 20, available: 10, wantStart: 11, wantEnd: 21},
{name: "last page", total: 40, selected: 39, available: 10, wantStart: 30, wantEnd: 40},
{name: "all fit", total: 5, selected: 3, available: 10, wantStart: 0, wantEnd: 5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
start, end := visibleServerRange(tt.total, tt.selected, tt.available)
if start != tt.wantStart || end != tt.wantEnd {
t.Fatalf("visibleServerRange() = %d, %d; want %d, %d", start, end, tt.wantStart, tt.wantEnd)
}
})
}
}
```
- [ ] **Step 2: Implement `visibleServerRange`**
Add a small helper near `selectedServer`.
```go
func visibleServerRange(total, selected, available int) (int, int) {
if total <= 0 || available <= 0 {
return 0, 0
}
if available >= total {
return 0, total
}
if selected < 0 {
selected = 0
}
if selected >= total {
selected = total - 1
}
start := selected - available + 1
if start < 0 {
start = 0
}
end := start + available
if end > total {
end = total
start = end - available
}
return start, end
}
```
- [ ] **Step 3: Run helper tests**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestVisibleServerRangeKeepsSelectionInsideWindow -count=1
```
Expected: PASS.
### Task 3: Render Only Rows That Fit
**Files:**
- Modify: `internal/tui/app.go`
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Reserve terminal space for fixed UI blocks**
Add a helper that decides how many server rows may be rendered while keeping the selected details and footer visible.
```go
func (m *tuiModel) visibleServerRows() int {
if m.height <= 0 {
return len(m.servers)
}
const fixedRows = 16
rows := m.height - fixedRows
if rows < 3 {
return 3
}
return rows
}
```
- [ ] **Step 2: Use the visible range in `viewServerList`**
In `viewServerList`, replace the loop over all servers with a bounded loop:
```go
selectedIndex := m.list.Index()
start, end := visibleServerRange(len(m.servers), selectedIndex, m.visibleServerRows())
for _, server := range m.servers[start:end] {
// existing row rendering body stays unchanged
}
```
Then render a compact range hint when rows are hidden:
```go
if len(m.servers) > end-start {
b.WriteString(helpStyle.Render(fmt.Sprintf(" Showing %d-%d of %d", start+1, end, len(m.servers))))
b.WriteString("\n")
}
```
- [ ] **Step 3: Run long-list regression test**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
```
Expected: PASS.
### Task 4: Verify Navigation Still Works
**Files:**
- Modify: `internal/tui/app_test.go`
- [ ] **Step 1: Add a test for moving selection beyond the first window**
Use the existing `m.list.Update` path by sending `tea.KeyDown` messages and confirm the rendered window follows the selected server.
```go
func TestServerListViewScrollsWithSelection(t *testing.T) {
servers := make([]*model.Server, 45)
for i := range servers {
servers[i] = &model.Server{
Alias: fmt.Sprintf("server-%02d", i+1),
DisplayName: fmt.Sprintf("Server %02d", i+1),
Host: fmt.Sprintf("host-%02d.example.org", i+1),
Port: 22,
User: "mirivlad",
AuthMethod: model.AuthKey,
}
}
m := New(servers)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
model := updated.(*tuiModel)
for i := 0; i < 20; i++ {
updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyDown})
model = updated.(*tuiModel)
}
view := model.View()
if !strings.Contains(view, "Server 21") {
t.Fatalf("expected selected server to be visible after navigation:\n%s", view)
}
if !strings.Contains(view, "Showing") {
t.Fatalf("expected range hint for long server list:\n%s", view)
}
}
```
- [ ] **Step 2: Run TUI tests**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -count=1
```
Expected: PASS.
### Task 5: Final Verification And Build
**Files:**
- No source edits expected.
- [ ] **Step 1: Run the full test suite**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go test ./...
```
Expected: all packages pass.
- [ ] **Step 2: Rebuild the project binary**
Run:
```bash
env GOCACHE=/tmp/sshkeeper-go-cache go build -o bin/sshkeeper .
```
Expected: exit code 0 and updated `bin/sshkeeper`.
- [ ] **Step 3: Commit the implementation**
Run:
```bash
git add internal/tui/app.go internal/tui/app_test.go bin/sshkeeper
git commit -m "fix: keep server list usable with many servers"
```
Expected: commit succeeds.
---
## Self-Review
- Spec coverage: the plan covers the known failure mode for 40+ servers, keeps selected details visible, keeps the footer visible, and preserves existing `bubbles/list` navigation.
- Placeholder scan: no `TBD`, `TODO`, or open-ended implementation placeholders remain.
- Type consistency: helper names and files match the current TUI code shape.

2
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/creack/pty v1.1.24 github.com/creack/pty v1.1.24
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.52.0
golang.org/x/sys v0.45.0
golang.org/x/term v0.43.0 golang.org/x/term v0.43.0
modernc.org/sqlite v1.50.1 modernc.org/sqlite v1.50.1
) )
@ -41,7 +42,6 @@ require (
github.com/sahilm/fuzzy v0.1.1 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/spf13/pflag v1.0.9 // indirect github.com/spf13/pflag v1.0.9 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.37.0 // indirect
modernc.org/libc v1.72.3 // indirect modernc.org/libc v1.72.3 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

View File

@ -30,6 +30,17 @@ func (db *DB) UpdateServer(s *model.Server) error {
return err return err
} }
func (db *DB) UpdateServerByAlias(oldAlias string, s *model.Server) error {
_, err := db.conn.Exec(`
UPDATE servers SET
alias=?, display_name=?, host=?, port=?, user=?, auth_method=?,
identity_file=?, proxy_jump=?, group_name=?, notes=?, updated_at=CURRENT_TIMESTAMP
WHERE alias=?`,
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, oldAlias)
return err
}
func (db *DB) DeleteServer(alias string) error { func (db *DB) DeleteServer(alias string) error {
_, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias) _, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias)
return err return err

View File

@ -0,0 +1,47 @@
package db
import (
"testing"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestUpdateServerByAliasCanRenameAlias(t *testing.T) {
db, err := Open(t.TempDir())
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
server := &model.Server{
Alias: "old",
Host: "old.example",
Port: 22,
User: "root",
AuthMethod: model.AuthPassword,
}
if err := db.CreateServer(server); err != nil {
t.Fatalf("create server: %v", err)
}
server.Alias = "new"
server.Host = "new.example"
server.AuthMethod = model.AuthKey
if err := db.UpdateServerByAlias("old", server); err != nil {
t.Fatalf("update server by old alias: %v", err)
}
if _, err := db.GetServer("old"); err == nil {
t.Fatal("expected old alias to be gone")
}
got, err := db.GetServer("new")
if err != nil {
t.Fatalf("get new alias: %v", err)
}
if got.ID != server.ID {
t.Fatalf("expected ID to stay %d, got %d", server.ID, got.ID)
}
if got.Host != "new.example" || got.AuthMethod != model.AuthKey {
t.Fatalf("unexpected updated server: %#v", got)
}
}

View File

@ -5,7 +5,6 @@ import (
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"time"
"github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/config"
"github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/model"
@ -14,7 +13,7 @@ import (
type VaultFunc func(serverAlias string, secretType string) (string, error) type VaultFunc func(serverAlias string, secretType string) (string, error)
func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error { func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error {
args := buildArgs(server) args := BuildSSHArgs(server)
switch server.AuthMethod { switch server.AuthMethod {
case model.AuthPassword: case model.AuthPassword:
@ -25,8 +24,7 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
return ConnectWithPassword(cfg.SSH.Binary, args, password) return ConnectWithPassword(cfg.SSH.Binary, args, password)
case model.AuthKeyPassphrase: case model.AuthKeyPassphrase:
// For key+passphrase, we need to handle the passphrase // For key+passphrase, let ssh-agent handle it or prompt normally
// For now, let ssh-agent handle it or prompt normally
// TODO: use ssh-agent or similar // TODO: use ssh-agent or similar
fallthrough fallthrough
@ -46,13 +44,11 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
} }
func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) { func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) {
args := buildArgs(server) args := BuildSSHArgs(server)
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec)) args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec))
switch server.AuthMethod { switch server.AuthMethod {
case model.AuthPassword: case model.AuthPassword:
// For password auth, we can't use BatchMode
// Use a short timeout and try to connect
args = append(args, "-o", "NumberOfPasswordPrompts=1") args = append(args, "-o", "NumberOfPasswordPrompts=1")
password, err := getVault(server.Alias, "ssh_password") password, err := getVault(server.Alias, "ssh_password")
if err != nil { if err != nil {
@ -81,40 +77,29 @@ func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, s
} }
} }
// testWithPassword tests SSH connection with password auth via PTY-wrapper.
// It connects, sends the password, runs the test command, and checks the output.
func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) { func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) {
// For password test, we use PTY approach with a short timeout
// This is a simplified version - in production, use ConnectWithPassword
// with a test command
args = append(args, cfg.SSH.TestCommand) args = append(args, cfg.SSH.TestCommand)
cmd := exec.Command(cfg.SSH.Binary, args...) ok, output := connectWithPasswordAndRead(cfg.SSH.Binary, args, password, cfg.SSH.ConnectTimeoutSec)
cmd.Stdin = nil if !ok {
cmd.Stdout = nil return false, output
cmd.Stderr = nil
// Use a timeout
done := make(chan error, 1)
if err := cmd.Start(); err != nil {
return false, err.Error()
} }
go func() { result := strings.TrimSpace(output)
done <- cmd.Wait() if result == "SSHKEEPER_OK" {
}()
select {
case err := <-done:
if err != nil {
return false, err.Error()
}
return true, "" return true, ""
case <-time.After(time.Duration(cfg.SSH.ConnectTimeoutSec) * time.Second):
cmd.Process.Kill()
return false, "connection timeout"
} }
// The output might have the test command echo before the result
if strings.Contains(result, "SSHKEEPER_OK") {
return true, ""
}
return false, result
} }
func buildArgs(server *model.Server) []string { // BuildSSHArgs builds the SSH command arguments for a server profile.
func BuildSSHArgs(server *model.Server) []string {
var args []string var args []string
args = append(args, "-p", fmt.Sprintf("%d", server.Port)) args = append(args, "-p", fmt.Sprintf("%d", server.Port))
@ -127,8 +112,6 @@ func buildArgs(server *model.Server) []string {
args = append(args, "-J", server.ProxyJump) args = append(args, "-J", server.ProxyJump)
} }
// Disable strict host key checking for first connection
// In production, this should be configurable
args = append(args, "-o", "StrictHostKeyChecking=accept-new") args = append(args, "-o", "StrictHostKeyChecking=accept-new")
target := fmt.Sprintf("%s@%s", server.User, server.Host) target := fmt.Sprintf("%s@%s", server.User, server.Host)

View File

@ -7,19 +7,19 @@ import (
"os/exec" "os/exec"
"regexp" "regexp"
"strings" "strings"
"sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
"github.com/creack/pty" "github.com/creack/pty"
"golang.org/x/sys/unix"
"golang.org/x/term" "golang.org/x/term"
) )
var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`) var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`)
// ConnectWithPassword runs SSH through a PTY, detects the password prompt,
// sends the password, and then bridges the user terminal to the SSH session.
func ConnectWithPassword(sshBinary string, args []string, password string) error { func ConnectWithPassword(sshBinary string, args []string, password string) error {
// Start SSH with PTY
cmd := exec.Command(sshBinary, args...) cmd := exec.Command(sshBinary, args...)
cmd.Env = os.Environ() cmd.Env = os.Environ()
cmd.SysProcAttr = &syscall.SysProcAttr{ cmd.SysProcAttr = &syscall.SysProcAttr{
@ -33,18 +33,14 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
} }
defer ptmx.Close() defer ptmx.Close()
// Save terminal state and set to raw
oldState, err := term.MakeRaw(int(os.Stdin.Fd())) oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil { if err != nil {
return fmt.Errorf("set raw terminal: %w", err) return fmt.Errorf("set raw terminal: %w", err)
} }
defer term.Restore(int(os.Stdin.Fd()), oldState)
// Channel to signal when password has been sent passwordSent := new(atomic.Bool)
passwordSent := make(chan bool, 1)
done := make(chan error, 1) done := make(chan error, 1)
// Read from PTY, detect password prompt
go func() { go func() {
buf := make([]byte, 4096) buf := make([]byte, 4096)
var accumulated strings.Builder var accumulated strings.Builder
@ -54,22 +50,17 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
if n > 0 { if n > 0 {
data := buf[:n] data := buf[:n]
accumulated.Write(data) accumulated.Write(data)
// Write to stdout
os.Stdout.Write(data) os.Stdout.Write(data)
// Check for password prompt if !passwordSent.Load() {
if !<-passwordSent {
text := accumulated.String() text := accumulated.String()
if passwordPromptRe.MatchString(text) { if passwordPromptRe.MatchString(text) {
passwordSent <- true passwordSent.Store(true)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
ptmx.Write([]byte(password + "\r")) ptmx.Write([]byte(password + "\r"))
continue
} }
} }
// Reset accumulated buffer periodically to avoid unbounded growth
if accumulated.Len() > 8192 { if accumulated.Len() > 8192 {
s := accumulated.String() s := accumulated.String()
accumulated.Reset() accumulated.Reset()
@ -87,14 +78,142 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
} }
}() }()
// Copy stdin to PTY stopInput, waitInput, restoreInput, err := forwardInputToPTY(ptmx)
go func() { if err != nil {
io.Copy(ptmx, os.Stdin) term.Restore(int(os.Stdin.Fd()), oldState)
}() return err
}
// Wait for command completion
err = cmd.Wait() err = cmd.Wait()
passwordSent <- false // signal to stop close(stopInput)
waitInput()
if restoreErr := restoreInput(); err == nil && restoreErr != nil {
err = restoreErr
}
if restoreErr := term.Restore(int(os.Stdin.Fd()), oldState); err == nil && restoreErr != nil {
err = restoreErr
}
return err return err
} }
func forwardInputToPTY(ptmx *os.File) (chan struct{}, func(), func() error, error) {
stdinFD := int(os.Stdin.Fd())
flags, err := unix.FcntlInt(uintptr(stdinFD), unix.F_GETFL, 0)
if err != nil {
return nil, nil, nil, fmt.Errorf("get stdin flags: %w", err)
}
if err := unix.SetNonblock(stdinFD, true); err != nil {
return nil, nil, nil, fmt.Errorf("set stdin nonblocking: %w", err)
}
stop := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 4096)
for {
select {
case <-stop:
return
default:
}
n, readErr := unix.Read(stdinFD, buf)
if n > 0 {
if _, writeErr := ptmx.Write(buf[:n]); writeErr != nil {
return
}
}
if readErr == nil {
continue
}
if readErr == unix.EAGAIN || readErr == unix.EWOULDBLOCK {
select {
case <-stop:
return
case <-time.After(10 * time.Millisecond):
continue
}
}
return
}
}()
restore := func() error {
_, err := unix.FcntlInt(uintptr(stdinFD), unix.F_SETFL, flags)
if err != nil {
return fmt.Errorf("restore stdin flags: %w", err)
}
return nil
}
return stop, wg.Wait, restore, nil
}
// connectWithPasswordAndRead runs SSH through a PTY, sends the password,
// collects all output, and returns it. Used for non-interactive testing.
// Returns (true, output) on success, (false, error) on failure.
func connectWithPasswordAndRead(sshBinary string, args []string, password string, timeoutSec int) (bool, string) {
cmd := exec.Command(sshBinary, args...)
cmd.Env = os.Environ()
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
Setctty: true,
}
ptmx, err := pty.Start(cmd)
if err != nil {
return false, fmt.Sprintf("start ssh with pty: %v", err)
}
defer ptmx.Close()
passwordSent := false
done := make(chan string, 1)
// Read from PTY, detect password prompt, collect output
go func() {
buf := make([]byte, 4096)
var accumulated strings.Builder
for {
n, err := ptmx.Read(buf)
if n > 0 {
data := buf[:n]
accumulated.Write(data)
// Check for password prompt
if !passwordSent {
text := accumulated.String()
if passwordPromptRe.MatchString(text) {
passwordSent = true
time.Sleep(100 * time.Millisecond)
ptmx.Write([]byte(password + "\r"))
continue
}
}
// Reset accumulated buffer periodically
if accumulated.Len() > 16384 {
s := accumulated.String()
accumulated.Reset()
accumulated.WriteString(s[len(s)-4096:])
}
}
if err != nil {
done <- accumulated.String()
return
}
}
}()
// Wait for command completion or timeout
select {
case output := <-done:
return true, output
case <-time.After(time.Duration(timeoutSec) * time.Second):
cmd.Process.Kill()
return false, "connection timeout"
}
}

63
internal/ssh/pty_test.go Normal file
View File

@ -0,0 +1,63 @@
package ssh
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/creack/pty"
)
func TestConnectWithPasswordDoesNotConsumeInputAfterReturn(t *testing.T) {
script := filepath.Join(t.TempDir(), "fake-ssh")
if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf 'password: '\nsleep 0.1\nexit 0\n"), 0o700); err != nil {
t.Fatalf("write fake ssh: %v", err)
}
master, slave, err := pty.Open()
if err != nil {
t.Fatalf("open stdin pty: %v", err)
}
defer master.Close()
defer slave.Close()
oldStdin := os.Stdin
oldStdout := os.Stdout
stdoutSink, err := os.CreateTemp(t.TempDir(), "stdout")
if err != nil {
t.Fatalf("create stdout sink: %v", err)
}
defer stdoutSink.Close()
os.Stdin = slave
os.Stdout = stdoutSink
defer func() {
os.Stdin = oldStdin
os.Stdout = oldStdout
}()
if err := ConnectWithPassword(script, nil, "secret"); err != nil {
t.Fatalf("connect with password: %v", err)
}
if _, err := master.Write([]byte("\n")); err != nil {
t.Fatalf("write next enter: %v", err)
}
time.Sleep(100 * time.Millisecond)
readDone := make(chan []byte, 1)
go func() {
buf := make([]byte, 1)
n, _ := slave.Read(buf)
readDone <- buf[:n]
}()
select {
case got := <-readDone:
if len(got) != 1 || got[0] != '\n' {
t.Fatalf("expected to read newline, got %q", got)
}
case <-time.After(300 * time.Millisecond):
t.Fatal("expected next Enter to remain readable after ConnectWithPassword returned")
}
}

View File

@ -26,7 +26,10 @@ var (
Background(lipgloss.Color("4")). Background(lipgloss.Color("4")).
Bold(true) Bold(true)
normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")) normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15"))
selectedRowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")).Background(lipgloss.Color("4"))
listHeaderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("14")).Bold(true)
sectionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")).Bold(true).MarginTop(1)
testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true) testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true)
testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true) testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true)
@ -67,9 +70,13 @@ type serverItem struct {
server *model.Server server *model.Server
} }
func (i serverItem) Title() string { return i.server.Alias } func (i serverItem) Title() string { return i.server.Alias }
func (i serverItem) Description() string { return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod) } func (i serverItem) Description() string {
func (i serverItem) FilterValue() string { return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User } return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod)
}
func (i serverItem) FilterValue() string {
return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User
}
// --- External callbacks --- // --- External callbacks ---
@ -80,10 +87,12 @@ var (
TestConnection func(server *model.Server) (bool, string) TestConnection func(server *model.Server) (bool, string)
// TestConnectionWithPassword tests with explicit password (for form test before save) // TestConnectionWithPassword tests with explicit password (for form test before save)
TestConnectionWithPassword func(server *model.Server, password string) (bool, string) TestConnectionWithPassword func(server *model.Server, password string) (bool, string)
SaveServer func(server *model.Server, password string) error SaveServer func(server *model.Server, password string, oldAlias string) error
GetGroups func() ([]string, error) // Returns existing group names UpdateTestResult func(alias string, status model.TestStatus, testErr string) error
RenameGroup func(oldName, newName string) error // Rename group for all servers HasSecret func(alias string, secretType string) bool
DeleteGroup func(name string) error // Remove group from all servers GetGroups func() ([]string, error)
RenameGroup func(oldName, newName string) error
DeleteGroup func(name string) error
) )
// --- Screen type --- // --- Screen type ---
@ -99,8 +108,8 @@ const (
// --- Result type — returned from TUI to caller --- // --- Result type — returned from TUI to caller ---
type TUIResult struct { type TUIResult struct {
Server *model.Server Server *model.Server
Action string // "connect" Action string // "connect"
} }
// --- Main TUI model --- // --- Main TUI model ---
@ -196,8 +205,22 @@ func (m *tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} }
m.form.testResultTime = time.Now() m.form.testResultTime = time.Now()
m.form.err = nil m.form.err = nil
return m, nil
}
// Update test status in DB and reload list
if item, ok := m.list.SelectedItem().(serverItem); ok && UpdateTestResult != nil {
status := model.TestUnknown
if msg.ok {
status = model.TestOK
} else if msg.err != "" {
status = model.TestFailed
}
UpdateTestResult(item.server.Alias, status, msg.err)
}
return m, func() tea.Msg {
servers, err := ListServers()
return serversLoadedMsg{servers: servers, err: err}
} }
return m, nil
case saveDoneMsg: case saveDoneMsg:
if m.form != nil { if m.form != nil {
@ -323,11 +346,23 @@ func (m *tuiModel) updateSearch(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
if msg.Type == tea.KeyEsc { if msg.Type == tea.KeyEsc {
if m.form != nil && (m.form.showGroupList || m.form.showAuthList) {
updated, cmd := m.form.Update(msg)
if fm, ok := updated.(*formModel); ok {
m.form = fm
}
return m, cmd
}
m.screen = screenList m.screen = screenList
m.form = nil m.form = nil
m.err = nil m.err = nil
m.success = "" m.success = ""
return m, nil // Reload server list after form close
return m, func() tea.Msg {
servers, err := ListServers()
return serversLoadedMsg{servers: servers, err: err}
}
} }
updated, cmd := m.form.Update(msg) updated, cmd := m.form.Update(msg)
@ -342,9 +377,7 @@ func (m *tuiModel) View() string {
switch m.screen { switch m.screen {
case screenList: case screenList:
b.WriteString(m.list.View()) b.WriteString(m.viewServerList())
b.WriteString("\n")
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
case screenSearch: case screenSearch:
b.WriteString("Search: " + m.searchInput.View() + "\n") b.WriteString("Search: " + m.searchInput.View() + "\n")
@ -366,13 +399,148 @@ func (m *tuiModel) View() string {
return b.String() return b.String()
} }
func (m *tuiModel) viewServerList() string {
var b strings.Builder
selectedAlias := ""
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
selectedAlias = item.server.Alias
}
b.WriteString(titleStyle.Render(fmt.Sprintf("sshkeeper %d servers", len(m.servers))))
b.WriteString("\n")
b.WriteString(helpStyle.Render(fmt.Sprintf("Vault unlocked | %s", testSummary(m.servers))))
b.WriteString("\n\n")
b.WriteString(listHeaderStyle.Render(fmt.Sprintf(" %-20s %-20s %-34s %-12s %-10s %s", "NAME", "ALIAS", "TARGET", "AUTH", "GROUP", "STATUS")))
b.WriteString("\n")
if len(m.servers) == 0 {
b.WriteString(helpStyle.Render(" No servers yet. Press Ctrl+A to add one."))
b.WriteString("\n")
} else {
for _, server := range m.servers {
marker := " "
rowStyle := normalStyle
if server.Alias == selectedAlias {
marker = ">"
rowStyle = selectedRowStyle
}
name := server.DisplayName
if name == "" {
name = server.Alias
}
target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port)
group := server.GroupName
if group == "" {
group = "-"
}
row := fmt.Sprintf("%s %-20s %-20s %-34s %-12s %-10s %s",
marker,
truncate(name, 20),
truncate(server.Alias, 20),
truncate(target, 34),
authLabel(server.AuthMethod),
truncate(group, 10),
testStatusLabel(server),
)
b.WriteString(rowStyle.Render(row))
b.WriteString("\n")
}
}
b.WriteString("\n")
if selectedAlias != "" {
if selected := m.selectedServer(); selected != nil {
b.WriteString(m.viewSelectedServer(selected))
b.WriteString("\n")
}
}
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
return b.String()
}
func (m *tuiModel) selectedServer() *model.Server {
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
return item.server
}
return nil
}
func (m *tuiModel) viewSelectedServer(server *model.Server) string {
displayName := server.DisplayName
if displayName == "" {
displayName = "-"
}
group := server.GroupName
if group == "" {
group = "-"
}
var b strings.Builder
b.WriteString(sectionStyle.Render("Selected"))
b.WriteString("\n")
b.WriteString(fmt.Sprintf(" Alias: %s\n", server.Alias))
b.WriteString(fmt.Sprintf(" Display Name: %s\n", displayName))
b.WriteString(fmt.Sprintf(" Host: %s\n", server.Host))
b.WriteString(fmt.Sprintf(" Port: %d\n", server.Port))
b.WriteString(fmt.Sprintf(" User: %s\n", server.User))
b.WriteString(fmt.Sprintf(" Auth: %s\n", authLabel(server.AuthMethod)))
b.WriteString(fmt.Sprintf(" Group: %s\n", group))
b.WriteString(fmt.Sprintf(" Status: %s\n", testStatusLabel(server)))
return b.String()
}
func testSummary(servers []*model.Server) string {
okCount := 0
failedCount := 0
for _, server := range servers {
switch server.LastTestStatus {
case model.TestOK:
okCount++
case model.TestFailed:
failedCount++
}
}
return fmt.Sprintf("%d OK | %d FAIL", okCount, failedCount)
}
func authLabel(auth model.AuthMethod) string {
switch auth {
case model.AuthPassword:
return "password"
case model.AuthKey:
return "key"
case model.AuthKeyPassphrase:
return "key+phrase"
case model.AuthAgent:
return "agent"
default:
return string(auth)
}
}
func testStatusLabel(server *model.Server) string {
switch server.LastTestStatus {
case model.TestOK:
return "OK"
case model.TestFailed:
if server.LastTestError != "" {
return "FAIL"
}
return "FAIL"
default:
return "?"
}
}
// --- Form model --- // --- Form model ---
type formModel struct { type formModel struct {
edit bool edit bool
server *model.Server server *model.Server
inputs []textinput.Model inputs []textinput.Model
labels []string
password textinput.Model password textinput.Model
passwordLabel string
focusIdx int focusIdx int
testResult string testResult string
testOK bool testOK bool
@ -385,10 +553,22 @@ type formModel struct {
spinner spinner.Model spinner spinner.Model
width int width int
height int height int
groups string // cached groups list (comma-separated for display) groups []string // existing group names
showGroups bool // show group dropdown groupList list.Model // dropdown list for groups
showGroupList bool // whether group dropdown is visible
authList list.Model
showAuthList bool
} }
// groupItem implements list.Item for the group dropdown
type groupItem struct {
name string
}
func (i groupItem) Title() string { return i.name }
func (i groupItem) Description() string { return "" }
func (i groupItem) FilterValue() string { return i.name }
func newFormModel(w, h int) *formModel { func newFormModel(w, h int) *formModel {
inputs := make([]textinput.Model, 10) inputs := make([]textinput.Model, 10)
labels := []string{ labels := []string{
@ -405,12 +585,12 @@ func newFormModel(w, h int) *formModel {
} }
for i, label := range labels { for i, label := range labels {
inputs[i] = textinput.New() inputs[i] = textinput.New()
inputs[i].Placeholder = label inputs[i].Placeholder = placeholderForLabel(label)
inputs[i].CharLimit = 128 inputs[i].CharLimit = 128
} }
pw := textinput.New() pw := textinput.New()
pw.Placeholder = "Password / Passphrase (stored in vault)" pw.Placeholder = "optional"
pw.CharLimit = 256 pw.CharLimit = 256
pw.EchoMode = textinput.EchoPassword pw.EchoMode = textinput.EchoPassword
@ -421,24 +601,75 @@ func newFormModel(w, h int) *formModel {
inputs[0].Focus() inputs[0].Focus()
fm := &formModel{ fm := &formModel{
inputs: inputs, inputs: inputs,
password: pw, labels: labels,
focusIdx: 0, password: pw,
spinner: s, passwordLabel: "Password / Passphrase",
width: w, focusIdx: 0,
height: h, spinner: s,
width: w,
height: h,
} }
fm.authList = newStringList([]string{
string(model.AuthPassword),
string(model.AuthKey),
string(model.AuthKeyPassphrase),
string(model.AuthAgent),
}, "Select auth method", 34, 16)
// Load existing groups // Load existing groups
if GetGroups != nil { if GetGroups != nil {
if groups, err := GetGroups(); err == nil && len(groups) > 0 { if groups, err := GetGroups(); err == nil && len(groups) > 0 {
fm.groups = strings.Join(groups, ", ") fm.groups = groups
fm.groupList = newStringList(groups, "Select group", 30, 8)
} }
} }
fm.updateFocus()
return fm return fm
} }
func placeholderForLabel(label string) string {
switch label {
case "Alias":
return "mail.kp"
case "Display Name":
return "Production mail"
case "Host":
return "mail.example.org"
case "Port":
return "22"
case "User":
return "root"
case "Auth Method (password/key/key_passphrase/agent)":
return "key"
case "Identity File":
return "~/.ssh/id_ed25519"
case "ProxyJump":
return "optional"
case "Group (type new or pick from list)":
return "KP"
case "Notes":
return "optional"
default:
return label
}
}
func newStringList(values []string, title string, width, height int) list.Model {
items := make([]list.Item, len(values))
for i, value := range values {
items[i] = groupItem{name: value}
}
l := list.New(items, list.NewDefaultDelegate(), width, height)
l.SetShowStatusBar(false)
l.SetShowHelp(false)
l.SetShowPagination(false)
l.Title = title
l.Styles.Title = titleStyle
return l
}
func newEditFormModel(s *model.Server, w, h int) *formModel { func newEditFormModel(s *model.Server, w, h int) *formModel {
fm := newFormModel(w, h) fm := newFormModel(w, h)
fm.edit = true fm.edit = true
@ -453,6 +684,21 @@ func newEditFormModel(s *model.Server, w, h int) *formModel {
fm.inputs[7].SetValue(s.ProxyJump) fm.inputs[7].SetValue(s.ProxyJump)
fm.inputs[8].SetValue(s.GroupName) fm.inputs[8].SetValue(s.GroupName)
fm.inputs[9].SetValue(s.Notes) fm.inputs[9].SetValue(s.Notes)
if HasSecret != nil {
switch s.AuthMethod {
case model.AuthPassword:
if HasSecret(s.Alias, "ssh_password") {
fm.passwordLabel = "Password (secret saved; leave blank to keep)"
fm.password.Placeholder = ""
}
case model.AuthKeyPassphrase:
if HasSecret(s.Alias, "key_passphrase") {
fm.passwordLabel = "Key passphrase (secret saved; leave blank to keep)"
fm.password.Placeholder = ""
}
}
}
fm.updateFocus()
return fm return fm
} }
@ -498,6 +744,48 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return fm, cmd return fm, cmd
} }
// Handle group dropdown
if fm.showGroupList {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyEsc:
fm.showGroupList = false
return fm, nil
case tea.KeyEnter:
if item, ok := fm.groupList.SelectedItem().(groupItem); ok {
fm.inputs[8].SetValue(item.name)
}
fm.showGroupList = false
return fm, nil
}
}
// Pass other keys to the list
var cmd tea.Cmd
fm.groupList, cmd = fm.groupList.Update(msg)
return fm, cmd
}
if fm.showAuthList {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyEsc:
fm.showAuthList = false
return fm, nil
case tea.KeyEnter:
if item, ok := fm.authList.SelectedItem().(groupItem); ok {
fm.inputs[5].SetValue(item.name)
}
fm.showAuthList = false
return fm, nil
}
}
var cmd tea.Cmd
fm.authList, cmd = fm.authList.Update(msg)
return fm, cmd
}
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.KeyMsg: case tea.KeyMsg:
switch msg.Type { switch msg.Type {
@ -519,6 +807,17 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
fm.updateFocus() fm.updateFocus()
return fm, nil return fm, nil
case tea.KeyRunes:
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 5 {
fm.showAuthList = true
return fm, nil
}
// '/' on Group field opens group dropdown
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 8 && len(fm.groups) > 0 {
fm.showGroupList = true
return fm, nil
}
case tea.KeyEnter: case tea.KeyEnter:
switch { switch {
case fm.focusIdx == len(fm.inputs)+1: case fm.focusIdx == len(fm.inputs)+1:
@ -576,20 +875,36 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (fm *formModel) updateFocus() { func (fm *formModel) updateFocus() {
for i := range fm.inputs { for i := range fm.inputs {
fm.inputs[i].Blur() fm.inputs[i].Blur()
fm.inputs[i].Prompt = blurredStyle.Render(fm.inputs[i].Placeholder + ": ") fm.inputs[i].Prompt = blurredStyle.Render(fm.labelAt(i) + ": ")
} }
fm.password.Blur() fm.password.Blur()
fm.password.Prompt = blurredStyle.Render(fm.password.Placeholder + ": ") fm.password.Prompt = blurredStyle.Render(fm.passwordLabel + ": ")
if fm.focusIdx < len(fm.inputs) { if fm.focusIdx < len(fm.inputs) {
fm.inputs[fm.focusIdx].Focus() fm.inputs[fm.focusIdx].Focus()
fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.inputs[fm.focusIdx].Placeholder + "> ") fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.labelAt(fm.focusIdx) + "> ")
} else if fm.focusIdx == len(fm.inputs) { } else if fm.focusIdx == len(fm.inputs) {
fm.password.Focus() fm.password.Focus()
fm.password.Prompt = focusedStyle.Render(fm.password.Placeholder + "> ") fm.password.Prompt = focusedStyle.Render(fm.passwordLabel + "> ")
} }
} }
func (fm *formModel) labelAt(index int) string {
if index >= 0 && index < len(fm.labels) {
if index == 5 {
return "Auth Method (/ pick)"
}
if index == 8 {
if len(fm.groups) > 0 {
return "Group (/ pick)"
}
return "Group"
}
return fm.labels[index]
}
return ""
}
func (fm *formModel) runTest() tea.Cmd { func (fm *formModel) runTest() tea.Cmd {
fm.testing = true fm.testing = true
fm.testResult = "" fm.testResult = ""
@ -635,7 +950,11 @@ func (fm *formModel) runSave() tea.Cmd {
if s.Host == "" { if s.Host == "" {
return saveDoneMsg{err: fmt.Errorf("host is required")} return saveDoneMsg{err: fmt.Errorf("host is required")}
} }
err := SaveServer(s, pw) oldAlias := ""
if fm.edit && fm.server != nil {
oldAlias = fm.server.Alias
}
err := SaveServer(s, pw, oldAlias)
return saveDoneMsg{err: err} return saveDoneMsg{err: err}
}, },
) )
@ -672,13 +991,72 @@ func (fm *formModel) View() string {
b.WriteString(titleStyle.Render(title)) b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n") b.WriteString("\n\n")
for i := range fm.inputs { // Calculate visible range based on terminal height
// Reserve lines for: title (2) + password (1) + buttons (3) + help (1) + padding (2) = ~9
reserved := 9
available := fm.height - reserved
if available < 4 {
available = 4
}
numInputs := len(fm.inputs)
startIdx := 0
endIdx := numInputs
// Scroll: keep focused field visible
if numInputs > available {
focusInput := fm.focusIdx
if focusInput >= numInputs {
focusInput = numInputs - 1
}
// Try to show `available` fields centered on focus
startIdx = focusInput - available/2
if startIdx < 0 {
startIdx = 0
}
endIdx = startIdx + available
if endIdx > numInputs {
endIdx = numInputs
startIdx = endIdx - available
if startIdx < 0 {
startIdx = 0
}
}
}
// Show scroll indicator if needed
if startIdx > 0 {
b.WriteString(helpStyle.Render(" ↑ more fields above\n"))
}
for i := startIdx; i < endIdx; i++ {
if section := formSectionTitle(i); section != "" {
b.WriteString(sectionStyle.Render(section))
b.WriteString("\n")
}
if i == 5 {
fm.inputs[i].Placeholder = "password/key/key_passphrase/agent"
}
// Show group hint inline in placeholder for Group field
if i == 8 && len(fm.groups) > 0 && !fm.showGroupList {
fm.inputs[i].Placeholder = truncate(strings.Join(fm.groups, ", "), 25)
}
b.WriteString(fm.inputs[i].View()) b.WriteString(fm.inputs[i].View())
b.WriteString("\n") b.WriteString("\n")
// Show existing groups hint under Group field if i == 5 && fm.showAuthList {
if i == 8 && fm.groups != "" { b.WriteString("\n" + renderDropdown(fm.authList) + "\n")
b.WriteString(helpStyle.Render(" Groups: " + fm.groups + "\n")) b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
return b.String()
} }
if i == 8 && fm.showGroupList {
b.WriteString("\n" + renderDropdown(fm.groupList) + "\n")
b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
return b.String()
}
}
if endIdx < numInputs {
b.WriteString(helpStyle.Render(fmt.Sprintf(" ↓ more fields below (%d-%d of %d)\n", startIdx+1, endIdx, numInputs)))
} }
b.WriteString(fm.password.View()) b.WriteString(fm.password.View())
@ -723,8 +1101,53 @@ func (fm *formModel) View() string {
saveBtn = normalStyle.Render(saveBtn) saveBtn = normalStyle.Render(saveBtn)
} }
b.WriteString("\n" + testBtn + " " + saveBtn + "\n\n") b.WriteString("\n" + sectionStyle.Render("Actions") + "\n")
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | Enter select | Esc back")) b.WriteString(testBtn + " " + saveBtn + "\n\n")
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | / pick list | Enter select | Esc back"))
return b.String() return b.String()
} }
func renderDropdown(l list.Model) string {
var b strings.Builder
b.WriteString(sectionStyle.Render(l.Title))
b.WriteString("\n")
for i, item := range l.Items() {
group, ok := item.(groupItem)
if !ok {
continue
}
prefix := " "
style := normalStyle
if i == l.Index() {
prefix = "> "
style = selectedRowStyle
}
b.WriteString(style.Render(prefix + group.name))
b.WriteString("\n")
}
return strings.TrimRight(b.String(), "\n")
}
func formSectionTitle(index int) string {
switch index {
case 0:
return "Identity"
case 2:
return "Connection"
case 5:
return "Authentication"
case 8:
return "Metadata"
default:
return ""
}
}
// truncate limits a string to maxLen, adding "..." if truncated
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen-3] + "..."
}

322
internal/tui/app_test.go Normal file
View File

@ -0,0 +1,322 @@
package tui
import (
"strings"
"testing"
"time"
"github.com/charmbracelet/bubbles/list"
tea "github.com/charmbracelet/bubbletea"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestServerListViewUsesDashboardLayout(t *testing.T) {
now := time.Date(2026, 5, 28, 1, 50, 0, 0, time.UTC)
m := New([]*model.Server{
{
Alias: "mail.kp",
DisplayName: "Mail",
Host: "mail.example.org",
Port: 222,
User: "mirivlad",
AuthMethod: model.AuthPassword,
GroupName: "KP",
LastTestStatus: model.TestOK,
LastTestAt: &now,
},
{
Alias: "mirv.top",
Host: "mirv.top",
Port: 22,
User: "root",
AuthMethod: model.AuthKey,
LastTestStatus: model.TestUnknown,
},
})
m.width = 100
m.height = 30
m.list.SetSize(100, 24)
view := m.View()
for _, want := range []string{
"sshkeeper",
"2 servers",
"Vault",
"NAME",
"TARGET",
"AUTH",
"GROUP",
"STATUS",
"Mail",
"mail.kp",
"mirivlad@mail.example.org:222",
"KP",
"OK",
"Selected",
"Host: mail.example.org",
"Alias: mail.kp",
"Display Name: Mail",
"Port: 222",
"Enter connect",
} {
if !strings.Contains(view, want) {
t.Fatalf("expected list view to contain %q\nview:\n%s", want, view)
}
}
if strings.Contains(view, "Profiles managed locally") {
t.Fatalf("expected compact status header instead of README text\nview:\n%s", view)
}
}
func TestEscClosesGroupListBeforeLeavingForm(t *testing.T) {
oldGetGroups := GetGroups
GetGroups = func() ([]string, error) {
return []string{"prod", "stage"}, nil
}
defer func() { GetGroups = oldGetGroups }()
m := &tuiModel{
screen: screenForm,
form: newFormModel(80, 24),
}
m.form.focusIdx = 8
m.form.updateFocus()
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
m = updated.(*tuiModel)
if !m.form.showGroupList {
t.Fatal("expected / on the group field to open the group list")
}
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
m = updated.(*tuiModel)
if m.screen != screenForm {
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
}
if m.form == nil {
t.Fatal("expected form to remain open")
}
if m.form.showGroupList {
t.Fatal("expected Esc to close only the group list")
}
}
func TestEscClosesAuthMethodListBeforeLeavingForm(t *testing.T) {
m := &tuiModel{
screen: screenForm,
form: newFormModel(80, 24),
}
m.form.focusIdx = 5
m.form.updateFocus()
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
m = updated.(*tuiModel)
if !m.form.showAuthList {
t.Fatal("expected / on the auth method field to open the auth method list")
}
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
m = updated.(*tuiModel)
if m.screen != screenForm {
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
}
if m.form == nil {
t.Fatal("expected form to remain open")
}
if m.form.showAuthList {
t.Fatal("expected Esc to close only the auth method list")
}
}
func TestAuthMethodListSelectsValue(t *testing.T) {
fm := newFormModel(80, 24)
fm.focusIdx = 5
fm.updateFocus()
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
fm = updated.(*formModel)
if !fm.showAuthList {
t.Fatal("expected / on the auth method field to open the auth method list")
}
fm.authList.Select(2)
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
fm = updated.(*formModel)
if fm.showAuthList {
t.Fatal("expected Enter to close auth method list")
}
if got := fm.inputs[5].Value(); got != string(model.AuthKeyPassphrase) {
t.Fatalf("expected auth method %q, got %q", model.AuthKeyPassphrase, got)
}
}
func TestAuthMethodListViewShowsAllOptions(t *testing.T) {
fm := newFormModel(80, 12)
fm.focusIdx = 5
fm.updateFocus()
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
fm = updated.(*formModel)
view := fm.View()
authPos := strings.Index(view, "Auth Method")
listPos := strings.Index(view, "Select auth method")
if authPos < 0 || listPos < 0 {
t.Fatalf("expected auth field and auth list title in view\nview:\n%s", view)
}
if listPos < authPos {
t.Fatalf("expected auth method list after auth field\nview:\n%s", view)
}
if between := view[authPos:listPos]; strings.Contains(between, "Identity File") {
t.Fatalf("expected auth method list to render directly under auth field\nview:\n%s", view)
}
if strings.Contains(view, "│") {
t.Fatalf("expected compact auth method dropdown without default list border\nview:\n%s", view)
}
for _, method := range []model.AuthMethod{
model.AuthPassword,
model.AuthKey,
model.AuthKeyPassphrase,
model.AuthAgent,
} {
if !strings.Contains(view, string(method)) {
t.Fatalf("expected auth method list view to contain %q\nview:\n%s", method, view)
}
}
}
func TestGroupListViewRendersDirectlyUnderGroupField(t *testing.T) {
oldGetGroups := GetGroups
GetGroups = func() ([]string, error) {
return []string{"KP", "MY"}, nil
}
defer func() { GetGroups = oldGetGroups }()
fm := newFormModel(80, 24)
fm.focusIdx = 8
fm.updateFocus()
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
fm = updated.(*formModel)
view := fm.View()
groupPos := strings.Index(view, "Group")
listPos := strings.Index(view, "Select group")
if groupPos < 0 || listPos < 0 {
t.Fatalf("expected group field and group list title in view\nview:\n%s", view)
}
if listPos < groupPos {
t.Fatalf("expected group list after group field\nview:\n%s", view)
}
if between := view[groupPos:listPos]; strings.Contains(between, "Password") {
t.Fatalf("expected group dropdown to render before password field\nview:\n%s", view)
}
if strings.Contains(view, "│") {
t.Fatalf("expected compact group dropdown without default list border\nview:\n%s", view)
}
}
func TestSelectableFieldHintsAreVisible(t *testing.T) {
oldGetGroups := GetGroups
GetGroups = func() ([]string, error) {
return []string{"KP"}, nil
}
defer func() { GetGroups = oldGetGroups }()
fm := newFormModel(80, 24)
view := fm.View()
for _, want := range []string{"Auth Method (/ pick)", "Group (/ pick)", "/ pick list"} {
if !strings.Contains(view, want) {
t.Fatalf("expected form view to contain selectable-field hint %q\nview:\n%s", want, view)
}
}
}
func TestEditFormShowsSavedSecretMarker(t *testing.T) {
oldHasSecret := HasSecret
HasSecret = func(alias string, secretType string) bool {
return alias == "prod" && secretType == "ssh_password"
}
defer func() { HasSecret = oldHasSecret }()
fm := newEditFormModel(&model.Server{
Alias: "prod",
Host: "example.org",
Port: 22,
User: "root",
AuthMethod: model.AuthPassword,
}, 80, 24)
view := fm.View()
if !strings.Contains(view, "secret saved") {
t.Fatalf("expected edit form to show saved secret marker\nview:\n%s", view)
}
if !strings.Contains(view, "leave blank to keep") {
t.Fatalf("expected edit form to explain blank password keeps saved secret\nview:\n%s", view)
}
if strings.Count(view, "secret saved") != 1 {
t.Fatalf("expected saved secret marker to appear once\nview:\n%s", view)
}
}
func TestFormViewUsesSectionsAndStableLabels(t *testing.T) {
fm := newFormModel(100, 30)
view := fm.View()
for _, want := range []string{
"Identity",
"Connection",
"Authentication",
"Metadata",
"Actions",
"Alias",
"Display Name",
"Auth Method",
"Password / Passphrase",
} {
if !strings.Contains(view, want) {
t.Fatalf("expected form view to contain %q\nview:\n%s", want, view)
}
}
}
func TestFormTestResultDoesNotUpdateSelectedListServer(t *testing.T) {
oldUpdateTestResult := UpdateTestResult
oldListServers := ListServers
defer func() {
UpdateTestResult = oldUpdateTestResult
ListServers = oldListServers
}()
updateCalled := false
UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
updateCalled = true
return nil
}
ListServers = func() ([]*model.Server, error) {
t.Fatal("form test result should not reload server list")
return nil, nil
}
selected := &model.Server{Alias: "selected", Host: "example.org", Port: 22, User: "root"}
m := New([]*model.Server{selected})
m.screen = screenForm
m.form = newFormModel(80, 24)
m.list = list.New([]list.Item{serverItem{server: selected}}, list.NewDefaultDelegate(), 80, 20)
updated, cmd := m.Update(testDoneMsg{ok: true})
m = updated.(*tuiModel)
if cmd != nil {
t.Fatal("form test result should not return a reload command")
}
if updateCalled {
t.Fatal("form test result should not update the selected list server")
}
if m.form == nil || m.form.testResult != "Connection OK." {
t.Fatal("expected form to keep its test result")
}
}

View File

@ -8,6 +8,8 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"sort"
"strings"
"sync" "sync"
"time" "time"
@ -16,18 +18,20 @@ import (
) )
const ( const (
currentVersion = 1 currentVersion = 1
saltLen = 32 saltLen = 32
nonceLen = 24 nonceLen = 24
keyLen = 32 keyLen = 32
verifierID = "__sshkeeper_vault_verifier__"
verifierPlaintext = "sshkeeper-vault-verifier-v1"
) )
type KDFMeta struct { type KDFMeta struct {
Name string `json:"name"` Name string `json:"name"`
MemoryKiB int `json:"memory_kib"` MemoryKiB int `json:"memory_kib"`
Iterations int `json:"iterations"` Iterations int `json:"iterations"`
Parallelism int `json:"parallelism"` Parallelism int `json:"parallelism"`
Salt string `json:"salt"` Salt string `json:"salt"`
} }
type Record struct { type Record struct {
@ -38,23 +42,35 @@ type Record struct {
} }
type VaultFile struct { type VaultFile struct {
Version int `json:"version"` Version int `json:"version"`
KDF KDFMeta `json:"kdf"` KDF KDFMeta `json:"kdf"`
Records []Record `json:"records"` Verifier *Record `json:"verifier,omitempty"`
Records []Record `json:"records"`
} }
type Vault struct { type Vault struct {
mu sync.Mutex mu sync.Mutex
path string path string
masterKey []byte masterKey []byte
records map[string][]byte // id -> plaintext records map[string]secretRecord
modified bool modified bool
}
type secretRecord struct {
secretType string
plaintext []byte
}
type SecretMeta struct {
ID string
Alias string
Type string
} }
func New(path string) *Vault { func New(path string) *Vault {
return &Vault{ return &Vault{
path: path, path: path,
records: make(map[string][]byte), records: make(map[string]secretRecord),
} }
} }
@ -76,21 +92,26 @@ func Create(path string, masterPassword string) error {
kdf := KDFMeta{ kdf := KDFMeta{
Name: "argon2id", Name: "argon2id",
MemoryKiB: 4096, MemoryKiB: 65536,
Iterations: 2, Iterations: 3,
Parallelism: 1, Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt), Salt: base64.StdEncoding.EncodeToString(salt),
} }
fmt.Print("Deriving key...") fmt.Print("Deriving key...")
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB)*1024, uint8(kdf.Parallelism), keyLen) key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB), uint8(kdf.Parallelism), keyLen)
verifier, err := newVerifierRecord(key)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
// Verify key is valid by doing a test encrypt/decrypt
vf := VaultFile{ vf := VaultFile{
Version: currentVersion, Version: currentVersion,
KDF: kdf, KDF: kdf,
Records: []Record{}, Verifier: &verifier,
Records: []Record{},
} }
data, err := json.Marshal(vf) data, err := json.Marshal(vf)
@ -143,24 +164,32 @@ func (v *Vault) Unlock(masterPassword string) error {
return fmt.Errorf("decode salt: %w", err) return fmt.Errorf("decode salt: %w", err)
} }
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen) key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
// Try to decrypt first record to verify password if vf.Verifier != nil {
if len(vf.Records) > 0 { if err := verifyRecord(key, *vf.Verifier); err != nil {
return fmt.Errorf("invalid master password")
}
} else if len(vf.Records) > 0 {
if _, err := decryptRecord(key, vf.Records[0]); err != nil { if _, err := decryptRecord(key, vf.Records[0]); err != nil {
return fmt.Errorf("invalid master password") return fmt.Errorf("invalid master password")
} }
} else {
return fmt.Errorf("vault cannot verify master password; recreate empty vault")
} }
v.masterKey = key v.masterKey = key
v.records = make(map[string][]byte) v.records = make(map[string]secretRecord)
for _, rec := range vf.Records { for _, rec := range vf.Records {
plaintext, err := decryptRecord(key, rec) plaintext, err := decryptRecord(key, rec)
if err != nil { if err != nil {
return fmt.Errorf("decrypt record %s: %w", rec.ID, err) return fmt.Errorf("decrypt record %s: %w", rec.ID, err)
} }
v.records[rec.ID] = plaintext v.records[rec.ID] = secretRecord{
secretType: inferSecretType(rec.ID, rec.Type),
plaintext: plaintext,
}
} }
fmt.Println(" done.") fmt.Println(" done.")
@ -178,7 +207,7 @@ func (v *Vault) Lock() {
} }
} }
v.masterKey = nil v.masterKey = nil
v.records = make(map[string][]byte) v.records = make(map[string]secretRecord)
} }
// IsUnlocked returns whether the vault is currently unlocked // IsUnlocked returns whether the vault is currently unlocked
@ -197,7 +226,9 @@ func (v *Vault) Put(id string, secretType string, plaintext []byte) error {
return fmt.Errorf("vault is locked") return fmt.Errorf("vault is locked")
} }
v.records[id] = plaintext data := make([]byte, len(plaintext))
copy(data, plaintext)
v.records[id] = secretRecord{secretType: secretType, plaintext: data}
v.modified = true v.modified = true
return nil return nil
} }
@ -211,16 +242,55 @@ func (v *Vault) Get(id string) ([]byte, error) {
return nil, fmt.Errorf("vault is locked") return nil, fmt.Errorf("vault is locked")
} }
data, ok := v.records[id] record, ok := v.records[id]
if !ok { if !ok {
return nil, fmt.Errorf("secret not found: %s", id) return nil, fmt.Errorf("secret not found: %s", id)
} }
result := make([]byte, len(data)) result := make([]byte, len(record.plaintext))
copy(result, data) copy(result, record.plaintext)
return result, nil return result, nil
} }
func (v *Vault) HasSecret(id string) bool {
v.mu.Lock()
defer v.mu.Unlock()
_, ok := v.records[id]
return ok
}
func (v *Vault) ListSecrets() ([]SecretMeta, error) {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey == nil {
return nil, fmt.Errorf("vault is locked")
}
metas := make([]SecretMeta, 0, len(v.records))
for id, record := range v.records {
alias, secretType, ok := parseServerSecretID(id)
if !ok {
continue
}
if record.secretType != "" {
secretType = record.secretType
}
metas = append(metas, SecretMeta{
ID: id,
Alias: alias,
Type: secretType,
})
}
sort.Slice(metas, func(i, j int) bool {
if metas[i].Alias == metas[j].Alias {
return metas[i].Type < metas[j].Type
}
return metas[i].Alias < metas[j].Alias
})
return metas, nil
}
// Delete removes a secret // Delete removes a secret
func (v *Vault) Delete(id string) { func (v *Vault) Delete(id string) {
v.mu.Lock() v.mu.Lock()
@ -245,8 +315,8 @@ func (v *Vault) Save() error {
kdf := KDFMeta{ kdf := KDFMeta{
Name: "argon2id", Name: "argon2id",
MemoryKiB: 4096, MemoryKiB: 65536,
Iterations: 2, Iterations: 3,
Parallelism: 1, Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt), Salt: base64.StdEncoding.EncodeToString(salt),
} }
@ -254,18 +324,24 @@ func (v *Vault) Save() error {
fmt.Print("Deriving key...") fmt.Print("Deriving key...")
var records []Record var records []Record
for id, plaintext := range v.records { for id, record := range v.records {
rec, err := encryptRecord(v.masterKey, id, plaintext) rec, err := encryptRecordWithType(v.masterKey, id, record.secretType, record.plaintext)
if err != nil { if err != nil {
return fmt.Errorf("encrypt record %s: %w", id, err) return fmt.Errorf("encrypt record %s: %w", id, err)
} }
records = append(records, rec) records = append(records, rec)
} }
verifier, err := newVerifierRecord(v.masterKey)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
vf := VaultFile{ vf := VaultFile{
Version: currentVersion, Version: currentVersion,
KDF: kdf, KDF: kdf,
Records: records, Verifier: &verifier,
Records: records,
} }
data, err := json.Marshal(vf) data, err := json.Marshal(vf)
@ -309,12 +385,12 @@ func (v *Vault) ChangePassword(newPassword string) error {
return fmt.Errorf("generate salt: %w", err) return fmt.Errorf("generate salt: %w", err)
} }
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 8192*1024, 1, keyLen) newKey := argon2.IDKey([]byte(newPassword), salt, 3, 65536, 1, keyLen)
kdf := KDFMeta{ kdf := KDFMeta{
Name: "argon2id", Name: "argon2id",
MemoryKiB: 4096, MemoryKiB: 65536,
Iterations: 2, Iterations: 3,
Parallelism: 1, Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt), Salt: base64.StdEncoding.EncodeToString(salt),
} }
@ -322,18 +398,24 @@ func (v *Vault) ChangePassword(newPassword string) error {
fmt.Print("Deriving key...") fmt.Print("Deriving key...")
var records []Record var records []Record
for id, plaintext := range v.records { for id, record := range v.records {
rec, err := encryptRecord(newKey, id, plaintext) rec, err := encryptRecordWithType(newKey, id, record.secretType, record.plaintext)
if err != nil { if err != nil {
return fmt.Errorf("encrypt record: %w", err) return fmt.Errorf("encrypt record: %w", err)
} }
records = append(records, rec) records = append(records, rec)
} }
verifier, err := newVerifierRecord(newKey)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
vf := VaultFile{ vf := VaultFile{
Version: currentVersion, Version: currentVersion,
KDF: kdf, KDF: kdf,
Records: records, Verifier: &verifier,
Records: records,
} }
data, err := json.Marshal(vf) data, err := json.Marshal(vf)
@ -382,6 +464,10 @@ func (v *Vault) getSalt() string {
} }
func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) { func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
return encryptRecordWithType(key, id, "", plaintext)
}
func encryptRecordWithType(key []byte, id string, secretType string, plaintext []byte) (Record, error) {
aead, err := chacha20poly1305.NewX(key) aead, err := chacha20poly1305.NewX(key)
if err != nil { if err != nil {
return Record{}, err return Record{}, err
@ -396,11 +482,31 @@ func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
return Record{ return Record{
ID: id, ID: id,
Type: secretType,
Nonce: base64.StdEncoding.EncodeToString(nonce), Nonce: base64.StdEncoding.EncodeToString(nonce),
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext), Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
}, nil }, nil
} }
func inferSecretType(id string, recordType string) string {
if recordType != "" {
return recordType
}
_, secretType, ok := parseServerSecretID(id)
if !ok {
return ""
}
return secretType
}
func parseServerSecretID(id string) (string, string, bool) {
parts := strings.Split(id, ":")
if len(parts) != 3 || parts[0] != "server" || parts[1] == "" || parts[2] == "" {
return "", "", false
}
return parts[1], parts[2], true
}
func decryptRecord(key []byte, rec Record) ([]byte, error) { func decryptRecord(key []byte, rec Record) ([]byte, error) {
aead, err := chacha20poly1305.NewX(key) aead, err := chacha20poly1305.NewX(key)
if err != nil { if err != nil {
@ -425,6 +531,26 @@ func decryptRecord(key []byte, rec Record) ([]byte, error) {
return plaintext, nil return plaintext, nil
} }
func newVerifierRecord(key []byte) (Record, error) {
rec, err := encryptRecord(key, verifierID, []byte(verifierPlaintext))
if err != nil {
return Record{}, err
}
rec.Type = "verifier"
return rec, nil
}
func verifyRecord(key []byte, rec Record) error {
plaintext, err := decryptRecord(key, rec)
if err != nil {
return err
}
if !SecureCompare(string(plaintext), verifierPlaintext) {
return fmt.Errorf("invalid verifier")
}
return nil
}
// VerifyPassword checks if a master password is correct without unlocking // VerifyPassword checks if a master password is correct without unlocking
func VerifyPassword(path string, masterPassword string) (bool, error) { func VerifyPassword(path string, masterPassword string) (bool, error) {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
@ -442,18 +568,20 @@ func VerifyPassword(path string, masterPassword string) (bool, error) {
return false, err return false, err
} }
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen) key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
defer func() { defer func() {
for i := range key { for i := range key {
key[i] = 0 key[i] = 0
} }
}() }()
if len(vf.Records) == 0 { if vf.Verifier != nil {
// Empty vault, try a test encryption return verifyRecord(key, *vf.Verifier) == nil, nil
return true, nil
} }
if len(vf.Records) == 0 {
return false, nil
}
_, err = decryptRecord(key, vf.Records[0]) _, err = decryptRecord(key, vf.Records[0])
return err == nil, nil return err == nil, nil
} }

View File

@ -0,0 +1,218 @@
package vault
import (
"encoding/base64"
"encoding/json"
"os"
"path/filepath"
"testing"
"golang.org/x/crypto/argon2"
)
func TestNewEmptyVaultRejectsWrongPassword(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "correct horse"); err != nil {
t.Fatalf("create vault: %v", err)
}
ok, err := VerifyPassword(path, "wrong horse")
if err != nil {
t.Fatalf("verify wrong password: %v", err)
}
if ok {
t.Fatal("expected wrong password to be rejected for a new empty vault")
}
v := New(path)
if err := v.Unlock("wrong horse"); err == nil {
t.Fatal("expected unlock with wrong password to fail for a new empty vault")
}
}
func TestNewEmptyVaultAcceptsCorrectPassword(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "correct horse"); err != nil {
t.Fatalf("create vault: %v", err)
}
ok, err := VerifyPassword(path, "correct horse")
if err != nil {
t.Fatalf("verify correct password: %v", err)
}
if !ok {
t.Fatal("expected correct password to be accepted for a new empty vault")
}
v := New(path)
if err := v.Unlock("correct horse"); err != nil {
t.Fatalf("unlock with correct password: %v", err)
}
}
func TestLegacyVaultWithRecordsStillVerifiesByFirstRecord(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
salt := []byte("12345678901234567890123456789012")
key := argon2.IDKey([]byte("correct horse"), salt, 3, 65536, 1, keyLen)
rec, err := encryptRecord(key, "server:test:ssh_password", []byte("secret"))
if err != nil {
t.Fatalf("encrypt legacy record: %v", err)
}
data, err := json.Marshal(VaultFile{
Version: currentVersion,
KDF: KDFMeta{
Name: "argon2id",
MemoryKiB: 65536,
Iterations: 3,
Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt),
},
Records: []Record{rec},
})
if err != nil {
t.Fatalf("marshal legacy vault: %v", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("write legacy vault: %v", err)
}
ok, err := VerifyPassword(path, "correct horse")
if err != nil {
t.Fatalf("verify legacy vault: %v", err)
}
if !ok {
t.Fatal("expected legacy vault with records to accept correct password")
}
ok, err = VerifyPassword(path, "wrong horse")
if err != nil {
t.Fatalf("verify legacy vault with wrong password: %v", err)
}
if ok {
t.Fatal("expected legacy vault with records to reject wrong password")
}
}
func TestLegacyEmptyVaultWithoutVerifierCannotUnlock(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
salt := []byte("12345678901234567890123456789012")
data, err := json.Marshal(VaultFile{
Version: currentVersion,
KDF: KDFMeta{
Name: "argon2id",
MemoryKiB: 65536,
Iterations: 3,
Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt),
},
Records: []Record{},
})
if err != nil {
t.Fatalf("marshal legacy empty vault: %v", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("write legacy empty vault: %v", err)
}
ok, err := VerifyPassword(path, "any password")
if err != nil {
t.Fatalf("verify legacy empty vault: %v", err)
}
if ok {
t.Fatal("expected legacy empty vault without verifier to be unverifiable")
}
v := New(path)
if err := v.Unlock("any password"); err == nil {
t.Fatal("expected legacy empty vault without verifier to reject unlock")
}
}
func TestListSecretsReturnsMetadataWithoutPlaintext(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
t.Fatalf("put password: %v", err)
}
if err := v.Put("server:prod:key_passphrase", "key_passphrase", []byte("secret-passphrase")); err != nil {
t.Fatalf("put passphrase: %v", err)
}
metas, err := v.ListSecrets()
if err != nil {
t.Fatalf("list secrets: %v", err)
}
if len(metas) != 2 {
t.Fatalf("expected 2 secret metadata records, got %d: %#v", len(metas), metas)
}
if metas[0].Alias != "prod" || metas[0].Type != "key_passphrase" {
t.Fatalf("unexpected first metadata record: %#v", metas[0])
}
if metas[1].Alias != "prod" || metas[1].Type != "ssh_password" {
t.Fatalf("unexpected second metadata record: %#v", metas[1])
}
}
func TestListSecretsPreservesTypesAfterSaveAndUnlock(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
t.Fatalf("put password: %v", err)
}
if err := v.Save(); err != nil {
t.Fatalf("save vault: %v", err)
}
reopened := New(path)
if err := reopened.Unlock("master"); err != nil {
t.Fatalf("unlock reopened vault: %v", err)
}
metas, err := reopened.ListSecrets()
if err != nil {
t.Fatalf("list reopened secrets: %v", err)
}
if len(metas) != 1 {
t.Fatalf("expected 1 secret metadata record, got %d: %#v", len(metas), metas)
}
if metas[0].ID != "server:prod:ssh_password" || metas[0].Alias != "prod" || metas[0].Type != "ssh_password" {
t.Fatalf("unexpected metadata after reopen: %#v", metas[0])
}
}
func TestHasSecretReportsPresenceWithoutReturningValue(t *testing.T) {
path := filepath.Join(t.TempDir(), "vault.bin")
if err := Create(path, "master"); err != nil {
t.Fatalf("create vault: %v", err)
}
v := New(path)
if err := v.Unlock("master"); err != nil {
t.Fatalf("unlock vault: %v", err)
}
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
t.Fatalf("put password: %v", err)
}
if !v.HasSecret("server:prod:ssh_password") {
t.Fatal("expected saved password to be reported present")
}
if v.HasSecret("server:prod:key_passphrase") {
t.Fatal("expected missing passphrase to be reported absent")
}
}