feat: improve tui and vault handling
This commit is contained in:
parent
73c60b9f93
commit
e1d709396b
51
cmd/add.go
51
cmd/add.go
|
|
@ -3,9 +3,11 @@ package cmd
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/mirivlad/sshkeeper/internal/model"
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
var addFlags struct {
|
var addFlags struct {
|
||||||
|
|
@ -19,7 +21,6 @@ var addFlags struct {
|
||||||
displayName string
|
displayName string
|
||||||
notes string
|
notes string
|
||||||
tags string
|
tags string
|
||||||
password string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var addCmd = &cobra.Command{
|
var addCmd = &cobra.Command{
|
||||||
|
|
@ -59,11 +60,21 @@ func addNonInteractive(alias string) error {
|
||||||
server.DisplayName = alias
|
server.DisplayName = alias
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle password auth - store in vault
|
// Handle password/passphrase auth — request interactively, never via argv
|
||||||
if server.AuthMethod == model.AuthPassword {
|
if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
|
||||||
password := addFlags.password
|
secretType := "password"
|
||||||
if password == "" {
|
if server.AuthMethod == model.AuthKeyPassphrase {
|
||||||
return fmt.Errorf("password auth requires --password flag or interactive mode")
|
secretType = "passphrase"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Enter %s (will be stored in vault, input hidden): ", secretType)
|
||||||
|
password, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
fmt.Println()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read %s: %w", secretType, err)
|
||||||
|
}
|
||||||
|
if len(password) == 0 {
|
||||||
|
return fmt.Errorf("%s cannot be empty", secretType)
|
||||||
}
|
}
|
||||||
|
|
||||||
v := getOrCreateVault()
|
v := getOrCreateVault()
|
||||||
|
|
@ -72,29 +83,14 @@ func addNonInteractive(alias string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
vaultKey := fmt.Sprintf("server:%s:ssh_password", alias)
|
vaultKey := fmt.Sprintf("server:%s:ssh_password", alias)
|
||||||
if err := v.Put(vaultKey, "ssh_password", []byte(password)); err != nil {
|
vaultType := "ssh_password"
|
||||||
return fmt.Errorf("store password in vault: %w", err)
|
if server.AuthMethod == model.AuthKeyPassphrase {
|
||||||
}
|
vaultKey = fmt.Sprintf("server:%s:key_passphrase", alias)
|
||||||
if err := v.Save(); err != nil {
|
vaultType = "key_passphrase"
|
||||||
return fmt.Errorf("save vault: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle key+passphrase - store passphrase in vault
|
|
||||||
if server.AuthMethod == model.AuthKeyPassphrase {
|
|
||||||
passphrase := addFlags.password
|
|
||||||
if passphrase == "" {
|
|
||||||
return fmt.Errorf("key+passphrase auth requires --password flag for the passphrase")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
v := getOrCreateVault()
|
if err := v.Put(vaultKey, vaultType, password); err != nil {
|
||||||
if !v.IsUnlocked() {
|
return fmt.Errorf("store %s in vault: %w", secretType, err)
|
||||||
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
|
|
||||||
}
|
|
||||||
|
|
||||||
vaultKey := fmt.Sprintf("server:%s:key_passphrase", alias)
|
|
||||||
if err := v.Put(vaultKey, "key_passphrase", []byte(passphrase)); err != nil {
|
|
||||||
return fmt.Errorf("store passphrase in vault: %w", err)
|
|
||||||
}
|
}
|
||||||
if err := v.Save(); err != nil {
|
if err := v.Save(); err != nil {
|
||||||
return fmt.Errorf("save vault: %w", err)
|
return fmt.Errorf("save vault: %w", err)
|
||||||
|
|
@ -132,5 +128,4 @@ func init() {
|
||||||
addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name")
|
addCmd.Flags().StringVar(&addFlags.displayName, "display-name", "", "Display name")
|
||||||
addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes")
|
addCmd.Flags().StringVar(&addFlags.notes, "notes", "", "Notes")
|
||||||
addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags")
|
addCmd.Flags().StringVar(&addFlags.tags, "tags", "", "Comma-separated tags")
|
||||||
addCmd.Flags().StringVar(&addFlags.password, "password", "", "SSH password or key passphrase (stored in vault)")
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,15 @@ var deleteCmd = &cobra.Command{
|
||||||
return fmt.Errorf("delete server: %w", err)
|
return fmt.Errorf("delete server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean up vault secrets for this server
|
||||||
|
v := getOrCreateVault()
|
||||||
|
if v.IsUnlocked() {
|
||||||
|
cleanupServerSecrets(v, alias)
|
||||||
|
if err := v.Save(); err != nil {
|
||||||
|
return fmt.Errorf("save vault after cleanup: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Println("Deleted.")
|
fmt.Println("Deleted.")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|
|
||||||
57
cmd/edit.go
57
cmd/edit.go
|
|
@ -2,9 +2,11 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"github.com/mirivlad/sshkeeper/internal/model"
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
var editCmd = &cobra.Command{
|
var editCmd = &cobra.Command{
|
||||||
|
|
@ -18,6 +20,8 @@ var editCmd = &cobra.Command{
|
||||||
return fmt.Errorf("server not found: %s", alias)
|
return fmt.Errorf("server not found: %s", alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldAuthMethod := server.AuthMethod
|
||||||
|
|
||||||
if parsedHost != "" {
|
if parsedHost != "" {
|
||||||
server.Host = parsedHost
|
server.Host = parsedHost
|
||||||
}
|
}
|
||||||
|
|
@ -46,6 +50,41 @@ var editCmd = &cobra.Command{
|
||||||
server.Notes = parsedNotes
|
server.Notes = parsedNotes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if parsedAuth != "" && oldAuthMethod != server.AuthMethod {
|
||||||
|
v := getOrCreateVault()
|
||||||
|
if v.IsUnlocked() {
|
||||||
|
var secret string
|
||||||
|
if server.AuthMethod == model.AuthPassword {
|
||||||
|
fmt.Print("Enter new password (stored in vault, input hidden): ")
|
||||||
|
pw, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
fmt.Println()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read password: %w", err)
|
||||||
|
}
|
||||||
|
if len(pw) > 0 {
|
||||||
|
secret = string(pw)
|
||||||
|
}
|
||||||
|
} else if server.AuthMethod == model.AuthKeyPassphrase {
|
||||||
|
fmt.Print("Enter key passphrase (stored in vault, input hidden): ")
|
||||||
|
pw, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
fmt.Println()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read passphrase: %w", err)
|
||||||
|
}
|
||||||
|
if len(pw) > 0 {
|
||||||
|
secret = string(pw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := syncServerSecrets(v, alias, server, secret); err != nil {
|
||||||
|
return fmt.Errorf("sync vault secrets: %w", err)
|
||||||
|
}
|
||||||
|
if err := v.Save(); err != nil {
|
||||||
|
return fmt.Errorf("save vault: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := appDB.UpdateServer(server); err != nil {
|
if err := appDB.UpdateServer(server); err != nil {
|
||||||
return fmt.Errorf("update server: %w", err)
|
return fmt.Errorf("update server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -56,15 +95,15 @@ var editCmd = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
parsedHost string
|
parsedHost string
|
||||||
parsedPort int
|
parsedPort int
|
||||||
parsedUser string
|
parsedUser string
|
||||||
parsedAuth string
|
parsedAuth string
|
||||||
parsedIdentity string
|
parsedIdentity string
|
||||||
parsedProxyJump string
|
parsedProxyJump string
|
||||||
parsedGroup string
|
parsedGroup string
|
||||||
parsedDisplayName string
|
parsedDisplayName string
|
||||||
parsedNotes string
|
parsedNotes string
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
||||||
35
cmd/extra.go
35
cmd/extra.go
|
|
@ -74,8 +74,13 @@ var runCmd = &cobra.Command{
|
||||||
return fmt.Errorf("server not found: %s", alias)
|
return fmt.Errorf("server not found: %s", alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build ssh args with the command
|
// For password auth, use PTY-wrapper with command
|
||||||
sshArgs := buildSSHArgs(server)
|
if server.AuthMethod == model.AuthPassword {
|
||||||
|
return runWithPassword(server, command)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For key/agent auth — direct execution
|
||||||
|
sshArgs := ssh.BuildSSHArgs(server)
|
||||||
sshArgs = append(sshArgs, command)
|
sshArgs = append(sshArgs, command)
|
||||||
|
|
||||||
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
|
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
|
||||||
|
|
@ -91,17 +96,21 @@ var runCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildSSHArgs(server *model.Server) []string {
|
// runWithPassword runs a command on a server with password auth via PTY-wrapper.
|
||||||
var args []string
|
func runWithPassword(server *model.Server, command string) error {
|
||||||
args = append(args, "-p", fmt.Sprintf("%d", server.Port))
|
v := getOrCreateVault()
|
||||||
if server.IdentityFile != "" {
|
if !v.IsUnlocked() {
|
||||||
args = append(args, "-i", server.IdentityFile)
|
return fmt.Errorf("vault is locked. Run 'sshkeeper vault unlock' first")
|
||||||
}
|
}
|
||||||
if server.ProxyJump != "" {
|
|
||||||
args = append(args, "-J", server.ProxyJump)
|
vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias)
|
||||||
|
password, err := v.Get(vaultKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get password from vault: %w", err)
|
||||||
}
|
}
|
||||||
args = append(args, "-o", "StrictHostKeyChecking=accept-new")
|
|
||||||
target := fmt.Sprintf("%s@%s", server.User, server.Host)
|
sshArgs := ssh.BuildSSHArgs(server)
|
||||||
args = append(args, target)
|
sshArgs = append(sshArgs, command)
|
||||||
return args
|
|
||||||
|
return ssh.ConnectWithPassword(cfg.SSH.Binary, sshArgs, string(password))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/ssh"
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/vault"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
secretSSHPassword = "ssh_password"
|
||||||
|
secretKeyPassphrase = "key_passphrase"
|
||||||
|
secretSudoPassword = "sudo_password"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serverSecretTypes = []string{
|
||||||
|
secretSSHPassword,
|
||||||
|
secretKeyPassphrase,
|
||||||
|
secretSudoPassword,
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverSecretID(alias, secretType string) string {
|
||||||
|
return fmt.Sprintf("server:%s:%s", alias, secretType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupServerSecrets(v *vault.Vault, alias string) {
|
||||||
|
for _, secretType := range serverSecretTypes {
|
||||||
|
v.Delete(serverSecretID(alias, secretType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func syncServerSecrets(v *vault.Vault, oldAlias string, server *model.Server, secret string) error {
|
||||||
|
if oldAlias == "" {
|
||||||
|
oldAlias = server.Alias
|
||||||
|
}
|
||||||
|
if oldAlias != server.Alias {
|
||||||
|
for _, secretType := range serverSecretTypes {
|
||||||
|
oldID := serverSecretID(oldAlias, secretType)
|
||||||
|
data, err := v.Get(oldID)
|
||||||
|
if err == nil {
|
||||||
|
if err := v.Put(serverSecretID(server.Alias, secretType), secretType, data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v.Delete(oldID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch server.AuthMethod {
|
||||||
|
case model.AuthPassword:
|
||||||
|
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
|
||||||
|
if secret != "" {
|
||||||
|
return v.Put(serverSecretID(server.Alias, secretSSHPassword), secretSSHPassword, []byte(secret))
|
||||||
|
}
|
||||||
|
case model.AuthKeyPassphrase:
|
||||||
|
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
|
||||||
|
if secret != "" {
|
||||||
|
return v.Put(serverSecretID(server.Alias, secretKeyPassphrase), secretKeyPassphrase, []byte(secret))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
v.Delete(serverSecretID(server.Alias, secretSSHPassword))
|
||||||
|
v.Delete(serverSecretID(server.Alias, secretKeyPassphrase))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteVaultSecrets(v *vault.Vault, alias string, secretType string) error {
|
||||||
|
if secretType != "" {
|
||||||
|
v.Delete(serverSecretID(alias, secretType))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cleanupServerSecrets(v, alias)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formTestVaultFunc(getVault ssh.VaultFunc, server *model.Server, formSecret string) ssh.VaultFunc {
|
||||||
|
return func(serverAlias string, secretType string) (string, error) {
|
||||||
|
if (secretType == secretSSHPassword || secretType == secretKeyPassphrase) && formSecret != "" {
|
||||||
|
return formSecret, nil
|
||||||
|
}
|
||||||
|
return getVault(serverAlias, secretType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,186 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/vault"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newUnlockedTestVault(t *testing.T) *vault.Vault {
|
||||||
|
t.Helper()
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
if err := vault.Create(path, "master"); err != nil {
|
||||||
|
t.Fatalf("create vault: %v", err)
|
||||||
|
}
|
||||||
|
v := vault.New(path)
|
||||||
|
if err := v.Unlock("master"); err != nil {
|
||||||
|
t.Fatalf("unlock vault: %v", err)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustPutSecret(t *testing.T, v *vault.Vault, alias, secretType, value string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := v.Put(serverSecretID(alias, secretType), secretType, []byte(value)); err != nil {
|
||||||
|
t.Fatalf("put %s: %v", secretType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustGetSecret(t *testing.T, v *vault.Vault, alias, secretType string) string {
|
||||||
|
t.Helper()
|
||||||
|
data, err := v.Get(serverSecretID(alias, secretType))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get %s: %v", secretType, err)
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertSecretMissing(t *testing.T, v *vault.Vault, alias, secretType string) {
|
||||||
|
t.Helper()
|
||||||
|
if data, err := v.Get(serverSecretID(alias, secretType)); err == nil {
|
||||||
|
t.Fatalf("expected %s for %s to be missing, got %q", secretType, alias, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupServerSecretsRemovesAuthAndSudoSecrets(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||||
|
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||||
|
mustPutSecret(t, v, "prod", "sudo_password", "sudo")
|
||||||
|
|
||||||
|
cleanupServerSecrets(v, "prod")
|
||||||
|
|
||||||
|
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||||
|
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||||
|
assertSecretMissing(t, v, "prod", "sudo_password")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncServerSecretsDeletesAuthSecretsForKeyAuth(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||||
|
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||||
|
|
||||||
|
server := &model.Server{Alias: "prod", AuthMethod: model.AuthKey}
|
||||||
|
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
|
||||||
|
t.Fatalf("sync secrets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||||
|
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncServerSecretsKeepsExistingPasswordWhenEditPasswordIsBlank(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "prod", "ssh_password", "old-password")
|
||||||
|
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
|
||||||
|
|
||||||
|
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
|
||||||
|
if err := syncServerSecrets(v, "prod", server, ""); err != nil {
|
||||||
|
t.Fatalf("sync secrets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "old-password" {
|
||||||
|
t.Fatalf("expected password to remain, got %q", got)
|
||||||
|
}
|
||||||
|
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncServerSecretsRenamesSecretsBeforeApplyingAuthCleanup(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "old", "ssh_password", "password")
|
||||||
|
mustPutSecret(t, v, "old", "key_passphrase", "passphrase")
|
||||||
|
mustPutSecret(t, v, "old", "sudo_password", "sudo")
|
||||||
|
|
||||||
|
server := &model.Server{Alias: "new", AuthMethod: model.AuthKeyPassphrase}
|
||||||
|
if err := syncServerSecrets(v, "old", server, ""); err != nil {
|
||||||
|
t.Fatalf("sync secrets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSecretMissing(t, v, "old", "ssh_password")
|
||||||
|
assertSecretMissing(t, v, "old", "key_passphrase")
|
||||||
|
assertSecretMissing(t, v, "old", "sudo_password")
|
||||||
|
assertSecretMissing(t, v, "new", "ssh_password")
|
||||||
|
if got := mustGetSecret(t, v, "new", "key_passphrase"); got != "passphrase" {
|
||||||
|
t.Fatalf("expected key passphrase to move, got %q", got)
|
||||||
|
}
|
||||||
|
if got := mustGetSecret(t, v, "new", "sudo_password"); got != "sudo" {
|
||||||
|
t.Fatalf("expected sudo password to move, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncServerSecretsStoresNewSecretForSelectedAuthMethod(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "prod", "key_passphrase", "old-passphrase")
|
||||||
|
|
||||||
|
server := &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}
|
||||||
|
if err := syncServerSecrets(v, "prod", server, "new-password"); err != nil {
|
||||||
|
t.Fatalf("sync secrets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := mustGetSecret(t, v, "prod", "ssh_password"); got != "new-password" {
|
||||||
|
t.Fatalf("expected new password, got %q", got)
|
||||||
|
}
|
||||||
|
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteVaultSecretsForAliasAndType(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||||
|
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||||
|
|
||||||
|
if err := deleteVaultSecrets(v, "prod", "ssh_password"); err != nil {
|
||||||
|
t.Fatalf("delete vault secret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||||
|
if got := mustGetSecret(t, v, "prod", "key_passphrase"); got != "passphrase" {
|
||||||
|
t.Fatalf("expected passphrase to remain, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteVaultSecretsForAlias(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "prod", "ssh_password", "password")
|
||||||
|
mustPutSecret(t, v, "prod", "key_passphrase", "passphrase")
|
||||||
|
|
||||||
|
if err := deleteVaultSecrets(v, "prod", ""); err != nil {
|
||||||
|
t.Fatalf("delete vault secrets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSecretMissing(t, v, "prod", "ssh_password")
|
||||||
|
assertSecretMissing(t, v, "prod", "key_passphrase")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormTestVaultFuncUsesSavedSecretWhenFormSecretBlank(t *testing.T) {
|
||||||
|
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
|
||||||
|
if serverAlias != "prod" || secretType != "ssh_password" {
|
||||||
|
return "", fmt.Errorf("unexpected secret lookup %s %s", serverAlias, secretType)
|
||||||
|
}
|
||||||
|
return "saved-password", nil
|
||||||
|
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "")
|
||||||
|
|
||||||
|
got, err := vaultFunc("prod", "ssh_password")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get password: %v", err)
|
||||||
|
}
|
||||||
|
if got != "saved-password" {
|
||||||
|
t.Fatalf("expected saved password, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormTestVaultFuncUsesFormSecretWhenProvided(t *testing.T) {
|
||||||
|
vaultFunc := formTestVaultFunc(func(serverAlias string, secretType string) (string, error) {
|
||||||
|
return "", fmt.Errorf("saved secret should not be used")
|
||||||
|
}, &model.Server{Alias: "prod", AuthMethod: model.AuthPassword}, "typed-password")
|
||||||
|
|
||||||
|
got, err := vaultFunc("prod", "ssh_password")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get password: %v", err)
|
||||||
|
}
|
||||||
|
if got != "typed-password" {
|
||||||
|
t.Fatalf("expected typed password, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,8 +2,11 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var templateCmd = &cobra.Command{
|
var templateCmd = &cobra.Command{
|
||||||
|
|
@ -70,13 +73,15 @@ var runTemplateCmd = &cobra.Command{
|
||||||
alias := args[0]
|
alias := args[0]
|
||||||
templateName := args[1]
|
templateName := args[1]
|
||||||
|
|
||||||
|
server, err := appDB.GetServer(alias)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("server not found: %s", alias)
|
||||||
|
}
|
||||||
|
|
||||||
templates, err := appDB.GetCommandTemplates(alias)
|
templates, err := appDB.GetCommandTemplates(alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list templates: %w", err)
|
return fmt.Errorf("list templates: %w", err)
|
||||||
}
|
}
|
||||||
if len(templates) == 0 {
|
|
||||||
return fmt.Errorf("server not found or no templates: %s", alias)
|
|
||||||
}
|
|
||||||
|
|
||||||
var command string
|
var command string
|
||||||
for _, t := range templates {
|
for _, t := range templates {
|
||||||
|
|
@ -91,7 +96,21 @@ var runTemplateCmd = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Running '%s' on %s...\n", command, alias)
|
fmt.Printf("Running '%s' on %s...\n", command, alias)
|
||||||
return nil
|
|
||||||
|
// Build ssh args with the command
|
||||||
|
sshArgs := ssh.BuildSSHArgs(server)
|
||||||
|
sshArgs = append(sshArgs, command)
|
||||||
|
|
||||||
|
sshCmd := exec.Command(cfg.SSH.Binary, sshArgs...)
|
||||||
|
sshCmd.Stdin = os.Stdin
|
||||||
|
sshCmd.Stdout = os.Stdout
|
||||||
|
sshCmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if err := sshCmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("start ssh: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshCmd.Wait()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
56
cmd/tui.go
56
cmd/tui.go
|
|
@ -36,41 +36,43 @@ func runTUI() error {
|
||||||
return appDB.SearchServers(query)
|
return appDB.SearchServers(query)
|
||||||
}
|
}
|
||||||
tui.DeleteServer = func(alias string) error {
|
tui.DeleteServer = func(alias string) error {
|
||||||
return appDB.DeleteServer(alias)
|
if err := appDB.DeleteServer(alias); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
v := getOrCreateVault()
|
||||||
|
if v.IsUnlocked() {
|
||||||
|
cleanupServerSecrets(v, alias)
|
||||||
|
if err := v.Save(); err != nil {
|
||||||
|
return fmt.Errorf("save vault after cleanup: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
tui.TestConnection = func(server *model.Server) (bool, string) {
|
tui.TestConnection = func(server *model.Server) (bool, string) {
|
||||||
return ssh.Test(cfg, server, vaultFunc)
|
return ssh.Test(cfg, server, vaultFunc)
|
||||||
}
|
}
|
||||||
tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) {
|
tui.TestConnectionWithPassword = func(server *model.Server, password string) (bool, string) {
|
||||||
directVaultFunc := func(sa string, st string) (string, error) {
|
return ssh.Test(cfg, server, formTestVaultFunc(vaultFunc, server, password))
|
||||||
if st == "ssh_password" || st == "key_passphrase" {
|
|
||||||
return password, nil
|
|
||||||
}
|
|
||||||
return vaultFunc(sa, st)
|
|
||||||
}
|
|
||||||
return ssh.Test(cfg, server, directVaultFunc)
|
|
||||||
}
|
}
|
||||||
tui.SaveServer = func(server *model.Server, password string) error {
|
tui.SaveServer = func(server *model.Server, password string, oldAlias string) error {
|
||||||
if password != "" {
|
v := getOrCreateVault()
|
||||||
v := getOrCreateVault()
|
if v.IsUnlocked() {
|
||||||
vaultKey := fmt.Sprintf("server:%s:ssh_password", server.Alias)
|
if err := syncServerSecrets(v, oldAlias, server, password); err != nil {
|
||||||
secretType := "ssh_password"
|
return fmt.Errorf("sync vault secrets: %w", err)
|
||||||
if server.AuthMethod == model.AuthKeyPassphrase {
|
|
||||||
vaultKey = fmt.Sprintf("server:%s:key_passphrase", server.Alias)
|
|
||||||
secretType = "key_passphrase"
|
|
||||||
}
|
|
||||||
if err := v.Put(vaultKey, secretType, []byte(password)); err != nil {
|
|
||||||
return fmt.Errorf("store secret: %w", err)
|
|
||||||
}
|
}
|
||||||
if err := v.Save(); err != nil {
|
if err := v.Save(); err != nil {
|
||||||
return fmt.Errorf("save vault: %w", err)
|
return fmt.Errorf("save vault: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, _ := appDB.GetServer(server.Alias)
|
lookupAlias := server.Alias
|
||||||
|
if oldAlias != "" {
|
||||||
|
lookupAlias = oldAlias
|
||||||
|
}
|
||||||
|
existing, _ := appDB.GetServer(lookupAlias)
|
||||||
if existing != nil {
|
if existing != nil {
|
||||||
server.ID = existing.ID
|
server.ID = existing.ID
|
||||||
return appDB.UpdateServer(server)
|
return appDB.UpdateServerByAlias(existing.Alias, server)
|
||||||
}
|
}
|
||||||
return appDB.CreateServer(server)
|
return appDB.CreateServer(server)
|
||||||
}
|
}
|
||||||
|
|
@ -84,6 +86,16 @@ func runTUI() error {
|
||||||
tui.DeleteGroup = func(name string) error {
|
tui.DeleteGroup = func(name string) error {
|
||||||
return appDB.DeleteGroup(name)
|
return appDB.DeleteGroup(name)
|
||||||
}
|
}
|
||||||
|
tui.UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
|
||||||
|
return appDB.UpdateTestResult(alias, status, testErr)
|
||||||
|
}
|
||||||
|
tui.HasSecret = func(alias string, secretType string) bool {
|
||||||
|
v := getOrCreateVault()
|
||||||
|
if !v.IsUnlocked() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return v.HasSecret(serverSecretID(alias, secretType))
|
||||||
|
}
|
||||||
|
|
||||||
// Run TUI in a loop — if user requests connect, handle it and restart TUI
|
// Run TUI in a loop — if user requests connect, handle it and restart TUI
|
||||||
for {
|
for {
|
||||||
|
|
@ -132,5 +144,3 @@ func runTUI() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
69
cmd/vault.go
69
cmd/vault.go
|
|
@ -3,11 +3,12 @@ package cmd
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"github.com/mirivlad/sshkeeper/internal/config"
|
"github.com/mirivlad/sshkeeper/internal/config"
|
||||||
"github.com/mirivlad/sshkeeper/internal/vault"
|
"github.com/mirivlad/sshkeeper/internal/vault"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -160,9 +161,75 @@ var vaultChangePasswordCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vaultListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Short: "List stored secret metadata",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
v := getOrCreateVault()
|
||||||
|
if !v.IsUnlocked() {
|
||||||
|
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
|
||||||
|
}
|
||||||
|
output, err := formatVaultSecretsList(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Print(output)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var vaultDeleteCmd = &cobra.Command{
|
||||||
|
Use: "delete <alias> [type]",
|
||||||
|
Short: "Delete stored secrets for a server",
|
||||||
|
Args: cobra.RangeArgs(1, 2),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
alias := args[0]
|
||||||
|
secretType := ""
|
||||||
|
if len(args) == 2 {
|
||||||
|
secretType = args[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
v := getOrCreateVault()
|
||||||
|
if !v.IsUnlocked() {
|
||||||
|
return fmt.Errorf("vault is locked. Unlock first with 'sshkeeper vault unlock'")
|
||||||
|
}
|
||||||
|
if err := deleteVaultSecrets(v, alias, secretType); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := v.Save(); err != nil {
|
||||||
|
return fmt.Errorf("save vault: %w", err)
|
||||||
|
}
|
||||||
|
if secretType == "" {
|
||||||
|
fmt.Printf("Deleted secrets for %s.\n", alias)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Deleted %s for %s.\n", secretType, alias)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatVaultSecretsList(v *vault.Vault) (string, error) {
|
||||||
|
metas, err := v.ListSecrets()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(metas) == 0 {
|
||||||
|
return "No secrets stored.\n", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
fmt.Fprintf(&b, "%-24s %-18s\n", "ALIAS", "TYPE")
|
||||||
|
for _, meta := range metas {
|
||||||
|
fmt.Fprintf(&b, "%-24s %-18s\n", meta.Alias, meta.Type)
|
||||||
|
}
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
vaultCmd.AddCommand(vaultUnlockCmd)
|
vaultCmd.AddCommand(vaultUnlockCmd)
|
||||||
vaultCmd.AddCommand(vaultLockCmd)
|
vaultCmd.AddCommand(vaultLockCmd)
|
||||||
vaultCmd.AddCommand(vaultStatusCmd)
|
vaultCmd.AddCommand(vaultStatusCmd)
|
||||||
vaultCmd.AddCommand(vaultChangePasswordCmd)
|
vaultCmd.AddCommand(vaultChangePasswordCmd)
|
||||||
|
vaultCmd.AddCommand(vaultListCmd)
|
||||||
|
vaultCmd.AddCommand(vaultDeleteCmd)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatVaultSecretsListDoesNotExposeSecretValues(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
mustPutSecret(t, v, "prod", "ssh_password", "super-secret")
|
||||||
|
mustPutSecret(t, v, "stage", "key_passphrase", "also-secret")
|
||||||
|
|
||||||
|
output, err := formatVaultSecretsList(v)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("format vault secrets list: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range []string{"prod", "ssh_password", "stage", "key_passphrase"} {
|
||||||
|
if !strings.Contains(output, want) {
|
||||||
|
t.Fatalf("expected output to contain %q\noutput:\n%s", want, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, secretValue := range []string{"super-secret", "also-secret"} {
|
||||||
|
if strings.Contains(output, secretValue) {
|
||||||
|
t.Fatalf("expected output not to expose secret value %q\noutput:\n%s", secretValue, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatVaultSecretsListHandlesEmptyVault(t *testing.T) {
|
||||||
|
v := newUnlockedTestVault(t)
|
||||||
|
|
||||||
|
output, err := formatVaultSecretsList(v)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("format empty vault secrets list: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "No secrets stored.") {
|
||||||
|
t.Fatalf("expected empty output message, got:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,295 @@
|
||||||
|
# Server List Scalability Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Keep the server list usable when the saved server count is larger than the terminal height.
|
||||||
|
|
||||||
|
**Architecture:** The list screen should render a bounded table viewport instead of rendering every server row. Selection remains owned by the existing `bubbles/list.Model`, while `viewServerList` derives a visible row range around the selected item and keeps the selected-server detail panel and footer visible.
|
||||||
|
|
||||||
|
**Tech Stack:** Go, Bubble Tea, Bubbles list, Lip Gloss, existing TUI tests in `internal/tui/app_test.go`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add Regression Coverage For Long Server Lists
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `internal/tui/app_test.go`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add a test for a constrained terminal height**
|
||||||
|
|
||||||
|
Add a test that creates more servers than can fit on screen, sets a small terminal size, renders the list, and verifies the selected details and footer are still visible.
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestServerListViewKeepsDetailsVisibleWithManyServers(t *testing.T) {
|
||||||
|
servers := make([]*model.Server, 45)
|
||||||
|
for i := range servers {
|
||||||
|
servers[i] = &model.Server{
|
||||||
|
Alias: fmt.Sprintf("server-%02d", i+1),
|
||||||
|
DisplayName: fmt.Sprintf("Server %02d", i+1),
|
||||||
|
Host: fmt.Sprintf("host-%02d.example.org", i+1),
|
||||||
|
Port: 22,
|
||||||
|
User: "mirivlad",
|
||||||
|
AuthMethod: model.AuthKey,
|
||||||
|
LastTestStatus: model.TestUnknown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := New(servers)
|
||||||
|
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
|
||||||
|
model := updated.(*tuiModel)
|
||||||
|
|
||||||
|
view := model.View()
|
||||||
|
if !strings.Contains(view, "Server 01") {
|
||||||
|
t.Fatalf("expected first selected server to be visible:\n%s", view)
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "Selected") {
|
||||||
|
t.Fatalf("expected selected server details to remain visible:\n%s", view)
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "Enter connect") {
|
||||||
|
t.Fatalf("expected footer to remain visible:\n%s", view)
|
||||||
|
}
|
||||||
|
if count := strings.Count(view, "server-"); count >= len(servers) {
|
||||||
|
t.Fatalf("expected bounded row rendering, rendered %d server aliases", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run the focused test and confirm it fails**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: FAIL because the current table renders every server and can push the detail panel/footer below the visible terminal area.
|
||||||
|
|
||||||
|
### Task 2: Compute A Visible Row Window
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `internal/tui/app.go`
|
||||||
|
- Modify: `internal/tui/app_test.go`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add focused tests for visible range calculation**
|
||||||
|
|
||||||
|
Add tests for a helper that computes the inclusive start and exclusive end indexes for rendered rows.
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestVisibleServerRangeKeepsSelectionInsideWindow(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
total int
|
||||||
|
selected int
|
||||||
|
available int
|
||||||
|
wantStart int
|
||||||
|
wantEnd int
|
||||||
|
}{
|
||||||
|
{name: "first page", total: 40, selected: 0, available: 10, wantStart: 0, wantEnd: 10},
|
||||||
|
{name: "middle page", total: 40, selected: 20, available: 10, wantStart: 11, wantEnd: 21},
|
||||||
|
{name: "last page", total: 40, selected: 39, available: 10, wantStart: 30, wantEnd: 40},
|
||||||
|
{name: "all fit", total: 5, selected: 3, available: 10, wantStart: 0, wantEnd: 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
start, end := visibleServerRange(tt.total, tt.selected, tt.available)
|
||||||
|
if start != tt.wantStart || end != tt.wantEnd {
|
||||||
|
t.Fatalf("visibleServerRange() = %d, %d; want %d, %d", start, end, tt.wantStart, tt.wantEnd)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Implement `visibleServerRange`**
|
||||||
|
|
||||||
|
Add a small helper near `selectedServer`.
|
||||||
|
|
||||||
|
```go
|
||||||
|
func visibleServerRange(total, selected, available int) (int, int) {
|
||||||
|
if total <= 0 || available <= 0 {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
if available >= total {
|
||||||
|
return 0, total
|
||||||
|
}
|
||||||
|
if selected < 0 {
|
||||||
|
selected = 0
|
||||||
|
}
|
||||||
|
if selected >= total {
|
||||||
|
selected = total - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
start := selected - available + 1
|
||||||
|
if start < 0 {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
end := start + available
|
||||||
|
if end > total {
|
||||||
|
end = total
|
||||||
|
start = end - available
|
||||||
|
}
|
||||||
|
return start, end
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 3: Run helper tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestVisibleServerRangeKeepsSelectionInsideWindow -count=1
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: PASS.
|
||||||
|
|
||||||
|
### Task 3: Render Only Rows That Fit
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `internal/tui/app.go`
|
||||||
|
- Modify: `internal/tui/app_test.go`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Reserve terminal space for fixed UI blocks**
|
||||||
|
|
||||||
|
Add a helper that decides how many server rows may be rendered while keeping the selected details and footer visible.
|
||||||
|
|
||||||
|
```go
|
||||||
|
func (m *tuiModel) visibleServerRows() int {
|
||||||
|
if m.height <= 0 {
|
||||||
|
return len(m.servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
const fixedRows = 16
|
||||||
|
rows := m.height - fixedRows
|
||||||
|
if rows < 3 {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Use the visible range in `viewServerList`**
|
||||||
|
|
||||||
|
In `viewServerList`, replace the loop over all servers with a bounded loop:
|
||||||
|
|
||||||
|
```go
|
||||||
|
selectedIndex := m.list.Index()
|
||||||
|
start, end := visibleServerRange(len(m.servers), selectedIndex, m.visibleServerRows())
|
||||||
|
for _, server := range m.servers[start:end] {
|
||||||
|
// existing row rendering body stays unchanged
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Then render a compact range hint when rows are hidden:
|
||||||
|
|
||||||
|
```go
|
||||||
|
if len(m.servers) > end-start {
|
||||||
|
b.WriteString(helpStyle.Render(fmt.Sprintf(" Showing %d-%d of %d", start+1, end, len(m.servers))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 3: Run long-list regression test**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -run TestServerListViewKeepsDetailsVisibleWithManyServers -count=1
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: PASS.
|
||||||
|
|
||||||
|
### Task 4: Verify Navigation Still Works
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `internal/tui/app_test.go`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Add a test for moving selection beyond the first window**
|
||||||
|
|
||||||
|
Use the existing `m.list.Update` path by sending `tea.KeyDown` messages and confirm the rendered window follows the selected server.
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestServerListViewScrollsWithSelection(t *testing.T) {
|
||||||
|
servers := make([]*model.Server, 45)
|
||||||
|
for i := range servers {
|
||||||
|
servers[i] = &model.Server{
|
||||||
|
Alias: fmt.Sprintf("server-%02d", i+1),
|
||||||
|
DisplayName: fmt.Sprintf("Server %02d", i+1),
|
||||||
|
Host: fmt.Sprintf("host-%02d.example.org", i+1),
|
||||||
|
Port: 22,
|
||||||
|
User: "mirivlad",
|
||||||
|
AuthMethod: model.AuthKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := New(servers)
|
||||||
|
updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 18})
|
||||||
|
model := updated.(*tuiModel)
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyDown})
|
||||||
|
model = updated.(*tuiModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
view := model.View()
|
||||||
|
if !strings.Contains(view, "Server 21") {
|
||||||
|
t.Fatalf("expected selected server to be visible after navigation:\n%s", view)
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "Showing") {
|
||||||
|
t.Fatalf("expected range hint for long server list:\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run TUI tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
env GOCACHE=/tmp/sshkeeper-go-cache go test ./internal/tui -count=1
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: PASS.
|
||||||
|
|
||||||
|
### Task 5: Final Verification And Build
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- No source edits expected.
|
||||||
|
|
||||||
|
- [ ] **Step 1: Run the full test suite**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
env GOCACHE=/tmp/sshkeeper-go-cache go test ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: all packages pass.
|
||||||
|
|
||||||
|
- [ ] **Step 2: Rebuild the project binary**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
env GOCACHE=/tmp/sshkeeper-go-cache go build -o bin/sshkeeper .
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: exit code 0 and updated `bin/sshkeeper`.
|
||||||
|
|
||||||
|
- [ ] **Step 3: Commit the implementation**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add internal/tui/app.go internal/tui/app_test.go bin/sshkeeper
|
||||||
|
git commit -m "fix: keep server list usable with many servers"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: commit succeeds.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Self-Review
|
||||||
|
|
||||||
|
- Spec coverage: the plan covers the known failure mode for 40+ servers, keeps selected details visible, keeps the footer visible, and preserves existing `bubbles/list` navigation.
|
||||||
|
- Placeholder scan: no `TBD`, `TODO`, or open-ended implementation placeholders remain.
|
||||||
|
- Type consistency: helper names and files match the current TUI code shape.
|
||||||
2
go.mod
2
go.mod
|
|
@ -10,6 +10,7 @@ require (
|
||||||
github.com/creack/pty v1.1.24
|
github.com/creack/pty v1.1.24
|
||||||
github.com/spf13/cobra v1.10.2
|
github.com/spf13/cobra v1.10.2
|
||||||
golang.org/x/crypto v0.52.0
|
golang.org/x/crypto v0.52.0
|
||||||
|
golang.org/x/sys v0.45.0
|
||||||
golang.org/x/term v0.43.0
|
golang.org/x/term v0.43.0
|
||||||
modernc.org/sqlite v1.50.1
|
modernc.org/sqlite v1.50.1
|
||||||
)
|
)
|
||||||
|
|
@ -41,7 +42,6 @@ require (
|
||||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||||
github.com/spf13/pflag v1.0.9 // indirect
|
github.com/spf13/pflag v1.0.9 // indirect
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
golang.org/x/sys v0.45.0 // indirect
|
|
||||||
golang.org/x/text v0.37.0 // indirect
|
golang.org/x/text v0.37.0 // indirect
|
||||||
modernc.org/libc v1.72.3 // indirect
|
modernc.org/libc v1.72.3 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,17 @@ func (db *DB) UpdateServer(s *model.Server) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) UpdateServerByAlias(oldAlias string, s *model.Server) error {
|
||||||
|
_, err := db.conn.Exec(`
|
||||||
|
UPDATE servers SET
|
||||||
|
alias=?, display_name=?, host=?, port=?, user=?, auth_method=?,
|
||||||
|
identity_file=?, proxy_jump=?, group_name=?, notes=?, updated_at=CURRENT_TIMESTAMP
|
||||||
|
WHERE alias=?`,
|
||||||
|
s.Alias, s.DisplayName, s.Host, s.Port, s.User, s.AuthMethod,
|
||||||
|
s.IdentityFile, s.ProxyJump, s.GroupName, s.Notes, oldAlias)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (db *DB) DeleteServer(alias string) error {
|
func (db *DB) DeleteServer(alias string) error {
|
||||||
_, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias)
|
_, err := db.conn.Exec("DELETE FROM servers WHERE alias=?", alias)
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateServerByAliasCanRenameAlias(t *testing.T) {
|
||||||
|
db, err := Open(t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
server := &model.Server{
|
||||||
|
Alias: "old",
|
||||||
|
Host: "old.example",
|
||||||
|
Port: 22,
|
||||||
|
User: "root",
|
||||||
|
AuthMethod: model.AuthPassword,
|
||||||
|
}
|
||||||
|
if err := db.CreateServer(server); err != nil {
|
||||||
|
t.Fatalf("create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.Alias = "new"
|
||||||
|
server.Host = "new.example"
|
||||||
|
server.AuthMethod = model.AuthKey
|
||||||
|
if err := db.UpdateServerByAlias("old", server); err != nil {
|
||||||
|
t.Fatalf("update server by old alias: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.GetServer("old"); err == nil {
|
||||||
|
t.Fatal("expected old alias to be gone")
|
||||||
|
}
|
||||||
|
got, err := db.GetServer("new")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get new alias: %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != server.ID {
|
||||||
|
t.Fatalf("expected ID to stay %d, got %d", server.ID, got.ID)
|
||||||
|
}
|
||||||
|
if got.Host != "new.example" || got.AuthMethod != model.AuthKey {
|
||||||
|
t.Fatalf("unexpected updated server: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mirivlad/sshkeeper/internal/config"
|
"github.com/mirivlad/sshkeeper/internal/config"
|
||||||
"github.com/mirivlad/sshkeeper/internal/model"
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
|
@ -14,7 +13,7 @@ import (
|
||||||
type VaultFunc func(serverAlias string, secretType string) (string, error)
|
type VaultFunc func(serverAlias string, secretType string) (string, error)
|
||||||
|
|
||||||
func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error {
|
func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error {
|
||||||
args := buildArgs(server)
|
args := BuildSSHArgs(server)
|
||||||
|
|
||||||
switch server.AuthMethod {
|
switch server.AuthMethod {
|
||||||
case model.AuthPassword:
|
case model.AuthPassword:
|
||||||
|
|
@ -25,8 +24,7 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
|
||||||
return ConnectWithPassword(cfg.SSH.Binary, args, password)
|
return ConnectWithPassword(cfg.SSH.Binary, args, password)
|
||||||
|
|
||||||
case model.AuthKeyPassphrase:
|
case model.AuthKeyPassphrase:
|
||||||
// For key+passphrase, we need to handle the passphrase
|
// For key+passphrase, let ssh-agent handle it or prompt normally
|
||||||
// For now, let ssh-agent handle it or prompt normally
|
|
||||||
// TODO: use ssh-agent or similar
|
// TODO: use ssh-agent or similar
|
||||||
fallthrough
|
fallthrough
|
||||||
|
|
||||||
|
|
@ -46,13 +44,11 @@ func Connect(cfg *config.Config, server *model.Server, getVault VaultFunc) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) {
|
func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, string) {
|
||||||
args := buildArgs(server)
|
args := BuildSSHArgs(server)
|
||||||
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec))
|
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", cfg.SSH.ConnectTimeoutSec))
|
||||||
|
|
||||||
switch server.AuthMethod {
|
switch server.AuthMethod {
|
||||||
case model.AuthPassword:
|
case model.AuthPassword:
|
||||||
// For password auth, we can't use BatchMode
|
|
||||||
// Use a short timeout and try to connect
|
|
||||||
args = append(args, "-o", "NumberOfPasswordPrompts=1")
|
args = append(args, "-o", "NumberOfPasswordPrompts=1")
|
||||||
password, err := getVault(server.Alias, "ssh_password")
|
password, err := getVault(server.Alias, "ssh_password")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -81,40 +77,29 @@ func Test(cfg *config.Config, server *model.Server, getVault VaultFunc) (bool, s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testWithPassword tests SSH connection with password auth via PTY-wrapper.
|
||||||
|
// It connects, sends the password, runs the test command, and checks the output.
|
||||||
func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) {
|
func testWithPassword(cfg *config.Config, args []string, password string) (bool, string) {
|
||||||
// For password test, we use PTY approach with a short timeout
|
|
||||||
// This is a simplified version - in production, use ConnectWithPassword
|
|
||||||
// with a test command
|
|
||||||
args = append(args, cfg.SSH.TestCommand)
|
args = append(args, cfg.SSH.TestCommand)
|
||||||
|
|
||||||
cmd := exec.Command(cfg.SSH.Binary, args...)
|
ok, output := connectWithPasswordAndRead(cfg.SSH.Binary, args, password, cfg.SSH.ConnectTimeoutSec)
|
||||||
cmd.Stdin = nil
|
if !ok {
|
||||||
cmd.Stdout = nil
|
return false, output
|
||||||
cmd.Stderr = nil
|
|
||||||
|
|
||||||
// Use a timeout
|
|
||||||
done := make(chan error, 1)
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
return false, err.Error()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
result := strings.TrimSpace(output)
|
||||||
done <- cmd.Wait()
|
if result == "SSHKEEPER_OK" {
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-done:
|
|
||||||
if err != nil {
|
|
||||||
return false, err.Error()
|
|
||||||
}
|
|
||||||
return true, ""
|
return true, ""
|
||||||
case <-time.After(time.Duration(cfg.SSH.ConnectTimeoutSec) * time.Second):
|
|
||||||
cmd.Process.Kill()
|
|
||||||
return false, "connection timeout"
|
|
||||||
}
|
}
|
||||||
|
// The output might have the test command echo before the result
|
||||||
|
if strings.Contains(result, "SSHKEEPER_OK") {
|
||||||
|
return true, ""
|
||||||
|
}
|
||||||
|
return false, result
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildArgs(server *model.Server) []string {
|
// BuildSSHArgs builds the SSH command arguments for a server profile.
|
||||||
|
func BuildSSHArgs(server *model.Server) []string {
|
||||||
var args []string
|
var args []string
|
||||||
|
|
||||||
args = append(args, "-p", fmt.Sprintf("%d", server.Port))
|
args = append(args, "-p", fmt.Sprintf("%d", server.Port))
|
||||||
|
|
@ -127,8 +112,6 @@ func buildArgs(server *model.Server) []string {
|
||||||
args = append(args, "-J", server.ProxyJump)
|
args = append(args, "-J", server.ProxyJump)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable strict host key checking for first connection
|
|
||||||
// In production, this should be configurable
|
|
||||||
args = append(args, "-o", "StrictHostKeyChecking=accept-new")
|
args = append(args, "-o", "StrictHostKeyChecking=accept-new")
|
||||||
|
|
||||||
target := fmt.Sprintf("%s@%s", server.User, server.Host)
|
target := fmt.Sprintf("%s@%s", server.User, server.Host)
|
||||||
|
|
|
||||||
|
|
@ -7,19 +7,19 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/creack/pty"
|
"github.com/creack/pty"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`)
|
var passwordPromptRe = regexp.MustCompile(`(?i)(password|passphrase).*:\s*$`)
|
||||||
|
|
||||||
// ConnectWithPassword runs SSH through a PTY, detects the password prompt,
|
|
||||||
// sends the password, and then bridges the user terminal to the SSH session.
|
|
||||||
func ConnectWithPassword(sshBinary string, args []string, password string) error {
|
func ConnectWithPassword(sshBinary string, args []string, password string) error {
|
||||||
// Start SSH with PTY
|
|
||||||
cmd := exec.Command(sshBinary, args...)
|
cmd := exec.Command(sshBinary, args...)
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
|
@ -33,18 +33,14 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
|
||||||
}
|
}
|
||||||
defer ptmx.Close()
|
defer ptmx.Close()
|
||||||
|
|
||||||
// Save terminal state and set to raw
|
|
||||||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set raw terminal: %w", err)
|
return fmt.Errorf("set raw terminal: %w", err)
|
||||||
}
|
}
|
||||||
defer term.Restore(int(os.Stdin.Fd()), oldState)
|
|
||||||
|
|
||||||
// Channel to signal when password has been sent
|
passwordSent := new(atomic.Bool)
|
||||||
passwordSent := make(chan bool, 1)
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
|
|
||||||
// Read from PTY, detect password prompt
|
|
||||||
go func() {
|
go func() {
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
var accumulated strings.Builder
|
var accumulated strings.Builder
|
||||||
|
|
@ -54,22 +50,17 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
data := buf[:n]
|
data := buf[:n]
|
||||||
accumulated.Write(data)
|
accumulated.Write(data)
|
||||||
|
|
||||||
// Write to stdout
|
|
||||||
os.Stdout.Write(data)
|
os.Stdout.Write(data)
|
||||||
|
|
||||||
// Check for password prompt
|
if !passwordSent.Load() {
|
||||||
if !<-passwordSent {
|
|
||||||
text := accumulated.String()
|
text := accumulated.String()
|
||||||
if passwordPromptRe.MatchString(text) {
|
if passwordPromptRe.MatchString(text) {
|
||||||
passwordSent <- true
|
passwordSent.Store(true)
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
ptmx.Write([]byte(password + "\r"))
|
ptmx.Write([]byte(password + "\r"))
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset accumulated buffer periodically to avoid unbounded growth
|
|
||||||
if accumulated.Len() > 8192 {
|
if accumulated.Len() > 8192 {
|
||||||
s := accumulated.String()
|
s := accumulated.String()
|
||||||
accumulated.Reset()
|
accumulated.Reset()
|
||||||
|
|
@ -87,14 +78,142 @@ func ConnectWithPassword(sshBinary string, args []string, password string) error
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Copy stdin to PTY
|
stopInput, waitInput, restoreInput, err := forwardInputToPTY(ptmx)
|
||||||
go func() {
|
if err != nil {
|
||||||
io.Copy(ptmx, os.Stdin)
|
term.Restore(int(os.Stdin.Fd()), oldState)
|
||||||
}()
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for command completion
|
|
||||||
err = cmd.Wait()
|
err = cmd.Wait()
|
||||||
passwordSent <- false // signal to stop
|
close(stopInput)
|
||||||
|
waitInput()
|
||||||
|
|
||||||
|
if restoreErr := restoreInput(); err == nil && restoreErr != nil {
|
||||||
|
err = restoreErr
|
||||||
|
}
|
||||||
|
if restoreErr := term.Restore(int(os.Stdin.Fd()), oldState); err == nil && restoreErr != nil {
|
||||||
|
err = restoreErr
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func forwardInputToPTY(ptmx *os.File) (chan struct{}, func(), func() error, error) {
|
||||||
|
stdinFD := int(os.Stdin.Fd())
|
||||||
|
flags, err := unix.FcntlInt(uintptr(stdinFD), unix.F_GETFL, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("get stdin flags: %w", err)
|
||||||
|
}
|
||||||
|
if err := unix.SetNonblock(stdinFD, true); err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("set stdin nonblocking: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stop := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
n, readErr := unix.Read(stdinFD, buf)
|
||||||
|
if n > 0 {
|
||||||
|
if _, writeErr := ptmx.Write(buf[:n]); writeErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readErr == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if readErr == unix.EAGAIN || readErr == unix.EWOULDBLOCK {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
restore := func() error {
|
||||||
|
_, err := unix.FcntlInt(uintptr(stdinFD), unix.F_SETFL, flags)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("restore stdin flags: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return stop, wg.Wait, restore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectWithPasswordAndRead runs SSH through a PTY, sends the password,
|
||||||
|
// collects all output, and returns it. Used for non-interactive testing.
|
||||||
|
// Returns (true, output) on success, (false, error) on failure.
|
||||||
|
func connectWithPasswordAndRead(sshBinary string, args []string, password string, timeoutSec int) (bool, string) {
|
||||||
|
cmd := exec.Command(sshBinary, args...)
|
||||||
|
cmd.Env = os.Environ()
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
Setsid: true,
|
||||||
|
Setctty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
ptmx, err := pty.Start(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Sprintf("start ssh with pty: %v", err)
|
||||||
|
}
|
||||||
|
defer ptmx.Close()
|
||||||
|
|
||||||
|
passwordSent := false
|
||||||
|
done := make(chan string, 1)
|
||||||
|
|
||||||
|
// Read from PTY, detect password prompt, collect output
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
var accumulated strings.Builder
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := ptmx.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
data := buf[:n]
|
||||||
|
accumulated.Write(data)
|
||||||
|
|
||||||
|
// Check for password prompt
|
||||||
|
if !passwordSent {
|
||||||
|
text := accumulated.String()
|
||||||
|
if passwordPromptRe.MatchString(text) {
|
||||||
|
passwordSent = true
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
ptmx.Write([]byte(password + "\r"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset accumulated buffer periodically
|
||||||
|
if accumulated.Len() > 16384 {
|
||||||
|
s := accumulated.String()
|
||||||
|
accumulated.Reset()
|
||||||
|
accumulated.WriteString(s[len(s)-4096:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
done <- accumulated.String()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for command completion or timeout
|
||||||
|
select {
|
||||||
|
case output := <-done:
|
||||||
|
return true, output
|
||||||
|
case <-time.After(time.Duration(timeoutSec) * time.Second):
|
||||||
|
cmd.Process.Kill()
|
||||||
|
return false, "connection timeout"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/creack/pty"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnectWithPasswordDoesNotConsumeInputAfterReturn(t *testing.T) {
|
||||||
|
script := filepath.Join(t.TempDir(), "fake-ssh")
|
||||||
|
if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf 'password: '\nsleep 0.1\nexit 0\n"), 0o700); err != nil {
|
||||||
|
t.Fatalf("write fake ssh: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
master, slave, err := pty.Open()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open stdin pty: %v", err)
|
||||||
|
}
|
||||||
|
defer master.Close()
|
||||||
|
defer slave.Close()
|
||||||
|
|
||||||
|
oldStdin := os.Stdin
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
stdoutSink, err := os.CreateTemp(t.TempDir(), "stdout")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create stdout sink: %v", err)
|
||||||
|
}
|
||||||
|
defer stdoutSink.Close()
|
||||||
|
os.Stdin = slave
|
||||||
|
os.Stdout = stdoutSink
|
||||||
|
defer func() {
|
||||||
|
os.Stdin = oldStdin
|
||||||
|
os.Stdout = oldStdout
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := ConnectWithPassword(script, nil, "secret"); err != nil {
|
||||||
|
t.Fatalf("connect with password: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := master.Write([]byte("\n")); err != nil {
|
||||||
|
t.Fatalf("write next enter: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
readDone := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
n, _ := slave.Read(buf)
|
||||||
|
readDone <- buf[:n]
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-readDone:
|
||||||
|
if len(got) != 1 || got[0] != '\n' {
|
||||||
|
t.Fatalf("expected to read newline, got %q", got)
|
||||||
|
}
|
||||||
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
t.Fatal("expected next Enter to remain readable after ConnectWithPassword returned")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -26,7 +26,10 @@ var (
|
||||||
Background(lipgloss.Color("4")).
|
Background(lipgloss.Color("4")).
|
||||||
Bold(true)
|
Bold(true)
|
||||||
|
|
||||||
normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15"))
|
normalStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15"))
|
||||||
|
selectedRowStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("15")).Background(lipgloss.Color("4"))
|
||||||
|
listHeaderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("14")).Bold(true)
|
||||||
|
sectionStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")).Bold(true).MarginTop(1)
|
||||||
|
|
||||||
testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true)
|
testOKStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("10")).Bold(true)
|
||||||
testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true)
|
testFailStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true)
|
||||||
|
|
@ -67,9 +70,13 @@ type serverItem struct {
|
||||||
server *model.Server
|
server *model.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i serverItem) Title() string { return i.server.Alias }
|
func (i serverItem) Title() string { return i.server.Alias }
|
||||||
func (i serverItem) Description() string { return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod) }
|
func (i serverItem) Description() string {
|
||||||
func (i serverItem) FilterValue() string { return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User }
|
return fmt.Sprintf("%s@%s:%d %s", i.server.User, i.server.Host, i.server.Port, i.server.AuthMethod)
|
||||||
|
}
|
||||||
|
func (i serverItem) FilterValue() string {
|
||||||
|
return i.server.Alias + " " + i.server.DisplayName + " " + i.server.Host + " " + i.server.User
|
||||||
|
}
|
||||||
|
|
||||||
// --- External callbacks ---
|
// --- External callbacks ---
|
||||||
|
|
||||||
|
|
@ -80,10 +87,12 @@ var (
|
||||||
TestConnection func(server *model.Server) (bool, string)
|
TestConnection func(server *model.Server) (bool, string)
|
||||||
// TestConnectionWithPassword tests with explicit password (for form test before save)
|
// TestConnectionWithPassword tests with explicit password (for form test before save)
|
||||||
TestConnectionWithPassword func(server *model.Server, password string) (bool, string)
|
TestConnectionWithPassword func(server *model.Server, password string) (bool, string)
|
||||||
SaveServer func(server *model.Server, password string) error
|
SaveServer func(server *model.Server, password string, oldAlias string) error
|
||||||
GetGroups func() ([]string, error) // Returns existing group names
|
UpdateTestResult func(alias string, status model.TestStatus, testErr string) error
|
||||||
RenameGroup func(oldName, newName string) error // Rename group for all servers
|
HasSecret func(alias string, secretType string) bool
|
||||||
DeleteGroup func(name string) error // Remove group from all servers
|
GetGroups func() ([]string, error)
|
||||||
|
RenameGroup func(oldName, newName string) error
|
||||||
|
DeleteGroup func(name string) error
|
||||||
)
|
)
|
||||||
|
|
||||||
// --- Screen type ---
|
// --- Screen type ---
|
||||||
|
|
@ -99,8 +108,8 @@ const (
|
||||||
// --- Result type — returned from TUI to caller ---
|
// --- Result type — returned from TUI to caller ---
|
||||||
|
|
||||||
type TUIResult struct {
|
type TUIResult struct {
|
||||||
Server *model.Server
|
Server *model.Server
|
||||||
Action string // "connect"
|
Action string // "connect"
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Main TUI model ---
|
// --- Main TUI model ---
|
||||||
|
|
@ -196,8 +205,22 @@ func (m *tuiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
}
|
}
|
||||||
m.form.testResultTime = time.Now()
|
m.form.testResultTime = time.Now()
|
||||||
m.form.err = nil
|
m.form.err = nil
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
// Update test status in DB and reload list
|
||||||
|
if item, ok := m.list.SelectedItem().(serverItem); ok && UpdateTestResult != nil {
|
||||||
|
status := model.TestUnknown
|
||||||
|
if msg.ok {
|
||||||
|
status = model.TestOK
|
||||||
|
} else if msg.err != "" {
|
||||||
|
status = model.TestFailed
|
||||||
|
}
|
||||||
|
UpdateTestResult(item.server.Alias, status, msg.err)
|
||||||
|
}
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
servers, err := ListServers()
|
||||||
|
return serversLoadedMsg{servers: servers, err: err}
|
||||||
}
|
}
|
||||||
return m, nil
|
|
||||||
|
|
||||||
case saveDoneMsg:
|
case saveDoneMsg:
|
||||||
if m.form != nil {
|
if m.form != nil {
|
||||||
|
|
@ -323,11 +346,23 @@ func (m *tuiModel) updateSearch(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
|
|
||||||
func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m *tuiModel) updateForm(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
if msg.Type == tea.KeyEsc {
|
if msg.Type == tea.KeyEsc {
|
||||||
|
if m.form != nil && (m.form.showGroupList || m.form.showAuthList) {
|
||||||
|
updated, cmd := m.form.Update(msg)
|
||||||
|
if fm, ok := updated.(*formModel); ok {
|
||||||
|
m.form = fm
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
m.screen = screenList
|
m.screen = screenList
|
||||||
m.form = nil
|
m.form = nil
|
||||||
m.err = nil
|
m.err = nil
|
||||||
m.success = ""
|
m.success = ""
|
||||||
return m, nil
|
// Reload server list after form close
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
servers, err := ListServers()
|
||||||
|
return serversLoadedMsg{servers: servers, err: err}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, cmd := m.form.Update(msg)
|
updated, cmd := m.form.Update(msg)
|
||||||
|
|
@ -342,9 +377,7 @@ func (m *tuiModel) View() string {
|
||||||
|
|
||||||
switch m.screen {
|
switch m.screen {
|
||||||
case screenList:
|
case screenList:
|
||||||
b.WriteString(m.list.View())
|
b.WriteString(m.viewServerList())
|
||||||
b.WriteString("\n")
|
|
||||||
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
|
|
||||||
|
|
||||||
case screenSearch:
|
case screenSearch:
|
||||||
b.WriteString("Search: " + m.searchInput.View() + "\n")
|
b.WriteString("Search: " + m.searchInput.View() + "\n")
|
||||||
|
|
@ -366,13 +399,148 @@ func (m *tuiModel) View() string {
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *tuiModel) viewServerList() string {
|
||||||
|
var b strings.Builder
|
||||||
|
selectedAlias := ""
|
||||||
|
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
|
||||||
|
selectedAlias = item.server.Alias
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(titleStyle.Render(fmt.Sprintf("sshkeeper %d servers", len(m.servers))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(helpStyle.Render(fmt.Sprintf("Vault unlocked | %s", testSummary(m.servers))))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(listHeaderStyle.Render(fmt.Sprintf(" %-20s %-20s %-34s %-12s %-10s %s", "NAME", "ALIAS", "TARGET", "AUTH", "GROUP", "STATUS")))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
if len(m.servers) == 0 {
|
||||||
|
b.WriteString(helpStyle.Render(" No servers yet. Press Ctrl+A to add one."))
|
||||||
|
b.WriteString("\n")
|
||||||
|
} else {
|
||||||
|
for _, server := range m.servers {
|
||||||
|
marker := " "
|
||||||
|
rowStyle := normalStyle
|
||||||
|
if server.Alias == selectedAlias {
|
||||||
|
marker = ">"
|
||||||
|
rowStyle = selectedRowStyle
|
||||||
|
}
|
||||||
|
name := server.DisplayName
|
||||||
|
if name == "" {
|
||||||
|
name = server.Alias
|
||||||
|
}
|
||||||
|
target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port)
|
||||||
|
group := server.GroupName
|
||||||
|
if group == "" {
|
||||||
|
group = "-"
|
||||||
|
}
|
||||||
|
row := fmt.Sprintf("%s %-20s %-20s %-34s %-12s %-10s %s",
|
||||||
|
marker,
|
||||||
|
truncate(name, 20),
|
||||||
|
truncate(server.Alias, 20),
|
||||||
|
truncate(target, 34),
|
||||||
|
authLabel(server.AuthMethod),
|
||||||
|
truncate(group, 10),
|
||||||
|
testStatusLabel(server),
|
||||||
|
)
|
||||||
|
b.WriteString(rowStyle.Render(row))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
if selectedAlias != "" {
|
||||||
|
if selected := m.selectedServer(); selected != nil {
|
||||||
|
b.WriteString(m.viewSelectedServer(selected))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteString(helpStyle.Render("Enter connect | Ctrl+A add | Ctrl+E edit | Ctrl+D del | Ctrl+T test | Ctrl+F search | Ctrl+Q quit"))
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *tuiModel) selectedServer() *model.Server {
|
||||||
|
if item, ok := m.list.SelectedItem().(serverItem); ok && item.server != nil {
|
||||||
|
return item.server
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *tuiModel) viewSelectedServer(server *model.Server) string {
|
||||||
|
displayName := server.DisplayName
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = "-"
|
||||||
|
}
|
||||||
|
group := server.GroupName
|
||||||
|
if group == "" {
|
||||||
|
group = "-"
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(sectionStyle.Render("Selected"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" Alias: %s\n", server.Alias))
|
||||||
|
b.WriteString(fmt.Sprintf(" Display Name: %s\n", displayName))
|
||||||
|
b.WriteString(fmt.Sprintf(" Host: %s\n", server.Host))
|
||||||
|
b.WriteString(fmt.Sprintf(" Port: %d\n", server.Port))
|
||||||
|
b.WriteString(fmt.Sprintf(" User: %s\n", server.User))
|
||||||
|
b.WriteString(fmt.Sprintf(" Auth: %s\n", authLabel(server.AuthMethod)))
|
||||||
|
b.WriteString(fmt.Sprintf(" Group: %s\n", group))
|
||||||
|
b.WriteString(fmt.Sprintf(" Status: %s\n", testStatusLabel(server)))
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSummary(servers []*model.Server) string {
|
||||||
|
okCount := 0
|
||||||
|
failedCount := 0
|
||||||
|
for _, server := range servers {
|
||||||
|
switch server.LastTestStatus {
|
||||||
|
case model.TestOK:
|
||||||
|
okCount++
|
||||||
|
case model.TestFailed:
|
||||||
|
failedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d OK | %d FAIL", okCount, failedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLabel(auth model.AuthMethod) string {
|
||||||
|
switch auth {
|
||||||
|
case model.AuthPassword:
|
||||||
|
return "password"
|
||||||
|
case model.AuthKey:
|
||||||
|
return "key"
|
||||||
|
case model.AuthKeyPassphrase:
|
||||||
|
return "key+phrase"
|
||||||
|
case model.AuthAgent:
|
||||||
|
return "agent"
|
||||||
|
default:
|
||||||
|
return string(auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStatusLabel(server *model.Server) string {
|
||||||
|
switch server.LastTestStatus {
|
||||||
|
case model.TestOK:
|
||||||
|
return "OK"
|
||||||
|
case model.TestFailed:
|
||||||
|
if server.LastTestError != "" {
|
||||||
|
return "FAIL"
|
||||||
|
}
|
||||||
|
return "FAIL"
|
||||||
|
default:
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Form model ---
|
// --- Form model ---
|
||||||
|
|
||||||
type formModel struct {
|
type formModel struct {
|
||||||
edit bool
|
edit bool
|
||||||
server *model.Server
|
server *model.Server
|
||||||
inputs []textinput.Model
|
inputs []textinput.Model
|
||||||
|
labels []string
|
||||||
password textinput.Model
|
password textinput.Model
|
||||||
|
passwordLabel string
|
||||||
focusIdx int
|
focusIdx int
|
||||||
testResult string
|
testResult string
|
||||||
testOK bool
|
testOK bool
|
||||||
|
|
@ -385,10 +553,22 @@ type formModel struct {
|
||||||
spinner spinner.Model
|
spinner spinner.Model
|
||||||
width int
|
width int
|
||||||
height int
|
height int
|
||||||
groups string // cached groups list (comma-separated for display)
|
groups []string // existing group names
|
||||||
showGroups bool // show group dropdown
|
groupList list.Model // dropdown list for groups
|
||||||
|
showGroupList bool // whether group dropdown is visible
|
||||||
|
authList list.Model
|
||||||
|
showAuthList bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// groupItem implements list.Item for the group dropdown
|
||||||
|
type groupItem struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i groupItem) Title() string { return i.name }
|
||||||
|
func (i groupItem) Description() string { return "" }
|
||||||
|
func (i groupItem) FilterValue() string { return i.name }
|
||||||
|
|
||||||
func newFormModel(w, h int) *formModel {
|
func newFormModel(w, h int) *formModel {
|
||||||
inputs := make([]textinput.Model, 10)
|
inputs := make([]textinput.Model, 10)
|
||||||
labels := []string{
|
labels := []string{
|
||||||
|
|
@ -405,12 +585,12 @@ func newFormModel(w, h int) *formModel {
|
||||||
}
|
}
|
||||||
for i, label := range labels {
|
for i, label := range labels {
|
||||||
inputs[i] = textinput.New()
|
inputs[i] = textinput.New()
|
||||||
inputs[i].Placeholder = label
|
inputs[i].Placeholder = placeholderForLabel(label)
|
||||||
inputs[i].CharLimit = 128
|
inputs[i].CharLimit = 128
|
||||||
}
|
}
|
||||||
|
|
||||||
pw := textinput.New()
|
pw := textinput.New()
|
||||||
pw.Placeholder = "Password / Passphrase (stored in vault)"
|
pw.Placeholder = "optional"
|
||||||
pw.CharLimit = 256
|
pw.CharLimit = 256
|
||||||
pw.EchoMode = textinput.EchoPassword
|
pw.EchoMode = textinput.EchoPassword
|
||||||
|
|
||||||
|
|
@ -421,24 +601,75 @@ func newFormModel(w, h int) *formModel {
|
||||||
inputs[0].Focus()
|
inputs[0].Focus()
|
||||||
|
|
||||||
fm := &formModel{
|
fm := &formModel{
|
||||||
inputs: inputs,
|
inputs: inputs,
|
||||||
password: pw,
|
labels: labels,
|
||||||
focusIdx: 0,
|
password: pw,
|
||||||
spinner: s,
|
passwordLabel: "Password / Passphrase",
|
||||||
width: w,
|
focusIdx: 0,
|
||||||
height: h,
|
spinner: s,
|
||||||
|
width: w,
|
||||||
|
height: h,
|
||||||
}
|
}
|
||||||
|
fm.authList = newStringList([]string{
|
||||||
|
string(model.AuthPassword),
|
||||||
|
string(model.AuthKey),
|
||||||
|
string(model.AuthKeyPassphrase),
|
||||||
|
string(model.AuthAgent),
|
||||||
|
}, "Select auth method", 34, 16)
|
||||||
|
|
||||||
// Load existing groups
|
// Load existing groups
|
||||||
if GetGroups != nil {
|
if GetGroups != nil {
|
||||||
if groups, err := GetGroups(); err == nil && len(groups) > 0 {
|
if groups, err := GetGroups(); err == nil && len(groups) > 0 {
|
||||||
fm.groups = strings.Join(groups, ", ")
|
fm.groups = groups
|
||||||
|
fm.groupList = newStringList(groups, "Select group", 30, 8)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fm.updateFocus()
|
||||||
return fm
|
return fm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func placeholderForLabel(label string) string {
|
||||||
|
switch label {
|
||||||
|
case "Alias":
|
||||||
|
return "mail.kp"
|
||||||
|
case "Display Name":
|
||||||
|
return "Production mail"
|
||||||
|
case "Host":
|
||||||
|
return "mail.example.org"
|
||||||
|
case "Port":
|
||||||
|
return "22"
|
||||||
|
case "User":
|
||||||
|
return "root"
|
||||||
|
case "Auth Method (password/key/key_passphrase/agent)":
|
||||||
|
return "key"
|
||||||
|
case "Identity File":
|
||||||
|
return "~/.ssh/id_ed25519"
|
||||||
|
case "ProxyJump":
|
||||||
|
return "optional"
|
||||||
|
case "Group (type new or pick from list)":
|
||||||
|
return "KP"
|
||||||
|
case "Notes":
|
||||||
|
return "optional"
|
||||||
|
default:
|
||||||
|
return label
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStringList(values []string, title string, width, height int) list.Model {
|
||||||
|
items := make([]list.Item, len(values))
|
||||||
|
for i, value := range values {
|
||||||
|
items[i] = groupItem{name: value}
|
||||||
|
}
|
||||||
|
l := list.New(items, list.NewDefaultDelegate(), width, height)
|
||||||
|
l.SetShowStatusBar(false)
|
||||||
|
l.SetShowHelp(false)
|
||||||
|
l.SetShowPagination(false)
|
||||||
|
l.Title = title
|
||||||
|
l.Styles.Title = titleStyle
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
func newEditFormModel(s *model.Server, w, h int) *formModel {
|
func newEditFormModel(s *model.Server, w, h int) *formModel {
|
||||||
fm := newFormModel(w, h)
|
fm := newFormModel(w, h)
|
||||||
fm.edit = true
|
fm.edit = true
|
||||||
|
|
@ -453,6 +684,21 @@ func newEditFormModel(s *model.Server, w, h int) *formModel {
|
||||||
fm.inputs[7].SetValue(s.ProxyJump)
|
fm.inputs[7].SetValue(s.ProxyJump)
|
||||||
fm.inputs[8].SetValue(s.GroupName)
|
fm.inputs[8].SetValue(s.GroupName)
|
||||||
fm.inputs[9].SetValue(s.Notes)
|
fm.inputs[9].SetValue(s.Notes)
|
||||||
|
if HasSecret != nil {
|
||||||
|
switch s.AuthMethod {
|
||||||
|
case model.AuthPassword:
|
||||||
|
if HasSecret(s.Alias, "ssh_password") {
|
||||||
|
fm.passwordLabel = "Password (secret saved; leave blank to keep)"
|
||||||
|
fm.password.Placeholder = ""
|
||||||
|
}
|
||||||
|
case model.AuthKeyPassphrase:
|
||||||
|
if HasSecret(s.Alias, "key_passphrase") {
|
||||||
|
fm.passwordLabel = "Key passphrase (secret saved; leave blank to keep)"
|
||||||
|
fm.password.Placeholder = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fm.updateFocus()
|
||||||
return fm
|
return fm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -498,6 +744,48 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
return fm, cmd
|
return fm, cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle group dropdown
|
||||||
|
if fm.showGroupList {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
switch msg.Type {
|
||||||
|
case tea.KeyEsc:
|
||||||
|
fm.showGroupList = false
|
||||||
|
return fm, nil
|
||||||
|
case tea.KeyEnter:
|
||||||
|
if item, ok := fm.groupList.SelectedItem().(groupItem); ok {
|
||||||
|
fm.inputs[8].SetValue(item.name)
|
||||||
|
}
|
||||||
|
fm.showGroupList = false
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Pass other keys to the list
|
||||||
|
var cmd tea.Cmd
|
||||||
|
fm.groupList, cmd = fm.groupList.Update(msg)
|
||||||
|
return fm, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
if fm.showAuthList {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
switch msg.Type {
|
||||||
|
case tea.KeyEsc:
|
||||||
|
fm.showAuthList = false
|
||||||
|
return fm, nil
|
||||||
|
case tea.KeyEnter:
|
||||||
|
if item, ok := fm.authList.SelectedItem().(groupItem); ok {
|
||||||
|
fm.inputs[5].SetValue(item.name)
|
||||||
|
}
|
||||||
|
fm.showAuthList = false
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var cmd tea.Cmd
|
||||||
|
fm.authList, cmd = fm.authList.Update(msg)
|
||||||
|
return fm, cmd
|
||||||
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case tea.KeyMsg:
|
case tea.KeyMsg:
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
|
|
@ -519,6 +807,17 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
fm.updateFocus()
|
fm.updateFocus()
|
||||||
return fm, nil
|
return fm, nil
|
||||||
|
|
||||||
|
case tea.KeyRunes:
|
||||||
|
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 5 {
|
||||||
|
fm.showAuthList = true
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
// '/' on Group field opens group dropdown
|
||||||
|
if len(msg.Runes) == 1 && msg.Runes[0] == '/' && !msg.Alt && fm.focusIdx == 8 && len(fm.groups) > 0 {
|
||||||
|
fm.showGroupList = true
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
|
||||||
case tea.KeyEnter:
|
case tea.KeyEnter:
|
||||||
switch {
|
switch {
|
||||||
case fm.focusIdx == len(fm.inputs)+1:
|
case fm.focusIdx == len(fm.inputs)+1:
|
||||||
|
|
@ -576,20 +875,36 @@ func (fm *formModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
func (fm *formModel) updateFocus() {
|
func (fm *formModel) updateFocus() {
|
||||||
for i := range fm.inputs {
|
for i := range fm.inputs {
|
||||||
fm.inputs[i].Blur()
|
fm.inputs[i].Blur()
|
||||||
fm.inputs[i].Prompt = blurredStyle.Render(fm.inputs[i].Placeholder + ": ")
|
fm.inputs[i].Prompt = blurredStyle.Render(fm.labelAt(i) + ": ")
|
||||||
}
|
}
|
||||||
fm.password.Blur()
|
fm.password.Blur()
|
||||||
fm.password.Prompt = blurredStyle.Render(fm.password.Placeholder + ": ")
|
fm.password.Prompt = blurredStyle.Render(fm.passwordLabel + ": ")
|
||||||
|
|
||||||
if fm.focusIdx < len(fm.inputs) {
|
if fm.focusIdx < len(fm.inputs) {
|
||||||
fm.inputs[fm.focusIdx].Focus()
|
fm.inputs[fm.focusIdx].Focus()
|
||||||
fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.inputs[fm.focusIdx].Placeholder + "> ")
|
fm.inputs[fm.focusIdx].Prompt = focusedStyle.Render(fm.labelAt(fm.focusIdx) + "> ")
|
||||||
} else if fm.focusIdx == len(fm.inputs) {
|
} else if fm.focusIdx == len(fm.inputs) {
|
||||||
fm.password.Focus()
|
fm.password.Focus()
|
||||||
fm.password.Prompt = focusedStyle.Render(fm.password.Placeholder + "> ")
|
fm.password.Prompt = focusedStyle.Render(fm.passwordLabel + "> ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fm *formModel) labelAt(index int) string {
|
||||||
|
if index >= 0 && index < len(fm.labels) {
|
||||||
|
if index == 5 {
|
||||||
|
return "Auth Method (/ pick)"
|
||||||
|
}
|
||||||
|
if index == 8 {
|
||||||
|
if len(fm.groups) > 0 {
|
||||||
|
return "Group (/ pick)"
|
||||||
|
}
|
||||||
|
return "Group"
|
||||||
|
}
|
||||||
|
return fm.labels[index]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (fm *formModel) runTest() tea.Cmd {
|
func (fm *formModel) runTest() tea.Cmd {
|
||||||
fm.testing = true
|
fm.testing = true
|
||||||
fm.testResult = ""
|
fm.testResult = ""
|
||||||
|
|
@ -635,7 +950,11 @@ func (fm *formModel) runSave() tea.Cmd {
|
||||||
if s.Host == "" {
|
if s.Host == "" {
|
||||||
return saveDoneMsg{err: fmt.Errorf("host is required")}
|
return saveDoneMsg{err: fmt.Errorf("host is required")}
|
||||||
}
|
}
|
||||||
err := SaveServer(s, pw)
|
oldAlias := ""
|
||||||
|
if fm.edit && fm.server != nil {
|
||||||
|
oldAlias = fm.server.Alias
|
||||||
|
}
|
||||||
|
err := SaveServer(s, pw, oldAlias)
|
||||||
return saveDoneMsg{err: err}
|
return saveDoneMsg{err: err}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -672,13 +991,72 @@ func (fm *formModel) View() string {
|
||||||
b.WriteString(titleStyle.Render(title))
|
b.WriteString(titleStyle.Render(title))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
for i := range fm.inputs {
|
// Calculate visible range based on terminal height
|
||||||
|
// Reserve lines for: title (2) + password (1) + buttons (3) + help (1) + padding (2) = ~9
|
||||||
|
reserved := 9
|
||||||
|
available := fm.height - reserved
|
||||||
|
if available < 4 {
|
||||||
|
available = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
numInputs := len(fm.inputs)
|
||||||
|
startIdx := 0
|
||||||
|
endIdx := numInputs
|
||||||
|
|
||||||
|
// Scroll: keep focused field visible
|
||||||
|
if numInputs > available {
|
||||||
|
focusInput := fm.focusIdx
|
||||||
|
if focusInput >= numInputs {
|
||||||
|
focusInput = numInputs - 1
|
||||||
|
}
|
||||||
|
// Try to show `available` fields centered on focus
|
||||||
|
startIdx = focusInput - available/2
|
||||||
|
if startIdx < 0 {
|
||||||
|
startIdx = 0
|
||||||
|
}
|
||||||
|
endIdx = startIdx + available
|
||||||
|
if endIdx > numInputs {
|
||||||
|
endIdx = numInputs
|
||||||
|
startIdx = endIdx - available
|
||||||
|
if startIdx < 0 {
|
||||||
|
startIdx = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show scroll indicator if needed
|
||||||
|
if startIdx > 0 {
|
||||||
|
b.WriteString(helpStyle.Render(" ↑ more fields above\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := startIdx; i < endIdx; i++ {
|
||||||
|
if section := formSectionTitle(i); section != "" {
|
||||||
|
b.WriteString(sectionStyle.Render(section))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
if i == 5 {
|
||||||
|
fm.inputs[i].Placeholder = "password/key/key_passphrase/agent"
|
||||||
|
}
|
||||||
|
// Show group hint inline in placeholder for Group field
|
||||||
|
if i == 8 && len(fm.groups) > 0 && !fm.showGroupList {
|
||||||
|
fm.inputs[i].Placeholder = truncate(strings.Join(fm.groups, ", "), 25)
|
||||||
|
}
|
||||||
b.WriteString(fm.inputs[i].View())
|
b.WriteString(fm.inputs[i].View())
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
// Show existing groups hint under Group field
|
if i == 5 && fm.showAuthList {
|
||||||
if i == 8 && fm.groups != "" {
|
b.WriteString("\n" + renderDropdown(fm.authList) + "\n")
|
||||||
b.WriteString(helpStyle.Render(" Groups: " + fm.groups + "\n"))
|
b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
if i == 8 && fm.showGroupList {
|
||||||
|
b.WriteString("\n" + renderDropdown(fm.groupList) + "\n")
|
||||||
|
b.WriteString(helpStyle.Render("Enter select | Esc cancel"))
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if endIdx < numInputs {
|
||||||
|
b.WriteString(helpStyle.Render(fmt.Sprintf(" ↓ more fields below (%d-%d of %d)\n", startIdx+1, endIdx, numInputs)))
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString(fm.password.View())
|
b.WriteString(fm.password.View())
|
||||||
|
|
@ -723,8 +1101,53 @@ func (fm *formModel) View() string {
|
||||||
saveBtn = normalStyle.Render(saveBtn)
|
saveBtn = normalStyle.Render(saveBtn)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString("\n" + testBtn + " " + saveBtn + "\n\n")
|
b.WriteString("\n" + sectionStyle.Render("Actions") + "\n")
|
||||||
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | Enter select | Esc back"))
|
b.WriteString(testBtn + " " + saveBtn + "\n\n")
|
||||||
|
b.WriteString(helpStyle.Render("Tab/↓ next | ↑ prev | / pick list | Enter select | Esc back"))
|
||||||
|
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func renderDropdown(l list.Model) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(sectionStyle.Render(l.Title))
|
||||||
|
b.WriteString("\n")
|
||||||
|
for i, item := range l.Items() {
|
||||||
|
group, ok := item.(groupItem)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prefix := " "
|
||||||
|
style := normalStyle
|
||||||
|
if i == l.Index() {
|
||||||
|
prefix = "> "
|
||||||
|
style = selectedRowStyle
|
||||||
|
}
|
||||||
|
b.WriteString(style.Render(prefix + group.name))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
return strings.TrimRight(b.String(), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func formSectionTitle(index int) string {
|
||||||
|
switch index {
|
||||||
|
case 0:
|
||||||
|
return "Identity"
|
||||||
|
case 2:
|
||||||
|
return "Connection"
|
||||||
|
case 5:
|
||||||
|
return "Authentication"
|
||||||
|
case 8:
|
||||||
|
return "Metadata"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncate limits a string to maxLen, adding "..." if truncated
|
||||||
|
func truncate(s string, maxLen int) string {
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen-3] + "..."
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,322 @@
|
||||||
|
package tui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/list"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerListViewUsesDashboardLayout(t *testing.T) {
|
||||||
|
now := time.Date(2026, 5, 28, 1, 50, 0, 0, time.UTC)
|
||||||
|
m := New([]*model.Server{
|
||||||
|
{
|
||||||
|
Alias: "mail.kp",
|
||||||
|
DisplayName: "Mail",
|
||||||
|
Host: "mail.example.org",
|
||||||
|
Port: 222,
|
||||||
|
User: "mirivlad",
|
||||||
|
AuthMethod: model.AuthPassword,
|
||||||
|
GroupName: "KP",
|
||||||
|
LastTestStatus: model.TestOK,
|
||||||
|
LastTestAt: &now,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Alias: "mirv.top",
|
||||||
|
Host: "mirv.top",
|
||||||
|
Port: 22,
|
||||||
|
User: "root",
|
||||||
|
AuthMethod: model.AuthKey,
|
||||||
|
LastTestStatus: model.TestUnknown,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.width = 100
|
||||||
|
m.height = 30
|
||||||
|
m.list.SetSize(100, 24)
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
for _, want := range []string{
|
||||||
|
"sshkeeper",
|
||||||
|
"2 servers",
|
||||||
|
"Vault",
|
||||||
|
"NAME",
|
||||||
|
"TARGET",
|
||||||
|
"AUTH",
|
||||||
|
"GROUP",
|
||||||
|
"STATUS",
|
||||||
|
"Mail",
|
||||||
|
"mail.kp",
|
||||||
|
"mirivlad@mail.example.org:222",
|
||||||
|
"KP",
|
||||||
|
"OK",
|
||||||
|
"Selected",
|
||||||
|
"Host: mail.example.org",
|
||||||
|
"Alias: mail.kp",
|
||||||
|
"Display Name: Mail",
|
||||||
|
"Port: 222",
|
||||||
|
"Enter connect",
|
||||||
|
} {
|
||||||
|
if !strings.Contains(view, want) {
|
||||||
|
t.Fatalf("expected list view to contain %q\nview:\n%s", want, view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(view, "Profiles managed locally") {
|
||||||
|
t.Fatalf("expected compact status header instead of README text\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEscClosesGroupListBeforeLeavingForm(t *testing.T) {
|
||||||
|
oldGetGroups := GetGroups
|
||||||
|
GetGroups = func() ([]string, error) {
|
||||||
|
return []string{"prod", "stage"}, nil
|
||||||
|
}
|
||||||
|
defer func() { GetGroups = oldGetGroups }()
|
||||||
|
|
||||||
|
m := &tuiModel{
|
||||||
|
screen: screenForm,
|
||||||
|
form: newFormModel(80, 24),
|
||||||
|
}
|
||||||
|
m.form.focusIdx = 8
|
||||||
|
m.form.updateFocus()
|
||||||
|
|
||||||
|
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||||
|
m = updated.(*tuiModel)
|
||||||
|
if !m.form.showGroupList {
|
||||||
|
t.Fatal("expected / on the group field to open the group list")
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
|
||||||
|
m = updated.(*tuiModel)
|
||||||
|
|
||||||
|
if m.screen != screenForm {
|
||||||
|
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
|
||||||
|
}
|
||||||
|
if m.form == nil {
|
||||||
|
t.Fatal("expected form to remain open")
|
||||||
|
}
|
||||||
|
if m.form.showGroupList {
|
||||||
|
t.Fatal("expected Esc to close only the group list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEscClosesAuthMethodListBeforeLeavingForm(t *testing.T) {
|
||||||
|
m := &tuiModel{
|
||||||
|
screen: screenForm,
|
||||||
|
form: newFormModel(80, 24),
|
||||||
|
}
|
||||||
|
m.form.focusIdx = 5
|
||||||
|
m.form.updateFocus()
|
||||||
|
|
||||||
|
updated, _ := m.updateForm(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||||
|
m = updated.(*tuiModel)
|
||||||
|
if !m.form.showAuthList {
|
||||||
|
t.Fatal("expected / on the auth method field to open the auth method list")
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, _ = m.updateForm(tea.KeyMsg{Type: tea.KeyEsc})
|
||||||
|
m = updated.(*tuiModel)
|
||||||
|
|
||||||
|
if m.screen != screenForm {
|
||||||
|
t.Fatalf("expected Esc to keep the user in the form, got screen %v", m.screen)
|
||||||
|
}
|
||||||
|
if m.form == nil {
|
||||||
|
t.Fatal("expected form to remain open")
|
||||||
|
}
|
||||||
|
if m.form.showAuthList {
|
||||||
|
t.Fatal("expected Esc to close only the auth method list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMethodListSelectsValue(t *testing.T) {
|
||||||
|
fm := newFormModel(80, 24)
|
||||||
|
fm.focusIdx = 5
|
||||||
|
fm.updateFocus()
|
||||||
|
|
||||||
|
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||||
|
fm = updated.(*formModel)
|
||||||
|
if !fm.showAuthList {
|
||||||
|
t.Fatal("expected / on the auth method field to open the auth method list")
|
||||||
|
}
|
||||||
|
|
||||||
|
fm.authList.Select(2)
|
||||||
|
updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
fm = updated.(*formModel)
|
||||||
|
|
||||||
|
if fm.showAuthList {
|
||||||
|
t.Fatal("expected Enter to close auth method list")
|
||||||
|
}
|
||||||
|
if got := fm.inputs[5].Value(); got != string(model.AuthKeyPassphrase) {
|
||||||
|
t.Fatalf("expected auth method %q, got %q", model.AuthKeyPassphrase, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMethodListViewShowsAllOptions(t *testing.T) {
|
||||||
|
fm := newFormModel(80, 12)
|
||||||
|
fm.focusIdx = 5
|
||||||
|
fm.updateFocus()
|
||||||
|
|
||||||
|
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||||
|
fm = updated.(*formModel)
|
||||||
|
|
||||||
|
view := fm.View()
|
||||||
|
authPos := strings.Index(view, "Auth Method")
|
||||||
|
listPos := strings.Index(view, "Select auth method")
|
||||||
|
if authPos < 0 || listPos < 0 {
|
||||||
|
t.Fatalf("expected auth field and auth list title in view\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if listPos < authPos {
|
||||||
|
t.Fatalf("expected auth method list after auth field\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if between := view[authPos:listPos]; strings.Contains(between, "Identity File") {
|
||||||
|
t.Fatalf("expected auth method list to render directly under auth field\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if strings.Contains(view, "│") {
|
||||||
|
t.Fatalf("expected compact auth method dropdown without default list border\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
for _, method := range []model.AuthMethod{
|
||||||
|
model.AuthPassword,
|
||||||
|
model.AuthKey,
|
||||||
|
model.AuthKeyPassphrase,
|
||||||
|
model.AuthAgent,
|
||||||
|
} {
|
||||||
|
if !strings.Contains(view, string(method)) {
|
||||||
|
t.Fatalf("expected auth method list view to contain %q\nview:\n%s", method, view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupListViewRendersDirectlyUnderGroupField(t *testing.T) {
|
||||||
|
oldGetGroups := GetGroups
|
||||||
|
GetGroups = func() ([]string, error) {
|
||||||
|
return []string{"KP", "MY"}, nil
|
||||||
|
}
|
||||||
|
defer func() { GetGroups = oldGetGroups }()
|
||||||
|
|
||||||
|
fm := newFormModel(80, 24)
|
||||||
|
fm.focusIdx = 8
|
||||||
|
fm.updateFocus()
|
||||||
|
|
||||||
|
updated, _ := fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'/'}})
|
||||||
|
fm = updated.(*formModel)
|
||||||
|
|
||||||
|
view := fm.View()
|
||||||
|
groupPos := strings.Index(view, "Group")
|
||||||
|
listPos := strings.Index(view, "Select group")
|
||||||
|
if groupPos < 0 || listPos < 0 {
|
||||||
|
t.Fatalf("expected group field and group list title in view\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if listPos < groupPos {
|
||||||
|
t.Fatalf("expected group list after group field\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if between := view[groupPos:listPos]; strings.Contains(between, "Password") {
|
||||||
|
t.Fatalf("expected group dropdown to render before password field\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if strings.Contains(view, "│") {
|
||||||
|
t.Fatalf("expected compact group dropdown without default list border\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectableFieldHintsAreVisible(t *testing.T) {
|
||||||
|
oldGetGroups := GetGroups
|
||||||
|
GetGroups = func() ([]string, error) {
|
||||||
|
return []string{"KP"}, nil
|
||||||
|
}
|
||||||
|
defer func() { GetGroups = oldGetGroups }()
|
||||||
|
|
||||||
|
fm := newFormModel(80, 24)
|
||||||
|
view := fm.View()
|
||||||
|
for _, want := range []string{"Auth Method (/ pick)", "Group (/ pick)", "/ pick list"} {
|
||||||
|
if !strings.Contains(view, want) {
|
||||||
|
t.Fatalf("expected form view to contain selectable-field hint %q\nview:\n%s", want, view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEditFormShowsSavedSecretMarker(t *testing.T) {
|
||||||
|
oldHasSecret := HasSecret
|
||||||
|
HasSecret = func(alias string, secretType string) bool {
|
||||||
|
return alias == "prod" && secretType == "ssh_password"
|
||||||
|
}
|
||||||
|
defer func() { HasSecret = oldHasSecret }()
|
||||||
|
|
||||||
|
fm := newEditFormModel(&model.Server{
|
||||||
|
Alias: "prod",
|
||||||
|
Host: "example.org",
|
||||||
|
Port: 22,
|
||||||
|
User: "root",
|
||||||
|
AuthMethod: model.AuthPassword,
|
||||||
|
}, 80, 24)
|
||||||
|
|
||||||
|
view := fm.View()
|
||||||
|
if !strings.Contains(view, "secret saved") {
|
||||||
|
t.Fatalf("expected edit form to show saved secret marker\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "leave blank to keep") {
|
||||||
|
t.Fatalf("expected edit form to explain blank password keeps saved secret\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
if strings.Count(view, "secret saved") != 1 {
|
||||||
|
t.Fatalf("expected saved secret marker to appear once\nview:\n%s", view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormViewUsesSectionsAndStableLabels(t *testing.T) {
|
||||||
|
fm := newFormModel(100, 30)
|
||||||
|
view := fm.View()
|
||||||
|
|
||||||
|
for _, want := range []string{
|
||||||
|
"Identity",
|
||||||
|
"Connection",
|
||||||
|
"Authentication",
|
||||||
|
"Metadata",
|
||||||
|
"Actions",
|
||||||
|
"Alias",
|
||||||
|
"Display Name",
|
||||||
|
"Auth Method",
|
||||||
|
"Password / Passphrase",
|
||||||
|
} {
|
||||||
|
if !strings.Contains(view, want) {
|
||||||
|
t.Fatalf("expected form view to contain %q\nview:\n%s", want, view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormTestResultDoesNotUpdateSelectedListServer(t *testing.T) {
|
||||||
|
oldUpdateTestResult := UpdateTestResult
|
||||||
|
oldListServers := ListServers
|
||||||
|
defer func() {
|
||||||
|
UpdateTestResult = oldUpdateTestResult
|
||||||
|
ListServers = oldListServers
|
||||||
|
}()
|
||||||
|
|
||||||
|
updateCalled := false
|
||||||
|
UpdateTestResult = func(alias string, status model.TestStatus, testErr string) error {
|
||||||
|
updateCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ListServers = func() ([]*model.Server, error) {
|
||||||
|
t.Fatal("form test result should not reload server list")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
selected := &model.Server{Alias: "selected", Host: "example.org", Port: 22, User: "root"}
|
||||||
|
m := New([]*model.Server{selected})
|
||||||
|
m.screen = screenForm
|
||||||
|
m.form = newFormModel(80, 24)
|
||||||
|
m.list = list.New([]list.Item{serverItem{server: selected}}, list.NewDefaultDelegate(), 80, 20)
|
||||||
|
|
||||||
|
updated, cmd := m.Update(testDoneMsg{ok: true})
|
||||||
|
m = updated.(*tuiModel)
|
||||||
|
|
||||||
|
if cmd != nil {
|
||||||
|
t.Fatal("form test result should not return a reload command")
|
||||||
|
}
|
||||||
|
if updateCalled {
|
||||||
|
t.Fatal("form test result should not update the selected list server")
|
||||||
|
}
|
||||||
|
if m.form == nil || m.form.testResult != "Connection OK." {
|
||||||
|
t.Fatal("expected form to keep its test result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -16,18 +18,20 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
currentVersion = 1
|
currentVersion = 1
|
||||||
saltLen = 32
|
saltLen = 32
|
||||||
nonceLen = 24
|
nonceLen = 24
|
||||||
keyLen = 32
|
keyLen = 32
|
||||||
|
verifierID = "__sshkeeper_vault_verifier__"
|
||||||
|
verifierPlaintext = "sshkeeper-vault-verifier-v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KDFMeta struct {
|
type KDFMeta struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
MemoryKiB int `json:"memory_kib"`
|
MemoryKiB int `json:"memory_kib"`
|
||||||
Iterations int `json:"iterations"`
|
Iterations int `json:"iterations"`
|
||||||
Parallelism int `json:"parallelism"`
|
Parallelism int `json:"parallelism"`
|
||||||
Salt string `json:"salt"`
|
Salt string `json:"salt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Record struct {
|
type Record struct {
|
||||||
|
|
@ -38,23 +42,35 @@ type Record struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type VaultFile struct {
|
type VaultFile struct {
|
||||||
Version int `json:"version"`
|
Version int `json:"version"`
|
||||||
KDF KDFMeta `json:"kdf"`
|
KDF KDFMeta `json:"kdf"`
|
||||||
Records []Record `json:"records"`
|
Verifier *Record `json:"verifier,omitempty"`
|
||||||
|
Records []Record `json:"records"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vault struct {
|
type Vault struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
path string
|
path string
|
||||||
masterKey []byte
|
masterKey []byte
|
||||||
records map[string][]byte // id -> plaintext
|
records map[string]secretRecord
|
||||||
modified bool
|
modified bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type secretRecord struct {
|
||||||
|
secretType string
|
||||||
|
plaintext []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type SecretMeta struct {
|
||||||
|
ID string
|
||||||
|
Alias string
|
||||||
|
Type string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(path string) *Vault {
|
func New(path string) *Vault {
|
||||||
return &Vault{
|
return &Vault{
|
||||||
path: path,
|
path: path,
|
||||||
records: make(map[string][]byte),
|
records: make(map[string]secretRecord),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,21 +92,26 @@ func Create(path string, masterPassword string) error {
|
||||||
|
|
||||||
kdf := KDFMeta{
|
kdf := KDFMeta{
|
||||||
Name: "argon2id",
|
Name: "argon2id",
|
||||||
MemoryKiB: 4096,
|
MemoryKiB: 65536,
|
||||||
Iterations: 2,
|
Iterations: 3,
|
||||||
Parallelism: 1,
|
Parallelism: 1,
|
||||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Print("Deriving key...")
|
fmt.Print("Deriving key...")
|
||||||
|
|
||||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB)*1024, uint8(kdf.Parallelism), keyLen)
|
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB), uint8(kdf.Parallelism), keyLen)
|
||||||
|
|
||||||
|
verifier, err := newVerifierRecord(key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create verifier: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Verify key is valid by doing a test encrypt/decrypt
|
|
||||||
vf := VaultFile{
|
vf := VaultFile{
|
||||||
Version: currentVersion,
|
Version: currentVersion,
|
||||||
KDF: kdf,
|
KDF: kdf,
|
||||||
Records: []Record{},
|
Verifier: &verifier,
|
||||||
|
Records: []Record{},
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(vf)
|
data, err := json.Marshal(vf)
|
||||||
|
|
@ -143,24 +164,32 @@ func (v *Vault) Unlock(masterPassword string) error {
|
||||||
return fmt.Errorf("decode salt: %w", err)
|
return fmt.Errorf("decode salt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen)
|
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
|
||||||
|
|
||||||
// Try to decrypt first record to verify password
|
if vf.Verifier != nil {
|
||||||
if len(vf.Records) > 0 {
|
if err := verifyRecord(key, *vf.Verifier); err != nil {
|
||||||
|
return fmt.Errorf("invalid master password")
|
||||||
|
}
|
||||||
|
} else if len(vf.Records) > 0 {
|
||||||
if _, err := decryptRecord(key, vf.Records[0]); err != nil {
|
if _, err := decryptRecord(key, vf.Records[0]); err != nil {
|
||||||
return fmt.Errorf("invalid master password")
|
return fmt.Errorf("invalid master password")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("vault cannot verify master password; recreate empty vault")
|
||||||
}
|
}
|
||||||
|
|
||||||
v.masterKey = key
|
v.masterKey = key
|
||||||
v.records = make(map[string][]byte)
|
v.records = make(map[string]secretRecord)
|
||||||
|
|
||||||
for _, rec := range vf.Records {
|
for _, rec := range vf.Records {
|
||||||
plaintext, err := decryptRecord(key, rec)
|
plaintext, err := decryptRecord(key, rec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("decrypt record %s: %w", rec.ID, err)
|
return fmt.Errorf("decrypt record %s: %w", rec.ID, err)
|
||||||
}
|
}
|
||||||
v.records[rec.ID] = plaintext
|
v.records[rec.ID] = secretRecord{
|
||||||
|
secretType: inferSecretType(rec.ID, rec.Type),
|
||||||
|
plaintext: plaintext,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(" done.")
|
fmt.Println(" done.")
|
||||||
|
|
@ -178,7 +207,7 @@ func (v *Vault) Lock() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
v.masterKey = nil
|
v.masterKey = nil
|
||||||
v.records = make(map[string][]byte)
|
v.records = make(map[string]secretRecord)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUnlocked returns whether the vault is currently unlocked
|
// IsUnlocked returns whether the vault is currently unlocked
|
||||||
|
|
@ -197,7 +226,9 @@ func (v *Vault) Put(id string, secretType string, plaintext []byte) error {
|
||||||
return fmt.Errorf("vault is locked")
|
return fmt.Errorf("vault is locked")
|
||||||
}
|
}
|
||||||
|
|
||||||
v.records[id] = plaintext
|
data := make([]byte, len(plaintext))
|
||||||
|
copy(data, plaintext)
|
||||||
|
v.records[id] = secretRecord{secretType: secretType, plaintext: data}
|
||||||
v.modified = true
|
v.modified = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -211,16 +242,55 @@ func (v *Vault) Get(id string) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("vault is locked")
|
return nil, fmt.Errorf("vault is locked")
|
||||||
}
|
}
|
||||||
|
|
||||||
data, ok := v.records[id]
|
record, ok := v.records[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("secret not found: %s", id)
|
return nil, fmt.Errorf("secret not found: %s", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]byte, len(data))
|
result := make([]byte, len(record.plaintext))
|
||||||
copy(result, data)
|
copy(result, record.plaintext)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *Vault) HasSecret(id string) bool {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
_, ok := v.records[id]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vault) ListSecrets() ([]SecretMeta, error) {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
|
||||||
|
if v.masterKey == nil {
|
||||||
|
return nil, fmt.Errorf("vault is locked")
|
||||||
|
}
|
||||||
|
|
||||||
|
metas := make([]SecretMeta, 0, len(v.records))
|
||||||
|
for id, record := range v.records {
|
||||||
|
alias, secretType, ok := parseServerSecretID(id)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if record.secretType != "" {
|
||||||
|
secretType = record.secretType
|
||||||
|
}
|
||||||
|
metas = append(metas, SecretMeta{
|
||||||
|
ID: id,
|
||||||
|
Alias: alias,
|
||||||
|
Type: secretType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
sort.Slice(metas, func(i, j int) bool {
|
||||||
|
if metas[i].Alias == metas[j].Alias {
|
||||||
|
return metas[i].Type < metas[j].Type
|
||||||
|
}
|
||||||
|
return metas[i].Alias < metas[j].Alias
|
||||||
|
})
|
||||||
|
return metas, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a secret
|
// Delete removes a secret
|
||||||
func (v *Vault) Delete(id string) {
|
func (v *Vault) Delete(id string) {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
|
|
@ -245,8 +315,8 @@ func (v *Vault) Save() error {
|
||||||
|
|
||||||
kdf := KDFMeta{
|
kdf := KDFMeta{
|
||||||
Name: "argon2id",
|
Name: "argon2id",
|
||||||
MemoryKiB: 4096,
|
MemoryKiB: 65536,
|
||||||
Iterations: 2,
|
Iterations: 3,
|
||||||
Parallelism: 1,
|
Parallelism: 1,
|
||||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||||
}
|
}
|
||||||
|
|
@ -254,18 +324,24 @@ func (v *Vault) Save() error {
|
||||||
fmt.Print("Deriving key...")
|
fmt.Print("Deriving key...")
|
||||||
|
|
||||||
var records []Record
|
var records []Record
|
||||||
for id, plaintext := range v.records {
|
for id, record := range v.records {
|
||||||
rec, err := encryptRecord(v.masterKey, id, plaintext)
|
rec, err := encryptRecordWithType(v.masterKey, id, record.secretType, record.plaintext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("encrypt record %s: %w", id, err)
|
return fmt.Errorf("encrypt record %s: %w", id, err)
|
||||||
}
|
}
|
||||||
records = append(records, rec)
|
records = append(records, rec)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
verifier, err := newVerifierRecord(v.masterKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create verifier: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
vf := VaultFile{
|
vf := VaultFile{
|
||||||
Version: currentVersion,
|
Version: currentVersion,
|
||||||
KDF: kdf,
|
KDF: kdf,
|
||||||
Records: records,
|
Verifier: &verifier,
|
||||||
|
Records: records,
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(vf)
|
data, err := json.Marshal(vf)
|
||||||
|
|
@ -309,12 +385,12 @@ func (v *Vault) ChangePassword(newPassword string) error {
|
||||||
return fmt.Errorf("generate salt: %w", err)
|
return fmt.Errorf("generate salt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 8192*1024, 1, keyLen)
|
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 65536, 1, keyLen)
|
||||||
|
|
||||||
kdf := KDFMeta{
|
kdf := KDFMeta{
|
||||||
Name: "argon2id",
|
Name: "argon2id",
|
||||||
MemoryKiB: 4096,
|
MemoryKiB: 65536,
|
||||||
Iterations: 2,
|
Iterations: 3,
|
||||||
Parallelism: 1,
|
Parallelism: 1,
|
||||||
Salt: base64.StdEncoding.EncodeToString(salt),
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||||
}
|
}
|
||||||
|
|
@ -322,18 +398,24 @@ func (v *Vault) ChangePassword(newPassword string) error {
|
||||||
fmt.Print("Deriving key...")
|
fmt.Print("Deriving key...")
|
||||||
|
|
||||||
var records []Record
|
var records []Record
|
||||||
for id, plaintext := range v.records {
|
for id, record := range v.records {
|
||||||
rec, err := encryptRecord(newKey, id, plaintext)
|
rec, err := encryptRecordWithType(newKey, id, record.secretType, record.plaintext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("encrypt record: %w", err)
|
return fmt.Errorf("encrypt record: %w", err)
|
||||||
}
|
}
|
||||||
records = append(records, rec)
|
records = append(records, rec)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
verifier, err := newVerifierRecord(newKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create verifier: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
vf := VaultFile{
|
vf := VaultFile{
|
||||||
Version: currentVersion,
|
Version: currentVersion,
|
||||||
KDF: kdf,
|
KDF: kdf,
|
||||||
Records: records,
|
Verifier: &verifier,
|
||||||
|
Records: records,
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(vf)
|
data, err := json.Marshal(vf)
|
||||||
|
|
@ -382,6 +464,10 @@ func (v *Vault) getSalt() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
|
func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
|
||||||
|
return encryptRecordWithType(key, id, "", plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encryptRecordWithType(key []byte, id string, secretType string, plaintext []byte) (Record, error) {
|
||||||
aead, err := chacha20poly1305.NewX(key)
|
aead, err := chacha20poly1305.NewX(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Record{}, err
|
return Record{}, err
|
||||||
|
|
@ -396,11 +482,31 @@ func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
|
||||||
|
|
||||||
return Record{
|
return Record{
|
||||||
ID: id,
|
ID: id,
|
||||||
|
Type: secretType,
|
||||||
Nonce: base64.StdEncoding.EncodeToString(nonce),
|
Nonce: base64.StdEncoding.EncodeToString(nonce),
|
||||||
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
|
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func inferSecretType(id string, recordType string) string {
|
||||||
|
if recordType != "" {
|
||||||
|
return recordType
|
||||||
|
}
|
||||||
|
_, secretType, ok := parseServerSecretID(id)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return secretType
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServerSecretID(id string) (string, string, bool) {
|
||||||
|
parts := strings.Split(id, ":")
|
||||||
|
if len(parts) != 3 || parts[0] != "server" || parts[1] == "" || parts[2] == "" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return parts[1], parts[2], true
|
||||||
|
}
|
||||||
|
|
||||||
func decryptRecord(key []byte, rec Record) ([]byte, error) {
|
func decryptRecord(key []byte, rec Record) ([]byte, error) {
|
||||||
aead, err := chacha20poly1305.NewX(key)
|
aead, err := chacha20poly1305.NewX(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -425,6 +531,26 @@ func decryptRecord(key []byte, rec Record) ([]byte, error) {
|
||||||
return plaintext, nil
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newVerifierRecord(key []byte) (Record, error) {
|
||||||
|
rec, err := encryptRecord(key, verifierID, []byte(verifierPlaintext))
|
||||||
|
if err != nil {
|
||||||
|
return Record{}, err
|
||||||
|
}
|
||||||
|
rec.Type = "verifier"
|
||||||
|
return rec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyRecord(key []byte, rec Record) error {
|
||||||
|
plaintext, err := decryptRecord(key, rec)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !SecureCompare(string(plaintext), verifierPlaintext) {
|
||||||
|
return fmt.Errorf("invalid verifier")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyPassword checks if a master password is correct without unlocking
|
// VerifyPassword checks if a master password is correct without unlocking
|
||||||
func VerifyPassword(path string, masterPassword string) (bool, error) {
|
func VerifyPassword(path string, masterPassword string) (bool, error) {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
|
|
@ -442,18 +568,20 @@ func VerifyPassword(path string, masterPassword string) (bool, error) {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB)*1024, uint8(vf.KDF.Parallelism), keyLen)
|
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
|
||||||
defer func() {
|
defer func() {
|
||||||
for i := range key {
|
for i := range key {
|
||||||
key[i] = 0
|
key[i] = 0
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if len(vf.Records) == 0 {
|
if vf.Verifier != nil {
|
||||||
// Empty vault, try a test encryption
|
return verifyRecord(key, *vf.Verifier) == nil, nil
|
||||||
return true, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(vf.Records) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
_, err = decryptRecord(key, vf.Records[0])
|
_, err = decryptRecord(key, vf.Records[0])
|
||||||
return err == nil, nil
|
return err == nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,218 @@
|
||||||
|
package vault
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewEmptyVaultRejectsWrongPassword(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
|
||||||
|
if err := Create(path, "correct horse"); err != nil {
|
||||||
|
t.Fatalf("create vault: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := VerifyPassword(path, "wrong horse")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verify wrong password: %v", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected wrong password to be rejected for a new empty vault")
|
||||||
|
}
|
||||||
|
|
||||||
|
v := New(path)
|
||||||
|
if err := v.Unlock("wrong horse"); err == nil {
|
||||||
|
t.Fatal("expected unlock with wrong password to fail for a new empty vault")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewEmptyVaultAcceptsCorrectPassword(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
|
||||||
|
if err := Create(path, "correct horse"); err != nil {
|
||||||
|
t.Fatalf("create vault: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := VerifyPassword(path, "correct horse")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verify correct password: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected correct password to be accepted for a new empty vault")
|
||||||
|
}
|
||||||
|
|
||||||
|
v := New(path)
|
||||||
|
if err := v.Unlock("correct horse"); err != nil {
|
||||||
|
t.Fatalf("unlock with correct password: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegacyVaultWithRecordsStillVerifiesByFirstRecord(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
salt := []byte("12345678901234567890123456789012")
|
||||||
|
key := argon2.IDKey([]byte("correct horse"), salt, 3, 65536, 1, keyLen)
|
||||||
|
|
||||||
|
rec, err := encryptRecord(key, "server:test:ssh_password", []byte("secret"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt legacy record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(VaultFile{
|
||||||
|
Version: currentVersion,
|
||||||
|
KDF: KDFMeta{
|
||||||
|
Name: "argon2id",
|
||||||
|
MemoryKiB: 65536,
|
||||||
|
Iterations: 3,
|
||||||
|
Parallelism: 1,
|
||||||
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||||
|
},
|
||||||
|
Records: []Record{rec},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal legacy vault: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||||
|
t.Fatalf("write legacy vault: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := VerifyPassword(path, "correct horse")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verify legacy vault: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected legacy vault with records to accept correct password")
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = VerifyPassword(path, "wrong horse")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verify legacy vault with wrong password: %v", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected legacy vault with records to reject wrong password")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegacyEmptyVaultWithoutVerifierCannotUnlock(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
salt := []byte("12345678901234567890123456789012")
|
||||||
|
|
||||||
|
data, err := json.Marshal(VaultFile{
|
||||||
|
Version: currentVersion,
|
||||||
|
KDF: KDFMeta{
|
||||||
|
Name: "argon2id",
|
||||||
|
MemoryKiB: 65536,
|
||||||
|
Iterations: 3,
|
||||||
|
Parallelism: 1,
|
||||||
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
||||||
|
},
|
||||||
|
Records: []Record{},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal legacy empty vault: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||||
|
t.Fatalf("write legacy empty vault: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := VerifyPassword(path, "any password")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verify legacy empty vault: %v", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected legacy empty vault without verifier to be unverifiable")
|
||||||
|
}
|
||||||
|
|
||||||
|
v := New(path)
|
||||||
|
if err := v.Unlock("any password"); err == nil {
|
||||||
|
t.Fatal("expected legacy empty vault without verifier to reject unlock")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListSecretsReturnsMetadataWithoutPlaintext(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
if err := Create(path, "master"); err != nil {
|
||||||
|
t.Fatalf("create vault: %v", err)
|
||||||
|
}
|
||||||
|
v := New(path)
|
||||||
|
if err := v.Unlock("master"); err != nil {
|
||||||
|
t.Fatalf("unlock vault: %v", err)
|
||||||
|
}
|
||||||
|
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
|
||||||
|
t.Fatalf("put password: %v", err)
|
||||||
|
}
|
||||||
|
if err := v.Put("server:prod:key_passphrase", "key_passphrase", []byte("secret-passphrase")); err != nil {
|
||||||
|
t.Fatalf("put passphrase: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
metas, err := v.ListSecrets()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list secrets: %v", err)
|
||||||
|
}
|
||||||
|
if len(metas) != 2 {
|
||||||
|
t.Fatalf("expected 2 secret metadata records, got %d: %#v", len(metas), metas)
|
||||||
|
}
|
||||||
|
if metas[0].Alias != "prod" || metas[0].Type != "key_passphrase" {
|
||||||
|
t.Fatalf("unexpected first metadata record: %#v", metas[0])
|
||||||
|
}
|
||||||
|
if metas[1].Alias != "prod" || metas[1].Type != "ssh_password" {
|
||||||
|
t.Fatalf("unexpected second metadata record: %#v", metas[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListSecretsPreservesTypesAfterSaveAndUnlock(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
if err := Create(path, "master"); err != nil {
|
||||||
|
t.Fatalf("create vault: %v", err)
|
||||||
|
}
|
||||||
|
v := New(path)
|
||||||
|
if err := v.Unlock("master"); err != nil {
|
||||||
|
t.Fatalf("unlock vault: %v", err)
|
||||||
|
}
|
||||||
|
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
|
||||||
|
t.Fatalf("put password: %v", err)
|
||||||
|
}
|
||||||
|
if err := v.Save(); err != nil {
|
||||||
|
t.Fatalf("save vault: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reopened := New(path)
|
||||||
|
if err := reopened.Unlock("master"); err != nil {
|
||||||
|
t.Fatalf("unlock reopened vault: %v", err)
|
||||||
|
}
|
||||||
|
metas, err := reopened.ListSecrets()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list reopened secrets: %v", err)
|
||||||
|
}
|
||||||
|
if len(metas) != 1 {
|
||||||
|
t.Fatalf("expected 1 secret metadata record, got %d: %#v", len(metas), metas)
|
||||||
|
}
|
||||||
|
if metas[0].ID != "server:prod:ssh_password" || metas[0].Alias != "prod" || metas[0].Type != "ssh_password" {
|
||||||
|
t.Fatalf("unexpected metadata after reopen: %#v", metas[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasSecretReportsPresenceWithoutReturningValue(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "vault.bin")
|
||||||
|
if err := Create(path, "master"); err != nil {
|
||||||
|
t.Fatalf("create vault: %v", err)
|
||||||
|
}
|
||||||
|
v := New(path)
|
||||||
|
if err := v.Unlock("master"); err != nil {
|
||||||
|
t.Fatalf("unlock vault: %v", err)
|
||||||
|
}
|
||||||
|
if err := v.Put("server:prod:ssh_password", "ssh_password", []byte("secret-password")); err != nil {
|
||||||
|
t.Fatalf("put password: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !v.HasSecret("server:prod:ssh_password") {
|
||||||
|
t.Fatal("expected saved password to be reported present")
|
||||||
|
}
|
||||||
|
if v.HasSecret("server:prod:key_passphrase") {
|
||||||
|
t.Fatal("expected missing passphrase to be reported absent")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue