feat: introduce shortcuts and shell hook support
Release / release (push) Successful in 50s

- **Shortcuts**: Add hostname normalization with domain stripping and hyphen folding. Include alias generation for cached hosts.
- **Shell Hook**: Automate 24h cache refresh trigger with shell startup hook. Add install/uninstall commands for bash, zsh, and fish.
- **Wizard**: Extend setup wizard to configure shortcuts (domains, hyphen stripping) and default SSH port.
- **Cache**: Add `GetByShortcut` for resolving hosts via normalized shortcuts. Implement `NeedsRefresh` / `SetRefreshed` logic for refresh timestamps.
- **Tests**: Comprehensive unit tests for shortcuts, hook installation, cache refresh, and alias generation.
- **Docs**: Update README with shortcuts, shell hook, and default SSH port configuration.
This commit is contained in:
Sebastian Unterschütz
2026-05-27 22:53:24 +02:00
parent d127a3b957
commit 7c902cab3a
13 changed files with 1378 additions and 19 deletions
+61 -8
View File
@@ -20,10 +20,11 @@ type Entry struct {
}
type Cache struct {
mu sync.RWMutex
entries map[string]Entry
path string
ttl time.Duration
mu sync.RWMutex
entries map[string]Entry
path string
ttl time.Duration
refreshStamp string // path to last_refresh timestamp file
}
type diskFormat struct {
@@ -31,11 +32,48 @@ type diskFormat struct {
}
func New(path string, ttlSeconds int) *Cache {
return &Cache{
entries: make(map[string]Entry),
path: path,
ttl: time.Duration(ttlSeconds) * time.Second,
stamp := ""
if path != "" {
stamp = filepath.Join(filepath.Dir(path), "last_refresh")
}
return &Cache{
entries: make(map[string]Entry),
path: path,
ttl: time.Duration(ttlSeconds) * time.Second,
refreshStamp: stamp,
}
}
// NeedsRefresh reports whether the last full cache refresh (via `cache refresh`)
// is older than d, or has never happened. Always returns false when no path is set.
func (c *Cache) NeedsRefresh(d time.Duration) bool {
if c.refreshStamp == "" {
return false
}
data, err := os.ReadFile(c.refreshStamp)
if err != nil {
return true // file missing → never refreshed
}
var t time.Time
if err := t.UnmarshalText(data); err != nil {
return true
}
return time.Since(t) >= d
}
// SetRefreshed records the current time as the last successful full refresh.
func (c *Cache) SetRefreshed() error {
if c.refreshStamp == "" {
return nil
}
data, err := time.Now().MarshalText()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(c.refreshStamp), 0o755); err != nil {
return err
}
return os.WriteFile(c.refreshStamp, data, 0o644)
}
func (c *Cache) Load() error {
@@ -137,6 +175,21 @@ func (c *Cache) Search(prefix string) []Entry {
return out
}
// GetByShortcut scans all entries and returns the first whose normalized name matches
// the normalized shortcut. Returns (entry, found, fresh).
func (c *Cache) GetByShortcut(shortcut string, normalize func(string) string) (entry Entry, found bool, fresh bool) {
c.mu.RLock()
defer c.mu.RUnlock()
norm := normalize(shortcut)
for _, e := range c.entries {
if normalize(e.Name) == norm {
isFresh := c.ttl > 0 && time.Since(e.CachedAt) < c.ttl
return e, true, isFresh
}
}
return Entry{}, false, false
}
// Get returns an entry and reports whether it is still within the TTL.
func (c *Cache) Get(name string) (entry Entry, fresh bool) {
c.mu.RLock()
+158
View File
@@ -300,6 +300,164 @@ func TestMarkUsed_RoundtripViaSave(t *testing.T) {
}
}
// --- GetByShortcut tests ---
func TestGetByShortcut_MatchFresh(t *testing.T) {
c := New("", 3600)
c.Upsert(Entry{Name: "web01.example.com", IP: "10.0.0.1", Kind: "device"})
normalize := func(s string) string {
// strip .example.com
if len(s) > len(".example.com") && s[len(s)-len(".example.com"):] == ".example.com" {
s = s[:len(s)-len(".example.com")]
}
return s
}
e, found, fresh := c.GetByShortcut("web01", normalize)
if !found {
t.Fatal("expected entry to be found")
}
if !fresh {
t.Error("entry just inserted should be fresh")
}
if e.IP != "10.0.0.1" {
t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.1")
}
}
func TestGetByShortcut_MatchStale(t *testing.T) {
c := New("", 1) // 1 second TTL
stale := Entry{Name: "web01.example.com", IP: "10.0.0.1", Kind: "device", CachedAt: time.Now().Add(-2 * time.Second)}
c.mu.Lock()
c.entries["web01.example.com"] = stale
c.mu.Unlock()
normalize := func(s string) string {
if len(s) > len(".example.com") && s[len(s)-len(".example.com"):] == ".example.com" {
s = s[:len(s)-len(".example.com")]
}
return s
}
_, found, fresh := c.GetByShortcut("web01", normalize)
if !found {
t.Fatal("expected entry to be found even when stale")
}
if fresh {
t.Error("entry older than TTL should not be fresh")
}
}
func TestGetByShortcut_NotFound(t *testing.T) {
c := New("", 3600)
c.Upsert(Entry{Name: "db01.example.com", IP: "10.0.0.2", Kind: "device"})
_, found, _ := c.GetByShortcut("web01", func(s string) string { return s })
if found {
t.Error("should not find an entry that does not match the shortcut")
}
}
func TestGetByShortcut_EmptyCache(t *testing.T) {
c := New("", 3600)
_, found, _ := c.GetByShortcut("web01", func(s string) string { return s })
if found {
t.Error("empty cache should return found=false")
}
}
func TestGetByShortcut_MultiDomain(t *testing.T) {
c := New("", 3600)
// Entry uses second domain (.example.de)
c.Upsert(Entry{Name: "web01.example.de", IP: "10.0.0.5", Kind: "vm"})
normalize := func(s string) string {
for _, suffix := range []string{".example.com", ".example.de"} {
if len(s) > len(suffix) && s[len(s)-len(suffix):] == suffix {
return s[:len(s)-len(suffix)]
}
}
return s
}
e, found, _ := c.GetByShortcut("web01", normalize)
if !found {
t.Fatal("should match entry with second configured domain")
}
if e.IP != "10.0.0.5" {
t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.5")
}
}
// --- NeedsRefresh / SetRefreshed tests ---
// These tests require NeedsRefresh(time.Duration) bool and SetRefreshed() error
// to be implemented on *Cache. The refreshStamp field (path to the timestamp file)
// must be set before calling these methods.
func TestNeedsRefresh_NeverRefreshed(t *testing.T) {
dir := t.TempDir()
stampPath := filepath.Join(dir, "last_refresh")
c := New(filepath.Join(dir, "cache.json"), 3600)
c.refreshStamp = stampPath
if !c.NeedsRefresh(24 * time.Hour) {
t.Error("NeedsRefresh should return true when last_refresh file does not exist")
}
}
func TestNeedsRefresh_JustRefreshed(t *testing.T) {
dir := t.TempDir()
stampPath := filepath.Join(dir, "last_refresh")
c := New(filepath.Join(dir, "cache.json"), 3600)
c.refreshStamp = stampPath
if err := c.SetRefreshed(); err != nil {
t.Fatalf("SetRefreshed: %v", err)
}
if c.NeedsRefresh(24 * time.Hour) {
t.Error("NeedsRefresh should return false immediately after SetRefreshed")
}
}
func TestNeedsRefresh_Stale(t *testing.T) {
dir := t.TempDir()
stampPath := filepath.Join(dir, "last_refresh")
// Write a timestamp 25 hours ago
oldTime := time.Now().Add(-25 * time.Hour).Format(time.RFC3339)
if err := os.WriteFile(stampPath, []byte(oldTime), 0o644); err != nil {
t.Fatalf("writing stamp: %v", err)
}
c := New(filepath.Join(dir, "cache.json"), 3600)
c.refreshStamp = stampPath
if !c.NeedsRefresh(24 * time.Hour) {
t.Error("NeedsRefresh should return true when last_refresh is older than duration")
}
}
func TestSetRefreshed_Roundtrip(t *testing.T) {
dir := t.TempDir()
stampPath := filepath.Join(dir, "last_refresh")
c := New(filepath.Join(dir, "cache.json"), 3600)
c.refreshStamp = stampPath
if err := c.SetRefreshed(); err != nil {
t.Fatalf("SetRefreshed: %v", err)
}
// Just refreshed: 24h window → not stale
if c.NeedsRefresh(24 * time.Hour) {
t.Error("NeedsRefresh(24h) should be false right after SetRefreshed")
}
// Zero duration: everything is stale
if !c.NeedsRefresh(0) {
t.Error("NeedsRefresh(0) should always return true")
}
}
// tempFile writes content to a temp file and returns its path.
func tempFile(t *testing.T, content []byte) string {
t.Helper()
+11 -4
View File
@@ -10,10 +10,11 @@ import (
)
type Config struct {
NetBox NetBoxConfig `mapstructure:"netbox"`
Resolver ResolverConfig `mapstructure:"resolver"`
Cache CacheConfig `mapstructure:"cache"`
SSH SSHConfig `mapstructure:"ssh"`
NetBox NetBoxConfig `mapstructure:"netbox"`
Resolver ResolverConfig `mapstructure:"resolver"`
Cache CacheConfig `mapstructure:"cache"`
SSH SSHConfig `mapstructure:"ssh"`
Shortcuts ShortcutsConfig `mapstructure:"shortcuts"`
}
type NetBoxConfig struct {
@@ -35,6 +36,12 @@ type CacheConfig struct {
type SSHConfig struct {
DefaultUser string `mapstructure:"default_user"`
DefaultPort int `mapstructure:"default_port"`
}
type ShortcutsConfig struct {
Domains []string `mapstructure:"domains"`
StripHyphens bool `mapstructure:"strip_hyphens"`
}
// Path returns the canonical config file path.
+101
View File
@@ -0,0 +1,101 @@
package hook
import (
"fmt"
"os"
"path/filepath"
"strings"
)
const (
// Marker is the unique string used to detect an existing hook installation.
Marker = "netssh shell-init"
// Line is the full line written into the shell profile.
Line = "netssh shell-init # netssh cache auto-refresh"
)
// ProfilePath returns the canonical shell profile path for the given shell name.
func ProfilePath(shell string) (string, error) {
switch shell {
case "bash":
return filepath.Join(os.Getenv("HOME"), ".bashrc"), nil
case "zsh":
return filepath.Join(os.Getenv("HOME"), ".zshrc"), nil
case "fish":
configDir, err := os.UserConfigDir()
if err != nil {
configDir = filepath.Join(os.Getenv("HOME"), ".config")
}
return filepath.Join(configDir, "fish", "config.fish"), nil
default:
return "", fmt.Errorf("unsupported shell %q — supported: bash, zsh, fish", shell)
}
}
// ReloadNote returns the shell-specific hint for reloading the profile.
func ReloadNote(profilePath string) string {
return "Reload with: source " + profilePath
}
// IsInstalled reports whether the hook marker is present in the profile file.
// Returns false if the file does not exist or cannot be read.
func IsInstalled(profilePath string) bool {
data, err := os.ReadFile(profilePath)
if err != nil {
return false
}
return strings.Contains(string(data), Marker)
}
// Install appends the hook line to the profile if it is not already present.
// Returns (true, nil) when newly installed, (false, nil) when already present.
// Creates the profile file (and any parent directories) if they do not exist.
func Install(profilePath string) (installed bool, err error) {
if IsInstalled(profilePath) {
return false, nil
}
if err := os.MkdirAll(filepath.Dir(profilePath), 0o755); err != nil {
return false, fmt.Errorf("creating directory: %w", err)
}
f, err := os.OpenFile(profilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return false, fmt.Errorf("opening profile: %w", err)
}
defer f.Close()
if _, err := fmt.Fprintf(f, "\n%s\n", Line); err != nil {
return false, err
}
return true, nil
}
// Uninstall removes the hook line from the profile.
// Returns (true, nil) when removed, (false, nil) when the hook was not found.
// Returns (false, nil) — not an error — when the file does not exist.
func Uninstall(profilePath string) (removed bool, err error) {
data, err := os.ReadFile(profilePath)
if os.IsNotExist(err) {
return false, nil
}
if err != nil {
return false, fmt.Errorf("reading profile: %w", err)
}
content := string(data)
if !strings.Contains(content, Marker) {
return false, nil
}
var kept []string
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, Marker) {
continue
}
kept = append(kept, line)
}
// Collapse triple blank lines that the removal may leave.
cleaned := strings.ReplaceAll(strings.Join(kept, "\n"), "\n\n\n", "\n\n")
if err := os.WriteFile(profilePath, []byte(cleaned), 0o644); err != nil {
return false, fmt.Errorf("writing profile: %w", err)
}
return true, nil
}
+263
View File
@@ -0,0 +1,263 @@
package hook_test
import (
"os"
"path/filepath"
"strings"
"testing"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/hook"
)
// --- ProfilePath ---
func TestProfilePath_Bash(t *testing.T) {
orig := os.Getenv("HOME")
os.Setenv("HOME", t.TempDir())
defer os.Setenv("HOME", orig)
p, err := hook.ProfilePath("bash")
if err != nil {
t.Fatalf("ProfilePath(bash): %v", err)
}
if !strings.HasSuffix(p, ".bashrc") {
t.Errorf("bash profile should end with .bashrc, got %q", p)
}
}
func TestProfilePath_Zsh(t *testing.T) {
orig := os.Getenv("HOME")
os.Setenv("HOME", t.TempDir())
defer os.Setenv("HOME", orig)
p, err := hook.ProfilePath("zsh")
if err != nil {
t.Fatalf("ProfilePath(zsh): %v", err)
}
if !strings.HasSuffix(p, ".zshrc") {
t.Errorf("zsh profile should end with .zshrc, got %q", p)
}
}
func TestProfilePath_Fish(t *testing.T) {
dir := t.TempDir()
orig := os.Getenv("XDG_CONFIG_HOME")
os.Setenv("XDG_CONFIG_HOME", dir)
defer os.Setenv("XDG_CONFIG_HOME", orig)
p, err := hook.ProfilePath("fish")
if err != nil {
t.Fatalf("ProfilePath(fish): %v", err)
}
if !strings.HasSuffix(p, "config.fish") {
t.Errorf("fish profile should end with config.fish, got %q", p)
}
}
func TestProfilePath_UnknownShell(t *testing.T) {
if _, err := hook.ProfilePath("ksh"); err == nil {
t.Error("unknown shell should return an error")
}
}
// --- IsInstalled ---
func TestIsInstalled_Present(t *testing.T) {
profile := writeProfile(t, hook.Line+"\n")
if !hook.IsInstalled(profile) {
t.Error("IsInstalled should return true when marker is present")
}
}
func TestIsInstalled_Absent(t *testing.T) {
profile := writeProfile(t, "export PATH=$PATH\n")
if hook.IsInstalled(profile) {
t.Error("IsInstalled should return false when marker is absent")
}
}
func TestIsInstalled_MissingFile(t *testing.T) {
if hook.IsInstalled("/nonexistent/path/.bashrc") {
t.Error("IsInstalled should return false for missing file")
}
}
// --- Install ---
func TestInstall_Fresh(t *testing.T) {
profile := filepath.Join(t.TempDir(), ".bashrc")
installed, err := hook.Install(profile)
if err != nil {
t.Fatalf("Install: %v", err)
}
if !installed {
t.Error("should report installed=true on fresh install")
}
if !hook.IsInstalled(profile) {
t.Error("profile should contain the hook line after Install")
}
}
func TestInstall_Idempotent(t *testing.T) {
profile := filepath.Join(t.TempDir(), ".bashrc")
hook.Install(profile)
installed, err := hook.Install(profile)
if err != nil {
t.Fatalf("second Install: %v", err)
}
if installed {
t.Error("second Install should report installed=false")
}
data, _ := os.ReadFile(profile)
if count := strings.Count(string(data), hook.Marker); count != 1 {
t.Errorf("hook line should appear exactly once, got %d", count)
}
}
func TestInstall_CreatesProfileAndParentDirs(t *testing.T) {
profile := filepath.Join(t.TempDir(), "nested", "dir", ".zshrc")
if _, err := hook.Install(profile); err != nil {
t.Fatalf("Install should create missing directories: %v", err)
}
if _, err := os.Stat(profile); err != nil {
t.Error("profile file should have been created")
}
}
func TestInstall_PreservesExistingContent(t *testing.T) {
profile := writeProfile(t, "export PATH=$PATH:/usr/local/bin\n")
hook.Install(profile)
data, _ := os.ReadFile(profile)
content := string(data)
if !strings.Contains(content, "export PATH") {
t.Error("Install should not remove existing content")
}
if !strings.Contains(content, hook.Marker) {
t.Error("hook line should be appended")
}
}
func TestInstall_AppendedAtEnd(t *testing.T) {
profile := writeProfile(t, "existing line\n")
hook.Install(profile)
data, _ := os.ReadFile(profile)
lines := strings.Split(strings.TrimRight(string(data), "\n"), "\n")
last := lines[len(lines)-1]
if last != hook.Line {
t.Errorf("hook line should be the last line, got %q", last)
}
}
// --- Uninstall ---
func TestUninstall_Removes(t *testing.T) {
profile := writeProfile(t, "export PATH=$PATH\n"+hook.Line+"\nexport FOO=bar\n")
removed, err := hook.Uninstall(profile)
if err != nil {
t.Fatalf("Uninstall: %v", err)
}
if !removed {
t.Error("should report removed=true")
}
if hook.IsInstalled(profile) {
t.Error("hook line should have been removed")
}
}
func TestUninstall_PreservesOtherContent(t *testing.T) {
profile := writeProfile(t, "export FOO=bar\n"+hook.Line+"\nexport BAZ=qux\n")
hook.Uninstall(profile)
data, _ := os.ReadFile(profile)
content := string(data)
if !strings.Contains(content, "export FOO=bar") {
t.Error("content before hook should be preserved")
}
if !strings.Contains(content, "export BAZ=qux") {
t.Error("content after hook should be preserved")
}
}
func TestUninstall_NotPresent(t *testing.T) {
profile := writeProfile(t, "export PATH=$PATH\n")
removed, err := hook.Uninstall(profile)
if err != nil {
t.Fatalf("Uninstall: %v", err)
}
if removed {
t.Error("should report removed=false when hook was not present")
}
}
func TestUninstall_MissingFile(t *testing.T) {
removed, err := hook.Uninstall("/nonexistent/.bashrc")
if err != nil {
t.Errorf("Uninstall on missing file should not error: %v", err)
}
if removed {
t.Error("should report removed=false for missing file")
}
}
func TestUninstall_CollapsesExtraBlankLines(t *testing.T) {
// blank line before hook + hook + blank line after = three consecutive newlines after removal
profile := writeProfile(t, "line1\n\n"+hook.Line+"\n\nline2\n")
hook.Uninstall(profile)
data, _ := os.ReadFile(profile)
if strings.Contains(string(data), "\n\n\n") {
t.Error("three consecutive blank lines should be collapsed after uninstall")
}
}
// --- Roundtrip ---
func TestInstallUninstall_Roundtrip(t *testing.T) {
profile := writeProfile(t, "existing content\n")
hook.Install(profile)
if !hook.IsInstalled(profile) {
t.Fatal("should be installed after Install")
}
hook.Uninstall(profile)
if hook.IsInstalled(profile) {
t.Fatal("should not be installed after Uninstall")
}
data, _ := os.ReadFile(profile)
if !strings.Contains(string(data), "existing content") {
t.Error("original content should survive install+uninstall cycle")
}
}
// --- ReloadNote ---
func TestReloadNote_ContainsPath(t *testing.T) {
note := hook.ReloadNote("/home/user/.bashrc")
if !strings.Contains(note, "/home/user/.bashrc") {
t.Errorf("ReloadNote should contain the profile path, got %q", note)
}
}
// --- helpers ---
func writeProfile(t *testing.T, content string) string {
t.Helper()
f := filepath.Join(t.TempDir(), "profile")
if err := os.WriteFile(f, []byte(content), 0o644); err != nil {
t.Fatal(err)
}
return f
}
+69
View File
@@ -20,10 +20,16 @@ func RunWizard(cfg *config.Config) error {
url := cfg.NetBox.URL
token := cfg.NetBox.Token
defaultUser := cfg.SSH.DefaultUser
defaultPort := ""
if cfg.SSH.DefaultPort > 0 {
defaultPort = strconv.Itoa(cfg.SSH.DefaultPort)
}
strategiesRaw := strings.Join(cfg.Resolver.Strategies, ", ")
subnets := strings.Join(cfg.Resolver.ManagementSubnets, ", ")
interfaceName := cfg.Resolver.InterfaceName
cacheTTL := strconv.Itoa(cfg.Cache.TTL)
shortcutDomains := strings.Join(cfg.Shortcuts.Domains, ", ")
stripHyphens := cfg.Shortcuts.StripHyphens
if strategiesRaw == "" {
strategiesRaw = "primary_ip"
@@ -62,6 +68,21 @@ func RunWizard(cfg *config.Config) error {
Title("Default SSH user").
Description("Leave empty to use your system user ($USER).").
Value(&defaultUser),
huh.NewInput().
Title("Default SSH port").
Description("Leave empty to use the standard port (22).").
Placeholder("22").
Value(&defaultPort).
Validate(func(s string) error {
if strings.TrimSpace(s) == "" {
return nil
}
n, err := strconv.Atoi(strings.TrimSpace(s))
if err != nil || n < 1 || n > 65535 {
return errors.New("must be a port number between 1 and 65535")
}
return nil
}),
).Title("SSH defaults"),
huh.NewGroup(
@@ -90,6 +111,18 @@ func RunWizard(cfg *config.Config) error {
return nil
}),
).Title("Resolver & cache"),
huh.NewGroup(
huh.NewInput().
Title("Domain suffixes").
Description("Comma-separated suffixes stripped for shortcuts, e.g. .example.com, .example.de\nAllows typing 'web01' instead of 'web01.example.com'.").
Placeholder(".example.com").
Value(&shortcutDomains),
huh.NewConfirm().
Title("Strip hyphens").
Description("When enabled, fsn1-web01.example.com can be accessed as fsn1web01.\nOnly works for hosts already in the cache.").
Value(&stripHyphens),
).Title("Shortcuts"),
)
if err := form.Run(); err != nil {
@@ -109,6 +142,11 @@ func RunWizard(cfg *config.Config) error {
ttl, _ := strconv.Atoi(cacheTTL)
port := 0
if p, err := strconv.Atoi(strings.TrimSpace(defaultPort)); err == nil {
port = p
}
var subnetList []string
for _, s := range strings.Split(subnets, ",") {
if s = strings.TrimSpace(s); s != "" {
@@ -116,6 +154,16 @@ func RunWizard(cfg *config.Config) error {
}
}
var domainList []string
for _, s := range strings.Split(shortcutDomains, ",") {
if s = strings.TrimSpace(s); s != "" {
if !strings.HasPrefix(s, ".") {
s = "." + s
}
domainList = append(domainList, s)
}
}
out := config.Config{
NetBox: config.NetBoxConfig{
URL: strings.TrimRight(strings.TrimSpace(url), "/"),
@@ -124,6 +172,7 @@ func RunWizard(cfg *config.Config) error {
},
SSH: config.SSHConfig{
DefaultUser: strings.TrimSpace(defaultUser),
DefaultPort: port,
},
Resolver: config.ResolverConfig{
Strategies: strategies,
@@ -133,6 +182,10 @@ func RunWizard(cfg *config.Config) error {
Cache: config.CacheConfig{
TTL: ttl,
},
Shortcuts: config.ShortcutsConfig{
Domains: domainList,
StripHyphens: stripHyphens,
},
}
return save(out)
@@ -200,6 +253,22 @@ func save(cfg config.Config) error {
if cfg.SSH.DefaultUser != "" {
fmt.Fprintf(&b, " default_user: %q\n", cfg.SSH.DefaultUser)
}
if cfg.SSH.DefaultPort > 0 {
fmt.Fprintf(&b, " default_port: %d\n", cfg.SSH.DefaultPort)
}
if len(cfg.Shortcuts.Domains) > 0 || cfg.Shortcuts.StripHyphens {
b.WriteString("\nshortcuts:\n")
if len(cfg.Shortcuts.Domains) > 0 {
b.WriteString(" domains:\n")
for _, d := range cfg.Shortcuts.Domains {
fmt.Fprintf(&b, " - %s\n", d)
}
}
if cfg.Shortcuts.StripHyphens {
b.WriteString(" strip_hyphens: true\n")
}
}
if err := os.WriteFile(path, []byte(b.String()), 0o600); err != nil {
return fmt.Errorf("writing config: %w", err)
+142
View File
@@ -116,6 +116,148 @@ func TestSave_CreatesConfigDir(t *testing.T) {
}
}
func TestSave_DefaultPort(t *testing.T) {
dir := t.TempDir()
orig := os.Getenv("XDG_CONFIG_HOME")
os.Setenv("XDG_CONFIG_HOME", dir)
defer os.Setenv("XDG_CONFIG_HOME", orig)
cfg := config.Config{
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
Cache: config.CacheConfig{TTL: 60},
SSH: config.SSHConfig{DefaultPort: 2222},
}
if err := save(cfg); err != nil {
t.Fatalf("save: %v", err)
}
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
if !strings.Contains(string(data), "default_port: 2222") {
t.Errorf("expected default_port: 2222 in config, got:\n%s", string(data))
}
}
func TestSave_Shortcuts_WritesSection(t *testing.T) {
dir := t.TempDir()
orig := os.Getenv("XDG_CONFIG_HOME")
os.Setenv("XDG_CONFIG_HOME", dir)
defer os.Setenv("XDG_CONFIG_HOME", orig)
cfg := config.Config{
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
Cache: config.CacheConfig{TTL: 60},
Shortcuts: config.ShortcutsConfig{
Domains: []string{".example.com"},
StripHyphens: true,
},
}
if err := save(cfg); err != nil {
t.Fatalf("save: %v", err)
}
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
content := string(data)
for _, want := range []string{"shortcuts:", "domains:", ".example.com", "strip_hyphens: true"} {
if !strings.Contains(content, want) {
t.Errorf("expected %q in config, got:\n%s", want, content)
}
}
}
func TestSave_OmitsDefaultPortWhenZero(t *testing.T) {
dir := t.TempDir()
orig := os.Getenv("XDG_CONFIG_HOME")
os.Setenv("XDG_CONFIG_HOME", dir)
defer os.Setenv("XDG_CONFIG_HOME", orig)
cfg := config.Config{
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
Cache: config.CacheConfig{TTL: 60},
SSH: config.SSHConfig{DefaultPort: 0},
}
if err := save(cfg); err != nil {
t.Fatalf("save: %v", err)
}
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
if strings.Contains(string(data), "default_port") {
t.Errorf("default_port should be omitted when zero, got:\n%s", string(data))
}
}
func TestSave_OmitsShortcutsWhenEmpty(t *testing.T) {
dir := t.TempDir()
orig := os.Getenv("XDG_CONFIG_HOME")
os.Setenv("XDG_CONFIG_HOME", dir)
defer os.Setenv("XDG_CONFIG_HOME", orig)
cfg := config.Config{
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
Cache: config.CacheConfig{TTL: 60},
Shortcuts: config.ShortcutsConfig{}, // empty
}
if err := save(cfg); err != nil {
t.Fatalf("save: %v", err)
}
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
if strings.Contains(string(data), "shortcuts:") {
t.Errorf("shortcuts section should be omitted when empty, got:\n%s", string(data))
}
}
func TestSave_Roundtrip_WithNewFields(t *testing.T) {
dir := t.TempDir()
orig := os.Getenv("XDG_CONFIG_HOME")
os.Setenv("XDG_CONFIG_HOME", dir)
defer os.Setenv("XDG_CONFIG_HOME", orig)
original := config.Config{
NetBox: config.NetBoxConfig{
URL: "https://netbox.example.com",
Token: "nbt_test",
TokenVersion: 2,
},
SSH: config.SSHConfig{
DefaultUser: "admin",
DefaultPort: 2222,
},
Resolver: config.ResolverConfig{
Strategies: []string{"primary_ip"},
},
Cache: config.CacheConfig{TTL: 3600},
Shortcuts: config.ShortcutsConfig{
Domains: []string{".example.com", ".example.de"},
StripHyphens: true,
},
}
if err := save(original); err != nil {
t.Fatalf("save: %v", err)
}
loaded, err := config.Load()
if err != nil {
t.Fatalf("Load after save: %v", err)
}
if loaded.SSH.DefaultPort != original.SSH.DefaultPort {
t.Errorf("DefaultPort: got %d, want %d", loaded.SSH.DefaultPort, original.SSH.DefaultPort)
}
if len(loaded.Shortcuts.Domains) != len(original.Shortcuts.Domains) {
t.Errorf("Shortcuts.Domains length: got %d, want %d", len(loaded.Shortcuts.Domains), len(original.Shortcuts.Domains))
} else {
for i, d := range original.Shortcuts.Domains {
if loaded.Shortcuts.Domains[i] != d {
t.Errorf("Shortcuts.Domains[%d]: got %q, want %q", i, loaded.Shortcuts.Domains[i], d)
}
}
}
if loaded.Shortcuts.StripHyphens != original.Shortcuts.StripHyphens {
t.Errorf("StripHyphens: got %v, want %v", loaded.Shortcuts.StripHyphens, original.Shortcuts.StripHyphens)
}
}
func TestParseStrategies(t *testing.T) {
tests := []struct {
in string
+42
View File
@@ -0,0 +1,42 @@
package shortcuts
import (
"strings"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
)
// Normalize strips configured domain suffixes and optionally hyphens from a hostname.
// The result is lowercased for case-insensitive comparison.
func Normalize(name string, cfg config.ShortcutsConfig) string {
s := strings.ToLower(name)
for _, domain := range cfg.Domains {
suffix := strings.ToLower(domain)
if !strings.HasPrefix(suffix, ".") {
suffix = "." + suffix
}
if strings.HasSuffix(s, suffix) {
s = s[:len(s)-len(suffix)]
break
}
}
if cfg.StripHyphens {
s = strings.ReplaceAll(s, "-", "")
}
return s
}
// MakeNormalizer returns a closure bound to cfg, suitable for Cache.GetByShortcut.
func MakeNormalizer(cfg config.ShortcutsConfig) func(string) string {
return func(name string) string {
return Normalize(name, cfg)
}
}
// AliasName returns a shell-safe alias name for the given host.
// It normalizes the name (strips configured domains and optionally hyphens) then
// replaces any remaining dots with underscores so the result is a valid identifier.
func AliasName(name string, cfg config.ShortcutsConfig) string {
s := Normalize(name, cfg)
return strings.ReplaceAll(s, ".", "_")
}
+80
View File
@@ -0,0 +1,80 @@
package shortcuts_test
import (
"testing"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/shortcuts"
)
func TestNormalize(t *testing.T) {
cfg := config.ShortcutsConfig{
Domains: []string{".example.com", ".example.de"},
StripHyphens: true,
}
cases := []struct {
input string
want string
}{
{"fsn1-web01.example.com", "fsn1web01"},
{"fsn1-web01.example.de", "fsn1web01"},
{"FSN1-WEB01.EXAMPLE.COM", "fsn1web01"},
{"fsn1-web01", "fsn1web01"},
{"web01.example.com", "web01"},
{"web01.other.com", "web01.other.com"}, // unknown domain not stripped
}
for _, c := range cases {
got := shortcuts.Normalize(c.input, cfg)
if got != c.want {
t.Errorf("Normalize(%q) = %q, want %q", c.input, got, c.want)
}
}
}
func TestNormalizeNoHyphenStrip(t *testing.T) {
cfg := config.ShortcutsConfig{
Domains: []string{".example.com"},
StripHyphens: false,
}
got := shortcuts.Normalize("fsn1-web01.example.com", cfg)
want := "fsn1-web01"
if got != want {
t.Errorf("Normalize = %q, want %q", got, want)
}
}
func TestAliasName_WithDomainAndHyphens(t *testing.T) {
cfg := config.ShortcutsConfig{
Domains: []string{".example.com"},
StripHyphens: true,
}
got := shortcuts.AliasName("fsn1-web01.example.com", cfg)
want := "fsn1web01"
if got != want {
t.Errorf("AliasName = %q, want %q", got, want)
}
}
func TestAliasName_DotReplacement(t *testing.T) {
cfg := config.ShortcutsConfig{
Domains: []string{".example.com"},
}
// "web01.other.com" — domain not stripped, remaining dots → underscores
got := shortcuts.AliasName("web01.other.com", cfg)
want := "web01_other_com"
if got != want {
t.Errorf("AliasName = %q, want %q", got, want)
}
}
func TestAliasName_NoConfig(t *testing.T) {
cfg := config.ShortcutsConfig{}
got := shortcuts.AliasName("web01.example.com", cfg)
want := "web01_example_com"
if got != want {
t.Errorf("AliasName = %q, want %q", got, want)
}
}
+13
View File
@@ -86,6 +86,19 @@ func ReplaceHost(args []string, destIdx int, newHost string) []string {
return result
}
// HasPortFlag reports whether a port was specified via -p in args.
func HasPortFlag(args []string) bool {
for i, a := range args {
if a == "-p" && i+1 < len(args) {
return true
}
if len(a) > 2 && a[0] == '-' && a[1] == 'p' {
return true
}
}
return false
}
// HasUserFlag reports whether a user was specified via -l in args.
// Used to avoid overriding an explicit -l with the configured default user.
func HasUserFlag(args []string) bool {
+31
View File
@@ -144,6 +144,37 @@ func TestHasUserFlag_LFlagAtEnd(t *testing.T) {
}
}
func TestHasPortFlag_FlagSeparated(t *testing.T) {
if !HasPortFlag([]string{"-p", "2222", "host"}) {
t.Error("should detect -p <port>")
}
}
func TestHasPortFlag_FlagAttached(t *testing.T) {
if !HasPortFlag([]string{"-p2222", "host"}) {
t.Error("should detect -p<port> (attached form)")
}
}
func TestHasPortFlag_NotPresent(t *testing.T) {
if HasPortFlag([]string{"-l", "admin", "host"}) {
t.Error("should not detect port flag when absent")
}
}
func TestHasPortFlag_EmptyArgs(t *testing.T) {
if HasPortFlag([]string{}) {
t.Error("empty args should return false")
}
}
func TestHasPortFlag_PAtEnd(t *testing.T) {
// -p at the very end with no value — should return false
if HasPortFlag([]string{"-p"}) {
t.Error("-p with no value should return false")
}
}
func assertParsed(t *testing.T, got *ParsedArgs, host, user string, destIdx int) {
t.Helper()
if got == nil {