feat: add tunnel background cli management
This commit is contained in:
parent
6991eab3c0
commit
7e0a00ff43
19
cmd/root.go
19
cmd/root.go
|
|
@ -183,8 +183,27 @@ func commandRequiresStartupVaultUnlock(args []string) bool {
|
||||||
|
|
||||||
switch args[0] {
|
switch args[0] {
|
||||||
case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel":
|
case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel":
|
||||||
|
if args[0] == "tunnel" && tunnelCommandSkipsStartupVaultUnlock(args[1:]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tunnelCommandSkipsStartupVaultUnlock(args []string) bool {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch args[0] {
|
||||||
|
case "list", "stop", "stop-all":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, arg := range args[1:] {
|
||||||
|
if arg == "--background" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,10 @@ func TestCommandRequiresStartupVaultUnlock(t *testing.T) {
|
||||||
{name: "list only reads database", args: []string{"list"}, want: false},
|
{name: "list only reads database", args: []string{"list"}, want: false},
|
||||||
{name: "show only reads database", args: []string{"show", "prod"}, want: false},
|
{name: "show only reads database", args: []string{"show", "prod"}, want: false},
|
||||||
{name: "search only reads database", args: []string{"search", "prod"}, want: false},
|
{name: "search only reads database", args: []string{"search", "prod"}, want: false},
|
||||||
|
{name: "tunnel list only reads state", args: []string{"tunnel", "list"}, want: false},
|
||||||
|
{name: "tunnel stop only edits state", args: []string{"tunnel", "stop", "1"}, want: false},
|
||||||
|
{name: "tunnel stop all only edits state", args: []string{"tunnel", "stop-all"}, want: false},
|
||||||
|
{name: "background tunnel does not need startup vault", args: []string{"tunnel", "prod", "--background"}, want: false},
|
||||||
{name: "config path only reads config", args: []string{"config", "path"}, want: false},
|
{name: "config path only reads config", args: []string{"config", "path"}, want: false},
|
||||||
{name: "help", args: []string{"--help"}, want: false},
|
{name: "help", args: []string{"--help"}, want: false},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -234,6 +234,11 @@ func runTUI() error {
|
||||||
|
|
||||||
if background {
|
if background {
|
||||||
// Start detached tunnel process
|
// Start detached tunnel process
|
||||||
|
if err := validateBackgroundTunnel(fresh, forwards); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err)
|
||||||
|
servers, _ = appDB.ListServers()
|
||||||
|
continue
|
||||||
|
}
|
||||||
state, err := tunnelpkg.Start(cfg, fresh, forwards, forwardOnly)
|
state, err := tunnelpkg.Start(cfg, fresh, forwards, forwardOnly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Start tunnel: %v\n", err)
|
||||||
|
|
|
||||||
104
cmd/tunnel.go
104
cmd/tunnel.go
|
|
@ -2,8 +2,11 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
"github.com/mirivlad/sshkeeper/internal/ssh"
|
"github.com/mirivlad/sshkeeper/internal/ssh"
|
||||||
|
tunnelpkg "github.com/mirivlad/sshkeeper/internal/tunnel"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -19,6 +22,29 @@ var tunnelCmd = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardsOnly, _ := cmd.Flags().GetBool("forward-only")
|
forwardsOnly, _ := cmd.Flags().GetBool("forward-only")
|
||||||
|
background, _ := cmd.Flags().GetBool("background")
|
||||||
|
|
||||||
|
// Load forwards
|
||||||
|
forwards, err := appDB.GetForwards(server.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load forwards: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if background {
|
||||||
|
if err := validateBackgroundTunnel(server, forwards); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
state, err := tunnelpkg.Start(cfg, server, forwards, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("✓ Tunnel started [%d] PID %d → %s\n", state.ID, state.PID, server.Alias)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(forwards) == 0 && forwardsOnly {
|
||||||
|
return fmt.Errorf("no forwards configured for %s", alias)
|
||||||
|
}
|
||||||
|
|
||||||
v := getOrCreateVault()
|
v := getOrCreateVault()
|
||||||
vaultFunc := func(serverAlias string, secretType string) (string, error) {
|
vaultFunc := func(serverAlias string, secretType string) (string, error) {
|
||||||
|
|
@ -33,16 +59,6 @@ var tunnelCmd = &cobra.Command{
|
||||||
return string(data), nil
|
return string(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load forwards
|
|
||||||
forwards, err := appDB.GetForwards(server.ID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load forwards: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(forwards) == 0 && forwardsOnly {
|
|
||||||
return fmt.Errorf("no forwards configured for %s", alias)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(forwards) > 0 {
|
if len(forwards) > 0 {
|
||||||
fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards))
|
fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards))
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -58,6 +74,74 @@ var tunnelCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var tunnelListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Short: "List tracked background tunnels",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
states := tunnelpkg.List()
|
||||||
|
if len(states) == 0 {
|
||||||
|
fmt.Println("No tracked tunnels.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fmt.Printf("%-22s %-8s %-10s %s\n", "ID", "PID", "STATUS", "SERVER")
|
||||||
|
for _, state := range states {
|
||||||
|
status := "stopped"
|
||||||
|
if tunnelpkg.IsRunning(state.ID) {
|
||||||
|
status = "running"
|
||||||
|
}
|
||||||
|
fmt.Printf("%-22d %-8d %-10s %s\n", state.ID, state.PID, status, state.ServerAlias)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var tunnelStopCmd = &cobra.Command{
|
||||||
|
Use: "stop <id>",
|
||||||
|
Short: "Stop a tracked background tunnel",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
id, err := strconv.ParseInt(args[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid tunnel ID: %s", args[0])
|
||||||
|
}
|
||||||
|
if err := tunnelpkg.Stop(id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("✓ Tunnel %d stopped\n", id)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var tunnelStopAllCmd = &cobra.Command{
|
||||||
|
Use: "stop-all",
|
||||||
|
Short: "Stop all tracked background tunnels",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := tunnelpkg.StopAll(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("✓ All tracked tunnels stopped")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateBackgroundTunnel(server *model.Server, forwards []*model.Forward) error {
|
||||||
|
if server.AuthMethod == model.AuthPassword || server.AuthMethod == model.AuthKeyPassphrase {
|
||||||
|
return fmt.Errorf("background tunnels support only key or agent auth; use foreground tunnel for %s auth", server.AuthMethod)
|
||||||
|
}
|
||||||
|
for _, f := range forwards {
|
||||||
|
if f.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("no enabled forwards configured for %s", server.Alias)
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)")
|
tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)")
|
||||||
|
tunnelCmd.Flags().Bool("background", false, "Start tunnel in background (ssh -N)")
|
||||||
|
tunnelCmd.AddCommand(tunnelListCmd)
|
||||||
|
tunnelCmd.AddCommand(tunnelStopCmd)
|
||||||
|
tunnelCmd.AddCommand(tunnelStopAllCmd)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTunnelCommandExposesBackgroundAndManagementCommands(t *testing.T) {
|
||||||
|
if flag := tunnelCmd.Flags().Lookup("background"); flag == nil {
|
||||||
|
t.Fatal("expected tunnel --background flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range []string{"list", "stop", "stop-all"} {
|
||||||
|
if cmd, _, err := tunnelCmd.Find([]string{name}); err != nil || cmd == nil || cmd.Name() != name {
|
||||||
|
t.Fatalf("expected tunnel subcommand %q, got cmd=%v err=%v", name, cmd, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateBackgroundTunnelRejectsSecretAuth(t *testing.T) {
|
||||||
|
server := &model.Server{Alias: "db", AuthMethod: model.AuthPassword}
|
||||||
|
forwards := []*model.Forward{{Enabled: true, Type: model.ForwardLocal, LocalPort: 15432}}
|
||||||
|
|
||||||
|
err := validateBackgroundTunnel(server, forwards)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected background tunnel to reject password auth")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "key or agent auth") {
|
||||||
|
t.Fatalf("expected auth guidance, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateBackgroundTunnelRequiresEnabledForward(t *testing.T) {
|
||||||
|
server := &model.Server{Alias: "db", AuthMethod: model.AuthKey}
|
||||||
|
forwards := []*model.Forward{{Enabled: false, Type: model.ForwardLocal, LocalPort: 15432}}
|
||||||
|
|
||||||
|
err := validateBackgroundTunnel(server, forwards)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected background tunnel to require enabled forwards")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no enabled forwards") {
|
||||||
|
t.Fatalf("expected enabled forward error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,11 +2,13 @@ package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mirivlad/sshkeeper/internal/config"
|
"github.com/mirivlad/sshkeeper/internal/config"
|
||||||
|
|
@ -22,7 +24,11 @@ var (
|
||||||
|
|
||||||
// Init initializes the tunnel state manager with the data directory.
|
// Init initializes the tunnel state manager with the data directory.
|
||||||
func Init(dir string) error {
|
func Init(dir string) error {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
dataDir = dir
|
dataDir = dir
|
||||||
|
states = map[int64]*model.TunnelState{}
|
||||||
return loadStates()
|
return loadStates()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,6 +106,7 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward,
|
||||||
copy(args, sshArgs)
|
copy(args, sshArgs)
|
||||||
|
|
||||||
cmd := exec.Command(cfg.SSH.Binary, args...)
|
cmd := exec.Command(cfg.SSH.Binary, args...)
|
||||||
|
cmd.Env = os.Environ()
|
||||||
cmd.Stdin = nil
|
cmd.Stdin = nil
|
||||||
cmd.Stdout = nil
|
cmd.Stdout = nil
|
||||||
cmd.Stderr = nil
|
cmd.Stderr = nil
|
||||||
|
|
@ -126,8 +133,13 @@ func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward,
|
||||||
|
|
||||||
states[id] = state
|
states[id] = state
|
||||||
if err := saveStates(); err != nil {
|
if err := saveStates(); err != nil {
|
||||||
// Non-fatal: log but don't fail
|
delete(states, id)
|
||||||
_ = err
|
_ = cmd.Process.Kill()
|
||||||
|
return nil, fmt.Errorf("save tunnel state: %w", err)
|
||||||
|
}
|
||||||
|
if err := cmd.Process.Release(); err != nil {
|
||||||
|
delete(states, id)
|
||||||
|
return nil, fmt.Errorf("release tunnel process: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
|
|
@ -188,5 +200,6 @@ func IsRunning(id int64) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// Signal 0 just checks if process exists
|
// Signal 0 just checks if process exists
|
||||||
return proc.Signal(os.Signal(nil)) == nil
|
err = proc.Signal(syscall.Signal(0))
|
||||||
|
return err == nil || errors.Is(err, syscall.EPERM)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mirivlad/sshkeeper/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInitClearsPreviousInMemoryStates(t *testing.T) {
|
||||||
|
if err := Init(t.TempDir()); err != nil {
|
||||||
|
t.Fatalf("init first dir: %v", err)
|
||||||
|
}
|
||||||
|
states[123] = &model.TunnelState{ID: 123, ServerAlias: "old", PID: 1}
|
||||||
|
|
||||||
|
if err := Init(t.TempDir()); err != nil {
|
||||||
|
t.Fatalf("init second dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := Get(123); got != nil {
|
||||||
|
t.Fatalf("expected init to clear stale state, got %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRunningDetectsLiveProcess(t *testing.T) {
|
||||||
|
if err := Init(t.TempDir()); err != nil {
|
||||||
|
t.Fatalf("init: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("sleep", "2")
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
t.Fatalf("start sleep: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
_, _ = cmd.Process.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
states[1] = &model.TunnelState{ID: 1, ServerAlias: "live", PID: cmd.Process.Pid, StartedAt: time.Now()}
|
||||||
|
|
||||||
|
if !IsRunning(1) {
|
||||||
|
t.Fatalf("expected pid %d to be detected as running", cmd.Process.Pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue