Compare commits

...

29 Commits
1.10.2 ... dev

Author SHA1 Message Date
Owen Schwartz
a2683eb385 Merge pull request #274 from LaurenceJJones/refactor/proxy-cleanup-basics
refactor(proxy): cleanup basics - constants, remove dead code, fix de…
2026-03-18 15:39:43 -07:00
Owen Schwartz
d3722c2519 Merge pull request #280 from LaurenceJJones/fix/healthcheck-ipv6
fix(healthcheck): Support ipv6 healthchecks
2026-03-18 15:38:15 -07:00
Laurence
8fda35db4f fix(healthcheck): Support ipv6 healthchecks
Currently we are doing fmt.sprintf on hostname and port which will not properly handle ipv6 addresses, instead of changing pangolin to send bracketed address a simply net.join can do this for us since we dont need to parse a formatted string
2026-03-18 13:37:31 +00:00
Owen Schwartz
de4353f2e6 Merge pull request #269 from LaurenceJJones/feature/pprof-endpoint
feat(admin): Add pprof endpoints
2026-03-17 11:42:08 -07:00
Owen
8161fa6626 Bump ping interval up 2026-03-16 14:33:40 -07:00
Owen
24dfb3a8a2 Remove redundant info 2026-03-16 13:50:45 -07:00
Laurence
13448f76aa refactor(proxy): cleanup basics - constants, remove dead code, fix deprecated calls
- Add maxUDPPacketSize constant to replace magic number 65507
- Remove commented-out code in Stop()
- Replace deprecated ne.Temporary() with errors.Is(err, net.ErrClosed)
- Use errors.As instead of type assertion for net.Error
- Use errors.Is for closed connection checks instead of string matching
- Handle closed connection gracefully when reading from UDP target
2026-03-16 14:11:14 +00:00
Owen
d4ebb3e2af Send disconnecting message 2026-03-15 17:42:03 -07:00
Owen
bf029b7bb2 Clean up to match olm 2026-03-14 11:57:37 -07:00
Owen
745d2dbc7e Merge branch 'dev' into msg-opt 2026-03-13 17:10:49 -07:00
Owen
c7b01288e0 Clean up previous logging 2026-03-13 11:45:36 -07:00
Owen
539e595c48 Add optional compression 2026-03-12 17:49:05 -07:00
Laurence
836144aebf feat(admin): Add pprof endpoints
To aid us in debugging user issues with memory or leaks we need to be able for the user to configure pprof, wait and then provide us the output files to see where memory/leaks occur in actual runtimes
2026-03-12 09:22:50 +00:00
Owen
a1df3d7ff0 Merge branch 'dev' of github.com:fosrl/newt into dev 2026-03-11 17:28:16 -07:00
Laurence
d68a13ea1f feat(installer): prefer /usr/local/bin and improve POSIX compatibility
- Always install to /usr/local/bin instead of ~/.local/bin
  - Use sudo automatically when write access is needed
  - Replace bash-specific syntax with POSIX equivalents:
    - Change shebang from #!/bin/bash to #!/bin/sh
    - Replace [[ == *pattern* ]] with case statements
    - Replace echo -e with printf for colored output
  - Script now works with dash, ash, busybox sh, and bash
2026-03-10 10:01:28 -07:00
Owen
accac75a53 Set newt version in dockerfile 2026-03-08 11:26:35 -07:00
Laurence
768415f90b Parse target strings with IPv6 support and strict validation
Add parseTargetString() for listenPort:host:targetPort using net.SplitHostPort/JoinHostPort. Replace manual split in updateTargets; fix err shadowing on remove. Validate listen port 1–65535 and reject empty host/port; use %w for errors. Add tests for IPv4, IPv6, hostnames, and invalid cases.
2026-03-07 21:32:36 -08:00
Owen
da9825d030 Merge branch 'main' into dev 2026-03-07 12:34:45 -08:00
Owen
afdb1fc977 Make sure to set version and fix prepare issue 2026-03-07 12:32:49 -08:00
Owen
1bd1133ac2 Make sure to skip prepare 2026-03-07 10:36:18 -08:00
Owen
fac0f5b197 Build full arn 2026-03-07 10:17:14 -08:00
Owen
e68b65683f Temp lets ignore the sync messages 2026-03-06 15:14:48 -08:00
Owen
7d6825132b Merge branch 'dev' into msg-opt 2026-03-03 16:56:41 -08:00
Owen
6371e980d2 Update the get all rules 2026-03-03 16:11:32 -08:00
Owen
15ea631b96 Mutex on handlers, slight change to ping message and handler 2026-03-02 20:56:36 -08:00
Owen
4e854b5f96 Working on message versioning 2026-03-02 20:56:18 -08:00
Owen
287eef0f44 Add version and send it down 2026-03-02 18:27:26 -08:00
Owen
f982e6b629 Merge branch 'dev' into msg-opt 2026-03-02 18:13:55 -08:00
Owen
039ae07b7b Support prefixes sent from server 2026-03-02 18:11:20 -08:00
14 changed files with 1063 additions and 203 deletions

View File

@@ -136,7 +136,7 @@ jobs:
build-amd: build-amd:
name: Build image (linux/amd64) name: Build image (linux/amd64)
needs: [pre-run, prepare] needs: [pre-run, prepare]
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }} if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]' && needs.prepare.result == 'skipped') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
runs-on: [self-hosted, linux, x64] runs-on: [self-hosted, linux, x64]
timeout-minutes: 120 timeout-minutes: 120
env: env:
@@ -269,6 +269,7 @@ jobs:
context: . context: .
push: true push: true
platforms: linux/amd64 platforms: linux/amd64
build-args: VERSION=${{ env.TAG }}
tags: | tags: |
${{ env.GHCR_IMAGE }}:amd64-${{ env.TAG }} ${{ env.GHCR_IMAGE }}:amd64-${{ env.TAG }}
${{ env.DOCKERHUB_IMAGE }}:amd64-${{ env.TAG }} ${{ env.DOCKERHUB_IMAGE }}:amd64-${{ env.TAG }}
@@ -293,7 +294,7 @@ jobs:
build-arm: build-arm:
name: Build image (linux/arm64) name: Build image (linux/arm64)
needs: [pre-run, prepare] needs: [pre-run, prepare]
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }} if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]' && needs.prepare.result == 'skipped') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
runs-on: [self-hosted, linux, arm64] # NOTE: ensure label exists on runner runs-on: [self-hosted, linux, arm64] # NOTE: ensure label exists on runner
timeout-minutes: 120 timeout-minutes: 120
env: env:
@@ -393,6 +394,7 @@ jobs:
context: . context: .
push: true push: true
platforms: linux/arm64 platforms: linux/arm64
build-args: VERSION=${{ env.TAG }}
tags: | tags: |
${{ env.GHCR_IMAGE }}:arm64-${{ env.TAG }} ${{ env.GHCR_IMAGE }}:arm64-${{ env.TAG }}
${{ env.DOCKERHUB_IMAGE }}:arm64-${{ env.TAG }} ${{ env.DOCKERHUB_IMAGE }}:arm64-${{ env.TAG }}
@@ -417,7 +419,7 @@ jobs:
build-armv7: build-armv7:
name: Build image (linux/arm/v7) name: Build image (linux/arm/v7)
needs: [pre-run, prepare] needs: [pre-run, prepare]
if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }} if: ${{ needs.pre-run.result == 'success' && ((github.event_name == 'push' && github.actor != 'github-actions[bot]' && needs.prepare.result == 'skipped') || (github.event_name == 'workflow_dispatch' && (needs.prepare.result == 'success' || needs.prepare.result == 'skipped'))) }}
runs-on: [self-hosted, linux, arm64] runs-on: [self-hosted, linux, arm64]
timeout-minutes: 120 timeout-minutes: 120
env: env:
@@ -509,6 +511,7 @@ jobs:
context: . context: .
push: true push: true
platforms: linux/arm/v7 platforms: linux/arm/v7
build-args: VERSION=${{ env.TAG }}
tags: | tags: |
${{ env.GHCR_IMAGE }}:armv7-${{ env.TAG }} ${{ env.GHCR_IMAGE }}:armv7-${{ env.TAG }}
${{ env.DOCKERHUB_IMAGE }}:armv7-${{ env.TAG }} ${{ env.DOCKERHUB_IMAGE }}:armv7-${{ env.TAG }}
@@ -887,7 +890,7 @@ jobs:
shell: bash shell: bash
run: | run: |
set -euo pipefail set -euo pipefail
make -j 10 go-build-release tag="${TAG}" make -j 10 go-build-release VERSION="${TAG}"
- name: Create GitHub Release (draft) - name: Create GitHub Release (draft)
uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2 uses: softprops/action-gh-release@5be0e66d93ac7ed76da52eca8bb058f665c3a5fe # v2.4.2

View File

@@ -17,7 +17,8 @@ RUN go mod download
COPY . . COPY . .
# Build the application # Build the application
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt ARG VERSION=dev
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X main.newtVersion=${VERSION}" -o /newt
FROM public.ecr.aws/docker/library/alpine:3.23 AS runner FROM public.ecr.aws/docker/library/alpine:3.23 AS runner

View File

@@ -2,6 +2,9 @@
all: local all: local
VERSION ?= dev
LDFLAGS = -X main.newtVersion=$(VERSION)
local: local:
CGO_ENABLED=0 go build -o ./bin/newt CGO_ENABLED=0 go build -o ./bin/newt
@@ -40,31 +43,31 @@ go-build-release: \
go-build-release-freebsd-arm64 go-build-release-freebsd-arm64
go-build-release-linux-arm64: go-build-release-linux-arm64:
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64 CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm64
go-build-release-linux-arm32-v7: go-build-release-linux-arm32-v7:
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32 CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32
go-build-release-linux-arm32-v6: go-build-release-linux-arm32-v6:
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -o bin/newt_linux_arm32v6 CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_arm32v6
go-build-release-linux-amd64: go-build-release-linux-amd64:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64 CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_amd64
go-build-release-linux-riscv64: go-build-release-linux-riscv64:
CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -o bin/newt_linux_riscv64 CGO_ENABLED=0 GOOS=linux GOARCH=riscv64 go build -ldflags "$(LDFLAGS)" -o bin/newt_linux_riscv64
go-build-release-darwin-arm64: go-build-release-darwin-arm64:
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64 CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_arm64
go-build-release-darwin-amd64: go-build-release-darwin-amd64:
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64 CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_darwin_amd64
go-build-release-windows-amd64: go-build-release-windows-amd64:
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_windows_amd64.exe
go-build-release-freebsd-amd64: go-build-release-freebsd-amd64:
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64 CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_amd64
go-build-release-freebsd-arm64: go-build-release-freebsd-arm64:
CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64 CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/newt_freebsd_arm64

View File

@@ -37,11 +37,12 @@ type WgConfig struct {
} }
type Target struct { type Target struct {
SourcePrefix string `json:"sourcePrefix"` SourcePrefix string `json:"sourcePrefix"`
DestPrefix string `json:"destPrefix"` SourcePrefixes []string `json:"sourcePrefixes"`
RewriteTo string `json:"rewriteTo,omitempty"` DestPrefix string `json:"destPrefix"`
DisableIcmp bool `json:"disableIcmp,omitempty"` RewriteTo string `json:"rewriteTo,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"` DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
} }
type PortRange struct { type PortRange struct {
@@ -112,8 +113,6 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
return nil, fmt.Errorf("failed to generate private key: %v", err) return nil, fmt.Errorf("failed to generate private key: %v", err)
} }
logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port)
if port == 0 { if port == 0 {
// Find an available port // Find an available port
portRandom, err := util.FindAvailableUDPPort(49152, 65535) portRandom, err := util.FindAvailableUDPPort(49152, 65535)
@@ -174,6 +173,7 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget) wsClient.RegisterHandler("newt/wg/targets/add", service.handleAddTarget)
wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget) wsClient.RegisterHandler("newt/wg/targets/remove", service.handleRemoveTarget)
wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget) wsClient.RegisterHandler("newt/wg/targets/update", service.handleUpdateTarget)
wsClient.RegisterHandler("newt/wg/sync", service.handleSyncConfig)
return service, nil return service, nil
} }
@@ -279,7 +279,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string, rel
} }
if relayPort == 0 { if relayPort == 0 {
relayPort = 21820 relayPort = 21820
} }
// Convert websocket.ExitNode to holepunch.ExitNode // Convert websocket.ExitNode to holepunch.ExitNode
@@ -494,6 +494,183 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Client connectivity setup. Ready to accept connections from clients!") logger.Info("Client connectivity setup. Ready to accept connections from clients!")
} }
// SyncConfig represents the configuration sent from server for syncing
type SyncConfig struct {
Targets []Target `json:"targets"`
Peers []Peer `json:"peers"`
}
func (s *WireGuardService) handleSyncConfig(msg websocket.WSMessage) {
var syncConfig SyncConfig
logger.Debug("Received sync message: %v", msg)
logger.Info("Received sync configuration from remote server")
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncConfig); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
// Sync peers
if err := s.syncPeers(syncConfig.Peers); err != nil {
logger.Error("Failed to sync peers: %v", err)
}
// Sync targets
if err := s.syncTargets(syncConfig.Targets); err != nil {
logger.Error("Failed to sync targets: %v", err)
}
}
// syncPeers synchronizes the current peers with the desired state
// It removes peers not in the desired list and adds missing ones
func (s *WireGuardService) syncPeers(desiredPeers []Peer) error {
if s.device == nil {
return fmt.Errorf("WireGuard device is not initialized")
}
// Get current peers from the device
currentConfig, err := s.device.IpcGet()
if err != nil {
return fmt.Errorf("failed to get current device config: %v", err)
}
// Parse current peer public keys
lines := strings.Split(currentConfig, "\n")
currentPeerKeys := make(map[string]bool)
for _, line := range lines {
if strings.HasPrefix(line, "public_key=") {
pubKey := strings.TrimPrefix(line, "public_key=")
currentPeerKeys[pubKey] = true
}
}
// Build a map of desired peers by their public key (normalized)
desiredPeerMap := make(map[string]Peer)
for _, peer := range desiredPeers {
// Normalize the public key for comparison
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
logger.Warn("Invalid public key in desired peers: %s", peer.PublicKey)
continue
}
normalizedKey := util.FixKey(pubKey.String())
desiredPeerMap[normalizedKey] = peer
}
// Remove peers that are not in the desired list
for currentKey := range currentPeerKeys {
if _, exists := desiredPeerMap[currentKey]; !exists {
// Parse the key back to get the original format for removal
removeConfig := fmt.Sprintf("public_key=%s\nremove=true", currentKey)
if err := s.device.IpcSet(removeConfig); err != nil {
logger.Warn("Failed to remove peer %s during sync: %v", currentKey, err)
} else {
logger.Info("Removed peer %s during sync", currentKey)
}
}
}
// Add peers that are missing
for normalizedKey, peer := range desiredPeerMap {
if _, exists := currentPeerKeys[normalizedKey]; !exists {
if err := s.addPeerToDevice(peer); err != nil {
logger.Warn("Failed to add peer %s during sync: %v", peer.PublicKey, err)
} else {
logger.Info("Added peer %s during sync", peer.PublicKey)
}
}
}
return nil
}
// syncTargets synchronizes the current targets with the desired state
// It removes targets not in the desired list and adds missing ones
func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently
logger.Debug("Skipping target sync - using native interface (no proxy support)")
return nil
}
// Get current rules from the proxy handler
currentRules := s.tnet.GetProxySubnetRules()
// Build a map of current rules by source+dest prefix
type ruleKey struct {
sourcePrefix string
destPrefix string
}
currentRuleMap := make(map[ruleKey]bool)
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
currentRuleMap[key] = true
}
// Build a map of desired targets
desiredTargetMap := make(map[ruleKey]Target)
for _, target := range desiredTargets {
key := ruleKey{
sourcePrefix: target.SourcePrefix,
destPrefix: target.DestPrefix,
}
desiredTargetMap[key] = target
}
// Remove targets that are not in the desired list
for _, rule := range currentRules {
key := ruleKey{
sourcePrefix: rule.SourcePrefix.String(),
destPrefix: rule.DestPrefix.String(),
}
if _, exists := desiredTargetMap[key]; !exists {
s.tnet.RemoveProxySubnetRule(rule.SourcePrefix, rule.DestPrefix)
logger.Info("Removed target %s -> %s during sync", rule.SourcePrefix.String(), rule.DestPrefix.String())
}
}
// Add targets that are missing
for key, target := range desiredTargetMap {
if _, exists := currentRuleMap[key]; !exists {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Warn("Invalid source prefix %s during sync: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil {
logger.Warn("Invalid dest prefix %s during sync: %v", target.DestPrefix, err)
continue
}
var portRanges []netstack2.PortRange
for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min,
Max: pr.Max,
Protocol: pr.Protocol,
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
}
}
return nil
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.mu.Lock() s.mu.Lock()
@@ -697,6 +874,19 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil return nil
} }
// resolveSourcePrefixes returns the effective list of source prefixes for a target,
// supporting both the legacy single SourcePrefix field and the new SourcePrefixes array.
// If SourcePrefixes is non-empty it takes precedence; otherwise SourcePrefix is used.
func resolveSourcePrefixes(target Target) []string {
if len(target.SourcePrefixes) > 0 {
return target.SourcePrefixes
}
if target.SourcePrefix != "" {
return []string{target.SourcePrefix}
}
return nil
}
func (s *WireGuardService) ensureTargets(targets []Target) error { func (s *WireGuardService) ensureTargets(targets []Target) error {
if s.tnet == nil { if s.tnet == nil {
// Native interface mode - proxy features not available, skip silently // Native interface mode - proxy features not available, skip silently
@@ -705,11 +895,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
} }
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.SourcePrefix, err)
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
@@ -724,9 +909,14 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp) if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
return nil return nil
@@ -1045,7 +1235,7 @@ func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txByt
BytesOut: bytesOutMB, BytesOut: bytesOutMB,
} }
} }
return nil return nil
} }
} }
@@ -1096,12 +1286,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1111,15 +1295,21 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min, Min: pr.Min,
Max: pr.Max, Max: pr.Max,
Protocol: pr.Protocol, Protocol: pr.Protocol,
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp) if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
} }
@@ -1148,21 +1338,21 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
// Process all targets // Process all targets
for _, target := range targets { for _, target := range targets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) for _, sp := range resolveSourcePrefixes(target) {
sourcePrefix, err := netip.ParsePrefix(sp)
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
} }
} }
@@ -1196,30 +1386,24 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
// Process all update requests // Process all update requests
for _, target := range requests.OldTargets { for _, target := range requests.OldTargets {
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
continue continue
} }
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix) for _, sp := range resolveSourcePrefixes(target) {
logger.Info("Removed target subnet %s with destination %s", target.SourcePrefix, target.DestPrefix) sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.RemoveProxySubnetRule(sourcePrefix, destPrefix)
logger.Info("Removed target subnet %s with destination %s", sp, target.DestPrefix)
}
} }
for _, target := range requests.NewTargets { for _, target := range requests.NewTargets {
// Now add the new target
sourcePrefix, err := netip.ParsePrefix(target.SourcePrefix)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.SourcePrefix, err)
continue
}
destPrefix, err := netip.ParsePrefix(target.DestPrefix) destPrefix, err := netip.ParsePrefix(target.DestPrefix)
if err != nil { if err != nil {
logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err) logger.Info("Invalid CIDR %s: %v", target.DestPrefix, err)
@@ -1229,14 +1413,21 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min, Min: pr.Min,
Max: pr.Max, Max: pr.Max,
Protocol: pr.Protocol, Protocol: pr.Protocol,
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) for _, sp := range resolveSourcePrefixes(target) {
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp) sourcePrefix, err := netip.ParsePrefix(sp)
if err != nil {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
@@ -363,27 +364,62 @@ func parseTargetData(data interface{}) (TargetData, error) {
return targetData, nil return targetData, nil
} }
// parseTargetString parses a target string in the format "listenPort:host:targetPort"
// It properly handles IPv6 addresses which must be in brackets: "listenPort:[ipv6]:targetPort"
// Examples:
// - IPv4: "3001:192.168.1.1:80"
// - IPv6: "3001:[::1]:8080" or "3001:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:80"
//
// Returns listenPort, targetAddress (in host:port format suitable for net.Dial), and error
func parseTargetString(target string) (int, string, error) {
// Find the first colon to extract the listen port
firstColon := strings.Index(target, ":")
if firstColon == -1 {
return 0, "", fmt.Errorf("invalid target format, no colon found: %s", target)
}
listenPortStr := target[:firstColon]
var listenPort int
_, err := fmt.Sscanf(listenPortStr, "%d", &listenPort)
if err != nil {
return 0, "", fmt.Errorf("invalid listen port: %s", listenPortStr)
}
if listenPort <= 0 || listenPort > 65535 {
return 0, "", fmt.Errorf("listen port out of range: %d", listenPort)
}
// The remainder is host:targetPort - use net.SplitHostPort which handles IPv6 brackets
remainder := target[firstColon+1:]
host, targetPort, err := net.SplitHostPort(remainder)
if err != nil {
return 0, "", fmt.Errorf("invalid host:port format '%s': %w", remainder, err)
}
// Reject empty host or target port
if host == "" {
return 0, "", fmt.Errorf("empty host in target: %s", target)
}
if targetPort == "" {
return 0, "", fmt.Errorf("empty target port in target: %s", target)
}
// Reconstruct the target address using JoinHostPort (handles IPv6 properly)
targetAddr := net.JoinHostPort(host, targetPort)
return listenPort, targetAddr, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets { for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port // Parse the target string, handling both IPv4 and IPv6 addresses
parts := strings.Split(t, ":") port, target, err := parseTargetString(t)
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil { if err != nil {
logger.Info("Invalid port: %s", parts[0]) logger.Info("Invalid target format: %s (%v)", t, err)
continue continue
} }
switch action { switch action {
case "add": case "add":
target := parts[1] + ":" + parts[2]
// Call updown script if provided // Call updown script if provided
processedTarget := target processedTarget := target
if updownScript != "" { if updownScript != "" {
@@ -410,8 +446,6 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
case "remove": case "remove":
logger.Info("Removing target with port %d", port) logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided // Call updown script if provided
if updownScript != "" { if updownScript != "" {
_, err := executeUpdownScript(action, proto, target) _, err := executeUpdownScript(action, proto, target)
@@ -420,7 +454,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
} }
} }
err := pm.RemoveTarget(proto, tunnelIP, port) err = pm.RemoveTarget(proto, tunnelIP, port)
if err != nil { if err != nil {
logger.Error("Failed to remove target: %v", err) logger.Error("Failed to remove target: %v", err)
return err return err

212
common_test.go Normal file
View File

@@ -0,0 +1,212 @@
package main
import (
"net"
"testing"
)
func TestParseTargetString(t *testing.T) {
tests := []struct {
name string
input string
wantListenPort int
wantTargetAddr string
wantErr bool
}{
// IPv4 test cases
{
name: "valid IPv4 basic",
input: "3001:192.168.1.1:80",
wantListenPort: 3001,
wantTargetAddr: "192.168.1.1:80",
wantErr: false,
},
{
name: "valid IPv4 localhost",
input: "8080:127.0.0.1:3000",
wantListenPort: 8080,
wantTargetAddr: "127.0.0.1:3000",
wantErr: false,
},
{
name: "valid IPv4 same ports",
input: "443:10.0.0.1:443",
wantListenPort: 443,
wantTargetAddr: "10.0.0.1:443",
wantErr: false,
},
// IPv6 test cases
{
name: "valid IPv6 loopback",
input: "3001:[::1]:8080",
wantListenPort: 3001,
wantTargetAddr: "[::1]:8080",
wantErr: false,
},
{
name: "valid IPv6 full address",
input: "80:[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantListenPort: 80,
wantTargetAddr: "[fd70:1452:b736:4dd5:caca:7db9:c588:f5b3]:8080",
wantErr: false,
},
{
name: "valid IPv6 link-local",
input: "443:[fe80::1]:443",
wantListenPort: 443,
wantTargetAddr: "[fe80::1]:443",
wantErr: false,
},
{
name: "valid IPv6 all zeros compressed",
input: "8000:[::]:9000",
wantListenPort: 8000,
wantTargetAddr: "[::]:9000",
wantErr: false,
},
{
name: "valid IPv6 mixed notation",
input: "5000:[::ffff:192.168.1.1]:6000",
wantListenPort: 5000,
wantTargetAddr: "[::ffff:192.168.1.1]:6000",
wantErr: false,
},
// Hostname test cases
{
name: "valid hostname",
input: "8080:example.com:80",
wantListenPort: 8080,
wantTargetAddr: "example.com:80",
wantErr: false,
},
{
name: "valid hostname with subdomain",
input: "443:api.example.com:8443",
wantListenPort: 443,
wantTargetAddr: "api.example.com:8443",
wantErr: false,
},
{
name: "valid localhost hostname",
input: "3000:localhost:3000",
wantListenPort: 3000,
wantTargetAddr: "localhost:3000",
wantErr: false,
},
// Error cases
{
name: "invalid - no colons",
input: "invalid",
wantErr: true,
},
{
name: "invalid - empty string",
input: "",
wantErr: true,
},
{
name: "invalid - non-numeric listen port",
input: "abc:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - missing target port",
input: "3001:192.168.1.1",
wantErr: true,
},
{
name: "invalid - IPv6 without brackets",
input: "3001:fd70:1452:b736:4dd5:caca:7db9:c588:f5b3:80",
wantErr: true,
},
{
name: "invalid - only listen port",
input: "3001:",
wantErr: true,
},
{
name: "invalid - missing host",
input: "3001::80",
wantErr: true,
},
{
name: "invalid - IPv6 unclosed bracket",
input: "3001:[::1:80",
wantErr: true,
},
{
name: "invalid - listen port zero",
input: "0:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port negative",
input: "-1:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - listen port out of range",
input: "70000:192.168.1.1:80",
wantErr: true,
},
{
name: "invalid - empty target port",
input: "3001:192.168.1.1:",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
listenPort, targetAddr, err := parseTargetString(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseTargetString(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if tt.wantErr {
return // Don't check other values if we expected an error
}
if listenPort != tt.wantListenPort {
t.Errorf("parseTargetString(%q) listenPort = %d, want %d", tt.input, listenPort, tt.wantListenPort)
}
if targetAddr != tt.wantTargetAddr {
t.Errorf("parseTargetString(%q) targetAddr = %q, want %q", tt.input, targetAddr, tt.wantTargetAddr)
}
})
}
}
// TestParseTargetStringNetDialCompatibility verifies that the output is compatible with net.Dial
func TestParseTargetStringNetDialCompatibility(t *testing.T) {
tests := []struct {
name string
input string
}{
{"IPv4", "8080:127.0.0.1:80"},
{"IPv6 loopback", "8080:[::1]:80"},
{"IPv6 full", "8080:[2001:db8::1]:80"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, targetAddr, err := parseTargetString(tt.input)
if err != nil {
t.Fatalf("parseTargetString(%q) unexpected error: %v", tt.input, err)
}
// Verify the format is valid for net.Dial by checking it can be split back
// This doesn't actually dial, just validates the format
_, _, err = net.SplitHostPort(targetAddr)
if err != nil {
t.Errorf("parseTargetString(%q) produced invalid net.Dial format %q: %v", tt.input, targetAddr, err)
}
})
}
}

View File

@@ -1,7 +1,7 @@
#!/bin/bash #!/bin/sh
# Get Newt - Cross-platform installation script # Get Newt - Cross-platform installation script
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | bash # Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/newt/refs/heads/main/get-newt.sh | sh
set -e set -e
@@ -17,15 +17,15 @@ GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
# Function to print colored output # Function to print colored output
print_status() { print_status() {
echo -e "${GREEN}[INFO]${NC} $1" printf '%b[INFO]%b %s\n' "${GREEN}" "${NC}" "$1"
} }
print_warning() { print_warning() {
echo -e "${YELLOW}[WARN]${NC} $1" printf '%b[WARN]%b %s\n' "${YELLOW}" "${NC}" "$1"
} }
print_error() { print_error() {
echo -e "${RED}[ERROR]${NC} $1" printf '%b[ERROR]%b %s\n' "${RED}" "${NC}" "$1"
} }
# Function to get latest version from GitHub API # Function to get latest version from GitHub API
@@ -113,16 +113,34 @@ get_install_dir() {
if [ "$OS" = "windows" ]; then if [ "$OS" = "windows" ]; then
echo "$HOME/bin" echo "$HOME/bin"
else else
# Try to use a directory in PATH, fallback to ~/.local/bin # Prefer /usr/local/bin for system-wide installation
if echo "$PATH" | grep -q "/usr/local/bin"; then echo "/usr/local/bin"
if [ -w "/usr/local/bin" ] 2>/dev/null; then fi
echo "/usr/local/bin" }
else
echo "$HOME/.local/bin" # Check if we need sudo for installation
fi needs_sudo() {
local install_dir="$1"
if [ -w "$install_dir" ] 2>/dev/null; then
return 1 # No sudo needed
else
return 0 # Sudo needed
fi
}
# Get the appropriate command prefix (sudo or empty)
get_sudo_cmd() {
local install_dir="$1"
if needs_sudo "$install_dir"; then
if command -v sudo >/dev/null 2>&1; then
echo "sudo"
else else
echo "$HOME/.local/bin" print_error "Cannot write to ${install_dir} and sudo is not available."
print_error "Please run this script as root or install sudo."
exit 1
fi fi
else
echo ""
fi fi
} }
@@ -130,21 +148,24 @@ get_install_dir() {
install_newt() { install_newt() {
local platform="$1" local platform="$1"
local install_dir="$2" local install_dir="$2"
local sudo_cmd="$3"
local binary_name="newt_${platform}" local binary_name="newt_${platform}"
local exe_suffix="" local exe_suffix=""
# Add .exe suffix for Windows # Add .exe suffix for Windows
if [[ "$platform" == *"windows"* ]]; then case "$platform" in
binary_name="${binary_name}.exe" *windows*)
exe_suffix=".exe" binary_name="${binary_name}.exe"
fi exe_suffix=".exe"
;;
esac
local download_url="${BASE_URL}/${binary_name}" local download_url="${BASE_URL}/${binary_name}"
local temp_file="/tmp/newt${exe_suffix}" local temp_file="/tmp/newt${exe_suffix}"
local final_path="${install_dir}/newt${exe_suffix}" local final_path="${install_dir}/newt${exe_suffix}"
print_status "Downloading newt from ${download_url}" print_status "Downloading newt from ${download_url}"
# Download the binary # Download the binary
if command -v curl >/dev/null 2>&1; then if command -v curl >/dev/null 2>&1; then
curl -fsSL "$download_url" -o "$temp_file" curl -fsSL "$download_url" -o "$temp_file"
@@ -154,18 +175,22 @@ install_newt() {
print_error "Neither curl nor wget is available. Please install one of them." print_error "Neither curl nor wget is available. Please install one of them."
exit 1 exit 1
fi fi
# Make executable before moving
chmod +x "$temp_file"
# Create install directory if it doesn't exist # Create install directory if it doesn't exist
mkdir -p "$install_dir" if [ -n "$sudo_cmd" ]; then
$sudo_cmd mkdir -p "$install_dir"
# Move binary to install directory print_status "Using sudo to install to ${install_dir}"
mv "$temp_file" "$final_path" $sudo_cmd mv "$temp_file" "$final_path"
else
# Make executable (not needed on Windows, but doesn't hurt) mkdir -p "$install_dir"
chmod +x "$final_path" mv "$temp_file" "$final_path"
fi
print_status "newt installed to ${final_path}" print_status "newt installed to ${final_path}"
# Check if install directory is in PATH # Check if install directory is in PATH
if ! echo "$PATH" | grep -q "$install_dir"; then if ! echo "$PATH" | grep -q "$install_dir"; then
print_warning "Install directory ${install_dir} is not in your PATH." print_warning "Install directory ${install_dir} is not in your PATH."
@@ -179,9 +204,9 @@ verify_installation() {
local install_dir="$1" local install_dir="$1"
local exe_suffix="" local exe_suffix=""
if [[ "$PLATFORM" == *"windows"* ]]; then case "$PLATFORM" in
exe_suffix=".exe" *windows*) exe_suffix=".exe" ;;
fi esac
local newt_path="${install_dir}/newt${exe_suffix}" local newt_path="${install_dir}/newt${exe_suffix}"
@@ -198,34 +223,36 @@ verify_installation() {
# Main installation process # Main installation process
main() { main() {
print_status "Installing latest version of newt..." print_status "Installing latest version of newt..."
# Get latest version # Get latest version
print_status "Fetching latest version from GitHub..." print_status "Fetching latest version from GitHub..."
VERSION=$(get_latest_version) VERSION=$(get_latest_version)
print_status "Latest version: v${VERSION}" print_status "Latest version: v${VERSION}"
# Set base URL with the fetched version # Set base URL with the fetched version
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}" BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
# Detect platform # Detect platform
PLATFORM=$(detect_platform) PLATFORM=$(detect_platform)
print_status "Detected platform: ${PLATFORM}" print_status "Detected platform: ${PLATFORM}"
# Get install directory # Get install directory
INSTALL_DIR=$(get_install_dir) INSTALL_DIR=$(get_install_dir)
print_status "Install directory: ${INSTALL_DIR}" print_status "Install directory: ${INSTALL_DIR}"
# Check if we need sudo
SUDO_CMD=$(get_sudo_cmd "$INSTALL_DIR")
if [ -n "$SUDO_CMD" ]; then
print_status "Root privileges required for installation to ${INSTALL_DIR}"
fi
# Install newt # Install newt
install_newt "$PLATFORM" "$INSTALL_DIR" install_newt "$PLATFORM" "$INSTALL_DIR" "$SUDO_CMD"
# Verify installation # Verify installation
if verify_installation "$INSTALL_DIR"; then if verify_installation "$INSTALL_DIR"; then
print_status "newt is ready to use!" print_status "newt is ready to use!"
if [[ "$PLATFORM" == *"windows"* ]]; then print_status "Run 'newt --help' to get started"
print_status "Run 'newt --help' to get started"
else
print_status "Run 'newt --help' to get started"
fi
else else
exit 1 exit 1
fi fi

View File

@@ -5,7 +5,9 @@ import (
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -365,11 +367,12 @@ func (m *Monitor) performHealthCheck(target *Target) {
target.LastCheck = time.Now() target.LastCheck = time.Now()
target.LastError = "" target.LastError = ""
// Build URL // Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname) host := target.Config.Hostname
if target.Config.Port > 0 { if target.Config.Port > 0 {
url = fmt.Sprintf("%s:%d", url, target.Config.Port) host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
} }
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
if target.Config.Path != "" { if target.Config.Path != "" {
if !strings.HasPrefix(target.Config.Path, "/") { if !strings.HasPrefix(target.Config.Path, "/") {
url += "/" url += "/"
@@ -521,3 +524,82 @@ func (m *Monitor) DisableTarget(id int) error {
return nil return nil
} }
// GetTargetIDs returns a slice of all current target IDs
func (m *Monitor) GetTargetIDs() []int {
m.mutex.RLock()
defer m.mutex.RUnlock()
ids := make([]int, 0, len(m.targets))
for id := range m.targets {
ids = append(ids, id)
}
return ids
}
// SyncTargets synchronizes the current targets to match the desired set.
// It removes targets not in the desired set and adds targets that are missing.
func (m *Monitor) SyncTargets(desiredConfigs []Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
logger.Info("Syncing health check targets: %d desired targets", len(desiredConfigs))
// Build a set of desired target IDs
desiredIDs := make(map[int]Config)
for _, config := range desiredConfigs {
desiredIDs[config.ID] = config
}
// Find targets to remove (exist but not in desired set)
var toRemove []int
for id := range m.targets {
if _, exists := desiredIDs[id]; !exists {
toRemove = append(toRemove, id)
}
}
// Remove targets that are not in the desired set
for _, id := range toRemove {
logger.Info("Sync: removing health check target %d", id)
if target, exists := m.targets[id]; exists {
target.cancel()
delete(m.targets, id)
}
}
// Add or update targets from the desired set
var addedCount, updatedCount int
for id, config := range desiredIDs {
if existing, exists := m.targets[id]; exists {
// Target exists - check if config changed and update if needed
// For now, we'll replace it to ensure config is up to date
logger.Debug("Sync: updating health check target %d", id)
existing.cancel()
delete(m.targets, id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to update target %d: %v", id, err)
return fmt.Errorf("failed to update target %d: %v", id, err)
}
updatedCount++
} else {
// Target doesn't exist - add it
logger.Debug("Sync: adding health check target %d", id)
if err := m.addTargetUnsafe(config); err != nil {
logger.Error("Sync: failed to add target %d: %v", id, err)
return fmt.Errorf("failed to add target %d: %v", id, err)
}
addedCount++
}
}
logger.Info("Sync complete: removed %d, added %d, updated %d targets",
len(toRemove), addedCount, updatedCount)
// Notify callback if any changes were made
if (len(toRemove) > 0 || addedCount > 0 || updatedCount > 0) && m.callback != nil {
go m.callback(m.getAllTargetsUnsafe())
}
return nil
}

193
main.go
View File

@@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/http/pprof"
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
@@ -147,6 +148,7 @@ var (
adminAddr string adminAddr string
region string region string
metricsAsyncBytes bool metricsAsyncBytes bool
pprofEnabled bool
blueprintFile string blueprintFile string
noCloud bool noCloud bool
@@ -225,6 +227,7 @@ func runNewtMain(ctx context.Context) {
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR") adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
regionEnv := os.Getenv("NEWT_REGION") regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES") asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
disableClientsEnv := os.Getenv("DISABLE_CLIENTS") disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true" disableClients = disableClientsEnv == "true"
@@ -302,10 +305,10 @@ func runNewtMain(ctx context.Context) {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)") flag.StringVar(&dockerSocket, "docker-socket", "", "Path or address to Docker socket (typically unix:///var/run/docker.sock)")
} }
if pingIntervalStr == "" { if pingIntervalStr == "" {
flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") flag.StringVar(&pingIntervalStr, "ping-interval", "15s", "Interval for pinging the server (default 15s)")
} }
if pingTimeoutStr == "" { if pingTimeoutStr == "" {
flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)") flag.StringVar(&pingTimeoutStr, "ping-timeout", "7s", " Timeout for each ping (default 7s)")
} }
// load the prefer endpoint just as a flag // load the prefer endpoint just as a flag
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)") flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
@@ -330,21 +333,21 @@ func runNewtMain(ctx context.Context) {
if pingIntervalStr != "" { if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr) pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil { if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr) fmt.Printf("Invalid PING_INTERVAL value: %s, using default 15 seconds\n", pingIntervalStr)
pingInterval = 3 * time.Second pingInterval = 15 * time.Second
} }
} else { } else {
pingInterval = 3 * time.Second pingInterval = 15 * time.Second
} }
if pingTimeoutStr != "" { if pingTimeoutStr != "" {
pingTimeout, err = time.ParseDuration(pingTimeoutStr) pingTimeout, err = time.ParseDuration(pingTimeoutStr)
if err != nil { if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr) fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 7 seconds\n", pingTimeoutStr)
pingTimeout = 5 * time.Second pingTimeout = 7 * time.Second
} }
} else { } else {
pingTimeout = 5 * time.Second pingTimeout = 7 * time.Second
} }
if dockerEnforceNetworkValidation == "" { if dockerEnforceNetworkValidation == "" {
@@ -390,6 +393,14 @@ func runNewtMain(ctx context.Context) {
metricsAsyncBytes = v metricsAsyncBytes = v
} }
} }
// pprof debug endpoint toggle
if pprofEnabledEnv == "" {
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
} else {
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
pprofEnabled = v
}
}
// Optional region flag (resource attribute) // Optional region flag (resource attribute)
if regionEnv == "" { if regionEnv == "" {
flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)") flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)")
@@ -485,6 +496,14 @@ func runNewtMain(ctx context.Context) {
if tel.PrometheusHandler != nil { if tel.PrometheusHandler != nil {
mux.Handle("/metrics", tel.PrometheusHandler) mux.Handle("/metrics", tel.PrometheusHandler)
} }
if pprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
}
admin := &http.Server{ admin := &http.Server{
Addr: tcfg.AdminAddr, Addr: tcfg.AdminAddr,
Handler: otelhttp.NewHandler(mux, "newt-admin"), Handler: otelhttp.NewHandler(mux, "newt-admin"),
@@ -565,8 +584,7 @@ func runNewtMain(ctx context.Context) {
id, // CLI arg takes precedence id, // CLI arg takes precedence
secret, // CLI arg takes precedence secret, // CLI arg takes precedence
endpoint, endpoint,
pingInterval, 30*time.Second,
pingTimeout,
opt, opt,
) )
if err != nil { if err != nil {
@@ -618,8 +636,6 @@ func runNewtMain(ctx context.Context) {
var connected bool var connected bool
var wgData WgData var wgData WgData
var dockerEventMonitor *docker.EventMonitor var dockerEventMonitor *docker.EventMonitor
logger.Debug("++++++++++++++++++++++ the port is %d", port)
if !disableClients { if !disableClients {
setupClients(client) setupClients(client)
@@ -959,7 +975,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
}, 1*time.Second) }, 2*time.Second)
return return
} }
@@ -1062,7 +1078,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
}, 1*time.Second) }, 2*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
}) })
@@ -1167,6 +1183,153 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
}) })
// Register handler for syncing targets (TCP, UDP, and health checks)
client.RegisterHandler("newt/sync", func(msg websocket.WSMessage) {
logger.Info("Received sync message")
// if there is no wgData or pm, we can't sync targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info(msgNoTunnelOrProxy)
return
}
// Define the sync data structure
type SyncData struct {
Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
}
var syncData SyncData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling sync data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &syncData); err != nil {
logger.Error("Error unmarshaling sync data: %v", err)
return
}
logger.Debug("Sync data received: TCP targets=%d, UDP targets=%d, health check targets=%d",
len(syncData.Targets.TCP), len(syncData.Targets.UDP), len(syncData.HealthCheckTargets))
//TODO: TEST AND IMPLEMENT THIS
// // Build sets of desired targets (port -> target string)
// desiredTCP := make(map[int]string)
// for _, t := range syncData.Targets.TCP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid TCP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in TCP target: %s", parts[0])
// continue
// }
// desiredTCP[port] = parts[1] + ":" + parts[2]
// }
// desiredUDP := make(map[int]string)
// for _, t := range syncData.Targets.UDP {
// parts := strings.Split(t, ":")
// if len(parts) != 3 {
// logger.Warn("Invalid UDP target format: %s", t)
// continue
// }
// port := 0
// if _, err := fmt.Sscanf(parts[0], "%d", &port); err != nil {
// logger.Warn("Invalid port in UDP target: %s", parts[0])
// continue
// }
// desiredUDP[port] = parts[1] + ":" + parts[2]
// }
// // Get current targets from proxy manager
// currentTCP, currentUDP := pm.GetTargets()
// // Sync TCP targets
// // Remove TCP targets not in desired set
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// for port := range tcpForIP {
// if _, exists := desiredTCP[port]; !exists {
// logger.Info("Sync: removing TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, tcpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add TCP targets that are missing
// for port, target := range desiredTCP {
// needsAdd := true
// if tcpForIP, ok := currentTCP[wgData.TunnelIP]; ok {
// if currentTarget, exists := tcpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating TCP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding TCP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync UDP targets
// // Remove UDP targets not in desired set
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// for port := range udpForIP {
// if _, exists := desiredUDP[port]; !exists {
// logger.Info("Sync: removing UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, udpForIP[port])
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// // Add UDP targets that are missing
// for port, target := range desiredUDP {
// needsAdd := true
// if udpForIP, ok := currentUDP[wgData.TunnelIP]; ok {
// if currentTarget, exists := udpForIP[port]; exists {
// // Check if target address changed
// if currentTarget == target {
// needsAdd = false
// } else {
// // Target changed, remove old one first
// logger.Info("Sync: updating UDP target on port %d", port)
// targetStr := fmt.Sprintf("%d:%s", port, currentTarget)
// updateTargets(pm, "remove", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// }
// if needsAdd {
// logger.Info("Sync: adding UDP target on port %d -> %s", port, target)
// targetStr := fmt.Sprintf("%d:%s", port, target)
// updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: []string{targetStr}})
// }
// }
// // Sync health check targets
// if err := healthMonitor.SyncTargets(syncData.HealthCheckTargets); err != nil {
// logger.Error("Failed to sync health check targets: %v", err)
// } else {
// logger.Info("Successfully synced health check targets")
// }
logger.Info("Sync complete")
})
// Register handler for Docker socket check // Register handler for Docker socket check
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) { client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
logger.Debug("Received Docker socket check request") logger.Debug("Received Docker socket check request")
@@ -1649,6 +1812,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
pm.Stop() pm.Stop()
} }
client.SendMessage("newt/disconnecting", map[string]any{})
if client != nil { if client != nil {
client.Close() client.Close()
} }

View File

@@ -48,6 +48,23 @@ type SubnetRule struct {
PortRanges []PortRange // empty slice means all ports allowed PortRanges []PortRange // empty slice means all ports allowed
} }
// GetAllRules returns a copy of all subnet rules
func (sl *SubnetLookup) GetAllRules() []SubnetRule {
sl.mu.RLock()
defer sl.mu.RUnlock()
var rules []SubnetRule
for _, destTriePtr := range sl.sourceTrie.All() {
if destTriePtr == nil {
continue
}
for _, rule := range destTriePtr.rules {
rules = append(rules, *rule)
}
}
return rules
}
// connKey uniquely identifies a connection for NAT tracking // connKey uniquely identifies a connection for NAT tracking
type connKey struct { type connKey struct {
srcIP string srcIP string
@@ -200,6 +217,14 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix) p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
} }
// GetAllRules returns all subnet rules from the proxy handler
func (p *ProxyHandler) GetAllRules() []SubnetRule {
if p == nil || !p.enabled {
return nil
}
return p.subnetLookup.GetAllRules()
}
// LookupDestinationRewrite looks up the rewritten destination for a connection // LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address // This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {

View File

@@ -369,6 +369,15 @@ func (net *Net) RemoveProxySubnetRule(sourcePrefix, destPrefix netip.Prefix) {
} }
} }
// GetProxySubnetRules returns all subnet rules from the proxy handler
func (net *Net) GetProxySubnetRules() []SubnetRule {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
return tun.proxyHandler.GetAllRules()
}
return nil
}
// GetProxyHandler returns the proxy handler (for advanced use cases) // GetProxyHandler returns the proxy handler (for advanced use cases)
// Returns nil if proxy is not enabled // Returns nil if proxy is not enabled
func (net *Net) GetProxyHandler() *ProxyHandler { func (net *Net) GetProxyHandler() *ProxyHandler {

View File

@@ -21,7 +21,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
) )
const errUnsupportedProtoFmt = "unsupported protocol: %s" const (
errUnsupportedProtoFmt = "unsupported protocol: %s"
maxUDPPacketSize = 65507
)
// Target represents a proxy target with its address and port // Target represents a proxy target with its address and port
type Target struct { type Target struct {
@@ -105,13 +108,9 @@ func classifyProxyError(err error) string {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return "closed" return "closed"
} }
if ne, ok := err.(net.Error); ok { var ne net.Error
if ne.Timeout() { if errors.As(err, &ne) && ne.Timeout() {
return "timeout" return "timeout"
}
if ne.Temporary() {
return "temporary"
}
} }
msg := strings.ToLower(err.Error()) msg := strings.ToLower(err.Error())
switch { switch {
@@ -437,14 +436,6 @@ func (pm *ProxyManager) Stop() error {
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
} }
// // Clear the target maps
// for k := range pm.tcpTargets {
// delete(pm.tcpTargets, k)
// }
// for k := range pm.udpTargets {
// delete(pm.udpTargets, k)
// }
// Give active connections a chance to close gracefully // Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -498,7 +489,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
if !pm.running { if !pm.running {
return return
} }
if ne, ok := err.(net.Error); ok && !ne.Temporary() { if errors.Is(err, net.ErrClosed) {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr()) logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return return
} }
@@ -564,7 +555,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
} }
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, 65507) // Max UDP packet size buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn) clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex var clientsMutex sync.RWMutex
@@ -583,7 +574,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
} }
// Check for connection closed conditions // Check for connection closed conditions
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") { if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
logger.Info("UDP connection closed, stopping proxy handler") logger.Info("UDP connection closed, stopping proxy handler")
// Clean up existing client connections // Clean up existing client connections
@@ -662,10 +653,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed) telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
}() }()
buffer := make([]byte, 65507) buffer := make([]byte, maxUDPPacketSize)
for { for {
n, _, err := targetConn.ReadFromUDP(buffer) n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil { if err != nil {
// Connection closed is normal during cleanup
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return // defer will handle cleanup, result stays "success"
}
logger.Error("Error reading from target: %v", err) logger.Error("Error reading from target: %v", err)
result = "failure" result = "failure"
return // defer will handle cleanup return // defer will handle cleanup
@@ -736,3 +731,28 @@ func (pm *ProxyManager) PrintTargets() {
} }
} }
} }
// GetTargets returns a copy of the current TCP and UDP targets
// Returns map[listenIP]map[port]targetAddress for both TCP and UDP
func (pm *ProxyManager) GetTargets() (tcpTargets map[string]map[int]string, udpTargets map[string]map[int]string) {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
tcpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.tcpTargets {
tcpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
tcpTargets[listenIP][port] = targetAddr
}
}
udpTargets = make(map[string]map[int]string)
for listenIP, targets := range pm.udpTargets {
udpTargets[listenIP] = make(map[int]string)
for port, targetAddr := range targets {
udpTargets[listenIP][port] = targetAddr
}
}
return tcpTargets, udpTargets
}

View File

@@ -2,6 +2,7 @@ package websocket
import ( import (
"bytes" "bytes"
"compress/gzip"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
@@ -37,7 +38,6 @@ type Client struct {
isConnected bool isConnected bool
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
pingInterval time.Duration pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error onConnect func() error
onTokenUpdate func(token string) onTokenUpdate func(token string)
writeMux sync.Mutex writeMux sync.Mutex
@@ -47,6 +47,11 @@ type Client struct {
metricsCtx context.Context metricsCtx context.Context
configNeedsSave bool // Flag to track if config needs to be saved configNeedsSave bool // Flag to track if config needs to be saved
serverVersion string serverVersion string
configVersion int64 // Latest config version received from server
configVersionMux sync.RWMutex
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
} }
type ClientOption func(*Client) type ClientOption func(*Client)
@@ -111,7 +116,7 @@ func (c *Client) MetricsContext() context.Context {
} }
// NewClient creates a new websocket client // NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{ config := &Config{
ID: ID, ID: ID,
Secret: secret, Secret: secret,
@@ -126,7 +131,6 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
reconnectInterval: 3 * time.Second, reconnectInterval: 3 * time.Second,
isConnected: false, isConnected: false,
pingInterval: pingInterval, pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: clientType, clientType: clientType,
} }
@@ -154,6 +158,20 @@ func (c *Client) GetServerVersion() string {
return c.serverVersion return c.serverVersion
} }
// GetConfigVersion returns the latest config version received from server
func (c *Client) GetConfigVersion() int64 {
c.configVersionMux.RLock()
defer c.configVersionMux.RUnlock()
return c.configVersion
}
// setConfigVersion updates the config version
func (c *Client) setConfigVersion(version int64) {
c.configVersionMux.Lock()
defer c.configVersionMux.Unlock()
c.configVersion = version
}
// Connect establishes the WebSocket connection // Connect establishes the WebSocket connection
func (c *Client) Connect() error { func (c *Client) Connect() error {
go c.connectWithRetry() go c.connectWithRetry()
@@ -641,7 +659,57 @@ func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
} }
// pingMonitor sends pings at a short interval and triggers reconnect on failure // pingMonitor sends pings at a short interval and triggers reconnect on failure
func (c *Client) sendPing() {
if c.conn == nil {
return
}
// Skip ping if a message is currently being processed
c.processingMux.RLock()
isProcessing := c.processingMessage
c.processingMux.RUnlock()
if isProcessing {
logger.Debug("Skipping ping, message is being processed")
return
}
c.configVersionMux.RLock()
configVersion := c.configVersion
c.configVersionMux.RUnlock()
pingMsg := WSMessage{
Type: "newt/ping",
Data: map[string]interface{}{},
ConfigVersion: configVersion,
}
c.writeMux.Lock()
err := c.conn.WriteJSON(pingMsg)
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
}
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write")
telemetry.IncWSReconnect(c.metricsContext(), "ping_write")
c.reconnect()
return
}
}
}
func (c *Client) pingMonitor() { func (c *Client) pingMonitor() {
// Send an immediate ping as soon as we connect
c.sendPing()
ticker := time.NewTicker(c.pingInterval) ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop() defer ticker.Stop()
@@ -650,29 +718,7 @@ func (c *Client) pingMonitor() {
case <-c.done: case <-c.done:
return return
case <-ticker.C: case <-ticker.C:
if c.conn == nil { c.sendPing()
return
}
c.writeMux.Lock()
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "out", "ping")
}
c.writeMux.Unlock()
if err != nil {
// Check if we're shutting down before logging error and reconnecting
select {
case <-c.done:
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
telemetry.IncWSKeepaliveFailure(c.metricsContext(), "ping_write")
telemetry.IncWSReconnect(c.metricsContext(), "ping_write")
c.reconnect()
return
}
}
} }
} }
} }
@@ -709,10 +755,13 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
disconnectResult = "success" disconnectResult = "success"
return return
default: default:
var msg WSMessage msgType, p, err := c.conn.ReadMessage()
err := c.conn.ReadJSON(&msg)
if err == nil { if err == nil {
telemetry.IncWSMessage(c.metricsContext(), "in", "text") if msgType == websocket.BinaryMessage {
telemetry.IncWSMessage(c.metricsContext(), "in", "binary")
} else {
telemetry.IncWSMessage(c.metricsContext(), "in", "text")
}
} }
if err != nil { if err != nil {
// Check if we're shutting down before logging error // Check if we're shutting down before logging error
@@ -737,9 +786,47 @@ func (c *Client) readPumpWithDisconnectDetection(started time.Time) {
} }
} }
// Update config version from incoming message
var data []byte
if msgType == websocket.BinaryMessage {
gr, err := gzip.NewReader(bytes.NewReader(p))
if err != nil {
logger.Error("WebSocket failed to create gzip reader: %v", err)
continue
}
data, err = io.ReadAll(gr)
gr.Close()
if err != nil {
logger.Error("WebSocket failed to decompress message: %v", err)
continue
}
} else {
data = p
}
var msg WSMessage
if err = json.Unmarshal(data, &msg); err != nil {
logger.Error("WebSocket failed to parse message: %v", err)
continue
}
c.setConfigVersion(msg.ConfigVersion)
c.handlersMux.RLock() c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok { if handler, ok := c.handlers[msg.Type]; ok {
// Mark that we're processing a message
c.processingMux.Lock()
c.processingMessage = true
c.processingMux.Unlock()
c.processingWg.Add(1)
handler(msg) handler(msg)
// Mark that we're done processing
c.processingWg.Done()
c.processingMux.Lock()
c.processingMessage = false
c.processingMux.Unlock()
} }
c.handlersMux.RUnlock() c.handlersMux.RUnlock()
} }

View File

@@ -17,6 +17,7 @@ type TokenResponse struct {
} }
type WSMessage struct { type WSMessage struct {
Type string `json:"type"` Type string `json:"type"`
Data interface{} `json:"data"` Data interface{} `json:"data"`
ConfigVersion int64 `json:"configVersion,omitempty"`
} }