diff --git a/cmd/root.go b/cmd/root.go index 9604d1e..65d8956 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -183,8 +183,27 @@ func commandRequiresStartupVaultUnlock(args []string) bool { switch args[0] { case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel": + if args[0] == "tunnel" && tunnelCommandSkipsStartupVaultUnlock(args[1:]) { + return false + } return true default: return false } } + +func tunnelCommandSkipsStartupVaultUnlock(args []string) bool { + if len(args) == 0 { + return false + } + switch args[0] { + case "list", "stop", "stop-all": + return true + } + for _, arg := range args[1:] { + if arg == "--background" { + return true + } + } + return false +} diff --git a/cmd/root_test.go b/cmd/root_test.go index 7f58db7..2db750b 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -18,6 +18,10 @@ func TestCommandRequiresStartupVaultUnlock(t *testing.T) { {name: "list only reads database", args: []string{"list"}, want: false}, {name: "show only reads database", args: []string{"show", "prod"}, want: false}, {name: "search only reads database", args: []string{"search", "prod"}, want: false}, + {name: "tunnel list only reads state", args: []string{"tunnel", "list"}, want: false}, + {name: "tunnel stop only edits state", args: []string{"tunnel", "stop", "1"}, want: false}, + {name: "tunnel stop all only edits state", args: []string{"tunnel", "stop-all"}, want: false}, + {name: "background tunnel does not need startup vault", args: []string{"tunnel", "prod", "--background"}, want: false}, {name: "config path only reads config", args: []string{"config", "path"}, want: false}, {name: "help", args: []string{"--help"}, want: false}, } diff --git a/cmd/tui.go b/cmd/tui.go index 263c03a..a35f89b 100644 --- a/cmd/tui.go +++ b/cmd/tui.go @@ -234,6 +234,11 @@ func runTUI() error { if background { // Start detached tunnel process + if err := validateBackgroundTunnel(fresh, forwards); err != nil { + fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err) + servers, _ = appDB.ListServers() + continue + } state, err := tunnelpkg.Start(cfg, fresh, forwards, forwardOnly) if err != nil { fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err) diff --git a/cmd/tunnel.go b/cmd/tunnel.go index 99a6d9a..a9534b1 100644 --- a/cmd/tunnel.go +++ b/cmd/tunnel.go @@ -2,8 +2,11 @@ package cmd import ( "fmt" + "strconv" + "github.com/mirivlad/sshkeeper/internal/model" "github.com/mirivlad/sshkeeper/internal/ssh" + tunnelpkg "github.com/mirivlad/sshkeeper/internal/tunnel" "github.com/spf13/cobra" ) @@ -19,6 +22,29 @@ var tunnelCmd = &cobra.Command{ } forwardsOnly, _ := cmd.Flags().GetBool("forward-only") + background, _ := cmd.Flags().GetBool("background") + + // Load forwards + forwards, err := appDB.GetForwards(server.ID) + if err != nil { + return fmt.Errorf("load forwards: %w", err) + } + + if background { + if err := validateBackgroundTunnel(server, forwards); err != nil { + return err + } + state, err := tunnelpkg.Start(cfg, server, forwards, true) + if err != nil { + return err + } + fmt.Printf("✓ Tunnel started [%d] PID %d → %s\n", state.ID, state.PID, server.Alias) + return nil + } + + if len(forwards) == 0 && forwardsOnly { + return fmt.Errorf("no forwards configured for %s", alias) + } v := getOrCreateVault() vaultFunc := func(serverAlias string, secretType string) (string, error) { @@ -33,16 +59,6 @@ var tunnelCmd = &cobra.Command{ return string(data), nil } - // Load forwards - forwards, err := appDB.GetForwards(server.ID) - if err != nil { - return fmt.Errorf("load forwards: %w", err) - } - - if len(forwards) == 0 && forwardsOnly { - return fmt.Errorf("no forwards configured for %s", alias) - } - if len(forwards) > 0 { fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards)) } else { @@ -58,6 +74,74 @@ var tunnelCmd = &cobra.Command{ }, } +var tunnelListCmd = &cobra.Command{ + Use: "list", + Short: "List tracked background tunnels", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + states := tunnelpkg.List() + if len(states) == 0 { + fmt.Println("No tracked tunnels.") + return nil + } + fmt.Printf("%-22s %-8s %-10s %s\n", "ID", "PID", "STATUS", "SERVER") + for _, state := range states { + status := "stopped" + if tunnelpkg.IsRunning(state.ID) { + status = "running" + } + fmt.Printf("%-22d %-8d %-10s %s\n", state.ID, state.PID, status, state.ServerAlias) + } + return nil + }, +} + +var tunnelStopCmd = &cobra.Command{ + Use: "stop ", + Short: "Stop a tracked background tunnel", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + id, err := strconv.ParseInt(args[0], 10, 64) + if err != nil { + return fmt.Errorf("invalid tunnel ID: %s", args[0]) + } + if err := tunnelpkg.Stop(id); err != nil { + return err + } + fmt.Printf("✓ Tunnel %d stopped\n", id) + return nil + }, +} + +var tunnelStopAllCmd = &cobra.Command{ + Use: "stop-all", + Short: "Stop all tracked background tunnels", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + if err := tunnelpkg.StopAll(); err != nil { + return err + } + fmt.Println("✓ All tracked tunnels stopped") + return nil + }, +} + +func validateBackgroundTunnel(server *model.Server, forwards []*model.Forward) error { + if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase { + return fmt.Errorf("background tunnels support only key or agent auth; use foreground tunnel for %s auth", server.AuthMethod) + } + for _, f := range forwards { + if f.Enabled { + return nil + } + } + return fmt.Errorf("no enabled forwards configured for %s", server.Alias) +} + func init() { tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)") + tunnelCmd.Flags().Bool("background", false, "Start tunnel in background (ssh -N)") + tunnelCmd.AddCommand(tunnelListCmd) + tunnelCmd.AddCommand(tunnelStopCmd) + tunnelCmd.AddCommand(tunnelStopAllCmd) } diff --git a/cmd/tunnel_test.go b/cmd/tunnel_test.go new file mode 100644 index 0000000..964c89d --- /dev/null +++ b/cmd/tunnel_test.go @@ -0,0 +1,46 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/mirivlad/sshkeeper/internal/model" +) + +func TestTunnelCommandExposesBackgroundAndManagementCommands(t *testing.T) { + if flag := tunnelCmd.Flags().Lookup("background"); flag == nil { + t.Fatal("expected tunnel --background flag") + } + + for _, name := range []string{"list", "stop", "stop-all"} { + if cmd, _, err := tunnelCmd.Find([]string{name}); err != nil || cmd == nil || cmd.Name() != name { + t.Fatalf("expected tunnel subcommand %q, got cmd=%v err=%v", name, cmd, err) + } + } +} + +func TestValidateBackgroundTunnelRejectsSecretAuth(t *testing.T) { + server := &model.Server{Alias: "db", AuthMethod: model.AuthPassword} + forwards := []*model.Forward{{Enabled: true, Type: model.ForwardLocal, LocalPort: 15432}} + + err := validateBackgroundTunnel(server, forwards) + if err == nil { + t.Fatal("expected background tunnel to reject password auth") + } + if !strings.Contains(err.Error(), "key or agent auth") { + t.Fatalf("expected auth guidance, got %v", err) + } +} + +func TestValidateBackgroundTunnelRequiresEnabledForward(t *testing.T) { + server := &model.Server{Alias: "db", AuthMethod: model.AuthKey} + forwards := []*model.Forward{{Enabled: false, Type: model.ForwardLocal, LocalPort: 15432}} + + err := validateBackgroundTunnel(server, forwards) + if err == nil { + t.Fatal("expected background tunnel to require enabled forwards") + } + if !strings.Contains(err.Error(), "no enabled forwards") { + t.Fatalf("expected enabled forward error, got %v", err) + } +} diff --git a/internal/tunnel/manager.go b/internal/tunnel/manager.go index e2f6929..227f624 100644 --- a/internal/tunnel/manager.go +++ b/internal/tunnel/manager.go @@ -2,11 +2,13 @@ package tunnel import ( "encoding/json" + "errors" "fmt" "os" "os/exec" "path/filepath" "sync" + "syscall" "time" "github.com/mirivlad/sshkeeper/internal/config" @@ -22,7 +24,11 @@ var ( // Init initializes the tunnel state manager with the data directory. func Init(dir string) error { + mu.Lock() + defer mu.Unlock() + dataDir = dir + states = map[int64]*model.TunnelState{} return loadStates() } @@ -100,6 +106,7 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward, copy(args, sshArgs) cmd := exec.Command(cfg.SSH.Binary, args...) + cmd.Env = os.Environ() cmd.Stdin = nil cmd.Stdout = nil cmd.Stderr = nil @@ -126,8 +133,13 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward, states[id] = state if err := saveStates(); err != nil { - // Non-fatal: log but don't fail - _ = err + delete(states, id) + _ = cmd.Process.Kill() + return nil, fmt.Errorf("save tunnel state: %w", err) + } + if err := cmd.Process.Release(); err != nil { + delete(states, id) + return nil, fmt.Errorf("release tunnel process: %w", err) } return state, nil @@ -188,5 +200,6 @@ func IsRunning(id int64) bool { return false } // Signal 0 just checks if process exists - return proc.Signal(os.Signal(nil)) == nil + err = proc.Signal(syscall.Signal(0)) + return err == nil || errors.Is(err, syscall.EPERM) } diff --git a/internal/tunnel/manager_test.go b/internal/tunnel/manager_test.go new file mode 100644 index 0000000..9a4bab4 --- /dev/null +++ b/internal/tunnel/manager_test.go @@ -0,0 +1,45 @@ +package tunnel + +import ( + "os/exec" + "testing" + "time" + + "github.com/mirivlad/sshkeeper/internal/model" +) + +func TestInitClearsPreviousInMemoryStates(t *testing.T) { + if err := Init(t.TempDir()); err != nil { + t.Fatalf("init first dir: %v", err) + } + states[123] = &model.TunnelState{ID: 123, ServerAlias: "old", PID: 1} + + if err := Init(t.TempDir()); err != nil { + t.Fatalf("init second dir: %v", err) + } + + if got := Get(123); got != nil { + t.Fatalf("expected init to clear stale state, got %#v", got) + } +} + +func TestIsRunningDetectsLiveProcess(t *testing.T) { + if err := Init(t.TempDir()); err != nil { + t.Fatalf("init: %v", err) + } + + cmd := exec.Command("sleep", "2") + if err := cmd.Start(); err != nil { + t.Fatalf("start sleep: %v", err) + } + t.Cleanup(func() { + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + }) + + states[1] = &model.TunnelState{ID: 1, ServerAlias: "live", PID: cmd.Process.Pid, StartedAt: time.Now()} + + if !IsRunning(1) { + t.Fatalf("expected pid %d to be detected as running", cmd.Process.Pid) + } +}