From c2edaa42244160621dfa380f23fc80ebf3250e92 Mon Sep 17 00:00:00 2001 From: mirivlad Date: Wed, 3 Jun 2026 10:32:18 +0800 Subject: [PATCH] =?UTF-8?q?sshkeeper:=20v0.2.0=20=E2=80=94=20Phase=204:=20?= =?UTF-8?q?CLI=20route/forward/tunnel=20commands?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/forward.go | 118 ++++++++++++++++++++++++++++++++++++ cmd/root.go | 5 +- cmd/route.go | 128 ++++++++++++++++++++++++++++++++++++++++ cmd/tunnel.go | 63 ++++++++++++++++++++ internal/ssh/command.go | 28 ++++++++- 5 files changed, 340 insertions(+), 2 deletions(-) create mode 100644 cmd/forward.go create mode 100644 cmd/route.go create mode 100644 cmd/tunnel.go diff --git a/cmd/forward.go b/cmd/forward.go new file mode 100644 index 0000000..92f687b --- /dev/null +++ b/cmd/forward.go @@ -0,0 +1,118 @@ +package cmd + +import ( + "fmt" + "strconv" + + "github.com/mirivlad/sshkeeper/internal/model" + "github.com/spf13/cobra" +) + +// --- Forward commands --- + +var forwardCmd = &cobra.Command{ + Use: "forward", + Short: "Manage port forwards", +} + +var forwardListCmd = &cobra.Command{ + Use: "list ", + Short: "List port forwards for a server", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + server, err := appDB.GetServer(alias) + if err != nil { + return fmt.Errorf("server not found: %s", alias) + } + forwards, err := appDB.GetForwards(server.ID) + if err != nil { + return fmt.Errorf("list forwards: %w", err) + } + if len(forwards) == 0 { + fmt.Println("No port forwards configured.") + return nil + } + fmt.Printf("Port forwards for %s:\n", alias) + for _, f := range forwards { + switch f.Type { + case model.ForwardLocal: + fmt.Printf(" [%d] -L %s:%d:%s:%d\n", f.ID, f.LocalAddr, f.LocalPort, f.RemoteAddr, f.RemotePort) + case model.ForwardRemote: + fmt.Printf(" [%d] -R %s:%d:%s:%d\n", f.ID, f.RemoteAddr, f.RemotePort, f.LocalAddr, f.LocalPort) + case model.ForwardDynamic: + fmt.Printf(" [%d] -D %s:%d\n", f.ID, f.LocalAddr, f.LocalPort) + } + } + return nil + }, +} + +var forwardAddCmd = &cobra.Command{ + Use: "add ", + Short: "Add a port forward", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + server, err := appDB.GetServer(alias) + if err != nil { + return fmt.Errorf("server not found: %s", alias) + } + + fwdType, _ := cmd.Flags().GetString("type") + localAddr, _ := cmd.Flags().GetString("local-addr") + localPort, _ := cmd.Flags().GetInt("local-port") + remoteAddr, _ := cmd.Flags().GetString("remote-addr") + remotePort, _ := cmd.Flags().GetInt("remote-port") + + fwd := &model.Forward{ + ServerID: server.ID, + Type: model.ForwardType(fwdType), + LocalAddr: localAddr, + LocalPort: localPort, + RemoteAddr: remoteAddr, + RemotePort: remotePort, + } + + if err := appDB.AddForward(fwd.ServerID, fwd.Type, fwd.LocalAddr, fwd.LocalPort, fwd.RemoteAddr, fwd.RemotePort); err != nil { + return fmt.Errorf("add forward: %w", err) + } + fmt.Printf("✓ Forward added [%d]\n", fwd.ID) + return nil + }, +} + +var forwardDeleteCmd = &cobra.Command{ + Use: "delete ", + Short: "Delete a port forward", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + id, err := strconv.ParseInt(args[1], 10, 64) + if err != nil { + return fmt.Errorf("invalid forward ID: %s", args[1]) + } + // Verify server exists + if _, err := appDB.GetServer(alias); err != nil { + return fmt.Errorf("server not found: %s", alias) + } + if err := appDB.DeleteForward(id); err != nil { + return fmt.Errorf("delete forward: %w", err) + } + fmt.Println("✓ Forward deleted") + return nil + }, +} + +func init() { + forwardAddCmd.Flags().String("type", "local", "Forward type: local, remote, dynamic") + forwardAddCmd.Flags().String("local-addr", "0.0.0.0", "Listen address") + forwardAddCmd.Flags().Int("local-port", 0, "Listen port (required)") + forwardAddCmd.MarkFlagRequired("local-port") + forwardAddCmd.Flags().String("remote-addr", "", "Target address") + forwardAddCmd.Flags().Int("remote-port", 0, "Target port") + + forwardCmd.AddCommand(forwardListCmd) + forwardCmd.AddCommand(forwardAddCmd) + forwardCmd.AddCommand(forwardDeleteCmd) +} diff --git a/cmd/root.go b/cmd/root.go index 4762031..28a6466 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -55,6 +55,9 @@ func init() { rootCmd.AddCommand(groupCmd) rootCmd.AddCommand(templateCmd) rootCmd.AddCommand(runTemplateCmd) + rootCmd.AddCommand(routeCmd) + rootCmd.AddCommand(forwardCmd) + rootCmd.AddCommand(tunnelCmd) } func initApp() { @@ -172,7 +175,7 @@ func commandRequiresStartupVaultUnlock(args []string) bool { } switch args[0] { - case "connect", "c", "run", "run-template", "test", "edit", "delete": + case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel": return true default: return false diff --git a/cmd/route.go b/cmd/route.go new file mode 100644 index 0000000..fd892af --- /dev/null +++ b/cmd/route.go @@ -0,0 +1,128 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/mirivlad/sshkeeper/internal/model" + "github.com/spf13/cobra" +) + +// --- Route commands --- + +var routeCmd = &cobra.Command{ + Use: "route", + Short: "Manage server routes (ProxyJump)", +} + +var routeShowCmd = &cobra.Command{ + Use: "show ", + Short: "Show route for a server", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + server, err := appDB.GetServer(alias) + if err != nil { + return fmt.Errorf("server not found: %s", alias) + } + target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port) + if len(server.Route.Hops) > 0 { + fmt.Printf("Route: %s\n", server.Route.DisplaySummary(target)) + fmt.Printf("Mode: %s\n", server.Route.RouteMode()) + fmt.Printf("ProxyJump: %s\n", server.Route.ProxyJumpString()) + if server.Route.HasProfileLinks() { + fmt.Println("Hops:") + for _, h := range server.Route.Hops { + if h.IsProfile { + fmt.Printf(" - %s (profile)\n", h.Alias) + } else { + fmt.Printf(" - %s (raw)\n", h.Raw) + } + } + } + } else if server.ProxyJump != "" { + fmt.Printf("ProxyJump: %s\n", server.ProxyJump) + } else { + fmt.Println("Direct connection (no route)") + } + return nil + }, +} + +var routeSetCmd = &cobra.Command{ + Use: "set ", + Short: "Set route for a server", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + server, err := appDB.GetServer(alias) + if err != nil { + return fmt.Errorf("server not found: %s", alias) + } + + mode, _ := cmd.Flags().GetString("mode") + jumps, _ := cmd.Flags().GetString("jumps") + + if mode == "clear" || jumps == "" { + server.Route = model.Route{} + server.ProxyJump = "" + } else { + parts := strings.Split(jumps, ",") + hops := make([]model.RouteHop, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if strings.Contains(p, "@") || strings.Contains(p, ":") { + hops = append(hops, model.RouteHop{Raw: p, IsProfile: false}) + } else { + hops = append(hops, model.RouteHop{Alias: p, IsProfile: true}) + } + } + server.Route = model.Route{Hops: hops} + server.ProxyJump = server.Route.ProxyJumpString() + } + + if err := appDB.UpdateServer(server); err != nil { + return fmt.Errorf("update route: %w", err) + } + + target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port) + if len(server.Route.Hops) > 0 { + fmt.Printf("✓ Route set: %s\n", server.Route.DisplaySummary(target)) + } else { + fmt.Println("✓ Route cleared (direct connection)") + } + return nil + }, +} + +var routeClearCmd = &cobra.Command{ + Use: "clear ", + Short: "Clear route for a server (set direct)", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + server, err := appDB.GetServer(alias) + if err != nil { + return fmt.Errorf("server not found: %s", alias) + } + server.Route = model.Route{} + server.ProxyJump = "" + if err := appDB.UpdateServer(server); err != nil { + return fmt.Errorf("clear route: %w", err) + } + fmt.Println("✓ Route cleared (direct connection)") + return nil + }, +} + +func init() { + routeSetCmd.Flags().String("mode", "via", "Route mode: direct, via, chain, or clear") + routeSetCmd.Flags().String("jumps", "", "Comma-separated jump hosts (aliases or raw addresses)") + + routeCmd.AddCommand(routeShowCmd) + routeCmd.AddCommand(routeSetCmd) + routeCmd.AddCommand(routeClearCmd) +} diff --git a/cmd/tunnel.go b/cmd/tunnel.go new file mode 100644 index 0000000..99a6d9a --- /dev/null +++ b/cmd/tunnel.go @@ -0,0 +1,63 @@ +package cmd + +import ( + "fmt" + + "github.com/mirivlad/sshkeeper/internal/ssh" + "github.com/spf13/cobra" +) + +var tunnelCmd = &cobra.Command{ + Use: "tunnel ", + Short: "Start SSH session with port forwards", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + alias := args[0] + server, err := appDB.GetServer(alias) + if err != nil { + return fmt.Errorf("server not found: %s", alias) + } + + forwardsOnly, _ := cmd.Flags().GetBool("forward-only") + + v := getOrCreateVault() + vaultFunc := func(serverAlias string, secretType string) (string, error) { + if !v.IsUnlocked() { + return "", fmt.Errorf("%s", vaultLockedProcessMessage()) + } + key := fmt.Sprintf("server:%s:%s", serverAlias, secretType) + data, err := v.Get(key) + if err != nil { + return "", err + } + return string(data), nil + } + + // Load forwards + forwards, err := appDB.GetForwards(server.ID) + if err != nil { + return fmt.Errorf("load forwards: %w", err) + } + + if len(forwards) == 0 && forwardsOnly { + return fmt.Errorf("no forwards configured for %s", alias) + } + + if len(forwards) > 0 { + fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards)) + } else { + fmt.Printf("Starting session to %s...\n", alias) + } + + sshArgs := ssh.BuildSSHArgs(server, forwards, forwardsOnly) + if forwardsOnly { + fmt.Printf("Tunnel mode (ssh -N). Press Ctrl+C to exit.\n") + } + + return ssh.ConnectWithArgs(cfg, sshArgs, vaultFunc, server) + }, +} + +func init() { + tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)") +} diff --git a/internal/ssh/command.go b/internal/ssh/command.go index 52574ab..896f6bb 100644 --- a/internal/ssh/command.go +++ b/internal/ssh/command.go @@ -178,7 +178,33 @@ func testWithPassword(cfg *config.Config, args []string, password string) (bool, return false, result } -// BuildForwardArgs builds SSH port forwarding arguments. +func ConnectWithArgs(cfg *config.Config, args []string, vaultFunc VaultFunc, server *model.Server) error { + switch server.AuthMethod { + case model.AuthPassword: + password, err := vaultFunc(server.Alias, "ssh_password") + if err != nil { + return fmt.Errorf("get password from vault: %w", err) + } + return ConnectWithPassword(cfg.SSH.Binary, args, password) + + case model.AuthKeyPassphrase: + passphrase, err := vaultFunc(server.Alias, "key_passphrase") + if err != nil { + return fmt.Errorf("get key passphrase from vault: %w", err) + } + return ConnectWithPassword(cfg.SSH.Binary, args, passphrase) + + default: + cmd := exec.Command(cfg.SSH.Binary, args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + return fmt.Errorf("start ssh: %w", err) + } + return cmd.Wait() + } +} func BuildForwardArgs(forwards []*model.Forward, exitOnForwardFailure bool) []string { var args []string for _, f := range forwards {