From 8ef4bbec160fd21d2ed8c0f23e66af851b2a2c9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Untersch=C3=BCtz?= Date: Sat, 23 May 2026 12:38:41 +0200 Subject: [PATCH] Add core modules (SSH args parser, cache, resolver, NetBox client) with tests --- .gitea/workflows/release.yml | 71 ++++++++ .gitignore | 21 +++ README.md | 190 +++++++++++++++++++ go.mod | 44 +++++ go.sum | 82 +++++++++ install.sh | 94 ++++++++++ internal/cache/cache.go | 134 ++++++++++++++ internal/cache/cache_test.go | 235 ++++++++++++++++++++++++ internal/config/config.go | 76 ++++++++ internal/netbox/client.go | 194 ++++++++++++++++++++ internal/netbox/client_test.go | 261 +++++++++++++++++++++++++++ internal/netbox/models.go | 53 ++++++ internal/resolver/chain.go | 57 ++++++ internal/resolver/chain_test.go | 155 ++++++++++++++++ internal/resolver/interface_name.go | 38 ++++ internal/resolver/management.go | 53 ++++++ internal/resolver/management_test.go | 109 +++++++++++ internal/resolver/primary_ip.go | 23 +++ internal/resolver/primary_ip_test.go | 91 ++++++++++ internal/resolver/strategy.go | 17 ++ internal/ssh/args.go | 109 +++++++++++ internal/ssh/args_test.go | 161 +++++++++++++++++ internal/ssh/exec.go | 19 ++ internal/tui/model.go | 237 ++++++++++++++++++++++++ 24 files changed, 2524 insertions(+) create mode 100644 .gitea/workflows/release.yml create mode 100644 .gitignore create mode 100644 README.md create mode 100644 go.mod create mode 100644 go.sum create mode 100755 install.sh create mode 100644 internal/cache/cache.go create mode 100644 internal/cache/cache_test.go create mode 100644 internal/config/config.go create mode 100644 internal/netbox/client.go create mode 100644 internal/netbox/client_test.go create mode 100644 internal/netbox/models.go create mode 100644 internal/resolver/chain.go create mode 100644 internal/resolver/chain_test.go create mode 100644 internal/resolver/interface_name.go create mode 100644 internal/resolver/management.go create mode 100644 internal/resolver/management_test.go create mode 100644 internal/resolver/primary_ip.go create mode 100644 internal/resolver/primary_ip_test.go create mode 100644 internal/resolver/strategy.go create mode 100644 internal/ssh/args.go create mode 100644 internal/ssh/args_test.go create mode 100644 internal/ssh/exec.go create mode 100644 internal/tui/model.go diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml new file mode 100644 index 0000000..4f121c6 --- /dev/null +++ b/.gitea/workflows/release.yml @@ -0,0 +1,71 @@ +name: Release + +on: + push: + tags: + - 'v*' + +jobs: + release: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Run tests + run: go test ./... + + - name: Build binaries + run: | + mkdir -p dist + for platform in linux/amd64 linux/arm64 darwin/amd64 darwin/arm64; do + GOOS=${platform%/*} + GOARCH=${platform#*/} + out="dist/netssh_${GOOS}_${GOARCH}" + echo "→ $out" + CGO_ENABLED=0 GOOS=$GOOS GOARCH=$GOARCH \ + go build -trimpath -ldflags="-s -w" \ + -o "$out" ./cmd/netssh + done + + - name: Generate checksums + working-directory: dist + run: sha256sum netssh_* > checksums.txt + + - name: Create release + id: create_release + run: | + TAG="${{ github.ref_name }}" + + RELEASE=$(curl -sf -X POST \ + -H "Authorization: token ${{ secrets.GITEA_TOKEN }}" \ + -H "Content-Type: application/json" \ + "${{ github.server_url }}/api/v1/repos/${{ github.repository }}/releases" \ + -d "{ + \"tag_name\": \"$TAG\", + \"name\": \"$TAG\", + \"body\": \"## Installation\n\n\`\`\`sh\ncurl -fsSL https://git.zb-server.de/Sebi/ssh-netbox-wrapper/raw/branch/main/install.sh | bash\n\`\`\`\", + \"draft\": false, + \"prerelease\": false + }") + + RELEASE_ID=$(echo "$RELEASE" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])") + echo "release_id=$RELEASE_ID" >> "$GITHUB_OUTPUT" + + - name: Upload assets + run: | + RELEASE_ID="${{ steps.create_release.outputs.release_id }}" + for file in dist/netssh_* dist/checksums.txt; do + name=$(basename "$file") + echo "↑ $name" + curl -sf -X POST \ + -H "Authorization: token ${{ secrets.GITEA_TOKEN }}" \ + -H "Content-Type: application/octet-stream" \ + "${{ github.server_url }}/api/v1/repos/${{ github.repository }}/releases/${RELEASE_ID}/assets?name=${name}" \ + --data-binary "@$file" + done diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5a83cff --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# binary +netssh +dist/ + +# Go test artifacts +*.test +*.out + +# environment / secrets +.env +*.env + +# editor +.idea/ +.vscode/ +*.swp +*~ + +# OS +.DS_Store +Thumbs.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..767dce2 --- /dev/null +++ b/README.md @@ -0,0 +1,190 @@ +# netssh + +A transparent SSH wrapper that resolves hostnames via [NetBox](https://netbox.dev/) before connecting. + +Instead of looking up an IP manually, you just type the hostname as it appears in NetBox: + +```sh +netssh my-router-01 +netssh -p 2222 admin@app-server-03 uptime +``` + +`netssh` looks up the host in NetBox, resolves the right IP using a configurable strategy chain, and replaces the process with the native `ssh` binary — so all your existing SSH configs, keys, and agent forwarding work without any changes. + +## Features + +- **Transparent proxy** — replaces itself with `ssh` via `syscall.Exec`, preserving all SSH flags and options +- **Flexible IP resolution** — configurable chain of strategies: management subnet, primary IP, or named interface +- **Interactive TUI** — fuzzy search with live NetBox queries and 300 ms debouncing (start with `netssh`, no arguments) +- **Persistent cache** — successful lookups are cached to `~/.cache/netssh/hosts.json` for instant shell completion +- **Shell completion** — tab-complete hostnames from the cache in zsh, bash, and fish +- **Default SSH user** — set a fallback username once in config instead of typing it every time + +## Installation + +### One-liner (Linux & macOS) + +```sh +curl -fsSL https://git.zb-server.de/Sebi/ssh-netbox-wrapper/raw/branch/main/install.sh | bash +``` + +The script detects your OS and architecture, downloads the matching binary from the [latest release](https://git.zb-server.de/Sebi/ssh-netbox-wrapper/releases/latest), verifies the SHA-256 checksum, and installs to `/usr/local/bin/netssh` (using `sudo` only if necessary). + +To install to a custom directory: + +```sh +INSTALL_DIR=~/.local/bin curl -fsSL https://git.zb-server.de/Sebi/ssh-netbox-wrapper/raw/branch/main/install.sh | bash +``` + +### Build from source + +```sh +git clone ssh://git@git.zb-server.de:30022/Sebi/ssh-netbox-wrapper.git +cd ssh-netbox-wrapper +go build -o netssh ./cmd/netssh +``` + +## Configuration + +Create `~/.config/netssh.yaml`: + +```yaml +netbox: + url: https://netbox.example.com + token: your-api-token-here + +resolver: + # Strategies are tried in order; the first to return an IP wins. + strategies: + - management_subnet + - primary_ip + # Used by the management_subnet strategy. + management_subnets: + - 10.0.0.0/8 + - 172.16.0.0/12 + # Used by the interface_name strategy. + interface_name: mgmt0 + +cache: + ttl: 3600 # seconds; 0 = always query NetBox on connect (cache still used for completion) + # path: ~/.cache/netssh/hosts.json # default + +ssh: + default_user: admin # used when no user is specified on the command line +``` + +Any value can be overridden with environment variables (`NETSSH_NETBOX_URL`, `NETSSH_NETBOX_TOKEN`, etc.) or will be read from the config file. + +## Usage + +### SSH wrapper mode + +Pass any SSH flags and a NetBox hostname: + +```sh +netssh my-router-01 +netssh -p 2222 admin@app-server-03 uptime +netssh -i ~/.ssh/id_rsa -o StrictHostKeyChecking=no db-primary +``` + +The process is replaced by `ssh` with the resolved IP — your `~/.ssh/config`, agent, and keys all work as normal. + +### Default username + +Set `ssh.default_user` in the config to avoid typing a username every time: + +```sh +netssh my-router # → ssh -l admin 10.0.0.1 +``` + +The default is only applied when no user is specified on the command line. An explicit user always takes precedence: + +```sh +netssh root@my-router # user@ prefix wins → ssh root@10.0.0.1 +netssh -l ops my-router # -l flag wins → ssh -l ops 10.0.0.1 +``` + +### Interactive TUI + +Run without arguments to open the interactive search: + +```sh +netssh +``` + +| Key | Action | +|-----|--------| +| type | filter hosts (300 ms debounce → NetBox query) | +| `Tab` | autocomplete top result into the search field | +| `↑` / `↓` | navigate results | +| `Enter` | connect to selected host | +| `Esc` / `Ctrl+C` | quit | + +### Cache management + +```sh +netssh cache list # show all cached entries +netssh cache refresh # re-fetch all hosts from NetBox +netssh cache clear # wipe the cache +``` + +### Search (for scripting) + +```sh +netssh search app- # prints matching hostnames, one per line +``` + +## IP Resolution Strategies + +Strategies are tried in the configured order; the first to succeed wins. + +| Name | Description | +|------|-------------| +| `primary_ip` | Returns the `primary_ip4` (or `primary_ip6`) set in NetBox. No extra API call. | +| `management_subnet` | Fetches all IPs for the host and returns the first one matching a configured CIDR. | +| `interface_name` | Fetches IPs attached to a specific named interface (e.g. `mgmt0`). | + +## Shell Completion + +### zsh + +```sh +netssh completion zsh > "${fpath[1]}/_netssh" +``` + +Or add to `.zshrc`: + +```zsh +source <(netssh completion zsh) +``` + +### bash + +```sh +netssh completion bash > /etc/bash_completion.d/netssh +``` + +### fish + +```sh +netssh completion fish > ~/.config/fish/completions/netssh.fish +``` + +Completions are served from the local cache — no network request on every ``. + +## Development + +```sh +go test ./... # run all tests +go build ./... # build all packages +``` + +The test suite covers the cache, NetBox client (via `httptest`), IP resolver chain, and SSH argument parser. + +## How it works + +1. `netssh` checks whether the first argument is a known subcommand (`search`, `cache`, `completion`). If not, it enters SSH wrapper mode. +2. It parses the SSH arguments to extract the destination hostname, handling all flags that consume an extra argument (`-p`, `-i`, `-J`, …). +3. It checks the local cache. If the entry exists and is within the TTL, it connects immediately. +4. Otherwise it queries NetBox (`/api/dcim/devices/` and `/api/virtualization/virtual-machines/` in parallel), runs the result through the resolver chain, and caches the IP. +5. It calls `syscall.Exec` to replace itself with `ssh`, substituting the hostname with the resolved IP. diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d34627a --- /dev/null +++ b/go.mod @@ -0,0 +1,44 @@ +module git.zb-server.de/Sebi/ssh-netbox-wrapper + +go 1.26.3 + +require ( + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/charmbracelet/bubbles v1.0.0 // indirect + github.com/charmbracelet/bubbletea v1.3.10 // indirect + github.com/charmbracelet/colorprofile v0.4.1 // indirect + github.com/charmbracelet/lipgloss v1.1.0 // indirect + github.com/charmbracelet/x/ansi v0.11.6 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.9.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.5.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/cobra v1.10.2 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/spf13/viper v1.21.0 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.28.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..13116ea --- /dev/null +++ b/go.sum @@ -0,0 +1,82 @@ +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc= +github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= +github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= +github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= +github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..9528aa5 --- /dev/null +++ b/install.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="Sebi/ssh-netbox-wrapper" +BASE_URL="https://git.zb-server.de" +BINARY="netssh" +INSTALL_DIR="${INSTALL_DIR:-/usr/local/bin}" + +# --- helpers ----------------------------------------------------------------- + +red() { printf '\033[31m%s\033[0m\n' "$*"; } +green() { printf '\033[32m%s\033[0m\n' "$*"; } +bold() { printf '\033[1m%s\033[0m\n' "$*"; } +info() { printf ' %s\n' "$*"; } + +die() { red "error: $*" >&2; exit 1; } + +need() { + command -v "$1" &>/dev/null || die "'$1' is required but not installed" +} + +# --- detect OS / arch -------------------------------------------------------- + +detect_os() { + case "$(uname -s)" in + Linux) echo linux ;; + Darwin) echo darwin ;; + *) die "unsupported OS: $(uname -s)" ;; + esac +} + +detect_arch() { + case "$(uname -m)" in + x86_64|amd64) echo amd64 ;; + aarch64|arm64) echo arm64 ;; + *) die "unsupported architecture: $(uname -m)" ;; + esac +} + +# --- main -------------------------------------------------------------------- + +need curl + +OS=$(detect_os) +ARCH=$(detect_arch) + +bold "netssh installer" +info "platform : $OS/$ARCH" +info "target : $INSTALL_DIR/$BINARY" + +# Fetch the latest release tag from the Gitea API. +API_URL="$BASE_URL/api/v1/repos/$REPO/releases/latest" +TAG=$(curl -sf "$API_URL" | python3 -c "import sys,json; print(json.load(sys.stdin)['tag_name'])") \ + || die "could not fetch latest release from $API_URL" + +ASSET="${BINARY}_${OS}_${ARCH}" +DOWNLOAD_URL="$BASE_URL/$REPO/releases/download/$TAG/$ASSET" +CHECKSUM_URL="$BASE_URL/$REPO/releases/download/$TAG/checksums.txt" + +info "version : $TAG" +echo + +# Download binary and checksums into a temp directory. +TMP=$(mktemp -d) +trap 'rm -rf "$TMP"' EXIT + +info "downloading $ASSET..." +curl -fL --progress-bar -o "$TMP/$ASSET" "$DOWNLOAD_URL" + +info "verifying checksum..." +curl -sf -o "$TMP/checksums.txt" "$CHECKSUM_URL" \ + || { info "warning: could not fetch checksums, skipping verification"; } + +if [[ -f "$TMP/checksums.txt" ]]; then + # checksums.txt was built with `sha256sum` in the dist/ dir, so entries look + # like "abc123 netssh_linux_amd64". We need to cd into the temp dir first. + (cd "$TMP" && grep "$ASSET" checksums.txt | sha256sum --check --status) \ + || die "checksum mismatch — aborting installation" + info "checksum OK" +fi + +chmod +x "$TMP/$ASSET" + +# Install — try without sudo first, fall back if needed. +if [[ -w "$INSTALL_DIR" ]]; then + mv "$TMP/$ASSET" "$INSTALL_DIR/$BINARY" +else + info "sudo required to write to $INSTALL_DIR" + sudo mv "$TMP/$ASSET" "$INSTALL_DIR/$BINARY" +fi + +echo +green "✓ netssh $TAG installed to $INSTALL_DIR/$BINARY" +info "run 'netssh --help' to get started" diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..347fa56 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,134 @@ +package cache + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +type Entry struct { + Name string `json:"name"` + IP string `json:"ip"` + Kind string `json:"kind"` + Tags []string `json:"tags,omitempty"` + CachedAt time.Time `json:"cached_at"` +} + +type Cache struct { + mu sync.RWMutex + entries map[string]Entry + path string + ttl time.Duration +} + +type diskFormat struct { + Entries []Entry `json:"entries"` +} + +func New(path string, ttlSeconds int) *Cache { + return &Cache{ + entries: make(map[string]Entry), + path: path, + ttl: time.Duration(ttlSeconds) * time.Second, + } +} + +func (c *Cache) Load() error { + c.mu.Lock() + defer c.mu.Unlock() + + data, err := os.ReadFile(c.path) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return err + } + + var df diskFormat + if err := json.Unmarshal(data, &df); err != nil { + return err + } + + c.entries = make(map[string]Entry, len(df.Entries)) + for _, e := range df.Entries { + c.entries[e.Name] = e + } + return nil +} + +func (c *Cache) Save() error { + c.mu.RLock() + df := diskFormat{Entries: make([]Entry, 0, len(c.entries))} + for _, e := range c.entries { + df.Entries = append(df.Entries, e) + } + c.mu.RUnlock() + + if err := os.MkdirAll(filepath.Dir(c.path), 0o755); err != nil { + return err + } + + data, err := json.MarshalIndent(df, "", " ") + if err != nil { + return err + } + return os.WriteFile(c.path, data, 0o644) +} + +func (c *Cache) Upsert(e Entry) { + e.CachedAt = time.Now() + c.mu.Lock() + c.entries[e.Name] = e + c.mu.Unlock() +} + +// Search returns all entries whose name starts with prefix (case-insensitive). +// TTL is intentionally ignored — this is used for shell completion. +func (c *Cache) Search(prefix string) []Entry { + c.mu.RLock() + defer c.mu.RUnlock() + + prefix = strings.ToLower(prefix) + var out []Entry + for name, e := range c.entries { + if strings.HasPrefix(strings.ToLower(name), prefix) { + out = append(out, e) + } + } + return out +} + +// Get returns an entry and reports whether it is still within the TTL. +func (c *Cache) Get(name string) (entry Entry, fresh bool) { + c.mu.RLock() + e, ok := c.entries[name] + c.mu.RUnlock() + + if !ok { + return Entry{}, false + } + if c.ttl == 0 { + return e, false + } + return e, time.Since(e.CachedAt) < c.ttl +} + +func (c *Cache) Clear() { + c.mu.Lock() + c.entries = make(map[string]Entry) + c.mu.Unlock() +} + +func (c *Cache) All() []Entry { + c.mu.RLock() + defer c.mu.RUnlock() + out := make([]Entry, 0, len(c.entries)) + for _, e := range c.entries { + out = append(out, e) + } + return out +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..11ee8f3 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,235 @@ +package cache + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" +) + +func TestNew(t *testing.T) { + c := New("/tmp/test.json", 60) + if c == nil { + t.Fatal("New returned nil") + } + if c.ttl != 60*time.Second { + t.Errorf("ttl: got %v, want %v", c.ttl, 60*time.Second) + } +} + +func TestLoad_MissingFile(t *testing.T) { + c := New("/nonexistent/path/cache.json", 60) + if err := c.Load(); err != nil { + t.Errorf("Load on missing file should not error, got: %v", err) + } +} + +func TestLoad_InvalidJSON(t *testing.T) { + f := tempFile(t, []byte("not json")) + c := New(f, 60) + if err := c.Load(); err == nil { + t.Error("Load on invalid JSON should return an error") + } +} + +func TestSaveAndLoad_Roundtrip(t *testing.T) { + path := filepath.Join(t.TempDir(), "cache.json") + c := New(path, 3600) + + c.Upsert(Entry{Name: "host-a", IP: "10.0.0.1", Kind: "device"}) + c.Upsert(Entry{Name: "host-b", IP: "10.0.0.2", Kind: "vm", Tags: []string{"prod"}}) + + if err := c.Save(); err != nil { + t.Fatalf("Save: %v", err) + } + + c2 := New(path, 3600) + if err := c2.Load(); err != nil { + t.Fatalf("Load: %v", err) + } + + e, _ := c2.Get("host-a") + if e.IP != "10.0.0.1" { + t.Errorf("host-a IP: got %q, want %q", e.IP, "10.0.0.1") + } + e2, _ := c2.Get("host-b") + if len(e2.Tags) != 1 || e2.Tags[0] != "prod" { + t.Errorf("host-b tags: got %v, want [prod]", e2.Tags) + } +} + +func TestSave_CreatesDirectory(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "sub", "dir", "cache.json") + c := New(path, 60) + c.Upsert(Entry{Name: "x", IP: "1.2.3.4", Kind: "device"}) + if err := c.Save(); err != nil { + t.Fatalf("Save: %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Errorf("cache file not created: %v", err) + } +} + +func TestUpsert_SetsTimestamp(t *testing.T) { + c := New("", 60) + before := time.Now() + c.Upsert(Entry{Name: "h", IP: "1.1.1.1", Kind: "device"}) + e, _ := c.Get("h") + if e.CachedAt.Before(before) { + t.Error("CachedAt should be set to current time on Upsert") + } +} + +func TestUpsert_Overwrites(t *testing.T) { + c := New("", 60) + c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"}) + c.Upsert(Entry{Name: "host", IP: "10.0.0.2", Kind: "device"}) + e, _ := c.Get("host") + if e.IP != "10.0.0.2" { + t.Errorf("Upsert should overwrite: got %q, want %q", e.IP, "10.0.0.2") + } +} + +func TestSearch_PrefixMatch(t *testing.T) { + c := New("", 60) + c.Upsert(Entry{Name: "app-server-01", IP: "10.0.0.1", Kind: "device"}) + c.Upsert(Entry{Name: "app-server-02", IP: "10.0.0.2", Kind: "vm"}) + c.Upsert(Entry{Name: "db-server-01", IP: "10.0.0.3", Kind: "device"}) + + results := c.Search("app") + if len(results) != 2 { + t.Errorf("Search(app): got %d results, want 2", len(results)) + } +} + +func TestSearch_CaseInsensitive(t *testing.T) { + c := New("", 60) + c.Upsert(Entry{Name: "App-Server", IP: "10.0.0.1", Kind: "device"}) + + if len(c.Search("app")) != 1 { + t.Error("Search should be case-insensitive") + } + if len(c.Search("APP")) != 1 { + t.Error("Search should be case-insensitive for uppercase") + } +} + +func TestSearch_EmptyPrefix(t *testing.T) { + c := New("", 60) + c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"}) + c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"}) + + if len(c.Search("")) != 2 { + t.Error("Search('') should return all entries") + } +} + +func TestSearch_NoMatch(t *testing.T) { + c := New("", 60) + c.Upsert(Entry{Name: "host", IP: "1.1.1.1", Kind: "device"}) + + if len(c.Search("xyz")) != 0 { + t.Error("Search should return empty slice when no match") + } +} + +func TestGet_Fresh(t *testing.T) { + c := New("", 3600) + c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"}) + + e, fresh := c.Get("host") + if !fresh { + t.Error("entry just inserted should be fresh") + } + if e.IP != "10.0.0.1" { + t.Errorf("IP: got %q, want %q", e.IP, "10.0.0.1") + } +} + +func TestGet_Expired(t *testing.T) { + c := New("", 1) // 1 second TTL + e := Entry{Name: "host", IP: "10.0.0.1", Kind: "device", CachedAt: time.Now().Add(-2 * time.Second)} + c.mu.Lock() + c.entries["host"] = e + c.mu.Unlock() + + _, fresh := c.Get("host") + if fresh { + t.Error("entry older than TTL should not be fresh") + } +} + +func TestGet_ZeroTTL_AlwaysStale(t *testing.T) { + c := New("", 0) // TTL=0 means never fresh for connect mode + c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device"}) + + _, fresh := c.Get("host") + if fresh { + t.Error("TTL=0 should always return fresh=false") + } +} + +func TestGet_Missing(t *testing.T) { + c := New("", 60) + _, fresh := c.Get("nonexistent") + if fresh { + t.Error("missing entry should not be fresh") + } +} + +func TestClear(t *testing.T) { + c := New("", 60) + c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"}) + c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"}) + c.Clear() + + if len(c.All()) != 0 { + t.Error("Clear should remove all entries") + } +} + +func TestAll(t *testing.T) { + c := New("", 60) + c.Upsert(Entry{Name: "a", IP: "1.1.1.1", Kind: "device"}) + c.Upsert(Entry{Name: "b", IP: "2.2.2.2", Kind: "vm"}) + + all := c.All() + if len(all) != 2 { + t.Errorf("All: got %d entries, want 2", len(all)) + } +} + +func TestSave_ProducesValidJSON(t *testing.T) { + path := filepath.Join(t.TempDir(), "cache.json") + c := New(path, 60) + c.Upsert(Entry{Name: "host", IP: "10.0.0.1", Kind: "device", Tags: []string{"mgmt"}}) + + if err := c.Save(); err != nil { + t.Fatalf("Save: %v", err) + } + + data, _ := os.ReadFile(path) + var df diskFormat + if err := json.Unmarshal(data, &df); err != nil { + t.Fatalf("saved file is not valid JSON: %v", err) + } + if len(df.Entries) != 1 { + t.Errorf("expected 1 entry in JSON, got %d", len(df.Entries)) + } +} + +// tempFile writes content to a temp file and returns its path. +func tempFile(t *testing.T, content []byte) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "cache-*.json") + if err != nil { + t.Fatal(err) + } + if _, err := f.Write(content); err != nil { + t.Fatal(err) + } + f.Close() + return f.Name() +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..d49b96a --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,76 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/spf13/viper" +) + +type Config struct { + NetBox NetBoxConfig `mapstructure:"netbox"` + Resolver ResolverConfig `mapstructure:"resolver"` + Cache CacheConfig `mapstructure:"cache"` + SSH SSHConfig `mapstructure:"ssh"` +} + +type NetBoxConfig struct { + URL string `mapstructure:"url"` + Token string `mapstructure:"token"` +} + +type ResolverConfig struct { + Strategies []string `mapstructure:"strategies"` + ManagementSubnets []string `mapstructure:"management_subnets"` + InterfaceName string `mapstructure:"interface_name"` +} + +type CacheConfig struct { + TTL int `mapstructure:"ttl"` + Path string `mapstructure:"path"` +} + +type SSHConfig struct { + DefaultUser string `mapstructure:"default_user"` +} + +func Load() (*Config, error) { + v := viper.New() + + v.SetDefault("resolver.strategies", []string{"management_subnet", "primary_ip"}) + v.SetDefault("resolver.management_subnets", []string{}) + v.SetDefault("cache.ttl", 3600) + + configDir, err := os.UserConfigDir() + if err == nil { + v.SetConfigName("netssh") + v.SetConfigType("yaml") + v.AddConfigPath(filepath.Join(configDir)) + v.AddConfigPath(".") + } + + v.SetEnvPrefix("NETSSH") + v.AutomaticEnv() + + if err := v.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return nil, fmt.Errorf("reading config: %w", err) + } + } + + var cfg Config + if err := v.Unmarshal(&cfg); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + + if cfg.Cache.Path == "" { + cacheDir, err := os.UserCacheDir() + if err != nil { + cacheDir = filepath.Join(os.Getenv("HOME"), ".cache") + } + cfg.Cache.Path = filepath.Join(cacheDir, "netssh", "hosts.json") + } + + return &cfg, nil +} diff --git a/internal/netbox/client.go b/internal/netbox/client.go new file mode 100644 index 0000000..f140255 --- /dev/null +++ b/internal/netbox/client.go @@ -0,0 +1,194 @@ +package netbox + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" +) + +type Client struct { + baseURL string + token string + httpClient *http.Client +} + +func NewClient(baseURL, token string) *Client { + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + token: token, + httpClient: &http.Client{}, + } +} + +// Search queries devices and VMs in parallel and merges the results. +func (c *Client) Search(ctx context.Context, query string) ([]HostEntry, error) { + var ( + mu sync.Mutex + results []HostEntry + errs []error + wg sync.WaitGroup + ) + + wg.Add(2) + + go func() { + defer wg.Done() + devices, err := c.searchDevices(ctx, query) + mu.Lock() + defer mu.Unlock() + if err != nil { + errs = append(errs, fmt.Errorf("devices: %w", err)) + return + } + results = append(results, devices...) + }() + + go func() { + defer wg.Done() + vms, err := c.searchVMs(ctx, query) + mu.Lock() + defer mu.Unlock() + if err != nil { + errs = append(errs, fmt.Errorf("vms: %w", err)) + return + } + results = append(results, vms...) + }() + + wg.Wait() + + if len(errs) == 2 { + return nil, fmt.Errorf("netbox search failed: %v; %v", errs[0], errs[1]) + } + + return results, nil +} + +// GetIPs returns all IP addresses assigned to a host, used by resolver strategies +// that need more than just the primary IP. +func (c *Client) GetIPs(ctx context.Context, entry HostEntry) ([]string, error) { + var apiURL string + switch entry.Kind { + case "device": + apiURL = fmt.Sprintf("%s/api/ipam/ip-addresses/?device_id=%d&limit=100", c.baseURL, entry.ID) + case "vm": + apiURL = fmt.Sprintf("%s/api/ipam/ip-addresses/?virtual_machine_id=%d&limit=100", c.baseURL, entry.ID) + default: + return nil, fmt.Errorf("unknown host kind: %q", entry.Kind) + } + + var resp netboxIPListResponse + if err := c.get(ctx, apiURL, &resp); err != nil { + return nil, err + } + + ips := make([]string, 0, len(resp.Results)) + for _, r := range resp.Results { + ips = append(ips, stripPrefix(r.Address)) + } + return ips, nil +} + +// GetIPsWithFilter calls /api/ipam/ip-addresses/ with arbitrary filter query parameters. +func (c *Client) GetIPsWithFilter(ctx context.Context, filterParams string) ([]string, error) { + apiURL := fmt.Sprintf("%s/api/ipam/ip-addresses/?%s&limit=100", c.baseURL, filterParams) + var resp netboxIPListResponse + if err := c.get(ctx, apiURL, &resp); err != nil { + return nil, err + } + ips := make([]string, 0, len(resp.Results)) + for _, r := range resp.Results { + ips = append(ips, stripPrefix(r.Address)) + } + return ips, nil +} + +func (c *Client) searchDevices(ctx context.Context, query string) ([]HostEntry, error) { + apiURL := fmt.Sprintf("%s/api/dcim/devices/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query)) + var resp netboxListResponse[netboxDevice] + if err := c.get(ctx, apiURL, &resp); err != nil { + return nil, err + } + entries := make([]HostEntry, 0, len(resp.Results)) + for _, d := range resp.Results { + entries = append(entries, deviceToEntry(d)) + } + return entries, nil +} + +func (c *Client) searchVMs(ctx context.Context, query string) ([]HostEntry, error) { + apiURL := fmt.Sprintf("%s/api/virtualization/virtual-machines/?name__ic=%s&limit=50", c.baseURL, url.QueryEscape(query)) + var resp netboxListResponse[netboxVM] + if err := c.get(ctx, apiURL, &resp); err != nil { + return nil, err + } + entries := make([]HostEntry, 0, len(resp.Results)) + for _, v := range resp.Results { + entries = append(entries, vmToEntry(v)) + } + return entries, nil +} + +func (c *Client) get(ctx context.Context, apiURL string, out any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Authorization", "Token "+c.token) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("request to %s: %w", apiURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("netbox returned %d for %s", resp.StatusCode, apiURL) + } + + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return fmt.Errorf("decoding response: %w", err) + } + return nil +} + +func deviceToEntry(d netboxDevice) HostEntry { + e := HostEntry{ID: d.ID, Name: d.Name, Kind: "device"} + if d.PrimaryIP4 != nil { + e.PrimaryIP4 = stripPrefix(d.PrimaryIP4.Address) + } + if d.PrimaryIP6 != nil { + e.PrimaryIP6 = stripPrefix(d.PrimaryIP6.Address) + } + for _, t := range d.Tags { + e.Tags = append(e.Tags, t.Name) + } + return e +} + +func vmToEntry(v netboxVM) HostEntry { + e := HostEntry{ID: v.ID, Name: v.Name, Kind: "vm"} + if v.PrimaryIP4 != nil { + e.PrimaryIP4 = stripPrefix(v.PrimaryIP4.Address) + } + if v.PrimaryIP6 != nil { + e.PrimaryIP6 = stripPrefix(v.PrimaryIP6.Address) + } + for _, t := range v.Tags { + e.Tags = append(e.Tags, t.Name) + } + return e +} + +// stripPrefix removes the CIDR prefix length from a NetBox IP (e.g. "10.0.1.5/24" → "10.0.1.5"). +func stripPrefix(cidr string) string { + if idx := strings.Index(cidr, "/"); idx != -1 { + return cidr[:idx] + } + return cidr +} diff --git a/internal/netbox/client_test.go b/internal/netbox/client_test.go new file mode 100644 index 0000000..cb50b79 --- /dev/null +++ b/internal/netbox/client_test.go @@ -0,0 +1,261 @@ +package netbox + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +// newTestServer returns an httptest.Server that serves fixed responses per path. +func newTestServer(t *testing.T, handlers map[string]any) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + for path, body := range handlers { + b, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshalling handler for %s: %v", path, err) + } + captured := b + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(captured) + }) + } + return httptest.NewServer(mux) +} + +func deviceListResponse(devices ...netboxDevice) netboxListResponse[netboxDevice] { + return netboxListResponse[netboxDevice]{Count: len(devices), Results: devices} +} + +func vmListResponse(vms ...netboxVM) netboxListResponse[netboxVM] { + return netboxListResponse[netboxVM]{Count: len(vms), Results: vms} +} + +func ipListResponse(addrs ...string) netboxIPListResponse { + resp := netboxIPListResponse{Count: len(addrs)} + for _, a := range addrs { + resp.Results = append(resp.Results, struct { + Address string `json:"address"` + Interface *struct { + Name string `json:"name"` + } `json:"assigned_object"` + }{Address: a}) + } + return resp +} + +func TestSearch_ReturnsBothDevicesAndVMs(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/dcim/devices/": deviceListResponse( + netboxDevice{ID: 1, Name: "router-01", PrimaryIP4: &netboxIP{Address: "10.0.0.1/24"}}, + ), + "/api/virtualization/virtual-machines/": vmListResponse( + netboxVM{ID: 2, Name: "vm-01", PrimaryIP4: &netboxIP{Address: "10.0.0.2/24"}}, + ), + }) + defer srv.Close() + + c := NewClient(srv.URL, "token") + results, err := c.Search(context.Background(), "") + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d results, want 2", len(results)) + } + + names := map[string]bool{} + for _, r := range results { + names[r.Name] = true + } + if !names["router-01"] || !names["vm-01"] { + t.Errorf("missing expected hosts in results: %v", names) + } +} + +func TestSearch_MapsKindCorrectly(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/dcim/devices/": deviceListResponse( + netboxDevice{ID: 1, Name: "sw-01"}, + ), + "/api/virtualization/virtual-machines/": vmListResponse( + netboxVM{ID: 2, Name: "vm-01"}, + ), + }) + defer srv.Close() + + c := NewClient(srv.URL, "token") + results, _ := c.Search(context.Background(), "") + + for _, r := range results { + switch r.Name { + case "sw-01": + if r.Kind != "device" { + t.Errorf("sw-01 kind: got %q, want %q", r.Kind, "device") + } + case "vm-01": + if r.Kind != "vm" { + t.Errorf("vm-01 kind: got %q, want %q", r.Kind, "vm") + } + } + } +} + +func TestSearch_StripsPrefixFromPrimaryIP(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/dcim/devices/": deviceListResponse( + netboxDevice{ID: 1, Name: "host", PrimaryIP4: &netboxIP{Address: "192.168.1.10/24"}}, + ), + "/api/virtualization/virtual-machines/": vmListResponse(), + }) + defer srv.Close() + + c := NewClient(srv.URL, "token") + results, _ := c.Search(context.Background(), "host") + if len(results) == 0 { + t.Fatal("expected at least one result") + } + if results[0].PrimaryIP4 != "192.168.1.10" { + t.Errorf("PrimaryIP4: got %q, want %q", results[0].PrimaryIP4, "192.168.1.10") + } +} + +func TestSearch_TagsAreMapped(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/dcim/devices/": deviceListResponse( + netboxDevice{ + ID: 1, + Name: "host", + Tags: []struct { + Name string `json:"name"` + }{{Name: "prod"}, {Name: "mgmt"}}, + }, + ), + "/api/virtualization/virtual-machines/": vmListResponse(), + }) + defer srv.Close() + + c := NewClient(srv.URL, "token") + results, _ := c.Search(context.Background(), "") + if len(results[0].Tags) != 2 { + t.Errorf("tags: got %v, want [prod mgmt]", results[0].Tags) + } +} + +func TestSearch_PartialFailure_ReturnsAvailableResults(t *testing.T) { + // Only devices endpoint works; VMs returns 500. + mux := http.NewServeMux() + body, _ := json.Marshal(deviceListResponse(netboxDevice{ID: 1, Name: "sw-01"})) + mux.HandleFunc("/api/dcim/devices/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(body) + }) + mux.HandleFunc("/api/virtualization/virtual-machines/", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal error", http.StatusInternalServerError) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := NewClient(srv.URL, "token") + results, err := c.Search(context.Background(), "") + if err != nil { + t.Fatalf("partial failure should not return error, got: %v", err) + } + if len(results) != 1 || results[0].Name != "sw-01" { + t.Errorf("expected device results, got %v", results) + } +} + +func TestSearch_BothFail_ReturnsError(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error", http.StatusInternalServerError) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := NewClient(srv.URL, "token") + _, err := c.Search(context.Background(), "") + if err == nil { + t.Error("both endpoints failing should return an error") + } +} + +func TestGetIPs_Device(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/ipam/ip-addresses/": ipListResponse("10.0.0.1/24", "10.0.0.2/24"), + }) + defer srv.Close() + + c := NewClient(srv.URL, "token") + ips, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "device"}) + if err != nil { + t.Fatalf("GetIPs: %v", err) + } + if len(ips) != 2 { + t.Errorf("got %d IPs, want 2", len(ips)) + } + if ips[0] != "10.0.0.1" || ips[1] != "10.0.0.2" { + t.Errorf("IPs: got %v, want [10.0.0.1 10.0.0.2]", ips) + } +} + +func TestGetIPs_VM(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/ipam/ip-addresses/": ipListResponse("172.16.0.5/16"), + }) + defer srv.Close() + + c := NewClient(srv.URL, "token") + ips, err := c.GetIPs(context.Background(), HostEntry{ID: 2, Kind: "vm"}) + if err != nil { + t.Fatalf("GetIPs: %v", err) + } + if len(ips) != 1 || ips[0] != "172.16.0.5" { + t.Errorf("IPs: got %v, want [172.16.0.5]", ips) + } +} + +func TestGetIPs_UnknownKind(t *testing.T) { + c := NewClient("http://localhost", "token") + _, err := c.GetIPs(context.Background(), HostEntry{ID: 1, Kind: "unknown"}) + if err == nil { + t.Error("unknown kind should return an error") + } +} + +func TestGetIPsWithFilter(t *testing.T) { + srv := newTestServer(t, map[string]any{ + "/api/ipam/ip-addresses/": ipListResponse("10.10.10.1/24"), + }) + defer srv.Close() + + c := NewClient(srv.URL, "token") + ips, err := c.GetIPsWithFilter(context.Background(), "device_id=1&interface_name=mgmt0") + if err != nil { + t.Fatalf("GetIPsWithFilter: %v", err) + } + if len(ips) != 1 || ips[0] != "10.10.10.1" { + t.Errorf("IPs: got %v, want [10.10.10.1]", ips) + } +} + +func TestStripPrefix(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"10.0.0.1/24", "10.0.0.1"}, + {"::1/128", "::1"}, + {"192.168.1.1", "192.168.1.1"}, // no prefix — unchanged + {"", ""}, + } + for _, tt := range tests { + if got := stripPrefix(tt.in); got != tt.want { + t.Errorf("stripPrefix(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} diff --git a/internal/netbox/models.go b/internal/netbox/models.go new file mode 100644 index 0000000..6e28d3b --- /dev/null +++ b/internal/netbox/models.go @@ -0,0 +1,53 @@ +package netbox + +// HostEntry is a unified model for both devices and virtual machines from NetBox. +type HostEntry struct { + ID int + Name string + Kind string // "device" | "vm" + PrimaryIP4 string // e.g. "10.0.1.5" (prefix length stripped) + PrimaryIP6 string + Tags []string +} + +// netboxIP represents an IP address as returned by the NetBox API. +type netboxIP struct { + Address string `json:"address"` // CIDR notation, e.g. "10.0.1.5/24" +} + +// netboxDevice matches the relevant fields of the NetBox /dcim/devices/ response. +type netboxDevice struct { + ID int `json:"id"` + Name string `json:"name"` + Tags []struct { + Name string `json:"name"` + } `json:"tags"` + PrimaryIP4 *netboxIP `json:"primary_ip4"` + PrimaryIP6 *netboxIP `json:"primary_ip6"` +} + +// netboxVM matches the relevant fields of the NetBox /virtualization/virtual-machines/ response. +type netboxVM struct { + ID int `json:"id"` + Name string `json:"name"` + Tags []struct { + Name string `json:"name"` + } `json:"tags"` + PrimaryIP4 *netboxIP `json:"primary_ip4"` + PrimaryIP6 *netboxIP `json:"primary_ip6"` +} + +type netboxListResponse[T any] struct { + Count int `json:"count"` + Results []T `json:"results"` +} + +type netboxIPListResponse struct { + Count int `json:"count"` + Results []struct { + Address string `json:"address"` + Interface *struct { + Name string `json:"name"` + } `json:"assigned_object"` + } `json:"results"` +} diff --git a/internal/resolver/chain.go b/internal/resolver/chain.go new file mode 100644 index 0000000..bc693e9 --- /dev/null +++ b/internal/resolver/chain.go @@ -0,0 +1,57 @@ +package resolver + +import ( + "context" + "fmt" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config" + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// Chain tries each strategy in order until one returns an IP. +type Chain struct { + strategies []Strategy +} + +// New builds a Chain from the strategy names listed in the resolver config. +func New(cfg config.ResolverConfig) (*Chain, error) { + var strategies []Strategy + for _, name := range cfg.Strategies { + s, err := newStrategy(name, cfg) + if err != nil { + return nil, fmt.Errorf("resolver strategy %q: %w", name, err) + } + strategies = append(strategies, s) + } + return &Chain{strategies: strategies}, nil +} + +func (c *Chain) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) { + for _, s := range c.strategies { + ip, err := s.Resolve(ctx, entry, client) + if err == nil { + return ip, nil + } + } + return "", fmt.Errorf("no strategy resolved an IP for %q", entry.Name) +} + +func newStrategy(name string, cfg config.ResolverConfig) (Strategy, error) { + switch name { + case "primary_ip": + return &PrimaryIPStrategy{}, nil + case "management_subnet": + s, err := NewManagementSubnetStrategy(cfg.ManagementSubnets) + if err != nil { + return nil, err + } + return s, nil + case "interface_name": + if cfg.InterfaceName == "" { + return nil, fmt.Errorf("interface_name strategy requires resolver.interface_name to be set") + } + return &InterfaceNameStrategy{name: cfg.InterfaceName}, nil + default: + return nil, fmt.Errorf("unknown strategy %q", name) + } +} diff --git a/internal/resolver/chain_test.go b/internal/resolver/chain_test.go new file mode 100644 index 0000000..b40fb73 --- /dev/null +++ b/internal/resolver/chain_test.go @@ -0,0 +1,155 @@ +package resolver + +import ( + "context" + "errors" + "testing" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/config" + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// stubStrategy is a test double for Strategy. +type stubStrategy struct { + name string + ip string + err error +} + +func (s *stubStrategy) Name() string { return s.name } +func (s *stubStrategy) Resolve(_ context.Context, _ *netbox.HostEntry, _ *netbox.Client) (string, error) { + return s.ip, s.err +} + +func TestChain_FirstStrategySucceeds(t *testing.T) { + c := &Chain{strategies: []Strategy{ + &stubStrategy{name: "first", ip: "10.0.0.1"}, + &stubStrategy{name: "second", ip: "10.0.0.2"}, + }} + ip, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "10.0.0.1" { + t.Errorf("got %q, want first strategy's IP %q", ip, "10.0.0.1") + } +} + +func TestChain_FallsBackToNextStrategy(t *testing.T) { + c := &Chain{strategies: []Strategy{ + &stubStrategy{name: "first", err: ErrNoIP}, + &stubStrategy{name: "second", ip: "10.0.0.2"}, + }} + ip, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "10.0.0.2" { + t.Errorf("got %q, want second strategy's IP %q", ip, "10.0.0.2") + } +} + +func TestChain_AllStrategiesFail(t *testing.T) { + c := &Chain{strategies: []Strategy{ + &stubStrategy{name: "a", err: ErrNoIP}, + &stubStrategy{name: "b", err: errors.New("api error")}, + }} + _, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil) + if err == nil { + t.Error("expected error when all strategies fail") + } +} + +func TestChain_EmptyStrategies(t *testing.T) { + c := &Chain{} + _, err := c.Resolve(context.Background(), &netbox.HostEntry{Name: "host"}, nil) + if err == nil { + t.Error("empty chain should return an error") + } +} + +func TestNew_PrimaryIP(t *testing.T) { + cfg := config.ResolverConfig{Strategies: []string{"primary_ip"}} + c, err := New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + if len(c.strategies) != 1 { + t.Errorf("got %d strategies, want 1", len(c.strategies)) + } + if c.strategies[0].Name() != "primary_ip" { + t.Errorf("strategy name: got %q, want %q", c.strategies[0].Name(), "primary_ip") + } +} + +func TestNew_ManagementSubnet(t *testing.T) { + cfg := config.ResolverConfig{ + Strategies: []string{"management_subnet"}, + ManagementSubnets: []string{"10.0.0.0/8"}, + } + c, err := New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + if c.strategies[0].Name() != "management_subnet" { + t.Errorf("strategy name: got %q, want %q", c.strategies[0].Name(), "management_subnet") + } +} + +func TestNew_ManagementSubnet_InvalidCIDR(t *testing.T) { + cfg := config.ResolverConfig{ + Strategies: []string{"management_subnet"}, + ManagementSubnets: []string{"not-a-cidr"}, + } + _, err := New(cfg) + if err == nil { + t.Error("invalid CIDR should return an error") + } +} + +func TestNew_InterfaceName(t *testing.T) { + cfg := config.ResolverConfig{ + Strategies: []string{"interface_name"}, + InterfaceName: "mgmt0", + } + c, err := New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + if c.strategies[0].Name() != "interface_name" { + t.Errorf("strategy name: got %q", c.strategies[0].Name()) + } +} + +func TestNew_InterfaceName_MissingConfig(t *testing.T) { + cfg := config.ResolverConfig{ + Strategies: []string{"interface_name"}, + InterfaceName: "", // not set + } + _, err := New(cfg) + if err == nil { + t.Error("interface_name without config should return an error") + } +} + +func TestNew_UnknownStrategy(t *testing.T) { + cfg := config.ResolverConfig{Strategies: []string{"nonexistent"}} + _, err := New(cfg) + if err == nil { + t.Error("unknown strategy should return an error") + } +} + +func TestNew_MultipleStrategies(t *testing.T) { + cfg := config.ResolverConfig{ + Strategies: []string{"management_subnet", "primary_ip"}, + ManagementSubnets: []string{"10.0.0.0/8"}, + } + c, err := New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + if len(c.strategies) != 2 { + t.Errorf("got %d strategies, want 2", len(c.strategies)) + } +} diff --git a/internal/resolver/interface_name.go b/internal/resolver/interface_name.go new file mode 100644 index 0000000..37ed9a2 --- /dev/null +++ b/internal/resolver/interface_name.go @@ -0,0 +1,38 @@ +package resolver + +import ( + "context" + "fmt" + "net/url" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// InterfaceNameStrategy finds the first IP assigned to a named interface (e.g. "mgmt0", "eth0"). +type InterfaceNameStrategy struct { + name string +} + +func (s *InterfaceNameStrategy) Name() string { return "interface_name" } + +func (s *InterfaceNameStrategy) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) { + // Build filter parameters for IP addresses attached to the named interface. + var filterParam string + switch entry.Kind { + case "device": + filterParam = fmt.Sprintf("device_id=%d&interface_name=%s", entry.ID, url.QueryEscape(s.name)) + case "vm": + filterParam = fmt.Sprintf("virtual_machine_id=%d&vminterface_name=%s", entry.ID, url.QueryEscape(s.name)) + default: + return "", fmt.Errorf("unknown kind %q", entry.Kind) + } + + ips, err := client.GetIPsWithFilter(ctx, filterParam) + if err != nil { + return "", fmt.Errorf("fetching IPs for interface %q: %w", s.name, err) + } + if len(ips) == 0 { + return "", ErrNoIP + } + return ips[0], nil +} diff --git a/internal/resolver/management.go b/internal/resolver/management.go new file mode 100644 index 0000000..63ff261 --- /dev/null +++ b/internal/resolver/management.go @@ -0,0 +1,53 @@ +package resolver + +import ( + "context" + "fmt" + "net" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// ManagementSubnetStrategy finds the first IP of a host that falls within +// one of the configured management subnets. +type ManagementSubnetStrategy struct { + subnets []*net.IPNet +} + +func NewManagementSubnetStrategy(cidrs []string) (*ManagementSubnetStrategy, error) { + nets := make([]*net.IPNet, 0, len(cidrs)) + for _, cidr := range cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err) + } + nets = append(nets, ipNet) + } + return &ManagementSubnetStrategy{subnets: nets}, nil +} + +func (s *ManagementSubnetStrategy) Name() string { return "management_subnet" } + +func (s *ManagementSubnetStrategy) Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) { + if len(s.subnets) == 0 { + return "", ErrNoIP + } + + ips, err := client.GetIPs(ctx, *entry) + if err != nil { + return "", fmt.Errorf("fetching IPs: %w", err) + } + + for _, rawIP := range ips { + ip := net.ParseIP(rawIP) + if ip == nil { + continue + } + for _, subnet := range s.subnets { + if subnet.Contains(ip) { + return rawIP, nil + } + } + } + return "", ErrNoIP +} diff --git a/internal/resolver/management_test.go b/internal/resolver/management_test.go new file mode 100644 index 0000000..56f9d00 --- /dev/null +++ b/internal/resolver/management_test.go @@ -0,0 +1,109 @@ +package resolver + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// newIPServer returns a test server that always responds with the given IP list. +func newIPServer(t *testing.T, ips []string) *httptest.Server { + t.Helper() + type result struct { + Address string `json:"address"` + } + type response struct { + Count int `json:"count"` + Results []result `json:"results"` + } + resp := response{Count: len(ips)} + for _, ip := range ips { + resp.Results = append(resp.Results, result{Address: ip}) + } + body, _ := json.Marshal(resp) + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(body) + })) +} + +func TestManagementSubnetStrategy_MatchesSubnet(t *testing.T) { + srv := newIPServer(t, []string{"10.0.1.5/24", "192.168.0.1/24"}) + defer srv.Close() + + s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"}) + client := netbox.NewClient(srv.URL, "token") + + ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "10.0.1.5" { + t.Errorf("got %q, want %q", ip, "10.0.1.5") + } +} + +func TestManagementSubnetStrategy_NoMatch(t *testing.T) { + srv := newIPServer(t, []string{"192.168.0.1/24"}) + defer srv.Close() + + s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"}) + client := netbox.NewClient(srv.URL, "token") + + _, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) + if err != ErrNoIP { + t.Errorf("no matching subnet should return ErrNoIP, got %v", err) + } +} + +func TestManagementSubnetStrategy_FirstMatchWins(t *testing.T) { + srv := newIPServer(t, []string{"10.0.1.1/24", "10.0.1.2/24"}) + defer srv.Close() + + s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"}) + client := netbox.NewClient(srv.URL, "token") + + ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "10.0.1.1" { + t.Errorf("got %q, want first matching IP %q", ip, "10.0.1.1") + } +} + +func TestManagementSubnetStrategy_VMKind(t *testing.T) { + srv := newIPServer(t, []string{"172.16.5.10/16"}) + defer srv.Close() + + s, _ := NewManagementSubnetStrategy([]string{"172.16.0.0/12"}) + client := netbox.NewClient(srv.URL, "token") + + ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 2, Kind: "vm"}, client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "172.16.5.10" { + t.Errorf("got %q, want %q", ip, "172.16.5.10") + } +} + +func TestManagementSubnetStrategy_IPv6Subnet(t *testing.T) { + srv := newIPServer(t, []string{"fd00::1/64"}) + defer srv.Close() + + s, _ := NewManagementSubnetStrategy([]string{"fd00::/8"}) + client := netbox.NewClient(srv.URL, "token") + + ip, err := s.Resolve(context.Background(), &netbox.HostEntry{ID: 1, Kind: "device"}, client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "fd00::1" { + t.Errorf("got %q, want %q", ip, "fd00::1") + } +} diff --git a/internal/resolver/primary_ip.go b/internal/resolver/primary_ip.go new file mode 100644 index 0000000..637041d --- /dev/null +++ b/internal/resolver/primary_ip.go @@ -0,0 +1,23 @@ +package resolver + +import ( + "context" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// PrimaryIPStrategy returns the primary IP configured in NetBox. +// Prefers IPv4, falls back to IPv6. +type PrimaryIPStrategy struct{} + +func (s *PrimaryIPStrategy) Name() string { return "primary_ip" } + +func (s *PrimaryIPStrategy) Resolve(_ context.Context, entry *netbox.HostEntry, _ *netbox.Client) (string, error) { + if entry.PrimaryIP4 != "" { + return entry.PrimaryIP4, nil + } + if entry.PrimaryIP6 != "" { + return entry.PrimaryIP6, nil + } + return "", ErrNoIP +} diff --git a/internal/resolver/primary_ip_test.go b/internal/resolver/primary_ip_test.go new file mode 100644 index 0000000..4fa01e1 --- /dev/null +++ b/internal/resolver/primary_ip_test.go @@ -0,0 +1,91 @@ +package resolver + +import ( + "context" + "testing" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +func TestPrimaryIPStrategy_Name(t *testing.T) { + s := &PrimaryIPStrategy{} + if s.Name() != "primary_ip" { + t.Errorf("Name: got %q, want %q", s.Name(), "primary_ip") + } +} + +func TestPrimaryIPStrategy_IPv4(t *testing.T) { + s := &PrimaryIPStrategy{} + e := &netbox.HostEntry{PrimaryIP4: "10.0.0.1", PrimaryIP6: "::1"} + ip, err := s.Resolve(context.Background(), e, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "10.0.0.1" { + t.Errorf("got %q, want IPv4 %q", ip, "10.0.0.1") + } +} + +func TestPrimaryIPStrategy_IPv6Fallback(t *testing.T) { + s := &PrimaryIPStrategy{} + e := &netbox.HostEntry{PrimaryIP6: "::1"} + ip, err := s.Resolve(context.Background(), e, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ip != "::1" { + t.Errorf("got %q, want IPv6 %q", ip, "::1") + } +} + +func TestPrimaryIPStrategy_NoIP(t *testing.T) { + s := &PrimaryIPStrategy{} + _, err := s.Resolve(context.Background(), &netbox.HostEntry{}, nil) + if err != ErrNoIP { + t.Errorf("got %v, want ErrNoIP", err) + } +} + +func TestManagementSubnetStrategy_Name(t *testing.T) { + s, _ := NewManagementSubnetStrategy([]string{"10.0.0.0/8"}) + if s.Name() != "management_subnet" { + t.Errorf("Name: got %q, want %q", s.Name(), "management_subnet") + } +} + +func TestManagementSubnetStrategy_InvalidCIDR(t *testing.T) { + _, err := NewManagementSubnetStrategy([]string{"not-a-cidr"}) + if err == nil { + t.Error("invalid CIDR should return an error") + } +} + +func TestManagementSubnetStrategy_EmptyCIDRs(t *testing.T) { + s, _ := NewManagementSubnetStrategy([]string{}) + _, err := s.Resolve(context.Background(), &netbox.HostEntry{}, nil) + if err != ErrNoIP { + t.Errorf("empty subnets should return ErrNoIP, got %v", err) + } +} + +func TestManagementSubnetStrategy_MultipleCIDRs(t *testing.T) { + _, err := NewManagementSubnetStrategy([]string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}) + if err != nil { + t.Fatalf("valid CIDRs should not error: %v", err) + } +} + +func TestInterfaceNameStrategy_Name(t *testing.T) { + s := &InterfaceNameStrategy{name: "mgmt0"} + if s.Name() != "interface_name" { + t.Errorf("Name: got %q, want %q", s.Name(), "interface_name") + } +} + +func TestInterfaceNameStrategy_UnknownKind(t *testing.T) { + s := &InterfaceNameStrategy{name: "eth0"} + _, err := s.Resolve(context.Background(), &netbox.HostEntry{Kind: "unknown"}, nil) + if err == nil { + t.Error("unknown kind should return an error") + } +} diff --git a/internal/resolver/strategy.go b/internal/resolver/strategy.go new file mode 100644 index 0000000..3559eb5 --- /dev/null +++ b/internal/resolver/strategy.go @@ -0,0 +1,17 @@ +package resolver + +import ( + "context" + "errors" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// ErrNoIP is returned when a strategy cannot find a matching IP address. +var ErrNoIP = errors.New("no matching IP found") + +// Strategy is a single rule for resolving an IP address from a NetBox host entry. +type Strategy interface { + Name() string + Resolve(ctx context.Context, entry *netbox.HostEntry, client *netbox.Client) (string, error) +} diff --git a/internal/ssh/args.go b/internal/ssh/args.go new file mode 100644 index 0000000..d5e834a --- /dev/null +++ b/internal/ssh/args.go @@ -0,0 +1,109 @@ +package ssh + +import "strings" + +// flagsWithArg lists all SSH flags that consume the following argument. +var flagsWithArg = map[byte]bool{ + 'b': true, 'c': true, 'D': true, 'E': true, 'e': true, + 'F': true, 'I': true, 'i': true, 'J': true, 'L': true, + 'l': true, 'm': true, 'o': true, 'O': true, 'p': true, + 'Q': true, 'R': true, 'S': true, 'w': true, 'W': true, +} + +// ParsedArgs holds the result of parsing SSH arguments. +type ParsedArgs struct { + Host string // hostname without the user@ prefix + User string // empty if not specified + DestIdx int // index in Args where [user@]host sits + Args []string +} + +// Parse scans SSH arguments and extracts the destination ([user@]host). +// Returns nil if no destination is found. +func Parse(args []string) *ParsedArgs { + i := 0 + for i < len(args) { + arg := args[i] + + // "--" ends option processing + if arg == "--" { + i++ + break + } + + if strings.HasPrefix(arg, "-") && len(arg) > 1 { + flag := arg[1] + if flagsWithArg[flag] { + if len(arg) > 2 { + // argument is attached, e.g. -p2222 + i++ + } else { + // argument is the next element, e.g. -p 2222 + i += 2 + } + } else { + i++ + } + continue + } + + // first non-flag argument is the destination + host, user := splitUserHost(arg) + return &ParsedArgs{ + Host: host, + User: user, + DestIdx: i, + Args: args, + } + } + + // handle arguments after "--" + if i < len(args) { + host, user := splitUserHost(args[i]) + return &ParsedArgs{ + Host: host, + User: user, + DestIdx: i, + Args: args, + } + } + + return nil +} + +// ReplaceHost returns a copy of args with the destination replaced by newHost, +// preserving any user@ prefix. +func ReplaceHost(args []string, destIdx int, newHost string) []string { + result := make([]string, len(args)) + copy(result, args) + + original := args[destIdx] + if at := strings.Index(original, "@"); at != -1 { + result[destIdx] = original[:at+1] + newHost + } else { + result[destIdx] = newHost + } + return result +} + +// HasUserFlag reports whether a user was specified via -l in args. +// Used to avoid overriding an explicit -l with the configured default user. +func HasUserFlag(args []string) bool { + for i, a := range args { + if a == "-l" && i+1 < len(args) { + return true + } + // handle attached form: -lroot + if len(a) > 2 && a[0] == '-' && a[1] == 'l' { + return true + } + } + return false +} + +func splitUserHost(dest string) (host, user string) { + if at := strings.Index(dest, "@"); at != -1 { + return dest[at+1:], dest[:at] + } + return dest, "" +} diff --git a/internal/ssh/args_test.go b/internal/ssh/args_test.go new file mode 100644 index 0000000..4dd8611 --- /dev/null +++ b/internal/ssh/args_test.go @@ -0,0 +1,161 @@ +package ssh + +import ( + "testing" +) + +func TestParse_BareHostname(t *testing.T) { + got := Parse([]string{"myhost"}) + assertParsed(t, got, "myhost", "", 0) +} + +func TestParse_UserAtHost(t *testing.T) { + got := Parse([]string{"admin@myhost"}) + assertParsed(t, got, "myhost", "admin", 0) +} + +func TestParse_PortFlag_Separated(t *testing.T) { + got := Parse([]string{"-p", "2222", "myhost"}) + assertParsed(t, got, "myhost", "", 2) +} + +func TestParse_PortFlag_Attached(t *testing.T) { + got := Parse([]string{"-p2222", "myhost"}) + assertParsed(t, got, "myhost", "", 1) +} + +func TestParse_IdentityFlag(t *testing.T) { + got := Parse([]string{"-i", "/path/to/key", "user@myhost", "ls"}) + assertParsed(t, got, "myhost", "user", 2) +} + +func TestParse_VerboseFlag(t *testing.T) { + got := Parse([]string{"-v", "myhost"}) + assertParsed(t, got, "myhost", "", 1) +} + +func TestParse_OptionFlag(t *testing.T) { + got := Parse([]string{"-o", "StrictHostKeyChecking=no", "myhost"}) + assertParsed(t, got, "myhost", "", 2) +} + +func TestParse_JumpHost(t *testing.T) { + got := Parse([]string{"-J", "jumphost", "-p", "22", "target"}) + assertParsed(t, got, "target", "", 4) +} + +func TestParse_MultipleFlags(t *testing.T) { + got := Parse([]string{"-v", "-p", "22", "-i", "key", "root@host", "uptime"}) + assertParsed(t, got, "host", "root", 5) +} + +func TestParse_DoubleDash(t *testing.T) { + got := Parse([]string{"--", "myhost"}) + assertParsed(t, got, "myhost", "", 1) +} + +func TestParse_DoubleDash_WithFlags(t *testing.T) { + // flags after -- should be treated as destination + got := Parse([]string{"-v", "--", "-not-a-flag"}) + assertParsed(t, got, "-not-a-flag", "", 2) +} + +func TestParse_NoDestination(t *testing.T) { + got := Parse([]string{"-v", "-p", "2222"}) + if got != nil { + t.Errorf("expected nil for args without destination, got %+v", got) + } +} + +func TestParse_EmptyArgs(t *testing.T) { + got := Parse([]string{}) + if got != nil { + t.Error("empty args should return nil") + } +} + +func TestParse_OnlyDoubleDash(t *testing.T) { + got := Parse([]string{"--"}) + if got != nil { + t.Error("only -- with no destination should return nil") + } +} + +func TestReplaceHost_PlainHost(t *testing.T) { + args := []string{"myhost"} + result := ReplaceHost(args, 0, "10.0.0.1") + if result[0] != "10.0.0.1" { + t.Errorf("got %q, want %q", result[0], "10.0.0.1") + } +} + +func TestReplaceHost_PreservesUserPrefix(t *testing.T) { + args := []string{"-p", "22", "admin@myhost", "ls"} + result := ReplaceHost(args, 2, "10.0.0.1") + if result[2] != "admin@10.0.0.1" { + t.Errorf("got %q, want %q", result[2], "admin@10.0.0.1") + } +} + +func TestReplaceHost_DoesNotMutateOriginal(t *testing.T) { + args := []string{"myhost"} + _ = ReplaceHost(args, 0, "10.0.0.1") + if args[0] != "myhost" { + t.Error("ReplaceHost must not mutate the original slice") + } +} + +func TestReplaceHost_OtherArgsUnchanged(t *testing.T) { + args := []string{"-p", "22", "myhost"} + result := ReplaceHost(args, 2, "10.0.0.1") + if result[0] != "-p" || result[1] != "22" { + t.Errorf("other args should be unchanged: %v", result) + } +} + +func TestHasUserFlag_FlagSeparated(t *testing.T) { + if !HasUserFlag([]string{"-l", "admin", "host"}) { + t.Error("should detect -l ") + } +} + +func TestHasUserFlag_FlagAttached(t *testing.T) { + if !HasUserFlag([]string{"-ladmin", "host"}) { + t.Error("should detect -l (attached form)") + } +} + +func TestHasUserFlag_NotPresent(t *testing.T) { + if HasUserFlag([]string{"-p", "22", "host"}) { + t.Error("should not detect user flag when absent") + } +} + +func TestHasUserFlag_EmptyArgs(t *testing.T) { + if HasUserFlag([]string{}) { + t.Error("empty args should return false") + } +} + +func TestHasUserFlag_LFlagAtEnd(t *testing.T) { + // -l at the very end with no value — should not panic + if HasUserFlag([]string{"-l"}) { + t.Error("-l with no value should return false") + } +} + +func assertParsed(t *testing.T, got *ParsedArgs, host, user string, destIdx int) { + t.Helper() + if got == nil { + t.Fatal("Parse returned nil") + } + if got.Host != host { + t.Errorf("host: got %q, want %q", got.Host, host) + } + if got.User != user { + t.Errorf("user: got %q, want %q", got.User, user) + } + if got.DestIdx != destIdx { + t.Errorf("destIdx: got %d, want %d", got.DestIdx, destIdx) + } +} diff --git a/internal/ssh/exec.go b/internal/ssh/exec.go new file mode 100644 index 0000000..fd55bb5 --- /dev/null +++ b/internal/ssh/exec.go @@ -0,0 +1,19 @@ +package ssh + +import ( + "fmt" + "os" + "os/exec" + "syscall" +) + +// Exec replaces the current process with the native ssh client via syscall.Exec. +// All existing SSH configs, keys, and agent forwarding remain intact. +func Exec(args []string) error { + sshPath, err := exec.LookPath("ssh") + if err != nil { + return fmt.Errorf("ssh not found in PATH: %w", err) + } + argv := append([]string{"ssh"}, args...) + return syscall.Exec(sshPath, argv, os.Environ()) +} diff --git a/internal/tui/model.go b/internal/tui/model.go new file mode 100644 index 0000000..9740c7e --- /dev/null +++ b/internal/tui/model.go @@ -0,0 +1,237 @@ +package tui + +import ( + "context" + "fmt" + "io" + "strings" + "time" + + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/cache" + "git.zb-server.de/Sebi/ssh-netbox-wrapper/internal/netbox" +) + +// SelectedHost is returned when the user confirms a host in the TUI. +type SelectedHost struct { + Name string + IP string +} + +// --- bubbletea messages --- + +type debounceMsg struct{ query string } + +type searchResultMsg struct { + query string + entries []netbox.HostEntry + err error +} + +// --- list item --- + +type hostItem struct { + name string + ip string + kind string +} + +func (h hostItem) Title() string { return h.name } +func (h hostItem) Description() string { return fmt.Sprintf("%s [%s]", h.ip, h.kind) } +func (h hostItem) FilterValue() string { return h.name } + +// --- compact list delegate --- + +type compactDelegate struct{} + +func (d compactDelegate) Height() int { return 1 } +func (d compactDelegate) Spacing() int { return 0 } +func (d compactDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil } +func (d compactDelegate) Render(w io.Writer, m list.Model, index int, item list.Item) { + h, ok := item.(hostItem) + if !ok { + return + } + line := fmt.Sprintf(" %s %s", h.name, lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(h.ip)) + if index == m.Index() { + line = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("86")).Render("> " + strings.TrimPrefix(line, " ")) + } + fmt.Fprintln(w, line) +} + +// --- Model --- + +type Model struct { + input textinput.Model + list list.Model + client *netbox.Client + cache *cache.Cache + lastSent string // last query sent to NetBox (or served from cache) + seq int // sequence number to discard stale results + loading bool + err error + selected *SelectedHost + width int + height int +} + +func New(client *netbox.Client, c *cache.Cache) *Model { + ti := textinput.New() + ti.Placeholder = "Search hostname…" + ti.Focus() + + l := list.New(nil, compactDelegate{}, 0, 0) + l.SetShowHelp(false) + l.SetShowTitle(false) + l.SetShowStatusBar(false) + l.SetFilteringEnabled(false) + + return &Model{ + input: ti, + list: l, + client: client, + cache: c, + } +} + +func (m *Model) Init() tea.Cmd { + return textinput.Blink +} + +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.list.SetSize(msg.Width, msg.Height-4) + return m, nil + + case tea.KeyMsg: + switch msg.String() { + case "ctrl+c", "esc": + return m, tea.Quit + + case "enter": + if item, ok := m.list.SelectedItem().(hostItem); ok { + m.selected = &SelectedHost{Name: item.name, IP: item.ip} + return m, tea.Quit + } + + case "tab": + // Copy the top result into the search field. + if m.list.Items() != nil && len(m.list.Items()) > 0 { + if item, ok := m.list.Items()[0].(hostItem); ok { + m.input.SetValue(item.name) + m.input.CursorEnd() + } + } + return m, nil + } + + case debounceMsg: + // Only query if the input has changed since the last request. + q := m.input.Value() + if q == m.lastSent { + return m, nil + } + m.lastSent = q + m.loading = true + m.seq++ + seq := m.seq + return m, m.doSearch(q, seq) + + case searchResultMsg: + if msg.query != m.lastSent { + return m, nil // discard stale result + } + m.loading = false + if msg.err != nil { + m.err = msg.err + return m, nil + } + items := make([]list.Item, len(msg.entries)) + for i, e := range msg.entries { + ip := e.PrimaryIP4 + if ip == "" { + ip = e.PrimaryIP6 + } + items[i] = hostItem{name: e.Name, ip: ip, kind: e.Kind} + } + m.list.SetItems(items) + m.err = nil + return m, nil + } + + // Forward to text input and restart the debounce timer. + var cmds []tea.Cmd + var inputCmd tea.Cmd + m.input, inputCmd = m.input.Update(msg) + cmds = append(cmds, inputCmd) + cmds = append(cmds, m.startDebounce()) + + var listCmd tea.Cmd + m.list, listCmd = m.list.Update(msg) + cmds = append(cmds, listCmd) + + return m, tea.Batch(cmds...) +} + +func (m *Model) View() string { + var sb strings.Builder + + title := lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("86")).Render("netssh") + sb.WriteString(title + "\n\n") + sb.WriteString(m.input.View() + "\n") + + if m.loading { + sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Render(" searching…") + "\n") + } else if m.err != nil { + sb.WriteString(lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Render(" error: "+m.err.Error()) + "\n") + } else { + sb.WriteString(m.list.View()) + } + + return sb.String() +} + +// Selected returns the host chosen by the user, or nil if none was selected. +func (m *Model) Selected() *SelectedHost { + return m.selected +} + +func (m *Model) startDebounce() tea.Cmd { + return tea.Tick(300*time.Millisecond, func(_ time.Time) tea.Msg { + return debounceMsg{query: m.input.Value()} + }) +} + +func (m *Model) doSearch(query string, seq int) tea.Cmd { + return func() tea.Msg { + // Return cache hits immediately without a network round-trip. + if m.cache != nil { + if cached := m.cache.Search(query); len(cached) > 0 { + entries := make([]netbox.HostEntry, len(cached)) + for i, c := range cached { + entries[i] = netbox.HostEntry{Name: c.Name, PrimaryIP4: c.IP, Kind: c.Kind} + } + return searchResultMsg{query: query, entries: entries} + } + } + + if m.client == nil { + return searchResultMsg{query: query, entries: nil} + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + entries, err := m.client.Search(ctx, query) + _ = seq + return searchResultMsg{query: query, entries: entries, err: err} + } +}