feat: add tunnel background cli management

This commit is contained in:
mirivlad 2026-06-06 03:21:30 +08:00
parent 6991eab3c0
commit 7e0a00ff43
7 changed files with 229 additions and 13 deletions

View File

@ -183,8 +183,27 @@ func commandRequiresStartupVaultUnlock(args []string) bool {
switch args[0] { switch args[0] {
case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel": case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel":
if args[0] == "tunnel" && tunnelCommandSkipsStartupVaultUnlock(args[1:]) {
return false
}
return true return true
default: default:
return false return false
} }
} }
func tunnelCommandSkipsStartupVaultUnlock(args []string) bool {
if len(args) == 0 {
return false
}
switch args[0] {
case "list", "stop", "stop-all":
return true
}
for _, arg := range args[1:] {
if arg == "--background" {
return true
}
}
return false
}

View File

@ -18,6 +18,10 @@ func TestCommandRequiresStartupVaultUnlock(t *testing.T) {
{name: "list only reads database", args: []string{"list"}, want: false}, {name: "list only reads database", args: []string{"list"}, want: false},
{name: "show only reads database", args: []string{"show", "prod"}, want: false}, {name: "show only reads database", args: []string{"show", "prod"}, want: false},
{name: "search only reads database", args: []string{"search", "prod"}, want: false}, {name: "search only reads database", args: []string{"search", "prod"}, want: false},
{name: "tunnel list only reads state", args: []string{"tunnel", "list"}, want: false},
{name: "tunnel stop only edits state", args: []string{"tunnel", "stop", "1"}, want: false},
{name: "tunnel stop all only edits state", args: []string{"tunnel", "stop-all"}, want: false},
{name: "background tunnel does not need startup vault", args: []string{"tunnel", "prod", "--background"}, want: false},
{name: "config path only reads config", args: []string{"config", "path"}, want: false}, {name: "config path only reads config", args: []string{"config", "path"}, want: false},
{name: "help", args: []string{"--help"}, want: false}, {name: "help", args: []string{"--help"}, want: false},
} }

View File

@ -234,6 +234,11 @@ func runTUI() error {
if background { if background {
// Start detached tunnel process // Start detached tunnel process
if err := validateBackgroundTunnel(fresh, forwards); err != nil {
fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err)
servers, _ = appDB.ListServers()
continue
}
state, err := tunnelpkg.Start(cfg, fresh, forwards, forwardOnly) state, err := tunnelpkg.Start(cfg, fresh, forwards, forwardOnly)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err) fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err)

View File

@ -2,8 +2,11 @@ package cmd
import ( import (
"fmt" "fmt"
"strconv"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/ssh" "github.com/mirivlad/sshkeeper/internal/ssh"
tunnelpkg "github.com/mirivlad/sshkeeper/internal/tunnel"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -19,6 +22,29 @@ var tunnelCmd = &cobra.Command{
} }
forwardsOnly, _ := cmd.Flags().GetBool("forward-only") forwardsOnly, _ := cmd.Flags().GetBool("forward-only")
background, _ := cmd.Flags().GetBool("background")
// Load forwards
forwards, err := appDB.GetForwards(server.ID)
if err != nil {
return fmt.Errorf("load forwards: %w", err)
}
if background {
if err := validateBackgroundTunnel(server, forwards); err != nil {
return err
}
state, err := tunnelpkg.Start(cfg, server, forwards, true)
if err != nil {
return err
}
fmt.Printf("✓ Tunnel started [%d] PID %d → %s\n", state.ID, state.PID, server.Alias)
return nil
}
if len(forwards) == 0 && forwardsOnly {
return fmt.Errorf("no forwards configured for %s", alias)
}
v := getOrCreateVault() v := getOrCreateVault()
vaultFunc := func(serverAlias string, secretType string) (string, error) { vaultFunc := func(serverAlias string, secretType string) (string, error) {
@ -33,16 +59,6 @@ var tunnelCmd = &cobra.Command{
return string(data), nil return string(data), nil
} }
// Load forwards
forwards, err := appDB.GetForwards(server.ID)
if err != nil {
return fmt.Errorf("load forwards: %w", err)
}
if len(forwards) == 0 && forwardsOnly {
return fmt.Errorf("no forwards configured for %s", alias)
}
if len(forwards) > 0 { if len(forwards) > 0 {
fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards)) fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards))
} else { } else {
@ -58,6 +74,74 @@ var tunnelCmd = &cobra.Command{
}, },
} }
var tunnelListCmd = &cobra.Command{
Use: "list",
Short: "List tracked background tunnels",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
states := tunnelpkg.List()
if len(states) == 0 {
fmt.Println("No tracked tunnels.")
return nil
}
fmt.Printf("%-22s %-8s %-10s %s\n", "ID", "PID", "STATUS", "SERVER")
for _, state := range states {
status := "stopped"
if tunnelpkg.IsRunning(state.ID) {
status = "running"
}
fmt.Printf("%-22d %-8d %-10s %s\n", state.ID, state.PID, status, state.ServerAlias)
}
return nil
},
}
var tunnelStopCmd = &cobra.Command{
Use: "stop <id>",
Short: "Stop a tracked background tunnel",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
id, err := strconv.ParseInt(args[0], 10, 64)
if err != nil {
return fmt.Errorf("invalid tunnel ID: %s", args[0])
}
if err := tunnelpkg.Stop(id); err != nil {
return err
}
fmt.Printf("✓ Tunnel %d stopped\n", id)
return nil
},
}
var tunnelStopAllCmd = &cobra.Command{
Use: "stop-all",
Short: "Stop all tracked background tunnels",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
if err := tunnelpkg.StopAll(); err != nil {
return err
}
fmt.Println("✓ All tracked tunnels stopped")
return nil
},
}
func validateBackgroundTunnel(server *model.Server, forwards []*model.Forward) error {
if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
return fmt.Errorf("background tunnels support only key or agent auth; use foreground tunnel for %s auth", server.AuthMethod)
}
for _, f := range forwards {
if f.Enabled {
return nil
}
}
return fmt.Errorf("no enabled forwards configured for %s", server.Alias)
}
func init() { func init() {
tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)") tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)")
tunnelCmd.Flags().Bool("background", false, "Start tunnel in background (ssh -N)")
tunnelCmd.AddCommand(tunnelListCmd)
tunnelCmd.AddCommand(tunnelStopCmd)
tunnelCmd.AddCommand(tunnelStopAllCmd)
} }

46
cmd/tunnel_test.go Normal file
View File

@ -0,0 +1,46 @@
package cmd
import (
"strings"
"testing"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestTunnelCommandExposesBackgroundAndManagementCommands(t *testing.T) {
if flag := tunnelCmd.Flags().Lookup("background"); flag == nil {
t.Fatal("expected tunnel --background flag")
}
for _, name := range []string{"list", "stop", "stop-all"} {
if cmd, _, err := tunnelCmd.Find([]string{name}); err != nil || cmd == nil || cmd.Name() != name {
t.Fatalf("expected tunnel subcommand %q, got cmd=%v err=%v", name, cmd, err)
}
}
}
func TestValidateBackgroundTunnelRejectsSecretAuth(t *testing.T) {
server := &model.Server{Alias: "db", AuthMethod: model.AuthPassword}
forwards := []*model.Forward{{Enabled: true, Type: model.ForwardLocal, LocalPort: 15432}}
err := validateBackgroundTunnel(server, forwards)
if err == nil {
t.Fatal("expected background tunnel to reject password auth")
}
if !strings.Contains(err.Error(), "key or agent auth") {
t.Fatalf("expected auth guidance, got %v", err)
}
}
func TestValidateBackgroundTunnelRequiresEnabledForward(t *testing.T) {
server := &model.Server{Alias: "db", AuthMethod: model.AuthKey}
forwards := []*model.Forward{{Enabled: false, Type: model.ForwardLocal, LocalPort: 15432}}
err := validateBackgroundTunnel(server, forwards)
if err == nil {
t.Fatal("expected background tunnel to require enabled forwards")
}
if !strings.Contains(err.Error(), "no enabled forwards") {
t.Fatalf("expected enabled forward error, got %v", err)
}
}

View File

@ -2,11 +2,13 @@ package tunnel
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"sync" "sync"
"syscall"
"time" "time"
"github.com/mirivlad/sshkeeper/internal/config" "github.com/mirivlad/sshkeeper/internal/config"
@ -22,7 +24,11 @@ var (
// Init initializes the tunnel state manager with the data directory. // Init initializes the tunnel state manager with the data directory.
func Init(dir string) error { func Init(dir string) error {
mu.Lock()
defer mu.Unlock()
dataDir = dir dataDir = dir
states = map[int64]*model.TunnelState{}
return loadStates() return loadStates()
} }
@ -100,6 +106,7 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward,
copy(args, sshArgs) copy(args, sshArgs)
cmd := exec.Command(cfg.SSH.Binary, args...) cmd := exec.Command(cfg.SSH.Binary, args...)
cmd.Env = os.Environ()
cmd.Stdin = nil cmd.Stdin = nil
cmd.Stdout = nil cmd.Stdout = nil
cmd.Stderr = nil cmd.Stderr = nil
@ -126,8 +133,13 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward,
states[id] = state states[id] = state
if err := saveStates(); err != nil { if err := saveStates(); err != nil {
// Non-fatal: log but don't fail delete(states, id)
_ = err _ = cmd.Process.Kill()
return nil, fmt.Errorf("save tunnel state: %w", err)
}
if err := cmd.Process.Release(); err != nil {
delete(states, id)
return nil, fmt.Errorf("release tunnel process: %w", err)
} }
return state, nil return state, nil
@ -188,5 +200,6 @@ func IsRunning(id int64) bool {
return false return false
} }
// Signal 0 just checks if process exists // Signal 0 just checks if process exists
return proc.Signal(os.Signal(nil)) == nil err = proc.Signal(syscall.Signal(0))
return err == nil || errors.Is(err, syscall.EPERM)
} }

View File

@ -0,0 +1,45 @@
package tunnel
import (
"os/exec"
"testing"
"time"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestInitClearsPreviousInMemoryStates(t *testing.T) {
if err := Init(t.TempDir()); err != nil {
t.Fatalf("init first dir: %v", err)
}
states[123] = &model.TunnelState{ID: 123, ServerAlias: "old", PID: 1}
if err := Init(t.TempDir()); err != nil {
t.Fatalf("init second dir: %v", err)
}
if got := Get(123); got != nil {
t.Fatalf("expected init to clear stale state, got %#v", got)
}
}
func TestIsRunningDetectsLiveProcess(t *testing.T) {
if err := Init(t.TempDir()); err != nil {
t.Fatalf("init: %v", err)
}
cmd := exec.Command("sleep", "2")
if err := cmd.Start(); err != nil {
t.Fatalf("start sleep: %v", err)
}
t.Cleanup(func() {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
})
states[1] = &model.TunnelState{ID: 1, ServerAlias: "live", PID: cmd.Process.Pid, StartedAt: time.Now()}
if !IsRunning(1) {
t.Fatalf("expected pid %d to be detected as running", cmd.Process.Pid)
}
}