sshkeeper/internal/tunnel/manager.go

193 lines
3.8 KiB
Go

package tunnel
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"sync"
"time"
"github.com/mirivlad/sshkeeper/internal/config"
"github.com/mirivlad/sshkeeper/internal/model"
"github.com/mirivlad/sshkeeper/internal/ssh"
)
var (
mu sync.Mutex
states = map[int64]*model.TunnelState{}
dataDir string
)
// Init initializes the tunnel state manager with the data directory.
func Init(dir string) error {
dataDir = dir
return loadStates()
}
// StateFilePath returns the path to the tunnel state file.
func StateFilePath() string {
return filepath.Join(dataDir, "tunnels.json")
}
func loadStates() error {
path := StateFilePath()
b, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read tunnel states: %w", err)
}
var list []*model.TunnelState
if err := json.Unmarshal(b, &list); err != nil {
return fmt.Errorf("unmarshal tunnel states: %w", err)
}
for _, s := range list {
states[s.ID] = s
}
return nil
}
func saveStates() error {
path := StateFilePath()
list := make([]*model.TunnelState, 0, len(states))
for _, s := range states {
list = append(list, s)
}
b, err := json.MarshalIndent(list, "", " ")
if err != nil {
return fmt.Errorf("marshal tunnel states: %w", err)
}
os.MkdirAll(dataDir, 0700)
return os.WriteFile(path, b, 0600)
}
// List returns all tunnel states.
func List() []*model.TunnelState {
mu.Lock()
defer mu.Unlock()
result := make([]*model.TunnelState, 0, len(states))
for _, s := range states {
result = append(result, s)
}
return result
}
// Get returns a tunnel state by ID.
func Get(id int64) *model.TunnelState {
mu.Lock()
defer mu.Unlock()
return states[id]
}
// Start starts a tunnel for the given server with its forwards.
func Start(cfg *config.Config, server *model.Server, forwards []*model.Forward, forwardOnly bool) (*model.TunnelState, error) {
mu.Lock()
defer mu.Unlock()
// Filter enabled forwards
var active []*model.Forward
for _, f := range forwards {
if f.Enabled {
active = append(active, f)
}
}
sshArgs := ssh.BuildSSHArgs(server, active, forwardOnly)
args := make([]string, len(sshArgs))
copy(args, sshArgs)
cmd := exec.Command(cfg.SSH.Binary, args...)
cmd.Stdin = nil
cmd.Stdout = nil
cmd.Stderr = nil
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start tunnel: %w", err)
}
forwardIDs := make([]int64, len(active))
for i, f := range active {
forwardIDs[i] = f.ID
}
id := time.Now().UnixNano()
state := &model.TunnelState{
ID: id,
ServerID: server.ID,
ServerAlias: server.Alias,
Name: fmt.Sprintf("Tunnel to %s", server.Alias),
PID: cmd.Process.Pid,
ForwardIDs: forwardIDs,
StartedAt: time.Now(),
}
states[id] = state
if err := saveStates(); err != nil {
// Non-fatal: log but don't fail
_ = err
}
return state, nil
}
// Stop stops a tunnel by ID.
func Stop(id int64) error {
mu.Lock()
defer mu.Unlock()
state, ok := states[id]
if !ok {
return fmt.Errorf("tunnel %d not found", id)
}
if state.PID > 0 {
proc, err := os.FindProcess(state.PID)
if err == nil {
proc.Kill()
}
}
delete(states, id)
return saveStates()
}
// StopAll stops all running tunnels.
func StopAll() error {
mu.Lock()
defer mu.Unlock()
for id, state := range states {
if state.PID > 0 {
proc, _ := os.FindProcess(state.PID)
if proc != nil {
proc.Kill()
}
}
delete(states, id)
}
return saveStates()
}
// IsRunning checks if a tunnel process is still running.
func IsRunning(id int64) bool {
mu.Lock()
defer mu.Unlock()
state, ok := states[id]
if !ok {
return false
}
if state.PID <= 0 {
return false
}
proc, err := os.FindProcess(state.PID)
if err != nil {
return false
}
// Signal 0 just checks if process exists
return proc.Signal(os.Signal(nil)) == nil
}