sshkeeper/internal/ssh/command_test.go

392 lines
12 KiB
Go

package ssh
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/mirivlad/sshkeeper/internal/config"
"github.com/mirivlad/sshkeeper/internal/model"
)
func TestKeyPassphraseTestUsesVaultSecret(t *testing.T) {
script := filepath.Join(t.TempDir(), "fake-ssh")
if err := os.WriteFile(script, []byte("#!/bin/sh\nprintf 'Enter passphrase for key: '\nIFS= read -r passphrase\nif [ \"$passphrase\" = \"key-secret\" ]; then\n echo SSHKEEPER_OK\n exit 0\nfi\necho denied\nexit 1\n"), 0o700); err != nil {
t.Fatalf("write fake ssh: %v", err)
}
cfg := &config.Config{
SSH: config.SSHConfig{
Binary: script,
ConnectTimeoutSec: 2,
TestCommand: "echo SSHKEEPER_OK",
},
}
server := &model.Server{
Alias: "prod",
Host: "example.org",
Port: 22,
User: "root",
AuthMethod: model.AuthKeyPassphrase,
IdentityFile: "/tmp/test-key",
}
ok, errText := Test(cfg, server, func(alias string, secretType string) (string, error) {
if alias != "prod" || secretType != "key_passphrase" {
return "", fmt.Errorf("unexpected secret lookup %s %s", alias, secretType)
}
return "key-secret", nil
})
if !ok {
t.Fatalf("expected key passphrase test to pass, error: %s", errText)
}
}
func TestKeyPassphraseTestReportsVaultError(t *testing.T) {
cfg := &config.Config{
SSH: config.SSHConfig{
Binary: "ssh",
ConnectTimeoutSec: 1,
TestCommand: "echo SSHKEEPER_OK",
},
}
server := &model.Server{
Alias: "prod",
Host: "example.org",
Port: 22,
User: "root",
AuthMethod: model.AuthKeyPassphrase,
}
ok, errText := Test(cfg, server, func(alias string, secretType string) (string, error) {
return "", fmt.Errorf("missing secret")
})
if ok {
t.Fatal("expected key passphrase test to fail when vault lookup fails")
}
if !strings.Contains(errText, "vault error: missing secret") {
t.Fatalf("expected vault error, got %q", errText)
}
}
func TestConnectRunsStartupCommand(t *testing.T) {
dir := t.TempDir()
argsFile := filepath.Join(dir, "args")
script := filepath.Join(dir, "fake-ssh")
if err := os.WriteFile(script, []byte(fmt.Sprintf("#!/bin/sh\nprintf '%%s\\n' \"$@\" > %q\n", argsFile)), 0o700); err != nil {
t.Fatalf("write fake ssh: %v", err)
}
cfg := &config.Config{SSH: config.SSHConfig{Binary: script}}
server := &model.Server{
Alias: "prod",
Host: "example.org",
Port: 22,
User: "root",
AuthMethod: model.AuthKey,
StartupCommand: "tmux attach -t ops",
}
if err := Connect(cfg, server, nil); err != nil {
t.Fatalf("connect: %v", err)
}
data, err := os.ReadFile(argsFile)
if err != nil {
t.Fatalf("read args: %v", err)
}
if !strings.Contains(string(data), "tmux attach -t ops") {
t.Fatalf("expected startup command in ssh args, got:\n%s", data)
}
}
func TestValidateSSHBinaryForWindowsRequiresOpenSSHClient(t *testing.T) {
err := validateSSHBinaryForOS("windows", "ssh.exe", func(name string) (string, error) {
if name != "ssh.exe" {
t.Fatalf("expected lookup for ssh.exe, got %q", name)
}
return "", os.ErrNotExist
})
if err == nil {
t.Fatal("expected missing ssh.exe error")
}
if !strings.Contains(err.Error(), "OpenSSH Client") {
t.Fatalf("expected OpenSSH Client guidance, got %q", err)
}
if !strings.Contains(err.Error(), "Add-WindowsCapability -Online -Name OpenSSH.Client~~~~0.0.1.0") {
t.Fatalf("expected PowerShell install command, got %q", err)
}
}
func TestValidateSSHBinaryForNonWindowsChecksConfiguredBinary(t *testing.T) {
var lookedUp string
err := validateSSHBinaryForOS("linux", "/usr/bin/ssh", func(name string) (string, error) {
lookedUp = name
return "/usr/bin/ssh", nil
})
if err != nil {
t.Fatalf("validate ssh binary: %v", err)
}
if lookedUp != "/usr/bin/ssh" {
t.Fatalf("expected lookup for configured binary, got %q", lookedUp)
}
}
func TestBuildSSHArgs_Simple(t *testing.T) {
server := &model.Server{Host: "example.org", Port: 22, User: "root"}
args := BuildSSHArgsSimple(server)
expected := []string{"-p", "22", "-o", "StrictHostKeyChecking=accept-new", "root@example.org"}
if len(args) != len(expected) {
t.Fatalf("expected %d args, got %d: %v", len(expected), len(args), args)
}
for i, a := range args {
if a != expected[i] {
t.Fatalf("arg[%d]: expected %q, got %q", i, expected[i], a)
}
}
}
func TestBuildSSHArgs_CustomPort(t *testing.T) {
server := &model.Server{Host: "example.org", Port: 2222, User: "deploy"}
args := BuildSSHArgsSimple(server)
if args[1] != "2222" {
t.Fatalf("expected port 2222, got %s", args[1])
}
if args[len(args)-1] != "deploy@example.org" {
t.Fatalf("expected target deploy@example.org, got %s", args[len(args)-1])
}
}
func TestBuildSSHArgs_WithIdentityFile(t *testing.T) {
server := &model.Server{Host: "example.org", Port: 22, User: "root", IdentityFile: "~/.ssh/id_ed25519"}
args := BuildSSHArgsSimple(server)
found := false
for _, a := range args {
if a == "-i" {
found = true
break
}
}
if !found {
t.Fatalf("expected -i flag in args: %v", args)
}
}
func TestBuildSSHArgs_WithProxyJump(t *testing.T) {
server := &model.Server{Host: "internal.example.org", Port: 22, User: "root", ProxyJump: "bastion.example.org"}
args := BuildSSHArgsSimple(server)
found := false
for _, a := range args {
if a == "-J" {
found = true
break
}
}
if !found {
t.Fatalf("expected -J flag in args: %v", args)
}
}
func TestBuildSSHArgs_ForwardLocal(t *testing.T) {
server := &model.Server{Host: "db.internal", Port: 22, User: "root"}
forwards := []*model.Forward{{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 15432, RemoteAddr: "127.0.0.1", RemotePort: 5432}}
args := BuildSSHArgs(server, forwards, false)
found := false
for _, a := range args {
if a == "-L" {
found = true
break
}
}
if !found {
t.Fatalf("expected -L flag in args: %v", args)
}
}
func TestBuildSSHArgs_ForwardRemote(t *testing.T) {
server := &model.Server{Host: "web.internal", Port: 22, User: "root"}
forwards := []*model.Forward{{Type: model.ForwardRemote, LocalAddr: "127.0.0.1", LocalPort: 8080, RemoteAddr: "0.0.0.0", RemotePort: 80}}
args := BuildSSHArgs(server, forwards, false)
found := false
for _, a := range args {
if a == "-R" {
found = true
break
}
}
if !found {
t.Fatalf("expected -R flag in args: %v", args)
}
}
func TestBuildSSHArgs_ForwardDynamic(t *testing.T) {
server := &model.Server{Host: "jump.example.org", Port: 22, User: "me"}
forwards := []*model.Forward{{Type: model.ForwardDynamic, LocalAddr: "127.0.0.1", LocalPort: 1080}}
args := BuildSSHArgs(server, forwards, false)
found := false
for _, a := range args {
if a == "-D" {
found = true
break
}
}
if !found {
t.Fatalf("expected -D flag in args: %v", args)
}
}
func TestBuildSSHArgs_ForwardOnly(t *testing.T) {
server := &model.Server{Host: "db.internal", Port: 22, User: "root"}
forwards := []*model.Forward{{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 15432, RemoteAddr: "127.0.0.1", RemotePort: 5432}}
args := BuildSSHArgs(server, forwards, true)
found := false
for _, a := range args {
if a == "-N" {
found = true
break
}
}
if !found {
t.Fatalf("expected -N flag in args: %v", args)
}
}
func TestBuildSSHArgs_ExitOnForwardFailure(t *testing.T) {
server := &model.Server{Host: "db.internal", Port: 22, User: "root"}
forwards := []*model.Forward{{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 15432, RemoteAddr: "127.0.0.1", RemotePort: 5432}}
args := BuildSSHArgs(server, forwards, false)
found := false
for _, a := range args {
if a == "ExitOnForwardFailure=yes" {
found = true
break
}
}
if !found {
t.Fatalf("expected ExitOnForwardFailure=yes in args: %v", args)
}
}
func TestBuildSSHArgs_MultipleForwards(t *testing.T) {
server := &model.Server{Host: "multi.internal", Port: 22, User: "root"}
forwards := []*model.Forward{
{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 15432, RemoteAddr: "127.0.0.1", RemotePort: 5432},
{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 18080, RemoteAddr: "internal.web", RemotePort: 80},
{Type: model.ForwardDynamic, LocalAddr: "127.0.0.1", LocalPort: 1080},
}
args := BuildSSHArgs(server, forwards, false)
lCount, dCount := 0, 0
for _, a := range args {
switch a {
case "-L":
lCount++
case "-D":
dCount++
}
}
if lCount != 2 {
t.Fatalf("expected 2 -L flags, got %d: %v", lCount, args)
}
if dCount != 1 {
t.Fatalf("expected 1 -D flag, got %d: %v", dCount, args)
}
}
func TestBuildSSHArgs_RouteAndForwards(t *testing.T) {
server := &model.Server{Host: "secure.internal", Port: 22, User: "root", Route: model.Route{Hops: []model.RouteHop{{Alias: "bastion", IsProfile: true}}}}
forwards := []*model.Forward{{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 15432, RemoteAddr: "127.0.0.1", RemotePort: 5432}}
args := BuildSSHArgs(server, forwards, false)
hasJ, hasL := false, false
for _, a := range args {
if a == "-J" {
hasJ = true
}
if a == "-L" {
hasL = true
}
}
if !hasJ {
t.Fatalf("expected -J flag in args: %v", args)
}
if !hasL {
t.Fatalf("expected -L flag in args: %v", args)
}
}
func TestBuildForwardArgs_Local(t *testing.T) {
fwd := &model.Forward{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 8080, RemoteAddr: "internal.web", RemotePort: 80}
args := BuildForwardArgs([]*model.Forward{fwd}, true)
expected := []string{"-L", "127.0.0.1:8080:internal.web:80", "-o", "ExitOnForwardFailure=yes"}
if len(args) != len(expected) {
t.Fatalf("expected %v, got %v", expected, args)
}
}
func TestBuildForwardArgs_Remote(t *testing.T) {
fwd := &model.Forward{Type: model.ForwardRemote, LocalAddr: "127.0.0.1", LocalPort: 2222, RemoteAddr: "0.0.0.0", RemotePort: 22}
args := BuildForwardArgs([]*model.Forward{fwd}, true)
expected := []string{"-R", "0.0.0.0:22:127.0.0.1:2222", "-o", "ExitOnForwardFailure=yes"}
if len(args) != len(expected) {
t.Fatalf("expected %v, got %v", expected, args)
}
}
func TestBuildForwardArgs_Dynamic(t *testing.T) {
fwd := &model.Forward{Type: model.ForwardDynamic, LocalAddr: "127.0.0.1", LocalPort: 1080}
args := BuildForwardArgs([]*model.Forward{fwd}, true)
expected := []string{"-D", "127.0.0.1:1080", "-o", "ExitOnForwardFailure=yes"}
if len(args) != len(expected) {
t.Fatalf("expected %v, got %v", expected, args)
}
}
func TestBuildForwardArgs_Empty(t *testing.T) {
args := BuildForwardArgs(nil, true)
if len(args) != 0 {
t.Fatalf("expected no args, got %v", args)
}
}
func TestForwardHumanExplanation(t *testing.T) {
tests := []struct {
fwd *model.Forward
want string
}{
{&model.Forward{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 15432, RemoteAddr: "127.0.0.1", RemotePort: 5432}, "Port 127.0.0.1:15432 on this machine will be forwarded through to 127.0.0.1:5432."},
{&model.Forward{Type: model.ForwardRemote, LocalAddr: "127.0.0.1", LocalPort: 8080, RemoteAddr: "0.0.0.0", RemotePort: 80}, "Port 0.0.0.0:80 on will be forwarded to 127.0.0.1:8080 on this machine."},
{&model.Forward{Type: model.ForwardDynamic, LocalAddr: "127.0.0.1", LocalPort: 1080}, "SOCKS proxy on 127.0.0.1:1080 will route traffic through ."},
}
for _, tt := range tests {
got := tt.fwd.ForwardHumanExplanation("")
if got != tt.want {
t.Fatalf("ForwardHumanExplanation() = %q, want %q", got, tt.want)
}
}
}
func TestForwardListenTarget(t *testing.T) {
tests := []struct {
fwd *model.Forward
listen string
target string
}{
{&model.Forward{Type: model.ForwardLocal, LocalAddr: "127.0.0.1", LocalPort: 15432, RemoteAddr: "127.0.0.1", RemotePort: 5432}, "127.0.0.1:15432", "127.0.0.1:5432"},
{&model.Forward{Type: model.ForwardRemote, LocalAddr: "127.0.0.1", LocalPort: 8080, RemoteAddr: "0.0.0.0", RemotePort: 80}, "0.0.0.0:80", "127.0.0.1:8080"},
{&model.Forward{Type: model.ForwardDynamic, LocalAddr: "127.0.0.1", LocalPort: 1080}, "127.0.0.1:1080", "SOCKS"},
}
for _, tt := range tests {
if got := tt.fwd.ForwardListen(); got != tt.listen {
t.Fatalf("ForwardListen() = %q, want %q", got, tt.listen)
}
if got := tt.fwd.ForwardTarget(); got != tt.target {
t.Fatalf("ForwardTarget() = %q, want %q", got, tt.target)
}
}
}