diff --git a/README.md b/README.md index 2f6c1a4..799cf62 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,10 @@ netssh -p 2222 admin@app-server-03 uptime - **Setup wizard** — interactive first-run onboarding; re-run anytime with `netssh configure` - **Shell completion** — install without sudo via `netssh completion install` - **Default SSH user** — set a fallback username once in config instead of typing it every time +- **Default SSH port** — set a fallback port in config; injected as `-p` when not specified on the command line +- **Host shortcuts** — type `web01` instead of `web01.example.com`; configurable domain strip and hyphen folding +- **Shell aliases** — `netssh alias` generates shell aliases for all cached hosts +- **Automatic 24h cache refresh** — the cache is refreshed in the background every 24 hours ## Installation @@ -94,6 +98,12 @@ cache: ssh: default_user: admin # used when no user is specified on the command line + default_port: 2222 # optional; injected as -p if not specified on the command line + +shortcuts: + domains: # strip these suffixes to create short aliases + - .example.com + strip_hyphens: false # if true, fsn1-web01 → fsn1web01 (cache-only resolution) ``` Any value can be overridden with environment variables (`NETSSH_NETBOX_URL`, `NETSSH_NETBOX_TOKEN`, etc.). @@ -139,6 +149,76 @@ netssh root@my-router # user@ prefix wins → ssh root@10.0.0.1 netssh -l ops my-router # -l flag wins → ssh -l ops 10.0.0.1 ``` +### Default SSH port + +Set `ssh.default_port` in the config to use a non-standard port without specifying it every time: + +```sh +netssh my-router # → ssh -p 2222 -l admin 10.0.0.1 +``` + +The default is only applied when `-p` is not already present in the SSH arguments. An explicit `-p` always takes precedence: + +```sh +netssh -p 22 my-router # -p flag wins → ssh -p 22 10.0.0.1 +``` + +### Host shortcuts + +Configure short aliases so you can type `web01` instead of `web01.example.com`: + +```yaml +shortcuts: + domains: + - .example.com + - .example.de + strip_hyphens: false # optional +``` + +Resolution order: + +1. **Exact match** — the input is looked up as-is in the cache. +2. **Domain expansion** — if the input contains no dot, each configured domain is appended and tried (`web01` → `web01.example.com`). If no cache hit, NetBox is queried with the expanded name. +3. **Hyphen folding** (`strip_hyphens: true`) — hyphens are inserted back before looking up in the cache (`fsn1web01` → `fsn1-web01.example.com`). This only works when the host is already cached; run `netssh cache refresh` first. + +| Input | Config | Resolved as | +|-------|--------|-------------| +| `web01` | `domains: [.example.com]` | `web01.example.com` | +| `fsn1web01` | `domains: [.example.com]`, `strip_hyphens: true` | `fsn1-web01.example.com` (cache only) | + +### Shell aliases + +Generate shell aliases for all cached hosts based on the shortcut settings: + +```sh +netssh alias # bash/zsh syntax +netssh alias --shell fish # fish syntax +``` + +Example output: + +```bash +# netssh aliases (3 hosts) — source with: eval "$(netssh alias)" +alias db01='netssh db01.example.com' +alias fsn1web01='netssh fsn1-web01.example.com' +alias web01='netssh web01.example.com' +``` + +Add to your shell startup file: + +```sh +# bash / zsh — in ~/.bashrc or ~/.zshrc +eval "$(netssh alias)" + +# fish — in ~/.config/fish/config.fish +netssh alias --shell fish | source + +# scripts +source <(netssh alias) +``` + +Without a `shortcuts` config, the alias name is derived from the full hostname with dots replaced by underscores (e.g. `web01_example_com`). If two hosts would produce the same alias name, the first entry wins and a warning is printed to stderr. + ### Interactive TUI Run without arguments to open the interactive search: @@ -186,6 +266,25 @@ netssh cache refresh --tag prod --kind vm # combine filters netssh cache clear # wipe the cache ``` +The cache is refreshed automatically in the background every 24 hours. The trigger fires on the next SSH connect or TUI start after the interval has elapsed — there is no delay to the connection itself. The timestamp of the last refresh is stored in `~/.cache/netssh/last_refresh`. + +To also trigger the check at shell startup (before you run any SSH command), install the shell hook: + +```sh +netssh hook install # auto-detects $SHELL +netssh hook install --shell bash +netssh hook install --shell zsh +netssh hook install --shell fish +``` + +This appends a single line (`netssh shell-init`) to your shell profile. To remove it: + +```sh +netssh hook uninstall +``` + +The hook is a no-op when the cache is fresh — it adds no measurable delay to your shell startup. + ### Search (for scripting) ```sh @@ -228,6 +327,59 @@ autoload -Uz compinit && compinit Completions are served from the local cache — no network request on every ``. +## Shell Hook + +The shell hook runs `netssh shell-init` at the start of every new shell session. It checks whether the cache is older than 24 hours and, if so, starts a background refresh. The check reads a single small file and adds no measurable delay to your shell startup. + +### Install + +```sh +netssh hook install # auto-detects $SHELL +netssh hook install --shell bash +netssh hook install --shell zsh +netssh hook install --shell fish +``` + +This appends exactly one line to your shell profile: + +```sh +netssh shell-init # netssh cache auto-refresh +``` + +| Shell | Profile file | +|-------|-------------| +| bash | `~/.bashrc` | +| zsh | `~/.zshrc` | +| fish | `~/.config/fish/config.fish` | + +The install is idempotent — running it again does nothing if the hook is already present. + +Reload your profile after installation: + +```sh +source ~/.bashrc # bash +source ~/.zshrc # zsh +source ~/.config/fish/config.fish # fish +``` + +### Uninstall + +```sh +netssh hook uninstall # auto-detects $SHELL +netssh hook uninstall --shell zsh +``` + +Removes the `netssh shell-init` line from the profile and collapses any blank lines left behind. + +### How it differs from the connect-time trigger + +| Trigger | When it fires | +|---------|--------------| +| Connect / TUI start | On the next SSH command or `netssh` TUI after 24 h | +| Shell hook | On the first new shell session after 24 h | + +Both triggers are non-blocking: the refresh runs in the background and your SSH connection (or prompt) is not delayed. You can install both — they share the same `~/.cache/netssh/last_refresh` timestamp, so the background process runs at most once per 24 hours regardless of how many shells or connections you open. + ## Development ```sh @@ -235,7 +387,7 @@ go test ./... # run all tests go build ./... # build all packages ``` -The test suite covers the cache, NetBox client (via `httptest`), IP resolver chain, SSH argument parser, config loading, and the setup wizard. +The test suite covers the cache, NetBox client (via `httptest`), IP resolver chain, SSH argument parser, config loading, setup wizard, shortcut normalization, and shell hook install/uninstall. ## Disclaimer @@ -246,6 +398,6 @@ This is a **vibe-coded** project: the entire codebase — architecture, implemen 1. `netssh` checks whether the first argument is a known subcommand (`configure`, `search`, `cache`, `completion`). If not, it enters SSH wrapper mode. 2. On first run or when `netbox.url` is empty, the interactive setup wizard starts automatically. 3. It parses the SSH arguments to extract the destination hostname, handling all flags that consume an extra argument (`-p`, `-i`, `-J`, …). -4. It checks the local cache. If the entry exists and is within the TTL, it connects immediately. +4. It resolves the destination: first an exact cache lookup, then shortcut expansion (domain append, hyphen folding) against the cache, and finally a fresh NetBox query with the expanded name. If the cache entry is within the TTL, no network request is made. 5. Otherwise it queries NetBox (`/api/dcim/devices/` and `/api/virtualization/virtual-machines/` in parallel), runs the result through the resolver chain, and caches the IP. 6. It calls `syscall.Exec` to replace itself with `ssh`, substituting the hostname with the resolved IP. diff --git a/cmd/netssh/main.go b/cmd/netssh/main.go index 64e101b..cdc084f 100644 --- a/cmd/netssh/main.go +++ b/cmd/netssh/main.go @@ -4,8 +4,10 @@ import ( "context" "fmt" "os" + "os/exec" "path/filepath" "sort" + "strconv" "strings" "text/tabwriter" "time" @@ -15,9 +17,11 @@ import ( "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/cache" "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config" + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/hook" "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/resolver" "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/setup" + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/shortcuts" internalssh "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/ssh" "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/tui" ) @@ -27,6 +31,9 @@ var managedSubcommands = map[string]bool{ "configure": true, "search": true, "cache": true, + "alias": true, + "hook": true, + "shell-init": true, "completion": true, "__complete": true, "help": true, @@ -87,9 +94,15 @@ func runSSHWrapper(args []string) { args = append([]string{"-l", cfg.SSH.DefaultUser}, args...) parsed.DestIdx += 2 } + // Inject the configured default port if none was given on the command line. + if cfg.SSH.DefaultPort > 0 && !internalssh.HasPortFlag(args) { + args = append([]string{"-p", strconv.Itoa(cfg.SSH.DefaultPort)}, args...) + parsed.DestIdx += 2 + } c := cache.New(cfg.Cache.Path, cfg.Cache.TTL) _ = c.Load() + maybeBackgroundRefresh(cfg, c) // Cache hit with a fresh TTL — connect directly without querying NetBox. if entry, fresh := c.Get(parsed.Host); fresh { @@ -99,6 +112,24 @@ func runSSHWrapper(args []string) { return } + // Shortcut cache lookup: scan entries whose normalized name matches the input. + // If a stale match is found, use the canonical name for the NetBox re-fetch below. + lookupHost := parsed.Host + shortcutsEnabled := len(cfg.Shortcuts.Domains) > 0 || cfg.Shortcuts.StripHyphens + if shortcutsEnabled { + normalize := shortcuts.MakeNormalizer(cfg.Shortcuts) + if entry, found, fresh := c.GetByShortcut(lookupHost, normalize); found { + if fresh { + c.MarkUsed(entry.Name) + _ = c.Save() + connect(entry.IP, parsed, args) + return + } + // Stale shortcut match — use canonical name for NetBox re-fetch. + lookupHost = entry.Name + } + } + if cfg.NetBox.URL == "" { fatalf("netbox.url is not configured (~/.config/netssh.yaml)") } @@ -107,18 +138,40 @@ func runSSHWrapper(args []string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - entries, err := nbClient.Search(ctx, parsed.Host, netbox.SearchOptions{}) + entries, err := nbClient.Search(ctx, lookupHost, netbox.SearchOptions{}) if err != nil { fatalf("NetBox search failed: %v", err) } var target *netbox.HostEntry for i, e := range entries { - if strings.EqualFold(e.Name, parsed.Host) { + if strings.EqualFold(e.Name, lookupHost) { target = &entries[i] break } } + + // Domain expansion: if still no match and the input has no dots, try appending + // each configured domain suffix and re-querying NetBox. + if target == nil && !strings.Contains(parsed.Host, ".") && len(cfg.Shortcuts.Domains) > 0 { + for _, domain := range cfg.Shortcuts.Domains { + expanded := parsed.Host + domain + expandedEntries, searchErr := nbClient.Search(ctx, expanded, netbox.SearchOptions{}) + if searchErr != nil { + continue + } + for i, e := range expandedEntries { + if strings.EqualFold(e.Name, expanded) { + target = &expandedEntries[i] + break + } + } + if target != nil { + break + } + } + } + if target == nil { fatalf("host %q not found in NetBox", parsed.Host) } @@ -153,6 +206,7 @@ func runTUI() { c := cache.New(cfg.Cache.Path, cfg.Cache.TTL) _ = c.Load() + maybeBackgroundRefresh(cfg, c) var nbClient *netbox.Client if cfg.NetBox.URL != "" { @@ -188,8 +242,12 @@ func runTUI() { sshArgs = append(sshArgs, "-l", user) } - if host.Port != "" { - sshArgs = append(sshArgs, "-p", host.Port) + port := host.Port + if port == "" && cfg.SSH.DefaultPort > 0 { + port = strconv.Itoa(cfg.SSH.DefaultPort) + } + if port != "" { + sshArgs = append(sshArgs, "-p", port) } sshArgs = append(sshArgs, host.IP) @@ -231,7 +289,7 @@ func rootCmd() *cobra.Command { }, } - root.AddCommand(configureCmd(), searchCmd(), cacheCmd()) + root.AddCommand(configureCmd(), searchCmd(), cacheCmd(), aliasCmd(), hookCmd(), shellInitCmd()) // cobra builds the "completion" command lazily; force init so we can extend it. root.InitDefaultCompletionCmd() @@ -460,6 +518,9 @@ func cacheRefreshCmd() *cobra.Command { if err := c.Save(); err != nil { return err } + if err := c.SetRefreshed(); err != nil { + fmt.Fprintf(os.Stderr, "warning: could not update refresh timestamp: %v\n", err) + } fmt.Printf("%d entries written to cache.\n", len(entries)) return nil }, @@ -469,6 +530,193 @@ func cacheRefreshCmd() *cobra.Command { return cmd } +func aliasCmd() *cobra.Command { + var shell string + cmd := &cobra.Command{ + Use: "alias", + Short: "Print shell aliases for all cached hosts", + Long: `Print shell alias definitions for all cached hosts. +The alias name is the shortened form derived from the configured shortcuts +(domain suffixes stripped, hyphens optionally stripped). + +Source the output in your shell profile: + bash/zsh: eval "$(netssh alias)" + fish: netssh alias --shell fish | source + +Or use in a script: + source <(netssh alias)`, + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return err + } + c := cache.New(cfg.Cache.Path, cfg.Cache.TTL) + if err := c.Load(); err != nil { + return err + } + if shell == "" { + shell = filepath.Base(os.Getenv("SHELL")) + } + + entries := c.All() + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name < entries[j].Name + }) + + // Deduplicate: first host wins when two normalize to the same alias. + seen := make(map[string]string) // alias name → canonical name + var lines []string + for _, e := range entries { + aliasName := shortcuts.AliasName(e.Name, cfg.Shortcuts) + if aliasName == "" { + continue + } + if prev, exists := seen[aliasName]; exists { + fmt.Fprintf(os.Stderr, "netssh: alias %q conflict: %s and %s — skipping %s\n", aliasName, prev, e.Name, e.Name) + continue + } + seen[aliasName] = e.Name + + switch shell { + case "fish": + lines = append(lines, fmt.Sprintf("alias %s 'netssh %s'", aliasName, e.Name)) + default: + lines = append(lines, fmt.Sprintf("alias %s='netssh %s'", aliasName, e.Name)) + } + } + + if len(lines) == 0 { + fmt.Fprintln(os.Stderr, "netssh: cache is empty — run 'netssh cache refresh' first") + return nil + } + + switch shell { + case "fish": + fmt.Printf("# netssh aliases (%d hosts) — source with: netssh alias --shell fish | source\n", len(lines)) + default: + fmt.Printf("# netssh aliases (%d hosts) — source with: eval \"$(netssh alias)\"\n", len(lines)) + } + for _, l := range lines { + fmt.Println(l) + } + return nil + }, + } + cmd.Flags().StringVar(&shell, "shell", "", "Output format: bash, zsh, fish (default: $SHELL)") + return cmd +} + +// shellInitCmd is called at shell startup to trigger a background cache refresh when stale. +// It is intentionally silent — no output on success so it never disrupts shell startup. +func shellInitCmd() *cobra.Command { + return &cobra.Command{ + Use: "shell-init", + Short: "Trigger a background cache refresh if stale (add to shell profile via 'hook install')", + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil || cfg.NetBox.URL == "" { + return nil + } + c := cache.New(cfg.Cache.Path, cfg.Cache.TTL) + maybeBackgroundRefresh(cfg, c) + return nil + }, + } +} + +func hookCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "hook", + Short: "Manage shell hooks for automatic cache refresh at shell startup", + } + cmd.AddCommand(hookInstallCmd(), hookUninstallCmd()) + return cmd +} + +func hookInstallCmd() *cobra.Command { + var shell string + cmd := &cobra.Command{ + Use: "install", + Short: "Add netssh shell-init to your shell profile", + Long: `Appends a single line to your shell profile that runs 'netssh shell-init' +on every new shell session. shell-init checks whether the cache is older +than 24 hours and, if so, starts a background refresh — no delay to your prompt. + +After installation, reload your profile or open a new shell.`, + RunE: func(cmd *cobra.Command, args []string) error { + if shell == "" { + shell = filepath.Base(os.Getenv("SHELL")) + } + profile, err := hook.ProfilePath(shell) + if err != nil { + return err + } + installed, err := hook.Install(profile) + if err != nil { + return err + } + if !installed { + fmt.Printf("Hook already installed in %s\n", profile) + return nil + } + fmt.Printf("Hook installed → %s\n%s\n", profile, hook.ReloadNote(profile)) + return nil + }, + } + cmd.Flags().StringVar(&shell, "shell", "", "Shell to install for (default: $SHELL). Supported: bash, zsh, fish") + return cmd +} + +func hookUninstallCmd() *cobra.Command { + var shell string + cmd := &cobra.Command{ + Use: "uninstall", + Short: "Remove netssh shell-init from your shell profile", + RunE: func(cmd *cobra.Command, args []string) error { + if shell == "" { + shell = filepath.Base(os.Getenv("SHELL")) + } + profile, err := hook.ProfilePath(shell) + if err != nil { + return err + } + removed, err := hook.Uninstall(profile) + if err != nil { + return err + } + if !removed { + fmt.Printf("No hook found in %s\n", profile) + return nil + } + fmt.Printf("Hook removed from %s\n", profile) + return nil + }, + } + cmd.Flags().StringVar(&shell, "shell", "", "Shell to uninstall from (default: $SHELL). Supported: bash, zsh, fish") + return cmd +} + +// maybeBackgroundRefresh starts a background `netssh cache refresh` if the cache +// has not been fully refreshed within the last 24 hours. +func maybeBackgroundRefresh(cfg *config.Config, c *cache.Cache) { + if cfg.NetBox.URL == "" { + return + } + if !c.NeedsRefresh(24 * time.Hour) { + return + } + self, err := os.Executable() + if err != nil { + return + } + cmd := exec.Command(self, "cache", "refresh") + cmd.Stdout = nil + cmd.Stderr = nil + cmd.Stdin = nil + _ = cmd.Start() // fire and forget — becomes orphan after parent execs ssh +} + func fatalf(format string, args ...any) { fmt.Fprintf(os.Stderr, "netssh: "+format+"\n", args...) os.Exit(1) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index d76101d..b09f5b1 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -20,10 +20,11 @@ type Entry struct { } type Cache struct { - mu sync.RWMutex - entries map[string]Entry - path string - ttl time.Duration + mu sync.RWMutex + entries map[string]Entry + path string + ttl time.Duration + refreshStamp string // path to last_refresh timestamp file } type diskFormat struct { @@ -31,11 +32,48 @@ type diskFormat struct { } func New(path string, ttlSeconds int) *Cache { - return &Cache{ - entries: make(map[string]Entry), - path: path, - ttl: time.Duration(ttlSeconds) * time.Second, + stamp := "" + if path != "" { + stamp = filepath.Join(filepath.Dir(path), "last_refresh") } + return &Cache{ + entries: make(map[string]Entry), + path: path, + ttl: time.Duration(ttlSeconds) * time.Second, + refreshStamp: stamp, + } +} + +// NeedsRefresh reports whether the last full cache refresh (via `cache refresh`) +// is older than d, or has never happened. Always returns false when no path is set. +func (c *Cache) NeedsRefresh(d time.Duration) bool { + if c.refreshStamp == "" { + return false + } + data, err := os.ReadFile(c.refreshStamp) + if err != nil { + return true // file missing → never refreshed + } + var t time.Time + if err := t.UnmarshalText(data); err != nil { + return true + } + return time.Since(t) >= d +} + +// SetRefreshed records the current time as the last successful full refresh. +func (c *Cache) SetRefreshed() error { + if c.refreshStamp == "" { + return nil + } + data, err := time.Now().MarshalText() + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(c.refreshStamp), 0o755); err != nil { + return err + } + return os.WriteFile(c.refreshStamp, data, 0o644) } func (c *Cache) Load() error { @@ -137,6 +175,21 @@ func (c *Cache) Search(prefix string) []Entry { return out } +// GetByShortcut scans all entries and returns the first whose normalized name matches +// the normalized shortcut. Returns (entry, found, fresh). +func (c *Cache) GetByShortcut(shortcut string, normalize func(string) string) (entry Entry, found bool, fresh bool) { + c.mu.RLock() + defer c.mu.RUnlock() + norm := normalize(shortcut) + for _, e := range c.entries { + if normalize(e.Name) == norm { + isFresh := c.ttl > 0 && time.Since(e.CachedAt) < c.ttl + return e, true, isFresh + } + } + return Entry{}, false, false +} + // Get returns an entry and reports whether it is still within the TTL. func (c *Cache) Get(name string) (entry Entry, fresh bool) { c.mu.RLock() diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index a8fd719..c57da31 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -300,6 +300,164 @@ func TestMarkUsed_RoundtripViaSave(t *testing.T) { } } +// --- GetByShortcut tests --- + +func TestGetByShortcut_MatchFresh(t *testing.T) { + c := New("", 3600) + c.Upsert(Entry{Name: "web01.example.com", IP: "10.0.0.1", Kind: "device"}) + + normalize := func(s string) string { + // strip .example.com + if len(s) > len(".example.com") && s[len(s)-len(".example.com"):] == ".example.com" { + s = s[:len(s)-len(".example.com")] + } + return s + } + + e, found, fresh := c.GetByShortcut("web01", normalize) + if !found { + t.Fatal("expected entry to be found") + } + if !fresh { + t.Error("entry just inserted should be fresh") + } + if e.IP != "10.0.0.1" { + t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.1") + } +} + +func TestGetByShortcut_MatchStale(t *testing.T) { + c := New("", 1) // 1 second TTL + stale := Entry{Name: "web01.example.com", IP: "10.0.0.1", Kind: "device", CachedAt: time.Now().Add(-2 * time.Second)} + c.mu.Lock() + c.entries["web01.example.com"] = stale + c.mu.Unlock() + + normalize := func(s string) string { + if len(s) > len(".example.com") && s[len(s)-len(".example.com"):] == ".example.com" { + s = s[:len(s)-len(".example.com")] + } + return s + } + + _, found, fresh := c.GetByShortcut("web01", normalize) + if !found { + t.Fatal("expected entry to be found even when stale") + } + if fresh { + t.Error("entry older than TTL should not be fresh") + } +} + +func TestGetByShortcut_NotFound(t *testing.T) { + c := New("", 3600) + c.Upsert(Entry{Name: "db01.example.com", IP: "10.0.0.2", Kind: "device"}) + + _, found, _ := c.GetByShortcut("web01", func(s string) string { return s }) + if found { + t.Error("should not find an entry that does not match the shortcut") + } +} + +func TestGetByShortcut_EmptyCache(t *testing.T) { + c := New("", 3600) + _, found, _ := c.GetByShortcut("web01", func(s string) string { return s }) + if found { + t.Error("empty cache should return found=false") + } +} + +func TestGetByShortcut_MultiDomain(t *testing.T) { + c := New("", 3600) + // Entry uses second domain (.example.de) + c.Upsert(Entry{Name: "web01.example.de", IP: "10.0.0.5", Kind: "vm"}) + + normalize := func(s string) string { + for _, suffix := range []string{".example.com", ".example.de"} { + if len(s) > len(suffix) && s[len(s)-len(suffix):] == suffix { + return s[:len(s)-len(suffix)] + } + } + return s + } + + e, found, _ := c.GetByShortcut("web01", normalize) + if !found { + t.Fatal("should match entry with second configured domain") + } + if e.IP != "10.0.0.5" { + t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.5") + } +} + +// --- NeedsRefresh / SetRefreshed tests --- +// These tests require NeedsRefresh(time.Duration) bool and SetRefreshed() error +// to be implemented on *Cache. The refreshStamp field (path to the timestamp file) +// must be set before calling these methods. + +func TestNeedsRefresh_NeverRefreshed(t *testing.T) { + dir := t.TempDir() + stampPath := filepath.Join(dir, "last_refresh") + c := New(filepath.Join(dir, "cache.json"), 3600) + c.refreshStamp = stampPath + + if !c.NeedsRefresh(24 * time.Hour) { + t.Error("NeedsRefresh should return true when last_refresh file does not exist") + } +} + +func TestNeedsRefresh_JustRefreshed(t *testing.T) { + dir := t.TempDir() + stampPath := filepath.Join(dir, "last_refresh") + c := New(filepath.Join(dir, "cache.json"), 3600) + c.refreshStamp = stampPath + + if err := c.SetRefreshed(); err != nil { + t.Fatalf("SetRefreshed: %v", err) + } + if c.NeedsRefresh(24 * time.Hour) { + t.Error("NeedsRefresh should return false immediately after SetRefreshed") + } +} + +func TestNeedsRefresh_Stale(t *testing.T) { + dir := t.TempDir() + stampPath := filepath.Join(dir, "last_refresh") + // Write a timestamp 25 hours ago + oldTime := time.Now().Add(-25 * time.Hour).Format(time.RFC3339) + if err := os.WriteFile(stampPath, []byte(oldTime), 0o644); err != nil { + t.Fatalf("writing stamp: %v", err) + } + + c := New(filepath.Join(dir, "cache.json"), 3600) + c.refreshStamp = stampPath + + if !c.NeedsRefresh(24 * time.Hour) { + t.Error("NeedsRefresh should return true when last_refresh is older than duration") + } +} + +func TestSetRefreshed_Roundtrip(t *testing.T) { + dir := t.TempDir() + stampPath := filepath.Join(dir, "last_refresh") + c := New(filepath.Join(dir, "cache.json"), 3600) + c.refreshStamp = stampPath + + if err := c.SetRefreshed(); err != nil { + t.Fatalf("SetRefreshed: %v", err) + } + + // Just refreshed: 24h window → not stale + if c.NeedsRefresh(24 * time.Hour) { + t.Error("NeedsRefresh(24h) should be false right after SetRefreshed") + } + + // Zero duration: everything is stale + if !c.NeedsRefresh(0) { + t.Error("NeedsRefresh(0) should always return true") + } +} + // tempFile writes content to a temp file and returns its path. func tempFile(t *testing.T, content []byte) string { t.Helper() diff --git a/internal/config/config.go b/internal/config/config.go index 78d1998..892d91f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,10 +10,11 @@ import ( ) type Config struct { - NetBox NetBoxConfig `mapstructure:"netbox"` - Resolver ResolverConfig `mapstructure:"resolver"` - Cache CacheConfig `mapstructure:"cache"` - SSH SSHConfig `mapstructure:"ssh"` + NetBox NetBoxConfig `mapstructure:"netbox"` + Resolver ResolverConfig `mapstructure:"resolver"` + Cache CacheConfig `mapstructure:"cache"` + SSH SSHConfig `mapstructure:"ssh"` + Shortcuts ShortcutsConfig `mapstructure:"shortcuts"` } type NetBoxConfig struct { @@ -35,6 +36,12 @@ type CacheConfig struct { type SSHConfig struct { DefaultUser string `mapstructure:"default_user"` + DefaultPort int `mapstructure:"default_port"` +} + +type ShortcutsConfig struct { + Domains []string `mapstructure:"domains"` + StripHyphens bool `mapstructure:"strip_hyphens"` } // Path returns the canonical config file path. diff --git a/internal/hook/hook.go b/internal/hook/hook.go new file mode 100644 index 0000000..bb973cc --- /dev/null +++ b/internal/hook/hook.go @@ -0,0 +1,101 @@ +package hook + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +const ( + // Marker is the unique string used to detect an existing hook installation. + Marker = "netssh shell-init" + // Line is the full line written into the shell profile. + Line = "netssh shell-init # netssh cache auto-refresh" +) + +// ProfilePath returns the canonical shell profile path for the given shell name. +func ProfilePath(shell string) (string, error) { + switch shell { + case "bash": + return filepath.Join(os.Getenv("HOME"), ".bashrc"), nil + case "zsh": + return filepath.Join(os.Getenv("HOME"), ".zshrc"), nil + case "fish": + configDir, err := os.UserConfigDir() + if err != nil { + configDir = filepath.Join(os.Getenv("HOME"), ".config") + } + return filepath.Join(configDir, "fish", "config.fish"), nil + default: + return "", fmt.Errorf("unsupported shell %q — supported: bash, zsh, fish", shell) + } +} + +// ReloadNote returns the shell-specific hint for reloading the profile. +func ReloadNote(profilePath string) string { + return "Reload with: source " + profilePath +} + +// IsInstalled reports whether the hook marker is present in the profile file. +// Returns false if the file does not exist or cannot be read. +func IsInstalled(profilePath string) bool { + data, err := os.ReadFile(profilePath) + if err != nil { + return false + } + return strings.Contains(string(data), Marker) +} + +// Install appends the hook line to the profile if it is not already present. +// Returns (true, nil) when newly installed, (false, nil) when already present. +// Creates the profile file (and any parent directories) if they do not exist. +func Install(profilePath string) (installed bool, err error) { + if IsInstalled(profilePath) { + return false, nil + } + if err := os.MkdirAll(filepath.Dir(profilePath), 0o755); err != nil { + return false, fmt.Errorf("creating directory: %w", err) + } + f, err := os.OpenFile(profilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return false, fmt.Errorf("opening profile: %w", err) + } + defer f.Close() + if _, err := fmt.Fprintf(f, "\n%s\n", Line); err != nil { + return false, err + } + return true, nil +} + +// Uninstall removes the hook line from the profile. +// Returns (true, nil) when removed, (false, nil) when the hook was not found. +// Returns (false, nil) — not an error — when the file does not exist. +func Uninstall(profilePath string) (removed bool, err error) { + data, err := os.ReadFile(profilePath) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("reading profile: %w", err) + } + content := string(data) + if !strings.Contains(content, Marker) { + return false, nil + } + + var kept []string + for _, line := range strings.Split(content, "\n") { + if strings.Contains(line, Marker) { + continue + } + kept = append(kept, line) + } + // Collapse triple blank lines that the removal may leave. + cleaned := strings.ReplaceAll(strings.Join(kept, "\n"), "\n\n\n", "\n\n") + + if err := os.WriteFile(profilePath, []byte(cleaned), 0o644); err != nil { + return false, fmt.Errorf("writing profile: %w", err) + } + return true, nil +} diff --git a/internal/hook/hook_test.go b/internal/hook/hook_test.go new file mode 100644 index 0000000..0a6bc0c --- /dev/null +++ b/internal/hook/hook_test.go @@ -0,0 +1,263 @@ +package hook_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/hook" +) + +// --- ProfilePath --- + +func TestProfilePath_Bash(t *testing.T) { + orig := os.Getenv("HOME") + os.Setenv("HOME", t.TempDir()) + defer os.Setenv("HOME", orig) + + p, err := hook.ProfilePath("bash") + if err != nil { + t.Fatalf("ProfilePath(bash): %v", err) + } + if !strings.HasSuffix(p, ".bashrc") { + t.Errorf("bash profile should end with .bashrc, got %q", p) + } +} + +func TestProfilePath_Zsh(t *testing.T) { + orig := os.Getenv("HOME") + os.Setenv("HOME", t.TempDir()) + defer os.Setenv("HOME", orig) + + p, err := hook.ProfilePath("zsh") + if err != nil { + t.Fatalf("ProfilePath(zsh): %v", err) + } + if !strings.HasSuffix(p, ".zshrc") { + t.Errorf("zsh profile should end with .zshrc, got %q", p) + } +} + +func TestProfilePath_Fish(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + p, err := hook.ProfilePath("fish") + if err != nil { + t.Fatalf("ProfilePath(fish): %v", err) + } + if !strings.HasSuffix(p, "config.fish") { + t.Errorf("fish profile should end with config.fish, got %q", p) + } +} + +func TestProfilePath_UnknownShell(t *testing.T) { + if _, err := hook.ProfilePath("ksh"); err == nil { + t.Error("unknown shell should return an error") + } +} + +// --- IsInstalled --- + +func TestIsInstalled_Present(t *testing.T) { + profile := writeProfile(t, hook.Line+"\n") + if !hook.IsInstalled(profile) { + t.Error("IsInstalled should return true when marker is present") + } +} + +func TestIsInstalled_Absent(t *testing.T) { + profile := writeProfile(t, "export PATH=$PATH\n") + if hook.IsInstalled(profile) { + t.Error("IsInstalled should return false when marker is absent") + } +} + +func TestIsInstalled_MissingFile(t *testing.T) { + if hook.IsInstalled("/nonexistent/path/.bashrc") { + t.Error("IsInstalled should return false for missing file") + } +} + +// --- Install --- + +func TestInstall_Fresh(t *testing.T) { + profile := filepath.Join(t.TempDir(), ".bashrc") + + installed, err := hook.Install(profile) + if err != nil { + t.Fatalf("Install: %v", err) + } + if !installed { + t.Error("should report installed=true on fresh install") + } + if !hook.IsInstalled(profile) { + t.Error("profile should contain the hook line after Install") + } +} + +func TestInstall_Idempotent(t *testing.T) { + profile := filepath.Join(t.TempDir(), ".bashrc") + + hook.Install(profile) + installed, err := hook.Install(profile) + if err != nil { + t.Fatalf("second Install: %v", err) + } + if installed { + t.Error("second Install should report installed=false") + } + + data, _ := os.ReadFile(profile) + if count := strings.Count(string(data), hook.Marker); count != 1 { + t.Errorf("hook line should appear exactly once, got %d", count) + } +} + +func TestInstall_CreatesProfileAndParentDirs(t *testing.T) { + profile := filepath.Join(t.TempDir(), "nested", "dir", ".zshrc") + + if _, err := hook.Install(profile); err != nil { + t.Fatalf("Install should create missing directories: %v", err) + } + if _, err := os.Stat(profile); err != nil { + t.Error("profile file should have been created") + } +} + +func TestInstall_PreservesExistingContent(t *testing.T) { + profile := writeProfile(t, "export PATH=$PATH:/usr/local/bin\n") + + hook.Install(profile) + + data, _ := os.ReadFile(profile) + content := string(data) + if !strings.Contains(content, "export PATH") { + t.Error("Install should not remove existing content") + } + if !strings.Contains(content, hook.Marker) { + t.Error("hook line should be appended") + } +} + +func TestInstall_AppendedAtEnd(t *testing.T) { + profile := writeProfile(t, "existing line\n") + hook.Install(profile) + + data, _ := os.ReadFile(profile) + lines := strings.Split(strings.TrimRight(string(data), "\n"), "\n") + last := lines[len(lines)-1] + if last != hook.Line { + t.Errorf("hook line should be the last line, got %q", last) + } +} + +// --- Uninstall --- + +func TestUninstall_Removes(t *testing.T) { + profile := writeProfile(t, "export PATH=$PATH\n"+hook.Line+"\nexport FOO=bar\n") + + removed, err := hook.Uninstall(profile) + if err != nil { + t.Fatalf("Uninstall: %v", err) + } + if !removed { + t.Error("should report removed=true") + } + if hook.IsInstalled(profile) { + t.Error("hook line should have been removed") + } +} + +func TestUninstall_PreservesOtherContent(t *testing.T) { + profile := writeProfile(t, "export FOO=bar\n"+hook.Line+"\nexport BAZ=qux\n") + + hook.Uninstall(profile) + + data, _ := os.ReadFile(profile) + content := string(data) + if !strings.Contains(content, "export FOO=bar") { + t.Error("content before hook should be preserved") + } + if !strings.Contains(content, "export BAZ=qux") { + t.Error("content after hook should be preserved") + } +} + +func TestUninstall_NotPresent(t *testing.T) { + profile := writeProfile(t, "export PATH=$PATH\n") + + removed, err := hook.Uninstall(profile) + if err != nil { + t.Fatalf("Uninstall: %v", err) + } + if removed { + t.Error("should report removed=false when hook was not present") + } +} + +func TestUninstall_MissingFile(t *testing.T) { + removed, err := hook.Uninstall("/nonexistent/.bashrc") + if err != nil { + t.Errorf("Uninstall on missing file should not error: %v", err) + } + if removed { + t.Error("should report removed=false for missing file") + } +} + +func TestUninstall_CollapsesExtraBlankLines(t *testing.T) { + // blank line before hook + hook + blank line after = three consecutive newlines after removal + profile := writeProfile(t, "line1\n\n"+hook.Line+"\n\nline2\n") + + hook.Uninstall(profile) + + data, _ := os.ReadFile(profile) + if strings.Contains(string(data), "\n\n\n") { + t.Error("three consecutive blank lines should be collapsed after uninstall") + } +} + +// --- Roundtrip --- + +func TestInstallUninstall_Roundtrip(t *testing.T) { + profile := writeProfile(t, "existing content\n") + + hook.Install(profile) + if !hook.IsInstalled(profile) { + t.Fatal("should be installed after Install") + } + + hook.Uninstall(profile) + if hook.IsInstalled(profile) { + t.Fatal("should not be installed after Uninstall") + } + + data, _ := os.ReadFile(profile) + if !strings.Contains(string(data), "existing content") { + t.Error("original content should survive install+uninstall cycle") + } +} + +// --- ReloadNote --- + +func TestReloadNote_ContainsPath(t *testing.T) { + note := hook.ReloadNote("/home/user/.bashrc") + if !strings.Contains(note, "/home/user/.bashrc") { + t.Errorf("ReloadNote should contain the profile path, got %q", note) + } +} + +// --- helpers --- + +func writeProfile(t *testing.T, content string) string { + t.Helper() + f := filepath.Join(t.TempDir(), "profile") + if err := os.WriteFile(f, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + return f +} diff --git a/internal/setup/wizard.go b/internal/setup/wizard.go index 9cc7403..f0e586b 100644 --- a/internal/setup/wizard.go +++ b/internal/setup/wizard.go @@ -20,10 +20,16 @@ func RunWizard(cfg *config.Config) error { url := cfg.NetBox.URL token := cfg.NetBox.Token defaultUser := cfg.SSH.DefaultUser + defaultPort := "" + if cfg.SSH.DefaultPort > 0 { + defaultPort = strconv.Itoa(cfg.SSH.DefaultPort) + } strategiesRaw := strings.Join(cfg.Resolver.Strategies, ", ") subnets := strings.Join(cfg.Resolver.ManagementSubnets, ", ") interfaceName := cfg.Resolver.InterfaceName cacheTTL := strconv.Itoa(cfg.Cache.TTL) + shortcutDomains := strings.Join(cfg.Shortcuts.Domains, ", ") + stripHyphens := cfg.Shortcuts.StripHyphens if strategiesRaw == "" { strategiesRaw = "primary_ip" @@ -62,6 +68,21 @@ func RunWizard(cfg *config.Config) error { Title("Default SSH user"). Description("Leave empty to use your system user ($USER)."). Value(&defaultUser), + huh.NewInput(). + Title("Default SSH port"). + Description("Leave empty to use the standard port (22)."). + Placeholder("22"). + Value(&defaultPort). + Validate(func(s string) error { + if strings.TrimSpace(s) == "" { + return nil + } + n, err := strconv.Atoi(strings.TrimSpace(s)) + if err != nil || n < 1 || n > 65535 { + return errors.New("must be a port number between 1 and 65535") + } + return nil + }), ).Title("SSH defaults"), huh.NewGroup( @@ -90,6 +111,18 @@ func RunWizard(cfg *config.Config) error { return nil }), ).Title("Resolver & cache"), + + huh.NewGroup( + huh.NewInput(). + Title("Domain suffixes"). + Description("Comma-separated suffixes stripped for shortcuts, e.g. .example.com, .example.de\nAllows typing 'web01' instead of 'web01.example.com'."). + Placeholder(".example.com"). + Value(&shortcutDomains), + huh.NewConfirm(). + Title("Strip hyphens"). + Description("When enabled, fsn1-web01.example.com can be accessed as fsn1web01.\nOnly works for hosts already in the cache."). + Value(&stripHyphens), + ).Title("Shortcuts"), ) if err := form.Run(); err != nil { @@ -109,6 +142,11 @@ func RunWizard(cfg *config.Config) error { ttl, _ := strconv.Atoi(cacheTTL) + port := 0 + if p, err := strconv.Atoi(strings.TrimSpace(defaultPort)); err == nil { + port = p + } + var subnetList []string for _, s := range strings.Split(subnets, ",") { if s = strings.TrimSpace(s); s != "" { @@ -116,6 +154,16 @@ func RunWizard(cfg *config.Config) error { } } + var domainList []string + for _, s := range strings.Split(shortcutDomains, ",") { + if s = strings.TrimSpace(s); s != "" { + if !strings.HasPrefix(s, ".") { + s = "." + s + } + domainList = append(domainList, s) + } + } + out := config.Config{ NetBox: config.NetBoxConfig{ URL: strings.TrimRight(strings.TrimSpace(url), "/"), @@ -124,6 +172,7 @@ func RunWizard(cfg *config.Config) error { }, SSH: config.SSHConfig{ DefaultUser: strings.TrimSpace(defaultUser), + DefaultPort: port, }, Resolver: config.ResolverConfig{ Strategies: strategies, @@ -133,6 +182,10 @@ func RunWizard(cfg *config.Config) error { Cache: config.CacheConfig{ TTL: ttl, }, + Shortcuts: config.ShortcutsConfig{ + Domains: domainList, + StripHyphens: stripHyphens, + }, } return save(out) @@ -200,6 +253,22 @@ func save(cfg config.Config) error { if cfg.SSH.DefaultUser != "" { fmt.Fprintf(&b, " default_user: %q\n", cfg.SSH.DefaultUser) } + if cfg.SSH.DefaultPort > 0 { + fmt.Fprintf(&b, " default_port: %d\n", cfg.SSH.DefaultPort) + } + + if len(cfg.Shortcuts.Domains) > 0 || cfg.Shortcuts.StripHyphens { + b.WriteString("\nshortcuts:\n") + if len(cfg.Shortcuts.Domains) > 0 { + b.WriteString(" domains:\n") + for _, d := range cfg.Shortcuts.Domains { + fmt.Fprintf(&b, " - %s\n", d) + } + } + if cfg.Shortcuts.StripHyphens { + b.WriteString(" strip_hyphens: true\n") + } + } if err := os.WriteFile(path, []byte(b.String()), 0o600); err != nil { return fmt.Errorf("writing config: %w", err) diff --git a/internal/setup/wizard_test.go b/internal/setup/wizard_test.go index 4351c80..1dc7602 100644 --- a/internal/setup/wizard_test.go +++ b/internal/setup/wizard_test.go @@ -116,6 +116,148 @@ func TestSave_CreatesConfigDir(t *testing.T) { } } +func TestSave_DefaultPort(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + cfg := config.Config{ + NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1}, + Cache: config.CacheConfig{TTL: 60}, + SSH: config.SSHConfig{DefaultPort: 2222}, + } + if err := save(cfg); err != nil { + t.Fatalf("save: %v", err) + } + + data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml")) + if !strings.Contains(string(data), "default_port: 2222") { + t.Errorf("expected default_port: 2222 in config, got:\n%s", string(data)) + } +} + +func TestSave_Shortcuts_WritesSection(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + cfg := config.Config{ + NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1}, + Cache: config.CacheConfig{TTL: 60}, + Shortcuts: config.ShortcutsConfig{ + Domains: []string{".example.com"}, + StripHyphens: true, + }, + } + if err := save(cfg); err != nil { + t.Fatalf("save: %v", err) + } + + data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml")) + content := string(data) + for _, want := range []string{"shortcuts:", "domains:", ".example.com", "strip_hyphens: true"} { + if !strings.Contains(content, want) { + t.Errorf("expected %q in config, got:\n%s", want, content) + } + } +} + +func TestSave_OmitsDefaultPortWhenZero(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + cfg := config.Config{ + NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1}, + Cache: config.CacheConfig{TTL: 60}, + SSH: config.SSHConfig{DefaultPort: 0}, + } + if err := save(cfg); err != nil { + t.Fatalf("save: %v", err) + } + + data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml")) + if strings.Contains(string(data), "default_port") { + t.Errorf("default_port should be omitted when zero, got:\n%s", string(data)) + } +} + +func TestSave_OmitsShortcutsWhenEmpty(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + cfg := config.Config{ + NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1}, + Cache: config.CacheConfig{TTL: 60}, + Shortcuts: config.ShortcutsConfig{}, // empty + } + if err := save(cfg); err != nil { + t.Fatalf("save: %v", err) + } + + data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml")) + if strings.Contains(string(data), "shortcuts:") { + t.Errorf("shortcuts section should be omitted when empty, got:\n%s", string(data)) + } +} + +func TestSave_Roundtrip_WithNewFields(t *testing.T) { + dir := t.TempDir() + orig := os.Getenv("XDG_CONFIG_HOME") + os.Setenv("XDG_CONFIG_HOME", dir) + defer os.Setenv("XDG_CONFIG_HOME", orig) + + original := config.Config{ + NetBox: config.NetBoxConfig{ + URL: "https://netbox.example.com", + Token: "nbt_test", + TokenVersion: 2, + }, + SSH: config.SSHConfig{ + DefaultUser: "admin", + DefaultPort: 2222, + }, + Resolver: config.ResolverConfig{ + Strategies: []string{"primary_ip"}, + }, + Cache: config.CacheConfig{TTL: 3600}, + Shortcuts: config.ShortcutsConfig{ + Domains: []string{".example.com", ".example.de"}, + StripHyphens: true, + }, + } + + if err := save(original); err != nil { + t.Fatalf("save: %v", err) + } + + loaded, err := config.Load() + if err != nil { + t.Fatalf("Load after save: %v", err) + } + + if loaded.SSH.DefaultPort != original.SSH.DefaultPort { + t.Errorf("DefaultPort: got %d, want %d", loaded.SSH.DefaultPort, original.SSH.DefaultPort) + } + if len(loaded.Shortcuts.Domains) != len(original.Shortcuts.Domains) { + t.Errorf("Shortcuts.Domains length: got %d, want %d", len(loaded.Shortcuts.Domains), len(original.Shortcuts.Domains)) + } else { + for i, d := range original.Shortcuts.Domains { + if loaded.Shortcuts.Domains[i] != d { + t.Errorf("Shortcuts.Domains[%d]: got %q, want %q", i, loaded.Shortcuts.Domains[i], d) + } + } + } + if loaded.Shortcuts.StripHyphens != original.Shortcuts.StripHyphens { + t.Errorf("StripHyphens: got %v, want %v", loaded.Shortcuts.StripHyphens, original.Shortcuts.StripHyphens) + } +} + func TestParseStrategies(t *testing.T) { tests := []struct { in string diff --git a/internal/shortcuts/shortcuts.go b/internal/shortcuts/shortcuts.go new file mode 100644 index 0000000..e357ccb --- /dev/null +++ b/internal/shortcuts/shortcuts.go @@ -0,0 +1,42 @@ +package shortcuts + +import ( + "strings" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config" +) + +// Normalize strips configured domain suffixes and optionally hyphens from a hostname. +// The result is lowercased for case-insensitive comparison. +func Normalize(name string, cfg config.ShortcutsConfig) string { + s := strings.ToLower(name) + for _, domain := range cfg.Domains { + suffix := strings.ToLower(domain) + if !strings.HasPrefix(suffix, ".") { + suffix = "." + suffix + } + if strings.HasSuffix(s, suffix) { + s = s[:len(s)-len(suffix)] + break + } + } + if cfg.StripHyphens { + s = strings.ReplaceAll(s, "-", "") + } + return s +} + +// MakeNormalizer returns a closure bound to cfg, suitable for Cache.GetByShortcut. +func MakeNormalizer(cfg config.ShortcutsConfig) func(string) string { + return func(name string) string { + return Normalize(name, cfg) + } +} + +// AliasName returns a shell-safe alias name for the given host. +// It normalizes the name (strips configured domains and optionally hyphens) then +// replaces any remaining dots with underscores so the result is a valid identifier. +func AliasName(name string, cfg config.ShortcutsConfig) string { + s := Normalize(name, cfg) + return strings.ReplaceAll(s, ".", "_") +} diff --git a/internal/shortcuts/shortcuts_test.go b/internal/shortcuts/shortcuts_test.go new file mode 100644 index 0000000..2ae5aa4 --- /dev/null +++ b/internal/shortcuts/shortcuts_test.go @@ -0,0 +1,80 @@ +package shortcuts_test + +import ( + "testing" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config" + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/shortcuts" +) + +func TestNormalize(t *testing.T) { + cfg := config.ShortcutsConfig{ + Domains: []string{".example.com", ".example.de"}, + StripHyphens: true, + } + + cases := []struct { + input string + want string + }{ + {"fsn1-web01.example.com", "fsn1web01"}, + {"fsn1-web01.example.de", "fsn1web01"}, + {"FSN1-WEB01.EXAMPLE.COM", "fsn1web01"}, + {"fsn1-web01", "fsn1web01"}, + {"web01.example.com", "web01"}, + {"web01.other.com", "web01.other.com"}, // unknown domain not stripped + } + + for _, c := range cases { + got := shortcuts.Normalize(c.input, cfg) + if got != c.want { + t.Errorf("Normalize(%q) = %q, want %q", c.input, got, c.want) + } + } +} + +func TestNormalizeNoHyphenStrip(t *testing.T) { + cfg := config.ShortcutsConfig{ + Domains: []string{".example.com"}, + StripHyphens: false, + } + + got := shortcuts.Normalize("fsn1-web01.example.com", cfg) + want := "fsn1-web01" + if got != want { + t.Errorf("Normalize = %q, want %q", got, want) + } +} + +func TestAliasName_WithDomainAndHyphens(t *testing.T) { + cfg := config.ShortcutsConfig{ + Domains: []string{".example.com"}, + StripHyphens: true, + } + got := shortcuts.AliasName("fsn1-web01.example.com", cfg) + want := "fsn1web01" + if got != want { + t.Errorf("AliasName = %q, want %q", got, want) + } +} + +func TestAliasName_DotReplacement(t *testing.T) { + cfg := config.ShortcutsConfig{ + Domains: []string{".example.com"}, + } + // "web01.other.com" — domain not stripped, remaining dots → underscores + got := shortcuts.AliasName("web01.other.com", cfg) + want := "web01_other_com" + if got != want { + t.Errorf("AliasName = %q, want %q", got, want) + } +} + +func TestAliasName_NoConfig(t *testing.T) { + cfg := config.ShortcutsConfig{} + got := shortcuts.AliasName("web01.example.com", cfg) + want := "web01_example_com" + if got != want { + t.Errorf("AliasName = %q, want %q", got, want) + } +} diff --git a/internal/ssh/args.go b/internal/ssh/args.go index d5e834a..63939be 100644 --- a/internal/ssh/args.go +++ b/internal/ssh/args.go @@ -86,6 +86,19 @@ func ReplaceHost(args []string, destIdx int, newHost string) []string { return result } +// HasPortFlag reports whether a port was specified via -p in args. +func HasPortFlag(args []string) bool { + for i, a := range args { + if a == "-p" && i+1 < len(args) { + return true + } + if len(a) > 2 && a[0] == '-' && a[1] == 'p' { + return true + } + } + return false +} + // HasUserFlag reports whether a user was specified via -l in args. // Used to avoid overriding an explicit -l with the configured default user. func HasUserFlag(args []string) bool { diff --git a/internal/ssh/args_test.go b/internal/ssh/args_test.go index 4dd8611..afd8b78 100644 --- a/internal/ssh/args_test.go +++ b/internal/ssh/args_test.go @@ -144,6 +144,37 @@ func TestHasUserFlag_LFlagAtEnd(t *testing.T) { } } +func TestHasPortFlag_FlagSeparated(t *testing.T) { + if !HasPortFlag([]string{"-p", "2222", "host"}) { + t.Error("should detect -p ") + } +} + +func TestHasPortFlag_FlagAttached(t *testing.T) { + if !HasPortFlag([]string{"-p2222", "host"}) { + t.Error("should detect -p (attached form)") + } +} + +func TestHasPortFlag_NotPresent(t *testing.T) { + if HasPortFlag([]string{"-l", "admin", "host"}) { + t.Error("should not detect port flag when absent") + } +} + +func TestHasPortFlag_EmptyArgs(t *testing.T) { + if HasPortFlag([]string{}) { + t.Error("empty args should return false") + } +} + +func TestHasPortFlag_PAtEnd(t *testing.T) { + // -p at the very end with no value — should return false + if HasPortFlag([]string{"-p"}) { + t.Error("-p with no value should return false") + } +} + func assertParsed(t *testing.T, got *ParsedArgs, host, user string, destIdx int) { t.Helper() if got == nil {