diff --git a/cmd/forward.go b/cmd/forward.go index a825de9..1401509 100644 --- a/cmd/forward.go +++ b/cmd/forward.go @@ -135,9 +135,19 @@ var forwardEditCmd = &cobra.Command{ return fmt.Errorf("invalid forward ID: %s", args[0]) } - // For now, just toggle enabled + fwd, err := appDB.GetForward(id) + if err != nil { + return fmt.Errorf("forward not found: %d", id) + } + enabled, _ := cmd.Flags().GetBool("enabled") - _ = enabled + if cmd.Flags().Changed("enabled") { + fwd.Enabled = enabled + } + if err := appDB.UpdateForward(fwd); err != nil { + return fmt.Errorf("update forward: %w", err) + } + fmt.Printf("✓ Forward %d updated\n", id) return nil }, diff --git a/cmd/forward_test.go b/cmd/forward_test.go new file mode 100644 index 0000000..45e319c --- /dev/null +++ b/cmd/forward_test.go @@ -0,0 +1,59 @@ +package cmd + +import ( + "strconv" + "testing" + + "github.com/mirivlad/sshkeeper/internal/db" + "github.com/mirivlad/sshkeeper/internal/model" + "github.com/spf13/cobra" +) + +func TestForwardEditUpdatesEnabledFlag(t *testing.T) { + testDB, err := db.Open(t.TempDir()) + if err != nil { + t.Fatalf("open db: %v", err) + } + defer testDB.Close() + + previousDB := appDB + appDB = testDB + t.Cleanup(func() { appDB = previousDB }) + + server := &model.Server{Alias: "web", Host: "web.example.org", Port: 22, User: "root", AuthMethod: model.AuthKey} + if err := appDB.CreateServer(server); err != nil { + t.Fatalf("create server: %v", err) + } + forwardID, err := appDB.AddForward(&model.Forward{ + ServerID: server.ID, + Name: "SOCKS", + Type: model.ForwardDynamic, + LocalAddr: "127.0.0.1", + LocalPort: 1080, + Enabled: true, + }) + if err != nil { + t.Fatalf("add forward: %v", err) + } + + cmd := &cobra.Command{} + cmd.Flags().Bool("enabled", true, "Enable/disable forward") + if err := cmd.Flags().Set("enabled", "false"); err != nil { + t.Fatalf("set enabled flag: %v", err) + } + + if err := forwardEditCmd.RunE(cmd, []string{strconv.FormatInt(forwardID, 10)}); err != nil { + t.Fatalf("edit forward: %v", err) + } + + forwards, err := appDB.GetForwards(server.ID) + if err != nil { + t.Fatalf("get forwards: %v", err) + } + if len(forwards) != 1 { + t.Fatalf("expected one forward, got %d", len(forwards)) + } + if forwards[0].Enabled { + t.Fatal("expected forward to be disabled") + } +} diff --git a/internal/db/servers.go b/internal/db/servers.go index 421d170..955506d 100644 --- a/internal/db/servers.go +++ b/internal/db/servers.go @@ -173,8 +173,26 @@ func (db *DB) SearchServers(query string) ([]*model.Server, error) { last_test_at, last_test_status, last_test_error FROM servers WHERE alias LIKE ? OR display_name LIKE ? OR host LIKE ? OR user LIKE ? - OR group_name LIKE ? OR notes LIKE ? OR proxy_jump LIKE ? - ORDER BY alias`, pattern, pattern, pattern, pattern, pattern, pattern, pattern) + OR group_name LIKE ? OR notes LIKE ? OR proxy_jump LIKE ? OR route_hops LIKE ? + OR EXISTS ( + SELECT 1 FROM server_tags st + JOIN tags t ON t.id = st.tag_id + WHERE st.server_id = servers.id AND t.name LIKE ? + ) + OR EXISTS ( + SELECT 1 FROM forwards f + WHERE f.server_id = servers.id + AND ( + f.name LIKE ? OR f.description LIKE ? + OR f.local_addr LIKE ? OR f.remote_addr LIKE ? + OR CAST(f.local_port AS TEXT) LIKE ? + OR CAST(f.remote_port AS TEXT) LIKE ? + ) + ) + ORDER BY alias`, + pattern, pattern, pattern, pattern, pattern, pattern, pattern, pattern, + pattern, + pattern, pattern, pattern, pattern, pattern, pattern) if err != nil { return nil, err } @@ -361,6 +379,18 @@ func (db *DB) GetForwards(serverID int64) ([]*model.Forward, error) { return forwards, rows.Err() } +func (db *DB) GetForward(forwardID int64) (*model.Forward, error) { + var f model.Forward + err := db.conn.QueryRow(` + SELECT id, server_id, name, description, type, local_addr, local_port, remote_addr, remote_port, enabled + FROM forwards WHERE id=?`, forwardID).Scan( + &f.ID, &f.ServerID, &f.Name, &f.Description, &f.Type, &f.LocalAddr, &f.LocalPort, &f.RemoteAddr, &f.RemotePort, &f.Enabled) + if err != nil { + return nil, err + } + return &f, nil +} + func (db *DB) DeleteForward(forwardID int64) error { _, err := db.conn.Exec("DELETE FROM forwards WHERE id=?", forwardID) return err diff --git a/internal/db/servers_test.go b/internal/db/servers_test.go index 5af1961..91d74ba 100644 --- a/internal/db/servers_test.go +++ b/internal/db/servers_test.go @@ -232,3 +232,53 @@ func TestTagManagementCRUD(t *testing.T) { t.Fatalf("tags after delete: %#v", got.Tags) } } + +func TestSearchServersMatchesTagsRoutesAndForwardPorts(t *testing.T) { + db, err := Open(t.TempDir()) + if err != nil { + t.Fatalf("open db: %v", err) + } + defer db.Close() + + server := &model.Server{ + Alias: "db", + Host: "db.internal", + Port: 22, + User: "postgres", + AuthMethod: model.AuthKey, + Route: model.Route{Hops: []model.RouteHop{ + {Alias: "bastion", IsProfile: true}, + {Raw: "dmz.example.org", IsProfile: false}, + }}, + } + if err := db.CreateServer(server); err != nil { + t.Fatalf("create server: %v", err) + } + if err := db.SetServerTags(server.ID, []string{"database"}); err != nil { + t.Fatalf("set tags: %v", err) + } + if _, err := db.AddForward(&model.Forward{ + ServerID: server.ID, + Name: "Postgres", + Type: model.ForwardLocal, + LocalAddr: "127.0.0.1", + LocalPort: 15432, + RemoteAddr: "127.0.0.1", + RemotePort: 5432, + Enabled: true, + }); err != nil { + t.Fatalf("add forward: %v", err) + } + + for _, query := range []string{"database", "dmz.example.org", "15432", "5432"} { + t.Run(query, func(t *testing.T) { + results, err := db.SearchServers(query) + if err != nil { + t.Fatalf("search servers: %v", err) + } + if len(results) != 1 || results[0].Alias != "db" { + t.Fatalf("search %q returned %#v", query, results) + } + }) + } +} diff --git a/internal/tui/app_test.go b/internal/tui/app_test.go index 51eddea..0492db6 100644 --- a/internal/tui/app_test.go +++ b/internal/tui/app_test.go @@ -766,6 +766,22 @@ func TestForwardSaveErrorStaysOnForm(t *testing.T) { } } +func TestForwardDynamicFormFocusDoesNotPanic(t *testing.T) { + fm := newForwardFormModel(1, 100, 30) + fm.currentType = model.ForwardDynamic + fm.typeIdx = typeIndex(model.ForwardDynamic) + fm.focusIdx = 2 + 3 + + defer func() { + if r := recover(); r != nil { + t.Fatalf("dynamic forward focus should not panic, got %v", r) + } + }() + + fm.updateFocus() + _ = fm.View() +} + func TestActionMenuClosesOnAllActions(t *testing.T) { server := &model.Server{ID: 1, Alias: "web", Host: "web.example.org", Port: 22, User: "root"} m := New([]*model.Server{server}) diff --git a/internal/tui/forward.go b/internal/tui/forward.go index 99f887b..b4f35bb 100644 --- a/internal/tui/forward.go +++ b/internal/tui/forward.go @@ -240,16 +240,21 @@ func (fm *forwardFormModel) visibleFields() []int { } func (fm *forwardFormModel) labelForField(idx int) string { + var labels []string switch fm.currentType { case model.ForwardLocal: - return []string{"Listen Address", "Listen Port", "Target Host", "Target Port"}[idx] + labels = []string{"Listen Address", "Listen Port", "Target Host", "Target Port"} case model.ForwardRemote: - return []string{"Remote Listen Addr", "Remote Listen Port", "Local Target Host", "Local Target Port"}[idx] + labels = []string{"Remote Listen Addr", "Remote Listen Port", "Local Target Host", "Local Target Port"} case model.ForwardDynamic: - return []string{"Listen Address", "Listen Port"}[idx] + labels = []string{"Listen Address", "Listen Port"} default: return "" } + if idx < 0 || idx >= len(labels) { + return "" + } + return labels[idx] } func (fm *forwardFormModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {