Add core modules (SSH args parser, cache, resolver, NetBox client) with tests
Release / release (push) Failing after 51s
Release / release (push) Failing after 51s
This commit is contained in:
Vendored
+134
@@ -0,0 +1,134 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
Name string `json:"name"`
|
||||
IP string `json:"ip"`
|
||||
Kind string `json:"kind"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
CachedAt time.Time `json:"cached_at"`
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]Entry
|
||||
path string
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type diskFormat struct {
|
||||
Entries []Entry `json:"entries"`
|
||||
}
|
||||
|
||||
func New(path string, ttlSeconds int) *Cache {
|
||||
return &Cache{
|
||||
entries: make(map[string]Entry),
|
||||
path: path,
|
||||
ttl: time.Duration(ttlSeconds) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Load() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
data, err := os.ReadFile(c.path)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var df diskFormat
|
||||
if err := json.Unmarshal(data, &df); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.entries = make(map[string]Entry, len(df.Entries))
|
||||
for _, e := range df.Entries {
|
||||
c.entries[e.Name] = e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) Save() error {
|
||||
c.mu.RLock()
|
||||
df := diskFormat{Entries: make([]Entry, 0, len(c.entries))}
|
||||
for _, e := range c.entries {
|
||||
df.Entries = append(df.Entries, e)
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(df, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(c.path, data, 0o644)
|
||||
}
|
||||
|
||||
func (c *Cache) Upsert(e Entry) {
|
||||
e.CachedAt = time.Now()
|
||||
c.mu.Lock()
|
||||
c.entries[e.Name] = e
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Search returns all entries whose name starts with prefix (case-insensitive).
|
||||
// TTL is intentionally ignored — this is used for shell completion.
|
||||
func (c *Cache) Search(prefix string) []Entry {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
prefix = strings.ToLower(prefix)
|
||||
var out []Entry
|
||||
for name, e := range c.entries {
|
||||
if strings.HasPrefix(strings.ToLower(name), prefix) {
|
||||
out = append(out, e)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Get returns an entry and reports whether it is still within the TTL.
|
||||
func (c *Cache) Get(name string) (entry Entry, fresh bool) {
|
||||
c.mu.RLock()
|
||||
e, ok := c.entries[name]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return Entry{}, false
|
||||
}
|
||||
if c.ttl == 0 {
|
||||
return e, false
|
||||
}
|
||||
return e, time.Since(e.CachedAt) < c.ttl
|
||||
}
|
||||
|
||||
func (c *Cache) Clear() {
|
||||
c.mu.Lock()
|
||||
c.entries = make(map[string]Entry)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Cache) All() []Entry {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
out := make([]Entry, 0, len(c.entries))
|
||||
for _, e := range c.entries {
|
||||
out = append(out, e)
|
||||
}
|
||||
return out
|
||||
}
|
||||
Vendored
+235
@@ -0,0 +1,235 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
c := New("/tmp/test.json", 60)
|
||||
if c == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
if c.ttl != 60*time.Second {
|
||||
t.Errorf("ttl: got %v, want %v", c.ttl, 60*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingFile(t *testing.T) {
|
||||
c := New("/nonexistent/path/cache.json", 60)
|
||||
if err := c.Load(); err != nil {
|
||||
t.Errorf("Load on missing file should not error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidJSON(t *testing.T) {
|
||||
f := tempFile(t, []byte("not json"))
|
||||
c := New(f, 60)
|
||||
if err := c.Load(); err == nil {
|
||||
t.Error("Load on invalid JSON should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndLoad_Roundtrip(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "cache.json")
|
||||
c := New(path, 3600)
|
||||
|
||||
c.Upsert(Entry{Name: "host-a", IP: "10.0.0.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "host-b", IP: "10.0.0.2", Kind: "vm", Tags: []string{"prod"}})
|
||||
|
||||
if err := c.Save(); err != nil {
|
||||
t.Fatalf("Save: %v", err)
|
||||
}
|
||||
|
||||
c2 := New(path, 3600)
|
||||
if err := c2.Load(); err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
e, _ := c2.Get("host-a")
|
||||
if e.IP != "10.0.0.1" {
|
||||
t.Errorf("host-a IP: got %q, want %q", e.IP, "10.0.0.1")
|
||||
}
|
||||
e2, _ := c2.Get("host-b")
|
||||
if len(e2.Tags) != 1 || e2.Tags[0] != "prod" {
|
||||
t.Errorf("host-b tags: got %v, want [prod]", e2.Tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_CreatesDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "sub", "dir", "cache.json")
|
||||
c := New(path, 60)
|
||||
c.Upsert(Entry{Name: "x", IP: "1.2.3.4", Kind: "device"})
|
||||
if err := c.Save(); err != nil {
|
||||
t.Fatalf("Save: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Errorf("cache file not created: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsert_SetsTimestamp(t *testing.T) {
|
||||
c := New("", 60)
|
||||
before := time.Now()
|
||||
c.Upsert(Entry{Name: "h", IP: "1.1.1.1", Kind: "device"})
|
||||
e, _ := c.Get("h")
|
||||
if e.CachedAt.Before(before) {
|
||||
t.Error("CachedAt should be set to current time on Upsert")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsert_Overwrites(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "host", IP: "10.0.0.2", Kind: "device"})
|
||||
e, _ := c.Get("host")
|
||||
if e.IP != "10.0.0.2" {
|
||||
t.Errorf("Upsert should overwrite: got %q, want %q", e.IP, "10.0.0.2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_PrefixMatch(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "app-server-01", IP: "10.0.0.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "app-server-02", IP: "10.0.0.2", Kind: "vm"})
|
||||
c.Upsert(Entry{Name: "db-server-01", IP: "10.0.0.3", Kind: "device"})
|
||||
|
||||
results := c.Search("app")
|
||||
if len(results) != 2 {
|
||||
t.Errorf("Search(app): got %d results, want 2", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_CaseInsensitive(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "App-Server", IP: "10.0.0.1", Kind: "device"})
|
||||
|
||||
if len(c.Search("app")) != 1 {
|
||||
t.Error("Search should be case-insensitive")
|
||||
}
|
||||
if len(c.Search("APP")) != 1 {
|
||||
t.Error("Search should be case-insensitive for uppercase")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_EmptyPrefix(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"})
|
||||
|
||||
if len(c.Search("")) != 2 {
|
||||
t.Error("Search('') should return all entries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_NoMatch(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "host", IP: "1.1.1.1", Kind: "device"})
|
||||
|
||||
if len(c.Search("xyz")) != 0 {
|
||||
t.Error("Search should return empty slice when no match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_Fresh(t *testing.T) {
|
||||
c := New("", 3600)
|
||||
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
|
||||
|
||||
e, fresh := c.Get("host")
|
||||
if !fresh {
|
||||
t.Error("entry just inserted should be fresh")
|
||||
}
|
||||
if e.IP != "10.0.0.1" {
|
||||
t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_Expired(t *testing.T) {
|
||||
c := New("", 1) // 1 second TTL
|
||||
e := Entry{Name: "host", IP: "10.0.0.1", Kind: "device", CachedAt: time.Now().Add(-2 * time.Second)}
|
||||
c.mu.Lock()
|
||||
c.entries["host"] = e
|
||||
c.mu.Unlock()
|
||||
|
||||
_, fresh := c.Get("host")
|
||||
if fresh {
|
||||
t.Error("entry older than TTL should not be fresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_ZeroTTL_AlwaysStale(t *testing.T) {
|
||||
c := New("", 0) // TTL=0 means never fresh for connect mode
|
||||
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
|
||||
|
||||
_, fresh := c.Get("host")
|
||||
if fresh {
|
||||
t.Error("TTL=0 should always return fresh=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_Missing(t *testing.T) {
|
||||
c := New("", 60)
|
||||
_, fresh := c.Get("nonexistent")
|
||||
if fresh {
|
||||
t.Error("missing entry should not be fresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClear(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"})
|
||||
c.Clear()
|
||||
|
||||
if len(c.All()) != 0 {
|
||||
t.Error("Clear should remove all entries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
c := New("", 60)
|
||||
c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"})
|
||||
c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"})
|
||||
|
||||
all := c.All()
|
||||
if len(all) != 2 {
|
||||
t.Errorf("All: got %d entries, want 2", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_ProducesValidJSON(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "cache.json")
|
||||
c := New(path, 60)
|
||||
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device", Tags: []string{"mgmt"}})
|
||||
|
||||
if err := c.Save(); err != nil {
|
||||
t.Fatalf("Save: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
var df diskFormat
|
||||
if err := json.Unmarshal(data, &df); err != nil {
|
||||
t.Fatalf("saved file is not valid JSON: %v", err)
|
||||
}
|
||||
if len(df.Entries) != 1 {
|
||||
t.Errorf("expected 1 entry in JSON, got %d", len(df.Entries))
|
||||
}
|
||||
}
|
||||
|
||||
// tempFile writes content to a temp file and returns its path.
|
||||
func tempFile(t *testing.T, content []byte) string {
|
||||
t.Helper()
|
||||
f, err := os.CreateTemp(t.TempDir(), "cache-*.json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := f.Write(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
return f.Name()
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
NetBox NetBoxConfig `mapstructure:"netbox"`
|
||||
Resolver ResolverConfig `mapstructure:"resolver"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
SSH SSHConfig `mapstructure:"ssh"`
|
||||
}
|
||||
|
||||
type NetBoxConfig struct {
|
||||
URL string `mapstructure:"url"`
|
||||
Token string `mapstructure:"token"`
|
||||
}
|
||||
|
||||
type ResolverConfig struct {
|
||||
Strategies []string `mapstructure:"strategies"`
|
||||
ManagementSubnets []string `mapstructure:"management_subnets"`
|
||||
InterfaceName string `mapstructure:"interface_name"`
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
TTL int `mapstructure:"ttl"`
|
||||
Path string `mapstructure:"path"`
|
||||
}
|
||||
|
||||
type SSHConfig struct {
|
||||
DefaultUser string `mapstructure:"default_user"`
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
v.SetDefault("resolver.strategies", []string{"management_subnet", "primary_ip"})
|
||||
v.SetDefault("resolver.management_subnets", []string{})
|
||||
v.SetDefault("cache.ttl", 3600)
|
||||
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err == nil {
|
||||
v.SetConfigName("netssh")
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(filepath.Join(configDir))
|
||||
v.AddConfigPath(".")
|
||||
}
|
||||
|
||||
v.SetEnvPrefix("NETSSH")
|
||||
v.AutomaticEnv()
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("reading config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Cache.Path == "" {
|
||||
cacheDir, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
cacheDir = filepath.Join(os.Getenv("HOME"), ".cache")
|
||||
}
|
||||
cfg.Cache.Path = filepath.Join(cacheDir, "netssh", "hosts.json")
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
package netbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
baseURL string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(baseURL, token string) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
token: token,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// Search queries devices and VMs in parallel and merges the results.
|
||||
func (c *Client) Search(ctx context.Context, query string) ([]HostEntry, error) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
results []HostEntry
|
||||
errs []error
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
devices, err := c.searchDevices(ctx, query)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("devices: %w", err))
|
||||
return
|
||||
}
|
||||
results = append(results, devices...)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vms, err := c.searchVMs(ctx, query)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("vms: %w", err))
|
||||
return
|
||||
}
|
||||
results = append(results, vms...)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(errs) == 2 {
|
||||
return nil, fmt.Errorf("netbox search failed: %v; %v", errs[0], errs[1])
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetIPs returns all IP addresses assigned to a host, used by resolver strategies
|
||||
// that need more than just the primary IP.
|
||||
func (c *Client) GetIPs(ctx context.Context, entry HostEntry) ([]string, error) {
|
||||
var apiURL string
|
||||
switch entry.Kind {
|
||||
case "device":
|
||||
apiURL = fmt.Sprintf("%s/api/ipam/ip-addresses/?device_id=%d&limit=100", c.baseURL, entry.ID)
|
||||
case "vm":
|
||||
apiURL = fmt.Sprintf("%s/api/ipam/ip-addresses/?virtual_machine_id=%d&limit=100", c.baseURL, entry.ID)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown host kind: %q", entry.Kind)
|
||||
}
|
||||
|
||||
var resp netboxIPListResponse
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ips := make([]string, 0, len(resp.Results))
|
||||
for _, r := range resp.Results {
|
||||
ips = append(ips, stripPrefix(r.Address))
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// GetIPsWithFilter calls /api/ipam/ip-addresses/ with arbitrary filter query parameters.
|
||||
func (c *Client) GetIPsWithFilter(ctx context.Context, filterParams string) ([]string, error) {
|
||||
apiURL := fmt.Sprintf("%s/api/ipam/ip-addresses/?%s&limit=100", c.baseURL, filterParams)
|
||||
var resp netboxIPListResponse
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips := make([]string, 0, len(resp.Results))
|
||||
for _, r := range resp.Results {
|
||||
ips = append(ips, stripPrefix(r.Address))
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (c *Client) searchDevices(ctx context.Context, query string) ([]HostEntry, error) {
|
||||
apiURL := fmt.Sprintf("%s/api/dcim/devices/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query))
|
||||
var resp netboxListResponse[netboxDevice]
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries := make([]HostEntry, 0, len(resp.Results))
|
||||
for _, d := range resp.Results {
|
||||
entries = append(entries, deviceToEntry(d))
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (c *Client) searchVMs(ctx context.Context, query string) ([]HostEntry, error) {
|
||||
apiURL := fmt.Sprintf("%s/api/virtualization/virtual-machines/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query))
|
||||
var resp netboxListResponse[netboxVM]
|
||||
if err := c.get(ctx, apiURL, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries := make([]HostEntry, 0, len(resp.Results))
|
||||
for _, v := range resp.Results {
|
||||
entries = append(entries, vmToEntry(v))
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (c *Client) get(ctx context.Context, apiURL string, out any) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Token "+c.token)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request to %s: %w", apiURL, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("netbox returned %d for %s", resp.StatusCode, apiURL)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
|
||||
return fmt.Errorf("decoding response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deviceToEntry(d netboxDevice) HostEntry {
|
||||
e := HostEntry{ID: d.ID, Name: d.Name, Kind: "device"}
|
||||
if d.PrimaryIP4 != nil {
|
||||
e.PrimaryIP4 = stripPrefix(d.PrimaryIP4.Address)
|
||||
}
|
||||
if d.PrimaryIP6 != nil {
|
||||
e.PrimaryIP6 = stripPrefix(d.PrimaryIP6.Address)
|
||||
}
|
||||
for _, t := range d.Tags {
|
||||
e.Tags = append(e.Tags, t.Name)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func vmToEntry(v netboxVM) HostEntry {
|
||||
e := HostEntry{ID: v.ID, Name: v.Name, Kind: "vm"}
|
||||
if v.PrimaryIP4 != nil {
|
||||
e.PrimaryIP4 = stripPrefix(v.PrimaryIP4.Address)
|
||||
}
|
||||
if v.PrimaryIP6 != nil {
|
||||
e.PrimaryIP6 = stripPrefix(v.PrimaryIP6.Address)
|
||||
}
|
||||
for _, t := range v.Tags {
|
||||
e.Tags = append(e.Tags, t.Name)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// stripPrefix removes the CIDR prefix length from a NetBox IP (e.g. "10.0.1.5/24" → "10.0.1.5").
|
||||
func stripPrefix(cidr string) string {
|
||||
if idx := strings.Index(cidr, "/"); idx != -1 {
|
||||
return cidr[:idx]
|
||||
}
|
||||
return cidr
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
package netbox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// newTestServer returns an httptest.Server that serves fixed responses per path.
|
||||
func newTestServer(t *testing.T, handlers map[string]any) *httptest.Server {
|
||||
t.Helper()
|
||||
mux := http.NewServeMux()
|
||||
for path, body := range handlers {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("marshalling handler for %s: %v", path, err)
|
||||
}
|
||||
captured := b
|
||||
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(captured)
|
||||
})
|
||||
}
|
||||
return httptest.NewServer(mux)
|
||||
}
|
||||
|
||||
func deviceListResponse(devices ...netboxDevice) netboxListResponse[netboxDevice] {
|
||||
return netboxListResponse[netboxDevice]{Count: len(devices), Results: devices}
|
||||
}
|
||||
|
||||
func vmListResponse(vms ...netboxVM) netboxListResponse[netboxVM] {
|
||||
return netboxListResponse[netboxVM]{Count: len(vms), Results: vms}
|
||||
}
|
||||
|
||||
func ipListResponse(addrs ...string) netboxIPListResponse {
|
||||
resp := netboxIPListResponse{Count: len(addrs)}
|
||||
for _, a := range addrs {
|
||||
resp.Results = append(resp.Results, struct {
|
||||
Address string `json:"address"`
|
||||
Interface *struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"assigned_object"`
|
||||
}{Address: a})
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func TestSearch_ReturnsBothDevicesAndVMs(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/dcim/devices/": deviceListResponse(
|
||||
netboxDevice{ID: 1, Name: "router-01", PrimaryIP4: &netboxIP{Address: "10.0.0.1/24"}},
|
||||
),
|
||||
"/api/virtualization/virtual-machines/": vmListResponse(
|
||||
netboxVM{ID: 2, Name: "vm-01", PrimaryIP4: &netboxIP{Address: "10.0.0.2/24"}},
|
||||
),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, err := c.Search(context.Background(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(results) != 2 {
|
||||
t.Errorf("got %d results, want 2", len(results))
|
||||
}
|
||||
|
||||
names := map[string]bool{}
|
||||
for _, r := range results {
|
||||
names[r.Name] = true
|
||||
}
|
||||
if !names["router-01"] || !names["vm-01"] {
|
||||
t.Errorf("missing expected hosts in results: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_MapsKindCorrectly(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/dcim/devices/": deviceListResponse(
|
||||
netboxDevice{ID: 1, Name: "sw-01"},
|
||||
),
|
||||
"/api/virtualization/virtual-machines/": vmListResponse(
|
||||
netboxVM{ID: 2, Name: "vm-01"},
|
||||
),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, _ := c.Search(context.Background(), "")
|
||||
|
||||
for _, r := range results {
|
||||
switch r.Name {
|
||||
case "sw-01":
|
||||
if r.Kind != "device" {
|
||||
t.Errorf("sw-01 kind: got %q, want %q", r.Kind, "device")
|
||||
}
|
||||
case "vm-01":
|
||||
if r.Kind != "vm" {
|
||||
t.Errorf("vm-01 kind: got %q, want %q", r.Kind, "vm")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_StripsPrefixFromPrimaryIP(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/dcim/devices/": deviceListResponse(
|
||||
netboxDevice{ID: 1, Name: "host", PrimaryIP4: &netboxIP{Address: "192.168.1.10/24"}},
|
||||
),
|
||||
"/api/virtualization/virtual-machines/": vmListResponse(),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, _ := c.Search(context.Background(), "host")
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected at least one result")
|
||||
}
|
||||
if results[0].PrimaryIP4 != "192.168.1.10" {
|
||||
t.Errorf("PrimaryIP4: got %q, want %q", results[0].PrimaryIP4, "192.168.1.10")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_TagsAreMapped(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/dcim/devices/": deviceListResponse(
|
||||
netboxDevice{
|
||||
ID: 1,
|
||||
Name: "host",
|
||||
Tags: []struct {
|
||||
Name string `json:"name"`
|
||||
}{{Name: "prod"}, {Name: "mgmt"}},
|
||||
},
|
||||
),
|
||||
"/api/virtualization/virtual-machines/": vmListResponse(),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, _ := c.Search(context.Background(), "")
|
||||
if len(results[0].Tags) != 2 {
|
||||
t.Errorf("tags: got %v, want [prod mgmt]", results[0].Tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_PartialFailure_ReturnsAvailableResults(t *testing.T) {
|
||||
// Only devices endpoint works; VMs returns 500.
|
||||
mux := http.NewServeMux()
|
||||
body, _ := json.Marshal(deviceListResponse(netboxDevice{ID: 1, Name: "sw-01"}))
|
||||
mux.HandleFunc("/api/dcim/devices/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(body)
|
||||
})
|
||||
mux.HandleFunc("/api/virtualization/virtual-machines/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
results, err := c.Search(context.Background(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("partial failure should not return error, got: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].Name != "sw-01" {
|
||||
t.Errorf("expected device results, got %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_BothFail_ReturnsError(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
_, err := c.Search(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Error("both endpoints failing should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIPs_Device(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/ipam/ip-addresses/": ipListResponse("10.0.0.1/24", "10.0.0.2/24"),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
ips, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "device"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetIPs: %v", err)
|
||||
}
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("got %d IPs, want 2", len(ips))
|
||||
}
|
||||
if ips[0] != "10.0.0.1" || ips[1] != "10.0.0.2" {
|
||||
t.Errorf("IPs: got %v, want [10.0.0.1 10.0.0.2]", ips)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIPs_VM(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/ipam/ip-addresses/": ipListResponse("172.16.0.5/16"),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
ips, err := c.GetIPs(context.Background(), HostEntry{ID: 2, Kind: "vm"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetIPs: %v", err)
|
||||
}
|
||||
if len(ips) != 1 || ips[0] != "172.16.0.5" {
|
||||
t.Errorf("IPs: got %v, want [172.16.0.5]", ips)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIPs_UnknownKind(t *testing.T) {
|
||||
c := NewClient("http://localhost", "token")
|
||||
_, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "unknown"})
|
||||
if err == nil {
|
||||
t.Error("unknown kind should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIPsWithFilter(t *testing.T) {
|
||||
srv := newTestServer(t, map[string]any{
|
||||
"/api/ipam/ip-addresses/": ipListResponse("10.10.10.1/24"),
|
||||
})
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient(srv.URL, "token")
|
||||
ips, err := c.GetIPsWithFilter(context.Background(), "device_id=1&interface_name=mgmt0")
|
||||
if err != nil {
|
||||
t.Fatalf("GetIPsWithFilter: %v", err)
|
||||
}
|
||||
if len(ips) != 1 || ips[0] != "10.10.10.1" {
|
||||
t.Errorf("IPs: got %v, want [10.10.10.1]", ips)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"10.0.0.1/24", "10.0.0.1"},
|
||||
{"::1/128", "::1"},
|
||||
{"192.168.1.1", "192.168.1.1"}, // no prefix — unchanged
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := stripPrefix(tt.in); got != tt.want {
|
||||
t.Errorf("stripPrefix(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package netbox
|
||||
|
||||
// HostEntry is a unified model for both devices and virtual machines from NetBox.
|
||||
type HostEntry struct {
|
||||
ID int
|
||||
Name string
|
||||
Kind string // "device" | "vm"
|
||||
PrimaryIP4 string // e.g. "10.0.1.5" (prefix length stripped)
|
||||
PrimaryIP6 string
|
||||
Tags []string
|
||||
}
|
||||
|
||||
// netboxIP represents an IP address as returned by the NetBox API.
|
||||
type netboxIP struct {
|
||||
Address string `json:"address"` // CIDR notation, e.g. "10.0.1.5/24"
|
||||
}
|
||||
|
||||
// netboxDevice matches the relevant fields of the NetBox /dcim/devices/ response.
|
||||
type netboxDevice struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Tags []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"tags"`
|
||||
PrimaryIP4 *netboxIP `json:"primary_ip4"`
|
||||
PrimaryIP6 *netboxIP `json:"primary_ip6"`
|
||||
}
|
||||
|
||||
// netboxVM matches the relevant fields of the NetBox /virtualization/virtual-machines/ response.
|
||||
type netboxVM struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Tags []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"tags"`
|
||||
PrimaryIP4 *netboxIP `json:"primary_ip4"`
|
||||
PrimaryIP6 *netboxIP `json:"primary_ip6"`
|
||||
}
|
||||
|
||||
type netboxListResponse[T any] struct {
|
||||
Count int `json:"count"`
|
||||
Results []T `json:"results"`
|
||||
}
|
||||
|
||||
type netboxIPListResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []struct {
|
||||
Address string `json:"address"`
|
||||
Interface *struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"assigned_object"`
|
||||
} `json:"results"`
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// Chain tries each strategy in order until one returns an IP.
|
||||
type Chain struct {
|
||||
strategies []Strategy
|
||||
}
|
||||
|
||||
// New builds a Chain from the strategy names listed in the resolver config.
|
||||
func New(cfg config.ResolverConfig) (*Chain, error) {
|
||||
var strategies []Strategy
|
||||
for _, name := range cfg.Strategies {
|
||||
s, err := newStrategy(name, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolver strategy %q: %w", name, err)
|
||||
}
|
||||
strategies = append(strategies, s)
|
||||
}
|
||||
return &Chain{strategies: strategies}, nil
|
||||
}
|
||||
|
||||
func (c *Chain) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) {
|
||||
for _, s := range c.strategies {
|
||||
ip, err := s.Resolve(ctx, entry, client)
|
||||
if err == nil {
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no strategy resolved an IP for %q", entry.Name)
|
||||
}
|
||||
|
||||
func newStrategy(name string, cfg config.ResolverConfig) (Strategy, error) {
|
||||
switch name {
|
||||
case "primary_ip":
|
||||
return &PrimaryIPStrategy{}, nil
|
||||
case "management_subnet":
|
||||
s, err := NewManagementSubnetStrategy(cfg.ManagementSubnets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
case "interface_name":
|
||||
if cfg.InterfaceName == "" {
|
||||
return nil, fmt.Errorf("interface_name strategy requires resolver.interface_name to be set")
|
||||
}
|
||||
return &InterfaceNameStrategy{name: cfg.InterfaceName}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy %q", name)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// stubStrategy is a test double for Strategy.
|
||||
type stubStrategy struct {
|
||||
name string
|
||||
ip string
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubStrategy) Name() string { return s.name }
|
||||
func (s *stubStrategy) Resolve(_ context.Context, _ *netbox.HostEntry, _ *netbox.Client) (string, error) {
|
||||
return s.ip, s.err
|
||||
}
|
||||
|
||||
func TestChain_FirstStrategySucceeds(t *testing.T) {
|
||||
c := &Chain{strategies: []Strategy{
|
||||
&stubStrategy{name: "first", ip: "10.0.0.1"},
|
||||
&stubStrategy{name: "second", ip: "10.0.0.2"},
|
||||
}}
|
||||
ip, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "10.0.0.1" {
|
||||
t.Errorf("got %q, want first strategy's IP %q", ip, "10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_FallsBackToNextStrategy(t *testing.T) {
|
||||
c := &Chain{strategies: []Strategy{
|
||||
&stubStrategy{name: "first", err: ErrNoIP},
|
||||
&stubStrategy{name: "second", ip: "10.0.0.2"},
|
||||
}}
|
||||
ip, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "10.0.0.2" {
|
||||
t.Errorf("got %q, want second strategy's IP %q", ip, "10.0.0.2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AllStrategiesFail(t *testing.T) {
|
||||
c := &Chain{strategies: []Strategy{
|
||||
&stubStrategy{name: "a", err: ErrNoIP},
|
||||
&stubStrategy{name: "b", err: errors.New("api error")},
|
||||
}}
|
||||
_, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error when all strategies fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_EmptyStrategies(t *testing.T) {
|
||||
c := &Chain{}
|
||||
_, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
|
||||
if err == nil {
|
||||
t.Error("empty chain should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_PrimaryIP(t *testing.T) {
|
||||
cfg := config.ResolverConfig{Strategies: []string{"primary_ip"}}
|
||||
c, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if len(c.strategies) != 1 {
|
||||
t.Errorf("got %d strategies, want 1", len(c.strategies))
|
||||
}
|
||||
if c.strategies[0].Name() != "primary_ip" {
|
||||
t.Errorf("strategy name: got %q, want %q", c.strategies[0].Name(), "primary_ip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_ManagementSubnet(t *testing.T) {
|
||||
cfg := config.ResolverConfig{
|
||||
Strategies: []string{"management_subnet"},
|
||||
ManagementSubnets: []string{"10.0.0.0/8"},
|
||||
}
|
||||
c, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if c.strategies[0].Name() != "management_subnet" {
|
||||
t.Errorf("strategy name: got %q, want %q", c.strategies[0].Name(), "management_subnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_ManagementSubnet_InvalidCIDR(t *testing.T) {
|
||||
cfg := config.ResolverConfig{
|
||||
Strategies: []string{"management_subnet"},
|
||||
ManagementSubnets: []string{"not-a-cidr"},
|
||||
}
|
||||
_, err := New(cfg)
|
||||
if err == nil {
|
||||
t.Error("invalid CIDR should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_InterfaceName(t *testing.T) {
|
||||
cfg := config.ResolverConfig{
|
||||
Strategies: []string{"interface_name"},
|
||||
InterfaceName: "mgmt0",
|
||||
}
|
||||
c, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if c.strategies[0].Name() != "interface_name" {
|
||||
t.Errorf("strategy name: got %q", c.strategies[0].Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_InterfaceName_MissingConfig(t *testing.T) {
|
||||
cfg := config.ResolverConfig{
|
||||
Strategies: []string{"interface_name"},
|
||||
InterfaceName: "", // not set
|
||||
}
|
||||
_, err := New(cfg)
|
||||
if err == nil {
|
||||
t.Error("interface_name without config should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_UnknownStrategy(t *testing.T) {
|
||||
cfg := config.ResolverConfig{Strategies: []string{"nonexistent"}}
|
||||
_, err := New(cfg)
|
||||
if err == nil {
|
||||
t.Error("unknown strategy should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_MultipleStrategies(t *testing.T) {
|
||||
cfg := config.ResolverConfig{
|
||||
Strategies: []string{"management_subnet", "primary_ip"},
|
||||
ManagementSubnets: []string{"10.0.0.0/8"},
|
||||
}
|
||||
c, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if len(c.strategies) != 2 {
|
||||
t.Errorf("got %d strategies, want 2", len(c.strategies))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// InterfaceNameStrategy finds the first IP assigned to a named interface (e.g. "mgmt0", "eth0").
|
||||
type InterfaceNameStrategy struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (s *InterfaceNameStrategy) Name() string { return "interface_name" }
|
||||
|
||||
func (s *InterfaceNameStrategy) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) {
|
||||
// Build filter parameters for IP addresses attached to the named interface.
|
||||
var filterParam string
|
||||
switch entry.Kind {
|
||||
case "device":
|
||||
filterParam = fmt.Sprintf("device_id=%d&interface_name=%s", entry.ID, url.QueryEscape(s.name))
|
||||
case "vm":
|
||||
filterParam = fmt.Sprintf("virtual_machine_id=%d&vminterface_name=%s", entry.ID, url.QueryEscape(s.name))
|
||||
default:
|
||||
return "", fmt.Errorf("unknown kind %q", entry.Kind)
|
||||
}
|
||||
|
||||
ips, err := client.GetIPsWithFilter(ctx, filterParam)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetching IPs for interface %q: %w", s.name, err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return "", ErrNoIP
|
||||
}
|
||||
return ips[0], nil
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// ManagementSubnetStrategy finds the first IP of a host that falls within
|
||||
// one of the configured management subnets.
|
||||
type ManagementSubnetStrategy struct {
|
||||
subnets []*net.IPNet
|
||||
}
|
||||
|
||||
func NewManagementSubnetStrategy(cidrs []string) (*ManagementSubnetStrategy, error) {
|
||||
nets := make([]*net.IPNet, 0, len(cidrs))
|
||||
for _, cidr := range cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err)
|
||||
}
|
||||
nets = append(nets, ipNet)
|
||||
}
|
||||
return &ManagementSubnetStrategy{subnets: nets}, nil
|
||||
}
|
||||
|
||||
func (s *ManagementSubnetStrategy) Name() string { return "management_subnet" }
|
||||
|
||||
func (s *ManagementSubnetStrategy) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) {
|
||||
if len(s.subnets) == 0 {
|
||||
return "", ErrNoIP
|
||||
}
|
||||
|
||||
ips, err := client.GetIPs(ctx, *entry)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetching IPs: %w", err)
|
||||
}
|
||||
|
||||
for _, rawIP := range ips {
|
||||
ip := net.ParseIP(rawIP)
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
for _, subnet := range s.subnets {
|
||||
if subnet.Contains(ip) {
|
||||
return rawIP, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", ErrNoIP
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// newIPServer returns a test server that always responds with the given IP list.
|
||||
func newIPServer(t *testing.T, ips []string) *httptest.Server {
|
||||
t.Helper()
|
||||
type result struct {
|
||||
Address string `json:"address"`
|
||||
}
|
||||
type response struct {
|
||||
Count int `json:"count"`
|
||||
Results []result `json:"results"`
|
||||
}
|
||||
resp := response{Count: len(ips)}
|
||||
for _, ip := range ips {
|
||||
resp.Results = append(resp.Results, result{Address: ip})
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(body)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_MatchesSubnet(t *testing.T) {
|
||||
srv := newIPServer(t, []string{"10.0.1.5/24", "192.168.0.1/24"})
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "10.0.1.5" {
|
||||
t.Errorf("got %q, want %q", ip, "10.0.1.5")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_NoMatch(t *testing.T) {
|
||||
srv := newIPServer(t, []string{"192.168.0.1/24"})
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
|
||||
_, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != ErrNoIP {
|
||||
t.Errorf("no matching subnet should return ErrNoIP, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_FirstMatchWins(t *testing.T) {
|
||||
srv := newIPServer(t, []string{"10.0.1.1/24", "10.0.1.2/24"})
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "10.0.1.1" {
|
||||
t.Errorf("got %q, want first matching IP %q", ip, "10.0.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_VMKind(t *testing.T) {
|
||||
srv := newIPServer(t, []string{"172.16.5.10/16"})
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"172.16.0.0/12"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 2, Kind: "vm"}, client)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "172.16.5.10" {
|
||||
t.Errorf("got %q, want %q", ip, "172.16.5.10")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_IPv6Subnet(t *testing.T) {
|
||||
srv := newIPServer(t, []string{"fd00::1/64"})
|
||||
defer srv.Close()
|
||||
|
||||
s, _ := NewManagementSubnetStrategy([]string{"fd00::/8"})
|
||||
client := netbox.NewClient(srv.URL, "token")
|
||||
|
||||
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "fd00::1" {
|
||||
t.Errorf("got %q, want %q", ip, "fd00::1")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// PrimaryIPStrategy returns the primary IP configured in NetBox.
|
||||
// Prefers IPv4, falls back to IPv6.
|
||||
type PrimaryIPStrategy struct{}
|
||||
|
||||
func (s *PrimaryIPStrategy) Name() string { return "primary_ip" }
|
||||
|
||||
func (s *PrimaryIPStrategy) Resolve(_ context.Context, entry *netbox.HostEntry, _ *netbox.Client) (string, error) {
|
||||
if entry.PrimaryIP4 != "" {
|
||||
return entry.PrimaryIP4, nil
|
||||
}
|
||||
if entry.PrimaryIP6 != "" {
|
||||
return entry.PrimaryIP6, nil
|
||||
}
|
||||
return "", ErrNoIP
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
func TestPrimaryIPStrategy_Name(t *testing.T) {
|
||||
s := &PrimaryIPStrategy{}
|
||||
if s.Name() != "primary_ip" {
|
||||
t.Errorf("Name: got %q, want %q", s.Name(), "primary_ip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrimaryIPStrategy_IPv4(t *testing.T) {
|
||||
s := &PrimaryIPStrategy{}
|
||||
e := &netbox.HostEntry{PrimaryIP4: "10.0.0.1", PrimaryIP6: "::1"}
|
||||
ip, err := s.Resolve(context.Background(), e, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "10.0.0.1" {
|
||||
t.Errorf("got %q, want IPv4 %q", ip, "10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrimaryIPStrategy_IPv6Fallback(t *testing.T) {
|
||||
s := &PrimaryIPStrategy{}
|
||||
e := &netbox.HostEntry{PrimaryIP6: "::1"}
|
||||
ip, err := s.Resolve(context.Background(), e, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ip != "::1" {
|
||||
t.Errorf("got %q, want IPv6 %q", ip, "::1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrimaryIPStrategy_NoIP(t *testing.T) {
|
||||
s := &PrimaryIPStrategy{}
|
||||
_, err := s.Resolve(context.Background(), &netbox.HostEntry{}, nil)
|
||||
if err != ErrNoIP {
|
||||
t.Errorf("got %v, want ErrNoIP", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_Name(t *testing.T) {
|
||||
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
|
||||
if s.Name() != "management_subnet" {
|
||||
t.Errorf("Name: got %q, want %q", s.Name(), "management_subnet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_InvalidCIDR(t *testing.T) {
|
||||
_, err := NewManagementSubnetStrategy([]string{"not-a-cidr"})
|
||||
if err == nil {
|
||||
t.Error("invalid CIDR should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_EmptyCIDRs(t *testing.T) {
|
||||
s, _ := NewManagementSubnetStrategy([]string{})
|
||||
_, err := s.Resolve(context.Background(), &netbox.HostEntry{}, nil)
|
||||
if err != ErrNoIP {
|
||||
t.Errorf("empty subnets should return ErrNoIP, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementSubnetStrategy_MultipleCIDRs(t *testing.T) {
|
||||
_, err := NewManagementSubnetStrategy([]string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"})
|
||||
if err != nil {
|
||||
t.Fatalf("valid CIDRs should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceNameStrategy_Name(t *testing.T) {
|
||||
s := &InterfaceNameStrategy{name: "mgmt0"}
|
||||
if s.Name() != "interface_name" {
|
||||
t.Errorf("Name: got %q, want %q", s.Name(), "interface_name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceNameStrategy_UnknownKind(t *testing.T) {
|
||||
s := &InterfaceNameStrategy{name: "eth0"}
|
||||
_, err := s.Resolve(context.Background(), &netbox.HostEntry{Kind: "unknown"}, nil)
|
||||
if err == nil {
|
||||
t.Error("unknown kind should return an error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// ErrNoIP is returned when a strategy cannot find a matching IP address.
|
||||
var ErrNoIP = errors.New("no matching IP found")
|
||||
|
||||
// Strategy is a single rule for resolving an IP address from a NetBox host entry.
|
||||
type Strategy interface {
|
||||
Name() string
|
||||
Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error)
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package ssh
|
||||
|
||||
import "strings"
|
||||
|
||||
// flagsWithArg lists all SSH flags that consume the following argument.
|
||||
var flagsWithArg = map[byte]bool{
|
||||
'b': true, 'c': true, 'D': true, 'E': true, 'e': true,
|
||||
'F': true, 'I': true, 'i': true, 'J': true, 'L': true,
|
||||
'l': true, 'm': true, 'o': true, 'O': true, 'p': true,
|
||||
'Q': true, 'R': true, 'S': true, 'w': true, 'W': true,
|
||||
}
|
||||
|
||||
// ParsedArgs holds the result of parsing SSH arguments.
|
||||
type ParsedArgs struct {
|
||||
Host string // hostname without the user@ prefix
|
||||
User string // empty if not specified
|
||||
DestIdx int // index in Args where [user@]host sits
|
||||
Args []string
|
||||
}
|
||||
|
||||
// Parse scans SSH arguments and extracts the destination ([user@]host).
|
||||
// Returns nil if no destination is found.
|
||||
func Parse(args []string) *ParsedArgs {
|
||||
i := 0
|
||||
for i < len(args) {
|
||||
arg := args[i]
|
||||
|
||||
// "--" ends option processing
|
||||
if arg == "--" {
|
||||
i++
|
||||
break
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "-") && len(arg) > 1 {
|
||||
flag := arg[1]
|
||||
if flagsWithArg[flag] {
|
||||
if len(arg) > 2 {
|
||||
// argument is attached, e.g. -p2222
|
||||
i++
|
||||
} else {
|
||||
// argument is the next element, e.g. -p 2222
|
||||
i += 2
|
||||
}
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// first non-flag argument is the destination
|
||||
host, user := splitUserHost(arg)
|
||||
return &ParsedArgs{
|
||||
Host: host,
|
||||
User: user,
|
||||
DestIdx: i,
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
|
||||
// handle arguments after "--"
|
||||
if i < len(args) {
|
||||
host, user := splitUserHost(args[i])
|
||||
return &ParsedArgs{
|
||||
Host: host,
|
||||
User: user,
|
||||
DestIdx: i,
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplaceHost returns a copy of args with the destination replaced by newHost,
|
||||
// preserving any user@ prefix.
|
||||
func ReplaceHost(args []string, destIdx int, newHost string) []string {
|
||||
result := make([]string, len(args))
|
||||
copy(result, args)
|
||||
|
||||
original := args[destIdx]
|
||||
if at := strings.Index(original, "@"); at != -1 {
|
||||
result[destIdx] = original[:at+1] + newHost
|
||||
} else {
|
||||
result[destIdx] = newHost
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// HasUserFlag reports whether a user was specified via -l in args.
|
||||
// Used to avoid overriding an explicit -l with the configured default user.
|
||||
func HasUserFlag(args []string) bool {
|
||||
for i, a := range args {
|
||||
if a == "-l" && i+1 < len(args) {
|
||||
return true
|
||||
}
|
||||
// handle attached form: -lroot
|
||||
if len(a) > 2 && a[0] == '-' && a[1] == 'l' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func splitUserHost(dest string) (host, user string) {
|
||||
if at := strings.Index(dest, "@"); at != -1 {
|
||||
return dest[at+1:], dest[:at]
|
||||
}
|
||||
return dest, ""
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse_BareHostname(t *testing.T) {
|
||||
got := Parse([]string{"myhost"})
|
||||
assertParsed(t, got, "myhost", "", 0)
|
||||
}
|
||||
|
||||
func TestParse_UserAtHost(t *testing.T) {
|
||||
got := Parse([]string{"admin@myhost"})
|
||||
assertParsed(t, got, "myhost", "admin", 0)
|
||||
}
|
||||
|
||||
func TestParse_PortFlag_Separated(t *testing.T) {
|
||||
got := Parse([]string{"-p", "2222", "myhost"})
|
||||
assertParsed(t, got, "myhost", "", 2)
|
||||
}
|
||||
|
||||
func TestParse_PortFlag_Attached(t *testing.T) {
|
||||
got := Parse([]string{"-p2222", "myhost"})
|
||||
assertParsed(t, got, "myhost", "", 1)
|
||||
}
|
||||
|
||||
func TestParse_IdentityFlag(t *testing.T) {
|
||||
got := Parse([]string{"-i", "/path/to/key", "user@myhost", "ls"})
|
||||
assertParsed(t, got, "myhost", "user", 2)
|
||||
}
|
||||
|
||||
func TestParse_VerboseFlag(t *testing.T) {
|
||||
got := Parse([]string{"-v", "myhost"})
|
||||
assertParsed(t, got, "myhost", "", 1)
|
||||
}
|
||||
|
||||
func TestParse_OptionFlag(t *testing.T) {
|
||||
got := Parse([]string{"-o", "StrictHostKeyChecking=no", "myhost"})
|
||||
assertParsed(t, got, "myhost", "", 2)
|
||||
}
|
||||
|
||||
func TestParse_JumpHost(t *testing.T) {
|
||||
got := Parse([]string{"-J", "jumphost", "-p", "22", "target"})
|
||||
assertParsed(t, got, "target", "", 4)
|
||||
}
|
||||
|
||||
func TestParse_MultipleFlags(t *testing.T) {
|
||||
got := Parse([]string{"-v", "-p", "22", "-i", "key", "root@host", "uptime"})
|
||||
assertParsed(t, got, "host", "root", 5)
|
||||
}
|
||||
|
||||
func TestParse_DoubleDash(t *testing.T) {
|
||||
got := Parse([]string{"--", "myhost"})
|
||||
assertParsed(t, got, "myhost", "", 1)
|
||||
}
|
||||
|
||||
func TestParse_DoubleDash_WithFlags(t *testing.T) {
|
||||
// flags after -- should be treated as destination
|
||||
got := Parse([]string{"-v", "--", "-not-a-flag"})
|
||||
assertParsed(t, got, "-not-a-flag", "", 2)
|
||||
}
|
||||
|
||||
func TestParse_NoDestination(t *testing.T) {
|
||||
got := Parse([]string{"-v", "-p", "2222"})
|
||||
if got != nil {
|
||||
t.Errorf("expected nil for args without destination, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_EmptyArgs(t *testing.T) {
|
||||
got := Parse([]string{})
|
||||
if got != nil {
|
||||
t.Error("empty args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_OnlyDoubleDash(t *testing.T) {
|
||||
got := Parse([]string{"--"})
|
||||
if got != nil {
|
||||
t.Error("only -- with no destination should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceHost_PlainHost(t *testing.T) {
|
||||
args := []string{"myhost"}
|
||||
result := ReplaceHost(args, 0, "10.0.0.1")
|
||||
if result[0] != "10.0.0.1" {
|
||||
t.Errorf("got %q, want %q", result[0], "10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceHost_PreservesUserPrefix(t *testing.T) {
|
||||
args := []string{"-p", "22", "admin@myhost", "ls"}
|
||||
result := ReplaceHost(args, 2, "10.0.0.1")
|
||||
if result[2] != "admin@10.0.0.1" {
|
||||
t.Errorf("got %q, want %q", result[2], "admin@10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceHost_DoesNotMutateOriginal(t *testing.T) {
|
||||
args := []string{"myhost"}
|
||||
_ = ReplaceHost(args, 0, "10.0.0.1")
|
||||
if args[0] != "myhost" {
|
||||
t.Error("ReplaceHost must not mutate the original slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceHost_OtherArgsUnchanged(t *testing.T) {
|
||||
args := []string{"-p", "22", "myhost"}
|
||||
result := ReplaceHost(args, 2, "10.0.0.1")
|
||||
if result[0] != "-p" || result[1] != "22" {
|
||||
t.Errorf("other args should be unchanged: %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasUserFlag_FlagSeparated(t *testing.T) {
|
||||
if !HasUserFlag([]string{"-l", "admin", "host"}) {
|
||||
t.Error("should detect -l <user>")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasUserFlag_FlagAttached(t *testing.T) {
|
||||
if !HasUserFlag([]string{"-ladmin", "host"}) {
|
||||
t.Error("should detect -l<user> (attached form)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasUserFlag_NotPresent(t *testing.T) {
|
||||
if HasUserFlag([]string{"-p", "22", "host"}) {
|
||||
t.Error("should not detect user flag when absent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasUserFlag_EmptyArgs(t *testing.T) {
|
||||
if HasUserFlag([]string{}) {
|
||||
t.Error("empty args should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasUserFlag_LFlagAtEnd(t *testing.T) {
|
||||
// -l at the very end with no value — should not panic
|
||||
if HasUserFlag([]string{"-l"}) {
|
||||
t.Error("-l with no value should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func assertParsed(t *testing.T, got *ParsedArgs, host, user string, destIdx int) {
|
||||
t.Helper()
|
||||
if got == nil {
|
||||
t.Fatal("Parse returned nil")
|
||||
}
|
||||
if got.Host != host {
|
||||
t.Errorf("host: got %q, want %q", got.Host, host)
|
||||
}
|
||||
if got.User != user {
|
||||
t.Errorf("user: got %q, want %q", got.User, user)
|
||||
}
|
||||
if got.DestIdx != destIdx {
|
||||
t.Errorf("destIdx: got %d, want %d", got.DestIdx, destIdx)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Exec replaces the current process with the native ssh client via syscall.Exec.
|
||||
// All existing SSH configs, keys, and agent forwarding remain intact.
|
||||
func Exec(args []string) error {
|
||||
sshPath, err := exec.LookPath("ssh")
|
||||
if err != nil {
|
||||
return fmt.Errorf("ssh not found in PATH: %w", err)
|
||||
}
|
||||
argv := append([]string{"ssh"}, args...)
|
||||
return syscall.Exec(sshPath, argv, os.Environ())
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/cache"
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
|
||||
)
|
||||
|
||||
// SelectedHost is returned when the user confirms a host in the TUI.
|
||||
type SelectedHost struct {
|
||||
Name string
|
||||
IP string
|
||||
}
|
||||
|
||||
// --- bubbletea messages ---
|
||||
|
||||
type debounceMsg struct{ query string }
|
||||
|
||||
type searchResultMsg struct {
|
||||
query string
|
||||
entries []netbox.HostEntry
|
||||
err error
|
||||
}
|
||||
|
||||
// --- list item ---
|
||||
|
||||
type hostItem struct {
|
||||
name string
|
||||
ip string
|
||||
kind string
|
||||
}
|
||||
|
||||
func (h hostItem) Title() string { return h.name }
|
||||
func (h hostItem) Description() string { return fmt.Sprintf("%s [%s]", h.ip, h.kind) }
|
||||
func (h hostItem) FilterValue() string { return h.name }
|
||||
|
||||
// --- compact list delegate ---
|
||||
|
||||
type compactDelegate struct{}
|
||||
|
||||
func (d compactDelegate) Height() int { return 1 }
|
||||
func (d compactDelegate) Spacing() int { return 0 }
|
||||
func (d compactDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil }
|
||||
func (d compactDelegate) Render(w io.Writer, m list.Model, index int, item list.Item) {
|
||||
h, ok := item.(hostItem)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
line := fmt.Sprintf(" %s %s", h.name, lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(h.ip))
|
||||
if index == m.Index() {
|
||||
line = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("86")).Render("> " + strings.TrimPrefix(line, " "))
|
||||
}
|
||||
fmt.Fprintln(w, line)
|
||||
}
|
||||
|
||||
// --- Model ---
|
||||
|
||||
type Model struct {
|
||||
input textinput.Model
|
||||
list list.Model
|
||||
client *netbox.Client
|
||||
cache *cache.Cache
|
||||
lastSent string // last query sent to NetBox (or served from cache)
|
||||
seq int // sequence number to discard stale results
|
||||
loading bool
|
||||
err error
|
||||
selected *SelectedHost
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func New(client *netbox.Client, c *cache.Cache) *Model {
|
||||
ti := textinput.New()
|
||||
ti.Placeholder = "Search hostname…"
|
||||
ti.Focus()
|
||||
|
||||
l := list.New(nil, compactDelegate{}, 0, 0)
|
||||
l.SetShowHelp(false)
|
||||
l.SetShowTitle(false)
|
||||
l.SetShowStatusBar(false)
|
||||
l.SetFilteringEnabled(false)
|
||||
|
||||
return &Model{
|
||||
input: ti,
|
||||
list: l,
|
||||
client: client,
|
||||
cache: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) Init() tea.Cmd {
|
||||
return textinput.Blink
|
||||
}
|
||||
|
||||
func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.list.SetSize(msg.Width, msg.Height-4)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "esc":
|
||||
return m, tea.Quit
|
||||
|
||||
case "enter":
|
||||
if item, ok := m.list.SelectedItem().(hostItem); ok {
|
||||
m.selected = &SelectedHost{Name: item.name, IP: item.ip}
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case "tab":
|
||||
// Copy the top result into the search field.
|
||||
if m.list.Items() != nil && len(m.list.Items()) > 0 {
|
||||
if item, ok := m.list.Items()[0].(hostItem); ok {
|
||||
m.input.SetValue(item.name)
|
||||
m.input.CursorEnd()
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
case debounceMsg:
|
||||
// Only query if the input has changed since the last request.
|
||||
q := m.input.Value()
|
||||
if q == m.lastSent {
|
||||
return m, nil
|
||||
}
|
||||
m.lastSent = q
|
||||
m.loading = true
|
||||
m.seq++
|
||||
seq := m.seq
|
||||
return m, m.doSearch(q, seq)
|
||||
|
||||
case searchResultMsg:
|
||||
if msg.query != m.lastSent {
|
||||
return m, nil // discard stale result
|
||||
}
|
||||
m.loading = false
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
return m, nil
|
||||
}
|
||||
items := make([]list.Item, len(msg.entries))
|
||||
for i, e := range msg.entries {
|
||||
ip := e.PrimaryIP4
|
||||
if ip == "" {
|
||||
ip = e.PrimaryIP6
|
||||
}
|
||||
items[i] = hostItem{name: e.Name, ip: ip, kind: e.Kind}
|
||||
}
|
||||
m.list.SetItems(items)
|
||||
m.err = nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward to text input and restart the debounce timer.
|
||||
var cmds []tea.Cmd
|
||||
var inputCmd tea.Cmd
|
||||
m.input, inputCmd = m.input.Update(msg)
|
||||
cmds = append(cmds, inputCmd)
|
||||
cmds = append(cmds, m.startDebounce())
|
||||
|
||||
var listCmd tea.Cmd
|
||||
m.list, listCmd = m.list.Update(msg)
|
||||
cmds = append(cmds, listCmd)
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *Model) View() string {
|
||||
var sb strings.Builder
|
||||
|
||||
title := lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("86")).Render("netssh")
|
||||
sb.WriteString(title + "\n\n")
|
||||
sb.WriteString(m.input.View() + "\n")
|
||||
|
||||
if m.loading {
|
||||
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(" searching…") + "\n")
|
||||
} else if m.err != nil {
|
||||
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Render(" error: "+m.err.Error()) + "\n")
|
||||
} else {
|
||||
sb.WriteString(m.list.View())
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// Selected returns the host chosen by the user, or nil if none was selected.
|
||||
func (m *Model) Selected() *SelectedHost {
|
||||
return m.selected
|
||||
}
|
||||
|
||||
func (m *Model) startDebounce() tea.Cmd {
|
||||
return tea.Tick(300*time.Millisecond, func(_ time.Time) tea.Msg {
|
||||
return debounceMsg{query: m.input.Value()}
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Model) doSearch(query string, seq int) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Return cache hits immediately without a network round-trip.
|
||||
if m.cache != nil {
|
||||
if cached := m.cache.Search(query); len(cached) > 0 {
|
||||
entries := make([]netbox.HostEntry, len(cached))
|
||||
for i, c := range cached {
|
||||
entries[i] = netbox.HostEntry{Name: c.Name, PrimaryIP4: c.IP, Kind: c.Kind}
|
||||
}
|
||||
return searchResultMsg{query: query, entries: entries}
|
||||
}
|
||||
}
|
||||
|
||||
if m.client == nil {
|
||||
return searchResultMsg{query: query, entries: nil}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
entries, err := m.client.Search(ctx, query)
|
||||
_ = seq
|
||||
return searchResultMsg{query: query, entries: entries, err: err}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user