Compare commits

..

9 Commits

Author SHA1 Message Date
Marc Schäfer
0f6852b681 Disable Build binaries step in cicd.yml
Comment out the Build binaries step in the CI/CD workflow.
2026-02-22 22:17:27 +01:00
Marc Schäfer
2b8e280f2e Update .goreleaser.yaml for release configuration 2026-02-22 22:16:31 +01:00
Marc Schäfer
3a377d43de Add .goreleaser.yaml for project configuration 2026-02-22 22:12:48 +01:00
Marc Schäfer
792057cf6c Merge pull request #30 from marcschaeferger/repo
Repo
2026-02-22 22:03:57 +01:00
Marc Schäfer
57afe91e85 Create nfpm.yaml.tmpl for Newt packaging
Added nfpm.yaml template for packaging configuration.
2026-02-22 22:02:04 +01:00
Marc Schäfer
3389088c43 Add script to publish APT packages to S3
This script publishes APT packages to an S3 bucket, handling GPG signing and CloudFront invalidation.
2026-02-22 22:01:24 +01:00
Marc Schäfer
e73150c187 Update APT publishing workflow configuration
Refactor APT publishing workflow with improved variable handling and script execution.
2026-02-22 22:00:46 +01:00
Marc Schäfer
18556f34b2 Refactor package build process in publish-apt.yml
Refactor nfpm.yaml generation to use Python script and update package naming conventions.
2026-02-22 21:58:56 +01:00
Marc Schäfer
66c235624a Add workflow to publish APT repo to S3/CloudFront
This workflow automates the process of publishing an APT repository to S3/CloudFront upon release events. It includes steps for configuring AWS credentials, installing necessary tools, processing tags, building packages, and uploading the repository.
2026-02-22 21:56:11 +01:00
36 changed files with 1156 additions and 4056 deletions

File diff suppressed because it is too large Load Diff

62
.github/workflows/publish-apt.yml vendored Normal file
View File

@@ -0,0 +1,62 @@
name: Publish APT repo to S3/CloudFront
on:
release:
types: [published]
workflow_dispatch:
inputs:
tag:
description: "Tag to publish (e.g. v1.9.0). Leave empty to use latest release."
required: false
type: string
backfill_all:
description: "Build/publish repo for ALL releases."
required: false
default: false
type: boolean
permissions:
id-token: write
contents: read
jobs:
publish:
runs-on: ubuntu-latest
env:
PKG_NAME: newt
SUITE: stable
COMPONENT: main
REPO_BASE_URL: https://repo.dev.fosrl.io/apt
AWS_REGION: ${{ vars.AWS_REGION }}
S3_BUCKET: ${{ vars.S3_BUCKET }}
S3_PREFIX: ${{ vars.S3_PREFIX }}
CLOUDFRONT_DISTRIBUTION_ID: ${{ vars.CLOUDFRONT_DISTRIBUTION_ID }}
INPUT_TAG: ${{ inputs.tag }}
BACKFILL_ALL: ${{ inputs.backfill_all }}
EVENT_TAG: ${{ github.event.release.tag_name }}
GH_REPO: ${{ github.repository }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Configure AWS credentials (OIDC)
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.AWS_ROLE_ARN }}
aws-region: ${{ vars.AWS_REGION }}
- name: Install dependencies
run: sudo apt-get update && sudo apt-get install -y dpkg-dev apt-utils gnupg curl jq gh
- name: Install nfpm
run: curl -fsSL https://github.com/goreleaser/nfpm/releases/latest/download/nfpm_Linux_x86_64.tar.gz | sudo tar -xz -C /usr/local/bin nfpm
- name: Publish APT repo
env:
GH_TOKEN: ${{ github.token }}
APT_GPG_PRIVATE_KEY: ${{ secrets.APT_GPG_PRIVATE_KEY }}
APT_GPG_PASSPHRASE: ${{ secrets.APT_GPG_PASSPHRASE }}
run: ./scripts/publish-apt.sh

View File

@@ -31,7 +31,7 @@ jobs:
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up Go
uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 # v6.2.0
uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0
with:
go-version: 1.25

62
.goreleaser.yaml Normal file
View File

@@ -0,0 +1,62 @@
project_name: newt
release:
# du nutzt Tags wie 1.2.3 und 1.2.3-rc.1 (ohne v)
draft: true
prerelease: auto
name_template: "{{ .Tag }}"
mode: replace
builds:
- id: newt
main: ./main.go # <- falls du cmd/newt hast: ./cmd/newt
binary: newt
env:
- CGO_ENABLED=0
goos:
- linux
- darwin
- windows
- freebsd
goarch:
- amd64
- arm64
goarm:
- "6"
- "7"
flags:
- -trimpath
ldflags:
- -s -w -X main.version={{ .Tag }}
archives:
# Wichtig: format "binary" -> keine tar.gz, sondern raw binary wie bei dir aktuell
- id: raw
builds:
- newt
format: binary
name_template: >-
{{ .ProjectName }}_{{ .Os }}_{{ if eq .Arch "amd64" }}amd64{{ else if eq .Arch "arm64" }}arm64{{ else if eq .Arch "386" }}386{{ else }}{{ .Arch }}{{ end }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}{{ if .Amd64 }}_{{ .Amd64 }}{{ end }}{{ if .Riscv64 }}_{{ .Riscv64 }}{{ end }}{{ if .Os | eq "windows" }}.exe{{ end }}
checksum:
name_template: "checksums.txt"
nfpms:
- id: packages
package_name: newt
builds:
- newt
vendor: fosrl
maintainer: fosrl <repo@fosrl.io>
description: Newt - userspace tunnel client and TCP/UDP proxy
license: AGPL-3.0
formats:
- deb
- rpm
- apk
bindir: /usr/bin
# sorgt dafür, dass die Paketnamen gut pattern-matchbar sind
file_name_template: "{{ .PackageName }}_{{ .Version }}_{{ .Arch }}"
contents:
- src: LICENSE
dst: /usr/share/doc/newt/LICENSE

View File

@@ -1,5 +1,4 @@
# FROM golang:1.25-alpine AS builder
FROM public.ecr.aws/docker/library/golang:1.25-alpine AS builder
FROM golang:1.25-alpine AS builder
# Install git and ca-certificates
RUN apk --no-cache add ca-certificates git tzdata
@@ -17,10 +16,9 @@ RUN go mod download
COPY . .
# Build the application
ARG VERSION=dev
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X main.newtVersion=${VERSION}" -o /newt
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt
FROM public.ecr.aws/docker/library/alpine:3.23 AS runner
FROM alpine:3.23 AS runner
RUN apk --no-cache add ca-certificates tzdata iputils

View File

@@ -2,9 +2,6 @@
all: local
VERSION ?= dev
LDFLAGS = -X main.newtVersion=$(VERSION)
local:
CGO_ENABLED=0 go build -o ./bin/newt
@@ -43,31 +40,31 @@ go-build-release: \
go-build-release-freebsd-arm64
go-build-release-linux-arm64:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm64
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64
go-build-release-linux-arm32-v7:
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32
go-build-release-linux-arm32-v6:
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32v6
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6
go-build-release-linux-amd64:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_amd64
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64
go-build-release-linux-riscv64:
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_riscv64
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64
go-build-release-darwin-arm64:
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_arm64
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64
go-build-release-darwin-amd64:
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_amd64
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64
go-build-release-windows-amd64:
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_windows_amd64.exe
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe
go-build-release-freebsd-amd64:
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_amd64
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64
go-build-release-freebsd-arm64:
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_arm64
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64

View File

@@ -51,7 +51,6 @@ func startAuthDaemon(ctx context.Context) error {
PrincipalsFilePath: principalsFile,
CACertPath: caCertPath,
Force: true,
GenerateRandomPassword: authDaemonGenerateRandomPassword,
}
srv, err := authdaemon.NewServer(cfg)
@@ -73,6 +72,8 @@ func startAuthDaemon(ctx context.Context) error {
return nil
}
// runPrincipalsCmd executes the principals subcommand logic
func runPrincipalsCmd(args []string) {
opts := struct {

View File

@@ -7,8 +7,8 @@ import (
// ProcessConnection runs the same logic as POST /connection: CA cert, user create/reconcile, principals.
// Use this when DisableHTTPS is true (e.g. embedded in Newt) instead of calling the API.
func (s *Server) ProcessConnection(req ConnectionRequest) {
logger.Info("connection: niceId=%q username=%q metadata.sudoMode=%q metadata.sudoCommands=%v metadata.homedir=%v metadata.groups=%v",
req.NiceId, req.Username, req.Metadata.SudoMode, req.Metadata.SudoCommands, req.Metadata.Homedir, req.Metadata.Groups)
logger.Info("connection: niceId=%q username=%q metadata.sudo=%v metadata.homedir=%v",
req.NiceId, req.Username, req.Metadata.Sudo, req.Metadata.Homedir)
cfg := &s.cfg
if cfg.CACertPath != "" {
@@ -16,7 +16,7 @@ func (s *Server) ProcessConnection(req ConnectionRequest) {
logger.Warn("auth-daemon: write CA cert: %v", err)
}
}
if err := ensureUser(req.Username, req.Metadata, s.cfg.GenerateRandomPassword); err != nil {
if err := ensureUser(req.Username, req.Metadata); err != nil {
logger.Warn("auth-daemon: ensure user: %v", err)
}
if cfg.PrincipalsFilePath != "" {

View File

@@ -4,8 +4,6 @@ package authdaemon
import (
"bufio"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"os"
@@ -124,73 +122,8 @@ func sudoGroup() string {
return "sudo"
}
// setRandomPassword generates a random password and sets it for username via chpasswd.
// Used when GenerateRandomPassword is true so SSH with PermitEmptyPasswords no can accept the user.
func setRandomPassword(username string) error {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return fmt.Errorf("generate password: %w", err)
}
password := hex.EncodeToString(b)
cmd := exec.Command("chpasswd")
cmd.Stdin = strings.NewReader(username + ":" + password)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("chpasswd: %w (output: %s)", err, string(out))
}
return nil
}
const skelDir = "/etc/skel"
// copySkelInto copies files from srcDir (e.g. /etc/skel) into dstDir (e.g. user's home).
// Only creates files that don't already exist. All created paths are chowned to uid:gid.
func copySkelInto(srcDir, dstDir string, uid, gid int) {
entries, err := os.ReadDir(srcDir)
if err != nil {
if !os.IsNotExist(err) {
logger.Warn("auth-daemon: read %s: %v", srcDir, err)
}
return
}
for _, e := range entries {
name := e.Name()
src := filepath.Join(srcDir, name)
dst := filepath.Join(dstDir, name)
if e.IsDir() {
if st, err := os.Stat(dst); err == nil && st.IsDir() {
copySkelInto(src, dst, uid, gid)
continue
}
if err := os.MkdirAll(dst, 0755); err != nil {
logger.Warn("auth-daemon: mkdir %s: %v", dst, err)
continue
}
if err := os.Chown(dst, uid, gid); err != nil {
logger.Warn("auth-daemon: chown %s: %v", dst, err)
}
copySkelInto(src, dst, uid, gid)
continue
}
if _, err := os.Stat(dst); err == nil {
continue
}
data, err := os.ReadFile(src)
if err != nil {
logger.Warn("auth-daemon: read %s: %v", src, err)
continue
}
if err := os.WriteFile(dst, data, 0644); err != nil {
logger.Warn("auth-daemon: write %s: %v", dst, err)
continue
}
if err := os.Chown(dst, uid, gid); err != nil {
logger.Warn("auth-daemon: chown %s: %v", dst, err)
}
}
}
// ensureUser creates the system user if missing, or reconciles sudo and homedir to match meta.
func ensureUser(username string, meta ConnectionMetadata, generateRandomPassword bool) error {
func ensureUser(username string, meta ConnectionMetadata) error {
if username == "" {
return nil
}
@@ -199,49 +132,12 @@ func ensureUser(username string, meta ConnectionMetadata, generateRandomPassword
if _, ok := err.(user.UnknownUserError); !ok {
return fmt.Errorf("lookup user %s: %w", username, err)
}
return createUser(username, meta, generateRandomPassword)
return createUser(username, meta)
}
return reconcileUser(u, meta)
}
// desiredGroups returns the exact list of supplementary groups the user should have:
// meta.Groups plus the sudo group when meta.SudoMode is "full" (deduped).
func desiredGroups(meta ConnectionMetadata) []string {
seen := make(map[string]struct{})
var out []string
for _, g := range meta.Groups {
g = strings.TrimSpace(g)
if g == "" {
continue
}
if _, ok := seen[g]; ok {
continue
}
seen[g] = struct{}{}
out = append(out, g)
}
if meta.SudoMode == "full" {
sg := sudoGroup()
if _, ok := seen[sg]; !ok {
out = append(out, sg)
}
}
return out
}
// setUserGroups sets the user's supplementary groups to exactly groups (local mirrors metadata).
// When groups is empty, clears all supplementary groups (usermod -G "").
func setUserGroups(username string, groups []string) {
list := strings.Join(groups, ",")
cmd := exec.Command("usermod", "-G", list, username)
if out, err := cmd.CombinedOutput(); err != nil {
logger.Warn("auth-daemon: usermod -G %s: %v (output: %s)", list, err, string(out))
} else {
logger.Info("auth-daemon: set %s supplementary groups to %s", username, list)
}
}
func createUser(username string, meta ConnectionMetadata, generateRandomPassword bool) error {
func createUser(username string, meta ConnectionMetadata) error {
args := []string{"-s", "/bin/bash"}
if meta.Homedir {
args = append(args, "-m")
@@ -254,143 +150,75 @@ func createUser(username string, meta ConnectionMetadata, generateRandomPassword
return fmt.Errorf("useradd %s: %w (output: %s)", username, err, string(out))
}
logger.Info("auth-daemon: created user %s (homedir=%v)", username, meta.Homedir)
if generateRandomPassword {
if err := setRandomPassword(username); err != nil {
logger.Warn("auth-daemon: set random password for %s: %v", username, err)
} else {
logger.Info("auth-daemon: set random password for %s (PermitEmptyPasswords no)", username)
}
}
if meta.Homedir {
if u, err := user.Lookup(username); err == nil && u.HomeDir != "" {
uid, gid := mustAtoi(u.Uid), mustAtoi(u.Gid)
copySkelInto(skelDir, u.HomeDir, uid, gid)
}
}
setUserGroups(username, desiredGroups(meta))
switch meta.SudoMode {
case "full":
if err := configurePasswordlessSudo(username); err != nil {
logger.Warn("auth-daemon: configure passwordless sudo for %s: %v", username, err)
}
case "commands":
if len(meta.SudoCommands) > 0 {
if err := configureSudoCommands(username, meta.SudoCommands); err != nil {
logger.Warn("auth-daemon: configure sudo commands for %s: %v", username, err)
}
}
default:
removeSudoers(username)
}
return nil
}
const sudoersFilePrefix = "90-pangolin-"
func sudoersPath(username string) string {
return filepath.Join("/etc/sudoers.d", sudoersFilePrefix+username)
}
// writeSudoersFile writes content to the user's sudoers.d file and validates with visudo.
func writeSudoersFile(username, content string) error {
sudoersFile := sudoersPath(username)
tmpFile := sudoersFile + ".tmp"
if err := os.WriteFile(tmpFile, []byte(content), 0440); err != nil {
return fmt.Errorf("write temp sudoers file: %w", err)
}
cmd := exec.Command("visudo", "-c", "-f", tmpFile)
if meta.Sudo {
group := sudoGroup()
cmd := exec.Command("usermod", "-aG", group, username)
if out, err := cmd.CombinedOutput(); err != nil {
os.Remove(tmpFile)
return fmt.Errorf("visudo validation failed: %w (output: %s)", err, string(out))
logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, username, err, string(out))
} else {
logger.Info("auth-daemon: added %s to %s", username, group)
}
if err := os.Rename(tmpFile, sudoersFile); err != nil {
os.Remove(tmpFile)
return fmt.Errorf("move sudoers file: %w", err)
}
return nil
}
// configurePasswordlessSudo creates a sudoers.d file to allow passwordless sudo for the user.
func configurePasswordlessSudo(username string) error {
content := fmt.Sprintf("# Created by Pangolin auth-daemon\n%s ALL=(ALL) NOPASSWD:ALL\n", username)
if err := writeSudoersFile(username, content); err != nil {
return err
}
logger.Info("auth-daemon: configured passwordless sudo for %s", username)
return nil
}
// configureSudoCommands creates a sudoers.d file allowing only the listed commands (NOPASSWD).
// Each command should be a full path (e.g. /usr/bin/systemctl).
func configureSudoCommands(username string, commands []string) error {
var b strings.Builder
b.WriteString("# Created by Pangolin auth-daemon (restricted commands)\n")
n := 0
for _, c := range commands {
c = strings.TrimSpace(c)
if c == "" {
continue
}
fmt.Fprintf(&b, "%s ALL=(ALL) NOPASSWD: %s\n", username, c)
n++
}
if n == 0 {
return fmt.Errorf("no valid sudo commands")
}
if err := writeSudoersFile(username, b.String()); err != nil {
return err
}
logger.Info("auth-daemon: configured restricted sudo for %s (%d commands)", username, len(commands))
return nil
}
// removeSudoers removes the sudoers.d file for the user.
func removeSudoers(username string) {
sudoersFile := sudoersPath(username)
if err := os.Remove(sudoersFile); err != nil && !os.IsNotExist(err) {
logger.Warn("auth-daemon: remove sudoers for %s: %v", username, err)
} else if err == nil {
logger.Info("auth-daemon: removed sudoers for %s", username)
}
}
func mustAtoi(s string) int {
n, _ := strconv.Atoi(s)
return n
}
func reconcileUser(u *user.User, meta ConnectionMetadata) error {
setUserGroups(u.Username, desiredGroups(meta))
switch meta.SudoMode {
case "full":
if err := configurePasswordlessSudo(u.Username); err != nil {
logger.Warn("auth-daemon: configure passwordless sudo for %s: %v", u.Username, err)
}
case "commands":
if len(meta.SudoCommands) > 0 {
if err := configureSudoCommands(u.Username, meta.SudoCommands); err != nil {
logger.Warn("auth-daemon: configure sudo commands for %s: %v", u.Username, err)
group := sudoGroup()
inGroup, err := userInGroup(u.Username, group)
if err != nil {
logger.Warn("auth-daemon: check group %s: %v", group, err)
inGroup = false
}
if meta.Sudo && !inGroup {
cmd := exec.Command("usermod", "-aG", group, u.Username)
if out, err := cmd.CombinedOutput(); err != nil {
logger.Warn("auth-daemon: usermod -aG %s %s: %v (output: %s)", group, u.Username, err, string(out))
} else {
removeSudoers(u.Username)
logger.Info("auth-daemon: added %s to %s", u.Username, group)
}
} else if !meta.Sudo && inGroup {
cmd := exec.Command("gpasswd", "-d", u.Username, group)
if out, err := cmd.CombinedOutput(); err != nil {
logger.Warn("auth-daemon: gpasswd -d %s %s: %v (output: %s)", u.Username, group, err, string(out))
} else {
logger.Info("auth-daemon: removed %s from %s", u.Username, group)
}
default:
removeSudoers(u.Username)
}
if meta.Homedir && u.HomeDir != "" {
uid, gid := mustAtoi(u.Uid), mustAtoi(u.Gid)
if st, err := os.Stat(u.HomeDir); err != nil || !st.IsDir() {
if err := os.MkdirAll(u.HomeDir, 0755); err != nil {
logger.Warn("auth-daemon: mkdir %s: %v", u.HomeDir, err)
} else {
uid, gid := mustAtoi(u.Uid), mustAtoi(u.Gid)
_ = os.Chown(u.HomeDir, uid, gid)
copySkelInto(skelDir, u.HomeDir, uid, gid)
logger.Info("auth-daemon: created home %s for %s", u.HomeDir, u.Username)
}
} else {
// Ensure .bashrc etc. exist (e.g. home existed but was empty or skel was minimal)
copySkelInto(skelDir, u.HomeDir, uid, gid)
}
}
return nil
}
func userInGroup(username, groupName string) (bool, error) {
// getent group wheel returns "wheel:x:10:user1,user2"
cmd := exec.Command("getent", "group", groupName)
out, err := cmd.Output()
if err != nil {
return false, err
}
parts := strings.SplitN(strings.TrimSpace(string(out)), ":", 4)
if len(parts) < 4 {
return false, nil
}
members := strings.Split(parts[3], ",")
for _, m := range members {
if strings.TrimSpace(m) == username {
return true, nil
}
}
return false, nil
}

View File

@@ -12,7 +12,7 @@ func writeCACertIfNotExists(path, contents string, force bool) error {
}
// ensureUser returns an error on non-Linux.
func ensureUser(username string, meta ConnectionMetadata, generateRandomPassword bool) error {
func ensureUser(username string, meta ConnectionMetadata) error {
return errLinuxOnly
}

View File

@@ -13,10 +13,8 @@ func (s *Server) registerRoutes() {
// ConnectionMetadata is the metadata object in POST /connection.
type ConnectionMetadata struct {
SudoMode string `json:"sudoMode"` // "none" | "full" | "commands"
SudoCommands []string `json:"sudoCommands"` // used when sudoMode is "commands"
Sudo bool `json:"sudo"`
Homedir bool `json:"homedir"`
Groups []string `json:"groups"` // system groups to add the user to
}
// ConnectionRequest is the JSON body for POST /connection.

View File

@@ -29,7 +29,6 @@ type Config struct {
CACertPath string // Required. Where to write the CA cert (e.g. /etc/ssh/ca.pem). No default.
Force bool // If true, overwrite existing CA cert (and other items) when content differs. Default false.
PrincipalsFilePath string // Required. Path to the principals data file (JSON: username -> array of principals). No default.
GenerateRandomPassword bool // If true, set a random password on users when they are provisioned (for SSH PermitEmptyPasswords no).
}
type Server struct {

View File

@@ -38,12 +38,10 @@ type WgConfig struct {
type Target struct {
SourcePrefix string `json:"sourcePrefix"`
SourcePrefixes []string `json:"sourcePrefixes"`
DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
ResourceId int `json:"resourceId,omitempty"`
}
type PortRange struct {
@@ -174,7 +172,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig)
return service, nil
}
@@ -197,15 +194,6 @@ func (s *WireGuardService) Close() {
s.stopGetConfig = nil
}
// Flush access logs before tearing down the tunnel
if s.tnet != nil {
if ph := s.tnet.GetProxyHandler(); ph != nil {
if al := ph.GetAccessLogger(); al != nil {
al.Close()
}
}
}
// Stop the direct UDP relay first
s.StopDirectUDPRelay()
@@ -504,183 +492,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Client connectivity setup. Ready to accept connections from clients!")
}
// SyncConfig represents the configuration sent from server for syncing
type SyncConfig struct {
Targets []Target `json:"targets"`
Peers []Peer `json:"peers"`
}
func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) {
var syncConfig SyncConfig
logger.Debug("Received sync message: %v", msg)
logger.Info("Received sync configuration from remote server")
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncConfig); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync peers
if err := s.syncPeers(syncConfig.Peers); err != nil {
logger.Error("Failed to sync peers: %v", err)
}
// Sync targets
if err := s.syncTargets(syncConfig.Targets); err != nil {
logger.Error("Failed to sync targets: %v", err)
}
}
// syncPeers synchronizes the current peers with the desired state
// It removes peers not in the desired list and adds missing ones
func (s *WireGuardService) syncPeers(desiredPeers []Peer) error {
if s.device == nil {
return fmt.Errorf("WireGuard device is not initialized")
}
// Get current peers from the device
currentConfig, err := s.device.IpcGet()
if err != nil {
return fmt.Errorf("failed to get current device config: %v", err)
}
// Parse current peer public keys
lines := strings.Split(currentConfig, "\n")
currentPeerKeys := make(map[string]bool)
for _, line := range lines {
if strings.HasPrefix(line, "public_key=") {
pubKey := strings.TrimPrefix(line, "public_key=")
currentPeerKeys[pubKey] = true
}
}
// Build a map of desired peers by their public key (normalized)
desiredPeerMap := make(map[string]Peer)
for _, peer := range desiredPeers {
// Normalize the public key for comparison
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey)
continue
}
normalizedKey := util.FixKey(pubKey.String())
desiredPeerMap[normalizedKey] = peer
}
// Remove peers that are not in the desired list
for currentKey := range currentPeerKeys {
if _, exists := desiredPeerMap[currentKey]; !exists {
// Parse the key back to get the original format for removal
removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey)
if err := s.device.IpcSet(removeConfig); err != nil {
logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err)
} else {
logger.Info("Removed peer %s during sync", currentKey)
}
}
}
// Add peers that are missing
for normalizedKey, peer := range desiredPeerMap {
if _, exists := currentPeerKeys[normalizedKey]; !exists {
if err := s.addPeerToDevice(peer); err != nil {
logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err)
} else {
logger.Info("Added peer %s during sync", peer.PublicKey)
}
}
}
return nil
}
// syncTargets synchronizes the current targets with the desired state
// It removes targets not in the desired list and adds missing ones
func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
logger.Debug("Skipping target sync - using native interface (no proxy support)")
return nil
}
// Get current rules from the proxy handler
currentRules := s.tnet.GetProxySubnetRules()
// Build a map of current rules by source+dest prefix
type ruleKey struct {
sourcePrefix string
destPrefix string
}
currentRuleMap := make(map[ruleKey]bool)
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
currentRuleMap[key] = true
}
// Build a map of desired targets
desiredTargetMap := make(map[ruleKey]Target)
for _, target := range desiredTargets {
key := ruleKey{
sourcePrefix: target.SourcePrefix,
destPrefix: target.DestPrefix,
}
desiredTargetMap[key] = target
}
// Remove targets that are not in the desired list
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
if _, exists := desiredTargetMap[key]; !exists {
s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix)
logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String())
}
}
// Add targets that are missing
for key, target := range desiredTargetMap {
if _, exists := currentRuleMap[key]; !exists {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err)
continue
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
}
}
return nil
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.mu.Lock()
@@ -804,13 +615,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.TunnelIP = tunnelIP.String()
// Configure the access log sender to ship compressed session logs via websocket
s.tnet.SetAccessLogSender(func(data string) error {
return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{
"compressed": data,
})
})
// Create WireGuard device using the shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent, // Use silent logging by default - could be made configurable
@@ -891,19 +695,6 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil
}
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
func resolveSourcePrefixes(target Target) []string {
if len(target.SourcePrefixes) > 0 {
return target.SourcePrefixes
}
if target.SourcePrefix != "" {
return []string{target.SourcePrefix}
}
return nil
}
func (s *WireGuardService) ensureTargets(targets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
@@ -912,6 +703,11 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
@@ -926,14 +722,9 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
})
}
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
}
return nil
@@ -1303,6 +1094,12 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
// Process all targets
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1318,15 +1115,9 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
})
}
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
@@ -1355,21 +1146,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
// Process all targets
for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue
}
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
}
}
@@ -1403,24 +1194,30 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
// Process all update requests
for _, target := range requests.OldTargets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue
}
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix)
}
for _, target := range requests.NewTargets {
// Now add the new target
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1436,15 +1233,8 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
})
}
for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}

View File

@@ -5,10 +5,8 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"regexp"
"strings"
"time"
@@ -365,62 +363,27 @@ func parseTargetData(data interface{}) (TargetData, error) {
return targetData, nil
}
// parseTargetString parses a target string in the format "listenPort:host:targetPort"
// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort"
// Examples:
// - IPv4: "3001:192.168.1.1:80"
// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80"
//
// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error
func parseTargetString(target string) (int, string, error) {
// Find the first colon to extract the listen port
firstColon := strings.Index(target, ":")
if firstColon == -1 {
return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target)
}
listenPortStr := target[:firstColon]
var listenPort int
_, err := fmt.Sscanf(listenPortStr, "%d", &listenPort)
if err != nil {
return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr)
}
if listenPort <= 0 || listenPort > 65535 {
return 0, "", fmt.Errorf("listen port out of range: %d", listenPort)
}
// The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets
remainder := target[firstColon+1:]
host, targetPort, err := net.SplitHostPort(remainder)
if err != nil {
return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err)
}
// Reject empty host or target port
if host == "" {
return 0, "", fmt.Errorf("empty host in target: %s", target)
}
if targetPort == "" {
return 0, "", fmt.Errorf("empty target port in target: %s", target)
}
// Reconstruct the target address using JoinHostPort (handles IPv6 properly)
targetAddr := net.JoinHostPort(host, targetPort)
return listenPort, targetAddr, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Parse the target string, handling both IPv4 and IPv6 addresses
port, target, err := parseTargetString(t)
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid target format: %s (%v)", t, err)
logger.Info("Invalid port: %s", parts[0])
continue
}
switch action {
case "add":
target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
if updownScript != "" {
@@ -447,6 +410,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
case "remove":
logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided
if updownScript != "" {
_, err := executeUpdownScript(action, proto, target)
@@ -455,7 +420,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
}
}
err = pm.RemoveTarget(proto, tunnelIP, port)
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
@@ -510,29 +475,6 @@ func executeUpdownScript(action, proto, target string) (string, error) {
return target, nil
}
// interpolateBlueprint finds all {{...}} tokens in the raw blueprint bytes and
// replaces recognised schemes with their resolved values. Currently supported:
//
// - env.<VAR> replaced with the value of the named environment variable
//
// Any token that does not match a supported scheme is left as-is so that
// future schemes (e.g. tag., api.) are preserved rather than silently dropped.
func interpolateBlueprint(data []byte) []byte {
re := regexp.MustCompile(`\{\{([^}]+)\}\}`)
return re.ReplaceAllFunc(data, func(match []byte) []byte {
// strip the surrounding {{ }}
inner := strings.TrimSpace(string(match[2 : len(match)-2]))
if strings.HasPrefix(inner, "env.") {
varName := strings.TrimPrefix(inner, "env.")
return []byte(os.Getenv(varName))
}
// unrecognised scheme leave the token untouched
return match
})
}
func sendBlueprint(client *websocket.Client) error {
if blueprintFile == "" {
return nil
@@ -542,9 +484,6 @@ func sendBlueprint(client *websocket.Client) error {
if err != nil {
logger.Error("Failed to read blueprint file: %v", err)
} else {
// interpolate {{env.VAR}} (and any future schemes) before parsing
blueprintData = interpolateBlueprint(blueprintData)
// first we should convert the yaml to json and error if the yaml is bad
var yamlObj interface{}
var blueprintJsonData string

View File

@@ -1,212 +0,0 @@
package main
import (
"net"
"testing"
)
func TestParseTargetString(t *testing.T) {
tests := []struct {
name string
input string
wantListenPort int
wantTargetAddr string
wantErr bool
}{
// IPv4 test cases
{
name: "valid IPv4 basic",
input: "3001:192.168.1.1:80",
wantListenPort: 3001,
wantTargetAddr: "192.168.1.1:80",
wantErr: false,
},
{
name: "valid IPv4 localhost",
input: "8080:127.0.0.1:3000",
wantListenPort: 8080,
wantTargetAddr: "127.0.0.1:3000",
wantErr: false,
},
{
name: "valid IPv4 same ports",
input: "443:10.0.0.1:443",
wantListenPort: 443,
wantTargetAddr: "10.0.0.1:443",
wantErr: false,
},
// IPv6 test cases
{
name: "valid IPv6 loopback",
input: "3001:[::1]:8080",
wantListenPort: 3001,
wantTargetAddr: "[::1]:8080",
wantErr: false,
},
{
name: "valid IPv6 full address",
input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantListenPort: 80,
wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantErr: false,
},
{
name: "valid IPv6 link-local",
input: "443:[fe80::1]:443",
wantListenPort: 443,
wantTargetAddr: "[fe80::1]:443",
wantErr: false,
},
{
name: "valid IPv6 all zeros compressed",
input: "8000:[::]:9000",
wantListenPort: 8000,
wantTargetAddr: "[::]:9000",
wantErr: false,
},
{
name: "valid IPv6 mixed notation",
input: "5000:[::ffff:192.168.1.1]:6000",
wantListenPort: 5000,
wantTargetAddr: "[::ffff:192.168.1.1]:6000",
wantErr: false,
},
// Hostname test cases
{
name: "valid hostname",
input: "8080:example.com:80",
wantListenPort: 8080,
wantTargetAddr: "example.com:80",
wantErr: false,
},
{
name: "valid hostname with subdomain",
input: "443:api.example.com:8443",
wantListenPort: 443,
wantTargetAddr: "api.example.com:8443",
wantErr: false,
},
{
name: "valid localhost hostname",
input: "3000:localhost:3000",
wantListenPort: 3000,
wantTargetAddr: "localhost:3000",
wantErr: false,
},
// Error cases
{
name: "invalid - no colons",
input: "invalid",
wantErr: true,
},
{
name: "invalid - empty string",
input: "",
wantErr: true,
},
{
name: "invalid - non-numeric listen port",
input: "abc:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - missing target port",
input: "3001:192.168.1.1",
wantErr: true,
},
{
name: "invalid - IPv6 without brackets",
input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80",
wantErr: true,
},
{
name: "invalid - only listen port",
input: "3001:",
wantErr: true,
},
{
name: "invalid - missing host",
input: "3001::80",
wantErr: true,
},
{
name: "invalid - IPv6 unclosed bracket",
input: "3001:[::1:80",
wantErr: true,
},
{
name: "invalid - listen port zero",
input: "0:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port negative",
input: "-1:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port out of range",
input: "70000:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - empty target port",
input: "3001:192.168.1.1:",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
listenPort, targetAddr, err := parseTargetString(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if tt.wantErr {
return // Don't check other values if we expected an error
}
if listenPort != tt.wantListenPort {
t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort)
}
if targetAddr != tt.wantTargetAddr {
t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr)
}
})
}
}
// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial
func TestParseTargetStringNetDialCompatibility(t *testing.T) {
tests := []struct {
name string
input string
}{
{"IPv4", "8080:127.0.0.1:80"},
{"IPv6 loopback", "8080:[::1]:80"},
{"IPv6 full", "8080:[2001:db8::1]:80"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, targetAddr, err := parseTargetString(tt.input)
if err != nil {
t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err)
}
// Verify the format is valid for net.Dial by checking it can be split back
// This doesn't actually dial, just validates the format
_, _, err = net.SplitHostPort(targetAddr)
if err != nil {
t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err)
}
})
}
}

View File

@@ -1,4 +0,0 @@
{
"endpoint": "http://you.fosrl.io",
"provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx"
}

View File

@@ -1,4 +0,0 @@
{
"endpoint": "http://you.fosrl.io",
"provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx"
}

View File

@@ -35,7 +35,7 @@
inherit version;
src = pkgs.nix-gitignore.gitignoreSource [ ] ./.;
vendorHash = "sha256-kmQM8Yy5TuOiNpMpUme/2gfE+vrhUK+0AphN+p71wGs=";
vendorHash = "sha256-Sib6AUCpMgxlMpTc2Esvs+UU0yduVOxWUgT44FHAI+k=";
nativeInstallCheckInputs = [ pkgs.versionCheckHook ];

View File

@@ -1,7 +1,7 @@
#!/bin/sh
#!/bin/bash
# Get Newt - Cross-platform installation script
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | sh
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | bash
set -e
@@ -17,15 +17,15 @@ GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
# Function to print colored output
print_status() {
printf '%b[INFO]%b %s\n' "${GREEN}" "${NC}" "$1"
echo -e "${GREEN}[INFO]${NC} $1"
}
print_warning() {
printf '%b[WARN]%b %s\n' "${YELLOW}" "${NC}" "$1"
echo -e "${YELLOW}[WARN]${NC} $1"
}
print_error() {
printf '%b[ERROR]%b %s\n' "${RED}" "${NC}" "$1"
echo -e "${RED}[ERROR]${NC} $1"
}
# Function to get latest version from GitHub API
@@ -113,34 +113,16 @@ get_install_dir() {
if [ "$OS" = "windows" ]; then
echo "$HOME/bin"
else
# Prefer /usr/local/bin for system-wide installation
# Try to use a directory in PATH, fallback to ~/.local/bin
if echo "$PATH" | grep -q "/usr/local/bin"; then
if [ -w "/usr/local/bin" ] 2>/dev/null; then
echo "/usr/local/bin"
fi
}
# Check if we need sudo for installation
needs_sudo() {
local install_dir="$1"
if [ -w "$install_dir" ] 2>/dev/null; then
return 1 # No sudo needed
else
return 0 # Sudo needed
fi
}
# Get the appropriate command prefix (sudo or empty)
get_sudo_cmd() {
local install_dir="$1"
if needs_sudo "$install_dir"; then
if command -v sudo >/dev/null 2>&1; then
echo "sudo"
else
print_error "Cannot write to ${install_dir} and sudo is not available."
print_error "Please run this script as root or install sudo."
exit 1
echo "$HOME/.local/bin"
fi
else
echo ""
echo "$HOME/.local/bin"
fi
fi
}
@@ -148,17 +130,14 @@ get_sudo_cmd() {
install_newt() {
local platform="$1"
local install_dir="$2"
local sudo_cmd="$3"
local binary_name="newt_${platform}"
local exe_suffix=""
# Add .exe suffix for Windows
case "$platform" in
*windows*)
if [[ "$platform" == *"windows"* ]]; then
binary_name="${binary_name}.exe"
exe_suffix=".exe"
;;
esac
fi
local download_url="${BASE_URL}/${binary_name}"
local temp_file="/tmp/newt${exe_suffix}"
@@ -176,18 +155,14 @@ install_newt() {
exit 1
fi
# Make executable before moving
chmod +x "$temp_file"
# Create install directory if it doesn't exist
if [ -n "$sudo_cmd" ]; then
$sudo_cmd mkdir -p "$install_dir"
print_status "Using sudo to install to ${install_dir}"
$sudo_cmd mv "$temp_file" "$final_path"
else
mkdir -p "$install_dir"
# Move binary to install directory
mv "$temp_file" "$final_path"
fi
# Make executable (not needed on Windows, but doesn't hurt)
chmod +x "$final_path"
print_status "newt installed to ${final_path}"
@@ -204,9 +179,9 @@ verify_installation() {
local install_dir="$1"
local exe_suffix=""
case "$PLATFORM" in
*windows*) exe_suffix=".exe" ;;
esac
if [[ "$PLATFORM" == *"windows"* ]]; then
exe_suffix=".exe"
fi
local newt_path="${install_dir}/newt${exe_suffix}"
@@ -240,19 +215,17 @@ main() {
INSTALL_DIR=$(get_install_dir)
print_status "Install directory: ${INSTALL_DIR}"
# Check if we need sudo
SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR")
if [ -n "$SUDO_CMD" ]; then
print_status "Root privileges required for installation to ${INSTALL_DIR}"
fi
# Install newt
install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD"
install_newt "$PLATFORM" "$INSTALL_DIR"
# Verify installation
if verify_installation "$INSTALL_DIR"; then
print_status "newt is ready to use!"
if [[ "$PLATFORM" == *"windows"* ]]; then
print_status "Run 'newt --help' to get started"
else
print_status "Run 'newt --help' to get started"
fi
else
exit 1
fi

49
go.mod
View File

@@ -1,30 +1,29 @@
module github.com/fosrl/newt
go 1.25.0
go 1.25
require (
github.com/docker/docker v28.5.2+incompatible
github.com/gaissmai/bart v0.26.0
github.com/gorilla/websocket v1.5.3
github.com/prometheus/client_golang v1.23.2
github.com/vishvananda/netlink v1.3.1
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.66.0
go.opentelemetry.io/contrib/instrumentation/runtime v0.66.0
go.opentelemetry.io/otel v1.41.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.41.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.41.0
go.opentelemetry.io/otel/exporters/prometheus v0.63.0
go.opentelemetry.io/otel/metric v1.41.0
go.opentelemetry.io/otel/sdk v1.41.0
go.opentelemetry.io/otel/sdk/metric v1.41.0
golang.org/x/crypto v0.48.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0
go.opentelemetry.io/contrib/instrumentation/runtime v0.64.0
go.opentelemetry.io/otel v1.39.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0
go.opentelemetry.io/otel/exporters/prometheus v0.61.0
go.opentelemetry.io/otel/metric v1.39.0
go.opentelemetry.io/otel/sdk v1.39.0
go.opentelemetry.io/otel/sdk/metric v1.39.0
golang.org/x/crypto v0.46.0
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6
golang.org/x/net v0.51.0
golang.org/x/sys v0.41.0
golang.org/x/net v0.48.0
golang.org/x/sys v0.39.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.79.1
google.golang.org/grpc v1.77.0
gopkg.in/yaml.v3 v3.0.1
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
software.sslmate.com/src/go-pkcs12 v0.7.0
@@ -45,7 +44,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/sys/atomicwriter v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect
@@ -55,23 +54,23 @@ require (
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/common v0.67.4 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.41.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.41.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect
google.golang.org/protobuf v1.36.11 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
google.golang.org/protobuf v1.36.10 // indirect
)

94
go.sum
View File

@@ -26,8 +26,6 @@ github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -43,8 +41,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -77,8 +75,8 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEoIwkU+A6qos=
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
@@ -95,56 +93,56 @@ github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zd
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.66.0 h1:PnV4kVnw0zOmwwFkAzCN5O07fw1YOIQor120zrh0AVo=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.66.0/go.mod h1:ofAwF4uinaf8SXdVzzbL4OsxJ3VfeEg3f/F6CeF49/Y=
go.opentelemetry.io/contrib/instrumentation/runtime v0.66.0 h1:JruBNmrPELWjR+PU3fsQBFQRYtsMLQ/zPfbvwDz9I/w=
go.opentelemetry.io/contrib/instrumentation/runtime v0.66.0/go.mod h1:vwNrfL6w1uAE3qX48KFii2Qoqf+NEDP5wNjus+RHz8Y=
go.opentelemetry.io/otel v1.41.0 h1:YlEwVsGAlCvczDILpUXpIpPSL/VPugt7zHThEMLce1c=
go.opentelemetry.io/otel v1.41.0/go.mod h1:Yt4UwgEKeT05QbLwbyHXEwhnjxNO6D8L5PQP51/46dE=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.41.0 h1:VO3BL6OZXRQ1yQc8W6EVfJzINeJ35BkiHx4MYfoQf44=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.41.0/go.mod h1:qRDnJ2nv3CQXMK2HUd9K9VtvedsPAce3S+/4LZHjX/s=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.41.0 h1:ao6Oe+wSebTlQ1OEht7jlYTzQKE+pnx/iNywFvTbuuI=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.41.0/go.mod h1:u3T6vz0gh/NVzgDgiwkgLxpsSF6PaPmo2il0apGJbls=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.41.0 h1:mq/Qcf28TWz719lE3/hMB4KkyDuLJIvgJnFGcd0kEUI=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.41.0/go.mod h1:yk5LXEYhsL2htyDNJbEq7fWzNEigeEdV5xBF/Y+kAv0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
go.opentelemetry.io/contrib/instrumentation/runtime v0.64.0 h1:/+/+UjlXjFcdDlXxKL1PouzX8Z2Vl0OxolRKeBEgYDw=
go.opentelemetry.io/contrib/instrumentation/runtime v0.64.0/go.mod h1:Ldm/PDuzY2DP7IypudopCR3OCOW42NJlN9+mNEroevo=
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0 h1:cEf8jF6WbuGQWUVcqgyWtTR0kOOAWY1DYZ+UhvdmQPw=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0/go.mod h1:k1lzV5n5U3HkGvTCJHraTAGJ7MqsgL1wrGwTj1Isfiw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0/go.mod h1:vnakAaFckOMiMtOIhFI2MNH4FYrZzXCYxmb1LlhoGz8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 h1:in9O8ESIOlwJAEGTkkf34DesGRAc/Pn8qJ7k3r/42LM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0/go.mod h1:Rp0EXBm5tfnv0WL+ARyO/PHBEaEAT8UUHQ6AGJcSq6c=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 h1:aTL7F04bJHUlztTsNGJ2l+6he8c+y/b//eR0jjjemT4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0/go.mod h1:kldtb7jDTeol0l3ewcmd8SDvx3EmIE7lyvqbasU3QC4=
go.opentelemetry.io/otel/exporters/prometheus v0.63.0 h1:OLo1FNb0pBZykLqbKRZolKtGZd0Waqlr240YdMEnhhg=
go.opentelemetry.io/otel/exporters/prometheus v0.63.0/go.mod h1:8yeQAdhrK5xsWuFehO13Dk/Xb9FuhZoVpJfpoNCfJnw=
go.opentelemetry.io/otel/metric v1.41.0 h1:rFnDcs4gRzBcsO9tS8LCpgR0dxg4aaxWlJxCno7JlTQ=
go.opentelemetry.io/otel/metric v1.41.0/go.mod h1:xPvCwd9pU0VN8tPZYzDZV/BMj9CM9vs00GuBjeKhJps=
go.opentelemetry.io/otel/sdk v1.41.0 h1:YPIEXKmiAwkGl3Gu1huk1aYWwtpRLeskpV+wPisxBp8=
go.opentelemetry.io/otel/sdk v1.41.0/go.mod h1:ahFdU0G5y8IxglBf0QBJXgSe7agzjE4GiTJ6HT9ud90=
go.opentelemetry.io/otel/sdk/metric v1.41.0 h1:siZQIYBAUd1rlIWQT2uCxWJxcCO7q3TriaMlf08rXw8=
go.opentelemetry.io/otel/sdk/metric v1.41.0/go.mod h1:HNBuSvT7ROaGtGI50ArdRLUnvRTRGniSUZbxiWxSO8Y=
go.opentelemetry.io/otel/trace v1.41.0 h1:Vbk2co6bhj8L59ZJ6/xFTskY+tGAbOnCtQGVVa9TIN0=
go.opentelemetry.io/otel/trace v1.41.0/go.mod h1:U1NU4ULCoxeDKc09yCWdWe+3QoyweJcISEVa1RBzOis=
go.opentelemetry.io/otel/exporters/prometheus v0.61.0 h1:cCyZS4dr67d30uDyh8etKM2QyDsQ4zC9ds3bdbrVoD0=
go.opentelemetry.io/otel/exporters/prometheus v0.61.0/go.mod h1:iivMuj3xpR2DkUrUya3TPS/Z9h3dz7h01GxU+fQBRNg=
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
@@ -155,14 +153,14 @@ golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0=
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY=
google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@@ -5,9 +5,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
@@ -367,12 +365,11 @@ func (m *Monitor) performHealthCheck(target *Target) {
target.LastCheck = time.Now()
target.LastError = ""
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
host := target.Config.Hostname
// Build URL
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
if target.Config.Port > 0 {
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
url = fmt.Sprintf("%s:%d", url, target.Config.Port)
}
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
if target.Config.Path != "" {
if !strings.HasPrefix(target.Config.Path, "/") {
url += "/"
@@ -524,82 +521,3 @@ func (m *Monitor) DisableTarget(id int) error {
return nil
}
// GetTargetIDs returns a slice of all current target IDs
func (m *Monitor) GetTargetIDs() []int {
m.mutex.RLock()
defer m.mutex.RUnlock()
ids := make([]int, 0, len(m.targets))
for id := range m.targets {
ids = append(ids, id)
}
return ids
}
// SyncTargets synchronizes the current targets to match the desired set.
// It removes targets not in the desired set and adds targets that are missing.
func (m *Monitor) SyncTargets(desiredConfigs []Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs))
// Build a set of desired target IDs
desiredIDs := make(map[int]Config)
for _, config := range desiredConfigs {
desiredIDs[config.ID] = config
}
// Find targets to remove (exist but not in desired set)
var toRemove []int
for id := range m.targets {
if _, exists := desiredIDs[id]; !exists {
toRemove = append(toRemove, id)
}
}
// Remove targets that are not in the desired set
for _, id := range toRemove {
logger.Info("Sync: removing health check target %d", id)
if target, exists := m.targets[id]; exists {
target.cancel()
delete(m.targets, id)
}
}
// Add or update targets from the desired set
var addedCount, updatedCount int
for id, config := range desiredIDs {
if existing, exists := m.targets[id]; exists {
// Target exists - check if config changed and update if needed
// For now, we'll replace it to ensure config is up to date
logger.Debug("Sync: updating health check target %d", id)
existing.cancel()
delete(m.targets, id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to update target %d: %v", id, err)
return fmt.Errorf("failed to update target %d: %v", id, err)
}
updatedCount++
} else {
// Target doesn't exist - add it
logger.Debug("Sync: adding health check target %d", id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to add target %d: %v", id, err)
return fmt.Errorf("failed to add target %d: %v", id, err)
}
addedCount++
}
}
logger.Info("Sync complete: removed %d, added %d, updated %d targets",
len(toRemove), addedCount, updatedCount)
// Notify callback if any changes were made
if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil {
go m.callback(m.getAllTargetsUnsafe())
}
return nil
}

257
main.go
View File

@@ -10,7 +10,6 @@ import (
"fmt"
"net"
"net/http"
"net/http/pprof"
"net/netip"
"os"
"os/signal"
@@ -117,7 +116,6 @@ var (
logLevel string
interfaceName string
port uint16
portStr string
disableClients bool
updownScript string
dockerSocket string
@@ -138,7 +136,6 @@ var (
authDaemonPrincipalsFile string
authDaemonCACertPath string
authDaemonEnabled bool
authDaemonGenerateRandomPassword bool
// Build/version (can be overridden via -ldflags "-X main.newtVersion=...")
newtVersion = "version_replaceme"
@@ -148,7 +145,6 @@ var (
adminAddr string
region string
metricsAsyncBytes bool
pprofEnabled bool
blueprintFile string
noCloud bool
@@ -159,12 +155,6 @@ var (
// Legacy PKCS12 support (deprecated)
tlsPrivateKey string
// Provisioning key exchanged once for a permanent newt ID + secret
provisioningKey string
// Path to config file (overrides CONFIG_FILE env var and default location)
configFile string
)
func main() {
@@ -220,12 +210,11 @@ func runNewtMain(ctx context.Context) {
logLevel = os.Getenv("LOG_LEVEL")
updownScript = os.Getenv("UPDOWN_SCRIPT")
interfaceName = os.Getenv("INTERFACE")
portStr = os.Getenv("PORT")
portStr := os.Getenv("PORT")
authDaemonKey = os.Getenv("AD_KEY")
authDaemonPrincipalsFile = os.Getenv("AD_PRINCIPALS_FILE")
authDaemonCACertPath = os.Getenv("AD_CA_CERT_PATH")
authDaemonEnabledEnv := os.Getenv("AUTH_DAEMON_ENABLED")
authDaemonGenerateRandomPasswordEnv := os.Getenv("AD_GENERATE_RANDOM_PASSWORD")
// Metrics/observability env mirrors
metricsEnabledEnv := os.Getenv("NEWT_METRICS_PROMETHEUS_ENABLED")
@@ -233,7 +222,6 @@ func runNewtMain(ctx context.Context) {
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true"
@@ -270,8 +258,6 @@ func runNewtMain(ctx context.Context) {
blueprintFile = os.Getenv("BLUEPRINT_FILE")
noCloudEnv := os.Getenv("NO_CLOUD")
noCloud = noCloudEnv == "true"
provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY")
configFile = os.Getenv("CONFIG_FILE")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -313,19 +299,13 @@ func runNewtMain(ctx context.Context) {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)")
}
if pingIntervalStr == "" {
flag.StringVar(&pingIntervalStr, "ping-interval", "15s", "Interval for pinging the server (default 15s)")
flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
}
if pingTimeoutStr == "" {
flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)")
flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)")
}
// load the prefer endpoint just as a flag
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
if provisioningKey == "" {
flag.StringVar(&provisioningKey, "provisioning-key", "", "One-time provisioning key used to obtain a newt ID and secret from the server")
}
if configFile == "" {
flag.StringVar(&configFile, "config-file", "", "Path to config file (overrides CONFIG_FILE env var and default location)")
}
// Add new mTLS flags
if tlsClientCert == "" {
@@ -347,21 +327,30 @@ func runNewtMain(ctx context.Context) {
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 15 seconds\n", pingIntervalStr)
pingInterval = 15 * time.Second
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
pingInterval = 3 * time.Second
}
} else {
pingInterval = 15 * time.Second
pingInterval = 3 * time.Second
}
if pingTimeoutStr != "" {
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 7 seconds\n", pingTimeoutStr)
pingTimeout = 7 * time.Second
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
pingTimeout = 5 * time.Second
}
} else {
pingTimeout = 7 * time.Second
pingTimeout = 5 * time.Second
}
if portStr != "" {
portInt, err := strconv.Atoi(portStr)
if err != nil {
logger.Warn("Failed to parse PORT, choosing a random port")
} else {
port = uint16(portInt)
}
}
if dockerEnforceNetworkValidation == "" {
@@ -407,14 +396,6 @@ func runNewtMain(ctx context.Context) {
metricsAsyncBytes = v
}
}
// pprof debug endpoint toggle
if pprofEnabledEnv == "" {
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
} else {
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
pprofEnabled = v
}
}
// Optional region flag (resource attribute)
if regionEnv == "" {
flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)")
@@ -439,13 +420,6 @@ func runNewtMain(ctx context.Context) {
authDaemonEnabled = v
}
}
if authDaemonGenerateRandomPasswordEnv == "" {
flag.BoolVar(&authDaemonGenerateRandomPassword, "ad-generate-random-password", false, "Generate a random password for authenticated users")
} else {
if v, err := strconv.ParseBool(authDaemonGenerateRandomPasswordEnv); err == nil {
authDaemonGenerateRandomPassword = v
}
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@@ -457,15 +431,6 @@ func runNewtMain(ctx context.Context) {
tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...)
}
if portStr != "" {
portInt, err := strconv.Atoi(portStr)
if err != nil {
logger.Warn("Failed to parse PORT, choosing a random port")
} else {
port = uint16(portInt)
}
}
if *version {
fmt.Println("Newt version " + newtVersion)
os.Exit(0)
@@ -510,14 +475,6 @@ func runNewtMain(ctx context.Context) {
if tel.PrometheusHandler != nil {
mux.Handle("/metrics", tel.PrometheusHandler)
}
if pprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
}
admin := &http.Server{
Addr: tcfg.AdminAddr,
Handler: otelhttp.NewHandler(mux, "newt-admin"),
@@ -598,19 +555,13 @@ func runNewtMain(ctx context.Context) {
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
30*time.Second,
pingInterval,
pingTimeout,
opt,
websocket.WithConfigFile(configFile),
)
if err != nil {
logger.Fatal("Failed to create client: %v", err)
}
// If a provisioning key was supplied via CLI / env and the config file did
// not already carry one, inject it now so provisionIfNeeded() can use it.
if provisioningKey != "" && client.GetConfig().ProvisioningKey == "" {
client.GetConfig().ProvisioningKey = provisioningKey
}
endpoint = client.GetConfig().Endpoint // Update endpoint from config
id = client.GetConfig().ID // Update ID from config
// Update site labels for metrics with the resolved ID
@@ -996,7 +947,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
}, 2*time.Second)
}, 1*time.Second)
return
}
@@ -1099,7 +1050,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
}, 2*time.Second)
}, 1*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
})
@@ -1204,153 +1155,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
})
// Register handler for syncing targets (TCP, UDP, and health checks)
client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) {
logger.Info("Received sync message")
// if there is no wgData or pm, we can't sync targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info(msgNoTunnelOrProxy)
return
}
// Define the sync data structure
type SyncData struct {
Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
}
var syncData SyncData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d",
len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets))
//TODO: TEST AND IMPLEMENT THIS
// // Build sets of desired targets (port -> target string)
// desiredTCP := make(map[int]string)
// for _, t := range syncData.Targets.TCP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid TCP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in TCP target: %s", parts[0])
// continue
// }
// desiredTCP[port] = parts[1] + ":" + parts[2]
// }
// desiredUDP := make(map[int]string)
// for _, t := range syncData.Targets.UDP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid UDP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in UDP target: %s", parts[0])
// continue
// }
// desiredUDP[port] = parts[1] + ":" + parts[2]
// }
// // Get current targets from proxy manager
// currentTCP, currentUDP := pm.GetTargets()
// // Sync TCP targets
// // Remove TCP targets not in desired set
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// for port := range tcpForIP {
// if _, exists := desiredTCP[port]; !exists {
// logger.Info("Sync: removing TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add TCP targets that are missing
// for port, target := range desiredTCP {
// needsAdd := true
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// if currentTarget, exists := tcpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding TCP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync UDP targets
// // Remove UDP targets not in desired set
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// for port := range udpForIP {
// if _, exists := desiredUDP[port]; !exists {
// logger.Info("Sync: removing UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add UDP targets that are missing
// for port, target := range desiredUDP {
// needsAdd := true
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// if currentTarget, exists := udpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding UDP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync health check targets
// if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil {
// logger.Error("Failed to sync health check targets: %v", err)
// } else {
// logger.Info("Successfully synced health check targets")
// }
logger.Info("Sync complete")
})
// Register handler for Docker socket check
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
logger.Debug("Received Docker socket check request")
@@ -1577,15 +1381,12 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
MessageId int `json:"messageId"`
AgentPort int `json:"agentPort"`
AgentHost string `json:"agentHost"`
ExternalAuthDaemon bool `json:"externalAuthDaemon"`
CACert string `json:"caCert"`
Username string `json:"username"`
NiceID string `json:"niceId"`
Metadata struct {
SudoMode string `json:"sudoMode"`
SudoCommands []string `json:"sudoCommands"`
Sudo bool `json:"sudo"`
Homedir bool `json:"homedir"`
Groups []string `json:"groups"`
} `json:"metadata"`
}
@@ -1605,7 +1406,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
// Check if we're running the auth daemon internally
if authDaemonServer != nil && !certData.ExternalAuthDaemon { // if the auth daemon is running internally and the external auth daemon is not enabled
if authDaemonServer != nil {
// Call ProcessConnection directly when running internally
logger.Debug("Calling internal auth daemon ProcessConnection for user %s", certData.Username)
@@ -1614,10 +1415,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
NiceId: certData.NiceID,
Username: certData.Username,
Metadata: authdaemon.ConnectionMetadata{
SudoMode: certData.Metadata.SudoMode,
SudoCommands: certData.Metadata.SudoCommands,
Sudo: certData.Metadata.Sudo,
Homedir: certData.Metadata.Homedir,
Groups: certData.Metadata.Groups,
},
})
@@ -1651,10 +1450,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"niceId": certData.NiceID,
"username": certData.Username,
"metadata": map[string]interface{}{
"sudoMode": certData.Metadata.SudoMode,
"sudoCommands": certData.Metadata.SudoCommands,
"sudo": certData.Metadata.Sudo,
"homedir": certData.Metadata.Homedir,
"groups": certData.Metadata.Groups,
},
}
@@ -1833,8 +1630,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
pm.Stop()
}
client.SendMessage("newt/disconnecting", map[string]any{})
if client != nil {
client.Close()
}

View File

@@ -1,514 +0,0 @@
package netstack2
import (
"bytes"
"compress/zlib"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"net"
"sort"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
const (
// flushInterval is how often the access logger flushes completed sessions to the server
flushInterval = 60 * time.Second
// maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush
maxBufferedSessions = 100
// sessionGapThreshold is the maximum gap between the end of one connection
// and the start of the next for them to be considered part of the same session.
// If the gap exceeds this, a new consolidated session is created.
sessionGapThreshold = 5 * time.Second
// minConnectionsToConsolidate is the minimum number of connections in a group
// before we bother consolidating. Groups smaller than this are sent as-is.
minConnectionsToConsolidate = 2
)
// SendFunc is a callback that sends compressed access log data to the server.
// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects.
type SendFunc func(data string) error
// AccessSession represents a tracked access session through the proxy
type AccessSession struct {
SessionID string `json:"sessionId"`
ResourceID int `json:"resourceId"`
SourceAddr string `json:"sourceAddr"`
DestAddr string `json:"destAddr"`
Protocol string `json:"protocol"`
StartedAt time.Time `json:"startedAt"`
EndedAt time.Time `json:"endedAt,omitempty"`
BytesTx int64 `json:"bytesTx"`
BytesRx int64 `json:"bytesRx"`
ConnectionCount int `json:"connectionCount,omitempty"` // number of raw connections merged into this session (0 or 1 = single)
}
// udpSessionKey identifies a unique UDP "session" by src -> dst
type udpSessionKey struct {
srcAddr string
dstAddr string
protocol string
}
// consolidationKey groups connections that may be part of the same logical session.
// Source port is intentionally excluded so that many ephemeral-port connections
// from the same source IP to the same destination are grouped together.
type consolidationKey struct {
sourceIP string // IP only, no port
destAddr string // full host:port of the destination
protocol string
resourceID int
}
// AccessLogger tracks access sessions for resources and periodically
// flushes completed sessions to the server via a configurable SendFunc.
type AccessLogger struct {
mu sync.Mutex
sessions map[string]*AccessSession // active sessions: sessionID -> session
udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup
completedSessions []*AccessSession // completed sessions waiting to be flushed
udpTimeout time.Duration
sendFn SendFunc
stopCh chan struct{}
flushDone chan struct{} // closed after the flush goroutine exits
}
// NewAccessLogger creates a new access logger.
// udpTimeout controls how long a UDP session is kept alive without traffic before being ended.
func NewAccessLogger(udpTimeout time.Duration) *AccessLogger {
al := &AccessLogger{
sessions: make(map[string]*AccessSession),
udpSessions: make(map[udpSessionKey]*AccessSession),
completedSessions: make([]*AccessSession, 0),
udpTimeout: udpTimeout,
stopCh: make(chan struct{}),
flushDone: make(chan struct{}),
}
go al.backgroundLoop()
return al
}
// SetSendFunc sets the callback used to send compressed access log batches
// to the server. This can be called after construction once the websocket
// client is available.
func (al *AccessLogger) SetSendFunc(fn SendFunc) {
al.mu.Lock()
defer al.mu.Unlock()
al.sendFn = fn
}
// generateSessionID creates a random session identifier
func generateSessionID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}
// StartTCPSession logs the start of a TCP session and returns a session ID.
func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string {
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "tcp",
StartedAt: now,
}
al.mu.Lock()
al.sessions[sessionID] = session
al.mu.Unlock()
logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndTCPSession logs the end of a TCP session and queues it for sending.
func (al *AccessLogger) EndTCPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// TrackUDPSession starts or returns an existing UDP session. Returns the session ID.
func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string {
key := udpSessionKey{
srcAddr: srcAddr,
dstAddr: dstAddr,
protocol: "udp",
}
al.mu.Lock()
defer al.mu.Unlock()
if existing, ok := al.udpSessions[key]; ok {
return existing.SessionID
}
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "udp",
StartedAt: now,
}
al.sessions[sessionID] = session
al.udpSessions[key] = session
logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndUDPSession ends a UDP session and queues it for sending.
func (al *AccessLogger) EndUDPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
key := udpSessionKey{
srcAddr: session.SourceAddr,
dstAddr: session.DestAddr,
protocol: "udp",
}
delete(al.udpSessions, key)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// backgroundLoop handles periodic flushing and stale session reaping.
func (al *AccessLogger) backgroundLoop() {
defer close(al.flushDone)
flushTicker := time.NewTicker(flushInterval)
defer flushTicker.Stop()
reapTicker := time.NewTicker(30 * time.Second)
defer reapTicker.Stop()
for {
select {
case <-al.stopCh:
return
case <-flushTicker.C:
al.flush()
case <-reapTicker.C:
al.reapStaleSessions()
}
}
}
// reapStaleSessions cleans up UDP sessions that were not properly ended.
func (al *AccessLogger) reapStaleSessions() {
al.mu.Lock()
defer al.mu.Unlock()
staleThreshold := time.Now().Add(-5 * time.Minute)
for key, session := range al.udpSessions {
if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() {
now := time.Now()
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
delete(al.sessions, session.SessionID)
delete(al.udpSessions, key)
}
}
}
// extractIP strips the port from an address string and returns just the IP.
// If the address has no port component it is returned as-is.
func extractIP(addr string) string {
host, _, err := net.SplitHostPort(addr)
if err != nil {
// Might already be a bare IP
return addr
}
return host
}
// consolidateSessions takes a slice of completed sessions and merges bursts of
// short-lived connections from the same source IP to the same destination into
// single higher-level session entries.
//
// The algorithm:
// 1. Group sessions by (sourceIP, destAddr, protocol, resourceID).
// 2. Within each group, sort by StartedAt.
// 3. Walk through the sorted list and merge consecutive sessions whose gap
// (previous EndedAt → next StartedAt) is ≤ sessionGapThreshold.
// 4. For merged sessions the earliest StartedAt and latest EndedAt are kept,
// bytes are summed, and ConnectionCount records how many raw connections
// were folded in. If the merged connections used more than one source port,
// SourceAddr is set to just the IP (port omitted).
// 5. Groups with fewer than minConnectionsToConsolidate members are passed
// through unmodified.
func consolidateSessions(sessions []*AccessSession) []*AccessSession {
if len(sessions) <= 1 {
return sessions
}
// Group sessions by consolidation key
groups := make(map[consolidationKey][]*AccessSession)
for _, s := range sessions {
key := consolidationKey{
sourceIP: extractIP(s.SourceAddr),
destAddr: s.DestAddr,
protocol: s.Protocol,
resourceID: s.ResourceID,
}
groups[key] = append(groups[key], s)
}
result := make([]*AccessSession, 0, len(sessions))
for key, group := range groups {
// Small groups don't need consolidation
if len(group) < minConnectionsToConsolidate {
result = append(result, group...)
continue
}
// Sort the group by start time so we can detect gaps
sort.Slice(group, func(i, j int) bool {
return group[i].StartedAt.Before(group[j].StartedAt)
})
// Walk through and merge runs that are within the gap threshold
var merged []*AccessSession
cur := cloneSession(group[0])
cur.ConnectionCount = 1
sourcePorts := make(map[string]struct{})
sourcePorts[cur.SourceAddr] = struct{}{}
for i := 1; i < len(group); i++ {
s := group[i]
// Determine the gap: from the latest end time we've seen so far to the
// start of the next connection.
gapRef := cur.EndedAt
if gapRef.IsZero() {
gapRef = cur.StartedAt
}
gap := s.StartedAt.Sub(gapRef)
if gap <= sessionGapThreshold {
// Merge into the current consolidated session
cur.ConnectionCount++
cur.BytesTx += s.BytesTx
cur.BytesRx += s.BytesRx
sourcePorts[s.SourceAddr] = struct{}{}
// Extend EndedAt to the latest time
endTime := s.EndedAt
if endTime.IsZero() {
endTime = s.StartedAt
}
if endTime.After(cur.EndedAt) {
cur.EndedAt = endTime
}
} else {
// Gap exceeded — finalize the current session and start a new one
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
cur = cloneSession(s)
cur.ConnectionCount = 1
sourcePorts = make(map[string]struct{})
sourcePorts[s.SourceAddr] = struct{}{}
}
}
// Finalize the last accumulated session
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
result = append(result, merged...)
}
return result
}
// cloneSession creates a shallow copy of an AccessSession.
func cloneSession(s *AccessSession) *AccessSession {
cp := *s
return &cp
}
// finalizeMergedSourceAddr sets the SourceAddr on a consolidated session.
// If multiple distinct source addresses (ports) were seen, the port is
// stripped and only the IP is kept so the log isn't misleading.
func finalizeMergedSourceAddr(s *AccessSession, sourceIP string, ports map[string]struct{}) {
if len(ports) > 1 {
// Multiple source ports — just report the IP
s.SourceAddr = sourceIP
}
// Otherwise keep the original SourceAddr which already has ip:port
}
// flush drains the completed sessions buffer, consolidates bursts of
// short-lived connections, compresses with zlib, and sends via the SendFunc.
func (al *AccessLogger) flush() {
al.mu.Lock()
if len(al.completedSessions) == 0 {
al.mu.Unlock()
return
}
batch := al.completedSessions
al.completedSessions = make([]*AccessSession, 0)
sendFn := al.sendFn
al.mu.Unlock()
if sendFn == nil {
logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch))
return
}
// Consolidate bursts of short-lived connections into higher-level sessions
originalCount := len(batch)
batch = consolidateSessions(batch)
if len(batch) != originalCount {
logger.Info("Access logger: consolidated %d raw connections into %d sessions", originalCount, len(batch))
}
compressed, err := compressSessions(batch)
if err != nil {
logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err)
return
}
if err := sendFn(compressed); err != nil {
logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err)
// Re-queue the batch so we don't lose data
al.mu.Lock()
al.completedSessions = append(batch, al.completedSessions...)
// Cap re-queued data to prevent unbounded growth if server is unreachable
if len(al.completedSessions) > maxBufferedSessions*5 {
dropped := len(al.completedSessions) - maxBufferedSessions*5
al.completedSessions = al.completedSessions[:maxBufferedSessions*5]
logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped)
}
al.mu.Unlock()
return
}
logger.Info("Access logger: sent %d sessions to server", len(batch))
}
// compressSessions JSON-encodes the sessions, compresses with zlib, and returns
// a base64-encoded string suitable for embedding in a JSON message.
func compressSessions(sessions []*AccessSession) (string, error) {
jsonData, err := json.Marshal(sessions)
if err != nil {
return "", err
}
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
if err != nil {
return "", err
}
if _, err := w.Write(jsonData); err != nil {
w.Close()
return "", err
}
if err := w.Close(); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// Close shuts down the background loop, ends all active sessions,
// and performs one final flush to send everything to the server.
func (al *AccessLogger) Close() {
// Signal the background loop to stop
select {
case <-al.stopCh:
// Already closed
return
default:
close(al.stopCh)
}
// Wait for the background loop to exit so we don't race on flush
<-al.flushDone
al.mu.Lock()
now := time.Now()
// End all active sessions and move them to the completed buffer
for _, session := range al.sessions {
if session.EndedAt.IsZero() {
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
}
}
al.sessions = make(map[string]*AccessSession)
al.udpSessions = make(map[udpSessionKey]*AccessSession)
al.mu.Unlock()
// Final flush to send all remaining sessions to the server
al.flush()
}

View File

@@ -1,811 +0,0 @@
package netstack2
import (
"testing"
"time"
)
func TestExtractIP(t *testing.T) {
tests := []struct {
name string
addr string
expected string
}{
{"ipv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"ipv4 without port", "192.168.1.1", "192.168.1.1"},
{"ipv6 with port", "[::1]:12345", "::1"},
{"ipv6 without port", "::1", "::1"},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractIP(tt.addr)
if result != tt.expected {
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, result, tt.expected)
}
})
}
}
func TestConsolidateSessions_Empty(t *testing.T) {
result := consolidateSessions(nil)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
result = consolidateSessions([]*AccessSession{})
if len(result) != 0 {
t.Errorf("expected empty slice, got %d items", len(result))
}
}
func TestConsolidateSessions_SingleSession(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "abc123",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(1 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr preserved, got %q", result[0].SourceAddr)
}
}
func TestConsolidateSessions_MergesBurstFromSameSourceIP(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
BytesRx: 200,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 150,
BytesRx: 250,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 50,
BytesRx: 75,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only (multiple ports), got %q", s.SourceAddr)
}
if s.DestAddr != "192.168.1.100:443" {
t.Errorf("expected dest addr preserved, got %q", s.DestAddr)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
expectedTx := int64(300)
expectedRx := int64(525)
if s.BytesTx != expectedTx {
t.Errorf("expected BytesTx=%d, got %d", expectedTx, s.BytesTx)
}
if s.BytesRx != expectedRx {
t.Errorf("expected BytesRx=%d, got %d", expectedRx, s.BytesRx)
}
}
func TestConsolidateSessions_SameSourcePortPreserved(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr with port preserved when all ports are the same, got %q", result[0].SourceAddr)
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_GapSplitsSessions(t *testing.T) {
now := time.Now()
// First burst
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
// Big gap here (10 seconds)
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10 * time.Second),
EndedAt: now.Add(10*time.Second + 100*time.Millisecond),
},
{
SessionID: "s4",
ResourceID: 1,
SourceAddr: "10.0.0.1:5003",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10*time.Second + 200*time.Millisecond),
EndedAt: now.Add(10*time.Second + 300*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 consolidated sessions (gap split), got %d", len(result))
}
// Find the sessions by their start time
var first, second *AccessSession
for _, s := range result {
if s.StartedAt.Equal(now) {
first = s
} else {
second = s
}
}
if first == nil || second == nil {
t.Fatal("could not find both consolidated sessions")
}
if first.ConnectionCount != 2 {
t.Errorf("first burst: expected ConnectionCount=2, got %d", first.ConnectionCount)
}
if second.ConnectionCount != 2 {
t.Errorf("second burst: expected ConnectionCount=2, got %d", second.ConnectionCount)
}
}
func TestConsolidateSessions_DifferentDestinationsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:8080",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
// Each goes to a different dest port so they should not be merged
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different destinations), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentProtocolsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different protocols), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentResourceIDsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 2,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different resource IDs), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentSourceIPsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.2:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different source IPs), got %d", len(result))
}
}
func TestConsolidateSessions_OutOfOrderInput(t *testing.T) {
now := time.Now()
// Provide sessions out of chronological order to verify sorting
sessions := []*AccessSession{
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 30,
},
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 10,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 20,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
if s.BytesTx != 60 {
t.Errorf("expected BytesTx=60, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_ExactlyAtGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts exactly sessionGapThreshold after s1 ends — should still merge
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session (gap exactly at threshold merges), got %d", len(result))
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_JustOverGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts 1ms over the gap threshold after s1 ends — should split
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 1*time.Millisecond),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (gap just over threshold splits), got %d", len(result))
}
}
func TestConsolidateSessions_UDPSessions(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "u1",
ResourceID: 5,
SourceAddr: "10.0.0.1:6000",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now,
EndedAt: now.Add(50 * time.Millisecond),
BytesTx: 64,
BytesRx: 512,
},
{
SessionID: "u2",
ResourceID: 5,
SourceAddr: "10.0.0.1:6001",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(100 * time.Millisecond),
EndedAt: now.Add(150 * time.Millisecond),
BytesTx: 64,
BytesRx: 256,
},
{
SessionID: "u3",
ResourceID: 5,
SourceAddr: "10.0.0.1:6002",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(250 * time.Millisecond),
BytesTx: 64,
BytesRx: 128,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated UDP session, got %d", len(result))
}
s := result[0]
if s.Protocol != "udp" {
t.Errorf("expected protocol=udp, got %q", s.Protocol)
}
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only, got %q", s.SourceAddr)
}
if s.BytesTx != 192 {
t.Errorf("expected BytesTx=192, got %d", s.BytesTx)
}
if s.BytesRx != 896 {
t.Errorf("expected BytesRx=896, got %d", s.BytesRx)
}
}
func TestConsolidateSessions_MixedGroupsSomeConsolidatedSomeNot(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
// Group 1: 3 connections to :443 from same IP — should consolidate
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
},
// Group 2: 1 connection to :8080 from different IP — should pass through
{
SessionID: "s4",
ResourceID: 2,
SourceAddr: "10.0.0.2:6000",
DestAddr: "192.168.1.200:8080",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(2 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions total, got %d", len(result))
}
var consolidated, passthrough *AccessSession
for _, s := range result {
if s.ConnectionCount > 1 {
consolidated = s
} else {
passthrough = s
}
}
if consolidated == nil {
t.Fatal("expected a consolidated session")
}
if consolidated.ConnectionCount != 3 {
t.Errorf("consolidated: expected ConnectionCount=3, got %d", consolidated.ConnectionCount)
}
if passthrough == nil {
t.Fatal("expected a passthrough session")
}
if passthrough.SessionID != "s4" {
t.Errorf("passthrough: expected session s4, got %s", passthrough.SessionID)
}
}
func TestConsolidateSessions_OverlappingConnections(t *testing.T) {
now := time.Now()
// Connections that overlap in time (not sequential)
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(5 * time.Second),
BytesTx: 100,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(3 * time.Second),
BytesTx: 200,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(2 * time.Second),
EndedAt: now.Add(6 * time.Second),
BytesTx: 300,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Error("expected StartedAt to be earliest")
}
if s.EndedAt != now.Add(6*time.Second) {
t.Error("expected EndedAt to be the latest end time")
}
if s.BytesTx != 600 {
t.Errorf("expected BytesTx=600, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_DoesNotMutateOriginals(t *testing.T) {
now := time.Now()
s1 := &AccessSession{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
}
s2 := &AccessSession{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 200,
}
// Save original values
origS1Addr := s1.SourceAddr
origS1Bytes := s1.BytesTx
origS2Addr := s2.SourceAddr
origS2Bytes := s2.BytesTx
_ = consolidateSessions([]*AccessSession{s1, s2})
if s1.SourceAddr != origS1Addr {
t.Errorf("s1.SourceAddr was mutated: %q -> %q", origS1Addr, s1.SourceAddr)
}
if s1.BytesTx != origS1Bytes {
t.Errorf("s1.BytesTx was mutated: %d -> %d", origS1Bytes, s1.BytesTx)
}
if s2.SourceAddr != origS2Addr {
t.Errorf("s2.SourceAddr was mutated: %q -> %q", origS2Addr, s2.SourceAddr)
}
if s2.BytesTx != origS2Bytes {
t.Errorf("s2.BytesTx was mutated: %d -> %d", origS2Bytes, s2.BytesTx)
}
}
func TestConsolidateSessions_ThreeBurstsWithGaps(t *testing.T) {
now := time.Now()
sessions := make([]*AccessSession, 0, 9)
// Burst 1: 3 connections at t=0
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('A'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(time.Duration(i*100) * time.Millisecond),
EndedAt: now.Add(time.Duration(i*100+50) * time.Millisecond),
})
}
// Burst 2: 3 connections at t=20s (well past the 5s gap)
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('D'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(20*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(20*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
// Burst 3: 3 connections at t=40s
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('G'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(40*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(40*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
result := consolidateSessions(sessions)
if len(result) != 3 {
t.Fatalf("expected 3 consolidated sessions (3 bursts), got %d", len(result))
}
for _, s := range result {
if s.ConnectionCount != 3 {
t.Errorf("expected each burst to have ConnectionCount=3, got %d (started=%v)", s.ConnectionCount, s.StartedAt)
}
}
}
func TestFinalizeMergedSourceAddr(t *testing.T) {
s := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports := map[string]struct{}{"10.0.0.1:5000": {}}
finalizeMergedSourceAddr(s, "10.0.0.1", ports)
if s.SourceAddr != "10.0.0.1:5000" {
t.Errorf("single port: expected addr preserved, got %q", s.SourceAddr)
}
s2 := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports2 := map[string]struct{}{"10.0.0.1:5000": {}, "10.0.0.1:5001": {}}
finalizeMergedSourceAddr(s2, "10.0.0.1", ports2)
if s2.SourceAddr != "10.0.0.1" {
t.Errorf("multiple ports: expected IP only, got %q", s2.SourceAddr)
}
}
func TestCloneSession(t *testing.T) {
original := &AccessSession{
SessionID: "test",
ResourceID: 42,
SourceAddr: "1.2.3.4:100",
DestAddr: "5.6.7.8:443",
Protocol: "tcp",
BytesTx: 999,
}
clone := cloneSession(original)
if clone == original {
t.Error("clone should be a different pointer")
}
if clone.SessionID != original.SessionID {
t.Error("clone should have same SessionID")
}
// Mutating clone should not affect original
clone.BytesTx = 0
clone.SourceAddr = "changed"
if original.BytesTx != 999 {
t.Error("mutating clone affected original BytesTx")
}
if original.SourceAddr != "1.2.3.4:100" {
t.Error("mutating clone affected original SourceAddr")
}
}

View File

@@ -158,18 +158,6 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr)
}
}
}
// Create context with timeout for connection establishment
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
@@ -179,26 +167,11 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
if err != nil {
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
// End access session on connection failure
if accessSessionID != "" {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}
// Connection failed, netstack will handle RST
return
}
defer targetConn.Close()
// End access session when connection closes
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}()
}
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
// Bidirectional copy between netstack and target
@@ -307,27 +280,6 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr)
}
}
}
// End access session when UDP handler returns (timeout or error)
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndUDPSession(accessSessionID)
}
}()
}
// Resolve target address
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {

View File

@@ -22,12 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
// udpAccessSessionTimeout is how long a UDP access session stays alive without traffic
// before being considered ended by the access logger
udpAccessSessionTimeout = 120 * time.Second
)
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
type PortRange struct {
@@ -52,24 +46,115 @@ type SubnetRule struct {
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed
ResourceId int // Optional resource ID from the server for access logging
}
// GetAllRules returns a copy of all subnet rules
func (sl *SubnetLookup) GetAllRules() []SubnetRule {
// ruleKey is used as a map key for fast O(1) lookups
type ruleKey struct {
sourcePrefix string
destPrefix string
}
// SubnetLookup provides fast IP subnet and port matching with O(1) lookup performance
type SubnetLookup struct {
mu sync.RWMutex
rules map[ruleKey]*SubnetRule // Map for O(1) lookups by prefix combination
}
// NewSubnetLookup creates a new subnet lookup table
func NewSubnetLookup() *SubnetLookup {
return &SubnetLookup{
rules: make(map[ruleKey]*SubnetRule),
}
}
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
sl.mu.Lock()
defer sl.mu.Unlock()
key := ruleKey{
sourcePrefix: sourcePrefix.String(),
destPrefix: destPrefix.String(),
}
sl.rules[key] = &SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
DisableIcmp: disableIcmp,
RewriteTo: rewriteTo,
PortRanges: portRanges,
}
}
// RemoveSubnet removes a subnet rule from the lookup table
func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
sl.mu.Lock()
defer sl.mu.Unlock()
key := ruleKey{
sourcePrefix: sourcePrefix.String(),
destPrefix: destPrefix.String(),
}
delete(sl.rules, key)
}
// Match checks if a source IP, destination IP, port, and protocol match any subnet rule
// Returns the matched rule if ALL of these conditions are met:
// - The source IP is in the rule's source prefix
// - The destination IP is in the rule's destination prefix
// - The port is in an allowed range (or no port restrictions exist)
// - The protocol matches (or the port range allows both protocols)
//
// proto should be header.TCPProtocolNumber or header.UDPProtocolNumber
// Returns nil if no rule matches
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule {
sl.mu.RLock()
defer sl.mu.RUnlock()
var rules []SubnetRule
for _, destTriePtr := range sl.sourceTrie.All() {
if destTriePtr == nil {
// Iterate through all rules to find matching source and destination prefixes
// This is O(n) but necessary since we need to check prefix containment, not exact match
for _, rule := range sl.rules {
// Check if source and destination IPs match their respective prefixes
if !rule.SourcePrefix.Contains(srcIP) {
continue
}
for _, rule := range destTriePtr.rules {
rules = append(rules, *rule)
if !rule.DestPrefix.Contains(dstIP) {
continue
}
if rule.DisableIcmp && (proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber) {
// ICMP is disabled for this subnet
return nil
}
// Both IPs match - now check port restrictions
// If no port ranges specified, all ports are allowed
if len(rule.PortRanges) == 0 {
return rule
}
// Check if port and protocol are in any of the allowed ranges
for _, pr := range rule.PortRanges {
if port >= pr.Min && port <= pr.Max {
// Check protocol compatibility
if pr.Protocol == "" {
// Empty protocol means allow both TCP and UDP
return rule
}
// Check if the packet protocol matches the port range protocol
if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) ||
(pr.Protocol == "udp" && proto == header.UDPProtocolNumber) {
return rule
}
// Port matches but protocol doesn't - continue checking other ranges
}
}
return rules
}
return nil
}
// connKey uniquely identifies a connection for NAT tracking
@@ -81,17 +166,6 @@ type connKey struct {
proto uint8
}
// reverseConnKey uniquely identifies a connection for reverse NAT lookup (reply direction)
// Key structure: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto)
// This allows O(1) lookup of NAT entries for reply packets
type reverseConnKey struct {
rewrittenTo string // The address we rewrote to (becomes src in replies)
originalSrcIP string // Original source IP (becomes dst in replies)
originalSrcPort uint16 // Original source port (becomes dst port in replies)
originalDstPort uint16 // Original destination port (becomes src port in replies)
proto uint8
}
// destKey identifies a destination for handler lookups (without source port since it may change)
type destKey struct {
srcIP string
@@ -116,14 +190,11 @@ type ProxyHandler struct {
icmpHandler *ICMPHandler
subnetLookup *SubnetLookup
natTable map[connKey]*natState
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
resourceTable map[destKey]int // Maps connection key to resource ID for access logging
natMu sync.RWMutex
enabled bool
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
notifiable channel.Notification // Notification handler for triggering reads
accessLogger *AccessLogger // Access logger for tracking sessions
}
// ProxyHandlerOptions configures the proxy handler
@@ -144,11 +215,8 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
enabled: true,
subnetLookup: NewSubnetLookup(),
natTable: make(map[connKey]*natState),
reverseNatTable: make(map[reverseConnKey]*natState),
destRewriteTable: make(map[destKey]netip.Addr),
resourceTable: make(map[destKey]int),
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
accessLogger: NewAccessLogger(udpAccessSessionTimeout),
proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -213,11 +281,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
// destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
}
// RemoveSubnetRule removes a subnet from the proxy handler
@@ -228,51 +296,6 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
}
// GetAllRules returns all subnet rules from the proxy handler
func (p *ProxyHandler) GetAllRules() []SubnetRule {
if p == nil || !p.enabled {
return nil
}
return p.subnetLookup.GetAllRules()
}
// LookupResourceId looks up the resource ID for a connection
// Returns 0 if no resource ID is associated with this connection
func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int {
if p == nil || !p.enabled {
return 0
}
key := destKey{
srcIP: srcIP,
dstIP: dstIP,
dstPort: dstPort,
proto: proto,
}
p.natMu.RLock()
defer p.natMu.RUnlock()
return p.resourceTable[key]
}
// GetAccessLogger returns the access logger for session tracking
func (p *ProxyHandler) GetAccessLogger() *AccessLogger {
if p == nil {
return nil
}
return p.accessLogger
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
if p == nil || !p.enabled || p.accessLogger == nil {
return
}
p.accessLogger.SetSendFunc(fn)
}
// LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
@@ -435,22 +458,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
// Check if the source IP, destination IP, port, and protocol match any subnet rule
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
if matchedRule != nil {
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)",
srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId)
// Store resource ID for connections without DNAT as well
if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" {
dKey := destKey{
srcIP: srcAddr.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
srcAddr, dstAddr, protocol, dstPort)
// Check if we need to perform DNAT
if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination
@@ -482,13 +491,6 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
proto: uint8(protocol),
}
// Store resource ID for access logging if present
if matchedRule.ResourceId != 0 {
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we already have a NAT entry for this connection
p.natMu.RLock()
existingEntry, exists := p.natTable[key]
@@ -515,23 +517,10 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
// Store NAT state for this connection
p.natMu.Lock()
natEntry := &natState{
p.natTable[key] = &natState{
originalDst: dstAddr,
rewrittenTo: newDst,
}
p.natTable[key] = natEntry
// Create reverse lookup key for O(1) reply packet lookups
// Key: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto)
reverseKey := reverseConnKey{
rewrittenTo: newDst.String(),
originalSrcIP: srcAddr.String(),
originalSrcPort: srcPort,
originalDstPort: dstPort,
proto: uint8(protocol),
}
p.reverseNatTable[reverseKey] = natEntry
// Store destination rewrite for handler lookups
p.destRewriteTable[dKey] = newDst
p.natMu.Unlock()
@@ -730,22 +719,20 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
return view
}
// Look up NAT state for reverse translation using O(1) reverse lookup map
// Key: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto)
// For reply packets:
// - reply's srcIP = rewrittenTo (the address we rewrote to)
// - reply's dstIP = originalSrcIP (original source IP)
// - reply's srcPort = originalDstPort (original destination port)
// - reply's dstPort = originalSrcPort (original source port)
// Look up NAT state for reverse translation
// The key uses the original dst (before rewrite), so for replies we need to
// find the entry where the rewritten address matches the current source
p.natMu.RLock()
reverseKey := reverseConnKey{
rewrittenTo: srcIP.String(), // Reply's source is the rewritten address
originalSrcIP: dstIP.String(), // Reply's destination is the original source
originalSrcPort: dstPort, // Reply's destination port is the original source port
originalDstPort: srcPort, // Reply's source port is the original destination port
proto: uint8(protocol),
var natEntry *natState
for k, entry := range p.natTable {
// Match: reply's dst should be original src, reply's src should be rewritten dst
if k.srcIP == dstIP.String() && k.srcPort == dstPort &&
entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort &&
k.proto == uint8(protocol) {
natEntry = entry
break
}
}
natEntry := p.reverseNatTable[reverseKey]
p.natMu.RUnlock()
if natEntry != nil {
@@ -789,11 +776,6 @@ func (p *ProxyHandler) Close() error {
return nil
}
// Shut down access logger
if p.accessLogger != nil {
p.accessLogger.Close()
}
// Close ICMP replies channel
if p.icmpReplies != nil {
close(p.icmpReplies)

View File

@@ -1,207 +0,0 @@
package netstack2
import (
"net/netip"
"sync"
"github.com/gaissmai/bart"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// SubnetLookup provides fast IP subnet and port matching using BART (Binary Aggregated Range Tree)
// This uses BART Table for O(log n) prefix matching with Supernets() for efficient lookups
//
// Architecture:
// - Two-level BART structure for matching both source AND destination prefixes
// - Level 1: Source prefix -> Level 2 (destination prefix -> rules)
// - This reduces search space: only check destination prefixes for matching source prefixes
type SubnetLookup struct {
mu sync.RWMutex
// Two-level BART structure:
// Level 1: Source prefix -> Level 2 (destination prefix -> rules)
// This allows us to first match source prefix, then only check destination prefixes
// for matching source prefixes, reducing the search space significantly
sourceTrie *bart.Table[*destTrie]
}
// destTrie is a BART for destination prefixes, containing the actual rules
type destTrie struct {
trie *bart.Table[[]*SubnetRule]
rules []*SubnetRule // All rules for this source prefix (for iteration if needed)
}
// NewSubnetLookup creates a new subnet lookup table using BART
func NewSubnetLookup() *SubnetLookup {
return &SubnetLookup{
sourceTrie: &bart.Table[*destTrie]{},
}
}
// prefixEqual compares two prefixes after masking to handle host bits correctly.
// For example, 10.0.0.5/24 and 10.0.0.0/24 are treated as equal.
func prefixEqual(a, b netip.Prefix) bool {
return a.Masked() == b.Masked()
}
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
sl.mu.Lock()
defer sl.mu.Unlock()
rule := &SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
DisableIcmp: disableIcmp,
RewriteTo: rewriteTo,
PortRanges: portRanges,
ResourceId: resourceId,
}
// Canonicalize source prefix to handle host bits correctly
canonicalSourcePrefix := sourcePrefix.Masked()
// Get or create destination trie for this source prefix
destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix)
if !exists {
// Create new destination trie for this source prefix
destTriePtr = &destTrie{
trie: &bart.Table[[]*SubnetRule]{},
rules: make([]*SubnetRule, 0),
}
sl.sourceTrie.Insert(canonicalSourcePrefix, destTriePtr)
}
// Canonicalize destination prefix to handle host bits correctly
// BART masks prefixes internally, so we need to match that behavior in our bookkeeping
canonicalDestPrefix := destPrefix.Masked()
// Add rule to destination trie
// Original behavior: overwrite if same (sourcePrefix, destPrefix) exists
// Store as single-element slice to match original overwrite behavior
destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rule})
// Update destTriePtr.rules - remove old rule with same canonical prefix if exists, then add new one
// Use canonical comparison to handle cases like 10.0.0.5/24 vs 10.0.0.0/24
newRules := make([]*SubnetRule, 0, len(destTriePtr.rules)+1)
for _, r := range destTriePtr.rules {
if !prefixEqual(r.DestPrefix, canonicalDestPrefix) || !prefixEqual(r.SourcePrefix, canonicalSourcePrefix) {
newRules = append(newRules, r)
}
}
newRules = append(newRules, rule)
destTriePtr.rules = newRules
}
// RemoveSubnet removes a subnet rule from the lookup table
func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
sl.mu.Lock()
defer sl.mu.Unlock()
// Canonicalize prefixes to handle host bits correctly
canonicalSourcePrefix := sourcePrefix.Masked()
canonicalDestPrefix := destPrefix.Masked()
destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix)
if !exists {
return
}
// Remove the rule - original behavior: delete exact (sourcePrefix, destPrefix) combination
// BART masks prefixes internally, so Delete works with canonical form
destTriePtr.trie.Delete(canonicalDestPrefix)
// Also remove from destTriePtr.rules using canonical comparison
// This ensures we remove rules even if they were added with host bits set
newDestRules := make([]*SubnetRule, 0, len(destTriePtr.rules))
for _, r := range destTriePtr.rules {
if !prefixEqual(r.DestPrefix, canonicalDestPrefix) || !prefixEqual(r.SourcePrefix, canonicalSourcePrefix) {
newDestRules = append(newDestRules, r)
}
}
destTriePtr.rules = newDestRules
// Check if the trie is actually empty using BART's Size() method
// This is more efficient than iterating and ensures we clean up empty tries
// even if there were stale entries in the rules slice (which shouldn't happen
// with proper canonicalization, but this provides a definitive check)
if destTriePtr.trie.Size() == 0 {
sl.sourceTrie.Delete(canonicalSourcePrefix)
}
}
// Match checks if a source IP, destination IP, port, and protocol match any subnet rule
// Returns the matched rule if ALL of these conditions are met:
// - The source IP is in the rule's source prefix
// - The destination IP is in the rule's destination prefix
// - The port is in an allowed range (or no port restrictions exist)
// - The protocol matches (or the port range allows both protocols)
//
// proto should be header.TCPProtocolNumber, header.UDPProtocolNumber, or header.ICMPv4ProtocolNumber
// Returns nil if no rule matches
// This uses BART's Supernets() for O(log n) prefix matching instead of O(n) iteration
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule {
sl.mu.RLock()
defer sl.mu.RUnlock()
// Convert IP addresses to /32 (IPv4) or /128 (IPv6) prefixes
// Supernets() finds all prefixes that contain this IP (i.e., are supernets of /32 or /128)
srcPrefix := netip.PrefixFrom(srcIP, srcIP.BitLen())
dstPrefix := netip.PrefixFrom(dstIP, dstIP.BitLen())
// Step 1: Find all source prefixes that contain srcIP using BART's Supernets
// This is O(log n) instead of O(n) iteration
// Supernets returns all prefixes that are supernets (contain) the given prefix
for _, destTriePtr := range sl.sourceTrie.Supernets(srcPrefix) {
if destTriePtr == nil {
continue
}
// Step 2: Find all destination prefixes that contain dstIP
// This is also O(log n) for each matching source prefix
for _, rules := range destTriePtr.trie.Supernets(dstPrefix) {
if rules == nil {
continue
}
// Step 3: Check each rule for ICMP and port restrictions
for _, rule := range rules {
// Handle ICMP before port range check — ICMP has no ports
if proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber {
if rule.DisableIcmp {
return nil
}
// ICMP is allowed; port ranges don't apply to ICMP
return rule
}
// Check port restrictions
if len(rule.PortRanges) == 0 {
// No port restrictions, match!
return rule
}
// Check if port and protocol are in any of the allowed ranges
for _, pr := range rule.PortRanges {
if port >= pr.Min && port <= pr.Max {
// Check protocol compatibility
if pr.Protocol == "" {
// Empty protocol means allow both TCP and UDP
return rule
}
// Check if the packet protocol matches the port range protocol
if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) ||
(pr.Protocol == "udp" && proto == header.UDPProtocolNumber) {
return rule
}
// Port matches but protocol doesn't - continue checking other ranges
}
}
}
}
}
return nil
}

View File

@@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
// AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
}
}
@@ -369,15 +369,6 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) {
}
}
// GetProxySubnetRules returns all subnet rules from the proxy handler
func (net *Net) GetProxySubnetRules() []SubnetRule {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
return tun.proxyHandler.GetAllRules()
}
return nil
}
// GetProxyHandler returns the proxy handler (for advanced use cases)
// Returns nil if proxy is not enabled
func (net *Net) GetProxyHandler() *ProxyHandler {
@@ -385,15 +376,6 @@ func (net *Net) GetProxyHandler() *ProxyHandler {
return tun.proxyHandler
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (net *Net) SetAccessLogSender(fn SendFunc) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.SetAccessLogSender(fn)
}
}
type PingConn struct {
laddr PingAddr
raddr PingAddr

View File

@@ -32,7 +32,7 @@ DefaultGroupName={#MyAppName}
DisableProgramGroupPage=yes
; Uncomment the following line to run in non administrative install mode (install for current user only).
;PrivilegesRequired=lowest
OutputBaseFilename=newt_windows_installer
OutputBaseFilename=mysetup
SolidCompression=yes
WizardStyle=modern
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed

View File

@@ -21,10 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
const (
errUnsupportedProtoFmt = "unsupported protocol: %s"
maxUDPPacketSize = 65507
)
const errUnsupportedProtoFmt = "unsupported protocol: %s"
// Target represents a proxy target with its address and port
type Target struct {
@@ -108,10 +105,14 @@ func classifyProxyError(err error) string {
if errors.Is(err, net.ErrClosed) {
return "closed"
}
var ne net.Error
if errors.As(err, &ne) && ne.Timeout() {
if ne, ok := err.(net.Error); ok {
if ne.Timeout() {
return "timeout"
}
if ne.Temporary() {
return "temporary"
}
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "refused"):
@@ -436,6 +437,14 @@ func (pm *ProxyManager) Stop() error {
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
}
// // Clear the target maps
// for k := range pm.tcpTargets {
// delete(pm.tcpTargets, k)
// }
// for k := range pm.udpTargets {
// delete(pm.udpTargets, k)
// }
// Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond)
@@ -489,7 +498,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
if !pm.running {
return
}
if errors.Is(err, net.ErrClosed) {
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return
}
@@ -555,7 +564,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
}
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
buffer := make([]byte, 65507) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex
@@ -574,7 +583,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
}
// Check for connection closed conditions
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("UDP connection closed, stopping proxy handler")
// Clean up existing client connections
@@ -653,14 +662,10 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
}()
buffer := make([]byte, maxUDPPacketSize)
buffer := make([]byte, 65507)
for {
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
// Connection closed is normal during cleanup
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return // defer will handle cleanup, result stays "success"
}
logger.Error("Error reading from target: %v", err)
result = "failure"
return // defer will handle cleanup
@@ -731,28 +736,3 @@ func (pm *ProxyManager) PrintTargets() {
}
}
}
// GetTargets returns a copy of the current TCP and UDP targets
// Returns map[listenIP]map[port]targetAddress for both TCP and UDP
func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
tcpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.tcpTargets {
tcpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
tcpTargets[listenIP][port] = targetAddr
}
}
udpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.udpTargets {
udpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
udpTargets[listenIP][port] = targetAddr
}
}
return tcpTargets, udpTargets
}

11
scripts/nfpm.yaml.tmpl Normal file
View File

@@ -0,0 +1,11 @@
name: __PKG_NAME__
arch: __ARCH__
platform: linux
version: __VERSION__
section: net
priority: optional
maintainer: fosrl
description: Newt - userspace tunnel client and TCP/UDP proxy
contents:
- src: build/newt
dst: /usr/bin/newt

139
scripts/publish-apt.sh Normal file
View File

@@ -0,0 +1,139 @@
#!/usr/bin/env bash
set -euo pipefail
# ---- required env ----
: "${GH_REPO:?}"
: "${S3_BUCKET:?}"
: "${AWS_REGION:?}"
: "${CLOUDFRONT_DISTRIBUTION_ID:?}"
: "${PKG_NAME:?}"
: "${SUITE:?}"
: "${COMPONENT:?}"
: "${APT_GPG_PRIVATE_KEY:?}"
S3_PREFIX="${S3_PREFIX:-}"
if [[ -n "${S3_PREFIX}" && "${S3_PREFIX}" != */ ]]; then
S3_PREFIX="${S3_PREFIX}/"
fi
WORKDIR="$(pwd)"
mkdir -p repo/apt assets build
echo "${APT_GPG_PRIVATE_KEY}" | gpg --batch --import >/dev/null 2>&1 || true
KEYID="$(gpg --list-secret-keys --with-colons | awk -F: '$1=="sec"{print $5; exit}')"
if [[ -z "${KEYID}" ]]; then
echo "ERROR: No GPG secret key available after import."
exit 1
fi
# Determine which tags to process
TAGS=""
if [[ "${BACKFILL_ALL:-false}" == "true" ]]; then
echo "Backfill mode: collecting all release tags..."
TAGS="$(gh release list -R "${GH_REPO}" --limit 200 --json tagName --jq '.[].tagName')"
else
if [[ -n "${INPUT_TAG:-}" ]]; then
TAGS="${INPUT_TAG}"
elif [[ -n "${EVENT_TAG:-}" ]]; then
TAGS="${EVENT_TAG}"
else
echo "No tag provided; using latest release tag..."
TAGS="$(gh release view -R "${GH_REPO}" --json tagName --jq '.tagName')"
fi
fi
echo "Tags to process:"
printf '%s\n' "${TAGS}"
# Pull existing repo from S3 so we keep older versions
echo "Sync existing repo from S3..."
aws s3 sync "s3://${S3_BUCKET}/${S3_PREFIX}apt/" repo/apt/ >/dev/null 2>&1 || true
# Build and add packages
while IFS= read -r TAG; do
[[ -z "${TAG}" ]] && continue
echo "=== Processing tag: ${TAG} ==="
rm -rf assets build
mkdir -p assets build
gh release download "${TAG}" -R "${GH_REPO}" -p "newt_linux_amd64" -D assets
gh release download "${TAG}" -R "${GH_REPO}" -p "newt_linux_arm64" -D assets
VERSION="${TAG#v}"
for arch in amd64 arm64; do
bin="assets/newt_linux_${arch}"
if [[ ! -f "${bin}" ]]; then
echo "ERROR: Missing release asset: ${bin}"
exit 1
fi
install -Dm755 "${bin}" "build/newt"
# Create nfpm config from template file (no heredoc here)
sed \
-e "s/__PKG_NAME__/${PKG_NAME}/g" \
-e "s/__ARCH__/${arch}/g" \
-e "s/__VERSION__/${VERSION}/g" \
scripts/nfpm.yaml.tmpl > nfpm.yaml
nfpm package -p deb -f nfpm.yaml -t "build/${PKG_NAME}_${VERSION}_${arch}.deb"
done
mkdir -p "repo/apt/pool/${COMPONENT}/${PKG_NAME:0:1}/${PKG_NAME}/"
cp -v build/*.deb "repo/apt/pool/${COMPONENT}/${PKG_NAME:0:1}/${PKG_NAME}/"
done <<< "${TAGS}"
# Regenerate metadata
cd repo/apt
for arch in amd64 arm64; do
mkdir -p "dists/${SUITE}/${COMPONENT}/binary-${arch}"
dpkg-scanpackages -a "${arch}" pool > "dists/${SUITE}/${COMPONENT}/binary-${arch}/Packages"
gzip -fk "dists/${SUITE}/${COMPONENT}/binary-${arch}/Packages"
done
# Release file with hashes
cat > apt-ftparchive.conf <<EOF
APT::FTPArchive::Release::Origin "fosrl";
APT::FTPArchive::Release::Label "newt";
APT::FTPArchive::Release::Suite "${SUITE}";
APT::FTPArchive::Release::Codename "${SUITE}";
APT::FTPArchive::Release::Architectures "amd64 arm64";
APT::FTPArchive::Release::Components "${COMPONENT}";
APT::FTPArchive::Release::Description "Newt APT repository";
EOF
apt-ftparchive -c apt-ftparchive.conf release "dists/${SUITE}" > "dists/${SUITE}/Release"
# Sign Release
cd "dists/${SUITE}"
gpg --batch --yes --pinentry-mode loopback \
${APT_GPG_PASSPHRASE:+--passphrase "${APT_GPG_PASSPHRASE}"} \
--local-user "${KEYID}" \
--clearsign -o InRelease Release
gpg --batch --yes --pinentry-mode loopback \
${APT_GPG_PASSPHRASE:+--passphrase "${APT_GPG_PASSPHRASE}"} \
--local-user "${KEYID}" \
-abs -o Release.gpg Release
# Export public key into apt repo root
cd ../../..
gpg --batch --yes --armor --export "${KEYID}" > public.key
# Upload to S3
echo "Uploading to S3..."
aws s3 sync "${WORKDIR}/repo/apt" "s3://${S3_BUCKET}/${S3_PREFIX}apt/" --delete
# Invalidate metadata
echo "CloudFront invalidation..."
aws cloudfront create-invalidation \
--distribution-id "${CLOUDFRONT_DISTRIBUTION_ID}" \
--paths "/${S3_PREFIX}apt/dists/*" "/${S3_PREFIX}apt/public.key"
echo "Done. Repo base: ${REPO_BASE_URL}"

View File

@@ -2,7 +2,6 @@ package websocket
import (
"bytes"
"compress/gzip"
"crypto/tls"
"crypto/x509"
"encoding/json"
@@ -38,21 +37,16 @@ type Client struct {
isConnected bool
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error
onTokenUpdate func(token string)
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
configFilePath string // Optional override for the config file path
tlsConfig TLSConfig
metricsCtxMu sync.RWMutex
metricsCtx context.Context
configNeedsSave bool // Flag to track if config needs to be saved
serverVersion string
configVersion int64 // Latest config version received from server
configVersionMux sync.RWMutex
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
}
type ClientOption func(*Client)
@@ -78,12 +72,6 @@ func WithBaseURL(url string) ClientOption {
}
// WithTLSConfig sets the TLS configuration for the client
func WithConfigFile(path string) ClientOption {
return func(c *Client) {
c.configFilePath = path
}
}
func WithTLSConfig(config TLSConfig) ClientOption {
return func(c *Client) {
c.tlsConfig = config
@@ -123,7 +111,7 @@ func (c *Client) MetricsContext() context.Context {
}
// NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) {
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{
ID: ID,
Secret: secret,
@@ -138,6 +126,7 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
reconnectInterval: 3 * time.Second,
isConnected: false,
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: clientType,
}
@@ -165,20 +154,6 @@ func (c *Client) GetServerVersion() string {
return c.serverVersion
}
// GetConfigVersion returns the latest config version received from server
func (c *Client) GetConfigVersion() int64 {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version
func (c *Client) setConfigVersion(version int64) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
c.configVersion = version
}
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
go c.connectWithRetry()
@@ -488,11 +463,6 @@ func (c *Client) connectWithRetry() {
func (c *Client) establishConnection() error {
ctx := context.Background()
// Exchange provisioning key for permanent credentials if needed.
if err := c.provisionIfNeeded(); err != nil {
return fmt.Errorf("failed to provision newt credentials: %w", err)
}
// Get token for authentication
token, err := c.getToken()
if err != nil {
@@ -671,37 +641,24 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
}
// pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) sendPing() {
func (c *Client) pingMonitor() {
ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
if c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("Skipping ping, message is being processed")
return
}
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingMsg := WSMessage{
Type: "newt/ping",
Data: map[string]interface{}{},
ConfigVersion: configVersion,
}
c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg)
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
}
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
@@ -717,21 +674,6 @@ func (c *Client) sendPing() {
}
}
}
func (c *Client) pingMonitor() {
// Send an immediate ping as soon as we connect
c.sendPing()
ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
c.sendPing()
}
}
}
@@ -767,14 +709,11 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
disconnectResult = "success"
return
default:
msgType, p, err := c.conn.ReadMessage()
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err == nil {
if msgType == websocket.BinaryMessage {
telemetry.IncWSMessage(c.metricsContext(), "in", "binary")
} else {
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
}
}
if err != nil {
// Check if we're shutting down before logging error
select {
@@ -798,47 +737,9 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
}
}
// Update config version from incoming message
var data []byte
if msgType == websocket.BinaryMessage {
gr, err := gzip.NewReader(bytes.NewReader(p))
if err != nil {
logger.Error("WebSocket failed to create gzip reader: %v", err)
continue
}
data, err = io.ReadAll(gr)
gr.Close()
if err != nil {
logger.Error("WebSocket failed to decompress message: %v", err)
continue
}
} else {
data = p
}
var msg WSMessage
if err = json.Unmarshal(data, &msg); err != nil {
logger.Error("WebSocket failed to parse message: %v", err)
continue
}
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
}
c.handlersMux.RUnlock()
}

View File

@@ -1,28 +1,16 @@
package websocket
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/fosrl/newt/logger"
)
func getConfigPath(clientType string, overridePath string) string {
if overridePath != "" {
return overridePath
}
func getConfigPath(clientType string) string {
configFile := os.Getenv("CONFIG_FILE")
if configFile == "" {
var configDir string
@@ -48,7 +36,7 @@ func getConfigPath(clientType string, overridePath string) string {
func (c *Client) loadConfig() error {
originalConfig := *c.config // Store original config to detect changes
configPath := getConfigPath(c.clientType, c.configFilePath)
configPath := getConfigPath(c.clientType)
if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
logger.Debug("Config already provided, skipping loading from file")
@@ -95,10 +83,6 @@ func (c *Client) loadConfig() error {
c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint
}
// Always load the provisioning key from the file if not already set
if c.config.ProvisioningKey == "" {
c.config.ProvisioningKey = config.ProvisioningKey
}
// Check if CLI args provided values that override file values
if (!fileHadID && originalConfig.ID != "") ||
@@ -121,7 +105,7 @@ func (c *Client) saveConfig() error {
return nil
}
configPath := getConfigPath(c.clientType, c.configFilePath)
configPath := getConfigPath(c.clientType)
data, err := json.MarshalIndent(c.config, "", " ")
if err != nil {
return err
@@ -134,116 +118,3 @@ func (c *Client) saveConfig() error {
}
return err
}
// provisionIfNeeded checks whether a provisioning key is present and, if so,
// exchanges it for a newt ID and secret by calling the registration endpoint.
// On success the config is updated in-place and flagged for saving so that
// subsequent runs use the permanent credentials directly.
func (c *Client) provisionIfNeeded() error {
if c.config.ProvisioningKey == "" {
return nil
}
// If we already have both credentials there is nothing to provision.
if c.config.ID != "" && c.config.Secret != "" {
logger.Debug("Credentials already present, skipping provisioning")
return nil
}
logger.Info("Provisioning key found exchanging for newt credentials...")
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return fmt.Errorf("failed to parse base URL for provisioning: %w", err)
}
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
reqBody := map[string]interface{}{
"provisioningKey": c.config.ProvisioningKey,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal provisioning request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(
ctx,
"POST",
baseEndpoint+"/api/v1/auth/newt/register",
bytes.NewBuffer(jsonData),
)
if err != nil {
return fmt.Errorf("failed to create provisioning request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// Mirror the TLS setup used by getToken so mTLS / self-signed CAs work.
var tlsCfg *tls.Config
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" ||
len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsCfg, err = c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS for provisioning: %w", err)
}
}
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
if tlsCfg == nil {
tlsCfg = &tls.Config{}
}
tlsCfg.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled for provisioning via SKIP_TLS_VERIFY")
}
httpClient := &http.Client{}
if tlsCfg != nil {
httpClient.Transport = &http.Transport{TLSClientConfig: tlsCfg}
}
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("provisioning request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
logger.Debug("Provisioning response body: %s", string(body))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("provisioning endpoint returned status %d: %s", resp.StatusCode, string(body))
}
var provResp ProvisioningResponse
if err := json.Unmarshal(body, &provResp); err != nil {
return fmt.Errorf("failed to decode provisioning response: %w", err)
}
if !provResp.Success {
return fmt.Errorf("provisioning failed: %s", provResp.Message)
}
if provResp.Data.NewtID == "" || provResp.Data.Secret == "" {
return fmt.Errorf("provisioning response is missing newt ID or secret")
}
logger.Info("Successfully provisioned newt ID: %s", provResp.Data.NewtID)
// Persist the returned credentials and clear the one-time provisioning key
// so subsequent runs authenticate normally.
c.config.ID = provResp.Data.NewtID
c.config.Secret = provResp.Data.Secret
c.config.ProvisioningKey = ""
c.configNeedsSave = true
// Save immediately so that if the subsequent connection attempt fails the
// provisioning key is already gone from disk and the next retry uses the
// permanent credentials instead of trying to provision again.
if err := c.saveConfig(); err != nil {
logger.Error("Failed to save config after provisioning: %v", err)
}
return nil
}

View File

@@ -5,7 +5,6 @@ type Config struct {
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
TlsClientCert string `json:"tlsClientCert"`
ProvisioningKey string `json:"provisioningKey,omitempty"`
}
type TokenResponse struct {
@@ -17,17 +16,7 @@ type TokenResponse struct {
Message string `json:"message"`
}
type ProvisioningResponse struct {
Data struct {
NewtID string `json:"newtId"`
Secret string `json:"secret"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
ConfigVersion int64 `json:"configVersion,omitempty"`
}