Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d127a3b957 | |||
| cdf750081e | |||
| 574c4dbf58 | |||
| 8fc7896b35 | |||
| da3a280a43 | |||
| a4fa33d224 | |||
| 8ae28b3474 | |||
| 9334003c9e |
@@ -1,5 +1,7 @@
|
||||
# netssh
|
||||
|
||||
> **Vibe-coded project** — this codebase was written entirely by an AI assistant (Claude) without human code review. Use in production at your own risk.
|
||||
|
||||
A transparent SSH wrapper that resolves hostnames via [NetBox](https://netbox.dev/) before connecting.
|
||||
|
||||
Instead of looking up an IP manually, you just type the hostname as it appears in NetBox:
|
||||
@@ -16,8 +18,14 @@ netssh -p 2222 admin@app-server-03 uptime
|
||||
- **Transparent proxy** — replaces itself with `ssh` via `syscall.Exec`, preserving all SSH flags and options
|
||||
- **Flexible IP resolution** — configurable chain of strategies: management subnet, primary IP, or named interface
|
||||
- **Interactive TUI** — fuzzy search with live NetBox queries and 300 ms debouncing (start with `netssh`, no arguments)
|
||||
- **Recently-used list** — TUI opens with your 10 most-recently-connected hosts, no typing needed
|
||||
- **Tag/kind filter** — press `Ctrl+F` in the TUI to filter by `tag:prod` or `kind:vm`
|
||||
- **User/port override** — press `e` in the TUI to override the SSH user or port before connecting
|
||||
- **Persistent cache** — successful lookups are cached to `~/.cache/netssh/hosts.json` for instant shell completion
|
||||
- **Shell completion** — tab-complete hostnames from the cache in zsh, bash, and fish
|
||||
- **Full pagination** — `cache refresh` fetches all hosts from NetBox (not just the first 50)
|
||||
- **Selective refresh** — `cache refresh --tag prod --kind vm` limits what gets synced
|
||||
- **Setup wizard** — interactive first-run onboarding; re-run anytime with `netssh configure`
|
||||
- **Shell completion** — install without sudo via `netssh completion install`
|
||||
- **Default SSH user** — set a fallback username once in config instead of typing it every time
|
||||
|
||||
## Installation
|
||||
@@ -46,12 +54,27 @@ go build -o netssh ./cmd/netssh
|
||||
|
||||
## Configuration
|
||||
|
||||
Create `~/.config/netssh.yaml`:
|
||||
### Interactive wizard
|
||||
|
||||
On first run (when no config exists), `netssh` automatically starts an interactive setup wizard.
|
||||
Re-run it at any time to change settings without editing the file manually:
|
||||
|
||||
```sh
|
||||
netssh configure
|
||||
```
|
||||
|
||||
The wizard walks through NetBox connection, SSH defaults, resolver strategies, and cache TTL,
|
||||
then saves to `~/.config/netssh.yaml`.
|
||||
|
||||
### Manual config
|
||||
|
||||
`~/.config/netssh.yaml`:
|
||||
|
||||
```yaml
|
||||
netbox:
|
||||
url: https://netbox.example.com
|
||||
token: your-api-token-here
|
||||
token: nbt_your-api-token-here # v2 token (nbt_ prefix) recommended
|
||||
token_version: 2 # auto-detected from token; 1 = legacy, 2 = nbt_
|
||||
|
||||
resolver:
|
||||
# Strategies are tried in order; the first to return an IP wins.
|
||||
@@ -73,7 +96,19 @@ ssh:
|
||||
default_user: admin # used when no user is specified on the command line
|
||||
```
|
||||
|
||||
Any value can be overridden with environment variables (`NETSSH_NETBOX_URL`, `NETSSH_NETBOX_TOKEN`, etc.) or will be read from the config file.
|
||||
Any value can be overridden with environment variables (`NETSSH_NETBOX_URL`, `NETSSH_NETBOX_TOKEN`, etc.).
|
||||
|
||||
### API tokens
|
||||
|
||||
NetBox supports two token formats:
|
||||
|
||||
| Format | Example | Notes |
|
||||
|--------|---------|-------|
|
||||
| v2 (recommended) | `nbt_abc123…` | Create in NetBox → Admin → API Tokens |
|
||||
| v1 (legacy) | `abc123def456…` | Older format; still works, but v2 is preferred |
|
||||
|
||||
`netssh` auto-detects the version from the token prefix and stores it as `token_version` in the config.
|
||||
A hint is shown during `netssh configure` if a legacy v1 token is entered.
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -112,19 +147,42 @@ Run without arguments to open the interactive search:
|
||||
netssh
|
||||
```
|
||||
|
||||
The TUI opens with your 10 most-recently-connected hosts. Start typing to search all cached hosts or query NetBox live.
|
||||
|
||||
| Key | Action |
|
||||
|-----|--------|
|
||||
| type | filter hosts (300 ms debounce → NetBox query) |
|
||||
| `Tab` | autocomplete top result into the search field |
|
||||
| `↑` / `↓` | navigate results |
|
||||
| `Enter` | connect to selected host |
|
||||
| `Esc` / `Ctrl+C` | quit |
|
||||
| `e` | open inline editor to override user/port before connecting |
|
||||
| `Ctrl+F` | open/close tag and kind filter (`tag:prod kind:vm`) |
|
||||
| `Esc` / `Ctrl+C` | quit (or close filter/edit if open) |
|
||||
|
||||
**Tag and kind filter** — press `Ctrl+F` to open a second input line:
|
||||
|
||||
```
|
||||
Filter: tag:prod kind:vm
|
||||
```
|
||||
|
||||
Multiple `tag:` values are AND-combined. The filter is applied locally against the cache; when doing a live NetBox search the first tag is also forwarded as a query parameter.
|
||||
|
||||
**User/port override** — press `e` on any highlighted host:
|
||||
|
||||
```
|
||||
Connect as: admin@my-router:22
|
||||
```
|
||||
|
||||
Edit the pre-filled value and press `Enter` to connect. `Esc` cancels. Port 22 is treated as default and omitted from the ssh command.
|
||||
|
||||
### Cache management
|
||||
|
||||
```sh
|
||||
netssh cache list # show all cached entries
|
||||
netssh cache refresh # re-fetch all hosts from NetBox
|
||||
netssh cache refresh # re-fetch ALL hosts from NetBox (paginated)
|
||||
netssh cache refresh --tag prod # only hosts with the "prod" tag
|
||||
netssh cache refresh --kind vm # only virtual machines
|
||||
netssh cache refresh --tag prod --kind vm # combine filters
|
||||
netssh cache clear # wipe the cache
|
||||
```
|
||||
|
||||
@@ -146,28 +204,26 @@ Strategies are tried in the configured order; the first to succeed wins.
|
||||
|
||||
## Shell Completion
|
||||
|
||||
### zsh
|
||||
Install completion for the current user (no sudo required):
|
||||
|
||||
```sh
|
||||
netssh completion zsh > "${fpath[1]}/_netssh"
|
||||
netssh completion install # auto-detects $SHELL
|
||||
netssh completion install --shell bash
|
||||
netssh completion install --shell zsh
|
||||
netssh completion install --shell fish
|
||||
```
|
||||
|
||||
Or add to `.zshrc`:
|
||||
| Shell | Install path |
|
||||
|-------|-------------|
|
||||
| bash | `~/.local/share/bash-completion/completions/netssh` |
|
||||
| zsh | `~/.zfunc/_netssh` |
|
||||
| fish | `~/.config/fish/completions/netssh.fish` |
|
||||
|
||||
For zsh, make sure `~/.zfunc` is in your `fpath` (add to `~/.zshrc`):
|
||||
|
||||
```zsh
|
||||
source <(netssh completion zsh)
|
||||
```
|
||||
|
||||
### bash
|
||||
|
||||
```sh
|
||||
netssh completion bash > /etc/bash_completion.d/netssh
|
||||
```
|
||||
|
||||
### fish
|
||||
|
||||
```sh
|
||||
netssh completion fish > ~/.config/fish/completions/netssh.fish
|
||||
fpath=(~/.zfunc $fpath)
|
||||
autoload -Uz compinit && compinit
|
||||
```
|
||||
|
||||
Completions are served from the local cache — no network request on every `<Tab>`.
|
||||
@@ -179,12 +235,17 @@ go test ./... # run all tests
|
||||
go build ./... # build all packages
|
||||
```
|
||||
|
||||
The test suite covers the cache, NetBox client (via `httptest`), IP resolver chain, and SSH argument parser.
|
||||
The test suite covers the cache, NetBox client (via `httptest`), IP resolver chain, SSH argument parser, config loading, and the setup wizard.
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is a **vibe-coded** project: the entire codebase — architecture, implementation, tests, and docs — was generated by an AI assistant (Claude by Anthropic). No human has reviewed or audited the code. It works for the author's personal use case, but correctness and security are not guaranteed. Read the source before running it in sensitive environments.
|
||||
|
||||
## How it works
|
||||
|
||||
1. `netssh` checks whether the first argument is a known subcommand (`search`, `cache`, `completion`). If not, it enters SSH wrapper mode.
|
||||
2. It parses the SSH arguments to extract the destination hostname, handling all flags that consume an extra argument (`-p`, `-i`, `-J`, …).
|
||||
3. It checks the local cache. If the entry exists and is within the TTL, it connects immediately.
|
||||
4. Otherwise it queries NetBox (`/api/dcim/devices/` and `/api/virtualization/virtual-machines/` in parallel), runs the result through the resolver chain, and caches the IP.
|
||||
5. It calls `syscall.Exec` to replace itself with `ssh`, substituting the hostname with the resolved IP.
|
||||
1. `netssh` checks whether the first argument is a known subcommand (`configure`, `search`, `cache`, `completion`). If not, it enters SSH wrapper mode.
|
||||
2. On first run or when `netbox.url` is empty, the interactive setup wizard starts automatically.
|
||||
3. It parses the SSH arguments to extract the destination hostname, handling all flags that consume an extra argument (`-p`, `-i`, `-J`, …).
|
||||
4. It checks the local cache. If the entry exists and is within the TTL, it connects immediately.
|
||||
5. Otherwise it queries NetBox (`/api/dcim/devices/` and `/api/virtualization/virtual-machines/` in parallel), runs the result through the resolver chain, and caches the IP.
|
||||
6. It calls `syscall.Exec` to replace itself with `ssh`, substituting the hostname with the resolved IP.
|
||||
|
||||
+133
-21
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
@@ -92,6 +93,8 @@ func runSSHWrapper(args []string) {
|
||||
|
||||
// Cache hit with a fresh TTL — connect directly without querying NetBox.
|
||||
if entry, fresh := c.Get(parsed.Host); fresh {
|
||||
c.MarkUsed(parsed.Host)
|
||||
_ = c.Save()
|
||||
connect(entry.IP, parsed, args)
|
||||
return
|
||||
}
|
||||
@@ -99,12 +102,12 @@ func runSSHWrapper(args []string) {
|
||||
if cfg.NetBox.URL == "" {
|
||||
fatalf("netbox.url is not configured (~/.config/netssh.yaml)")
|
||||
}
|
||||
nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token)
|
||||
nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token, cfg.NetBox.TokenVersion)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
entries, err := nbClient.Search(ctx, parsed.Host)
|
||||
entries, err := nbClient.Search(ctx, parsed.Host, netbox.SearchOptions{})
|
||||
if err != nil {
|
||||
fatalf("NetBox search failed: %v", err)
|
||||
}
|
||||
@@ -131,6 +134,7 @@ func runSSHWrapper(args []string) {
|
||||
}
|
||||
|
||||
c.Upsert(cache.Entry{Name: target.Name, IP: ip, Kind: target.Kind, Tags: target.Tags})
|
||||
c.MarkUsed(target.Name)
|
||||
_ = c.Save()
|
||||
|
||||
connect(ip, parsed, args)
|
||||
@@ -152,10 +156,10 @@ func runTUI() {
|
||||
|
||||
var nbClient *netbox.Client
|
||||
if cfg.NetBox.URL != "" {
|
||||
nbClient = netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token)
|
||||
nbClient = netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token, cfg.NetBox.TokenVersion)
|
||||
}
|
||||
|
||||
m := tui.New(nbClient, c)
|
||||
m := tui.New(nbClient, c, cfg.SSH.DefaultUser)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen())
|
||||
final, err := p.Run()
|
||||
if err != nil {
|
||||
@@ -174,11 +178,25 @@ func runTUI() {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to %s (%s)…\n", host.Name, host.IP)
|
||||
|
||||
var sshArgs []string
|
||||
if cfg.SSH.DefaultUser != "" {
|
||||
sshArgs = append(sshArgs, "-l", cfg.SSH.DefaultUser)
|
||||
|
||||
// User override from TUI edit mode takes priority, then config default.
|
||||
user := host.User
|
||||
if user == "" {
|
||||
user = cfg.SSH.DefaultUser
|
||||
}
|
||||
if user != "" {
|
||||
sshArgs = append(sshArgs, "-l", user)
|
||||
}
|
||||
|
||||
if host.Port != "" {
|
||||
sshArgs = append(sshArgs, "-p", host.Port)
|
||||
}
|
||||
|
||||
sshArgs = append(sshArgs, host.IP)
|
||||
|
||||
c.MarkUsed(host.Name)
|
||||
_ = c.Save()
|
||||
|
||||
if err := internalssh.Exec(sshArgs); err != nil {
|
||||
fatalf("%v", err)
|
||||
}
|
||||
@@ -198,20 +216,103 @@ func rootCmd() *cobra.Command {
|
||||
}
|
||||
c := cache.New(cfg.Cache.Path, cfg.Cache.TTL)
|
||||
_ = c.Load()
|
||||
entries := c.Search(toComplete)
|
||||
names := make([]cobra.Completion, len(entries))
|
||||
for i, e := range entries {
|
||||
names[i] = cobra.Completion(e.Name)
|
||||
|
||||
var completions []cobra.Completion
|
||||
for _, e := range c.Search(toComplete) {
|
||||
completions = append(completions, cobra.Completion(e.Name))
|
||||
}
|
||||
return names, cobra.ShellCompDirectiveNoFileComp
|
||||
// Subcommands at the end, after all hostnames.
|
||||
for _, sub := range cmd.Commands() {
|
||||
if sub.IsAvailableCommand() && strings.HasPrefix(sub.Name(), toComplete) {
|
||||
completions = append(completions, cobra.Completion(sub.Name()+"\t"+sub.Short))
|
||||
}
|
||||
}
|
||||
return completions, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveKeepOrder
|
||||
},
|
||||
}
|
||||
|
||||
// cobra automatically adds a "completion" subcommand
|
||||
root.AddCommand(configureCmd(), searchCmd(), cacheCmd())
|
||||
|
||||
// cobra builds the "completion" command lazily; force init so we can extend it.
|
||||
root.InitDefaultCompletionCmd()
|
||||
for _, cmd := range root.Commands() {
|
||||
if cmd.Name() == "completion" {
|
||||
cmd.AddCommand(completionInstallCmd(root))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
func completionInstallCmd(root *cobra.Command) *cobra.Command {
|
||||
var shell string
|
||||
cmd := &cobra.Command{
|
||||
Use: "install",
|
||||
Short: "Install shell completion for the current user (no sudo required)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if shell == "" {
|
||||
shell = filepath.Base(os.Getenv("SHELL"))
|
||||
}
|
||||
|
||||
var (
|
||||
dir string
|
||||
file string
|
||||
gen func() ([]byte, error)
|
||||
note string
|
||||
)
|
||||
|
||||
switch shell {
|
||||
case "bash":
|
||||
dir = filepath.Join(os.Getenv("HOME"), ".local", "share", "bash-completion", "completions")
|
||||
file = filepath.Join(dir, "netssh")
|
||||
gen = func() ([]byte, error) {
|
||||
var buf strings.Builder
|
||||
err := root.GenBashCompletionV2(&buf, true)
|
||||
return []byte(buf.String()), err
|
||||
}
|
||||
note = "Reload your shell or run: source " + file
|
||||
case "zsh":
|
||||
dir = filepath.Join(os.Getenv("HOME"), ".zfunc")
|
||||
file = filepath.Join(dir, "_netssh")
|
||||
gen = func() ([]byte, error) {
|
||||
var buf strings.Builder
|
||||
err := root.GenZshCompletion(&buf)
|
||||
return []byte(buf.String()), err
|
||||
}
|
||||
note = "Make sure ~/.zfunc is in your fpath:\n fpath=(~/.zfunc $fpath)\n autoload -Uz compinit && compinit"
|
||||
case "fish":
|
||||
configDir, _ := os.UserConfigDir()
|
||||
dir = filepath.Join(configDir, "fish", "completions")
|
||||
file = filepath.Join(dir, "netssh.fish")
|
||||
gen = func() ([]byte, error) {
|
||||
var buf strings.Builder
|
||||
err := root.GenFishCompletion(&buf, true)
|
||||
return []byte(buf.String()), err
|
||||
}
|
||||
note = "Reload your shell or start a new fish session."
|
||||
default:
|
||||
return fmt.Errorf("unsupported shell %q — use --shell bash|zsh|fish", shell)
|
||||
}
|
||||
|
||||
script, err := gen()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating completion: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("creating %s: %w", dir, err)
|
||||
}
|
||||
if err := os.WriteFile(file, script, 0o644); err != nil {
|
||||
return fmt.Errorf("writing %s: %w", file, err)
|
||||
}
|
||||
fmt.Printf("Completion installed → %s\n%s\n", file, note)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&shell, "shell", "", "Shell to install for (default: $SHELL). Supported: bash, zsh, fish")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func configureCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "configure",
|
||||
@@ -308,7 +409,8 @@ func cacheClearCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func cacheRefreshCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
var filterTag, filterKind string
|
||||
cmd := &cobra.Command{
|
||||
Use: "refresh",
|
||||
Short: "Re-fetch all known hosts from NetBox",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
@@ -320,12 +422,12 @@ func cacheRefreshCmd() *cobra.Command {
|
||||
return fmt.Errorf("netbox.url is not configured")
|
||||
}
|
||||
|
||||
nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token, cfg.NetBox.TokenVersion)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// An empty query returns up to 50 entries per type.
|
||||
entries, err := nbClient.Search(ctx, "")
|
||||
opts := netbox.SearchOptions{Tag: filterTag, Kind: filterKind}
|
||||
entries, err := nbClient.SearchAll(ctx, "", opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("NetBox: %w", err)
|
||||
}
|
||||
@@ -336,13 +438,20 @@ func cacheRefreshCmd() *cobra.Command {
|
||||
chain, _ := resolver.New(cfg.Resolver)
|
||||
for i := range entries {
|
||||
e := &entries[i]
|
||||
ip := e.PrimaryIP4
|
||||
if ip == "" && chain != nil {
|
||||
resolved, err := chain.Resolve(ctx, e, nbClient)
|
||||
if err == nil {
|
||||
|
||||
var ip string
|
||||
if chain != nil {
|
||||
if resolved, err := chain.Resolve(ctx, e, nbClient); err == nil {
|
||||
ip = resolved
|
||||
}
|
||||
}
|
||||
if ip == "" {
|
||||
ip = e.PrimaryIP4
|
||||
}
|
||||
if ip == "" {
|
||||
ip = e.PrimaryIP6
|
||||
}
|
||||
|
||||
if ip != "" {
|
||||
c.Upsert(cache.Entry{Name: e.Name, IP: ip, Kind: e.Kind, Tags: e.Tags})
|
||||
}
|
||||
@@ -355,6 +464,9 @@ func cacheRefreshCmd() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd.Flags().StringVar(&filterTag, "tag", "", "Filter by NetBox tag slug (e.g. prod)")
|
||||
cmd.Flags().StringVar(&filterKind, "kind", "", "Filter by kind: device or vm")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func fatalf(format string, args ...any) {
|
||||
|
||||
Vendored
+35
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -15,6 +16,7 @@ type Entry struct {
|
||||
Kind string `json:"kind"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
CachedAt time.Time `json:"cached_at"`
|
||||
LastUsed time.Time `json:"last_used,omitempty"`
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
@@ -86,6 +88,39 @@ func (c *Cache) Upsert(e Entry) {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// MarkUsed records the current time as LastUsed for the named entry.
|
||||
// It is a no-op if the entry does not exist.
|
||||
func (c *Cache) MarkUsed(name string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if e, ok := c.entries[name]; ok {
|
||||
e.LastUsed = time.Now()
|
||||
c.entries[name] = e
|
||||
}
|
||||
}
|
||||
|
||||
// RecentlyUsed returns the n most recently used entries, sorted by LastUsed desc.
|
||||
// Entries that have never been used (LastUsed zero) are excluded.
|
||||
// If n <= 0, all used entries are returned.
|
||||
func (c *Cache) RecentlyUsed(n int) []Entry {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
var used []Entry
|
||||
for _, e := range c.entries {
|
||||
if !e.LastUsed.IsZero() {
|
||||
used = append(used, e)
|
||||
}
|
||||
}
|
||||
sort.Slice(used, func(i, j int) bool {
|
||||
return used[i].LastUsed.After(used[j].LastUsed)
|
||||
})
|
||||
if n > 0 && len(used) > n {
|
||||
used = used[:n]
|
||||
}
|
||||
return used
|
||||
}
|
||||
|
||||
// Search returns all entries whose name starts with prefix (case-insensitive).
|
||||
// TTL is intentionally ignored — this is used for shell completion.
|
||||
func (c *Cache) Search(prefix string) []Entry {
|
||||
|
||||
Vendored
+80
@@ -220,6 +220,86 @@ func TestSave_ProducesValidJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkUsed_SetsLastUsed(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
|
||||
|
||||
before := time.Now()
|
||||
c.MarkUsed("host")
|
||||
|
||||
e, _ := c.Get("host")
|
||||
if e.LastUsed.Before(before) {
|
||||
t.Error("LastUsed should be set to current time by MarkUsed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkUsed_NoopForMissingEntry(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.MarkUsed("nonexistent") // should not panic
|
||||
}
|
||||
|
||||
func TestRecentlyUsed_ReturnsTopN(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "c", IP: "3.3.3.3", Kind: "device"})
|
||||
|
||||
c.MarkUsed("c")
|
||||
time.Sleep(time.Millisecond)
|
||||
c.MarkUsed("a")
|
||||
|
||||
results := c.RecentlyUsed(2)
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("RecentlyUsed(2): got %d results, want 2", len(results))
|
||||
}
|
||||
if results[0].Name != "a" {
|
||||
t.Errorf("first result: got %q, want %q", results[0].Name, "a")
|
||||
}
|
||||
if results[1].Name != "c" {
|
||||
t.Errorf("second result: got %q, want %q", results[1].Name, "c")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecentlyUsed_ExcludesNeverUsed(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "used", IP: "1.1.1.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "unused", IP: "2.2.2.2", Kind: "device"})
|
||||
c.MarkUsed("used")
|
||||
|
||||
results := c.RecentlyUsed(10)
|
||||
if len(results) != 1 || results[0].Name != "used" {
|
||||
t.Errorf("RecentlyUsed should exclude entries with zero LastUsed, got %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecentlyUsed_EmptyCache(t *testing.T) {
|
||||
c := New("", 60)
|
||||
if results := c.RecentlyUsed(10); len(results) != 0 {
|
||||
t.Errorf("RecentlyUsed on empty cache: got %d, want 0", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkUsed_RoundtripViaSave(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "cache.json")
|
||||
c := New(path, 3600)
|
||||
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
|
||||
c.MarkUsed("host")
|
||||
|
||||
if err := c.Save(); err != nil {
|
||||
t.Fatalf("Save: %v", err)
|
||||
}
|
||||
|
||||
c2 := New(path, 3600)
|
||||
if err := c2.Load(); err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
results := c2.RecentlyUsed(1)
|
||||
if len(results) != 1 || results[0].Name != "host" {
|
||||
t.Errorf("LastUsed not persisted: %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
// tempFile writes content to a temp file and returns its path.
|
||||
func tempFile(t *testing.T, content []byte) string {
|
||||
t.Helper()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -18,6 +19,7 @@ type Config struct {
|
||||
type NetBoxConfig struct {
|
||||
URL string `mapstructure:"url"`
|
||||
Token string `mapstructure:"token"`
|
||||
TokenVersion int `mapstructure:"token_version"`
|
||||
}
|
||||
|
||||
type ResolverConfig struct {
|
||||
@@ -73,6 +75,14 @@ func Load() (*Config, error) {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.NetBox.TokenVersion == 0 && cfg.NetBox.Token != "" {
|
||||
if strings.HasPrefix(cfg.NetBox.Token, "nbt_") {
|
||||
cfg.NetBox.TokenVersion = 2
|
||||
} else {
|
||||
cfg.NetBox.TokenVersion = 1
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Cache.Path == "" {
|
||||
cacheDir, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeConfig(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "netssh.yaml")
|
||||
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func loadFromDir(t *testing.T, dir string) *Config {
|
||||
t.Helper()
|
||||
// Override UserConfigDir by pointing viper at our temp dir via env isn't
|
||||
// straightforward, so we exercise Load() by temporarily changing XDG_CONFIG_HOME.
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Cleanup(func() { os.Setenv("XDG_CONFIG_HOME", orig) })
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func TestLoad_V2TokenVersion_Preserved(t *testing.T) {
|
||||
dir := writeConfig(t, `
|
||||
netbox:
|
||||
url: "https://netbox.example.com"
|
||||
token: "nbt_abc123"
|
||||
token_version: 2
|
||||
`)
|
||||
cfg := loadFromDir(t, dir)
|
||||
if cfg.NetBox.TokenVersion != 2 {
|
||||
t.Errorf("TokenVersion: got %d, want 2", cfg.NetBox.TokenVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_V1TokenVersion_Preserved(t *testing.T) {
|
||||
dir := writeConfig(t, `
|
||||
netbox:
|
||||
url: "https://netbox.example.com"
|
||||
token: "legacyhex123"
|
||||
token_version: 1
|
||||
`)
|
||||
cfg := loadFromDir(t, dir)
|
||||
if cfg.NetBox.TokenVersion != 1 {
|
||||
t.Errorf("TokenVersion: got %d, want 1", cfg.NetBox.TokenVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_AutoDetectsV2_WhenFieldMissing(t *testing.T) {
|
||||
dir := writeConfig(t, `
|
||||
netbox:
|
||||
url: "https://netbox.example.com"
|
||||
token: "nbt_mytoken"
|
||||
`)
|
||||
cfg := loadFromDir(t, dir)
|
||||
if cfg.NetBox.TokenVersion != 2 {
|
||||
t.Errorf("TokenVersion: got %d, want 2 (auto-detected from nbt_ prefix)", cfg.NetBox.TokenVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_AutoDetectsV1_WhenFieldMissing(t *testing.T) {
|
||||
dir := writeConfig(t, `
|
||||
netbox:
|
||||
url: "https://netbox.example.com"
|
||||
token: "abc123def456"
|
||||
`)
|
||||
cfg := loadFromDir(t, dir)
|
||||
if cfg.NetBox.TokenVersion != 1 {
|
||||
t.Errorf("TokenVersion: got %d, want 1 (auto-detected from plain token)", cfg.NetBox.TokenVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_TokenVersionZero_WhenNoToken(t *testing.T) {
|
||||
dir := writeConfig(t, `
|
||||
netbox:
|
||||
url: "https://netbox.example.com"
|
||||
`)
|
||||
cfg := loadFromDir(t, dir)
|
||||
if cfg.NetBox.TokenVersion != 0 {
|
||||
t.Errorf("TokenVersion: got %d, want 0 (no token present)", cfg.NetBox.TokenVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Defaults(t *testing.T) {
|
||||
dir := writeConfig(t, `
|
||||
netbox:
|
||||
url: "https://netbox.example.com"
|
||||
token: "nbt_x"
|
||||
`)
|
||||
cfg := loadFromDir(t, dir)
|
||||
|
||||
if cfg.Cache.TTL != 3600 {
|
||||
t.Errorf("default cache.ttl: got %d, want 3600", cfg.Cache.TTL)
|
||||
}
|
||||
if cfg.Cache.Path == "" {
|
||||
t.Error("cache.path should be auto-set when empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingFile_ReturnsEmptyConfig(t *testing.T) {
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", t.TempDir()) // dir exists but no netssh.yaml
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load on missing file should not error: %v", err)
|
||||
}
|
||||
if cfg.NetBox.URL != "" {
|
||||
t.Errorf("expected empty URL, got %q", cfg.NetBox.URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidYAML_ReturnsError(t *testing.T) {
|
||||
dir := writeConfig(t, "not: valid: yaml: [[[")
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Error("invalid YAML should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPath_ReturnsNonEmpty(t *testing.T) {
|
||||
p := Path()
|
||||
if p == "" {
|
||||
t.Error("Path() should return a non-empty string")
|
||||
}
|
||||
if filepath.Base(p) != "netssh.yaml" {
|
||||
t.Errorf("Path() base: got %q, want netssh.yaml", filepath.Base(p))
|
||||
}
|
||||
}
|
||||
+149
-10
@@ -13,31 +13,41 @@ import (
|
||||
type Client struct {
|
||||
baseURL string
|
||||
token string
|
||||
tokenVersion int
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(baseURL, token string) *Client {
|
||||
// NewClient creates a NetBox API client. Pass tokenVersion=0 to auto-detect
|
||||
// from the token string (1 for legacy, 2 for nbt_-prefixed tokens).
|
||||
func NewClient(baseURL, token string, tokenVersion int) *Client {
|
||||
if tokenVersion == 0 {
|
||||
tokenVersion = TokenVersion(token)
|
||||
}
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
token: token,
|
||||
tokenVersion: tokenVersion,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// Search queries devices and VMs in parallel and merges the results.
|
||||
func (c *Client) Search(ctx context.Context, query string) ([]HostEntry, error) {
|
||||
// Search queries up to 50 devices and VMs in parallel and merges the results.
|
||||
// Use SearchOptions to restrict by kind or tag.
|
||||
func (c *Client) Search(ctx context.Context, query string, opts SearchOptions) ([]HostEntry, error) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
results []HostEntry
|
||||
errs []error
|
||||
wg sync.WaitGroup
|
||||
started int
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
if opts.Kind != "vm" {
|
||||
started++
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
devices, err := c.searchDevices(ctx, query)
|
||||
devices, err := c.searchDevices(ctx, query, opts.Tag)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err != nil {
|
||||
@@ -46,10 +56,14 @@ func (c *Client) Search(ctx context.Context, query string) ([]HostEntry, error)
|
||||
}
|
||||
results = append(results, devices...)
|
||||
}()
|
||||
}
|
||||
|
||||
if opts.Kind != "device" {
|
||||
started++
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vms, err := c.searchVMs(ctx, query)
|
||||
vms, err := c.searchVMs(ctx, query, opts.Tag)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err != nil {
|
||||
@@ -58,13 +72,70 @@ func (c *Client) Search(ctx context.Context, query string) ([]HostEntry, error)
|
||||
}
|
||||
results = append(results, vms...)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errs) == 2 {
|
||||
if len(errs) == started {
|
||||
if started == 1 {
|
||||
return nil, errs[0]
|
||||
}
|
||||
return nil, fmt.Errorf("netbox search failed: %v; %v", errs[0], errs[1])
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SearchAll paginates through all matching devices and VMs, fetching every page.
|
||||
// Intended for cache refresh; use Search for interactive queries.
|
||||
func (c *Client) SearchAll(ctx context.Context, query string, opts SearchOptions) ([]HostEntry, error) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
results []HostEntry
|
||||
errs []error
|
||||
wg sync.WaitGroup
|
||||
started int
|
||||
)
|
||||
|
||||
if opts.Kind != "vm" {
|
||||
started++
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
devices, err := c.fetchAllDevices(ctx, query, opts.Tag)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("devices: %w", err))
|
||||
return
|
||||
}
|
||||
results = append(results, devices...)
|
||||
}()
|
||||
}
|
||||
|
||||
if opts.Kind != "device" {
|
||||
started++
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vms, err := c.fetchAllVMs(ctx, query, opts.Tag)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("vms: %w", err))
|
||||
return
|
||||
}
|
||||
results = append(results, vms...)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errs) == started {
|
||||
if started == 1 {
|
||||
return nil, errs[0]
|
||||
}
|
||||
return nil, fmt.Errorf("netbox search failed: %v; %v", errs[0], errs[1])
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -107,8 +178,11 @@ func (c *Client) GetIPsWithFilter(ctx context.Context, filterParams string) ([]s
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (c *Client) searchDevices(ctx context.Context, query string) ([]HostEntry, error) {
|
||||
func (c *Client) searchDevices(ctx context.Context, query, tag string) ([]HostEntry, error) {
|
||||
apiURL := fmt.Sprintf("%s/api/dcim/devices/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query))
|
||||
if tag != "" {
|
||||
apiURL += "&tag=" + url.QueryEscape(tag)
|
||||
}
|
||||
var resp netboxListResponse[netboxDevice]
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -120,8 +194,11 @@ func (c *Client) searchDevices(ctx context.Context, query string) ([]HostEntry,
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (c *Client) searchVMs(ctx context.Context, query string) ([]HostEntry, error) {
|
||||
func (c *Client) searchVMs(ctx context.Context, query, tag string) ([]HostEntry, error) {
|
||||
apiURL := fmt.Sprintf("%s/api/virtualization/virtual-machines/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query))
|
||||
if tag != "" {
|
||||
apiURL += "&tag=" + url.QueryEscape(tag)
|
||||
}
|
||||
var resp netboxListResponse[netboxVM]
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -133,6 +210,60 @@ func (c *Client) searchVMs(ctx context.Context, query string) ([]HostEntry, erro
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (c *Client) fetchAllDevices(ctx context.Context, query, tag string) ([]HostEntry, error) {
|
||||
const pageSize = 50
|
||||
var all []HostEntry
|
||||
for offset := 0; ; offset += pageSize {
|
||||
apiURL := fmt.Sprintf("%s/api/dcim/devices/?name__ic=%s&limit=%d&offset=%d",
|
||||
c.baseURL, url.QueryEscape(query), pageSize, offset)
|
||||
if tag != "" {
|
||||
apiURL += "&tag=" + url.QueryEscape(tag)
|
||||
}
|
||||
var resp netboxListResponse[netboxDevice]
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, d := range resp.Results {
|
||||
all = append(all, deviceToEntry(d))
|
||||
}
|
||||
if len(resp.Results) == 0 || len(all) >= resp.Count {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
func (c *Client) fetchAllVMs(ctx context.Context, query, tag string) ([]HostEntry, error) {
|
||||
const pageSize = 50
|
||||
var all []HostEntry
|
||||
for offset := 0; ; offset += pageSize {
|
||||
apiURL := fmt.Sprintf("%s/api/virtualization/virtual-machines/?name__ic=%s&limit=%d&offset=%d",
|
||||
c.baseURL, url.QueryEscape(query), pageSize, offset)
|
||||
if tag != "" {
|
||||
apiURL += "&tag=" + url.QueryEscape(tag)
|
||||
}
|
||||
var resp netboxListResponse[netboxVM]
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range resp.Results {
|
||||
all = append(all, vmToEntry(v))
|
||||
}
|
||||
if len(resp.Results) == 0 || len(all) >= resp.Count {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// TokenVersion returns 2 for NetBox v2 tokens (nbt_ prefix) or 1 for legacy tokens.
|
||||
func TokenVersion(token string) int {
|
||||
if strings.HasPrefix(token, "nbt_") {
|
||||
return 2
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (c *Client) get(ctx context.Context, apiURL string, out any) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
|
||||
if err != nil {
|
||||
@@ -147,6 +278,14 @@ func (c *Client) get(ctx context.Context, apiURL string, out any) error {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
hint := "check token permissions in NetBox"
|
||||
if c.tokenVersion == 1 {
|
||||
hint += " — legacy v1 token detected, consider upgrading to a v2 token (starts with nbt_)"
|
||||
}
|
||||
return fmt.Errorf("%s: %s", apiURL, hint)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("netbox returned %d for %s", resp.StatusCode, apiURL)
|
||||
}
|
||||
|
||||
+191
-16
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -58,8 +59,8 @@ func TestSearch_ReturnsBothDevicesAndVMs(t *testing.T) {
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, err := c.Search(context.Background(), "")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, err := c.Search(context.Background(), "", SearchOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
@@ -87,8 +88,8 @@ func TestSearch_MapsKindCorrectly(t *testing.T) {
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, _ := c.Search(context.Background(), "")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, _ := c.Search(context.Background(), "", SearchOptions{})
|
||||
|
||||
for _, r := range results {
|
||||
switch r.Name {
|
||||
@@ -113,8 +114,8 @@ func TestSearch_StripsPrefixFromPrimaryIP(t *testing.T) {
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, _ := c.Search(context.Background(), "host")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, _ := c.Search(context.Background(), "host", SearchOptions{})
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected at least one result")
|
||||
}
|
||||
@@ -138,8 +139,8 @@ func TestSearch_TagsAreMapped(t *testing.T) {
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, _ := c.Search(context.Background(), "")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, _ := c.Search(context.Background(), "", SearchOptions{})
|
||||
if len(results[0].Tags) != 2 {
|
||||
t.Errorf("tags: got %v, want [prod mgmt]", results[0].Tags)
|
||||
}
|
||||
@@ -159,8 +160,8 @@ func TestSearch_PartialFailure_ReturnsAvailableResults(t *testing.T) {
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, err := c.Search(context.Background(), "")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, err := c.Search(context.Background(), "", SearchOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("partial failure should not return error, got: %v", err)
|
||||
}
|
||||
@@ -177,8 +178,8 @@ func TestSearch_BothFail_ReturnsError(t *testing.T) {
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
_, err := c.Search(context.Background(), "")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
_, err := c.Search(context.Background(), "", SearchOptions{})
|
||||
if err == nil {
|
||||
t.Error("both endpoints failing should return an error")
|
||||
}
|
||||
@@ -190,7 +191,7 @@ func TestGetIPs_Device(t *testing.T) {
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
ips, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "device"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetIPs: %v", err)
|
||||
@@ -209,7 +210,7 @@ func TestGetIPs_VM(t *testing.T) {
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
ips, err := c.GetIPs(context.Background(), HostEntry{ID: 2, Kind: "vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetIPs: %v", err)
|
||||
@@ -220,7 +221,7 @@ func TestGetIPs_VM(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetIPs_UnknownKind(t *testing.T) {
|
||||
c := NewClient("http://localhost", "token")
|
||||
c := NewClient("http://localhost", "token", 0)
|
||||
_, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "unknown"})
|
||||
if err == nil {
|
||||
t.Error("unknown kind should return an error")
|
||||
@@ -233,7 +234,7 @@ func TestGetIPsWithFilter(t *testing.T) {
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
ips, err := c.GetIPsWithFilter(context.Background(), "device_id=1&interface_name=mgmt0")
|
||||
if err != nil {
|
||||
t.Fatalf("GetIPsWithFilter: %v", err)
|
||||
@@ -243,6 +244,180 @@ func TestGetIPsWithFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
token string
|
||||
want int
|
||||
}{
|
||||
{"nbt_abc123", 2},
|
||||
{"nbt_", 2},
|
||||
{"abc123def456", 1},
|
||||
{"", 1},
|
||||
{"Token abc", 1},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := TokenVersion(tt.token); got != tt.want {
|
||||
t.Errorf("TokenVersion(%q) = %d, want %d", tt.token, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_AutoDetectsVersion(t *testing.T) {
|
||||
c := NewClient("http://localhost", "nbt_secret", 0)
|
||||
if c.tokenVersion != 2 {
|
||||
t.Errorf("tokenVersion: got %d, want 2", c.tokenVersion)
|
||||
}
|
||||
|
||||
c2 := NewClient("http://localhost", "legacytoken", 0)
|
||||
if c2.tokenVersion != 1 {
|
||||
t.Errorf("tokenVersion: got %d, want 1", c2.tokenVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_RespectsExplicitVersion(t *testing.T) {
|
||||
// Explicit version overrides auto-detection.
|
||||
c := NewClient("http://localhost", "legacytoken", 2)
|
||||
if c.tokenVersion != 2 {
|
||||
t.Errorf("tokenVersion: got %d, want 2", c.tokenVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func Test403_V1Token_HintsUpgrade(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "legacytoken", 1)
|
||||
_, err := c.Search(context.Background(), "host", SearchOptions{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error on 403")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "v1 token") {
|
||||
t.Errorf("expected v1 hint in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test403_V2Token_NoV1Hint(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "nbt_secret", 2)
|
||||
_, err := c.Search(context.Background(), "host", SearchOptions{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error on 403")
|
||||
}
|
||||
if strings.Contains(err.Error(), "v1 token") {
|
||||
t.Errorf("v1 hint should not appear for v2 token, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "check token permissions") {
|
||||
t.Errorf("expected permissions hint in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_SendsAuthorizationHeader(t *testing.T) {
|
||||
var gotAuth string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
b, _ := json.Marshal(deviceListResponse())
|
||||
w.Write(b)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "nbt_mytoken", 2)
|
||||
c.Search(context.Background(), "", SearchOptions{}) //nolint:errcheck
|
||||
|
||||
want := "Token nbt_mytoken"
|
||||
if gotAuth != want {
|
||||
t.Errorf("Authorization header: got %q, want %q", gotAuth, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchAll_PaginatesResults(t *testing.T) {
|
||||
// Simulate a NetBox endpoint with 3 total devices split across 2 pages.
|
||||
// count=3 throughout; first page has 2 results, second has 1.
|
||||
callCount := 0
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/dcim/devices/", func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
offset := r.URL.Query().Get("offset")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var resp netboxListResponse[netboxDevice]
|
||||
switch offset {
|
||||
case "", "0":
|
||||
resp = netboxListResponse[netboxDevice]{
|
||||
Count: 3,
|
||||
Results: []netboxDevice{{ID: 1, Name: "d-01"}, {ID: 2, Name: "d-02"}},
|
||||
}
|
||||
default:
|
||||
resp = netboxListResponse[netboxDevice]{
|
||||
Count: 3,
|
||||
Results: []netboxDevice{{ID: 3, Name: "d-03"}},
|
||||
}
|
||||
}
|
||||
b, _ := json.Marshal(resp)
|
||||
w.Write(b)
|
||||
})
|
||||
mux.HandleFunc("/api/virtualization/virtual-machines/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
b, _ := json.Marshal(vmListResponse())
|
||||
w.Write(b)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, err := c.SearchAll(context.Background(), "", SearchOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchAll: %v", err)
|
||||
}
|
||||
if len(results) != 3 {
|
||||
t.Errorf("SearchAll: got %d results, want 3", len(results))
|
||||
}
|
||||
if callCount < 2 {
|
||||
t.Errorf("expected at least 2 device API calls for pagination, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_KindFilterDeviceOnly(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/dcim/devices/": deviceListResponse(
|
||||
netboxDevice{ID: 1, Name: "sw-01"},
|
||||
),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, err := c.Search(context.Background(), "", SearchOptions{Kind: "device"})
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].Name != "sw-01" {
|
||||
t.Errorf("expected 1 device result, got %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_KindFilterVMOnly(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/virtualization/virtual-machines/": vmListResponse(
|
||||
netboxVM{ID: 1, Name: "vm-01"},
|
||||
),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token", 0)
|
||||
results, err := c.Search(context.Background(), "", SearchOptions{Kind: "vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].Name != "vm-01" {
|
||||
t.Errorf("expected 1 vm result, got %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package netbox
|
||||
|
||||
// SearchOptions filters a Search or SearchAll query.
|
||||
// Zero value means no filtering (return all kinds, no tag filter).
|
||||
type SearchOptions struct {
|
||||
Tag string // filter by tag slug; empty = no filter
|
||||
Kind string // "device" | "vm" | "" (both)
|
||||
}
|
||||
|
||||
// HostEntry is a unified model for both devices and virtual machines from NetBox.
|
||||
type HostEntry struct {
|
||||
ID int
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestManagementSubnetStrategy_MatchesSubnet(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
client := netbox.NewClient(srv.URL, "token", 0)
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != nil {
|
||||
@@ -52,7 +52,7 @@ func TestManagementSubnetStrategy_NoMatch(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
client := netbox.NewClient(srv.URL, "token", 0)
|
||||
|
||||
_, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != ErrNoIP {
|
||||
@@ -65,7 +65,7 @@ func TestManagementSubnetStrategy_FirstMatchWins(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
client := netbox.NewClient(srv.URL, "token", 0)
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != nil {
|
||||
@@ -81,7 +81,7 @@ func TestManagementSubnetStrategy_VMKind(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"172.16.0.0/12"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
client := netbox.NewClient(srv.URL, "token", 0)
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 2, Kind: "vm"}, client)
|
||||
if err != nil {
|
||||
@@ -97,7 +97,7 @@ func TestManagementSubnetStrategy_IPv6Subnet(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"fd00::/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
client := netbox.NewClient(srv.URL, "token", 0)
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != nil {
|
||||
|
||||
+46
-11
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/charmbracelet/huh"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// RunWizard runs the interactive setup form, pre-filled with any existing cfg values.
|
||||
@@ -19,13 +20,13 @@ func RunWizard(cfg *config.Config) error {
|
||||
url := cfg.NetBox.URL
|
||||
token := cfg.NetBox.Token
|
||||
defaultUser := cfg.SSH.DefaultUser
|
||||
strategies := cfg.Resolver.Strategies
|
||||
strategiesRaw := strings.Join(cfg.Resolver.Strategies, ", ")
|
||||
subnets := strings.Join(cfg.Resolver.ManagementSubnets, ", ")
|
||||
interfaceName := cfg.Resolver.InterfaceName
|
||||
cacheTTL := strconv.Itoa(cfg.Cache.TTL)
|
||||
|
||||
if len(strategies) == 0 {
|
||||
strategies = []string{"primary_ip"}
|
||||
if strategiesRaw == "" {
|
||||
strategiesRaw = "primary_ip"
|
||||
}
|
||||
if cacheTTL == "0" {
|
||||
cacheTTL = "3600"
|
||||
@@ -64,15 +65,12 @@ func RunWizard(cfg *config.Config) error {
|
||||
).Title("SSH defaults"),
|
||||
|
||||
huh.NewGroup(
|
||||
huh.NewMultiSelect[string]().
|
||||
huh.NewInput().
|
||||
Title("Resolver strategies").
|
||||
Description("Order matters: first match wins.").
|
||||
Options(
|
||||
huh.NewOption("primary_ip — NetBox primary IPv4/IPv6", "primary_ip"),
|
||||
huh.NewOption("management_subnet — first IP inside a subnet", "management_subnet"),
|
||||
huh.NewOption("interface_name — IP on a named interface", "interface_name"),
|
||||
).
|
||||
Value(&strategies),
|
||||
Description("Comma-separated, in priority order. First match wins.\nAvailable: primary_ip, management_subnet, interface_name").
|
||||
Placeholder("primary_ip, management_subnet").
|
||||
Value(&strategiesRaw).
|
||||
Validate(validateStrategies),
|
||||
huh.NewInput().
|
||||
Title("Management subnets").
|
||||
Description("Comma-separated CIDRs, e.g. 10.0.0.0/8, 192.168.0.0/16\nOnly used when management_subnet strategy is active.").
|
||||
@@ -102,6 +100,13 @@ func RunWizard(cfg *config.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
strategies := parseStrategies(strategiesRaw)
|
||||
tokenVersion := netbox.TokenVersion(token)
|
||||
if tokenVersion == 1 {
|
||||
fmt.Fprintln(os.Stderr, "\nHinweis: Du verwendest einen Legacy-Token (v1). Erstelle in NetBox einen v2-Token (beginnt mit nbt_) für bessere Kompatibilität.")
|
||||
fmt.Fprintln(os.Stderr, " NetBox → Admin → API Tokens → Add Token")
|
||||
}
|
||||
|
||||
ttl, _ := strconv.Atoi(cacheTTL)
|
||||
|
||||
var subnetList []string
|
||||
@@ -115,6 +120,7 @@ func RunWizard(cfg *config.Config) error {
|
||||
NetBox: config.NetBoxConfig{
|
||||
URL: strings.TrimRight(strings.TrimSpace(url), "/"),
|
||||
Token: strings.TrimSpace(token),
|
||||
TokenVersion: tokenVersion,
|
||||
},
|
||||
SSH: config.SSHConfig{
|
||||
DefaultUser: strings.TrimSpace(defaultUser),
|
||||
@@ -132,6 +138,34 @@ func RunWizard(cfg *config.Config) error {
|
||||
return save(out)
|
||||
}
|
||||
|
||||
var knownStrategies = map[string]bool{
|
||||
"primary_ip": true,
|
||||
"management_subnet": true,
|
||||
"interface_name": true,
|
||||
}
|
||||
|
||||
func validateStrategies(s string) error {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return errors.New("at least one strategy is required")
|
||||
}
|
||||
for _, name := range parseStrategies(s) {
|
||||
if !knownStrategies[name] {
|
||||
return fmt.Errorf("unknown strategy %q — available: primary_ip, management_subnet, interface_name", name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseStrategies(s string) []string {
|
||||
var out []string
|
||||
for _, part := range strings.Split(s, ",") {
|
||||
if name := strings.TrimSpace(part); name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func save(cfg config.Config) error {
|
||||
path := config.Path()
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
@@ -142,6 +176,7 @@ func save(cfg config.Config) error {
|
||||
b.WriteString("netbox:\n")
|
||||
b.WriteString(fmt.Sprintf(" url: %q\n", cfg.NetBox.URL))
|
||||
b.WriteString(fmt.Sprintf(" token: %q\n", cfg.NetBox.Token))
|
||||
fmt.Fprintf(&b, " token_version: %d\n", cfg.NetBox.TokenVersion)
|
||||
|
||||
b.WriteString("\nresolver:\n")
|
||||
b.WriteString(" strategies:\n")
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
|
||||
)
|
||||
|
||||
func TestSave_WritesFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
cfg := config.Config{
|
||||
NetBox: config.NetBoxConfig{
|
||||
URL: "https://netbox.example.com",
|
||||
Token: "nbt_abc123",
|
||||
TokenVersion: 2,
|
||||
},
|
||||
SSH: config.SSHConfig{DefaultUser: "admin"},
|
||||
Resolver: config.ResolverConfig{
|
||||
Strategies: []string{"primary_ip", "management_subnet"},
|
||||
ManagementSubnets: []string{"10.0.0.0/8"},
|
||||
},
|
||||
Cache: config.CacheConfig{TTL: 3600},
|
||||
}
|
||||
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("reading saved file: %v", err)
|
||||
}
|
||||
content := string(data)
|
||||
|
||||
for _, want := range []string{
|
||||
`"https://netbox.example.com"`,
|
||||
`"nbt_abc123"`,
|
||||
`token_version: 2`,
|
||||
`- primary_ip`,
|
||||
`- management_subnet`,
|
||||
`- 10.0.0.0/8`,
|
||||
`ttl: 3600`,
|
||||
`"admin"`,
|
||||
} {
|
||||
if !strings.Contains(content, want) {
|
||||
t.Errorf("saved config missing %q\nfull content:\n%s", want, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_FilePermissions(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
if err := save(config.Config{
|
||||
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
|
||||
Cache: config.CacheConfig{TTL: 60},
|
||||
}); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(filepath.Join(dir, "netssh.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("stat: %v", err)
|
||||
}
|
||||
if perm := info.Mode().Perm(); perm != 0o600 {
|
||||
t.Errorf("file permissions: got %o, want 600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_OmitsEmptyOptionalFields(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
cfg := config.Config{
|
||||
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
|
||||
Cache: config.CacheConfig{TTL: 60},
|
||||
// No DefaultUser, no ManagementSubnets, no InterfaceName
|
||||
}
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
|
||||
content := string(data)
|
||||
|
||||
for _, absent := range []string{"default_user", "management_subnets", "interface_name"} {
|
||||
if strings.Contains(content, absent) {
|
||||
t.Errorf("config should not contain %q when field is empty\nfull content:\n%s", absent, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_CreatesConfigDir(t *testing.T) {
|
||||
dir := filepath.Join(t.TempDir(), "does", "not", "exist")
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
if err := save(config.Config{
|
||||
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
|
||||
Cache: config.CacheConfig{TTL: 60},
|
||||
}); err != nil {
|
||||
t.Fatalf("save should create missing directories: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStrategies(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want []string
|
||||
}{
|
||||
{"primary_ip", []string{"primary_ip"}},
|
||||
{"management_subnet, primary_ip", []string{"management_subnet", "primary_ip"}},
|
||||
{"primary_ip,management_subnet,interface_name", []string{"primary_ip", "management_subnet", "interface_name"}},
|
||||
{" primary_ip , management_subnet ", []string{"primary_ip", "management_subnet"}},
|
||||
{"", nil},
|
||||
{" , ", nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := parseStrategies(tt.in)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("parseStrategies(%q): got %v, want %v", tt.in, got, tt.want)
|
||||
continue
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("parseStrategies(%q)[%d]: got %q, want %q", tt.in, i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStrategies_PreservesOrder(t *testing.T) {
|
||||
got := parseStrategies("interface_name, management_subnet, primary_ip")
|
||||
want := []string{"interface_name", "management_subnet", "primary_ip"}
|
||||
for i, s := range got {
|
||||
if s != want[i] {
|
||||
t.Errorf("order not preserved at [%d]: got %q, want %q", i, s, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateStrategies_Valid(t *testing.T) {
|
||||
cases := []string{
|
||||
"primary_ip",
|
||||
"management_subnet, primary_ip",
|
||||
"interface_name, management_subnet, primary_ip",
|
||||
}
|
||||
for _, c := range cases {
|
||||
if err := validateStrategies(c); err != nil {
|
||||
t.Errorf("validateStrategies(%q) should be valid, got: %v", c, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateStrategies_Invalid(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
"unknown_strategy",
|
||||
"primary_ip, typo",
|
||||
}
|
||||
for _, c := range cases {
|
||||
if err := validateStrategies(c); err == nil {
|
||||
t.Errorf("validateStrategies(%q) should return an error", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_RoundtripViaLoad(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
original := config.Config{
|
||||
NetBox: config.NetBoxConfig{
|
||||
URL: "https://netbox.zb-server.de",
|
||||
Token: "nbt_supersecret",
|
||||
TokenVersion: 2,
|
||||
},
|
||||
SSH: config.SSHConfig{DefaultUser: "root"},
|
||||
Resolver: config.ResolverConfig{
|
||||
Strategies: []string{"primary_ip"},
|
||||
ManagementSubnets: []string{"192.168.0.0/16"},
|
||||
InterfaceName: "eth0",
|
||||
},
|
||||
Cache: config.CacheConfig{TTL: 7200},
|
||||
}
|
||||
|
||||
if err := save(original); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := config.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load after save: %v", err)
|
||||
}
|
||||
|
||||
if loaded.NetBox.URL != original.NetBox.URL {
|
||||
t.Errorf("URL: got %q, want %q", loaded.NetBox.URL, original.NetBox.URL)
|
||||
}
|
||||
if loaded.NetBox.Token != original.NetBox.Token {
|
||||
t.Errorf("Token: got %q, want %q", loaded.NetBox.Token, original.NetBox.Token)
|
||||
}
|
||||
if loaded.NetBox.TokenVersion != original.NetBox.TokenVersion {
|
||||
t.Errorf("TokenVersion: got %d, want %d", loaded.NetBox.TokenVersion, original.NetBox.TokenVersion)
|
||||
}
|
||||
if loaded.SSH.DefaultUser != original.SSH.DefaultUser {
|
||||
t.Errorf("DefaultUser: got %q, want %q", loaded.SSH.DefaultUser, original.SSH.DefaultUser)
|
||||
}
|
||||
if loaded.Cache.TTL != original.Cache.TTL {
|
||||
t.Errorf("TTL: got %d, want %d", loaded.Cache.TTL, original.Cache.TTL)
|
||||
}
|
||||
}
|
||||
+245
-17
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +21,8 @@ import (
|
||||
type SelectedHost struct {
|
||||
Name string
|
||||
IP string
|
||||
User string // empty = use config default
|
||||
Port string // empty = use default port
|
||||
}
|
||||
|
||||
// --- bubbletea messages ---
|
||||
@@ -30,6 +33,7 @@ type searchResultMsg struct {
|
||||
query string
|
||||
entries []netbox.HostEntry
|
||||
err error
|
||||
recent bool // true when this is a recently-used list, not a search result
|
||||
}
|
||||
|
||||
// --- list item ---
|
||||
@@ -38,6 +42,7 @@ type hostItem struct {
|
||||
name string
|
||||
ip string
|
||||
kind string
|
||||
tags []string
|
||||
}
|
||||
|
||||
func (h hostItem) Title() string { return h.name }
|
||||
@@ -63,27 +68,90 @@ func (d compactDelegate) Render(w io.Writer, m list.Model, index int, item list.
|
||||
fmt.Fprintln(w, line)
|
||||
}
|
||||
|
||||
// --- filter ---
|
||||
|
||||
type filterOpts struct {
|
||||
tags []string
|
||||
kind string
|
||||
}
|
||||
|
||||
func parseFilter(s string) filterOpts {
|
||||
var f filterOpts
|
||||
for _, part := range strings.Fields(s) {
|
||||
if after, ok := strings.CutPrefix(part, "tag:"); ok {
|
||||
f.tags = append(f.tags, after)
|
||||
} else if after, ok := strings.CutPrefix(part, "kind:"); ok {
|
||||
f.kind = after
|
||||
}
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func applyFilter(entries []netbox.HostEntry, f filterOpts) []netbox.HostEntry {
|
||||
if len(f.tags) == 0 && f.kind == "" {
|
||||
return entries
|
||||
}
|
||||
out := make([]netbox.HostEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if f.kind != "" && e.Kind != f.kind {
|
||||
continue
|
||||
}
|
||||
if len(f.tags) > 0 {
|
||||
tagSet := make(map[string]bool, len(e.Tags))
|
||||
for _, t := range e.Tags {
|
||||
tagSet[strings.ToLower(t)] = true
|
||||
}
|
||||
allMatch := true
|
||||
for _, want := range f.tags {
|
||||
if !tagSet[strings.ToLower(want)] {
|
||||
allMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allMatch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Model ---
|
||||
|
||||
type Model struct {
|
||||
input textinput.Model
|
||||
filterInput textinput.Model // tag:X kind:Y filter, toggled with ctrl+f
|
||||
editInput textinput.Model // user@host:port inline editor
|
||||
list list.Model
|
||||
client *netbox.Client
|
||||
cache *cache.Cache
|
||||
defaultUser string
|
||||
lastSent string // last query sent to NetBox (or served from cache)
|
||||
lastResults []netbox.HostEntry // raw results before filter applied
|
||||
seq int // sequence number to discard stale results
|
||||
loading bool
|
||||
err error
|
||||
selected *SelectedHost
|
||||
width int
|
||||
height int
|
||||
recentMode bool // true when showing recent hosts (empty search, initial state)
|
||||
filterOpen bool // Ctrl+F toggles
|
||||
editMode bool // 'e' on selected item
|
||||
}
|
||||
|
||||
func New(client *netbox.Client, c *cache.Cache) *Model {
|
||||
func New(client *netbox.Client, c *cache.Cache, defaultUser string) *Model {
|
||||
ti := textinput.New()
|
||||
ti.Placeholder = "Search hostname…"
|
||||
ti.Focus()
|
||||
|
||||
fi := textinput.New()
|
||||
fi.Placeholder = "tag:prod kind:vm"
|
||||
fi.Prompt = "Filter: "
|
||||
|
||||
ei := textinput.New()
|
||||
ei.Prompt = ""
|
||||
|
||||
l := list.New(nil, compactDelegate{}, 0, 0)
|
||||
l.SetShowHelp(false)
|
||||
l.SetShowTitle(false)
|
||||
@@ -92,14 +160,33 @@ func New(client *netbox.Client, c *cache.Cache) *Model {
|
||||
|
||||
return &Model{
|
||||
input: ti,
|
||||
filterInput: fi,
|
||||
editInput: ei,
|
||||
list: l,
|
||||
client: client,
|
||||
cache: c,
|
||||
defaultUser: defaultUser,
|
||||
recentMode: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) Init() tea.Cmd {
|
||||
return textinput.Blink
|
||||
return tea.Batch(textinput.Blink, m.loadRecent())
|
||||
}
|
||||
|
||||
// loadRecent returns a Cmd that immediately resolves to the recently-used host list.
|
||||
func (m *Model) loadRecent() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if m.cache == nil {
|
||||
return searchResultMsg{query: "", entries: nil, recent: true}
|
||||
}
|
||||
recent := m.cache.RecentlyUsed(10)
|
||||
entries := make([]netbox.HostEntry, len(recent))
|
||||
for i, e := range recent {
|
||||
entries[i] = netbox.HostEntry{Name: e.Name, PrimaryIP4: e.IP, Kind: e.Kind, Tags: e.Tags}
|
||||
}
|
||||
return searchResultMsg{query: "", entries: entries, recent: true}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
@@ -108,22 +195,83 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.list.SetSize(msg.Width, msg.Height-4)
|
||||
extraRows := 4
|
||||
if m.filterOpen {
|
||||
extraRows++
|
||||
}
|
||||
if m.editMode {
|
||||
extraRows += 2
|
||||
}
|
||||
m.list.SetSize(msg.Width, msg.Height-extraRows)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
// --- Edit mode ---
|
||||
if m.editMode {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.editMode = false
|
||||
m.editInput.Blur()
|
||||
m.input.Focus()
|
||||
return m, nil
|
||||
case "enter":
|
||||
if item, ok := m.list.SelectedItem().(hostItem); ok {
|
||||
m.selected = m.buildSelected(item)
|
||||
}
|
||||
return m, tea.Quit
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.editInput, cmd = m.editInput.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
// --- Filter focused ---
|
||||
if m.filterOpen {
|
||||
switch msg.String() {
|
||||
case "ctrl+f", "esc":
|
||||
m.filterOpen = false
|
||||
m.filterInput.Blur()
|
||||
m.filterInput.SetValue("")
|
||||
m.input.Focus()
|
||||
m.updateListItems(m.lastResults)
|
||||
return m, nil
|
||||
case "enter":
|
||||
m.filterOpen = false
|
||||
m.filterInput.Blur()
|
||||
m.input.Focus()
|
||||
m.updateListItems(m.lastResults)
|
||||
return m, nil
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.filterInput, cmd = m.filterInput.Update(msg)
|
||||
m.updateListItems(m.lastResults)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
// --- Normal mode ---
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "esc":
|
||||
return m, tea.Quit
|
||||
|
||||
case "ctrl+f":
|
||||
m.filterOpen = true
|
||||
m.input.Blur()
|
||||
m.filterInput.Focus()
|
||||
return m, nil
|
||||
|
||||
case "enter":
|
||||
if item, ok := m.list.SelectedItem().(hostItem); ok {
|
||||
m.selected = &SelectedHost{Name: item.name, IP: item.ip}
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case "e":
|
||||
if item, ok := m.list.SelectedItem().(hostItem); ok {
|
||||
m.openEdit(item)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case "tab":
|
||||
// Copy the top result into the search field.
|
||||
if m.list.Items() != nil && len(m.list.Items()) > 0 {
|
||||
if item, ok := m.list.Items()[0].(hostItem); ok {
|
||||
m.input.SetValue(item.name)
|
||||
@@ -134,18 +282,24 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
case debounceMsg:
|
||||
// Only query if the input has changed since the last request.
|
||||
q := m.input.Value()
|
||||
if q == m.lastSent {
|
||||
return m, nil
|
||||
}
|
||||
m.lastSent = q
|
||||
m.recentMode = false
|
||||
m.loading = true
|
||||
m.seq++
|
||||
seq := m.seq
|
||||
return m, m.doSearch(q, seq)
|
||||
|
||||
case searchResultMsg:
|
||||
if msg.recent {
|
||||
m.recentMode = true
|
||||
m.lastResults = msg.entries
|
||||
m.updateListItems(msg.entries)
|
||||
return m, nil
|
||||
}
|
||||
if msg.query != m.lastSent {
|
||||
return m, nil // discard stale result
|
||||
}
|
||||
@@ -154,20 +308,13 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.err = msg.err
|
||||
return m, nil
|
||||
}
|
||||
items := make([]list.Item, len(msg.entries))
|
||||
for i, e := range msg.entries {
|
||||
ip := e.PrimaryIP4
|
||||
if ip == "" {
|
||||
ip = e.PrimaryIP6
|
||||
}
|
||||
items[i] = hostItem{name: e.Name, ip: ip, kind: e.Kind}
|
||||
}
|
||||
m.list.SetItems(items)
|
||||
m.lastResults = msg.entries
|
||||
m.updateListItems(msg.entries)
|
||||
m.err = nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward to text input and restart the debounce timer.
|
||||
// Forward to search input and restart debounce timer.
|
||||
var cmds []tea.Cmd
|
||||
var inputCmd tea.Cmd
|
||||
m.input, inputCmd = m.input.Update(msg)
|
||||
@@ -188,14 +335,33 @@ func (m *Model) View() string {
|
||||
sb.WriteString(title + "\n\n")
|
||||
sb.WriteString(m.input.View() + "\n")
|
||||
|
||||
if m.filterOpen {
|
||||
sb.WriteString(m.filterInput.View() + "\n")
|
||||
}
|
||||
|
||||
if m.loading {
|
||||
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(" searching…") + "\n")
|
||||
} else if m.err != nil {
|
||||
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Render(" error: "+m.err.Error()) + "\n")
|
||||
} else {
|
||||
if m.recentMode && m.input.Value() == "" {
|
||||
if len(m.list.Items()) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(" Recent connections") + "\n")
|
||||
}
|
||||
}
|
||||
sb.WriteString(m.list.View())
|
||||
}
|
||||
|
||||
if m.editMode {
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render("Connect as: "))
|
||||
sb.WriteString(m.editInput.View() + "\n")
|
||||
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(" enter connect esc cancel") + "\n")
|
||||
} else {
|
||||
hint := "ctrl+c quit enter connect e edit ctrl+f filter"
|
||||
sb.WriteString("\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(hint) + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -211,13 +377,14 @@ func (m *Model) startDebounce() tea.Cmd {
|
||||
}
|
||||
|
||||
func (m *Model) doSearch(query string, seq int) tea.Cmd {
|
||||
opts := m.netboxSearchOpts()
|
||||
return func() tea.Msg {
|
||||
// Return cache hits immediately without a network round-trip.
|
||||
if m.cache != nil {
|
||||
if cached := m.cache.Search(query); len(cached) > 0 {
|
||||
entries := make([]netbox.HostEntry, len(cached))
|
||||
for i, c := range cached {
|
||||
entries[i] = netbox.HostEntry{Name: c.Name, PrimaryIP4: c.IP, Kind: c.Kind}
|
||||
entries[i] = netbox.HostEntry{Name: c.Name, PrimaryIP4: c.IP, Kind: c.Kind, Tags: c.Tags}
|
||||
}
|
||||
return searchResultMsg{query: query, entries: entries}
|
||||
}
|
||||
@@ -230,8 +397,69 @@ func (m *Model) doSearch(query string, seq int) tea.Cmd {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
entries, err := m.client.Search(ctx, query)
|
||||
entries, err := m.client.Search(ctx, query, opts)
|
||||
_ = seq
|
||||
return searchResultMsg{query: query, entries: entries, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// netboxSearchOpts derives SearchOptions from the current filter input.
|
||||
func (m *Model) netboxSearchOpts() netbox.SearchOptions {
|
||||
f := parseFilter(m.filterInput.Value())
|
||||
var tag string
|
||||
if len(f.tags) > 0 {
|
||||
tag = f.tags[0]
|
||||
}
|
||||
return netbox.SearchOptions{Tag: tag, Kind: f.kind}
|
||||
}
|
||||
|
||||
// updateListItems applies the active filter and sets the list items.
|
||||
func (m *Model) updateListItems(entries []netbox.HostEntry) {
|
||||
f := parseFilter(m.filterInput.Value())
|
||||
filtered := applyFilter(entries, f)
|
||||
items := make([]list.Item, len(filtered))
|
||||
for i, e := range filtered {
|
||||
ip := e.PrimaryIP4
|
||||
if ip == "" {
|
||||
ip = e.PrimaryIP6
|
||||
}
|
||||
items[i] = hostItem{name: e.Name, ip: ip, kind: e.Kind, tags: e.Tags}
|
||||
}
|
||||
m.list.SetItems(items)
|
||||
}
|
||||
|
||||
// openEdit switches to edit mode for the given list item.
|
||||
func (m *Model) openEdit(item hostItem) {
|
||||
m.editMode = true
|
||||
m.input.Blur()
|
||||
|
||||
user := m.defaultUser
|
||||
if user == "" {
|
||||
user = item.name
|
||||
}
|
||||
m.editInput.SetValue(fmt.Sprintf("%s@%s:22", user, item.name))
|
||||
m.editInput.CursorEnd()
|
||||
m.editInput.Focus()
|
||||
}
|
||||
|
||||
// buildSelected parses the editInput to extract user/port overrides.
|
||||
func (m *Model) buildSelected(item hostItem) *SelectedHost {
|
||||
s := strings.TrimSpace(m.editInput.Value())
|
||||
sel := &SelectedHost{Name: item.name, IP: item.ip}
|
||||
|
||||
// Extract user: everything before @
|
||||
if idx := strings.Index(s, "@"); idx != -1 {
|
||||
sel.User = strings.TrimSpace(s[:idx])
|
||||
s = s[idx+1:]
|
||||
}
|
||||
|
||||
// Extract port: after the last colon, if it's a number
|
||||
if idx := strings.LastIndex(s, ":"); idx != -1 {
|
||||
portStr := strings.TrimSpace(s[idx+1:])
|
||||
if _, err := strconv.Atoi(portStr); err == nil && portStr != "22" {
|
||||
sel.Port = portStr
|
||||
}
|
||||
}
|
||||
|
||||
return sel
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user