- **Shortcuts**: Add hostname normalization with domain stripping and hyphen folding. Include alias generation for cached hosts. - **Shell Hook**: Automate 24h cache refresh trigger with shell startup hook. Add install/uninstall commands for bash, zsh, and fish. - **Wizard**: Extend setup wizard to configure shortcuts (domains, hyphen stripping) and default SSH port. - **Cache**: Add `GetByShortcut` for resolving hosts via normalized shortcuts. Implement `NeedsRefresh` / `SetRefreshed` logic for refresh timestamps. - **Tests**: Comprehensive unit tests for shortcuts, hook installation, cache refresh, and alias generation. - **Docs**: Update README with shortcuts, shell hook, and default SSH port configuration.
This commit is contained in:
Vendored
+61
-8
@@ -20,10 +20,11 @@ type Entry struct {
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]Entry
|
||||
path string
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
entries map[string]Entry
|
||||
path string
|
||||
ttl time.Duration
|
||||
refreshStamp string // path to last_refresh timestamp file
|
||||
}
|
||||
|
||||
type diskFormat struct {
|
||||
@@ -31,11 +32,48 @@ type diskFormat struct {
|
||||
}
|
||||
|
||||
func New(path string, ttlSeconds int) *Cache {
|
||||
return &Cache{
|
||||
entries: make(map[string]Entry),
|
||||
path: path,
|
||||
ttl: time.Duration(ttlSeconds) * time.Second,
|
||||
stamp := ""
|
||||
if path != "" {
|
||||
stamp = filepath.Join(filepath.Dir(path), "last_refresh")
|
||||
}
|
||||
return &Cache{
|
||||
entries: make(map[string]Entry),
|
||||
path: path,
|
||||
ttl: time.Duration(ttlSeconds) * time.Second,
|
||||
refreshStamp: stamp,
|
||||
}
|
||||
}
|
||||
|
||||
// NeedsRefresh reports whether the last full cache refresh (via `cache refresh`)
|
||||
// is older than d, or has never happened. Always returns false when no path is set.
|
||||
func (c *Cache) NeedsRefresh(d time.Duration) bool {
|
||||
if c.refreshStamp == "" {
|
||||
return false
|
||||
}
|
||||
data, err := os.ReadFile(c.refreshStamp)
|
||||
if err != nil {
|
||||
return true // file missing → never refreshed
|
||||
}
|
||||
var t time.Time
|
||||
if err := t.UnmarshalText(data); err != nil {
|
||||
return true
|
||||
}
|
||||
return time.Since(t) >= d
|
||||
}
|
||||
|
||||
// SetRefreshed records the current time as the last successful full refresh.
|
||||
func (c *Cache) SetRefreshed() error {
|
||||
if c.refreshStamp == "" {
|
||||
return nil
|
||||
}
|
||||
data, err := time.Now().MarshalText()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(c.refreshStamp), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(c.refreshStamp, data, 0o644)
|
||||
}
|
||||
|
||||
func (c *Cache) Load() error {
|
||||
@@ -137,6 +175,21 @@ func (c *Cache) Search(prefix string) []Entry {
|
||||
return out
|
||||
}
|
||||
|
||||
// GetByShortcut scans all entries and returns the first whose normalized name matches
|
||||
// the normalized shortcut. Returns (entry, found, fresh).
|
||||
func (c *Cache) GetByShortcut(shortcut string, normalize func(string) string) (entry Entry, found bool, fresh bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
norm := normalize(shortcut)
|
||||
for _, e := range c.entries {
|
||||
if normalize(e.Name) == norm {
|
||||
isFresh := c.ttl > 0 && time.Since(e.CachedAt) < c.ttl
|
||||
return e, true, isFresh
|
||||
}
|
||||
}
|
||||
return Entry{}, false, false
|
||||
}
|
||||
|
||||
// Get returns an entry and reports whether it is still within the TTL.
|
||||
func (c *Cache) Get(name string) (entry Entry, fresh bool) {
|
||||
c.mu.RLock()
|
||||
|
||||
Vendored
+158
@@ -300,6 +300,164 @@ func TestMarkUsed_RoundtripViaSave(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetByShortcut tests ---
|
||||
|
||||
func TestGetByShortcut_MatchFresh(t *testing.T) {
|
||||
c := New("", 3600)
|
||||
c.Upsert(Entry{Name: "web01.example.com", IP: "10.0.0.1", Kind: "device"})
|
||||
|
||||
normalize := func(s string) string {
|
||||
// strip .example.com
|
||||
if len(s) > len(".example.com") && s[len(s)-len(".example.com"):] == ".example.com" {
|
||||
s = s[:len(s)-len(".example.com")]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
e, found, fresh := c.GetByShortcut("web01", normalize)
|
||||
if !found {
|
||||
t.Fatal("expected entry to be found")
|
||||
}
|
||||
if !fresh {
|
||||
t.Error("entry just inserted should be fresh")
|
||||
}
|
||||
if e.IP != "10.0.0.1" {
|
||||
t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByShortcut_MatchStale(t *testing.T) {
|
||||
c := New("", 1) // 1 second TTL
|
||||
stale := Entry{Name: "web01.example.com", IP: "10.0.0.1", Kind: "device", CachedAt: time.Now().Add(-2 * time.Second)}
|
||||
c.mu.Lock()
|
||||
c.entries["web01.example.com"] = stale
|
||||
c.mu.Unlock()
|
||||
|
||||
normalize := func(s string) string {
|
||||
if len(s) > len(".example.com") && s[len(s)-len(".example.com"):] == ".example.com" {
|
||||
s = s[:len(s)-len(".example.com")]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
_, found, fresh := c.GetByShortcut("web01", normalize)
|
||||
if !found {
|
||||
t.Fatal("expected entry to be found even when stale")
|
||||
}
|
||||
if fresh {
|
||||
t.Error("entry older than TTL should not be fresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByShortcut_NotFound(t *testing.T) {
|
||||
c := New("", 3600)
|
||||
c.Upsert(Entry{Name: "db01.example.com", IP: "10.0.0.2", Kind: "device"})
|
||||
|
||||
_, found, _ := c.GetByShortcut("web01", func(s string) string { return s })
|
||||
if found {
|
||||
t.Error("should not find an entry that does not match the shortcut")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByShortcut_EmptyCache(t *testing.T) {
|
||||
c := New("", 3600)
|
||||
_, found, _ := c.GetByShortcut("web01", func(s string) string { return s })
|
||||
if found {
|
||||
t.Error("empty cache should return found=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByShortcut_MultiDomain(t *testing.T) {
|
||||
c := New("", 3600)
|
||||
// Entry uses second domain (.example.de)
|
||||
c.Upsert(Entry{Name: "web01.example.de", IP: "10.0.0.5", Kind: "vm"})
|
||||
|
||||
normalize := func(s string) string {
|
||||
for _, suffix := range []string{".example.com", ".example.de"} {
|
||||
if len(s) > len(suffix) && s[len(s)-len(suffix):] == suffix {
|
||||
return s[:len(s)-len(suffix)]
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
e, found, _ := c.GetByShortcut("web01", normalize)
|
||||
if !found {
|
||||
t.Fatal("should match entry with second configured domain")
|
||||
}
|
||||
if e.IP != "10.0.0.5" {
|
||||
t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.5")
|
||||
}
|
||||
}
|
||||
|
||||
// --- NeedsRefresh / SetRefreshed tests ---
|
||||
// These tests require NeedsRefresh(time.Duration) bool and SetRefreshed() error
|
||||
// to be implemented on *Cache. The refreshStamp field (path to the timestamp file)
|
||||
// must be set before calling these methods.
|
||||
|
||||
func TestNeedsRefresh_NeverRefreshed(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
stampPath := filepath.Join(dir, "last_refresh")
|
||||
c := New(filepath.Join(dir, "cache.json"), 3600)
|
||||
c.refreshStamp = stampPath
|
||||
|
||||
if !c.NeedsRefresh(24 * time.Hour) {
|
||||
t.Error("NeedsRefresh should return true when last_refresh file does not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsRefresh_JustRefreshed(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
stampPath := filepath.Join(dir, "last_refresh")
|
||||
c := New(filepath.Join(dir, "cache.json"), 3600)
|
||||
c.refreshStamp = stampPath
|
||||
|
||||
if err := c.SetRefreshed(); err != nil {
|
||||
t.Fatalf("SetRefreshed: %v", err)
|
||||
}
|
||||
if c.NeedsRefresh(24 * time.Hour) {
|
||||
t.Error("NeedsRefresh should return false immediately after SetRefreshed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsRefresh_Stale(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
stampPath := filepath.Join(dir, "last_refresh")
|
||||
// Write a timestamp 25 hours ago
|
||||
oldTime := time.Now().Add(-25 * time.Hour).Format(time.RFC3339)
|
||||
if err := os.WriteFile(stampPath, []byte(oldTime), 0o644); err != nil {
|
||||
t.Fatalf("writing stamp: %v", err)
|
||||
}
|
||||
|
||||
c := New(filepath.Join(dir, "cache.json"), 3600)
|
||||
c.refreshStamp = stampPath
|
||||
|
||||
if !c.NeedsRefresh(24 * time.Hour) {
|
||||
t.Error("NeedsRefresh should return true when last_refresh is older than duration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRefreshed_Roundtrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
stampPath := filepath.Join(dir, "last_refresh")
|
||||
c := New(filepath.Join(dir, "cache.json"), 3600)
|
||||
c.refreshStamp = stampPath
|
||||
|
||||
if err := c.SetRefreshed(); err != nil {
|
||||
t.Fatalf("SetRefreshed: %v", err)
|
||||
}
|
||||
|
||||
// Just refreshed: 24h window → not stale
|
||||
if c.NeedsRefresh(24 * time.Hour) {
|
||||
t.Error("NeedsRefresh(24h) should be false right after SetRefreshed")
|
||||
}
|
||||
|
||||
// Zero duration: everything is stale
|
||||
if !c.NeedsRefresh(0) {
|
||||
t.Error("NeedsRefresh(0) should always return true")
|
||||
}
|
||||
}
|
||||
|
||||
// tempFile writes content to a temp file and returns its path.
|
||||
func tempFile(t *testing.T, content []byte) string {
|
||||
t.Helper()
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
NetBox NetBoxConfig `mapstructure:"netbox"`
|
||||
Resolver ResolverConfig `mapstructure:"resolver"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
SSH SSHConfig `mapstructure:"ssh"`
|
||||
NetBox NetBoxConfig `mapstructure:"netbox"`
|
||||
Resolver ResolverConfig `mapstructure:"resolver"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
SSH SSHConfig `mapstructure:"ssh"`
|
||||
Shortcuts ShortcutsConfig `mapstructure:"shortcuts"`
|
||||
}
|
||||
|
||||
type NetBoxConfig struct {
|
||||
@@ -35,6 +36,12 @@ type CacheConfig struct {
|
||||
|
||||
type SSHConfig struct {
|
||||
DefaultUser string `mapstructure:"default_user"`
|
||||
DefaultPort int `mapstructure:"default_port"`
|
||||
}
|
||||
|
||||
type ShortcutsConfig struct {
|
||||
Domains []string `mapstructure:"domains"`
|
||||
StripHyphens bool `mapstructure:"strip_hyphens"`
|
||||
}
|
||||
|
||||
// Path returns the canonical config file path.
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// Marker is the unique string used to detect an existing hook installation.
|
||||
Marker = "netssh shell-init"
|
||||
// Line is the full line written into the shell profile.
|
||||
Line = "netssh shell-init # netssh cache auto-refresh"
|
||||
)
|
||||
|
||||
// ProfilePath returns the canonical shell profile path for the given shell name.
|
||||
func ProfilePath(shell string) (string, error) {
|
||||
switch shell {
|
||||
case "bash":
|
||||
return filepath.Join(os.Getenv("HOME"), ".bashrc"), nil
|
||||
case "zsh":
|
||||
return filepath.Join(os.Getenv("HOME"), ".zshrc"), nil
|
||||
case "fish":
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
configDir = filepath.Join(os.Getenv("HOME"), ".config")
|
||||
}
|
||||
return filepath.Join(configDir, "fish", "config.fish"), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported shell %q — supported: bash, zsh, fish", shell)
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadNote returns the shell-specific hint for reloading the profile.
|
||||
func ReloadNote(profilePath string) string {
|
||||
return "Reload with: source " + profilePath
|
||||
}
|
||||
|
||||
// IsInstalled reports whether the hook marker is present in the profile file.
|
||||
// Returns false if the file does not exist or cannot be read.
|
||||
func IsInstalled(profilePath string) bool {
|
||||
data, err := os.ReadFile(profilePath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(string(data), Marker)
|
||||
}
|
||||
|
||||
// Install appends the hook line to the profile if it is not already present.
|
||||
// Returns (true, nil) when newly installed, (false, nil) when already present.
|
||||
// Creates the profile file (and any parent directories) if they do not exist.
|
||||
func Install(profilePath string) (installed bool, err error) {
|
||||
if IsInstalled(profilePath) {
|
||||
return false, nil
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(profilePath), 0o755); err != nil {
|
||||
return false, fmt.Errorf("creating directory: %w", err)
|
||||
}
|
||||
f, err := os.OpenFile(profilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("opening profile: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
if _, err := fmt.Fprintf(f, "\n%s\n", Line); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Uninstall removes the hook line from the profile.
|
||||
// Returns (true, nil) when removed, (false, nil) when the hook was not found.
|
||||
// Returns (false, nil) — not an error — when the file does not exist.
|
||||
func Uninstall(profilePath string) (removed bool, err error) {
|
||||
data, err := os.ReadFile(profilePath)
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("reading profile: %w", err)
|
||||
}
|
||||
content := string(data)
|
||||
if !strings.Contains(content, Marker) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var kept []string
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
if strings.Contains(line, Marker) {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, line)
|
||||
}
|
||||
// Collapse triple blank lines that the removal may leave.
|
||||
cleaned := strings.ReplaceAll(strings.Join(kept, "\n"), "\n\n\n", "\n\n")
|
||||
|
||||
if err := os.WriteFile(profilePath, []byte(cleaned), 0o644); err != nil {
|
||||
return false, fmt.Errorf("writing profile: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
package hook_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/hook"
|
||||
)
|
||||
|
||||
// --- ProfilePath ---
|
||||
|
||||
func TestProfilePath_Bash(t *testing.T) {
|
||||
orig := os.Getenv("HOME")
|
||||
os.Setenv("HOME", t.TempDir())
|
||||
defer os.Setenv("HOME", orig)
|
||||
|
||||
p, err := hook.ProfilePath("bash")
|
||||
if err != nil {
|
||||
t.Fatalf("ProfilePath(bash): %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(p, ".bashrc") {
|
||||
t.Errorf("bash profile should end with .bashrc, got %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProfilePath_Zsh(t *testing.T) {
|
||||
orig := os.Getenv("HOME")
|
||||
os.Setenv("HOME", t.TempDir())
|
||||
defer os.Setenv("HOME", orig)
|
||||
|
||||
p, err := hook.ProfilePath("zsh")
|
||||
if err != nil {
|
||||
t.Fatalf("ProfilePath(zsh): %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(p, ".zshrc") {
|
||||
t.Errorf("zsh profile should end with .zshrc, got %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProfilePath_Fish(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
p, err := hook.ProfilePath("fish")
|
||||
if err != nil {
|
||||
t.Fatalf("ProfilePath(fish): %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(p, "config.fish") {
|
||||
t.Errorf("fish profile should end with config.fish, got %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProfilePath_UnknownShell(t *testing.T) {
|
||||
if _, err := hook.ProfilePath("ksh"); err == nil {
|
||||
t.Error("unknown shell should return an error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- IsInstalled ---
|
||||
|
||||
func TestIsInstalled_Present(t *testing.T) {
|
||||
profile := writeProfile(t, hook.Line+"\n")
|
||||
if !hook.IsInstalled(profile) {
|
||||
t.Error("IsInstalled should return true when marker is present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInstalled_Absent(t *testing.T) {
|
||||
profile := writeProfile(t, "export PATH=$PATH\n")
|
||||
if hook.IsInstalled(profile) {
|
||||
t.Error("IsInstalled should return false when marker is absent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInstalled_MissingFile(t *testing.T) {
|
||||
if hook.IsInstalled("/nonexistent/path/.bashrc") {
|
||||
t.Error("IsInstalled should return false for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Install ---
|
||||
|
||||
func TestInstall_Fresh(t *testing.T) {
|
||||
profile := filepath.Join(t.TempDir(), ".bashrc")
|
||||
|
||||
installed, err := hook.Install(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("Install: %v", err)
|
||||
}
|
||||
if !installed {
|
||||
t.Error("should report installed=true on fresh install")
|
||||
}
|
||||
if !hook.IsInstalled(profile) {
|
||||
t.Error("profile should contain the hook line after Install")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstall_Idempotent(t *testing.T) {
|
||||
profile := filepath.Join(t.TempDir(), ".bashrc")
|
||||
|
||||
hook.Install(profile)
|
||||
installed, err := hook.Install(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("second Install: %v", err)
|
||||
}
|
||||
if installed {
|
||||
t.Error("second Install should report installed=false")
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(profile)
|
||||
if count := strings.Count(string(data), hook.Marker); count != 1 {
|
||||
t.Errorf("hook line should appear exactly once, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstall_CreatesProfileAndParentDirs(t *testing.T) {
|
||||
profile := filepath.Join(t.TempDir(), "nested", "dir", ".zshrc")
|
||||
|
||||
if _, err := hook.Install(profile); err != nil {
|
||||
t.Fatalf("Install should create missing directories: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(profile); err != nil {
|
||||
t.Error("profile file should have been created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstall_PreservesExistingContent(t *testing.T) {
|
||||
profile := writeProfile(t, "export PATH=$PATH:/usr/local/bin\n")
|
||||
|
||||
hook.Install(profile)
|
||||
|
||||
data, _ := os.ReadFile(profile)
|
||||
content := string(data)
|
||||
if !strings.Contains(content, "export PATH") {
|
||||
t.Error("Install should not remove existing content")
|
||||
}
|
||||
if !strings.Contains(content, hook.Marker) {
|
||||
t.Error("hook line should be appended")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstall_AppendedAtEnd(t *testing.T) {
|
||||
profile := writeProfile(t, "existing line\n")
|
||||
hook.Install(profile)
|
||||
|
||||
data, _ := os.ReadFile(profile)
|
||||
lines := strings.Split(strings.TrimRight(string(data), "\n"), "\n")
|
||||
last := lines[len(lines)-1]
|
||||
if last != hook.Line {
|
||||
t.Errorf("hook line should be the last line, got %q", last)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Uninstall ---
|
||||
|
||||
func TestUninstall_Removes(t *testing.T) {
|
||||
profile := writeProfile(t, "export PATH=$PATH\n"+hook.Line+"\nexport FOO=bar\n")
|
||||
|
||||
removed, err := hook.Uninstall(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("Uninstall: %v", err)
|
||||
}
|
||||
if !removed {
|
||||
t.Error("should report removed=true")
|
||||
}
|
||||
if hook.IsInstalled(profile) {
|
||||
t.Error("hook line should have been removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUninstall_PreservesOtherContent(t *testing.T) {
|
||||
profile := writeProfile(t, "export FOO=bar\n"+hook.Line+"\nexport BAZ=qux\n")
|
||||
|
||||
hook.Uninstall(profile)
|
||||
|
||||
data, _ := os.ReadFile(profile)
|
||||
content := string(data)
|
||||
if !strings.Contains(content, "export FOO=bar") {
|
||||
t.Error("content before hook should be preserved")
|
||||
}
|
||||
if !strings.Contains(content, "export BAZ=qux") {
|
||||
t.Error("content after hook should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUninstall_NotPresent(t *testing.T) {
|
||||
profile := writeProfile(t, "export PATH=$PATH\n")
|
||||
|
||||
removed, err := hook.Uninstall(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("Uninstall: %v", err)
|
||||
}
|
||||
if removed {
|
||||
t.Error("should report removed=false when hook was not present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUninstall_MissingFile(t *testing.T) {
|
||||
removed, err := hook.Uninstall("/nonexistent/.bashrc")
|
||||
if err != nil {
|
||||
t.Errorf("Uninstall on missing file should not error: %v", err)
|
||||
}
|
||||
if removed {
|
||||
t.Error("should report removed=false for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUninstall_CollapsesExtraBlankLines(t *testing.T) {
|
||||
// blank line before hook + hook + blank line after = three consecutive newlines after removal
|
||||
profile := writeProfile(t, "line1\n\n"+hook.Line+"\n\nline2\n")
|
||||
|
||||
hook.Uninstall(profile)
|
||||
|
||||
data, _ := os.ReadFile(profile)
|
||||
if strings.Contains(string(data), "\n\n\n") {
|
||||
t.Error("three consecutive blank lines should be collapsed after uninstall")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Roundtrip ---
|
||||
|
||||
func TestInstallUninstall_Roundtrip(t *testing.T) {
|
||||
profile := writeProfile(t, "existing content\n")
|
||||
|
||||
hook.Install(profile)
|
||||
if !hook.IsInstalled(profile) {
|
||||
t.Fatal("should be installed after Install")
|
||||
}
|
||||
|
||||
hook.Uninstall(profile)
|
||||
if hook.IsInstalled(profile) {
|
||||
t.Fatal("should not be installed after Uninstall")
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(profile)
|
||||
if !strings.Contains(string(data), "existing content") {
|
||||
t.Error("original content should survive install+uninstall cycle")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReloadNote ---
|
||||
|
||||
func TestReloadNote_ContainsPath(t *testing.T) {
|
||||
note := hook.ReloadNote("/home/user/.bashrc")
|
||||
if !strings.Contains(note, "/home/user/.bashrc") {
|
||||
t.Errorf("ReloadNote should contain the profile path, got %q", note)
|
||||
}
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func writeProfile(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
f := filepath.Join(t.TempDir(), "profile")
|
||||
if err := os.WriteFile(f, []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
@@ -20,10 +20,16 @@ func RunWizard(cfg *config.Config) error {
|
||||
url := cfg.NetBox.URL
|
||||
token := cfg.NetBox.Token
|
||||
defaultUser := cfg.SSH.DefaultUser
|
||||
defaultPort := ""
|
||||
if cfg.SSH.DefaultPort > 0 {
|
||||
defaultPort = strconv.Itoa(cfg.SSH.DefaultPort)
|
||||
}
|
||||
strategiesRaw := strings.Join(cfg.Resolver.Strategies, ", ")
|
||||
subnets := strings.Join(cfg.Resolver.ManagementSubnets, ", ")
|
||||
interfaceName := cfg.Resolver.InterfaceName
|
||||
cacheTTL := strconv.Itoa(cfg.Cache.TTL)
|
||||
shortcutDomains := strings.Join(cfg.Shortcuts.Domains, ", ")
|
||||
stripHyphens := cfg.Shortcuts.StripHyphens
|
||||
|
||||
if strategiesRaw == "" {
|
||||
strategiesRaw = "primary_ip"
|
||||
@@ -62,6 +68,21 @@ func RunWizard(cfg *config.Config) error {
|
||||
Title("Default SSH user").
|
||||
Description("Leave empty to use your system user ($USER).").
|
||||
Value(&defaultUser),
|
||||
huh.NewInput().
|
||||
Title("Default SSH port").
|
||||
Description("Leave empty to use the standard port (22).").
|
||||
Placeholder("22").
|
||||
Value(&defaultPort).
|
||||
Validate(func(s string) error {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return nil
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimSpace(s))
|
||||
if err != nil || n < 1 || n > 65535 {
|
||||
return errors.New("must be a port number between 1 and 65535")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
).Title("SSH defaults"),
|
||||
|
||||
huh.NewGroup(
|
||||
@@ -90,6 +111,18 @@ func RunWizard(cfg *config.Config) error {
|
||||
return nil
|
||||
}),
|
||||
).Title("Resolver & cache"),
|
||||
|
||||
huh.NewGroup(
|
||||
huh.NewInput().
|
||||
Title("Domain suffixes").
|
||||
Description("Comma-separated suffixes stripped for shortcuts, e.g. .example.com, .example.de\nAllows typing 'web01' instead of 'web01.example.com'.").
|
||||
Placeholder(".example.com").
|
||||
Value(&shortcutDomains),
|
||||
huh.NewConfirm().
|
||||
Title("Strip hyphens").
|
||||
Description("When enabled, fsn1-web01.example.com can be accessed as fsn1web01.\nOnly works for hosts already in the cache.").
|
||||
Value(&stripHyphens),
|
||||
).Title("Shortcuts"),
|
||||
)
|
||||
|
||||
if err := form.Run(); err != nil {
|
||||
@@ -109,6 +142,11 @@ func RunWizard(cfg *config.Config) error {
|
||||
|
||||
ttl, _ := strconv.Atoi(cacheTTL)
|
||||
|
||||
port := 0
|
||||
if p, err := strconv.Atoi(strings.TrimSpace(defaultPort)); err == nil {
|
||||
port = p
|
||||
}
|
||||
|
||||
var subnetList []string
|
||||
for _, s := range strings.Split(subnets, ",") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
@@ -116,6 +154,16 @@ func RunWizard(cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
var domainList []string
|
||||
for _, s := range strings.Split(shortcutDomains, ",") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
if !strings.HasPrefix(s, ".") {
|
||||
s = "." + s
|
||||
}
|
||||
domainList = append(domainList, s)
|
||||
}
|
||||
}
|
||||
|
||||
out := config.Config{
|
||||
NetBox: config.NetBoxConfig{
|
||||
URL: strings.TrimRight(strings.TrimSpace(url), "/"),
|
||||
@@ -124,6 +172,7 @@ func RunWizard(cfg *config.Config) error {
|
||||
},
|
||||
SSH: config.SSHConfig{
|
||||
DefaultUser: strings.TrimSpace(defaultUser),
|
||||
DefaultPort: port,
|
||||
},
|
||||
Resolver: config.ResolverConfig{
|
||||
Strategies: strategies,
|
||||
@@ -133,6 +182,10 @@ func RunWizard(cfg *config.Config) error {
|
||||
Cache: config.CacheConfig{
|
||||
TTL: ttl,
|
||||
},
|
||||
Shortcuts: config.ShortcutsConfig{
|
||||
Domains: domainList,
|
||||
StripHyphens: stripHyphens,
|
||||
},
|
||||
}
|
||||
|
||||
return save(out)
|
||||
@@ -200,6 +253,22 @@ func save(cfg config.Config) error {
|
||||
if cfg.SSH.DefaultUser != "" {
|
||||
fmt.Fprintf(&b, " default_user: %q\n", cfg.SSH.DefaultUser)
|
||||
}
|
||||
if cfg.SSH.DefaultPort > 0 {
|
||||
fmt.Fprintf(&b, " default_port: %d\n", cfg.SSH.DefaultPort)
|
||||
}
|
||||
|
||||
if len(cfg.Shortcuts.Domains) > 0 || cfg.Shortcuts.StripHyphens {
|
||||
b.WriteString("\nshortcuts:\n")
|
||||
if len(cfg.Shortcuts.Domains) > 0 {
|
||||
b.WriteString(" domains:\n")
|
||||
for _, d := range cfg.Shortcuts.Domains {
|
||||
fmt.Fprintf(&b, " - %s\n", d)
|
||||
}
|
||||
}
|
||||
if cfg.Shortcuts.StripHyphens {
|
||||
b.WriteString(" strip_hyphens: true\n")
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte(b.String()), 0o600); err != nil {
|
||||
return fmt.Errorf("writing config: %w", err)
|
||||
|
||||
@@ -116,6 +116,148 @@ func TestSave_CreatesConfigDir(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_DefaultPort(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
cfg := config.Config{
|
||||
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
|
||||
Cache: config.CacheConfig{TTL: 60},
|
||||
SSH: config.SSHConfig{DefaultPort: 2222},
|
||||
}
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
|
||||
if !strings.Contains(string(data), "default_port: 2222") {
|
||||
t.Errorf("expected default_port: 2222 in config, got:\n%s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_Shortcuts_WritesSection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
cfg := config.Config{
|
||||
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
|
||||
Cache: config.CacheConfig{TTL: 60},
|
||||
Shortcuts: config.ShortcutsConfig{
|
||||
Domains: []string{".example.com"},
|
||||
StripHyphens: true,
|
||||
},
|
||||
}
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
|
||||
content := string(data)
|
||||
for _, want := range []string{"shortcuts:", "domains:", ".example.com", "strip_hyphens: true"} {
|
||||
if !strings.Contains(content, want) {
|
||||
t.Errorf("expected %q in config, got:\n%s", want, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_OmitsDefaultPortWhenZero(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
cfg := config.Config{
|
||||
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
|
||||
Cache: config.CacheConfig{TTL: 60},
|
||||
SSH: config.SSHConfig{DefaultPort: 0},
|
||||
}
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
|
||||
if strings.Contains(string(data), "default_port") {
|
||||
t.Errorf("default_port should be omitted when zero, got:\n%s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_OmitsShortcutsWhenEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
cfg := config.Config{
|
||||
NetBox: config.NetBoxConfig{URL: "http://x", Token: "t", TokenVersion: 1},
|
||||
Cache: config.CacheConfig{TTL: 60},
|
||||
Shortcuts: config.ShortcutsConfig{}, // empty
|
||||
}
|
||||
if err := save(cfg); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(dir, "netssh.yaml"))
|
||||
if strings.Contains(string(data), "shortcuts:") {
|
||||
t.Errorf("shortcuts section should be omitted when empty, got:\n%s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_Roundtrip_WithNewFields(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
os.Setenv("XDG_CONFIG_HOME", dir)
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
original := config.Config{
|
||||
NetBox: config.NetBoxConfig{
|
||||
URL: "https://netbox.example.com",
|
||||
Token: "nbt_test",
|
||||
TokenVersion: 2,
|
||||
},
|
||||
SSH: config.SSHConfig{
|
||||
DefaultUser: "admin",
|
||||
DefaultPort: 2222,
|
||||
},
|
||||
Resolver: config.ResolverConfig{
|
||||
Strategies: []string{"primary_ip"},
|
||||
},
|
||||
Cache: config.CacheConfig{TTL: 3600},
|
||||
Shortcuts: config.ShortcutsConfig{
|
||||
Domains: []string{".example.com", ".example.de"},
|
||||
StripHyphens: true,
|
||||
},
|
||||
}
|
||||
|
||||
if err := save(original); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := config.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load after save: %v", err)
|
||||
}
|
||||
|
||||
if loaded.SSH.DefaultPort != original.SSH.DefaultPort {
|
||||
t.Errorf("DefaultPort: got %d, want %d", loaded.SSH.DefaultPort, original.SSH.DefaultPort)
|
||||
}
|
||||
if len(loaded.Shortcuts.Domains) != len(original.Shortcuts.Domains) {
|
||||
t.Errorf("Shortcuts.Domains length: got %d, want %d", len(loaded.Shortcuts.Domains), len(original.Shortcuts.Domains))
|
||||
} else {
|
||||
for i, d := range original.Shortcuts.Domains {
|
||||
if loaded.Shortcuts.Domains[i] != d {
|
||||
t.Errorf("Shortcuts.Domains[%d]: got %q, want %q", i, loaded.Shortcuts.Domains[i], d)
|
||||
}
|
||||
}
|
||||
}
|
||||
if loaded.Shortcuts.StripHyphens != original.Shortcuts.StripHyphens {
|
||||
t.Errorf("StripHyphens: got %v, want %v", loaded.Shortcuts.StripHyphens, original.Shortcuts.StripHyphens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStrategies(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package shortcuts
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
|
||||
)
|
||||
|
||||
// Normalize strips configured domain suffixes and optionally hyphens from a hostname.
|
||||
// The result is lowercased for case-insensitive comparison.
|
||||
func Normalize(name string, cfg config.ShortcutsConfig) string {
|
||||
s := strings.ToLower(name)
|
||||
for _, domain := range cfg.Domains {
|
||||
suffix := strings.ToLower(domain)
|
||||
if !strings.HasPrefix(suffix, ".") {
|
||||
suffix = "." + suffix
|
||||
}
|
||||
if strings.HasSuffix(s, suffix) {
|
||||
s = s[:len(s)-len(suffix)]
|
||||
break
|
||||
}
|
||||
}
|
||||
if cfg.StripHyphens {
|
||||
s = strings.ReplaceAll(s, "-", "")
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// MakeNormalizer returns a closure bound to cfg, suitable for Cache.GetByShortcut.
|
||||
func MakeNormalizer(cfg config.ShortcutsConfig) func(string) string {
|
||||
return func(name string) string {
|
||||
return Normalize(name, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// AliasName returns a shell-safe alias name for the given host.
|
||||
// It normalizes the name (strips configured domains and optionally hyphens) then
|
||||
// replaces any remaining dots with underscores so the result is a valid identifier.
|
||||
func AliasName(name string, cfg config.ShortcutsConfig) string {
|
||||
s := Normalize(name, cfg)
|
||||
return strings.ReplaceAll(s, ".", "_")
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package shortcuts_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config"
|
||||
"git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/shortcuts"
|
||||
)
|
||||
|
||||
func TestNormalize(t *testing.T) {
|
||||
cfg := config.ShortcutsConfig{
|
||||
Domains: []string{".example.com", ".example.de"},
|
||||
StripHyphens: true,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"fsn1-web01.example.com", "fsn1web01"},
|
||||
{"fsn1-web01.example.de", "fsn1web01"},
|
||||
{"FSN1-WEB01.EXAMPLE.COM", "fsn1web01"},
|
||||
{"fsn1-web01", "fsn1web01"},
|
||||
{"web01.example.com", "web01"},
|
||||
{"web01.other.com", "web01.other.com"}, // unknown domain not stripped
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := shortcuts.Normalize(c.input, cfg)
|
||||
if got != c.want {
|
||||
t.Errorf("Normalize(%q) = %q, want %q", c.input, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeNoHyphenStrip(t *testing.T) {
|
||||
cfg := config.ShortcutsConfig{
|
||||
Domains: []string{".example.com"},
|
||||
StripHyphens: false,
|
||||
}
|
||||
|
||||
got := shortcuts.Normalize("fsn1-web01.example.com", cfg)
|
||||
want := "fsn1-web01"
|
||||
if got != want {
|
||||
t.Errorf("Normalize = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasName_WithDomainAndHyphens(t *testing.T) {
|
||||
cfg := config.ShortcutsConfig{
|
||||
Domains: []string{".example.com"},
|
||||
StripHyphens: true,
|
||||
}
|
||||
got := shortcuts.AliasName("fsn1-web01.example.com", cfg)
|
||||
want := "fsn1web01"
|
||||
if got != want {
|
||||
t.Errorf("AliasName = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasName_DotReplacement(t *testing.T) {
|
||||
cfg := config.ShortcutsConfig{
|
||||
Domains: []string{".example.com"},
|
||||
}
|
||||
// "web01.other.com" — domain not stripped, remaining dots → underscores
|
||||
got := shortcuts.AliasName("web01.other.com", cfg)
|
||||
want := "web01_other_com"
|
||||
if got != want {
|
||||
t.Errorf("AliasName = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasName_NoConfig(t *testing.T) {
|
||||
cfg := config.ShortcutsConfig{}
|
||||
got := shortcuts.AliasName("web01.example.com", cfg)
|
||||
want := "web01_example_com"
|
||||
if got != want {
|
||||
t.Errorf("AliasName = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -86,6 +86,19 @@ func ReplaceHost(args []string, destIdx int, newHost string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// HasPortFlag reports whether a port was specified via -p in args.
|
||||
func HasPortFlag(args []string) bool {
|
||||
for i, a := range args {
|
||||
if a == "-p" && i+1 < len(args) {
|
||||
return true
|
||||
}
|
||||
if len(a) > 2 && a[0] == '-' && a[1] == 'p' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasUserFlag reports whether a user was specified via -l in args.
|
||||
// Used to avoid overriding an explicit -l with the configured default user.
|
||||
func HasUserFlag(args []string) bool {
|
||||
|
||||
@@ -144,6 +144,37 @@ func TestHasUserFlag_LFlagAtEnd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPortFlag_FlagSeparated(t *testing.T) {
|
||||
if !HasPortFlag([]string{"-p", "2222", "host"}) {
|
||||
t.Error("should detect -p <port>")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPortFlag_FlagAttached(t *testing.T) {
|
||||
if !HasPortFlag([]string{"-p2222", "host"}) {
|
||||
t.Error("should detect -p<port> (attached form)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPortFlag_NotPresent(t *testing.T) {
|
||||
if HasPortFlag([]string{"-l", "admin", "host"}) {
|
||||
t.Error("should not detect port flag when absent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPortFlag_EmptyArgs(t *testing.T) {
|
||||
if HasPortFlag([]string{}) {
|
||||
t.Error("empty args should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPortFlag_PAtEnd(t *testing.T) {
|
||||
// -p at the very end with no value — should return false
|
||||
if HasPortFlag([]string{"-p"}) {
|
||||
t.Error("-p with no value should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func assertParsed(t *testing.T, got *ParsedArgs, host, user string, destIdx int) {
|
||||
t.Helper()
|
||||
if got == nil {
|
||||
|
||||
Reference in New Issue
Block a user