diff --git a/cmd/netssh/main.go b/cmd/netssh/main.go index 19994aa..a222273 100644 --- a/cmd/netssh/main.go +++ b/cmd/netssh/main.go @@ -99,7 +99,7 @@ func runSSHWrapper(args []string) { if cfg.NetBox.URL == "" { fatalf("netbox.url is not configured (~/.config/netssh.yaml)") } - nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token) + nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token, cfg.NetBox.TokenVersion) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -152,7 +152,7 @@ func runTUI() { var nbClient *netbox.Client if cfg.NetBox.URL != "" { - nbClient = netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token) + nbClient = netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token, cfg.NetBox.TokenVersion) } m := tui.New(nbClient, c) @@ -320,7 +320,7 @@ func cacheRefreshCmd() *cobra.Command { return fmt.Errorf("netbox.url is not configured") } - nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token) + nbClient := netbox.NewClient(cfg.NetBox.URL, cfg.NetBox.Token, cfg.NetBox.TokenVersion) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..26db733 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,145 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func writeConfig(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "netssh.yaml") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + return dir +} + +func loadFromDir(t *testing.T, dir string) *Config { + t.Helper() + // Override UserConfigDir by pointing viper at our temp dir via env isn't + // straightforward, so we exercise Load() by temporarily changing XDG_CONFIG_HOME. + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + t.Cleanup(func() { os.Setenv("XDG_CONFIG_HOME", orig) }) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + return cfg +} + +func TestLoad_V2TokenVersion_Preserved(t *testing.T) { + dir := writeConfig(t, ` +netbox: + url: "https://netbox.example.com" + token: "nbt_abc123" + token_version: 2 +`) + cfg := loadFromDir(t, dir) + if cfg.NetBox.TokenVersion != 2 { + t.Errorf("TokenVersion: got %d, want 2", cfg.NetBox.TokenVersion) + } +} + +func TestLoad_V1TokenVersion_Preserved(t *testing.T) { + dir := writeConfig(t, ` +netbox: + url: "https://netbox.example.com" + token: "legacyhex123" + token_version: 1 +`) + cfg := loadFromDir(t, dir) + if cfg.NetBox.TokenVersion != 1 { + t.Errorf("TokenVersion: got %d, want 1", cfg.NetBox.TokenVersion) + } +} + +func TestLoad_AutoDetectsV2_WhenFieldMissing(t *testing.T) { + dir := writeConfig(t, ` +netbox: + url: "https://netbox.example.com" + token: "nbt_mytoken" +`) + cfg := loadFromDir(t, dir) + if cfg.NetBox.TokenVersion != 2 { + t.Errorf("TokenVersion: got %d, want 2 (auto-detected from nbt_ prefix)", cfg.NetBox.TokenVersion) + } +} + +func TestLoad_AutoDetectsV1_WhenFieldMissing(t *testing.T) { + dir := writeConfig(t, ` +netbox: + url: "https://netbox.example.com" + token: "abc123def456" +`) + cfg := loadFromDir(t, dir) + if cfg.NetBox.TokenVersion != 1 { + t.Errorf("TokenVersion: got %d, want 1 (auto-detected from plain token)", cfg.NetBox.TokenVersion) + } +} + +func TestLoad_TokenVersionZero_WhenNoToken(t *testing.T) { + dir := writeConfig(t, ` +netbox: + url: "https://netbox.example.com" +`) + cfg := loadFromDir(t, dir) + if cfg.NetBox.TokenVersion != 0 { + t.Errorf("TokenVersion: got %d, want 0 (no token present)", cfg.NetBox.TokenVersion) + } +} + +func TestLoad_Defaults(t *testing.T) { + dir := writeConfig(t, ` +netbox: + url: "https://netbox.example.com" + token: "nbt_x" +`) + cfg := loadFromDir(t, dir) + + if cfg.Cache.TTL != 3600 { + t.Errorf("default cache.ttl: got %d, want 3600", cfg.Cache.TTL) + } + if cfg.Cache.Path == "" { + t.Error("cache.path should be auto-set when empty") + } +} + +func TestLoad_MissingFile_ReturnsEmptyConfig(t *testing.T) { + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", t.TempDir()) // dir exists but no netssh.yaml + defer os.Setenv("XDG_CONFIG_HOME", orig) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load on missing file should not error: %v", err) + } + if cfg.NetBox.URL != "" { + t.Errorf("expected empty URL, got %q", cfg.NetBox.URL) + } +} + +func TestLoad_InvalidYAML_ReturnsError(t *testing.T) { + dir := writeConfig(t, "not: valid: yaml: [[[") + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + _, err := Load() + if err == nil { + t.Error("invalid YAML should return an error") + } +} + +func TestPath_ReturnsNonEmpty(t *testing.T) { + p := Path() + if p == "" { + t.Error("Path() should return a non-empty string") + } + if filepath.Base(p) != "netssh.yaml" { + t.Errorf("Path() base: got %q, want netssh.yaml", filepath.Base(p)) + } +} diff --git a/internal/netbox/client.go b/internal/netbox/client.go index 07894b4..cdfa2a5 100644 --- a/internal/netbox/client.go +++ b/internal/netbox/client.go @@ -11,16 +11,23 @@ import ( ) type Client struct { - baseURL string - token string - httpClient *http.Client + baseURL string + token string + tokenVersion int + httpClient *http.Client } -func NewClient(baseURL, token string) *Client { +// NewClient creates a NetBox API client. Pass tokenVersion=0 to auto-detect +// from the token string (1 for legacy, 2 for nbt_-prefixed tokens). +func NewClient(baseURL, token string, tokenVersion int) *Client { + if tokenVersion == 0 { + tokenVersion = TokenVersion(token) + } return &Client{ - baseURL: strings.TrimRight(baseURL, "/"), - token: token, - httpClient: &http.Client{}, + baseURL: strings.TrimRight(baseURL, "/"), + token: token, + tokenVersion: tokenVersion, + httpClient: &http.Client{}, } } @@ -157,7 +164,7 @@ func (c *Client) get(ctx context.Context, apiURL string, out any) error { if resp.StatusCode == http.StatusForbidden { hint := "check token permissions in NetBox" - if TokenVersion(c.token) == 1 { + if c.tokenVersion == 1 { hint += " — legacy v1 token detected, consider upgrading to a v2 token (starts with nbt_)" } return fmt.Errorf("%s: %s", apiURL, hint) diff --git a/internal/netbox/client_test.go b/internal/netbox/client_test.go index cb50b79..5ac023b 100644 --- a/internal/netbox/client_test.go +++ b/internal/netbox/client_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -58,7 +59,7 @@ func TestSearch_ReturnsBothDevicesAndVMs(t *testing.T) { }) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) results, err := c.Search(context.Background(), "") if err != nil { t.Fatalf("Search: %v", err) @@ -87,7 +88,7 @@ func TestSearch_MapsKindCorrectly(t *testing.T) { }) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) results, _ := c.Search(context.Background(), "") for _, r := range results { @@ -113,7 +114,7 @@ func TestSearch_StripsPrefixFromPrimaryIP(t *testing.T) { }) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) results, _ := c.Search(context.Background(), "host") if len(results) == 0 { t.Fatal("expected at least one result") @@ -138,7 +139,7 @@ func TestSearch_TagsAreMapped(t *testing.T) { }) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) results, _ := c.Search(context.Background(), "") if len(results[0].Tags) != 2 { t.Errorf("tags: got %v, want [prod mgmt]", results[0].Tags) @@ -159,7 +160,7 @@ func TestSearch_PartialFailure_ReturnsAvailableResults(t *testing.T) { srv := httptest.NewServer(mux) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) results, err := c.Search(context.Background(), "") if err != nil { t.Fatalf("partial failure should not return error, got: %v", err) @@ -177,7 +178,7 @@ func TestSearch_BothFail_ReturnsError(t *testing.T) { srv := httptest.NewServer(mux) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) _, err := c.Search(context.Background(), "") if err == nil { t.Error("both endpoints failing should return an error") @@ -190,7 +191,7 @@ func TestGetIPs_Device(t *testing.T) { }) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) ips, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "device"}) if err != nil { t.Fatalf("GetIPs: %v", err) @@ -209,7 +210,7 @@ func TestGetIPs_VM(t *testing.T) { }) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) ips, err := c.GetIPs(context.Background(), HostEntry{ID: 2, Kind: "vm"}) if err != nil { t.Fatalf("GetIPs: %v", err) @@ -220,7 +221,7 @@ func TestGetIPs_VM(t *testing.T) { } func TestGetIPs_UnknownKind(t *testing.T) { - c := NewClient("http://localhost", "token") + c := NewClient("http://localhost", "token", 0) _, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "unknown"}) if err == nil { t.Error("unknown kind should return an error") @@ -233,7 +234,7 @@ func TestGetIPsWithFilter(t *testing.T) { }) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient(srv.URL, "token", 0) ips, err := c.GetIPsWithFilter(context.Background(), "device_id=1&interface_name=mgmt0") if err != nil { t.Fatalf("GetIPsWithFilter: %v", err) @@ -243,6 +244,98 @@ func TestGetIPsWithFilter(t *testing.T) { } } +func TestTokenVersion(t *testing.T) { + tests := []struct { + token string + want int + }{ + {"nbt_abc123", 2}, + {"nbt_", 2}, + {"abc123def456", 1}, + {"", 1}, + {"Token abc", 1}, + } + for _, tt := range tests { + if got := TokenVersion(tt.token); got != tt.want { + t.Errorf("TokenVersion(%q) = %d, want %d", tt.token, got, tt.want) + } + } +} + +func TestNewClient_AutoDetectsVersion(t *testing.T) { + c := NewClient("http://localhost", "nbt_secret", 0) + if c.tokenVersion != 2 { + t.Errorf("tokenVersion: got %d, want 2", c.tokenVersion) + } + + c2 := NewClient("http://localhost", "legacytoken", 0) + if c2.tokenVersion != 1 { + t.Errorf("tokenVersion: got %d, want 1", c2.tokenVersion) + } +} + +func TestNewClient_RespectsExplicitVersion(t *testing.T) { + // Explicit version overrides auto-detection. + c := NewClient("http://localhost", "legacytoken", 2) + if c.tokenVersion != 2 { + t.Errorf("tokenVersion: got %d, want 2", c.tokenVersion) + } +} + +func Test403_V1Token_HintsUpgrade(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "forbidden", http.StatusForbidden) + })) + defer srv.Close() + + c := NewClient(srv.URL, "legacytoken", 1) + _, err := c.Search(context.Background(), "host") + if err == nil { + t.Fatal("expected error on 403") + } + if !strings.Contains(err.Error(), "v1 token") { + t.Errorf("expected v1 hint in error, got: %v", err) + } +} + +func Test403_V2Token_NoV1Hint(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "forbidden", http.StatusForbidden) + })) + defer srv.Close() + + c := NewClient(srv.URL, "nbt_secret", 2) + _, err := c.Search(context.Background(), "host") + if err == nil { + t.Fatal("expected error on 403") + } + if strings.Contains(err.Error(), "v1 token") { + t.Errorf("v1 hint should not appear for v2 token, got: %v", err) + } + if !strings.Contains(err.Error(), "check token permissions") { + t.Errorf("expected permissions hint in error, got: %v", err) + } +} + +func TestGet_SendsAuthorizationHeader(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + b, _ := json.Marshal(deviceListResponse()) + w.Write(b) + })) + defer srv.Close() + + c := NewClient(srv.URL, "nbt_mytoken", 2) + c.Search(context.Background(), "") //nolint:errcheck + + want := "Token nbt_mytoken" + if gotAuth != want { + t.Errorf("Authorization header: got %q, want %q", gotAuth, want) + } +} + func TestStripPrefix(t *testing.T) { tests := []struct { in string diff --git a/internal/resolver/management_test.go b/internal/resolver/management_test.go index 56f9d00..f92bb39 100644 --- a/internal/resolver/management_test.go +++ b/internal/resolver/management_test.go @@ -36,7 +36,7 @@ func TestManagementSubnetStrategy_MatchesSubnet(t *testing.T) { defer srv.Close() s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"}) - client := netbox.NewClient(srv.URL, "token") + client := netbox.NewClient(srv.URL, "token", 0) ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) if err != nil { @@ -52,7 +52,7 @@ func TestManagementSubnetStrategy_NoMatch(t *testing.T) { defer srv.Close() s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"}) - client := netbox.NewClient(srv.URL, "token") + client := netbox.NewClient(srv.URL, "token", 0) _, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) if err != ErrNoIP { @@ -65,7 +65,7 @@ func TestManagementSubnetStrategy_FirstMatchWins(t *testing.T) { defer srv.Close() s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"}) - client := netbox.NewClient(srv.URL, "token") + client := netbox.NewClient(srv.URL, "token", 0) ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) if err != nil { @@ -81,7 +81,7 @@ func TestManagementSubnetStrategy_VMKind(t *testing.T) { defer srv.Close() s, _ := NewManagementSubnetStrategy([]string{"172.16.0.0/12"}) - client := netbox.NewClient(srv.URL, "token") + client := netbox.NewClient(srv.URL, "token", 0) ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 2, Kind: "vm"}, client) if err != nil { @@ -97,7 +97,7 @@ func TestManagementSubnetStrategy_IPv6Subnet(t *testing.T) { defer srv.Close() s, _ := NewManagementSubnetStrategy([]string{"fd00::/8"}) - client := netbox.NewClient(srv.URL, "token") + client := netbox.NewClient(srv.URL, "token", 0) ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) if err != nil { diff --git a/internal/setup/wizard_test.go b/internal/setup/wizard_test.go new file mode 100644 index 0000000..95b1948 --- /dev/null +++ b/internal/setup/wizard_test.go @@ -0,0 +1,164 @@ +package setup + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config" +) + +func TestSave_WritesFile(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + cfg := config.Config{ + NetBox: config.NetBoxConfig{ + URL: "https://netbox.example.com", + Token: "nbt_abc123", + TokenVersion: 2, + }, + SSH: config.SSHConfig{DefaultUser: "admin"}, + Resolver: config.ResolverConfig{ + Strategies: []string{"primary_ip", "management_subnet"}, + ManagementSubnets: []string{"10.0.0.0/8"}, + }, + Cache: config.CacheConfig{TTL: 3600}, + } + + if err := save(cfg); err != nil { + t.Fatalf("save: %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "netssh.yaml")) + if err != nil { + t.Fatalf("reading saved file: %v", err) + } + content := string(data) + + for _, want := range []string{ + `"https://netbox.example.com"`, + `"nbt_abc123"`, + `token_version: 2`, + `- primary_ip`, + `- management_subnet`, + `- 10.0.0.0/8`, + `ttl: 3600`, + `"admin"`, + } { + if !strings.Contains(content, want) { + t.Errorf("saved config missing %q\nfull content:\n%s", want, content) + } + } +} + +func TestSave_FilePermissions(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + if err := save(config.Config{ + NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1}, + Cache: config.CacheConfig{TTL: 60}, + }); err != nil { + t.Fatalf("save: %v", err) + } + + info, err := os.Stat(filepath.Join(dir, "netssh.yaml")) + if err != nil { + t.Fatalf("stat: %v", err) + } + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("file permissions: got %o, want 600", perm) + } +} + +func TestSave_OmitsEmptyOptionalFields(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + cfg := config.Config{ + NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1}, + Cache: config.CacheConfig{TTL: 60}, + // No DefaultUser, no ManagementSubnets, no InterfaceName + } + if err := save(cfg); err != nil { + t.Fatalf("save: %v", err) + } + + data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml")) + content := string(data) + + for _, absent := range []string{"default_user", "management_subnets", "interface_name"} { + if strings.Contains(content, absent) { + t.Errorf("config should not contain %q when field is empty\nfull content:\n%s", absent, content) + } + } +} + +func TestSave_CreatesConfigDir(t *testing.T) { + dir := filepath.Join(t.TempDir(), "does", "not", "exist") + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + if err := save(config.Config{ + NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1}, + Cache: config.CacheConfig{TTL: 60}, + }); err != nil { + t.Fatalf("save should create missing directories: %v", err) + } +} + +func TestSave_RoundtripViaLoad(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + original := config.Config{ + NetBox: config.NetBoxConfig{ + URL: "https://netbox.zb-server.de", + Token: "nbt_supersecret", + TokenVersion: 2, + }, + SSH: config.SSHConfig{DefaultUser: "root"}, + Resolver: config.ResolverConfig{ + Strategies: []string{"primary_ip"}, + ManagementSubnets: []string{"192.168.0.0/16"}, + InterfaceName: "eth0", + }, + Cache: config.CacheConfig{TTL: 7200}, + } + + if err := save(original); err != nil { + t.Fatalf("save: %v", err) + } + + loaded, err := config.Load() + if err != nil { + t.Fatalf("Load after save: %v", err) + } + + if loaded.NetBox.URL != original.NetBox.URL { + t.Errorf("URL: got %q, want %q", loaded.NetBox.URL, original.NetBox.URL) + } + if loaded.NetBox.Token != original.NetBox.Token { + t.Errorf("Token: got %q, want %q", loaded.NetBox.Token, original.NetBox.Token) + } + if loaded.NetBox.TokenVersion != original.NetBox.TokenVersion { + t.Errorf("TokenVersion: got %d, want %d", loaded.NetBox.TokenVersion, original.NetBox.TokenVersion) + } + if loaded.SSH.DefaultUser != original.SSH.DefaultUser { + t.Errorf("DefaultUser: got %q, want %q", loaded.SSH.DefaultUser, original.SSH.DefaultUser) + } + if loaded.Cache.TTL != original.Cache.TTL { + t.Errorf("TTL: got %d, want %d", loaded.Cache.TTL, original.Cache.TTL) + } +}