feat: add tunnel background cli management
This commit is contained in:
parent
6991eab3c0
commit
7e0a00ff43
19
cmd/root.go
19
cmd/root.go
|
|
@ -183,8 +183,27 @@ func commandRequiresStartupVaultUnlock(args []string) bool {
|
|||
|
||||
switch args[0] {
|
||||
case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel":
|
||||
if args[0] == "tunnel" && tunnelCommandSkipsStartupVaultUnlock(args[1:]) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func tunnelCommandSkipsStartupVaultUnlock(args []string) bool {
|
||||
if len(args) == 0 {
|
||||
return false
|
||||
}
|
||||
switch args[0] {
|
||||
case "list", "stop", "stop-all":
|
||||
return true
|
||||
}
|
||||
for _, arg := range args[1:] {
|
||||
if arg == "--background" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ func TestCommandRequiresStartupVaultUnlock(t *testing.T) {
|
|||
{name: "list only reads database", args: []string{"list"}, want: false},
|
||||
{name: "show only reads database", args: []string{"show", "prod"}, want: false},
|
||||
{name: "search only reads database", args: []string{"search", "prod"}, want: false},
|
||||
{name: "tunnel list only reads state", args: []string{"tunnel", "list"}, want: false},
|
||||
{name: "tunnel stop only edits state", args: []string{"tunnel", "stop", "1"}, want: false},
|
||||
{name: "tunnel stop all only edits state", args: []string{"tunnel", "stop-all"}, want: false},
|
||||
{name: "background tunnel does not need startup vault", args: []string{"tunnel", "prod", "--background"}, want: false},
|
||||
{name: "config path only reads config", args: []string{"config", "path"}, want: false},
|
||||
{name: "help", args: []string{"--help"}, want: false},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -234,6 +234,11 @@ func runTUI() error {
|
|||
|
||||
if background {
|
||||
// Start detached tunnel process
|
||||
if err := validateBackgroundTunnel(fresh, forwards); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err)
|
||||
servers, _ = appDB.ListServers()
|
||||
continue
|
||||
}
|
||||
state, err := tunnelpkg.Start(cfg, fresh, forwards, forwardOnly)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err)
|
||||
|
|
|
|||
104
cmd/tunnel.go
104
cmd/tunnel.go
|
|
@ -2,8 +2,11 @@ package cmd
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
"github.com/mirivlad/sshkeeper/internal/ssh"
|
||||
tunnelpkg "github.com/mirivlad/sshkeeper/internal/tunnel"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
@ -19,6 +22,29 @@ var tunnelCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
forwardsOnly, _ := cmd.Flags().GetBool("forward-only")
|
||||
background, _ := cmd.Flags().GetBool("background")
|
||||
|
||||
// Load forwards
|
||||
forwards, err := appDB.GetForwards(server.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load forwards: %w", err)
|
||||
}
|
||||
|
||||
if background {
|
||||
if err := validateBackgroundTunnel(server, forwards); err != nil {
|
||||
return err
|
||||
}
|
||||
state, err := tunnelpkg.Start(cfg, server, forwards, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("✓ Tunnel started [%d] PID %d → %s\n", state.ID, state.PID, server.Alias)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(forwards) == 0 && forwardsOnly {
|
||||
return fmt.Errorf("no forwards configured for %s", alias)
|
||||
}
|
||||
|
||||
v := getOrCreateVault()
|
||||
vaultFunc := func(serverAlias string, secretType string) (string, error) {
|
||||
|
|
@ -33,16 +59,6 @@ var tunnelCmd = &cobra.Command{
|
|||
return string(data), nil
|
||||
}
|
||||
|
||||
// Load forwards
|
||||
forwards, err := appDB.GetForwards(server.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load forwards: %w", err)
|
||||
}
|
||||
|
||||
if len(forwards) == 0 && forwardsOnly {
|
||||
return fmt.Errorf("no forwards configured for %s", alias)
|
||||
}
|
||||
|
||||
if len(forwards) > 0 {
|
||||
fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards))
|
||||
} else {
|
||||
|
|
@ -58,6 +74,74 @@ var tunnelCmd = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
var tunnelListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List tracked background tunnels",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
states := tunnelpkg.List()
|
||||
if len(states) == 0 {
|
||||
fmt.Println("No tracked tunnels.")
|
||||
return nil
|
||||
}
|
||||
fmt.Printf("%-22s %-8s %-10s %s\n", "ID", "PID", "STATUS", "SERVER")
|
||||
for _, state := range states {
|
||||
status := "stopped"
|
||||
if tunnelpkg.IsRunning(state.ID) {
|
||||
status = "running"
|
||||
}
|
||||
fmt.Printf("%-22d %-8d %-10s %s\n", state.ID, state.PID, status, state.ServerAlias)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var tunnelStopCmd = &cobra.Command{
|
||||
Use: "stop <id>",
|
||||
Short: "Stop a tracked background tunnel",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
id, err := strconv.ParseInt(args[0], 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid tunnel ID: %s", args[0])
|
||||
}
|
||||
if err := tunnelpkg.Stop(id); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("✓ Tunnel %d stopped\n", id)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var tunnelStopAllCmd = &cobra.Command{
|
||||
Use: "stop-all",
|
||||
Short: "Stop all tracked background tunnels",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := tunnelpkg.StopAll(); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("✓ All tracked tunnels stopped")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func validateBackgroundTunnel(server *model.Server, forwards []*model.Forward) error {
|
||||
if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
|
||||
return fmt.Errorf("background tunnels support only key or agent auth; use foreground tunnel for %s auth", server.AuthMethod)
|
||||
}
|
||||
for _, f := range forwards {
|
||||
if f.Enabled {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("no enabled forwards configured for %s", server.Alias)
|
||||
}
|
||||
|
||||
func init() {
|
||||
tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)")
|
||||
tunnelCmd.Flags().Bool("background", false, "Start tunnel in background (ssh -N)")
|
||||
tunnelCmd.AddCommand(tunnelListCmd)
|
||||
tunnelCmd.AddCommand(tunnelStopCmd)
|
||||
tunnelCmd.AddCommand(tunnelStopAllCmd)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
)
|
||||
|
||||
func TestTunnelCommandExposesBackgroundAndManagementCommands(t *testing.T) {
|
||||
if flag := tunnelCmd.Flags().Lookup("background"); flag == nil {
|
||||
t.Fatal("expected tunnel --background flag")
|
||||
}
|
||||
|
||||
for _, name := range []string{"list", "stop", "stop-all"} {
|
||||
if cmd, _, err := tunnelCmd.Find([]string{name}); err != nil || cmd == nil || cmd.Name() != name {
|
||||
t.Fatalf("expected tunnel subcommand %q, got cmd=%v err=%v", name, cmd, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateBackgroundTunnelRejectsSecretAuth(t *testing.T) {
|
||||
server := &model.Server{Alias: "db", AuthMethod: model.AuthPassword}
|
||||
forwards := []*model.Forward{{Enabled: true, Type: model.ForwardLocal, LocalPort: 15432}}
|
||||
|
||||
err := validateBackgroundTunnel(server, forwards)
|
||||
if err == nil {
|
||||
t.Fatal("expected background tunnel to reject password auth")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "key or agent auth") {
|
||||
t.Fatalf("expected auth guidance, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateBackgroundTunnelRequiresEnabledForward(t *testing.T) {
|
||||
server := &model.Server{Alias: "db", AuthMethod: model.AuthKey}
|
||||
forwards := []*model.Forward{{Enabled: false, Type: model.ForwardLocal, LocalPort: 15432}}
|
||||
|
||||
err := validateBackgroundTunnel(server, forwards)
|
||||
if err == nil {
|
||||
t.Fatal("expected background tunnel to require enabled forwards")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no enabled forwards") {
|
||||
t.Fatalf("expected enabled forward error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -2,11 +2,13 @@ package tunnel
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/config"
|
||||
|
|
@ -22,7 +24,11 @@ var (
|
|||
|
||||
// Init initializes the tunnel state manager with the data directory.
|
||||
func Init(dir string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
dataDir = dir
|
||||
states = map[int64]*model.TunnelState{}
|
||||
return loadStates()
|
||||
}
|
||||
|
||||
|
|
@ -100,6 +106,7 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward,
|
|||
copy(args, sshArgs)
|
||||
|
||||
cmd := exec.Command(cfg.SSH.Binary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Stdin = nil
|
||||
cmd.Stdout = nil
|
||||
cmd.Stderr = nil
|
||||
|
|
@ -126,8 +133,13 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward,
|
|||
|
||||
states[id] = state
|
||||
if err := saveStates(); err != nil {
|
||||
// Non-fatal: log but don't fail
|
||||
_ = err
|
||||
delete(states, id)
|
||||
_ = cmd.Process.Kill()
|
||||
return nil, fmt.Errorf("save tunnel state: %w", err)
|
||||
}
|
||||
if err := cmd.Process.Release(); err != nil {
|
||||
delete(states, id)
|
||||
return nil, fmt.Errorf("release tunnel process: %w", err)
|
||||
}
|
||||
|
||||
return state, nil
|
||||
|
|
@ -188,5 +200,6 @@ func IsRunning(id int64) bool {
|
|||
return false
|
||||
}
|
||||
// Signal 0 just checks if process exists
|
||||
return proc.Signal(os.Signal(nil)) == nil
|
||||
err = proc.Signal(syscall.Signal(0))
|
||||
return err == nil || errors.Is(err, syscall.EPERM)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mirivlad/sshkeeper/internal/model"
|
||||
)
|
||||
|
||||
func TestInitClearsPreviousInMemoryStates(t *testing.T) {
|
||||
if err := Init(t.TempDir()); err != nil {
|
||||
t.Fatalf("init first dir: %v", err)
|
||||
}
|
||||
states[123] = &model.TunnelState{ID: 123, ServerAlias: "old", PID: 1}
|
||||
|
||||
if err := Init(t.TempDir()); err != nil {
|
||||
t.Fatalf("init second dir: %v", err)
|
||||
}
|
||||
|
||||
if got := Get(123); got != nil {
|
||||
t.Fatalf("expected init to clear stale state, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRunningDetectsLiveProcess(t *testing.T) {
|
||||
if err := Init(t.TempDir()); err != nil {
|
||||
t.Fatalf("init: %v", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("sleep", "2")
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start sleep: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
})
|
||||
|
||||
states[1] = &model.TunnelState{ID: 1, ServerAlias: "live", PID: cmd.Process.Pid, StartedAt: time.Now()}
|
||||
|
||||
if !IsRunning(1) {
|
||||
t.Fatalf("expected pid %d to be detected as running", cmd.Process.Pid)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue