Add core modules (SSH args parser, cache, resolver, NetBox client) with tests
Release / release (push) Failing after 51s

This commit is contained in:
Sebastian Unterschütz
2026-05-23 12:38:41 +02:00
commit 8ef4bbec16
24 changed files with 2524 additions and 0 deletions
+134
View File
@@ -0,0 +1,134 @@
package cache
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type Entry struct {
Name string `json:"name"`
IP string `json:"ip"`
Kind string `json:"kind"`
Tags []string `json:"tags,omitempty"`
CachedAt time.Time `json:"cached_at"`
}
type Cache struct {
mu sync.RWMutex
entries map[string]Entry
path string
ttl time.Duration
}
type diskFormat struct {
Entries []Entry `json:"entries"`
}
func New(path string, ttlSeconds int) *Cache {
return &Cache{
entries: make(map[string]Entry),
path: path,
ttl: time.Duration(ttlSeconds) * time.Second,
}
}
func (c *Cache) Load() error {
c.mu.Lock()
defer c.mu.Unlock()
data, err := os.ReadFile(c.path)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return err
}
var df diskFormat
if err := json.Unmarshal(data, &df); err != nil {
return err
}
c.entries = make(map[string]Entry, len(df.Entries))
for _, e := range df.Entries {
c.entries[e.Name] = e
}
return nil
}
func (c *Cache) Save() error {
c.mu.RLock()
df := diskFormat{Entries: make([]Entry, 0, len(c.entries))}
for _, e := range c.entries {
df.Entries = append(df.Entries, e)
}
c.mu.RUnlock()
if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(df, "", " ")
if err != nil {
return err
}
return os.WriteFile(c.path, data, 0o644)
}
func (c *Cache) Upsert(e Entry) {
e.CachedAt = time.Now()
c.mu.Lock()
c.entries[e.Name] = e
c.mu.Unlock()
}
// Search returns all entries whose name starts with prefix (case-insensitive).
// TTL is intentionally ignored — this is used for shell completion.
func (c *Cache) Search(prefix string) []Entry {
c.mu.RLock()
defer c.mu.RUnlock()
prefix = strings.ToLower(prefix)
var out []Entry
for name, e := range c.entries {
if strings.HasPrefix(strings.ToLower(name), prefix) {
out = append(out, e)
}
}
return out
}
// Get returns an entry and reports whether it is still within the TTL.
func (c *Cache) Get(name string) (entry Entry, fresh bool) {
c.mu.RLock()
e, ok := c.entries[name]
c.mu.RUnlock()
if !ok {
return Entry{}, false
}
if c.ttl == 0 {
return e, false
}
return e, time.Since(e.CachedAt) < c.ttl
}
func (c *Cache) Clear() {
c.mu.Lock()
c.entries = make(map[string]Entry)
c.mu.Unlock()
}
func (c *Cache) All() []Entry {
c.mu.RLock()
defer c.mu.RUnlock()
out := make([]Entry, 0, len(c.entries))
for _, e := range c.entries {
out = append(out, e)
}
return out
}
+235
View File
@@ -0,0 +1,235 @@
package cache
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
)
func TestNew(t *testing.T) {
c := New("/tmp/test.json", 60)
if c == nil {
t.Fatal("New returned nil")
}
if c.ttl != 60*time.Second {
t.Errorf("ttl: got %v, want %v", c.ttl, 60*time.Second)
}
}
func TestLoad_MissingFile(t *testing.T) {
c := New("/nonexistent/path/cache.json", 60)
if err := c.Load(); err != nil {
t.Errorf("Load on missing file should not error, got: %v", err)
}
}
func TestLoad_InvalidJSON(t *testing.T) {
f := tempFile(t, []byte("not json"))
c := New(f, 60)
if err := c.Load(); err == nil {
t.Error("Load on invalid JSON should return an error")
}
}
func TestSaveAndLoad_Roundtrip(t *testing.T) {
path := filepath.Join(t.TempDir(), "cache.json")
c := New(path, 3600)
c.Upsert(Entry{Name: "host-a", IP: "10.0.0.1", Kind: "device"})
c.Upsert(Entry{Name: "host-b", IP: "10.0.0.2", Kind: "vm", Tags: []string{"prod"}})
if err := c.Save(); err != nil {
t.Fatalf("Save: %v", err)
}
c2 := New(path, 3600)
if err := c2.Load(); err != nil {
t.Fatalf("Load: %v", err)
}
e, _ := c2.Get("host-a")
if e.IP != "10.0.0.1" {
t.Errorf("host-a IP: got %q, want %q", e.IP, "10.0.0.1")
}
e2, _ := c2.Get("host-b")
if len(e2.Tags) != 1 || e2.Tags[0] != "prod" {
t.Errorf("host-b tags: got %v, want [prod]", e2.Tags)
}
}
func TestSave_CreatesDirectory(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "sub", "dir", "cache.json")
c := New(path, 60)
c.Upsert(Entry{Name: "x", IP: "1.2.3.4", Kind: "device"})
if err := c.Save(); err != nil {
t.Fatalf("Save: %v", err)
}
if _, err := os.Stat(path); err != nil {
t.Errorf("cache file not created: %v", err)
}
}
func TestUpsert_SetsTimestamp(t *testing.T) {
c := New("", 60)
before := time.Now()
c.Upsert(Entry{Name: "h", IP: "1.1.1.1", Kind: "device"})
e, _ := c.Get("h")
if e.CachedAt.Before(before) {
t.Error("CachedAt should be set to current time on Upsert")
}
}
func TestUpsert_Overwrites(t *testing.T) {
c := New("", 60)
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
c.Upsert(Entry{Name: "host", IP: "10.0.0.2", Kind: "device"})
e, _ := c.Get("host")
if e.IP != "10.0.0.2" {
t.Errorf("Upsert should overwrite: got %q, want %q", e.IP, "10.0.0.2")
}
}
func TestSearch_PrefixMatch(t *testing.T) {
c := New("", 60)
c.Upsert(Entry{Name: "app-server-01", IP: "10.0.0.1", Kind: "device"})
c.Upsert(Entry{Name: "app-server-02", IP: "10.0.0.2", Kind: "vm"})
c.Upsert(Entry{Name: "db-server-01", IP: "10.0.0.3", Kind: "device"})
results := c.Search("app")
if len(results) != 2 {
t.Errorf("Search(app): got %d results, want 2", len(results))
}
}
func TestSearch_CaseInsensitive(t *testing.T) {
c := New("", 60)
c.Upsert(Entry{Name: "App-Server", IP: "10.0.0.1", Kind: "device"})
if len(c.Search("app")) != 1 {
t.Error("Search should be case-insensitive")
}
if len(c.Search("APP")) != 1 {
t.Error("Search should be case-insensitive for uppercase")
}
}
func TestSearch_EmptyPrefix(t *testing.T) {
c := New("", 60)
c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"})
c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"})
if len(c.Search("")) != 2 {
t.Error("Search('') should return all entries")
}
}
func TestSearch_NoMatch(t *testing.T) {
c := New("", 60)
c.Upsert(Entry{Name: "host", IP: "1.1.1.1", Kind: "device"})
if len(c.Search("xyz")) != 0 {
t.Error("Search should return empty slice when no match")
}
}
func TestGet_Fresh(t *testing.T) {
c := New("", 3600)
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
e, fresh := c.Get("host")
if !fresh {
t.Error("entry just inserted should be fresh")
}
if e.IP != "10.0.0.1" {
t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.1")
}
}
func TestGet_Expired(t *testing.T) {
c := New("", 1) // 1 second TTL
e := Entry{Name: "host", IP: "10.0.0.1", Kind: "device", CachedAt: time.Now().Add(-2 * time.Second)}
c.mu.Lock()
c.entries["host"] = e
c.mu.Unlock()
_, fresh := c.Get("host")
if fresh {
t.Error("entry older than TTL should not be fresh")
}
}
func TestGet_ZeroTTL_AlwaysStale(t *testing.T) {
c := New("", 0) // TTL=0 means never fresh for connect mode
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"})
_, fresh := c.Get("host")
if fresh {
t.Error("TTL=0 should always return fresh=false")
}
}
func TestGet_Missing(t *testing.T) {
c := New("", 60)
_, fresh := c.Get("nonexistent")
if fresh {
t.Error("missing entry should not be fresh")
}
}
func TestClear(t *testing.T) {
c := New("", 60)
c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"})
c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"})
c.Clear()
if len(c.All()) != 0 {
t.Error("Clear should remove all entries")
}
}
func TestAll(t *testing.T) {
c := New("", 60)
c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"})
c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"})
all := c.All()
if len(all) != 2 {
t.Errorf("All: got %d entries, want 2", len(all))
}
}
func TestSave_ProducesValidJSON(t *testing.T) {
path := filepath.Join(t.TempDir(), "cache.json")
c := New(path, 60)
c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device", Tags: []string{"mgmt"}})
if err := c.Save(); err != nil {
t.Fatalf("Save: %v", err)
}
data, _ := os.ReadFile(path)
var df diskFormat
if err := json.Unmarshal(data, &df); err != nil {
t.Fatalf("saved file is not valid JSON: %v", err)
}
if len(df.Entries) != 1 {
t.Errorf("expected 1 entry in JSON, got %d", len(df.Entries))
}
}
// tempFile writes content to a temp file and returns its path.
func tempFile(t *testing.T, content []byte) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "cache-*.json")
if err != nil {
t.Fatal(err)
}
if _, err := f.Write(content); err != nil {
t.Fatal(err)
}
f.Close()
return f.Name()
}
+76
View File
@@ -0,0 +1,76 @@
package config
import (
"fmt"
"os"
"path/filepath"
"github.com/spf13/viper"
)
type Config struct {
NetBox NetBoxConfig `mapstructure:"netbox"`
Resolver ResolverConfig `mapstructure:"resolver"`
Cache CacheConfig `mapstructure:"cache"`
SSH SSHConfig `mapstructure:"ssh"`
}
type NetBoxConfig struct {
URL string `mapstructure:"url"`
Token string `mapstructure:"token"`
}
type ResolverConfig struct {
Strategies []string `mapstructure:"strategies"`
ManagementSubnets []string `mapstructure:"management_subnets"`
InterfaceName string `mapstructure:"interface_name"`
}
type CacheConfig struct {
TTL int `mapstructure:"ttl"`
Path string `mapstructure:"path"`
}
type SSHConfig struct {
DefaultUser string `mapstructure:"default_user"`
}
func Load() (*Config, error) {
v := viper.New()
v.SetDefault("resolver.strategies", []string{"management_subnet", "primary_ip"})
v.SetDefault("resolver.management_subnets", []string{})
v.SetDefault("cache.ttl", 3600)
configDir, err := os.UserConfigDir()
if err == nil {
v.SetConfigName("netssh")
v.SetConfigType("yaml")
v.AddConfigPath(filepath.Join(configDir))
v.AddConfigPath(".")
}
v.SetEnvPrefix("NETSSH")
v.AutomaticEnv()
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("reading config: %w", err)
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
if cfg.Cache.Path == "" {
cacheDir, err := os.UserCacheDir()
if err != nil {
cacheDir = filepath.Join(os.Getenv("HOME"), ".cache")
}
cfg.Cache.Path = filepath.Join(cacheDir, "netssh", "hosts.json")
}
return &cfg, nil
}
+194
View File
@@ -0,0 +1,194 @@
package netbox
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
)
type Client struct {
baseURL string
token string
httpClient *http.Client
}
func NewClient(baseURL, token string) *Client {
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
token: token,
httpClient: &http.Client{},
}
}
// Search queries devices and VMs in parallel and merges the results.
func (c *Client) Search(ctx context.Context, query string) ([]HostEntry, error) {
var (
mu sync.Mutex
results []HostEntry
errs []error
wg sync.WaitGroup
)
wg.Add(2)
go func() {
defer wg.Done()
devices, err := c.searchDevices(ctx, query)
mu.Lock()
defer mu.Unlock()
if err != nil {
errs = append(errs, fmt.Errorf("devices: %w", err))
return
}
results = append(results, devices...)
}()
go func() {
defer wg.Done()
vms, err := c.searchVMs(ctx, query)
mu.Lock()
defer mu.Unlock()
if err != nil {
errs = append(errs, fmt.Errorf("vms: %w", err))
return
}
results = append(results, vms...)
}()
wg.Wait()
if len(errs) == 2 {
return nil, fmt.Errorf("netbox search failed: %v; %v", errs[0], errs[1])
}
return results, nil
}
// GetIPs returns all IP addresses assigned to a host, used by resolver strategies
// that need more than just the primary IP.
func (c *Client) GetIPs(ctx context.Context, entry HostEntry) ([]string, error) {
var apiURL string
switch entry.Kind {
case "device":
apiURL = fmt.Sprintf("%s/api/ipam/ip-addresses/?device_id=%d&limit=100", c.baseURL, entry.ID)
case "vm":
apiURL = fmt.Sprintf("%s/api/ipam/ip-addresses/?virtual_machine_id=%d&limit=100", c.baseURL, entry.ID)
default:
return nil, fmt.Errorf("unknown host kind: %q", entry.Kind)
}
var resp netboxIPListResponse
if err := c.get(ctx, apiURL, &resp); err != nil {
return nil, err
}
ips := make([]string, 0, len(resp.Results))
for _, r := range resp.Results {
ips = append(ips, stripPrefix(r.Address))
}
return ips, nil
}
// GetIPsWithFilter calls /api/ipam/ip-addresses/ with arbitrary filter query parameters.
func (c *Client) GetIPsWithFilter(ctx context.Context, filterParams string) ([]string, error) {
apiURL := fmt.Sprintf("%s/api/ipam/ip-addresses/?%s&limit=100", c.baseURL, filterParams)
var resp netboxIPListResponse
if err := c.get(ctx, apiURL, &resp); err != nil {
return nil, err
}
ips := make([]string, 0, len(resp.Results))
for _, r := range resp.Results {
ips = append(ips, stripPrefix(r.Address))
}
return ips, nil
}
func (c *Client) searchDevices(ctx context.Context, query string) ([]HostEntry, error) {
apiURL := fmt.Sprintf("%s/api/dcim/devices/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query))
var resp netboxListResponse[netboxDevice]
if err := c.get(ctx, apiURL, &resp); err != nil {
return nil, err
}
entries := make([]HostEntry, 0, len(resp.Results))
for _, d := range resp.Results {
entries = append(entries, deviceToEntry(d))
}
return entries, nil
}
func (c *Client) searchVMs(ctx context.Context, query string) ([]HostEntry, error) {
apiURL := fmt.Sprintf("%s/api/virtualization/virtual-machines/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query))
var resp netboxListResponse[netboxVM]
if err := c.get(ctx, apiURL, &resp); err != nil {
return nil, err
}
entries := make([]HostEntry, 0, len(resp.Results))
for _, v := range resp.Results {
entries = append(entries, vmToEntry(v))
}
return entries, nil
}
func (c *Client) get(ctx context.Context, apiURL string, out any) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
if err != nil {
return fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Authorization", "Token "+c.token)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("request to %s: %w", apiURL, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("netbox returned %d for %s", resp.StatusCode, apiURL)
}
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
return fmt.Errorf("decoding response: %w", err)
}
return nil
}
func deviceToEntry(d netboxDevice) HostEntry {
e := HostEntry{ID: d.ID, Name: d.Name, Kind: "device"}
if d.PrimaryIP4 != nil {
e.PrimaryIP4 = stripPrefix(d.PrimaryIP4.Address)
}
if d.PrimaryIP6 != nil {
e.PrimaryIP6 = stripPrefix(d.PrimaryIP6.Address)
}
for _, t := range d.Tags {
e.Tags = append(e.Tags, t.Name)
}
return e
}
func vmToEntry(v netboxVM) HostEntry {
e := HostEntry{ID: v.ID, Name: v.Name, Kind: "vm"}
if v.PrimaryIP4 != nil {
e.PrimaryIP4 = stripPrefix(v.PrimaryIP4.Address)
}
if v.PrimaryIP6 != nil {
e.PrimaryIP6 = stripPrefix(v.PrimaryIP6.Address)
}
for _, t := range v.Tags {
e.Tags = append(e.Tags, t.Name)
}
return e
}
// stripPrefix removes the CIDR prefix length from a NetBox IP (e.g. "10.0.1.5/24" → "10.0.1.5").
func stripPrefix(cidr string) string {
if idx := strings.Index(cidr, "/"); idx != -1 {
return cidr[:idx]
}
return cidr
}
+261
View File
@@ -0,0 +1,261 @@
package netbox
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
// newTestServer returns an httptest.Server that serves fixed responses per path.
func newTestServer(t *testing.T, handlers map[string]any) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
for path, body := range handlers {
b, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshalling handler for %s: %v", path, err)
}
captured := b
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(captured)
})
}
return httptest.NewServer(mux)
}
func deviceListResponse(devices ...netboxDevice) netboxListResponse[netboxDevice] {
return netboxListResponse[netboxDevice]{Count: len(devices), Results: devices}
}
func vmListResponse(vms ...netboxVM) netboxListResponse[netboxVM] {
return netboxListResponse[netboxVM]{Count: len(vms), Results: vms}
}
func ipListResponse(addrs ...string) netboxIPListResponse {
resp := netboxIPListResponse{Count: len(addrs)}
for _, a := range addrs {
resp.Results = append(resp.Results, struct {
Address string `json:"address"`
Interface *struct {
Name string `json:"name"`
} `json:"assigned_object"`
}{Address: a})
}
return resp
}
func TestSearch_ReturnsBothDevicesAndVMs(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/dcim/devices/": deviceListResponse(
netboxDevice{ID: 1, Name: "router-01", PrimaryIP4: &netboxIP{Address: "10.0.0.1/24"}},
),
"/api/virtualization/virtual-machines/": vmListResponse(
netboxVM{ID: 2, Name: "vm-01", PrimaryIP4: &netboxIP{Address: "10.0.0.2/24"}},
),
})
defer srv.Close()
c := NewClient(srv.URL, "token")
results, err := c.Search(context.Background(), "")
if err != nil {
t.Fatalf("Search: %v", err)
}
if len(results) != 2 {
t.Errorf("got %d results, want 2", len(results))
}
names := map[string]bool{}
for _, r := range results {
names[r.Name] = true
}
if !names["router-01"] || !names["vm-01"] {
t.Errorf("missing expected hosts in results: %v", names)
}
}
func TestSearch_MapsKindCorrectly(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/dcim/devices/": deviceListResponse(
netboxDevice{ID: 1, Name: "sw-01"},
),
"/api/virtualization/virtual-machines/": vmListResponse(
netboxVM{ID: 2, Name: "vm-01"},
),
})
defer srv.Close()
c := NewClient(srv.URL, "token")
results, _ := c.Search(context.Background(), "")
for _, r := range results {
switch r.Name {
case "sw-01":
if r.Kind != "device" {
t.Errorf("sw-01 kind: got %q, want %q", r.Kind, "device")
}
case "vm-01":
if r.Kind != "vm" {
t.Errorf("vm-01 kind: got %q, want %q", r.Kind, "vm")
}
}
}
}
func TestSearch_StripsPrefixFromPrimaryIP(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/dcim/devices/": deviceListResponse(
netboxDevice{ID: 1, Name: "host", PrimaryIP4: &netboxIP{Address: "192.168.1.10/24"}},
),
"/api/virtualization/virtual-machines/": vmListResponse(),
})
defer srv.Close()
c := NewClient(srv.URL, "token")
results, _ := c.Search(context.Background(), "host")
if len(results) == 0 {
t.Fatal("expected at least one result")
}
if results[0].PrimaryIP4 != "192.168.1.10" {
t.Errorf("PrimaryIP4: got %q, want %q", results[0].PrimaryIP4, "192.168.1.10")
}
}
func TestSearch_TagsAreMapped(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/dcim/devices/": deviceListResponse(
netboxDevice{
ID: 1,
Name: "host",
Tags: []struct {
Name string `json:"name"`
}{{Name: "prod"}, {Name: "mgmt"}},
},
),
"/api/virtualization/virtual-machines/": vmListResponse(),
})
defer srv.Close()
c := NewClient(srv.URL, "token")
results, _ := c.Search(context.Background(), "")
if len(results[0].Tags) != 2 {
t.Errorf("tags: got %v, want [prod mgmt]", results[0].Tags)
}
}
func TestSearch_PartialFailure_ReturnsAvailableResults(t *testing.T) {
// Only devices endpoint works; VMs returns 500.
mux := http.NewServeMux()
body, _ := json.Marshal(deviceListResponse(netboxDevice{ID: 1, Name: "sw-01"}))
mux.HandleFunc("/api/dcim/devices/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(body)
})
mux.HandleFunc("/api/virtualization/virtual-machines/", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal error", http.StatusInternalServerError)
})
srv := httptest.NewServer(mux)
defer srv.Close()
c := NewClient(srv.URL, "token")
results, err := c.Search(context.Background(), "")
if err != nil {
t.Fatalf("partial failure should not return error, got: %v", err)
}
if len(results) != 1 || results[0].Name != "sw-01" {
t.Errorf("expected device results, got %v", results)
}
}
func TestSearch_BothFail_ReturnsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", http.StatusInternalServerError)
})
srv := httptest.NewServer(mux)
defer srv.Close()
c := NewClient(srv.URL, "token")
_, err := c.Search(context.Background(), "")
if err == nil {
t.Error("both endpoints failing should return an error")
}
}
func TestGetIPs_Device(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/ipam/ip-addresses/": ipListResponse("10.0.0.1/24", "10.0.0.2/24"),
})
defer srv.Close()
c := NewClient(srv.URL, "token")
ips, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "device"})
if err != nil {
t.Fatalf("GetIPs: %v", err)
}
if len(ips) != 2 {
t.Errorf("got %d IPs, want 2", len(ips))
}
if ips[0] != "10.0.0.1" || ips[1] != "10.0.0.2" {
t.Errorf("IPs: got %v, want [10.0.0.1 10.0.0.2]", ips)
}
}
func TestGetIPs_VM(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/ipam/ip-addresses/": ipListResponse("172.16.0.5/16"),
})
defer srv.Close()
c := NewClient(srv.URL, "token")
ips, err := c.GetIPs(context.Background(), HostEntry{ID: 2, Kind: "vm"})
if err != nil {
t.Fatalf("GetIPs: %v", err)
}
if len(ips) != 1 || ips[0] != "172.16.0.5" {
t.Errorf("IPs: got %v, want [172.16.0.5]", ips)
}
}
func TestGetIPs_UnknownKind(t *testing.T) {
c := NewClient("http://localhost", "token")
_, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "unknown"})
if err == nil {
t.Error("unknown kind should return an error")
}
}
func TestGetIPsWithFilter(t *testing.T) {
srv := newTestServer(t, map[string]any{
"/api/ipam/ip-addresses/": ipListResponse("10.10.10.1/24"),
})
defer srv.Close()
c := NewClient(srv.URL, "token")
ips, err := c.GetIPsWithFilter(context.Background(), "device_id=1&interface_name=mgmt0")
if err != nil {
t.Fatalf("GetIPsWithFilter: %v", err)
}
if len(ips) != 1 || ips[0] != "10.10.10.1" {
t.Errorf("IPs: got %v, want [10.10.10.1]", ips)
}
}
func TestStripPrefix(t *testing.T) {
tests := []struct {
in string
want string
}{
{"10.0.0.1/24", "10.0.0.1"},
{"::1/128", "::1"},
{"192.168.1.1", "192.168.1.1"}, // no prefix — unchanged
{"", ""},
}
for _, tt := range tests {
if got := stripPrefix(tt.in); got != tt.want {
t.Errorf("stripPrefix(%q) = %q, want %q", tt.in, got, tt.want)
}
}
}
+53
View File
@@ -0,0 +1,53 @@
package netbox
// HostEntry is a unified model for both devices and virtual machines from NetBox.
type HostEntry struct {
ID int
Name string
Kind string // "device" | "vm"
PrimaryIP4 string // e.g. "10.0.1.5" (prefix length stripped)
PrimaryIP6 string
Tags []string
}
// netboxIP represents an IP address as returned by the NetBox API.
type netboxIP struct {
Address string `json:"address"` // CIDR notation, e.g. "10.0.1.5/24"
}
// netboxDevice matches the relevant fields of the NetBox /dcim/devices/ response.
type netboxDevice struct {
ID int `json:"id"`
Name string `json:"name"`
Tags []struct {
Name string `json:"name"`
} `json:"tags"`
PrimaryIP4 *netboxIP `json:"primary_ip4"`
PrimaryIP6 *netboxIP `json:"primary_ip6"`
}
// netboxVM matches the relevant fields of the NetBox /virtualization/virtual-machines/ response.
type netboxVM struct {
ID int `json:"id"`
Name string `json:"name"`
Tags []struct {
Name string `json:"name"`
} `json:"tags"`
PrimaryIP4 *netboxIP `json:"primary_ip4"`
PrimaryIP6 *netboxIP `json:"primary_ip6"`
}
type netboxListResponse[T any] struct {
Count int `json:"count"`
Results []T `json:"results"`
}
type netboxIPListResponse struct {
Count int `json:"count"`
Results []struct {
Address string `json:"address"`
Interface *struct {
Name string `json:"name"`
} `json:"assigned_object"`
} `json:"results"`
}
+57
View File
@@ -0,0 +1,57 @@
package resolver
import (
"context"
"fmt"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// Chain tries each strategy in order until one returns an IP.
type Chain struct {
strategies []Strategy
}
// New builds a Chain from the strategy names listed in the resolver config.
func New(cfg config.ResolverConfig) (*Chain, error) {
var strategies []Strategy
for _, name := range cfg.Strategies {
s, err := newStrategy(name, cfg)
if err != nil {
return nil, fmt.Errorf("resolver strategy %q: %w", name, err)
}
strategies = append(strategies, s)
}
return &Chain{strategies: strategies}, nil
}
func (c *Chain) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) {
for _, s := range c.strategies {
ip, err := s.Resolve(ctx, entry, client)
if err == nil {
return ip, nil
}
}
return "", fmt.Errorf("no strategy resolved an IP for %q", entry.Name)
}
func newStrategy(name string, cfg config.ResolverConfig) (Strategy, error) {
switch name {
case "primary_ip":
return &PrimaryIPStrategy{}, nil
case "management_subnet":
s, err := NewManagementSubnetStrategy(cfg.ManagementSubnets)
if err != nil {
return nil, err
}
return s, nil
case "interface_name":
if cfg.InterfaceName == "" {
return nil, fmt.Errorf("interface_name strategy requires resolver.interface_name to be set")
}
return &InterfaceNameStrategy{name: cfg.InterfaceName}, nil
default:
return nil, fmt.Errorf("unknown strategy %q", name)
}
}
+155
View File
@@ -0,0 +1,155 @@
package resolver
import (
"context"
"errors"
"testing"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// stubStrategy is a test double for Strategy.
type stubStrategy struct {
name string
ip string
err error
}
func (s *stubStrategy) Name() string { return s.name }
func (s *stubStrategy) Resolve(_ context.Context, _ *netbox.HostEntry, _ *netbox.Client) (string, error) {
return s.ip, s.err
}
func TestChain_FirstStrategySucceeds(t *testing.T) {
c := &Chain{strategies: []Strategy{
&stubStrategy{name: "first", ip: "10.0.0.1"},
&stubStrategy{name: "second", ip: "10.0.0.2"},
}}
ip, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "10.0.0.1" {
t.Errorf("got %q, want first strategy's IP %q", ip, "10.0.0.1")
}
}
func TestChain_FallsBackToNextStrategy(t *testing.T) {
c := &Chain{strategies: []Strategy{
&stubStrategy{name: "first", err: ErrNoIP},
&stubStrategy{name: "second", ip: "10.0.0.2"},
}}
ip, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "10.0.0.2" {
t.Errorf("got %q, want second strategy's IP %q", ip, "10.0.0.2")
}
}
func TestChain_AllStrategiesFail(t *testing.T) {
c := &Chain{strategies: []Strategy{
&stubStrategy{name: "a", err: ErrNoIP},
&stubStrategy{name: "b", err: errors.New("api error")},
}}
_, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
if err == nil {
t.Error("expected error when all strategies fail")
}
}
func TestChain_EmptyStrategies(t *testing.T) {
c := &Chain{}
_, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil)
if err == nil {
t.Error("empty chain should return an error")
}
}
func TestNew_PrimaryIP(t *testing.T) {
cfg := config.ResolverConfig{Strategies: []string{"primary_ip"}}
c, err := New(cfg)
if err != nil {
t.Fatalf("New: %v", err)
}
if len(c.strategies) != 1 {
t.Errorf("got %d strategies, want 1", len(c.strategies))
}
if c.strategies[0].Name() != "primary_ip" {
t.Errorf("strategy name: got %q, want %q", c.strategies[0].Name(), "primary_ip")
}
}
func TestNew_ManagementSubnet(t *testing.T) {
cfg := config.ResolverConfig{
Strategies: []string{"management_subnet"},
ManagementSubnets: []string{"10.0.0.0/8"},
}
c, err := New(cfg)
if err != nil {
t.Fatalf("New: %v", err)
}
if c.strategies[0].Name() != "management_subnet" {
t.Errorf("strategy name: got %q, want %q", c.strategies[0].Name(), "management_subnet")
}
}
func TestNew_ManagementSubnet_InvalidCIDR(t *testing.T) {
cfg := config.ResolverConfig{
Strategies: []string{"management_subnet"},
ManagementSubnets: []string{"not-a-cidr"},
}
_, err := New(cfg)
if err == nil {
t.Error("invalid CIDR should return an error")
}
}
func TestNew_InterfaceName(t *testing.T) {
cfg := config.ResolverConfig{
Strategies: []string{"interface_name"},
InterfaceName: "mgmt0",
}
c, err := New(cfg)
if err != nil {
t.Fatalf("New: %v", err)
}
if c.strategies[0].Name() != "interface_name" {
t.Errorf("strategy name: got %q", c.strategies[0].Name())
}
}
func TestNew_InterfaceName_MissingConfig(t *testing.T) {
cfg := config.ResolverConfig{
Strategies: []string{"interface_name"},
InterfaceName: "", // not set
}
_, err := New(cfg)
if err == nil {
t.Error("interface_name without config should return an error")
}
}
func TestNew_UnknownStrategy(t *testing.T) {
cfg := config.ResolverConfig{Strategies: []string{"nonexistent"}}
_, err := New(cfg)
if err == nil {
t.Error("unknown strategy should return an error")
}
}
func TestNew_MultipleStrategies(t *testing.T) {
cfg := config.ResolverConfig{
Strategies: []string{"management_subnet", "primary_ip"},
ManagementSubnets: []string{"10.0.0.0/8"},
}
c, err := New(cfg)
if err != nil {
t.Fatalf("New: %v", err)
}
if len(c.strategies) != 2 {
t.Errorf("got %d strategies, want 2", len(c.strategies))
}
}
+38
View File
@@ -0,0 +1,38 @@
package resolver
import (
"context"
"fmt"
"net/url"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// InterfaceNameStrategy finds the first IP assigned to a named interface (e.g. "mgmt0", "eth0").
type InterfaceNameStrategy struct {
name string
}
func (s *InterfaceNameStrategy) Name() string { return "interface_name" }
func (s *InterfaceNameStrategy) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) {
// Build filter parameters for IP addresses attached to the named interface.
var filterParam string
switch entry.Kind {
case "device":
filterParam = fmt.Sprintf("device_id=%d&interface_name=%s", entry.ID, url.QueryEscape(s.name))
case "vm":
filterParam = fmt.Sprintf("virtual_machine_id=%d&vminterface_name=%s", entry.ID, url.QueryEscape(s.name))
default:
return "", fmt.Errorf("unknown kind %q", entry.Kind)
}
ips, err := client.GetIPsWithFilter(ctx, filterParam)
if err != nil {
return "", fmt.Errorf("fetching IPs for interface %q: %w", s.name, err)
}
if len(ips) == 0 {
return "", ErrNoIP
}
return ips[0], nil
}
+53
View File
@@ -0,0 +1,53 @@
package resolver
import (
"context"
"fmt"
"net"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// ManagementSubnetStrategy finds the first IP of a host that falls within
// one of the configured management subnets.
type ManagementSubnetStrategy struct {
subnets []*net.IPNet
}
func NewManagementSubnetStrategy(cidrs []string) (*ManagementSubnetStrategy, error) {
nets := make([]*net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
nets = append(nets, ipNet)
}
return &ManagementSubnetStrategy{subnets: nets}, nil
}
func (s *ManagementSubnetStrategy) Name() string { return "management_subnet" }
func (s *ManagementSubnetStrategy) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) {
if len(s.subnets) == 0 {
return "", ErrNoIP
}
ips, err := client.GetIPs(ctx, *entry)
if err != nil {
return "", fmt.Errorf("fetching IPs: %w", err)
}
for _, rawIP := range ips {
ip := net.ParseIP(rawIP)
if ip == nil {
continue
}
for _, subnet := range s.subnets {
if subnet.Contains(ip) {
return rawIP, nil
}
}
}
return "", ErrNoIP
}
+109
View File
@@ -0,0 +1,109 @@
package resolver
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// newIPServer returns a test server that always responds with the given IP list.
func newIPServer(t *testing.T, ips []string) *httptest.Server {
t.Helper()
type result struct {
Address string `json:"address"`
}
type response struct {
Count int `json:"count"`
Results []result `json:"results"`
}
resp := response{Count: len(ips)}
for _, ip := range ips {
resp.Results = append(resp.Results, result{Address: ip})
}
body, _ := json.Marshal(resp)
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(body)
}))
}
func TestManagementSubnetStrategy_MatchesSubnet(t *testing.T) {
srv := newIPServer(t, []string{"10.0.1.5/24", "192.168.0.1/24"})
defer srv.Close()
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
client := netbox.NewClient(srv.URL, "token")
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "10.0.1.5" {
t.Errorf("got %q, want %q", ip, "10.0.1.5")
}
}
func TestManagementSubnetStrategy_NoMatch(t *testing.T) {
srv := newIPServer(t, []string{"192.168.0.1/24"})
defer srv.Close()
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
client := netbox.NewClient(srv.URL, "token")
_, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
if err != ErrNoIP {
t.Errorf("no matching subnet should return ErrNoIP, got %v", err)
}
}
func TestManagementSubnetStrategy_FirstMatchWins(t *testing.T) {
srv := newIPServer(t, []string{"10.0.1.1/24", "10.0.1.2/24"})
defer srv.Close()
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
client := netbox.NewClient(srv.URL, "token")
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "10.0.1.1" {
t.Errorf("got %q, want first matching IP %q", ip, "10.0.1.1")
}
}
func TestManagementSubnetStrategy_VMKind(t *testing.T) {
srv := newIPServer(t, []string{"172.16.5.10/16"})
defer srv.Close()
s, _ := NewManagementSubnetStrategy([]string{"172.16.0.0/12"})
client := netbox.NewClient(srv.URL, "token")
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 2, Kind: "vm"}, client)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "172.16.5.10" {
t.Errorf("got %q, want %q", ip, "172.16.5.10")
}
}
func TestManagementSubnetStrategy_IPv6Subnet(t *testing.T) {
srv := newIPServer(t, []string{"fd00::1/64"})
defer srv.Close()
s, _ := NewManagementSubnetStrategy([]string{"fd00::/8"})
client := netbox.NewClient(srv.URL, "token")
ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "fd00::1" {
t.Errorf("got %q, want %q", ip, "fd00::1")
}
}
+23
View File
@@ -0,0 +1,23 @@
package resolver
import (
"context"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// PrimaryIPStrategy returns the primary IP configured in NetBox.
// Prefers IPv4, falls back to IPv6.
type PrimaryIPStrategy struct{}
func (s *PrimaryIPStrategy) Name() string { return "primary_ip" }
func (s *PrimaryIPStrategy) Resolve(_ context.Context, entry *netbox.HostEntry, _ *netbox.Client) (string, error) {
if entry.PrimaryIP4 != "" {
return entry.PrimaryIP4, nil
}
if entry.PrimaryIP6 != "" {
return entry.PrimaryIP6, nil
}
return "", ErrNoIP
}
+91
View File
@@ -0,0 +1,91 @@
package resolver
import (
"context"
"testing"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
func TestPrimaryIPStrategy_Name(t *testing.T) {
s := &PrimaryIPStrategy{}
if s.Name() != "primary_ip" {
t.Errorf("Name: got %q, want %q", s.Name(), "primary_ip")
}
}
func TestPrimaryIPStrategy_IPv4(t *testing.T) {
s := &PrimaryIPStrategy{}
e := &netbox.HostEntry{PrimaryIP4: "10.0.0.1", PrimaryIP6: "::1"}
ip, err := s.Resolve(context.Background(), e, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "10.0.0.1" {
t.Errorf("got %q, want IPv4 %q", ip, "10.0.0.1")
}
}
func TestPrimaryIPStrategy_IPv6Fallback(t *testing.T) {
s := &PrimaryIPStrategy{}
e := &netbox.HostEntry{PrimaryIP6: "::1"}
ip, err := s.Resolve(context.Background(), e, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ip != "::1" {
t.Errorf("got %q, want IPv6 %q", ip, "::1")
}
}
func TestPrimaryIPStrategy_NoIP(t *testing.T) {
s := &PrimaryIPStrategy{}
_, err := s.Resolve(context.Background(), &netbox.HostEntry{}, nil)
if err != ErrNoIP {
t.Errorf("got %v, want ErrNoIP", err)
}
}
func TestManagementSubnetStrategy_Name(t *testing.T) {
s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"})
if s.Name() != "management_subnet" {
t.Errorf("Name: got %q, want %q", s.Name(), "management_subnet")
}
}
func TestManagementSubnetStrategy_InvalidCIDR(t *testing.T) {
_, err := NewManagementSubnetStrategy([]string{"not-a-cidr"})
if err == nil {
t.Error("invalid CIDR should return an error")
}
}
func TestManagementSubnetStrategy_EmptyCIDRs(t *testing.T) {
s, _ := NewManagementSubnetStrategy([]string{})
_, err := s.Resolve(context.Background(), &netbox.HostEntry{}, nil)
if err != ErrNoIP {
t.Errorf("empty subnets should return ErrNoIP, got %v", err)
}
}
func TestManagementSubnetStrategy_MultipleCIDRs(t *testing.T) {
_, err := NewManagementSubnetStrategy([]string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"})
if err != nil {
t.Fatalf("valid CIDRs should not error: %v", err)
}
}
func TestInterfaceNameStrategy_Name(t *testing.T) {
s := &InterfaceNameStrategy{name: "mgmt0"}
if s.Name() != "interface_name" {
t.Errorf("Name: got %q, want %q", s.Name(), "interface_name")
}
}
func TestInterfaceNameStrategy_UnknownKind(t *testing.T) {
s := &InterfaceNameStrategy{name: "eth0"}
_, err := s.Resolve(context.Background(), &netbox.HostEntry{Kind: "unknown"}, nil)
if err == nil {
t.Error("unknown kind should return an error")
}
}
+17
View File
@@ -0,0 +1,17 @@
package resolver
import (
"context"
"errors"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// ErrNoIP is returned when a strategy cannot find a matching IP address.
var ErrNoIP = errors.New("no matching IP found")
// Strategy is a single rule for resolving an IP address from a NetBox host entry.
type Strategy interface {
Name() string
Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error)
}
+109
View File
@@ -0,0 +1,109 @@
package ssh
import "strings"
// flagsWithArg lists all SSH flags that consume the following argument.
var flagsWithArg = map[byte]bool{
'b': true, 'c': true, 'D': true, 'E': true, 'e': true,
'F': true, 'I': true, 'i': true, 'J': true, 'L': true,
'l': true, 'm': true, 'o': true, 'O': true, 'p': true,
'Q': true, 'R': true, 'S': true, 'w': true, 'W': true,
}
// ParsedArgs holds the result of parsing SSH arguments.
type ParsedArgs struct {
Host string // hostname without the user@ prefix
User string // empty if not specified
DestIdx int // index in Args where [user@]host sits
Args []string
}
// Parse scans SSH arguments and extracts the destination ([user@]host).
// Returns nil if no destination is found.
func Parse(args []string) *ParsedArgs {
i := 0
for i < len(args) {
arg := args[i]
// "--" ends option processing
if arg == "--" {
i++
break
}
if strings.HasPrefix(arg, "-") && len(arg) > 1 {
flag := arg[1]
if flagsWithArg[flag] {
if len(arg) > 2 {
// argument is attached, e.g. -p2222
i++
} else {
// argument is the next element, e.g. -p 2222
i += 2
}
} else {
i++
}
continue
}
// first non-flag argument is the destination
host, user := splitUserHost(arg)
return &ParsedArgs{
Host: host,
User: user,
DestIdx: i,
Args: args,
}
}
// handle arguments after "--"
if i < len(args) {
host, user := splitUserHost(args[i])
return &ParsedArgs{
Host: host,
User: user,
DestIdx: i,
Args: args,
}
}
return nil
}
// ReplaceHost returns a copy of args with the destination replaced by newHost,
// preserving any user@ prefix.
func ReplaceHost(args []string, destIdx int, newHost string) []string {
result := make([]string, len(args))
copy(result, args)
original := args[destIdx]
if at := strings.Index(original, "@"); at != -1 {
result[destIdx] = original[:at+1] + newHost
} else {
result[destIdx] = newHost
}
return result
}
// HasUserFlag reports whether a user was specified via -l in args.
// Used to avoid overriding an explicit -l with the configured default user.
func HasUserFlag(args []string) bool {
for i, a := range args {
if a == "-l" && i+1 < len(args) {
return true
}
// handle attached form: -lroot
if len(a) > 2 && a[0] == '-' && a[1] == 'l' {
return true
}
}
return false
}
func splitUserHost(dest string) (host, user string) {
if at := strings.Index(dest, "@"); at != -1 {
return dest[at+1:], dest[:at]
}
return dest, ""
}
+161
View File
@@ -0,0 +1,161 @@
package ssh
import (
"testing"
)
func TestParse_BareHostname(t *testing.T) {
got := Parse([]string{"myhost"})
assertParsed(t, got, "myhost", "", 0)
}
func TestParse_UserAtHost(t *testing.T) {
got := Parse([]string{"admin@myhost"})
assertParsed(t, got, "myhost", "admin", 0)
}
func TestParse_PortFlag_Separated(t *testing.T) {
got := Parse([]string{"-p", "2222", "myhost"})
assertParsed(t, got, "myhost", "", 2)
}
func TestParse_PortFlag_Attached(t *testing.T) {
got := Parse([]string{"-p2222", "myhost"})
assertParsed(t, got, "myhost", "", 1)
}
func TestParse_IdentityFlag(t *testing.T) {
got := Parse([]string{"-i", "/path/to/key", "user@myhost", "ls"})
assertParsed(t, got, "myhost", "user", 2)
}
func TestParse_VerboseFlag(t *testing.T) {
got := Parse([]string{"-v", "myhost"})
assertParsed(t, got, "myhost", "", 1)
}
func TestParse_OptionFlag(t *testing.T) {
got := Parse([]string{"-o", "StrictHostKeyChecking=no", "myhost"})
assertParsed(t, got, "myhost", "", 2)
}
func TestParse_JumpHost(t *testing.T) {
got := Parse([]string{"-J", "jumphost", "-p", "22", "target"})
assertParsed(t, got, "target", "", 4)
}
func TestParse_MultipleFlags(t *testing.T) {
got := Parse([]string{"-v", "-p", "22", "-i", "key", "root@host", "uptime"})
assertParsed(t, got, "host", "root", 5)
}
func TestParse_DoubleDash(t *testing.T) {
got := Parse([]string{"--", "myhost"})
assertParsed(t, got, "myhost", "", 1)
}
func TestParse_DoubleDash_WithFlags(t *testing.T) {
// flags after -- should be treated as destination
got := Parse([]string{"-v", "--", "-not-a-flag"})
assertParsed(t, got, "-not-a-flag", "", 2)
}
func TestParse_NoDestination(t *testing.T) {
got := Parse([]string{"-v", "-p", "2222"})
if got != nil {
t.Errorf("expected nil for args without destination, got %+v", got)
}
}
func TestParse_EmptyArgs(t *testing.T) {
got := Parse([]string{})
if got != nil {
t.Error("empty args should return nil")
}
}
func TestParse_OnlyDoubleDash(t *testing.T) {
got := Parse([]string{"--"})
if got != nil {
t.Error("only -- with no destination should return nil")
}
}
func TestReplaceHost_PlainHost(t *testing.T) {
args := []string{"myhost"}
result := ReplaceHost(args, 0, "10.0.0.1")
if result[0] != "10.0.0.1" {
t.Errorf("got %q, want %q", result[0], "10.0.0.1")
}
}
func TestReplaceHost_PreservesUserPrefix(t *testing.T) {
args := []string{"-p", "22", "admin@myhost", "ls"}
result := ReplaceHost(args, 2, "10.0.0.1")
if result[2] != "admin@10.0.0.1" {
t.Errorf("got %q, want %q", result[2], "admin@10.0.0.1")
}
}
func TestReplaceHost_DoesNotMutateOriginal(t *testing.T) {
args := []string{"myhost"}
_ = ReplaceHost(args, 0, "10.0.0.1")
if args[0] != "myhost" {
t.Error("ReplaceHost must not mutate the original slice")
}
}
func TestReplaceHost_OtherArgsUnchanged(t *testing.T) {
args := []string{"-p", "22", "myhost"}
result := ReplaceHost(args, 2, "10.0.0.1")
if result[0] != "-p" || result[1] != "22" {
t.Errorf("other args should be unchanged: %v", result)
}
}
func TestHasUserFlag_FlagSeparated(t *testing.T) {
if !HasUserFlag([]string{"-l", "admin", "host"}) {
t.Error("should detect -l <user>")
}
}
func TestHasUserFlag_FlagAttached(t *testing.T) {
if !HasUserFlag([]string{"-ladmin", "host"}) {
t.Error("should detect -l<user> (attached form)")
}
}
func TestHasUserFlag_NotPresent(t *testing.T) {
if HasUserFlag([]string{"-p", "22", "host"}) {
t.Error("should not detect user flag when absent")
}
}
func TestHasUserFlag_EmptyArgs(t *testing.T) {
if HasUserFlag([]string{}) {
t.Error("empty args should return false")
}
}
func TestHasUserFlag_LFlagAtEnd(t *testing.T) {
// -l at the very end with no value — should not panic
if HasUserFlag([]string{"-l"}) {
t.Error("-l with no value should return false")
}
}
func assertParsed(t *testing.T, got *ParsedArgs, host, user string, destIdx int) {
t.Helper()
if got == nil {
t.Fatal("Parse returned nil")
}
if got.Host != host {
t.Errorf("host: got %q, want %q", got.Host, host)
}
if got.User != user {
t.Errorf("user: got %q, want %q", got.User, user)
}
if got.DestIdx != destIdx {
t.Errorf("destIdx: got %d, want %d", got.DestIdx, destIdx)
}
}
+19
View File
@@ -0,0 +1,19 @@
package ssh
import (
"fmt"
"os"
"os/exec"
"syscall"
)
// Exec replaces the current process with the native ssh client via syscall.Exec.
// All existing SSH configs, keys, and agent forwarding remain intact.
func Exec(args []string) error {
sshPath, err := exec.LookPath("ssh")
if err != nil {
return fmt.Errorf("ssh not found in PATH: %w", err)
}
argv := append([]string{"ssh"}, args...)
return syscall.Exec(sshPath, argv, os.Environ())
}
+237
View File
@@ -0,0 +1,237 @@
package tui
import (
"context"
"fmt"
"io"
"strings"
"time"
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/cache"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox"
)
// SelectedHost is returned when the user confirms a host in the TUI.
type SelectedHost struct {
Name string
IP string
}
// --- bubbletea messages ---
type debounceMsg struct{ query string }
type searchResultMsg struct {
query string
entries []netbox.HostEntry
err error
}
// --- list item ---
type hostItem struct {
name string
ip string
kind string
}
func (h hostItem) Title() string { return h.name }
func (h hostItem) Description() string { return fmt.Sprintf("%s [%s]", h.ip, h.kind) }
func (h hostItem) FilterValue() string { return h.name }
// --- compact list delegate ---
type compactDelegate struct{}
func (d compactDelegate) Height() int { return 1 }
func (d compactDelegate) Spacing() int { return 0 }
func (d compactDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil }
func (d compactDelegate) Render(w io.Writer, m list.Model, index int, item list.Item) {
h, ok := item.(hostItem)
if !ok {
return
}
line := fmt.Sprintf(" %s %s", h.name, lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(h.ip))
if index == m.Index() {
line = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("86")).Render("> " + strings.TrimPrefix(line, " "))
}
fmt.Fprintln(w, line)
}
// --- Model ---
type Model struct {
input textinput.Model
list list.Model
client *netbox.Client
cache *cache.Cache
lastSent string // last query sent to NetBox (or served from cache)
seq int // sequence number to discard stale results
loading bool
err error
selected *SelectedHost
width int
height int
}
func New(client *netbox.Client, c *cache.Cache) *Model {
ti := textinput.New()
ti.Placeholder = "Search hostname…"
ti.Focus()
l := list.New(nil, compactDelegate{}, 0, 0)
l.SetShowHelp(false)
l.SetShowTitle(false)
l.SetShowStatusBar(false)
l.SetFilteringEnabled(false)
return &Model{
input: ti,
list: l,
client: client,
cache: c,
}
}
func (m *Model) Init() tea.Cmd {
return textinput.Blink
}
func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.list.SetSize(msg.Width, msg.Height-4)
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "esc":
return m, tea.Quit
case "enter":
if item, ok := m.list.SelectedItem().(hostItem); ok {
m.selected = &SelectedHost{Name: item.name, IP: item.ip}
return m, tea.Quit
}
case "tab":
// Copy the top result into the search field.
if m.list.Items() != nil && len(m.list.Items()) > 0 {
if item, ok := m.list.Items()[0].(hostItem); ok {
m.input.SetValue(item.name)
m.input.CursorEnd()
}
}
return m, nil
}
case debounceMsg:
// Only query if the input has changed since the last request.
q := m.input.Value()
if q == m.lastSent {
return m, nil
}
m.lastSent = q
m.loading = true
m.seq++
seq := m.seq
return m, m.doSearch(q, seq)
case searchResultMsg:
if msg.query != m.lastSent {
return m, nil // discard stale result
}
m.loading = false
if msg.err != nil {
m.err = msg.err
return m, nil
}
items := make([]list.Item, len(msg.entries))
for i, e := range msg.entries {
ip := e.PrimaryIP4
if ip == "" {
ip = e.PrimaryIP6
}
items[i] = hostItem{name: e.Name, ip: ip, kind: e.Kind}
}
m.list.SetItems(items)
m.err = nil
return m, nil
}
// Forward to text input and restart the debounce timer.
var cmds []tea.Cmd
var inputCmd tea.Cmd
m.input, inputCmd = m.input.Update(msg)
cmds = append(cmds, inputCmd)
cmds = append(cmds, m.startDebounce())
var listCmd tea.Cmd
m.list, listCmd = m.list.Update(msg)
cmds = append(cmds, listCmd)
return m, tea.Batch(cmds...)
}
func (m *Model) View() string {
var sb strings.Builder
title := lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("86")).Render("netssh")
sb.WriteString(title + "\n\n")
sb.WriteString(m.input.View() + "\n")
if m.loading {
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(" searching…") + "\n")
} else if m.err != nil {
sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Render(" error: "+m.err.Error()) + "\n")
} else {
sb.WriteString(m.list.View())
}
return sb.String()
}
// Selected returns the host chosen by the user, or nil if none was selected.
func (m *Model) Selected() *SelectedHost {
return m.selected
}
func (m *Model) startDebounce() tea.Cmd {
return tea.Tick(300*time.Millisecond, func(_ time.Time) tea.Msg {
return debounceMsg{query: m.input.Value()}
})
}
func (m *Model) doSearch(query string, seq int) tea.Cmd {
return func() tea.Msg {
// Return cache hits immediately without a network round-trip.
if m.cache != nil {
if cached := m.cache.Search(query); len(cached) > 0 {
entries := make([]netbox.HostEntry, len(cached))
for i, c := range cached {
entries[i] = netbox.HostEntry{Name: c.Name, PrimaryIP4: c.IP, Kind: c.Kind}
}
return searchResultMsg{query: query, entries: entries}
}
}
if m.client == nil {
return searchResultMsg{query: query, entries: nil}
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
entries, err := m.client.Search(ctx, query)
_ = seq
return searchResultMsg{query: query, entries: entries, err: err}
}
}