sshkeeper: v0.2.0 — Phase 4: CLI route/forward/tunnel commands

This commit is contained in:
mirivlad 2026-06-03 10:32:18 +08:00
parent 912b17e1f1
commit c2edaa4224
5 changed files with 340 additions and 2 deletions

118
cmd/forward.go Normal file
View File

@ -0,0 +1,118 @@
package cmd
import (
"fmt"
"strconv"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/spf13/cobra"
)
// --- Forward commands ---
var forwardCmd = &cobra.Command{
Use: "forward",
Short: "Manage port forwards",
}
var forwardListCmd = &cobra.Command{
Use: "list <alias>",
Short: "List port forwards for a server",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
forwards, err := appDB.GetForwards(server.ID)
if err != nil {
return fmt.Errorf("list forwards: %w", err)
}
if len(forwards) == 0 {
fmt.Println("No port forwards configured.")
return nil
}
fmt.Printf("Port forwards for %s:\n", alias)
for _, f := range forwards {
switch f.Type {
case model.ForwardLocal:
fmt.Printf(" [%d] -L %s:%d:%s:%d\n", f.ID, f.LocalAddr, f.LocalPort, f.RemoteAddr, f.RemotePort)
case model.ForwardRemote:
fmt.Printf(" [%d] -R %s:%d:%s:%d\n", f.ID, f.RemoteAddr, f.RemotePort, f.LocalAddr, f.LocalPort)
case model.ForwardDynamic:
fmt.Printf(" [%d] -D %s:%d\n", f.ID, f.LocalAddr, f.LocalPort)
}
}
return nil
},
}
var forwardAddCmd = &cobra.Command{
Use: "add <alias>",
Short: "Add a port forward",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
fwdType, _ := cmd.Flags().GetString("type")
localAddr, _ := cmd.Flags().GetString("local-addr")
localPort, _ := cmd.Flags().GetInt("local-port")
remoteAddr, _ := cmd.Flags().GetString("remote-addr")
remotePort, _ := cmd.Flags().GetInt("remote-port")
fwd := &model.Forward{
ServerID: server.ID,
Type: model.ForwardType(fwdType),
LocalAddr: localAddr,
LocalPort: localPort,
RemoteAddr: remoteAddr,
RemotePort: remotePort,
}
if err := appDB.AddForward(fwd.ServerID, fwd.Type, fwd.LocalAddr, fwd.LocalPort, fwd.RemoteAddr, fwd.RemotePort); err != nil {
return fmt.Errorf("add forward: %w", err)
}
fmt.Printf("✓ Forward added [%d]\n", fwd.ID)
return nil
},
}
var forwardDeleteCmd = &cobra.Command{
Use: "delete <alias> <id>",
Short: "Delete a port forward",
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
id, err := strconv.ParseInt(args[1], 10, 64)
if err != nil {
return fmt.Errorf("invalid forward ID: %s", args[1])
}
// Verify server exists
if _, err := appDB.GetServer(alias); err != nil {
return fmt.Errorf("server not found: %s", alias)
}
if err := appDB.DeleteForward(id); err != nil {
return fmt.Errorf("delete forward: %w", err)
}
fmt.Println("✓ Forward deleted")
return nil
},
}
func init() {
forwardAddCmd.Flags().String("type", "local", "Forward type: local, remote, dynamic")
forwardAddCmd.Flags().String("local-addr", "0.0.0.0", "Listen address")
forwardAddCmd.Flags().Int("local-port", 0, "Listen port (required)")
forwardAddCmd.MarkFlagRequired("local-port")
forwardAddCmd.Flags().String("remote-addr", "", "Target address")
forwardAddCmd.Flags().Int("remote-port", 0, "Target port")
forwardCmd.AddCommand(forwardListCmd)
forwardCmd.AddCommand(forwardAddCmd)
forwardCmd.AddCommand(forwardDeleteCmd)
}

View File

@ -55,6 +55,9 @@ func init() {
rootCmd.AddCommand(groupCmd) rootCmd.AddCommand(groupCmd)
rootCmd.AddCommand(templateCmd) rootCmd.AddCommand(templateCmd)
rootCmd.AddCommand(runTemplateCmd) rootCmd.AddCommand(runTemplateCmd)
rootCmd.AddCommand(routeCmd)
rootCmd.AddCommand(forwardCmd)
rootCmd.AddCommand(tunnelCmd)
} }
func initApp() { func initApp() {
@ -172,7 +175,7 @@ func commandRequiresStartupVaultUnlock(args []string) bool {
} }
switch args[0] { switch args[0] {
case "connect", "c", "run", "run-template", "test", "edit", "delete": case "connect", "c", "run", "run-template", "test", "edit", "delete", "tunnel":
return true return true
default: default:
return false return false

128
cmd/route.go Normal file
View File

@ -0,0 +1,128 @@
package cmd
import (
"fmt"
"strings"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/spf13/cobra"
)
// --- Route commands ---
var routeCmd = &cobra.Command{
Use: "route",
Short: "Manage server routes (ProxyJump)",
}
var routeShowCmd = &cobra.Command{
Use: "show <alias>",
Short: "Show route for a server",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port)
if len(server.Route.Hops) > 0 {
fmt.Printf("Route: %s\n", server.Route.DisplaySummary(target))
fmt.Printf("Mode: %s\n", server.Route.RouteMode())
fmt.Printf("ProxyJump: %s\n", server.Route.ProxyJumpString())
if server.Route.HasProfileLinks() {
fmt.Println("Hops:")
for _, h := range server.Route.Hops {
if h.IsProfile {
fmt.Printf(" - %s (profile)\n", h.Alias)
} else {
fmt.Printf(" - %s (raw)\n", h.Raw)
}
}
}
} else if server.ProxyJump != "" {
fmt.Printf("ProxyJump: %s\n", server.ProxyJump)
} else {
fmt.Println("Direct connection (no route)")
}
return nil
},
}
var routeSetCmd = &cobra.Command{
Use: "set <alias>",
Short: "Set route for a server",
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
mode, _ := cmd.Flags().GetString("mode")
jumps, _ := cmd.Flags().GetString("jumps")
if mode == "clear" || jumps == "" {
server.Route = model.Route{}
server.ProxyJump = ""
} else {
parts := strings.Split(jumps, ",")
hops := make([]model.RouteHop, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if strings.Contains(p, "@") || strings.Contains(p, ":") {
hops = append(hops, model.RouteHop{Raw: p, IsProfile: false})
} else {
hops = append(hops, model.RouteHop{Alias: p, IsProfile: true})
}
}
server.Route = model.Route{Hops: hops}
server.ProxyJump = server.Route.ProxyJumpString()
}
if err := appDB.UpdateServer(server); err != nil {
return fmt.Errorf("update route: %w", err)
}
target := fmt.Sprintf("%s@%s:%d", server.User, server.Host, server.Port)
if len(server.Route.Hops) > 0 {
fmt.Printf("✓ Route set: %s\n", server.Route.DisplaySummary(target))
} else {
fmt.Println("✓ Route cleared (direct connection)")
}
return nil
},
}
var routeClearCmd = &cobra.Command{
Use: "clear <alias>",
Short: "Clear route for a server (set direct)",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
server.Route = model.Route{}
server.ProxyJump = ""
if err := appDB.UpdateServer(server); err != nil {
return fmt.Errorf("clear route: %w", err)
}
fmt.Println("✓ Route cleared (direct connection)")
return nil
},
}
func init() {
routeSetCmd.Flags().String("mode", "via", "Route mode: direct, via, chain, or clear")
routeSetCmd.Flags().String("jumps", "", "Comma-separated jump hosts (aliases or raw addresses)")
routeCmd.AddCommand(routeShowCmd)
routeCmd.AddCommand(routeSetCmd)
routeCmd.AddCommand(routeClearCmd)
}

63
cmd/tunnel.go Normal file
View File

@ -0,0 +1,63 @@
package cmd
import (
"fmt"
"github.com/mirivlad/sshkeeper/internal/ssh"
"github.com/spf13/cobra"
)
var tunnelCmd = &cobra.Command{
Use: "tunnel <alias>",
Short: "Start SSH session with port forwards",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
server, err := appDB.GetServer(alias)
if err != nil {
return fmt.Errorf("server not found: %s", alias)
}
forwardsOnly, _ := cmd.Flags().GetBool("forward-only")
v := getOrCreateVault()
vaultFunc := func(serverAlias string, secretType string) (string, error) {
if !v.IsUnlocked() {
return "", fmt.Errorf("%s", vaultLockedProcessMessage())
}
key := fmt.Sprintf("server:%s:%s", serverAlias, secretType)
data, err := v.Get(key)
if err != nil {
return "", err
}
return string(data), nil
}
// Load forwards
forwards, err := appDB.GetForwards(server.ID)
if err != nil {
return fmt.Errorf("load forwards: %w", err)
}
if len(forwards) == 0 && forwardsOnly {
return fmt.Errorf("no forwards configured for %s", alias)
}
if len(forwards) > 0 {
fmt.Printf("Starting tunnel to %s with %d forward(s)...\n", alias, len(forwards))
} else {
fmt.Printf("Starting session to %s...\n", alias)
}
sshArgs := ssh.BuildSSHArgs(server, forwards, forwardsOnly)
if forwardsOnly {
fmt.Printf("Tunnel mode (ssh -N). Press Ctrl+C to exit.\n")
}
return ssh.ConnectWithArgs(cfg, sshArgs, vaultFunc, server)
},
}
func init() {
tunnelCmd.Flags().Bool("forward-only", false, "Start tunnel only (ssh -N)")
}

View File

@ -178,7 +178,33 @@ func testWithPassword(cfg *config.Config, args []string, password string) (bool,
return false, result return false, result
} }
// BuildForwardArgs builds SSH port forwarding arguments. func ConnectWithArgs(cfg *config.Config, args []string, vaultFunc VaultFunc, server *model.Server) error {
switch server.AuthMethod {
case model.AuthPassword:
password, err := vaultFunc(server.Alias, "ssh_password")
if err != nil {
return fmt.Errorf("get password from vault: %w", err)
}
return ConnectWithPassword(cfg.SSH.Binary, args, password)
case model.AuthKeyPassphrase:
passphrase, err := vaultFunc(server.Alias, "key_passphrase")
if err != nil {
return fmt.Errorf("get key passphrase from vault: %w", err)
}
return ConnectWithPassword(cfg.SSH.Binary, args, passphrase)
default:
cmd := exec.Command(cfg.SSH.Binary, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return fmt.Errorf("start ssh: %w", err)
}
return cmd.Wait()
}
}
func BuildForwardArgs(forwards []*model.Forward, exitOnForwardFailure bool) []string { func BuildForwardArgs(forwards []*model.Forward, exitOnForwardFailure bool) []string {
var args []string var args []string
for _, f := range forwards { for _, f := range forwards {