feat: improve tui and vault handling
This commit is contained in:
parent
73c60b9f93
commit
e1d709396b
51
cmd/add.go
51
cmd/add.go
|
|
@ -3,9 +3,11 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var addFlags struct {
|
||||
|
|
@ -19,7 +21,6 @@ var addFlags struct {
|
|||
displayName string
|
||||
notes string
|
||||
tags string
|
||||
password string
|
||||
}
|
||||
|
||||
var addCmd = &cobra.Command{
|
||||
|
|
@ -59,11 +60,21 @@ func addNonInteractive(alias string) error {
|
|||
server.DisplayName = alias
|
||||
}
|
||||
|
||||
// Handle password auth - store in vault
|
||||
if server.AuthMethod == model.AuthPassword {
|
||||
password := addFlags.password
|
||||
if password == "" {
|
||||
return fmt.Errorf("password auth requires --password flag or interactive mode")
|
||||
// Handle password/passphrase auth — request interactively, never via argv
|
||||
if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
|
||||
secretType := "password"
|
||||
if server.AuthMethod == model.AuthKeyPassphrase {
|
||||
secretType = "passphrase"
|
||||
}
|
||||
|
||||
fmt.Printf("Enter %s (will be stored in vault, input hidden): ", secretType)
|
||||
password, err := term.ReadPassword(int(syscall.Stdin))
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", secretType, err)
|
||||
}
|
||||
if len(password) == 0 {
|
||||
return fmt.Errorf("%s cannot be empty", secretType)
|
||||
}
|
||||
|
||||
v := getOrCreateVault()
|
||||
|
|
@ -72,29 +83,14 @@ func addNonInteractive(alias string) error {
|
|||
}
|
||||
|
||||
vaultKey := fmt.Sprintf("server:%s:ssh_password", alias)
|
||||
if err := v.Put(vaultKey, "ssh_password", []byte(password)); err != nil {
|
||||
return fmt.Errorf("store password in vault: %w", err)
|
||||
}
|
||||
if err := v.Save(); err != nil {
|
||||
return fmt.Errorf("save vault: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle key+passphrase - store passphrase in vault
|
||||
if server.AuthMethod == model.AuthKeyPassphrase {
|
||||
passphrase := addFlags.password
|
||||
if passphrase == "" {
|
||||
return fmt.Errorf("key+passphrase auth requires --password flag for the passphrase")
|
||||
vaultType := "ssh_password"
|
||||
if server.AuthMethod == model.AuthKeyPassphrase {
|
||||
vaultKey = fmt.Sprintf("server:%s:key_passphrase", alias)
|
||||
vaultType = "key_passphrase"
|
||||
}
|
||||
|
||||
v := getOrCreateVault()
|
||||
if !v.IsUnlocked() {
|
||||
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
|
||||
}
|
||||
|
||||
vaultKey := fmt.Sprintf("server:%s:key_passphrase", alias)
|
||||
if err := v.Put(vaultKey, "key_passphrase", []byte(passphrase)); err != nil {
|
||||
return fmt.Errorf("store passphrase in vault: %w", err)
|
||||
if err := v.Put(vaultKey, vaultType, password); err != nil {
|
||||
return fmt.Errorf("store %s in vault: %w", secretType, err)
|
||||
}
|
||||
if err := v.Save(); err != nil {
|
||||
return fmt.Errorf("save vault: %w", err)
|
||||
|
|
@ -132,5 +128,4 @@ func init() {
|
|||
addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name")
|
||||
addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes")
|
||||
addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags")
|
||||
addCmd.Flags().StringVar(&addFlags.password, "password", "", "SSH password or key passphrase (stored in vault)")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,15 @@ var deleteCmd = &cobra.Command{
|
|||
return fmt.Errorf("delete server: %w", err)
|
||||
}
|
||||
|
||||
// Clean up vault secrets for this server
|
||||
v := getOrCreateVault()
|
||||
if v.IsUnlocked() {
|
||||
cleanupServerSecrets(v, alias)
|
||||
if err := v.Save(); err != nil {
|
||||
return fmt.Errorf("save vault after cleanup: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Deleted.")
|
||||
return nil
|
||||
},
|
||||
|
|
|
|||
57
cmd/edit.go
57
cmd/edit.go
|
|
@ -2,9 +2,11 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var editCmd = &cobra.Command{
|
||||
|
|
@ -18,6 +20,8 @@ var editCmd = &cobra.Command{
|
|||
return fmt.Errorf("server not found: %s", alias)
|
||||
}
|
||||
|
||||
oldAuthMethod := server.AuthMethod
|
||||
|
||||
if parsedHost != "" {
|
||||
server.Host = parsedHost
|
||||
}
|
||||
|
|
@ -46,6 +50,41 @@ var editCmd = &cobra.Command{
|
|||
server.Notes = parsedNotes
|
||||
}
|
||||
|
||||
if parsedAuth != "" && oldAuthMethod != server.AuthMethod {
|
||||
v := getOrCreateVault()
|
||||
if v.IsUnlocked() {
|
||||
var secret string
|
||||
if server.AuthMethod == model.AuthPassword {
|
||||
fmt.Print("Enter new password (stored in vault, input hidden): ")
|
||||
pw, err := term.ReadPassword(int(syscall.Stdin))
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read password: %w", err)
|
||||
}
|
||||
if len(pw) > 0 {
|
||||
secret = string(pw)
|
||||
}
|
||||
} else if server.AuthMethod == model.AuthKeyPassphrase {
|
||||
fmt.Print("Enter key passphrase (stored in vault, input hidden): ")
|
||||
pw, err := term.ReadPassword(int(syscall.Stdin))
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read passphrase: %w", err)
|
||||
}
|
||||
if len(pw) > 0 {
|
||||
secret = string(pw)
|
||||
}
|
||||
}
|
||||
|
||||
if err := syncServerSecrets(v, alias, server, secret); err != nil {
|
||||
return fmt.Errorf("sync vault secrets: %w", err)
|
||||
}
|
||||
if err := v.Save(); err != nil {
|
||||
return fmt.Errorf("save vault: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := appDB.UpdateServer(server); err != nil {
|
||||
return fmt.Errorf("update server: %w", err)
|
||||
}
|
||||
|
|
@ -56,15 +95,15 @@ var editCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
var (
|
||||
parsedHost string
|
||||
parsedPort int
|
||||
parsedUser string
|
||||
parsedAuth string
|
||||
parsedIdentity string
|
||||
parsedProxyJump string
|
||||
parsedGroup string
|
||||
parsedHost string
|
||||
parsedPort int
|
||||
parsedUser string
|
||||
parsedAuth string
|
||||
parsedIdentity string
|
||||
parsedProxyJump string
|
||||
parsedGroup string
|
||||
parsedDisplayName string
|
||||
parsedNotes string
|
||||
parsedNotes string
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
|
|||
35
cmd/extra.go
35
cmd/extra.go
|
|
@ -74,8 +74,13 @@ var runCmd = &cobra.Command{
|
|||
return fmt.Errorf("server not found: %s", alias)
|
||||
}
|
||||
|
||||
// Build ssh args with the command
|
||||
sshArgs := buildSSHArgs(server)
|
||||
// For password auth, use PTY-wrapper with command
|
||||
if server.AuthMethod == model.AuthPassword {
|
||||
return runWithPassword(server, command)
|
||||
}
|
||||
|
||||
// For key/agent auth — direct execution
|
||||
sshArgs := ssh.BuildSSHArgs(server)
|
||||
sshArgs = append(sshArgs, command)
|
||||
|
||||
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
|
||||
|
|
@ -91,17 +96,21 @@ var runCmd = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
func buildSSHArgs(server *model.Server) []string {
|
||||
var args []string
|
||||
args = append(args, "-p", fmt.Sprintf("%d", server.Port))
|
||||
if server.IdentityFile != "" {
|
||||
args = append(args, "-i", server.IdentityFile)
|
||||
// runWithPassword runs a command on a server with password auth via PTY-wrapper.
|
||||
func runWithPassword(server *model.Server, command string) error {
|
||||
v := getOrCreateVault()
|
||||
if !v.IsUnlocked() {
|
||||
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
|
||||
}
|
||||
if server.ProxyJump != "" {
|
||||
args = append(args, "-J", server.ProxyJump)
|
||||
|
||||
vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias)
|
||||
password, err := v.Get(vaultKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get password from vault: %w", err)
|
||||
}
|
||||
args = append(args, "-o", "StrictHostKeyChecking=accept-new")
|
||||
target := fmt.Sprintf("%s@%s", server.User, server.Host)
|
||||
args = append(args, target)
|
||||
return args
|
||||
|
||||
sshArgs := ssh.BuildSSHArgs(server)
|
||||
sshArgs = append(sshArgs, command)
|
||||
|
||||
return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(password))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,85 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
"github.com/mirivlad/sshkeeper/internal/ssh"
|
||||
"github.com/mirivlad/sshkeeper/internal/vault"
|
||||
)
|
||||
|
||||
const (
|
||||
secretSSHPassword = "ssh_password"
|
||||
secretKeyPassphrase = "key_passphrase"
|
||||
secretSudoPassword = "sudo_password"
|
||||
)
|
||||
|
||||
var serverSecretTypes = []string{
|
||||
secretSSHPassword,
|
||||
secretKeyPassphrase,
|
||||
secretSudoPassword,
|
||||
}
|
||||
|
||||
func serverSecretID(alias, secretType string) string {
|
||||
return fmt.Sprintf("server:%s:%s", alias, secretType)
|
||||
}
|
||||
|
||||
func cleanupServerSecrets(v *vault.Vault, alias string) {
|
||||
for _, secretType := range serverSecretTypes {
|
||||
v.Delete(serverSecretID(alias, secretType))
|
||||
}
|
||||
}
|
||||
|
||||
func syncServerSecrets(v *vault.Vault, oldAlias string, server *model.Server, secret string) error {
|
||||
if oldAlias == "" {
|
||||
oldAlias = server.Alias
|
||||
}
|
||||
if oldAlias != server.Alias {
|
||||
for _, secretType := range serverSecretTypes {
|
||||
oldID := serverSecretID(oldAlias, secretType)
|
||||
data, err := v.Get(oldID)
|
||||
if err == nil {
|
||||
if err := v.Put(serverSecretID(server.Alias, secretType), secretType, data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v.Delete(oldID)
|
||||
}
|
||||
}
|
||||
|
||||
switch server.AuthMethod {
|
||||
case model.AuthPassword:
|
||||
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
|
||||
if secret != "" {
|
||||
return v.Put(serverSecretID(server.Alias, secretSSHPassword), secretSSHPassword, []byte(secret))
|
||||
}
|
||||
case model.AuthKeyPassphrase:
|
||||
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
|
||||
if secret != "" {
|
||||
return v.Put(serverSecretID(server.Alias, secretKeyPassphrase), secretKeyPassphrase, []byte(secret))
|
||||
}
|
||||
default:
|
||||
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
|
||||
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteVaultSecrets(v *vault.Vault, alias string, secretType string) error {
|
||||
if secretType != "" {
|
||||
v.Delete(serverSecretID(alias, secretType))
|
||||
return nil
|
||||
}
|
||||
cleanupServerSecrets(v, alias)
|
||||
return nil
|
||||
}
|
||||
|
||||
func formTestVaultFunc(getVault ssh.VaultFunc, server *model.Server, formSecret string) ssh.VaultFunc {
|
||||
return func(serverAlias string, secretType string) (string, error) {
|
||||
if (secretType == secretSSHPassword || secretType == secretKeyPassphrase) && formSecret != "" {
|
||||
return formSecret, nil
|
||||
}
|
||||
return getVault(serverAlias, secretType)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
"github.com/mirivlad/sshkeeper/internal/vault"
|
||||
)
|
||||
|
||||
func newUnlockedTestVault(t *testing.T) *vault.Vault {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
if err := vault.Create(path, "master"); err != nil {
|
||||
t.Fatalf("create vault: %v", err)
|
||||
}
|
||||
v := vault.New(path)
|
||||
if err := v.Unlock("master"); err != nil {
|
||||
t.Fatalf("unlock vault: %v", err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func mustPutSecret(t *testing.T, v *vault.Vault, alias, secretType, value string) {
|
||||
t.Helper()
|
||||
if err := v.Put(serverSecretID(alias, secretType), secretType, []byte(value)); err != nil {
|
||||
t.Fatalf("put %s: %v", secretType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustGetSecret(t *testing.T, v *vault.Vault, alias, secretType string) string {
|
||||
t.Helper()
|
||||
data, err := v.Get(serverSecretID(alias, secretType))
|
||||
if err != nil {
|
||||
t.Fatalf("get %s: %v", secretType, err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func assertSecretMissing(t *testing.T, v *vault.Vault, alias, secretType string) {
|
||||
t.Helper()
|
||||
if data, err := v.Get(serverSecretID(alias, secretType)); err == nil {
|
||||
t.Fatalf("expected %s for %s to be missing, got %q", secretType, alias, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupServerSecretsRemovesAuthAndSudoSecrets(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||
mustPutSecret(t, v, "prod", "sudo_password", "sudo")
|
||||
|
||||
cleanupServerSecrets(v, "prod")
|
||||
|
||||
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||
assertSecretMissing(t, v, "prod", "sudo_password")
|
||||
}
|
||||
|
||||
func TestSyncServerSecretsDeletesAuthSecretsForKeyAuth(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||
|
||||
server := &model.Server{Alias: "prod", AuthMethod: model.AuthKey}
|
||||
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
|
||||
t.Fatalf("sync secrets: %v", err)
|
||||
}
|
||||
|
||||
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||
}
|
||||
|
||||
func TestSyncServerSecretsKeepsExistingPasswordWhenEditPasswordIsBlank(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "prod", "ssh_password", "old-password")
|
||||
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
|
||||
|
||||
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
|
||||
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
|
||||
t.Fatalf("sync secrets: %v", err)
|
||||
}
|
||||
|
||||
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "old-password" {
|
||||
t.Fatalf("expected password to remain, got %q", got)
|
||||
}
|
||||
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||
}
|
||||
|
||||
func TestSyncServerSecretsRenamesSecretsBeforeApplyingAuthCleanup(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "old", "ssh_password", "password")
|
||||
mustPutSecret(t, v, "old", "key_passphrase", "passphrase")
|
||||
mustPutSecret(t, v, "old", "sudo_password", "sudo")
|
||||
|
||||
server := &model.Server{Alias: "new", AuthMethod: model.AuthKeyPassphrase}
|
||||
if err := syncServerSecrets(v, "old", server, ""); err != nil {
|
||||
t.Fatalf("sync secrets: %v", err)
|
||||
}
|
||||
|
||||
assertSecretMissing(t, v, "old", "ssh_password")
|
||||
assertSecretMissing(t, v, "old", "key_passphrase")
|
||||
assertSecretMissing(t, v, "old", "sudo_password")
|
||||
assertSecretMissing(t, v, "new", "ssh_password")
|
||||
if got := mustGetSecret(t, v, "new", "key_passphrase"); got != "passphrase" {
|
||||
t.Fatalf("expected key passphrase to move, got %q", got)
|
||||
}
|
||||
if got := mustGetSecret(t, v, "new", "sudo_password"); got != "sudo" {
|
||||
t.Fatalf("expected sudo password to move, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncServerSecretsStoresNewSecretForSelectedAuthMethod(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
|
||||
|
||||
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
|
||||
if err := syncServerSecrets(v, "prod", server, "new-password"); err != nil {
|
||||
t.Fatalf("sync secrets: %v", err)
|
||||
}
|
||||
|
||||
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "new-password" {
|
||||
t.Fatalf("expected new password, got %q", got)
|
||||
}
|
||||
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||
}
|
||||
|
||||
func TestDeleteVaultSecretsForAliasAndType(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||
|
||||
if err := deleteVaultSecrets(v, "prod", "ssh_password"); err != nil {
|
||||
t.Fatalf("delete vault secret: %v", err)
|
||||
}
|
||||
|
||||
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||
if got := mustGetSecret(t, v, "prod", "key_passphrase"); got != "passphrase" {
|
||||
t.Fatalf("expected passphrase to remain, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteVaultSecretsForAlias(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||
|
||||
if err := deleteVaultSecrets(v, "prod", ""); err != nil {
|
||||
t.Fatalf("delete vault secrets: %v", err)
|
||||
}
|
||||
|
||||
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||
}
|
||||
|
||||
func TestFormTestVaultFuncUsesSavedSecretWhenFormSecretBlank(t *testing.T) {
|
||||
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
|
||||
if serverAlias != "prod" || secretType != "ssh_password" {
|
||||
return "", fmt.Errorf("unexpected secret lookup %s %s", serverAlias, secretType)
|
||||
}
|
||||
return "saved-password", nil
|
||||
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "")
|
||||
|
||||
got, err := vaultFunc("prod", "ssh_password")
|
||||
if err != nil {
|
||||
t.Fatalf("get password: %v", err)
|
||||
}
|
||||
if got != "saved-password" {
|
||||
t.Fatalf("expected saved password, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormTestVaultFuncUsesFormSecretWhenProvided(t *testing.T) {
|
||||
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
|
||||
return "", fmt.Errorf("saved secret should not be used")
|
||||
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "typed-password")
|
||||
|
||||
got, err := vaultFunc("prod", "ssh_password")
|
||||
if err != nil {
|
||||
t.Fatalf("get password: %v", err)
|
||||
}
|
||||
if got != "typed-password" {
|
||||
t.Fatalf("expected typed password, got %q", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,8 +2,11 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/mirivlad/sshkeeper/internal/ssh"
|
||||
)
|
||||
|
||||
var templateCmd = &cobra.Command{
|
||||
|
|
@ -70,13 +73,15 @@ var runTemplateCmd = &cobra.Command{
|
|||
alias := args[0]
|
||||
templateName := args[1]
|
||||
|
||||
server, err := appDB.GetServer(alias)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server not found: %s", alias)
|
||||
}
|
||||
|
||||
templates, err := appDB.GetCommandTemplates(alias)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list templates: %w", err)
|
||||
}
|
||||
if len(templates) == 0 {
|
||||
return fmt.Errorf("server not found or no templates: %s", alias)
|
||||
}
|
||||
|
||||
var command string
|
||||
for _, t := range templates {
|
||||
|
|
@ -91,7 +96,21 @@ var runTemplateCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
fmt.Printf("Running '%s' on %s...\n", command, alias)
|
||||
return nil
|
||||
|
||||
// Build ssh args with the command
|
||||
sshArgs := ssh.BuildSSHArgs(server)
|
||||
sshArgs = append(sshArgs, command)
|
||||
|
||||
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
|
||||
sshCmd.Stdin = os.Stdin
|
||||
sshCmd.Stdout = os.Stdout
|
||||
sshCmd.Stderr = os.Stderr
|
||||
|
||||
if err := sshCmd.Start(); err != nil {
|
||||
return fmt.Errorf("start ssh: %w", err)
|
||||
}
|
||||
|
||||
return sshCmd.Wait()
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
56
cmd/tui.go
56
cmd/tui.go
|
|
@ -36,41 +36,43 @@ func runTUI() error {
|
|||
return appDB.SearchServers(query)
|
||||
}
|
||||
tui.DeleteServer = func(alias string) error {
|
||||
return appDB.DeleteServer(alias)
|
||||
if err := appDB.DeleteServer(alias); err != nil {
|
||||
return err
|
||||
}
|
||||
v := getOrCreateVault()
|
||||
if v.IsUnlocked() {
|
||||
cleanupServerSecrets(v, alias)
|
||||
if err := v.Save(); err != nil {
|
||||
return fmt.Errorf("save vault after cleanup: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
tui.TestConnection = func(server *model.Server) (bool, string) {
|
||||
return ssh.Test(cfg, server, vaultFunc)
|
||||
}
|
||||
tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) {
|
||||
directVaultFunc := func(sa string, st string) (string, error) {
|
||||
if st == "ssh_password" || st == "key_passphrase" {
|
||||
return password, nil
|
||||
}
|
||||
return vaultFunc(sa, st)
|
||||
}
|
||||
return ssh.Test(cfg, server, directVaultFunc)
|
||||
return ssh.Test(cfg, server, formTestVaultFunc(vaultFunc, server, password))
|
||||
}
|
||||
tui.SaveServer = func(server *model.Server, password string) error {
|
||||
if password != "" {
|
||||
v := getOrCreateVault()
|
||||
vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias)
|
||||
secretType := "ssh_password"
|
||||
if server.AuthMethod == model.AuthKeyPassphrase {
|
||||
vaultKey = fmt.Sprintf("server:%s:key_passphrase", server.Alias)
|
||||
secretType = "key_passphrase"
|
||||
}
|
||||
if err := v.Put(vaultKey, secretType, []byte(password)); err != nil {
|
||||
return fmt.Errorf("store secret: %w", err)
|
||||
tui.SaveServer = func(server *model.Server, password string, oldAlias string) error {
|
||||
v := getOrCreateVault()
|
||||
if v.IsUnlocked() {
|
||||
if err := syncServerSecrets(v, oldAlias, server, password); err != nil {
|
||||
return fmt.Errorf("sync vault secrets: %w", err)
|
||||
}
|
||||
if err := v.Save(); err != nil {
|
||||
return fmt.Errorf("save vault: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
existing, _ := appDB.GetServer(server.Alias)
|
||||
lookupAlias := server.Alias
|
||||
if oldAlias != "" {
|
||||
lookupAlias = oldAlias
|
||||
}
|
||||
existing, _ := appDB.GetServer(lookupAlias)
|
||||
if existing != nil {
|
||||
server.ID = existing.ID
|
||||
return appDB.UpdateServer(server)
|
||||
return appDB.UpdateServerByAlias(existing.Alias, server)
|
||||
}
|
||||
return appDB.CreateServer(server)
|
||||
}
|
||||
|
|
@ -84,6 +86,16 @@ func runTUI() error {
|
|||
tui.DeleteGroup = func(name string) error {
|
||||
return appDB.DeleteGroup(name)
|
||||
}
|
||||
tui.UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
|
||||
return appDB.UpdateTestResult(alias, status, testErr)
|
||||
}
|
||||
tui.HasSecret = func(alias string, secretType string) bool {
|
||||
v := getOrCreateVault()
|
||||
if !v.IsUnlocked() {
|
||||
return false
|
||||
}
|
||||
return v.HasSecret(serverSecretID(alias, secretType))
|
||||
}
|
||||
|
||||
// Run TUI in a loop — if user requests connect, handle it and restart TUI
|
||||
for {
|
||||
|
|
@ -132,5 +144,3 @@ func runTUI() error {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
69
cmd/vault.go
69
cmd/vault.go
|
|
@ -3,11 +3,12 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/mirivlad/sshkeeper/internal/config"
|
||||
"github.com/mirivlad/sshkeeper/internal/vault"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
|
|
@ -160,9 +161,75 @@ var vaultChangePasswordCmd = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
var vaultListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List stored secret metadata",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
v := getOrCreateVault()
|
||||
if !v.IsUnlocked() {
|
||||
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
|
||||
}
|
||||
output, err := formatVaultSecretsList(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Print(output)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var vaultDeleteCmd = &cobra.Command{
|
||||
Use: "delete <alias> [type]",
|
||||
Short: "Delete stored secrets for a server",
|
||||
Args: cobra.RangeArgs(1, 2),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
alias := args[0]
|
||||
secretType := ""
|
||||
if len(args) == 2 {
|
||||
secretType = args[1]
|
||||
}
|
||||
|
||||
v := getOrCreateVault()
|
||||
if !v.IsUnlocked() {
|
||||
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
|
||||
}
|
||||
if err := deleteVaultSecrets(v, alias, secretType); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v.Save(); err != nil {
|
||||
return fmt.Errorf("save vault: %w", err)
|
||||
}
|
||||
if secretType == "" {
|
||||
fmt.Printf("Deleted secrets for %s.\n", alias)
|
||||
} else {
|
||||
fmt.Printf("Deleted %s for %s.\n", secretType, alias)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func formatVaultSecretsList(v *vault.Vault) (string, error) {
|
||||
metas, err := v.ListSecrets()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(metas) == 0 {
|
||||
return "No secrets stored.\n", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%-24s %-18s\n", "ALIAS", "TYPE")
|
||||
for _, meta := range metas {
|
||||
fmt.Fprintf(&b, "%-24s %-18s\n", meta.Alias, meta.Type)
|
||||
}
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
vaultCmd.AddCommand(vaultUnlockCmd)
|
||||
vaultCmd.AddCommand(vaultLockCmd)
|
||||
vaultCmd.AddCommand(vaultStatusCmd)
|
||||
vaultCmd.AddCommand(vaultChangePasswordCmd)
|
||||
vaultCmd.AddCommand(vaultListCmd)
|
||||
vaultCmd.AddCommand(vaultDeleteCmd)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFormatVaultSecretsListDoesNotExposeSecretValues(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
mustPutSecret(t, v, "prod", "ssh_password", "super-secret")
|
||||
mustPutSecret(t, v, "stage", "key_passphrase", "also-secret")
|
||||
|
||||
output, err := formatVaultSecretsList(v)
|
||||
if err != nil {
|
||||
t.Fatalf("format vault secrets list: %v", err)
|
||||
}
|
||||
|
||||
for _, want := range []string{"prod", "ssh_password", "stage", "key_passphrase"} {
|
||||
if !strings.Contains(output, want) {
|
||||
t.Fatalf("expected output to contain %q\noutput:\n%s", want, output)
|
||||
}
|
||||
}
|
||||
for _, secretValue := range []string{"super-secret", "also-secret"} {
|
||||
if strings.Contains(output, secretValue) {
|
||||
t.Fatalf("expected output not to expose secret value %q\noutput:\n%s", secretValue, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatVaultSecretsListHandlesEmptyVault(t *testing.T) {
|
||||
v := newUnlockedTestVault(t)
|
||||
|
||||
output, err := formatVaultSecretsList(v)
|
||||
if err != nil {
|
||||
t.Fatalf("format empty vault secrets list: %v", err)
|
||||
}
|
||||
if !strings.Contains(output, "No secrets stored.") {
|
||||
t.Fatalf("expected empty output message, got:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,295 @@
|
|||
# Server List Scalability Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Keep the server list usable when the saved server count is larger than the terminal height.
|
||||
|
||||
**Architecture:** The list screen should render a bounded table viewport instead of rendering every server row. Selection remains owned by the existing `bubbles/list.Model`, while `viewServerList` derives a visible row range around the selected item and keeps the selected-server detail panel and footer visible.
|
||||
|
||||
**Tech Stack:** Go, Bubble Tea, Bubbles list, Lip Gloss, existing TUI tests in `internal/tui/app_test.go`.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add Regression Coverage For Long Server Lists
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/tui/app_test.go`
|
||||
|
||||
- [ ] **Step 1: Add a test for a constrained terminal height**
|
||||
|
||||
Add a test that creates more servers than can fit on screen, sets a small terminal size, renders the list, and verifies the selected details and footer are still visible.
|
||||
|
||||
```go
|
||||
func TestServerListViewKeepsDetailsVisibleWithManyServers(t *testing.T) {
|
||||
servers := make([]*model.Server, 45)
|
||||
for i := range servers {
|
||||
servers[i] = &model.Server{
|
||||
Alias: fmt.Sprintf("server-%02d", i+1),
|
||||
DisplayName: fmt.Sprintf("Server %02d", i+1),
|
||||
Host: fmt.Sprintf("host-%02d.example.org", i+1),
|
||||
Port: 22,
|
||||
User: "mirivlad",
|
||||
AuthMethod: model.AuthKey,
|
||||
LastTestStatus: model.TestUnknown,
|
||||
}
|
||||
}
|
||||
|
||||
m := New(servers)
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
|
||||
model := updated.(*tuiModel)
|
||||
|
||||
view := model.View()
|
||||
if !strings.Contains(view, "Server 01") {
|
||||
t.Fatalf("expected first selected server to be visible:\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, "Selected") {
|
||||
t.Fatalf("expected selected server details to remain visible:\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, "Enter connect") {
|
||||
t.Fatalf("expected footer to remain visible:\n%s", view)
|
||||
}
|
||||
if count := strings.Count(view, "server-"); count >= len(servers) {
|
||||
t.Fatalf("expected bounded row rendering, rendered %d server aliases", count)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Run the focused test and confirm it fails**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
|
||||
```
|
||||
|
||||
Expected: FAIL because the current table renders every server and can push the detail panel/footer below the visible terminal area.
|
||||
|
||||
### Task 2: Compute A Visible Row Window
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/tui/app.go`
|
||||
- Modify: `internal/tui/app_test.go`
|
||||
|
||||
- [ ] **Step 1: Add focused tests for visible range calculation**
|
||||
|
||||
Add tests for a helper that computes the inclusive start and exclusive end indexes for rendered rows.
|
||||
|
||||
```go
|
||||
func TestVisibleServerRangeKeepsSelectionInsideWindow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
total int
|
||||
selected int
|
||||
available int
|
||||
wantStart int
|
||||
wantEnd int
|
||||
}{
|
||||
{name: "first page", total: 40, selected: 0, available: 10, wantStart: 0, wantEnd: 10},
|
||||
{name: "middle page", total: 40, selected: 20, available: 10, wantStart: 11, wantEnd: 21},
|
||||
{name: "last page", total: 40, selected: 39, available: 10, wantStart: 30, wantEnd: 40},
|
||||
{name: "all fit", total: 5, selected: 3, available: 10, wantStart: 0, wantEnd: 5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
start, end := visibleServerRange(tt.total, tt.selected, tt.available)
|
||||
if start != tt.wantStart || end != tt.wantEnd {
|
||||
t.Fatalf("visibleServerRange() = %d, %d; want %d, %d", start, end, tt.wantStart, tt.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Implement `visibleServerRange`**
|
||||
|
||||
Add a small helper near `selectedServer`.
|
||||
|
||||
```go
|
||||
func visibleServerRange(total, selected, available int) (int, int) {
|
||||
if total <= 0 || available <= 0 {
|
||||
return 0, 0
|
||||
}
|
||||
if available >= total {
|
||||
return 0, total
|
||||
}
|
||||
if selected < 0 {
|
||||
selected = 0
|
||||
}
|
||||
if selected >= total {
|
||||
selected = total - 1
|
||||
}
|
||||
|
||||
start := selected - available + 1
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
end := start + available
|
||||
if end > total {
|
||||
end = total
|
||||
start = end - available
|
||||
}
|
||||
return start, end
|
||||
}
|
||||
```
|
||||
|
||||
- [ ] **Step 3: Run helper tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestVisibleServerRangeKeepsSelectionInsideWindow -count=1
|
||||
```
|
||||
|
||||
Expected: PASS.
|
||||
|
||||
### Task 3: Render Only Rows That Fit
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/tui/app.go`
|
||||
- Modify: `internal/tui/app_test.go`
|
||||
|
||||
- [ ] **Step 1: Reserve terminal space for fixed UI blocks**
|
||||
|
||||
Add a helper that decides how many server rows may be rendered while keeping the selected details and footer visible.
|
||||
|
||||
```go
|
||||
func (m *tuiModel) visibleServerRows() int {
|
||||
if m.height <= 0 {
|
||||
return len(m.servers)
|
||||
}
|
||||
|
||||
const fixedRows = 16
|
||||
rows := m.height - fixedRows
|
||||
if rows < 3 {
|
||||
return 3
|
||||
}
|
||||
return rows
|
||||
}
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Use the visible range in `viewServerList`**
|
||||
|
||||
In `viewServerList`, replace the loop over all servers with a bounded loop:
|
||||
|
||||
```go
|
||||
selectedIndex := m.list.Index()
|
||||
start, end := visibleServerRange(len(m.servers), selectedIndex, m.visibleServerRows())
|
||||
for _, server := range m.servers[start:end] {
|
||||
// existing row rendering body stays unchanged
|
||||
}
|
||||
```
|
||||
|
||||
Then render a compact range hint when rows are hidden:
|
||||
|
||||
```go
|
||||
if len(m.servers) > end-start {
|
||||
b.WriteString(helpStyle.Render(fmt.Sprintf(" Showing %d-%d of %d", start+1, end, len(m.servers))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
```
|
||||
|
||||
- [ ] **Step 3: Run long-list regression test**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
|
||||
```
|
||||
|
||||
Expected: PASS.
|
||||
|
||||
### Task 4: Verify Navigation Still Works
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/tui/app_test.go`
|
||||
|
||||
- [ ] **Step 1: Add a test for moving selection beyond the first window**
|
||||
|
||||
Use the existing `m.list.Update` path by sending `tea.KeyDown` messages and confirm the rendered window follows the selected server.
|
||||
|
||||
```go
|
||||
func TestServerListViewScrollsWithSelection(t *testing.T) {
|
||||
servers := make([]*model.Server, 45)
|
||||
for i := range servers {
|
||||
servers[i] = &model.Server{
|
||||
Alias: fmt.Sprintf("server-%02d", i+1),
|
||||
DisplayName: fmt.Sprintf("Server %02d", i+1),
|
||||
Host: fmt.Sprintf("host-%02d.example.org", i+1),
|
||||
Port: 22,
|
||||
User: "mirivlad",
|
||||
AuthMethod: model.AuthKey,
|
||||
}
|
||||
}
|
||||
|
||||
m := New(servers)
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
|
||||
model := updated.(*tuiModel)
|
||||
for i := 0; i < 20; i++ {
|
||||
updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyDown})
|
||||
model = updated.(*tuiModel)
|
||||
}
|
||||
|
||||
view := model.View()
|
||||
if !strings.Contains(view, "Server 21") {
|
||||
t.Fatalf("expected selected server to be visible after navigation:\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, "Showing") {
|
||||
t.Fatalf("expected range hint for long server list:\n%s", view)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Run TUI tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -count=1
|
||||
```
|
||||
|
||||
Expected: PASS.
|
||||
|
||||
### Task 5: Final Verification And Build
|
||||
|
||||
**Files:**
|
||||
- No source edits expected.
|
||||
|
||||
- [ ] **Step 1: Run the full test suite**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
env GOCACHE=/tmp/sshkeeper-go-cache go test ./...
|
||||
```
|
||||
|
||||
Expected: all packages pass.
|
||||
|
||||
- [ ] **Step 2: Rebuild the project binary**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
env GOCACHE=/tmp/sshkeeper-go-cache go build -o bin/sshkeeper .
|
||||
```
|
||||
|
||||
Expected: exit code 0 and updated `bin/sshkeeper`.
|
||||
|
||||
- [ ] **Step 3: Commit the implementation**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
git add internal/tui/app.go internal/tui/app_test.go bin/sshkeeper
|
||||
git commit -m "fix: keep server list usable with many servers"
|
||||
```
|
||||
|
||||
Expected: commit succeeds.
|
||||
|
||||
---
|
||||
|
||||
## Self-Review
|
||||
|
||||
- Spec coverage: the plan covers the known failure mode for 40+ servers, keeps selected details visible, keeps the footer visible, and preserves existing `bubbles/list` navigation.
|
||||
- Placeholder scan: no `TBD`, `TODO`, or open-ended implementation placeholders remain.
|
||||
- Type consistency: helper names and files match the current TUI code shape.
|
||||
2
go.mod
2
go.mod
|
|
@ -10,6 +10,7 @@ require (
|
|||
github.com/creack/pty v1.1.24
|
||||
github.com/spf13/cobra v1.10.2
|
||||
golang.org/x/crypto v0.52.0
|
||||
golang.org/x/sys v0.45.0
|
||||
golang.org/x/term v0.43.0
|
||||
modernc.org/sqlite v1.50.1
|
||||
)
|
||||
|
|
@ -41,7 +42,6 @@ require (
|
|||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||
github.com/spf13/pflag v1.0.9 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
modernc.org/libc v1.72.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
|
|
|||
|
|
@ -30,6 +30,17 @@ func (db *DB) UpdateServer(s *model.Server) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (db *DB) UpdateServerByAlias(oldAlias string, s *model.Server) error {
|
||||
_, err := db.conn.Exec(`
|
||||
UPDATE servers SET
|
||||
alias=?, display_name=?, host=?, port=?, user=?, auth_method=?,
|
||||
identity_file=?, proxy_jump=?, group_name=?, notes=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE alias=?`,
|
||||
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
|
||||
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, oldAlias)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DB) DeleteServer(alias string) error {
|
||||
_, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
)
|
||||
|
||||
func TestUpdateServerByAliasCanRenameAlias(t *testing.T) {
|
||||
db, err := Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
server := &model.Server{
|
||||
Alias: "old",
|
||||
Host: "old.example",
|
||||
Port: 22,
|
||||
User: "root",
|
||||
AuthMethod: model.AuthPassword,
|
||||
}
|
||||
if err := db.CreateServer(server); err != nil {
|
||||
t.Fatalf("create server: %v", err)
|
||||
}
|
||||
|
||||
server.Alias = "new"
|
||||
server.Host = "new.example"
|
||||
server.AuthMethod = model.AuthKey
|
||||
if err := db.UpdateServerByAlias("old", server); err != nil {
|
||||
t.Fatalf("update server by old alias: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.GetServer("old"); err == nil {
|
||||
t.Fatal("expected old alias to be gone")
|
||||
}
|
||||
got, err := db.GetServer("new")
|
||||
if err != nil {
|
||||
t.Fatalf("get new alias: %v", err)
|
||||
}
|
||||
if got.ID != server.ID {
|
||||
t.Fatalf("expected ID to stay %d, got %d", server.ID, got.ID)
|
||||
}
|
||||
if got.Host != "new.example" || got.AuthMethod != model.AuthKey {
|
||||
t.Fatalf("unexpected updated server: %#v", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/config"
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
|
|
@ -14,7 +13,7 @@ import (
|
|||
type VaultFunc func(serverAlias string, secretType string) (string, error)
|
||||
|
||||
func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error {
|
||||
args := buildArgs(server)
|
||||
args := BuildSSHArgs(server)
|
||||
|
||||
switch server.AuthMethod {
|
||||
case model.AuthPassword:
|
||||
|
|
@ -25,8 +24,7 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
|
|||
return ConnectWithPassword(cfg.SSH.Binary, args, password)
|
||||
|
||||
case model.AuthKeyPassphrase:
|
||||
// For key+passphrase, we need to handle the passphrase
|
||||
// For now, let ssh-agent handle it or prompt normally
|
||||
// For key+passphrase, let ssh-agent handle it or prompt normally
|
||||
// TODO: use ssh-agent or similar
|
||||
fallthrough
|
||||
|
||||
|
|
@ -46,13 +44,11 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
|
|||
}
|
||||
|
||||
func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) {
|
||||
args := buildArgs(server)
|
||||
args := BuildSSHArgs(server)
|
||||
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec))
|
||||
|
||||
switch server.AuthMethod {
|
||||
case model.AuthPassword:
|
||||
// For password auth, we can't use BatchMode
|
||||
// Use a short timeout and try to connect
|
||||
args = append(args, "-o", "NumberOfPasswordPrompts=1")
|
||||
password, err := getVault(server.Alias, "ssh_password")
|
||||
if err != nil {
|
||||
|
|
@ -81,40 +77,29 @@ func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, s
|
|||
}
|
||||
}
|
||||
|
||||
// testWithPassword tests SSH connection with password auth via PTY-wrapper.
|
||||
// It connects, sends the password, runs the test command, and checks the output.
|
||||
func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) {
|
||||
// For password test, we use PTY approach with a short timeout
|
||||
// This is a simplified version - in production, use ConnectWithPassword
|
||||
// with a test command
|
||||
args = append(args, cfg.SSH.TestCommand)
|
||||
|
||||
cmd := exec.Command(cfg.SSH.Binary, args...)
|
||||
cmd.Stdin = nil
|
||||
cmd.Stdout = nil
|
||||
cmd.Stderr = nil
|
||||
|
||||
// Use a timeout
|
||||
done := make(chan error, 1)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return false, err.Error()
|
||||
ok, output := connectWithPasswordAndRead(cfg.SSH.Binary, args, password, cfg.SSH.ConnectTimeoutSec)
|
||||
if !ok {
|
||||
return false, output
|
||||
}
|
||||
|
||||
go func() {
|
||||
done <- cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
result := strings.TrimSpace(output)
|
||||
if result == "SSHKEEPER_OK" {
|
||||
return true, ""
|
||||
case <-time.After(time.Duration(cfg.SSH.ConnectTimeoutSec) * time.Second):
|
||||
cmd.Process.Kill()
|
||||
return false, "connection timeout"
|
||||
}
|
||||
// The output might have the test command echo before the result
|
||||
if strings.Contains(result, "SSHKEEPER_OK") {
|
||||
return true, ""
|
||||
}
|
||||
return false, result
|
||||
}
|
||||
|
||||
func buildArgs(server *model.Server) []string {
|
||||
// BuildSSHArgs builds the SSH command arguments for a server profile.
|
||||
func BuildSSHArgs(server *model.Server) []string {
|
||||
var args []string
|
||||
|
||||
args = append(args, "-p", fmt.Sprintf("%d", server.Port))
|
||||
|
|
@ -127,8 +112,6 @@ func buildArgs(server *model.Server) []string {
|
|||
args = append(args, "-J", server.ProxyJump)
|
||||
}
|
||||
|
||||
// Disable strict host key checking for first connection
|
||||
// In production, this should be configurable
|
||||
args = append(args, "-o", "StrictHostKeyChecking=accept-new")
|
||||
|
||||
target := fmt.Sprintf("%s@%s", server.User, server.Host)
|
||||
|
|
|
|||
|
|
@ -7,19 +7,19 @@ import (
|
|||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`)
|
||||
|
||||
// ConnectWithPassword runs SSH through a PTY, detects the password prompt,
|
||||
// sends the password, and then bridges the user terminal to the SSH session.
|
||||
func ConnectWithPassword(sshBinary string, args []string, password string) error {
|
||||
// Start SSH with PTY
|
||||
cmd := exec.Command(sshBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
|
|
@ -33,18 +33,14 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
|
|||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
// Save terminal state and set to raw
|
||||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("set raw terminal: %w", err)
|
||||
}
|
||||
defer term.Restore(int(os.Stdin.Fd()), oldState)
|
||||
|
||||
// Channel to signal when password has been sent
|
||||
passwordSent := make(chan bool, 1)
|
||||
passwordSent := new(atomic.Bool)
|
||||
done := make(chan error, 1)
|
||||
|
||||
// Read from PTY, detect password prompt
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
var accumulated strings.Builder
|
||||
|
|
@ -54,22 +50,17 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
|
|||
if n > 0 {
|
||||
data := buf[:n]
|
||||
accumulated.Write(data)
|
||||
|
||||
// Write to stdout
|
||||
os.Stdout.Write(data)
|
||||
|
||||
// Check for password prompt
|
||||
if !<-passwordSent {
|
||||
if !passwordSent.Load() {
|
||||
text := accumulated.String()
|
||||
if passwordPromptRe.MatchString(text) {
|
||||
passwordSent <- true
|
||||
passwordSent.Store(true)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
ptmx.Write([]byte(password + "\r"))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Reset accumulated buffer periodically to avoid unbounded growth
|
||||
if accumulated.Len() > 8192 {
|
||||
s := accumulated.String()
|
||||
accumulated.Reset()
|
||||
|
|
@ -87,14 +78,142 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
|
|||
}
|
||||
}()
|
||||
|
||||
// Copy stdin to PTY
|
||||
go func() {
|
||||
io.Copy(ptmx, os.Stdin)
|
||||
}()
|
||||
stopInput, waitInput, restoreInput, err := forwardInputToPTY(ptmx)
|
||||
if err != nil {
|
||||
term.Restore(int(os.Stdin.Fd()), oldState)
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for command completion
|
||||
err = cmd.Wait()
|
||||
passwordSent <- false // signal to stop
|
||||
close(stopInput)
|
||||
waitInput()
|
||||
|
||||
if restoreErr := restoreInput(); err == nil && restoreErr != nil {
|
||||
err = restoreErr
|
||||
}
|
||||
if restoreErr := term.Restore(int(os.Stdin.Fd()), oldState); err == nil && restoreErr != nil {
|
||||
err = restoreErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func forwardInputToPTY(ptmx *os.File) (chan struct{}, func(), func() error, error) {
|
||||
stdinFD := int(os.Stdin.Fd())
|
||||
flags, err := unix.FcntlInt(uintptr(stdinFD), unix.F_GETFL, 0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("get stdin flags: %w", err)
|
||||
}
|
||||
if err := unix.SetNonblock(stdinFD, true); err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("set stdin nonblocking: %w", err)
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, readErr := unix.Read(stdinFD, buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := ptmx.Write(buf[:n]); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if readErr == nil {
|
||||
continue
|
||||
}
|
||||
if readErr == unix.EAGAIN || readErr == unix.EWOULDBLOCK {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
restore := func() error {
|
||||
_, err := unix.FcntlInt(uintptr(stdinFD), unix.F_SETFL, flags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("restore stdin flags: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return stop, wg.Wait, restore, nil
|
||||
}
|
||||
|
||||
// connectWithPasswordAndRead runs SSH through a PTY, sends the password,
|
||||
// collects all output, and returns it. Used for non-interactive testing.
|
||||
// Returns (true, output) on success, (false, error) on failure.
|
||||
func connectWithPasswordAndRead(sshBinary string, args []string, password string, timeoutSec int) (bool, string) {
|
||||
cmd := exec.Command(sshBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setsid: true,
|
||||
Setctty: true,
|
||||
}
|
||||
|
||||
ptmx, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("start ssh with pty: %v", err)
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
passwordSent := false
|
||||
done := make(chan string, 1)
|
||||
|
||||
// Read from PTY, detect password prompt, collect output
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
var accumulated strings.Builder
|
||||
|
||||
for {
|
||||
n, err := ptmx.Read(buf)
|
||||
if n > 0 {
|
||||
data := buf[:n]
|
||||
accumulated.Write(data)
|
||||
|
||||
// Check for password prompt
|
||||
if !passwordSent {
|
||||
text := accumulated.String()
|
||||
if passwordPromptRe.MatchString(text) {
|
||||
passwordSent = true
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
ptmx.Write([]byte(password + "\r"))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Reset accumulated buffer periodically
|
||||
if accumulated.Len() > 16384 {
|
||||
s := accumulated.String()
|
||||
accumulated.Reset()
|
||||
accumulated.WriteString(s[len(s)-4096:])
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
done <- accumulated.String()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for command completion or timeout
|
||||
select {
|
||||
case output := <-done:
|
||||
return true, output
|
||||
case <-time.After(time.Duration(timeoutSec) * time.Second):
|
||||
cmd.Process.Kill()
|
||||
return false, "connection timeout"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
)
|
||||
|
||||
func TestConnectWithPasswordDoesNotConsumeInputAfterReturn(t *testing.T) {
|
||||
script := filepath.Join(t.TempDir(), "fake-ssh")
|
||||
if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf 'password: '\nsleep 0.1\nexit 0\n"), 0o700); err != nil {
|
||||
t.Fatalf("write fake ssh: %v", err)
|
||||
}
|
||||
|
||||
master, slave, err := pty.Open()
|
||||
if err != nil {
|
||||
t.Fatalf("open stdin pty: %v", err)
|
||||
}
|
||||
defer master.Close()
|
||||
defer slave.Close()
|
||||
|
||||
oldStdin := os.Stdin
|
||||
oldStdout := os.Stdout
|
||||
stdoutSink, err := os.CreateTemp(t.TempDir(), "stdout")
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout sink: %v", err)
|
||||
}
|
||||
defer stdoutSink.Close()
|
||||
os.Stdin = slave
|
||||
os.Stdout = stdoutSink
|
||||
defer func() {
|
||||
os.Stdin = oldStdin
|
||||
os.Stdout = oldStdout
|
||||
}()
|
||||
|
||||
if err := ConnectWithPassword(script, nil, "secret"); err != nil {
|
||||
t.Fatalf("connect with password: %v", err)
|
||||
}
|
||||
|
||||
if _, err := master.Write([]byte("\n")); err != nil {
|
||||
t.Fatalf("write next enter: %v", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
readDone := make(chan []byte, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
n, _ := slave.Read(buf)
|
||||
readDone <- buf[:n]
|
||||
}()
|
||||
|
||||
select {
|
||||
case got := <-readDone:
|
||||
if len(got) != 1 || got[0] != '\n' {
|
||||
t.Fatalf("expected to read newline, got %q", got)
|
||||
}
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
t.Fatal("expected next Enter to remain readable after ConnectWithPassword returned")
|
||||
}
|
||||
}
|
||||
|
|
@ -26,7 +26,10 @@ var (
|
|||
Background(lipgloss.Color("4")).
|
||||
Bold(true)
|
||||
|
||||
normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15"))
|
||||
normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15"))
|
||||
selectedRowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")).Background(lipgloss.Color("4"))
|
||||
listHeaderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("14")).Bold(true)
|
||||
sectionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")).Bold(true).MarginTop(1)
|
||||
|
||||
testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true)
|
||||
testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true)
|
||||
|
|
@ -67,9 +70,13 @@ type serverItem struct {
|
|||
server *model.Server
|
||||
}
|
||||
|
||||
func (i serverItem) Title() string { return i.server.Alias }
|
||||
func (i serverItem) Description() string { return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod) }
|
||||
func (i serverItem) FilterValue() string { return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User }
|
||||
func (i serverItem) Title() string { return i.server.Alias }
|
||||
func (i serverItem) Description() string {
|
||||
return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod)
|
||||
}
|
||||
func (i serverItem) FilterValue() string {
|
||||
return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User
|
||||
}
|
||||
|
||||
// --- External callbacks ---
|
||||
|
||||
|
|
@ -80,10 +87,12 @@ var (
|
|||
TestConnection func(server *model.Server) (bool, string)
|
||||
// TestConnectionWithPassword tests with explicit password (for form test before save)
|
||||
TestConnectionWithPassword func(server *model.Server, password string) (bool, string)
|
||||
SaveServer func(server *model.Server, password string) error
|
||||
GetGroups func() ([]string, error) // Returns existing group names
|
||||
RenameGroup func(oldName, newName string) error // Rename group for all servers
|
||||
DeleteGroup func(name string) error // Remove group from all servers
|
||||
SaveServer func(server *model.Server, password string, oldAlias string) error
|
||||
UpdateTestResult func(alias string, status model.TestStatus, testErr string) error
|
||||
HasSecret func(alias string, secretType string) bool
|
||||
GetGroups func() ([]string, error)
|
||||
RenameGroup func(oldName, newName string) error
|
||||
DeleteGroup func(name string) error
|
||||
)
|
||||
|
||||
// --- Screen type ---
|
||||
|
|
@ -99,8 +108,8 @@ const (
|
|||
// --- Result type — returned from TUI to caller ---
|
||||
|
||||
type TUIResult struct {
|
||||
Server *model.Server
|
||||
Action string // "connect"
|
||||
Server *model.Server
|
||||
Action string // "connect"
|
||||
}
|
||||
|
||||
// --- Main TUI model ---
|
||||
|
|
@ -196,8 +205,22 @@ func (m *tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
}
|
||||
m.form.testResultTime = time.Now()
|
||||
m.form.err = nil
|
||||
return m, nil
|
||||
}
|
||||
// Update test status in DB and reload list
|
||||
if item, ok := m.list.SelectedItem().(serverItem); ok && UpdateTestResult != nil {
|
||||
status := model.TestUnknown
|
||||
if msg.ok {
|
||||
status = model.TestOK
|
||||
} else if msg.err != "" {
|
||||
status = model.TestFailed
|
||||
}
|
||||
UpdateTestResult(item.server.Alias, status, msg.err)
|
||||
}
|
||||
return m, func() tea.Msg {
|
||||
servers, err := ListServers()
|
||||
return serversLoadedMsg{servers: servers, err: err}
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case saveDoneMsg:
|
||||
if m.form != nil {
|
||||
|
|
@ -323,11 +346,23 @@ func (m *tuiModel) updateSearch(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||
|
||||
func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
if msg.Type == tea.KeyEsc {
|
||||
if m.form != nil && (m.form.showGroupList || m.form.showAuthList) {
|
||||
updated, cmd := m.form.Update(msg)
|
||||
if fm, ok := updated.(*formModel); ok {
|
||||
m.form = fm
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
m.screen = screenList
|
||||
m.form = nil
|
||||
m.err = nil
|
||||
m.success = ""
|
||||
return m, nil
|
||||
// Reload server list after form close
|
||||
return m, func() tea.Msg {
|
||||
servers, err := ListServers()
|
||||
return serversLoadedMsg{servers: servers, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
updated, cmd := m.form.Update(msg)
|
||||
|
|
@ -342,9 +377,7 @@ func (m *tuiModel) View() string {
|
|||
|
||||
switch m.screen {
|
||||
case screenList:
|
||||
b.WriteString(m.list.View())
|
||||
b.WriteString("\n")
|
||||
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
|
||||
b.WriteString(m.viewServerList())
|
||||
|
||||
case screenSearch:
|
||||
b.WriteString("Search: " + m.searchInput.View() + "\n")
|
||||
|
|
@ -366,13 +399,148 @@ func (m *tuiModel) View() string {
|
|||
return b.String()
|
||||
}
|
||||
|
||||
func (m *tuiModel) viewServerList() string {
|
||||
var b strings.Builder
|
||||
selectedAlias := ""
|
||||
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
|
||||
selectedAlias = item.server.Alias
|
||||
}
|
||||
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf("sshkeeper %d servers", len(m.servers))))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(helpStyle.Render(fmt.Sprintf("Vault unlocked | %s", testSummary(m.servers))))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(listHeaderStyle.Render(fmt.Sprintf(" %-20s %-20s %-34s %-12s %-10s %s", "NAME", "ALIAS", "TARGET", "AUTH", "GROUP", "STATUS")))
|
||||
b.WriteString("\n")
|
||||
|
||||
if len(m.servers) == 0 {
|
||||
b.WriteString(helpStyle.Render(" No servers yet. Press Ctrl+A to add one."))
|
||||
b.WriteString("\n")
|
||||
} else {
|
||||
for _, server := range m.servers {
|
||||
marker := " "
|
||||
rowStyle := normalStyle
|
||||
if server.Alias == selectedAlias {
|
||||
marker = ">"
|
||||
rowStyle = selectedRowStyle
|
||||
}
|
||||
name := server.DisplayName
|
||||
if name == "" {
|
||||
name = server.Alias
|
||||
}
|
||||
target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port)
|
||||
group := server.GroupName
|
||||
if group == "" {
|
||||
group = "-"
|
||||
}
|
||||
row := fmt.Sprintf("%s %-20s %-20s %-34s %-12s %-10s %s",
|
||||
marker,
|
||||
truncate(name, 20),
|
||||
truncate(server.Alias, 20),
|
||||
truncate(target, 34),
|
||||
authLabel(server.AuthMethod),
|
||||
truncate(group, 10),
|
||||
testStatusLabel(server),
|
||||
)
|
||||
b.WriteString(rowStyle.Render(row))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
if selectedAlias != "" {
|
||||
if selected := m.selectedServer(); selected != nil {
|
||||
b.WriteString(m.viewSelectedServer(selected))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *tuiModel) selectedServer() *model.Server {
|
||||
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
|
||||
return item.server
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *tuiModel) viewSelectedServer(server *model.Server) string {
|
||||
displayName := server.DisplayName
|
||||
if displayName == "" {
|
||||
displayName = "-"
|
||||
}
|
||||
group := server.GroupName
|
||||
if group == "" {
|
||||
group = "-"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(sectionStyle.Render("Selected"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(fmt.Sprintf(" Alias: %s\n", server.Alias))
|
||||
b.WriteString(fmt.Sprintf(" Display Name: %s\n", displayName))
|
||||
b.WriteString(fmt.Sprintf(" Host: %s\n", server.Host))
|
||||
b.WriteString(fmt.Sprintf(" Port: %d\n", server.Port))
|
||||
b.WriteString(fmt.Sprintf(" User: %s\n", server.User))
|
||||
b.WriteString(fmt.Sprintf(" Auth: %s\n", authLabel(server.AuthMethod)))
|
||||
b.WriteString(fmt.Sprintf(" Group: %s\n", group))
|
||||
b.WriteString(fmt.Sprintf(" Status: %s\n", testStatusLabel(server)))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func testSummary(servers []*model.Server) string {
|
||||
okCount := 0
|
||||
failedCount := 0
|
||||
for _, server := range servers {
|
||||
switch server.LastTestStatus {
|
||||
case model.TestOK:
|
||||
okCount++
|
||||
case model.TestFailed:
|
||||
failedCount++
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%d OK | %d FAIL", okCount, failedCount)
|
||||
}
|
||||
|
||||
func authLabel(auth model.AuthMethod) string {
|
||||
switch auth {
|
||||
case model.AuthPassword:
|
||||
return "password"
|
||||
case model.AuthKey:
|
||||
return "key"
|
||||
case model.AuthKeyPassphrase:
|
||||
return "key+phrase"
|
||||
case model.AuthAgent:
|
||||
return "agent"
|
||||
default:
|
||||
return string(auth)
|
||||
}
|
||||
}
|
||||
|
||||
func testStatusLabel(server *model.Server) string {
|
||||
switch server.LastTestStatus {
|
||||
case model.TestOK:
|
||||
return "OK"
|
||||
case model.TestFailed:
|
||||
if server.LastTestError != "" {
|
||||
return "FAIL"
|
||||
}
|
||||
return "FAIL"
|
||||
default:
|
||||
return "?"
|
||||
}
|
||||
}
|
||||
|
||||
// --- Form model ---
|
||||
|
||||
type formModel struct {
|
||||
edit bool
|
||||
server *model.Server
|
||||
inputs []textinput.Model
|
||||
labels []string
|
||||
password textinput.Model
|
||||
passwordLabel string
|
||||
focusIdx int
|
||||
testResult string
|
||||
testOK bool
|
||||
|
|
@ -385,10 +553,22 @@ type formModel struct {
|
|||
spinner spinner.Model
|
||||
width int
|
||||
height int
|
||||
groups string // cached groups list (comma-separated for display)
|
||||
showGroups bool // show group dropdown
|
||||
groups []string // existing group names
|
||||
groupList list.Model // dropdown list for groups
|
||||
showGroupList bool // whether group dropdown is visible
|
||||
authList list.Model
|
||||
showAuthList bool
|
||||
}
|
||||
|
||||
// groupItem implements list.Item for the group dropdown
|
||||
type groupItem struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (i groupItem) Title() string { return i.name }
|
||||
func (i groupItem) Description() string { return "" }
|
||||
func (i groupItem) FilterValue() string { return i.name }
|
||||
|
||||
func newFormModel(w, h int) *formModel {
|
||||
inputs := make([]textinput.Model, 10)
|
||||
labels := []string{
|
||||
|
|
@ -405,12 +585,12 @@ func newFormModel(w, h int) *formModel {
|
|||
}
|
||||
for i, label := range labels {
|
||||
inputs[i] = textinput.New()
|
||||
inputs[i].Placeholder = label
|
||||
inputs[i].Placeholder = placeholderForLabel(label)
|
||||
inputs[i].CharLimit = 128
|
||||
}
|
||||
|
||||
pw := textinput.New()
|
||||
pw.Placeholder = "Password / Passphrase (stored in vault)"
|
||||
pw.Placeholder = "optional"
|
||||
pw.CharLimit = 256
|
||||
pw.EchoMode = textinput.EchoPassword
|
||||
|
||||
|
|
@ -421,24 +601,75 @@ func newFormModel(w, h int) *formModel {
|
|||
inputs[0].Focus()
|
||||
|
||||
fm := &formModel{
|
||||
inputs: inputs,
|
||||
password: pw,
|
||||
focusIdx: 0,
|
||||
spinner: s,
|
||||
width: w,
|
||||
height: h,
|
||||
inputs: inputs,
|
||||
labels: labels,
|
||||
password: pw,
|
||||
passwordLabel: "Password / Passphrase",
|
||||
focusIdx: 0,
|
||||
spinner: s,
|
||||
width: w,
|
||||
height: h,
|
||||
}
|
||||
fm.authList = newStringList([]string{
|
||||
string(model.AuthPassword),
|
||||
string(model.AuthKey),
|
||||
string(model.AuthKeyPassphrase),
|
||||
string(model.AuthAgent),
|
||||
}, "Select auth method", 34, 16)
|
||||
|
||||
// Load existing groups
|
||||
if GetGroups != nil {
|
||||
if groups, err := GetGroups(); err == nil && len(groups) > 0 {
|
||||
fm.groups = strings.Join(groups, ", ")
|
||||
fm.groups = groups
|
||||
fm.groupList = newStringList(groups, "Select group", 30, 8)
|
||||
}
|
||||
}
|
||||
|
||||
fm.updateFocus()
|
||||
return fm
|
||||
}
|
||||
|
||||
func placeholderForLabel(label string) string {
|
||||
switch label {
|
||||
case "Alias":
|
||||
return "mail.kp"
|
||||
case "Display Name":
|
||||
return "Production mail"
|
||||
case "Host":
|
||||
return "mail.example.org"
|
||||
case "Port":
|
||||
return "22"
|
||||
case "User":
|
||||
return "root"
|
||||
case "Auth Method (password/key/key_passphrase/agent)":
|
||||
return "key"
|
||||
case "Identity File":
|
||||
return "~/.ssh/id_ed25519"
|
||||
case "ProxyJump":
|
||||
return "optional"
|
||||
case "Group (type new or pick from list)":
|
||||
return "KP"
|
||||
case "Notes":
|
||||
return "optional"
|
||||
default:
|
||||
return label
|
||||
}
|
||||
}
|
||||
|
||||
func newStringList(values []string, title string, width, height int) list.Model {
|
||||
items := make([]list.Item, len(values))
|
||||
for i, value := range values {
|
||||
items[i] = groupItem{name: value}
|
||||
}
|
||||
l := list.New(items, list.NewDefaultDelegate(), width, height)
|
||||
l.SetShowStatusBar(false)
|
||||
l.SetShowHelp(false)
|
||||
l.SetShowPagination(false)
|
||||
l.Title = title
|
||||
l.Styles.Title = titleStyle
|
||||
return l
|
||||
}
|
||||
|
||||
func newEditFormModel(s *model.Server, w, h int) *formModel {
|
||||
fm := newFormModel(w, h)
|
||||
fm.edit = true
|
||||
|
|
@ -453,6 +684,21 @@ func newEditFormModel(s *model.Server, w, h int) *formModel {
|
|||
fm.inputs[7].SetValue(s.ProxyJump)
|
||||
fm.inputs[8].SetValue(s.GroupName)
|
||||
fm.inputs[9].SetValue(s.Notes)
|
||||
if HasSecret != nil {
|
||||
switch s.AuthMethod {
|
||||
case model.AuthPassword:
|
||||
if HasSecret(s.Alias, "ssh_password") {
|
||||
fm.passwordLabel = "Password (secret saved; leave blank to keep)"
|
||||
fm.password.Placeholder = ""
|
||||
}
|
||||
case model.AuthKeyPassphrase:
|
||||
if HasSecret(s.Alias, "key_passphrase") {
|
||||
fm.passwordLabel = "Key passphrase (secret saved; leave blank to keep)"
|
||||
fm.password.Placeholder = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
fm.updateFocus()
|
||||
return fm
|
||||
}
|
||||
|
||||
|
|
@ -498,6 +744,48 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
return fm, cmd
|
||||
}
|
||||
|
||||
// Handle group dropdown
|
||||
if fm.showGroupList {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyEsc:
|
||||
fm.showGroupList = false
|
||||
return fm, nil
|
||||
case tea.KeyEnter:
|
||||
if item, ok := fm.groupList.SelectedItem().(groupItem); ok {
|
||||
fm.inputs[8].SetValue(item.name)
|
||||
}
|
||||
fm.showGroupList = false
|
||||
return fm, nil
|
||||
}
|
||||
}
|
||||
// Pass other keys to the list
|
||||
var cmd tea.Cmd
|
||||
fm.groupList, cmd = fm.groupList.Update(msg)
|
||||
return fm, cmd
|
||||
}
|
||||
|
||||
if fm.showAuthList {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyEsc:
|
||||
fm.showAuthList = false
|
||||
return fm, nil
|
||||
case tea.KeyEnter:
|
||||
if item, ok := fm.authList.SelectedItem().(groupItem); ok {
|
||||
fm.inputs[5].SetValue(item.name)
|
||||
}
|
||||
fm.showAuthList = false
|
||||
return fm, nil
|
||||
}
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
fm.authList, cmd = fm.authList.Update(msg)
|
||||
return fm, cmd
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
|
|
@ -519,6 +807,17 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
fm.updateFocus()
|
||||
return fm, nil
|
||||
|
||||
case tea.KeyRunes:
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 5 {
|
||||
fm.showAuthList = true
|
||||
return fm, nil
|
||||
}
|
||||
// '/' on Group field opens group dropdown
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 8 && len(fm.groups) > 0 {
|
||||
fm.showGroupList = true
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
case tea.KeyEnter:
|
||||
switch {
|
||||
case fm.focusIdx == len(fm.inputs)+1:
|
||||
|
|
@ -576,20 +875,36 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||
func (fm *formModel) updateFocus() {
|
||||
for i := range fm.inputs {
|
||||
fm.inputs[i].Blur()
|
||||
fm.inputs[i].Prompt = blurredStyle.Render(fm.inputs[i].Placeholder + ": ")
|
||||
fm.inputs[i].Prompt = blurredStyle.Render(fm.labelAt(i) + ": ")
|
||||
}
|
||||
fm.password.Blur()
|
||||
fm.password.Prompt = blurredStyle.Render(fm.password.Placeholder + ": ")
|
||||
fm.password.Prompt = blurredStyle.Render(fm.passwordLabel + ": ")
|
||||
|
||||
if fm.focusIdx < len(fm.inputs) {
|
||||
fm.inputs[fm.focusIdx].Focus()
|
||||
fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.inputs[fm.focusIdx].Placeholder + "> ")
|
||||
fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.labelAt(fm.focusIdx) + "> ")
|
||||
} else if fm.focusIdx == len(fm.inputs) {
|
||||
fm.password.Focus()
|
||||
fm.password.Prompt = focusedStyle.Render(fm.password.Placeholder + "> ")
|
||||
fm.password.Prompt = focusedStyle.Render(fm.passwordLabel + "> ")
|
||||
}
|
||||
}
|
||||
|
||||
func (fm *formModel) labelAt(index int) string {
|
||||
if index >= 0 && index < len(fm.labels) {
|
||||
if index == 5 {
|
||||
return "Auth Method (/ pick)"
|
||||
}
|
||||
if index == 8 {
|
||||
if len(fm.groups) > 0 {
|
||||
return "Group (/ pick)"
|
||||
}
|
||||
return "Group"
|
||||
}
|
||||
return fm.labels[index]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (fm *formModel) runTest() tea.Cmd {
|
||||
fm.testing = true
|
||||
fm.testResult = ""
|
||||
|
|
@ -635,7 +950,11 @@ func (fm *formModel) runSave() tea.Cmd {
|
|||
if s.Host == "" {
|
||||
return saveDoneMsg{err: fmt.Errorf("host is required")}
|
||||
}
|
||||
err := SaveServer(s, pw)
|
||||
oldAlias := ""
|
||||
if fm.edit && fm.server != nil {
|
||||
oldAlias = fm.server.Alias
|
||||
}
|
||||
err := SaveServer(s, pw, oldAlias)
|
||||
return saveDoneMsg{err: err}
|
||||
},
|
||||
)
|
||||
|
|
@ -672,13 +991,72 @@ func (fm *formModel) View() string {
|
|||
b.WriteString(titleStyle.Render(title))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for i := range fm.inputs {
|
||||
// Calculate visible range based on terminal height
|
||||
// Reserve lines for: title (2) + password (1) + buttons (3) + help (1) + padding (2) = ~9
|
||||
reserved := 9
|
||||
available := fm.height - reserved
|
||||
if available < 4 {
|
||||
available = 4
|
||||
}
|
||||
|
||||
numInputs := len(fm.inputs)
|
||||
startIdx := 0
|
||||
endIdx := numInputs
|
||||
|
||||
// Scroll: keep focused field visible
|
||||
if numInputs > available {
|
||||
focusInput := fm.focusIdx
|
||||
if focusInput >= numInputs {
|
||||
focusInput = numInputs - 1
|
||||
}
|
||||
// Try to show `available` fields centered on focus
|
||||
startIdx = focusInput - available/2
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
endIdx = startIdx + available
|
||||
if endIdx > numInputs {
|
||||
endIdx = numInputs
|
||||
startIdx = endIdx - available
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Show scroll indicator if needed
|
||||
if startIdx > 0 {
|
||||
b.WriteString(helpStyle.Render(" ↑ more fields above\n"))
|
||||
}
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
if section := formSectionTitle(i); section != "" {
|
||||
b.WriteString(sectionStyle.Render(section))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if i == 5 {
|
||||
fm.inputs[i].Placeholder = "password/key/key_passphrase/agent"
|
||||
}
|
||||
// Show group hint inline in placeholder for Group field
|
||||
if i == 8 && len(fm.groups) > 0 && !fm.showGroupList {
|
||||
fm.inputs[i].Placeholder = truncate(strings.Join(fm.groups, ", "), 25)
|
||||
}
|
||||
b.WriteString(fm.inputs[i].View())
|
||||
b.WriteString("\n")
|
||||
// Show existing groups hint under Group field
|
||||
if i == 8 && fm.groups != "" {
|
||||
b.WriteString(helpStyle.Render(" Groups: " + fm.groups + "\n"))
|
||||
if i == 5 && fm.showAuthList {
|
||||
b.WriteString("\n" + renderDropdown(fm.authList) + "\n")
|
||||
b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
|
||||
return b.String()
|
||||
}
|
||||
if i == 8 && fm.showGroupList {
|
||||
b.WriteString("\n" + renderDropdown(fm.groupList) + "\n")
|
||||
b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
|
||||
return b.String()
|
||||
}
|
||||
}
|
||||
|
||||
if endIdx < numInputs {
|
||||
b.WriteString(helpStyle.Render(fmt.Sprintf(" ↓ more fields below (%d-%d of %d)\n", startIdx+1, endIdx, numInputs)))
|
||||
}
|
||||
|
||||
b.WriteString(fm.password.View())
|
||||
|
|
@ -723,8 +1101,53 @@ func (fm *formModel) View() string {
|
|||
saveBtn = normalStyle.Render(saveBtn)
|
||||
}
|
||||
|
||||
b.WriteString("\n" + testBtn + " " + saveBtn + "\n\n")
|
||||
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | Enter select | Esc back"))
|
||||
b.WriteString("\n" + sectionStyle.Render("Actions") + "\n")
|
||||
b.WriteString(testBtn + " " + saveBtn + "\n\n")
|
||||
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | / pick list | Enter select | Esc back"))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func renderDropdown(l list.Model) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(sectionStyle.Render(l.Title))
|
||||
b.WriteString("\n")
|
||||
for i, item := range l.Items() {
|
||||
group, ok := item.(groupItem)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
prefix := " "
|
||||
style := normalStyle
|
||||
if i == l.Index() {
|
||||
prefix = "> "
|
||||
style = selectedRowStyle
|
||||
}
|
||||
b.WriteString(style.Render(prefix + group.name))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return strings.TrimRight(b.String(), "\n")
|
||||
}
|
||||
|
||||
func formSectionTitle(index int) string {
|
||||
switch index {
|
||||
case 0:
|
||||
return "Identity"
|
||||
case 2:
|
||||
return "Connection"
|
||||
case 5:
|
||||
return "Authentication"
|
||||
case 8:
|
||||
return "Metadata"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// truncate limits a string to maxLen, adding "..." if truncated
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,322 @@
|
|||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
)
|
||||
|
||||
func TestServerListViewUsesDashboardLayout(t *testing.T) {
|
||||
now := time.Date(2026, 5, 28, 1, 50, 0, 0, time.UTC)
|
||||
m := New([]*model.Server{
|
||||
{
|
||||
Alias: "mail.kp",
|
||||
DisplayName: "Mail",
|
||||
Host: "mail.example.org",
|
||||
Port: 222,
|
||||
User: "mirivlad",
|
||||
AuthMethod: model.AuthPassword,
|
||||
GroupName: "KP",
|
||||
LastTestStatus: model.TestOK,
|
||||
LastTestAt: &now,
|
||||
},
|
||||
{
|
||||
Alias: "mirv.top",
|
||||
Host: "mirv.top",
|
||||
Port: 22,
|
||||
User: "root",
|
||||
AuthMethod: model.AuthKey,
|
||||
LastTestStatus: model.TestUnknown,
|
||||
},
|
||||
})
|
||||
m.width = 100
|
||||
m.height = 30
|
||||
m.list.SetSize(100, 24)
|
||||
|
||||
view := m.View()
|
||||
for _, want := range []string{
|
||||
"sshkeeper",
|
||||
"2 servers",
|
||||
"Vault",
|
||||
"NAME",
|
||||
"TARGET",
|
||||
"AUTH",
|
||||
"GROUP",
|
||||
"STATUS",
|
||||
"Mail",
|
||||
"mail.kp",
|
||||
"mirivlad@mail.example.org:222",
|
||||
"KP",
|
||||
"OK",
|
||||
"Selected",
|
||||
"Host: mail.example.org",
|
||||
"Alias: mail.kp",
|
||||
"Display Name: Mail",
|
||||
"Port: 222",
|
||||
"Enter connect",
|
||||
} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected list view to contain %q\nview:\n%s", want, view)
|
||||
}
|
||||
}
|
||||
if strings.Contains(view, "Profiles managed locally") {
|
||||
t.Fatalf("expected compact status header instead of README text\nview:\n%s", view)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscClosesGroupListBeforeLeavingForm(t *testing.T) {
|
||||
oldGetGroups := GetGroups
|
||||
GetGroups = func() ([]string, error) {
|
||||
return []string{"prod", "stage"}, nil
|
||||
}
|
||||
defer func() { GetGroups = oldGetGroups }()
|
||||
|
||||
m := &tuiModel{
|
||||
screen: screenForm,
|
||||
form: newFormModel(80, 24),
|
||||
}
|
||||
m.form.focusIdx = 8
|
||||
m.form.updateFocus()
|
||||
|
||||
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||
m = updated.(*tuiModel)
|
||||
if !m.form.showGroupList {
|
||||
t.Fatal("expected / on the group field to open the group list")
|
||||
}
|
||||
|
||||
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
m = updated.(*tuiModel)
|
||||
|
||||
if m.screen != screenForm {
|
||||
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
|
||||
}
|
||||
if m.form == nil {
|
||||
t.Fatal("expected form to remain open")
|
||||
}
|
||||
if m.form.showGroupList {
|
||||
t.Fatal("expected Esc to close only the group list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscClosesAuthMethodListBeforeLeavingForm(t *testing.T) {
|
||||
m := &tuiModel{
|
||||
screen: screenForm,
|
||||
form: newFormModel(80, 24),
|
||||
}
|
||||
m.form.focusIdx = 5
|
||||
m.form.updateFocus()
|
||||
|
||||
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||
m = updated.(*tuiModel)
|
||||
if !m.form.showAuthList {
|
||||
t.Fatal("expected / on the auth method field to open the auth method list")
|
||||
}
|
||||
|
||||
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
m = updated.(*tuiModel)
|
||||
|
||||
if m.screen != screenForm {
|
||||
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
|
||||
}
|
||||
if m.form == nil {
|
||||
t.Fatal("expected form to remain open")
|
||||
}
|
||||
if m.form.showAuthList {
|
||||
t.Fatal("expected Esc to close only the auth method list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMethodListSelectsValue(t *testing.T) {
|
||||
fm := newFormModel(80, 24)
|
||||
fm.focusIdx = 5
|
||||
fm.updateFocus()
|
||||
|
||||
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||
fm = updated.(*formModel)
|
||||
if !fm.showAuthList {
|
||||
t.Fatal("expected / on the auth method field to open the auth method list")
|
||||
}
|
||||
|
||||
fm.authList.Select(2)
|
||||
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
fm = updated.(*formModel)
|
||||
|
||||
if fm.showAuthList {
|
||||
t.Fatal("expected Enter to close auth method list")
|
||||
}
|
||||
if got := fm.inputs[5].Value(); got != string(model.AuthKeyPassphrase) {
|
||||
t.Fatalf("expected auth method %q, got %q", model.AuthKeyPassphrase, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMethodListViewShowsAllOptions(t *testing.T) {
|
||||
fm := newFormModel(80, 12)
|
||||
fm.focusIdx = 5
|
||||
fm.updateFocus()
|
||||
|
||||
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||
fm = updated.(*formModel)
|
||||
|
||||
view := fm.View()
|
||||
authPos := strings.Index(view, "Auth Method")
|
||||
listPos := strings.Index(view, "Select auth method")
|
||||
if authPos < 0 || listPos < 0 {
|
||||
t.Fatalf("expected auth field and auth list title in view\nview:\n%s", view)
|
||||
}
|
||||
if listPos < authPos {
|
||||
t.Fatalf("expected auth method list after auth field\nview:\n%s", view)
|
||||
}
|
||||
if between := view[authPos:listPos]; strings.Contains(between, "Identity File") {
|
||||
t.Fatalf("expected auth method list to render directly under auth field\nview:\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "│") {
|
||||
t.Fatalf("expected compact auth method dropdown without default list border\nview:\n%s", view)
|
||||
}
|
||||
for _, method := range []model.AuthMethod{
|
||||
model.AuthPassword,
|
||||
model.AuthKey,
|
||||
model.AuthKeyPassphrase,
|
||||
model.AuthAgent,
|
||||
} {
|
||||
if !strings.Contains(view, string(method)) {
|
||||
t.Fatalf("expected auth method list view to contain %q\nview:\n%s", method, view)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupListViewRendersDirectlyUnderGroupField(t *testing.T) {
|
||||
oldGetGroups := GetGroups
|
||||
GetGroups = func() ([]string, error) {
|
||||
return []string{"KP", "MY"}, nil
|
||||
}
|
||||
defer func() { GetGroups = oldGetGroups }()
|
||||
|
||||
fm := newFormModel(80, 24)
|
||||
fm.focusIdx = 8
|
||||
fm.updateFocus()
|
||||
|
||||
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||
fm = updated.(*formModel)
|
||||
|
||||
view := fm.View()
|
||||
groupPos := strings.Index(view, "Group")
|
||||
listPos := strings.Index(view, "Select group")
|
||||
if groupPos < 0 || listPos < 0 {
|
||||
t.Fatalf("expected group field and group list title in view\nview:\n%s", view)
|
||||
}
|
||||
if listPos < groupPos {
|
||||
t.Fatalf("expected group list after group field\nview:\n%s", view)
|
||||
}
|
||||
if between := view[groupPos:listPos]; strings.Contains(between, "Password") {
|
||||
t.Fatalf("expected group dropdown to render before password field\nview:\n%s", view)
|
||||
}
|
||||
if strings.Contains(view, "│") {
|
||||
t.Fatalf("expected compact group dropdown without default list border\nview:\n%s", view)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectableFieldHintsAreVisible(t *testing.T) {
|
||||
oldGetGroups := GetGroups
|
||||
GetGroups = func() ([]string, error) {
|
||||
return []string{"KP"}, nil
|
||||
}
|
||||
defer func() { GetGroups = oldGetGroups }()
|
||||
|
||||
fm := newFormModel(80, 24)
|
||||
view := fm.View()
|
||||
for _, want := range []string{"Auth Method (/ pick)", "Group (/ pick)", "/ pick list"} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected form view to contain selectable-field hint %q\nview:\n%s", want, view)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditFormShowsSavedSecretMarker(t *testing.T) {
|
||||
oldHasSecret := HasSecret
|
||||
HasSecret = func(alias string, secretType string) bool {
|
||||
return alias == "prod" && secretType == "ssh_password"
|
||||
}
|
||||
defer func() { HasSecret = oldHasSecret }()
|
||||
|
||||
fm := newEditFormModel(&model.Server{
|
||||
Alias: "prod",
|
||||
Host: "example.org",
|
||||
Port: 22,
|
||||
User: "root",
|
||||
AuthMethod: model.AuthPassword,
|
||||
}, 80, 24)
|
||||
|
||||
view := fm.View()
|
||||
if !strings.Contains(view, "secret saved") {
|
||||
t.Fatalf("expected edit form to show saved secret marker\nview:\n%s", view)
|
||||
}
|
||||
if !strings.Contains(view, "leave blank to keep") {
|
||||
t.Fatalf("expected edit form to explain blank password keeps saved secret\nview:\n%s", view)
|
||||
}
|
||||
if strings.Count(view, "secret saved") != 1 {
|
||||
t.Fatalf("expected saved secret marker to appear once\nview:\n%s", view)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormViewUsesSectionsAndStableLabels(t *testing.T) {
|
||||
fm := newFormModel(100, 30)
|
||||
view := fm.View()
|
||||
|
||||
for _, want := range []string{
|
||||
"Identity",
|
||||
"Connection",
|
||||
"Authentication",
|
||||
"Metadata",
|
||||
"Actions",
|
||||
"Alias",
|
||||
"Display Name",
|
||||
"Auth Method",
|
||||
"Password / Passphrase",
|
||||
} {
|
||||
if !strings.Contains(view, want) {
|
||||
t.Fatalf("expected form view to contain %q\nview:\n%s", want, view)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormTestResultDoesNotUpdateSelectedListServer(t *testing.T) {
|
||||
oldUpdateTestResult := UpdateTestResult
|
||||
oldListServers := ListServers
|
||||
defer func() {
|
||||
UpdateTestResult = oldUpdateTestResult
|
||||
ListServers = oldListServers
|
||||
}()
|
||||
|
||||
updateCalled := false
|
||||
UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
|
||||
updateCalled = true
|
||||
return nil
|
||||
}
|
||||
ListServers = func() ([]*model.Server, error) {
|
||||
t.Fatal("form test result should not reload server list")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
selected := &model.Server{Alias: "selected", Host: "example.org", Port: 22, User: "root"}
|
||||
m := New([]*model.Server{selected})
|
||||
m.screen = screenForm
|
||||
m.form = newFormModel(80, 24)
|
||||
m.list = list.New([]list.Item{serverItem{server: selected}}, list.NewDefaultDelegate(), 80, 20)
|
||||
|
||||
updated, cmd := m.Update(testDoneMsg{ok: true})
|
||||
m = updated.(*tuiModel)
|
||||
|
||||
if cmd != nil {
|
||||
t.Fatal("form test result should not return a reload command")
|
||||
}
|
||||
if updateCalled {
|
||||
t.Fatal("form test result should not update the selected list server")
|
||||
}
|
||||
if m.form == nil || m.form.testResult != "Connection OK." {
|
||||
t.Fatal("expected form to keep its test result")
|
||||
}
|
||||
}
|
||||
|
|
@ -8,6 +8,8 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -16,18 +18,20 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
currentVersion = 1
|
||||
saltLen = 32
|
||||
nonceLen = 24
|
||||
keyLen = 32
|
||||
currentVersion = 1
|
||||
saltLen = 32
|
||||
nonceLen = 24
|
||||
keyLen = 32
|
||||
verifierID = "__sshkeeper_vault_verifier__"
|
||||
verifierPlaintext = "sshkeeper-vault-verifier-v1"
|
||||
)
|
||||
|
||||
type KDFMeta struct {
|
||||
Name string `json:"name"`
|
||||
MemoryKiB int `json:"memory_kib"`
|
||||
Iterations int `json:"iterations"`
|
||||
Parallelism int `json:"parallelism"`
|
||||
Salt string `json:"salt"`
|
||||
Name string `json:"name"`
|
||||
MemoryKiB int `json:"memory_kib"`
|
||||
Iterations int `json:"iterations"`
|
||||
Parallelism int `json:"parallelism"`
|
||||
Salt string `json:"salt"`
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
|
|
@ -38,23 +42,35 @@ type Record struct {
|
|||
}
|
||||
|
||||
type VaultFile struct {
|
||||
Version int `json:"version"`
|
||||
KDF KDFMeta `json:"kdf"`
|
||||
Records []Record `json:"records"`
|
||||
Version int `json:"version"`
|
||||
KDF KDFMeta `json:"kdf"`
|
||||
Verifier *Record `json:"verifier,omitempty"`
|
||||
Records []Record `json:"records"`
|
||||
}
|
||||
|
||||
type Vault struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
mu sync.Mutex
|
||||
path string
|
||||
masterKey []byte
|
||||
records map[string][]byte // id -> plaintext
|
||||
modified bool
|
||||
records map[string]secretRecord
|
||||
modified bool
|
||||
}
|
||||
|
||||
type secretRecord struct {
|
||||
secretType string
|
||||
plaintext []byte
|
||||
}
|
||||
|
||||
type SecretMeta struct {
|
||||
ID string
|
||||
Alias string
|
||||
Type string
|
||||
}
|
||||
|
||||
func New(path string) *Vault {
|
||||
return &Vault{
|
||||
path: path,
|
||||
records: make(map[string][]byte),
|
||||
records: make(map[string]secretRecord),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -76,21 +92,26 @@ func Create(path string, masterPassword string) error {
|
|||
|
||||
kdf := KDFMeta{
|
||||
Name: "argon2id",
|
||||
MemoryKiB: 4096,
|
||||
Iterations: 2,
|
||||
MemoryKiB: 65536,
|
||||
Iterations: 3,
|
||||
Parallelism: 1,
|
||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||
}
|
||||
|
||||
fmt.Print("Deriving key...")
|
||||
|
||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB)*1024, uint8(kdf.Parallelism), keyLen)
|
||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB), uint8(kdf.Parallelism), keyLen)
|
||||
|
||||
verifier, err := newVerifierRecord(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create verifier: %w", err)
|
||||
}
|
||||
|
||||
// Verify key is valid by doing a test encrypt/decrypt
|
||||
vf := VaultFile{
|
||||
Version: currentVersion,
|
||||
KDF: kdf,
|
||||
Records: []Record{},
|
||||
Version: currentVersion,
|
||||
KDF: kdf,
|
||||
Verifier: &verifier,
|
||||
Records: []Record{},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(vf)
|
||||
|
|
@ -143,24 +164,32 @@ func (v *Vault) Unlock(masterPassword string) error {
|
|||
return fmt.Errorf("decode salt: %w", err)
|
||||
}
|
||||
|
||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen)
|
||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
|
||||
|
||||
// Try to decrypt first record to verify password
|
||||
if len(vf.Records) > 0 {
|
||||
if vf.Verifier != nil {
|
||||
if err := verifyRecord(key, *vf.Verifier); err != nil {
|
||||
return fmt.Errorf("invalid master password")
|
||||
}
|
||||
} else if len(vf.Records) > 0 {
|
||||
if _, err := decryptRecord(key, vf.Records[0]); err != nil {
|
||||
return fmt.Errorf("invalid master password")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("vault cannot verify master password; recreate empty vault")
|
||||
}
|
||||
|
||||
v.masterKey = key
|
||||
v.records = make(map[string][]byte)
|
||||
v.records = make(map[string]secretRecord)
|
||||
|
||||
for _, rec := range vf.Records {
|
||||
plaintext, err := decryptRecord(key, rec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt record %s: %w", rec.ID, err)
|
||||
}
|
||||
v.records[rec.ID] = plaintext
|
||||
v.records[rec.ID] = secretRecord{
|
||||
secretType: inferSecretType(rec.ID, rec.Type),
|
||||
plaintext: plaintext,
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(" done.")
|
||||
|
|
@ -178,7 +207,7 @@ func (v *Vault) Lock() {
|
|||
}
|
||||
}
|
||||
v.masterKey = nil
|
||||
v.records = make(map[string][]byte)
|
||||
v.records = make(map[string]secretRecord)
|
||||
}
|
||||
|
||||
// IsUnlocked returns whether the vault is currently unlocked
|
||||
|
|
@ -197,7 +226,9 @@ func (v *Vault) Put(id string, secretType string, plaintext []byte) error {
|
|||
return fmt.Errorf("vault is locked")
|
||||
}
|
||||
|
||||
v.records[id] = plaintext
|
||||
data := make([]byte, len(plaintext))
|
||||
copy(data, plaintext)
|
||||
v.records[id] = secretRecord{secretType: secretType, plaintext: data}
|
||||
v.modified = true
|
||||
return nil
|
||||
}
|
||||
|
|
@ -211,16 +242,55 @@ func (v *Vault) Get(id string) ([]byte, error) {
|
|||
return nil, fmt.Errorf("vault is locked")
|
||||
}
|
||||
|
||||
data, ok := v.records[id]
|
||||
record, ok := v.records[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret not found: %s", id)
|
||||
}
|
||||
|
||||
result := make([]byte, len(data))
|
||||
copy(result, data)
|
||||
result := make([]byte, len(record.plaintext))
|
||||
copy(result, record.plaintext)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (v *Vault) HasSecret(id string) bool {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
_, ok := v.records[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (v *Vault) ListSecrets() ([]SecretMeta, error) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
||||
if v.masterKey == nil {
|
||||
return nil, fmt.Errorf("vault is locked")
|
||||
}
|
||||
|
||||
metas := make([]SecretMeta, 0, len(v.records))
|
||||
for id, record := range v.records {
|
||||
alias, secretType, ok := parseServerSecretID(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if record.secretType != "" {
|
||||
secretType = record.secretType
|
||||
}
|
||||
metas = append(metas, SecretMeta{
|
||||
ID: id,
|
||||
Alias: alias,
|
||||
Type: secretType,
|
||||
})
|
||||
}
|
||||
sort.Slice(metas, func(i, j int) bool {
|
||||
if metas[i].Alias == metas[j].Alias {
|
||||
return metas[i].Type < metas[j].Type
|
||||
}
|
||||
return metas[i].Alias < metas[j].Alias
|
||||
})
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
// Delete removes a secret
|
||||
func (v *Vault) Delete(id string) {
|
||||
v.mu.Lock()
|
||||
|
|
@ -245,8 +315,8 @@ func (v *Vault) Save() error {
|
|||
|
||||
kdf := KDFMeta{
|
||||
Name: "argon2id",
|
||||
MemoryKiB: 4096,
|
||||
Iterations: 2,
|
||||
MemoryKiB: 65536,
|
||||
Iterations: 3,
|
||||
Parallelism: 1,
|
||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||
}
|
||||
|
|
@ -254,18 +324,24 @@ func (v *Vault) Save() error {
|
|||
fmt.Print("Deriving key...")
|
||||
|
||||
var records []Record
|
||||
for id, plaintext := range v.records {
|
||||
rec, err := encryptRecord(v.masterKey, id, plaintext)
|
||||
for id, record := range v.records {
|
||||
rec, err := encryptRecordWithType(v.masterKey, id, record.secretType, record.plaintext)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt record %s: %w", id, err)
|
||||
}
|
||||
records = append(records, rec)
|
||||
}
|
||||
|
||||
verifier, err := newVerifierRecord(v.masterKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create verifier: %w", err)
|
||||
}
|
||||
|
||||
vf := VaultFile{
|
||||
Version: currentVersion,
|
||||
KDF: kdf,
|
||||
Records: records,
|
||||
Version: currentVersion,
|
||||
KDF: kdf,
|
||||
Verifier: &verifier,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(vf)
|
||||
|
|
@ -309,12 +385,12 @@ func (v *Vault) ChangePassword(newPassword string) error {
|
|||
return fmt.Errorf("generate salt: %w", err)
|
||||
}
|
||||
|
||||
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 8192*1024, 1, keyLen)
|
||||
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 65536, 1, keyLen)
|
||||
|
||||
kdf := KDFMeta{
|
||||
Name: "argon2id",
|
||||
MemoryKiB: 4096,
|
||||
Iterations: 2,
|
||||
MemoryKiB: 65536,
|
||||
Iterations: 3,
|
||||
Parallelism: 1,
|
||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||
}
|
||||
|
|
@ -322,18 +398,24 @@ func (v *Vault) ChangePassword(newPassword string) error {
|
|||
fmt.Print("Deriving key...")
|
||||
|
||||
var records []Record
|
||||
for id, plaintext := range v.records {
|
||||
rec, err := encryptRecord(newKey, id, plaintext)
|
||||
for id, record := range v.records {
|
||||
rec, err := encryptRecordWithType(newKey, id, record.secretType, record.plaintext)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt record: %w", err)
|
||||
}
|
||||
records = append(records, rec)
|
||||
}
|
||||
|
||||
verifier, err := newVerifierRecord(newKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create verifier: %w", err)
|
||||
}
|
||||
|
||||
vf := VaultFile{
|
||||
Version: currentVersion,
|
||||
KDF: kdf,
|
||||
Records: records,
|
||||
Version: currentVersion,
|
||||
KDF: kdf,
|
||||
Verifier: &verifier,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(vf)
|
||||
|
|
@ -382,6 +464,10 @@ func (v *Vault) getSalt() string {
|
|||
}
|
||||
|
||||
func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
|
||||
return encryptRecordWithType(key, id, "", plaintext)
|
||||
}
|
||||
|
||||
func encryptRecordWithType(key []byte, id string, secretType string, plaintext []byte) (Record, error) {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return Record{}, err
|
||||
|
|
@ -396,11 +482,31 @@ func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
|
|||
|
||||
return Record{
|
||||
ID: id,
|
||||
Type: secretType,
|
||||
Nonce: base64.StdEncoding.EncodeToString(nonce),
|
||||
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func inferSecretType(id string, recordType string) string {
|
||||
if recordType != "" {
|
||||
return recordType
|
||||
}
|
||||
_, secretType, ok := parseServerSecretID(id)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return secretType
|
||||
}
|
||||
|
||||
func parseServerSecretID(id string) (string, string, bool) {
|
||||
parts := strings.Split(id, ":")
|
||||
if len(parts) != 3 || parts[0] != "server" || parts[1] == "" || parts[2] == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return parts[1], parts[2], true
|
||||
}
|
||||
|
||||
func decryptRecord(key []byte, rec Record) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
|
|
@ -425,6 +531,26 @@ func decryptRecord(key []byte, rec Record) ([]byte, error) {
|
|||
return plaintext, nil
|
||||
}
|
||||
|
||||
func newVerifierRecord(key []byte) (Record, error) {
|
||||
rec, err := encryptRecord(key, verifierID, []byte(verifierPlaintext))
|
||||
if err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
rec.Type = "verifier"
|
||||
return rec, nil
|
||||
}
|
||||
|
||||
func verifyRecord(key []byte, rec Record) error {
|
||||
plaintext, err := decryptRecord(key, rec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !SecureCompare(string(plaintext), verifierPlaintext) {
|
||||
return fmt.Errorf("invalid verifier")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks if a master password is correct without unlocking
|
||||
func VerifyPassword(path string, masterPassword string) (bool, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
|
|
@ -442,18 +568,20 @@ func VerifyPassword(path string, masterPassword string) (bool, error) {
|
|||
return false, err
|
||||
}
|
||||
|
||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen)
|
||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
|
||||
defer func() {
|
||||
for i := range key {
|
||||
key[i] = 0
|
||||
}
|
||||
}()
|
||||
|
||||
if len(vf.Records) == 0 {
|
||||
// Empty vault, try a test encryption
|
||||
return true, nil
|
||||
if vf.Verifier != nil {
|
||||
return verifyRecord(key, *vf.Verifier) == nil, nil
|
||||
}
|
||||
|
||||
if len(vf.Records) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
_, err = decryptRecord(key, vf.Records[0])
|
||||
return err == nil, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,218 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
func TestNewEmptyVaultRejectsWrongPassword(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
|
||||
if err := Create(path, "correct horse"); err != nil {
|
||||
t.Fatalf("create vault: %v", err)
|
||||
}
|
||||
|
||||
ok, err := VerifyPassword(path, "wrong horse")
|
||||
if err != nil {
|
||||
t.Fatalf("verify wrong password: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected wrong password to be rejected for a new empty vault")
|
||||
}
|
||||
|
||||
v := New(path)
|
||||
if err := v.Unlock("wrong horse"); err == nil {
|
||||
t.Fatal("expected unlock with wrong password to fail for a new empty vault")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEmptyVaultAcceptsCorrectPassword(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
|
||||
if err := Create(path, "correct horse"); err != nil {
|
||||
t.Fatalf("create vault: %v", err)
|
||||
}
|
||||
|
||||
ok, err := VerifyPassword(path, "correct horse")
|
||||
if err != nil {
|
||||
t.Fatalf("verify correct password: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected correct password to be accepted for a new empty vault")
|
||||
}
|
||||
|
||||
v := New(path)
|
||||
if err := v.Unlock("correct horse"); err != nil {
|
||||
t.Fatalf("unlock with correct password: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyVaultWithRecordsStillVerifiesByFirstRecord(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
salt := []byte("12345678901234567890123456789012")
|
||||
key := argon2.IDKey([]byte("correct horse"), salt, 3, 65536, 1, keyLen)
|
||||
|
||||
rec, err := encryptRecord(key, "server:test:ssh_password", []byte("secret"))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt legacy record: %v", err)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(VaultFile{
|
||||
Version: currentVersion,
|
||||
KDF: KDFMeta{
|
||||
Name: "argon2id",
|
||||
MemoryKiB: 65536,
|
||||
Iterations: 3,
|
||||
Parallelism: 1,
|
||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||
},
|
||||
Records: []Record{rec},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal legacy vault: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||
t.Fatalf("write legacy vault: %v", err)
|
||||
}
|
||||
|
||||
ok, err := VerifyPassword(path, "correct horse")
|
||||
if err != nil {
|
||||
t.Fatalf("verify legacy vault: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected legacy vault with records to accept correct password")
|
||||
}
|
||||
|
||||
ok, err = VerifyPassword(path, "wrong horse")
|
||||
if err != nil {
|
||||
t.Fatalf("verify legacy vault with wrong password: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected legacy vault with records to reject wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyEmptyVaultWithoutVerifierCannotUnlock(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
salt := []byte("12345678901234567890123456789012")
|
||||
|
||||
data, err := json.Marshal(VaultFile{
|
||||
Version: currentVersion,
|
||||
KDF: KDFMeta{
|
||||
Name: "argon2id",
|
||||
MemoryKiB: 65536,
|
||||
Iterations: 3,
|
||||
Parallelism: 1,
|
||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||
},
|
||||
Records: []Record{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal legacy empty vault: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||
t.Fatalf("write legacy empty vault: %v", err)
|
||||
}
|
||||
|
||||
ok, err := VerifyPassword(path, "any password")
|
||||
if err != nil {
|
||||
t.Fatalf("verify legacy empty vault: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("expected legacy empty vault without verifier to be unverifiable")
|
||||
}
|
||||
|
||||
v := New(path)
|
||||
if err := v.Unlock("any password"); err == nil {
|
||||
t.Fatal("expected legacy empty vault without verifier to reject unlock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSecretsReturnsMetadataWithoutPlaintext(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
if err := Create(path, "master"); err != nil {
|
||||
t.Fatalf("create vault: %v", err)
|
||||
}
|
||||
v := New(path)
|
||||
if err := v.Unlock("master"); err != nil {
|
||||
t.Fatalf("unlock vault: %v", err)
|
||||
}
|
||||
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
|
||||
t.Fatalf("put password: %v", err)
|
||||
}
|
||||
if err := v.Put("server:prod:key_passphrase", "key_passphrase", []byte("secret-passphrase")); err != nil {
|
||||
t.Fatalf("put passphrase: %v", err)
|
||||
}
|
||||
|
||||
metas, err := v.ListSecrets()
|
||||
if err != nil {
|
||||
t.Fatalf("list secrets: %v", err)
|
||||
}
|
||||
if len(metas) != 2 {
|
||||
t.Fatalf("expected 2 secret metadata records, got %d: %#v", len(metas), metas)
|
||||
}
|
||||
if metas[0].Alias != "prod" || metas[0].Type != "key_passphrase" {
|
||||
t.Fatalf("unexpected first metadata record: %#v", metas[0])
|
||||
}
|
||||
if metas[1].Alias != "prod" || metas[1].Type != "ssh_password" {
|
||||
t.Fatalf("unexpected second metadata record: %#v", metas[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSecretsPreservesTypesAfterSaveAndUnlock(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
if err := Create(path, "master"); err != nil {
|
||||
t.Fatalf("create vault: %v", err)
|
||||
}
|
||||
v := New(path)
|
||||
if err := v.Unlock("master"); err != nil {
|
||||
t.Fatalf("unlock vault: %v", err)
|
||||
}
|
||||
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
|
||||
t.Fatalf("put password: %v", err)
|
||||
}
|
||||
if err := v.Save(); err != nil {
|
||||
t.Fatalf("save vault: %v", err)
|
||||
}
|
||||
|
||||
reopened := New(path)
|
||||
if err := reopened.Unlock("master"); err != nil {
|
||||
t.Fatalf("unlock reopened vault: %v", err)
|
||||
}
|
||||
metas, err := reopened.ListSecrets()
|
||||
if err != nil {
|
||||
t.Fatalf("list reopened secrets: %v", err)
|
||||
}
|
||||
if len(metas) != 1 {
|
||||
t.Fatalf("expected 1 secret metadata record, got %d: %#v", len(metas), metas)
|
||||
}
|
||||
if metas[0].ID != "server:prod:ssh_password" || metas[0].Alias != "prod" || metas[0].Type != "ssh_password" {
|
||||
t.Fatalf("unexpected metadata after reopen: %#v", metas[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSecretReportsPresenceWithoutReturningValue(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||
if err := Create(path, "master"); err != nil {
|
||||
t.Fatalf("create vault: %v", err)
|
||||
}
|
||||
v := New(path)
|
||||
if err := v.Unlock("master"); err != nil {
|
||||
t.Fatalf("unlock vault: %v", err)
|
||||
}
|
||||
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
|
||||
t.Fatalf("put password: %v", err)
|
||||
}
|
||||
|
||||
if !v.HasSecret("server:prod:ssh_password") {
|
||||
t.Fatal("expected saved password to be reported present")
|
||||
}
|
||||
if v.HasSecret("server:prod:key_passphrase") {
|
||||
t.Fatal("expected missing passphrase to be reported absent")
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue