mirror of
https://github.com/fosrl/newt.git
synced 2026-04-05 17:36:37 +00:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd4782265a | ||
|
|
2e02c9b7a9 | ||
|
|
5c329be1f3 | ||
|
|
732e788c66 | ||
|
|
aa42b3623d | ||
|
|
4f42560e26 | ||
|
|
c2187de482 | ||
|
|
5ced7d6909 | ||
|
|
4e6e79ad21 | ||
|
|
abe6e2e400 | ||
|
|
f432a17c16 | ||
|
|
6f96169ff1 | ||
|
|
575942c4be | ||
|
|
16864fc1d7 | ||
|
|
f925c681d2 | ||
|
|
e01b0ae9c7 | ||
|
|
f4d071fe27 | ||
|
|
8d82460a76 | ||
|
|
5208117c56 | ||
|
|
381f5a619c | ||
|
|
b6f13a1b55 | ||
|
|
cdaf4f7898 | ||
|
|
d4a5ac8682 | ||
|
|
1057013b50 | ||
|
|
fc4b375bf1 | ||
|
|
baca04ee58 | ||
|
|
b43572dd8d | ||
|
|
69019d5655 | ||
|
|
0f57985b6f | ||
|
|
212bdf765a | ||
|
|
b045a0f5d4 | ||
|
|
a2683eb385 | ||
|
|
d3722c2519 | ||
|
|
8fda35db4f | ||
|
|
de4353f2e6 | ||
|
|
13448f76aa | ||
|
|
836144aebf | ||
|
|
d7741df514 | ||
|
|
8e188933a2 | ||
|
|
a13c7c6e65 | ||
|
|
bc44ca1aba | ||
|
|
a76089db98 | ||
|
|
627ec2fdbc |
40
.github/workflows/cicd.yml
vendored
40
.github/workflows/cicd.yml
vendored
@@ -235,17 +235,17 @@ jobs:
|
|||||||
# uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
# uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||||
|
|
||||||
#- name: Set up Docker Buildx
|
#- name: Set up Docker Buildx
|
||||||
# uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
# uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: docker.io
|
registry: docker.io
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
- name: Log in to GHCR
|
- name: Log in to GHCR
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
@@ -259,12 +259,12 @@ jobs:
|
|||||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||||
|
|
||||||
# Build ONLY amd64 and push arch-specific tag suffixes used later for manifest creation.
|
# Build ONLY amd64 and push arch-specific tag suffixes used later for manifest creation.
|
||||||
- name: Build and push (amd64 -> *:amd64-TAG)
|
- name: Build and push (amd64 -> *:amd64-TAG)
|
||||||
id: build_amd
|
id: build_amd
|
||||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6.19.2
|
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
@@ -363,14 +363,14 @@ jobs:
|
|||||||
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
|
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: docker.io
|
registry: docker.io
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
- name: Log in to GHCR
|
- name: Log in to GHCR
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
@@ -384,12 +384,12 @@ jobs:
|
|||||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||||
|
|
||||||
# Build ONLY arm64 and push arch-specific tag suffixes used later for manifest creation.
|
# Build ONLY arm64 and push arch-specific tag suffixes used later for manifest creation.
|
||||||
- name: Build and push (arm64 -> *:arm64-TAG)
|
- name: Build and push (arm64 -> *:arm64-TAG)
|
||||||
id: build_arm
|
id: build_arm
|
||||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6.19.2
|
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
@@ -478,14 +478,14 @@ jobs:
|
|||||||
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
|
echo "Checked out $(git rev-parse --short HEAD) for tag ${TAG}"
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: docker.io
|
registry: docker.io
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
- name: Log in to GHCR
|
- name: Log in to GHCR
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
@@ -502,11 +502,11 @@ jobs:
|
|||||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||||
|
|
||||||
- name: Build and push (arm/v7 -> *:armv7-TAG)
|
- name: Build and push (arm/v7 -> *:armv7-TAG)
|
||||||
id: build_armv7
|
id: build_armv7
|
||||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6.19.2
|
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
@@ -551,14 +551,14 @@ jobs:
|
|||||||
#PUBLISH_MINOR: ${{ github.event_name == 'workflow_dispatch' && inputs.publish_minor || vars.PUBLISH_MINOR }}
|
#PUBLISH_MINOR: ${{ github.event_name == 'workflow_dispatch' && inputs.publish_minor || vars.PUBLISH_MINOR }}
|
||||||
steps:
|
steps:
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: docker.io
|
registry: docker.io
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
- name: Log in to GHCR
|
- name: Log in to GHCR
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
@@ -572,7 +572,7 @@ jobs:
|
|||||||
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
echo "DOCKERHUB_IMAGE=${DOCKERHUB_IMAGE,,}" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
- name: Set up Docker Buildx (needed for imagetools)
|
- name: Set up Docker Buildx (needed for imagetools)
|
||||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||||
|
|
||||||
- name: Create & push multi-arch index (GHCR :TAG) via imagetools
|
- name: Create & push multi-arch index (GHCR :TAG) via imagetools
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -656,14 +656,14 @@ jobs:
|
|||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: docker.io
|
registry: docker.io
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
- name: Log in to GHCR
|
- name: Log in to GHCR
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
@@ -687,7 +687,7 @@ jobs:
|
|||||||
sudo apt-get install -y jq
|
sudo apt-get install -y jq
|
||||||
|
|
||||||
- name: Set up Docker Buildx (needed for imagetools)
|
- name: Set up Docker Buildx (needed for imagetools)
|
||||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||||
|
|
||||||
- name: Resolve multi-arch digest refs (by TAG)
|
- name: Resolve multi-arch digest refs (by TAG)
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -759,7 +759,7 @@ jobs:
|
|||||||
cosign public-key --key env://COSIGN_PRIVATE_KEY >/dev/null
|
cosign public-key --key env://COSIGN_PRIVATE_KEY >/dev/null
|
||||||
|
|
||||||
- name: Generate SBOM (SPDX JSON) from GHCR digest
|
- name: Generate SBOM (SPDX JSON) from GHCR digest
|
||||||
uses: aquasecurity/trivy-action@97e0b3872f55f89b95b2f65b3dbab56962816478 # v0.34.2
|
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.35.0
|
||||||
with:
|
with:
|
||||||
image-ref: ${{ env.GHCR_REF }}
|
image-ref: ${{ env.GHCR_REF }}
|
||||||
format: spdx-json
|
format: spdx-json
|
||||||
|
|||||||
2
.github/workflows/stale-bot.yml
vendored
2
.github/workflows/stale-bot.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
|||||||
stale:
|
stale:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||||
with:
|
with:
|
||||||
days-before-stale: 14
|
days-before-stale: 14
|
||||||
days-before-close: 14
|
days-before-close: 14
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package clients
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -34,6 +36,7 @@ type WgConfig struct {
|
|||||||
IpAddress string `json:"ipAddress"`
|
IpAddress string `json:"ipAddress"`
|
||||||
Peers []Peer `json:"peers"`
|
Peers []Peer `json:"peers"`
|
||||||
Targets []Target `json:"targets"`
|
Targets []Target `json:"targets"`
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Target struct {
|
type Target struct {
|
||||||
@@ -43,6 +46,7 @@ type Target struct {
|
|||||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||||
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
||||||
PortRange []PortRange `json:"portRange,omitempty"`
|
PortRange []PortRange `json:"portRange,omitempty"`
|
||||||
|
ResourceId int `json:"resourceId,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PortRange struct {
|
type PortRange struct {
|
||||||
@@ -83,6 +87,7 @@ type WireGuardService struct {
|
|||||||
serverPubKey string
|
serverPubKey string
|
||||||
token string
|
token string
|
||||||
stopGetConfig func()
|
stopGetConfig func()
|
||||||
|
pendingConfigChainId string
|
||||||
// Netstack fields
|
// Netstack fields
|
||||||
tun tun.Device
|
tun tun.Device
|
||||||
tnet *netstack2.Net
|
tnet *netstack2.Net
|
||||||
@@ -107,6 +112,13 @@ type WireGuardService struct {
|
|||||||
wgTesterServer *wgtester.Server
|
wgTesterServer *wgtester.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||||
|
func generateChainId() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -195,6 +207,15 @@ func (s *WireGuardService) Close() {
|
|||||||
s.stopGetConfig = nil
|
s.stopGetConfig = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush access logs before tearing down the tunnel
|
||||||
|
if s.tnet != nil {
|
||||||
|
if ph := s.tnet.GetProxyHandler(); ph != nil {
|
||||||
|
if al := ph.GetAccessLogger(); al != nil {
|
||||||
|
al.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stop the direct UDP relay first
|
// Stop the direct UDP relay first
|
||||||
s.StopDirectUDPRelay()
|
s.StopDirectUDPRelay()
|
||||||
|
|
||||||
@@ -441,9 +462,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
|
|||||||
s.stopGetConfig()
|
s.stopGetConfig()
|
||||||
s.stopGetConfig = nil
|
s.stopGetConfig = nil
|
||||||
}
|
}
|
||||||
|
chainId := generateChainId()
|
||||||
|
s.pendingConfigChainId = chainId
|
||||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||||
"publicKey": s.key.PublicKey().String(),
|
"publicKey": s.key.PublicKey().String(),
|
||||||
"port": s.Port,
|
"port": s.Port,
|
||||||
|
"chainId": chainId,
|
||||||
}, 2*time.Second)
|
}, 2*time.Second)
|
||||||
|
|
||||||
logger.Debug("Requesting WireGuard configuration from remote server")
|
logger.Debug("Requesting WireGuard configuration from remote server")
|
||||||
@@ -468,6 +492,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|||||||
logger.Info("Error unmarshaling target data: %v", err)
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deduplicate using chainId: discard responses that don't match the
|
||||||
|
// pending request, or that we have already processed.
|
||||||
|
if config.ChainId != "" {
|
||||||
|
if config.ChainId != s.pendingConfigChainId {
|
||||||
|
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.pendingConfigChainId = "" // consume – further duplicates are rejected
|
||||||
|
}
|
||||||
|
|
||||||
s.config = config
|
s.config = config
|
||||||
|
|
||||||
if s.stopGetConfig != nil {
|
if s.stopGetConfig != nil {
|
||||||
@@ -662,7 +697,7 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||||
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
|
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -793,6 +828,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
|
|
||||||
s.TunnelIP = tunnelIP.String()
|
s.TunnelIP = tunnelIP.String()
|
||||||
|
|
||||||
|
// Configure the access log sender to ship compressed session logs via websocket
|
||||||
|
s.tnet.SetAccessLogSender(func(data string) error {
|
||||||
|
return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{
|
||||||
|
"compressed": data,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
// Create WireGuard device using the shared bind
|
// Create WireGuard device using the shared bind
|
||||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||||
device.LogLevelSilent, // Use silent logging by default - could be made configurable
|
device.LogLevelSilent, // Use silent logging by default - could be made configurable
|
||||||
@@ -913,7 +955,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
|
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
|
||||||
}
|
}
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1306,7 +1348,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
|||||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1424,7 +1466,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
|||||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
42
common.go
42
common.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -285,11 +286,18 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
|||||||
if tunnelID != "" {
|
if tunnelID != "" {
|
||||||
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
|
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
|
||||||
}
|
}
|
||||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
pingChainId := generateChainId()
|
||||||
|
pendingPingChainId = pingChainId
|
||||||
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
||||||
|
"chainId": pingChainId,
|
||||||
|
}, 3*time.Second)
|
||||||
// Send registration message to the server for backward compatibility
|
// Send registration message to the server for backward compatibility
|
||||||
|
bcChainId := generateChainId()
|
||||||
|
pendingRegisterChainId = bcChainId
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"backwardsCompatible": true,
|
"backwardsCompatible": true,
|
||||||
|
"chainId": bcChainId,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
@@ -509,15 +517,41 @@ func executeUpdownScript(action, proto, target string) (string, error) {
|
|||||||
return target, nil
|
return target, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendBlueprint(client *websocket.Client) error {
|
// interpolateBlueprint finds all {{...}} tokens in the raw blueprint bytes and
|
||||||
if blueprintFile == "" {
|
// replaces recognised schemes with their resolved values. Currently supported:
|
||||||
|
//
|
||||||
|
// - env.<VAR> – replaced with the value of the named environment variable
|
||||||
|
//
|
||||||
|
// Any token that does not match a supported scheme is left as-is so that
|
||||||
|
// future schemes (e.g. tag., api.) are preserved rather than silently dropped.
|
||||||
|
func interpolateBlueprint(data []byte) []byte {
|
||||||
|
re := regexp.MustCompile(`\{\{([^}]+)\}\}`)
|
||||||
|
return re.ReplaceAllFunc(data, func(match []byte) []byte {
|
||||||
|
// strip the surrounding {{ }}
|
||||||
|
inner := strings.TrimSpace(string(match[2 : len(match)-2]))
|
||||||
|
|
||||||
|
if strings.HasPrefix(inner, "env.") {
|
||||||
|
varName := strings.TrimPrefix(inner, "env.")
|
||||||
|
return []byte(os.Getenv(varName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// unrecognised scheme – leave the token untouched
|
||||||
|
return match
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBlueprint(client *websocket.Client, file string) error {
|
||||||
|
if file == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// try to read the blueprint file
|
// try to read the blueprint file
|
||||||
blueprintData, err := os.ReadFile(blueprintFile)
|
blueprintData, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to read blueprint file: %v", err)
|
logger.Error("Failed to read blueprint file: %v", err)
|
||||||
} else {
|
} else {
|
||||||
|
// interpolate {{env.VAR}} (and any future schemes) before parsing
|
||||||
|
blueprintData = interpolateBlueprint(blueprintData)
|
||||||
|
|
||||||
// first we should convert the yaml to json and error if the yaml is bad
|
// first we should convert the yaml to json and error if the yaml is bad
|
||||||
var yamlObj interface{}
|
var yamlObj interface{}
|
||||||
var blueprintJsonData string
|
var blueprintJsonData string
|
||||||
|
|||||||
@@ -35,7 +35,7 @@
|
|||||||
inherit version;
|
inherit version;
|
||||||
src = pkgs.nix-gitignore.gitignoreSource [ ] ./.;
|
src = pkgs.nix-gitignore.gitignoreSource [ ] ./.;
|
||||||
|
|
||||||
vendorHash = "sha256-kmQM8Yy5TuOiNpMpUme/2gfE+vrhUK+0AphN+p71wGs=";
|
vendorHash = "sha256-YIcuj1S+ZWAzXZOMZbppTvsDcW1W1Sy8ynfMkzLMQpM=";
|
||||||
|
|
||||||
nativeInstallCheckInputs = [ pkgs.versionCheckHook ];
|
nativeInstallCheckInputs = [ pkgs.versionCheckHook ];
|
||||||
|
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -4,7 +4,7 @@ go 1.25.0
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/docker/docker v28.5.2+incompatible
|
github.com/docker/docker v28.5.2+incompatible
|
||||||
github.com/gaissmai/bart v0.26.0
|
github.com/gaissmai/bart v0.26.1
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
@@ -24,7 +24,7 @@ require (
|
|||||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/grpc v1.79.1
|
google.golang.org/grpc v1.79.3
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||||
software.sslmate.com/src/go-pkcs12 v0.7.0
|
software.sslmate.com/src/go-pkcs12 v0.7.0
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -26,8 +26,8 @@ github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw
|
|||||||
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
|
github.com/gaissmai/bart v0.26.1 h1:+w4rnLGNlA2GDVn382Tfe3jOsK5vOr5n4KmigJ9lbTo=
|
||||||
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
github.com/gaissmai/bart v0.26.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
@@ -159,8 +159,8 @@ google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:
|
|||||||
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY=
|
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
||||||
google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY=
|
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||||
google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -365,11 +367,12 @@ func (m *Monitor) performHealthCheck(target *Target) {
|
|||||||
target.LastCheck = time.Now()
|
target.LastCheck = time.Now()
|
||||||
target.LastError = ""
|
target.LastError = ""
|
||||||
|
|
||||||
// Build URL
|
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
|
||||||
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
|
host := target.Config.Hostname
|
||||||
if target.Config.Port > 0 {
|
if target.Config.Port > 0 {
|
||||||
url = fmt.Sprintf("%s:%d", url, target.Config.Port)
|
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
|
||||||
}
|
}
|
||||||
|
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
|
||||||
if target.Config.Path != "" {
|
if target.Config.Path != "" {
|
||||||
if !strings.HasPrefix(target.Config.Path, "/") {
|
if !strings.HasPrefix(target.Config.Path, "/") {
|
||||||
url += "/"
|
url += "/"
|
||||||
|
|||||||
140
main.go
140
main.go
@@ -3,13 +3,16 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/pprof"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -45,6 +48,7 @@ type WgData struct {
|
|||||||
TunnelIP string `json:"tunnelIP"`
|
TunnelIP string `json:"tunnelIP"`
|
||||||
Targets TargetsByType `json:"targets"`
|
Targets TargetsByType `json:"targets"`
|
||||||
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TargetsByType struct {
|
type TargetsByType struct {
|
||||||
@@ -58,6 +62,7 @@ type TargetData struct {
|
|||||||
|
|
||||||
type ExitNodeData struct {
|
type ExitNodeData struct {
|
||||||
ExitNodes []ExitNode `json:"exitNodes"`
|
ExitNodes []ExitNode `json:"exitNodes"`
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExitNode represents an exit node with an ID, endpoint, and weight.
|
// ExitNode represents an exit node with an ID, endpoint, and weight.
|
||||||
@@ -127,6 +132,8 @@ var (
|
|||||||
publicKey wgtypes.Key
|
publicKey wgtypes.Key
|
||||||
pingStopChan chan struct{}
|
pingStopChan chan struct{}
|
||||||
stopFunc func()
|
stopFunc func()
|
||||||
|
pendingRegisterChainId string
|
||||||
|
pendingPingChainId string
|
||||||
healthFile string
|
healthFile string
|
||||||
useNativeInterface bool
|
useNativeInterface bool
|
||||||
authorizedKeysFile string
|
authorizedKeysFile string
|
||||||
@@ -147,7 +154,9 @@ var (
|
|||||||
adminAddr string
|
adminAddr string
|
||||||
region string
|
region string
|
||||||
metricsAsyncBytes bool
|
metricsAsyncBytes bool
|
||||||
|
pprofEnabled bool
|
||||||
blueprintFile string
|
blueprintFile string
|
||||||
|
provisioningBlueprintFile string
|
||||||
noCloud bool
|
noCloud bool
|
||||||
|
|
||||||
// New mTLS configuration variables
|
// New mTLS configuration variables
|
||||||
@@ -157,8 +166,24 @@ var (
|
|||||||
|
|
||||||
// Legacy PKCS12 support (deprecated)
|
// Legacy PKCS12 support (deprecated)
|
||||||
tlsPrivateKey string
|
tlsPrivateKey string
|
||||||
|
|
||||||
|
// Provisioning key – exchanged once for a permanent newt ID + secret
|
||||||
|
provisioningKey string
|
||||||
|
|
||||||
|
// Optional name for the site created during provisioning
|
||||||
|
newtName string
|
||||||
|
|
||||||
|
// Path to config file (overrides CONFIG_FILE env var and default location)
|
||||||
|
configFile string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||||
|
func generateChainId() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Check for subcommands first (only principals exits early)
|
// Check for subcommands first (only principals exits early)
|
||||||
if len(os.Args) > 1 {
|
if len(os.Args) > 1 {
|
||||||
@@ -225,6 +250,7 @@ func runNewtMain(ctx context.Context) {
|
|||||||
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
|
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
|
||||||
regionEnv := os.Getenv("NEWT_REGION")
|
regionEnv := os.Getenv("NEWT_REGION")
|
||||||
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
||||||
|
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
|
||||||
|
|
||||||
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
|
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
|
||||||
disableClients = disableClientsEnv == "true"
|
disableClients = disableClientsEnv == "true"
|
||||||
@@ -259,8 +285,12 @@ func runNewtMain(ctx context.Context) {
|
|||||||
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
||||||
}
|
}
|
||||||
blueprintFile = os.Getenv("BLUEPRINT_FILE")
|
blueprintFile = os.Getenv("BLUEPRINT_FILE")
|
||||||
|
provisioningBlueprintFile = os.Getenv("PROVISIONING_BLUEPRINT_FILE")
|
||||||
noCloudEnv := os.Getenv("NO_CLOUD")
|
noCloudEnv := os.Getenv("NO_CLOUD")
|
||||||
noCloud = noCloudEnv == "true"
|
noCloud = noCloudEnv == "true"
|
||||||
|
provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY")
|
||||||
|
newtName = os.Getenv("NEWT_NAME")
|
||||||
|
configFile = os.Getenv("CONFIG_FILE")
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||||
@@ -309,6 +339,15 @@ func runNewtMain(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
// load the prefer endpoint just as a flag
|
// load the prefer endpoint just as a flag
|
||||||
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
|
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
|
||||||
|
if provisioningKey == "" {
|
||||||
|
flag.StringVar(&provisioningKey, "provisioning-key", "", "One-time provisioning key used to obtain a newt ID and secret from the server")
|
||||||
|
}
|
||||||
|
if newtName == "" {
|
||||||
|
flag.StringVar(&newtName, "name", "", "Name for the site created during provisioning (supports {{env.VAR}} interpolation)")
|
||||||
|
}
|
||||||
|
if configFile == "" {
|
||||||
|
flag.StringVar(&configFile, "config-file", "", "Path to config file (overrides CONFIG_FILE env var and default location)")
|
||||||
|
}
|
||||||
|
|
||||||
// Add new mTLS flags
|
// Add new mTLS flags
|
||||||
if tlsClientCert == "" {
|
if tlsClientCert == "" {
|
||||||
@@ -356,6 +395,9 @@ func runNewtMain(ctx context.Context) {
|
|||||||
if blueprintFile == "" {
|
if blueprintFile == "" {
|
||||||
flag.StringVar(&blueprintFile, "blueprint-file", "", "Path to blueprint file (if unset, no blueprint will be applied)")
|
flag.StringVar(&blueprintFile, "blueprint-file", "", "Path to blueprint file (if unset, no blueprint will be applied)")
|
||||||
}
|
}
|
||||||
|
if provisioningBlueprintFile == "" {
|
||||||
|
flag.StringVar(&provisioningBlueprintFile, "provisioning-blueprint-file", "", "Path to blueprint file applied once after a provisioning credential exchange (if unset, no provisioning blueprint will be applied)")
|
||||||
|
}
|
||||||
if noCloudEnv == "" {
|
if noCloudEnv == "" {
|
||||||
flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover")
|
flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover")
|
||||||
}
|
}
|
||||||
@@ -390,6 +432,14 @@ func runNewtMain(ctx context.Context) {
|
|||||||
metricsAsyncBytes = v
|
metricsAsyncBytes = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// pprof debug endpoint toggle
|
||||||
|
if pprofEnabledEnv == "" {
|
||||||
|
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
|
||||||
|
} else {
|
||||||
|
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
|
||||||
|
pprofEnabled = v
|
||||||
|
}
|
||||||
|
}
|
||||||
// Optional region flag (resource attribute)
|
// Optional region flag (resource attribute)
|
||||||
if regionEnv == "" {
|
if regionEnv == "" {
|
||||||
flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)")
|
flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)")
|
||||||
@@ -485,6 +535,14 @@ func runNewtMain(ctx context.Context) {
|
|||||||
if tel.PrometheusHandler != nil {
|
if tel.PrometheusHandler != nil {
|
||||||
mux.Handle("/metrics", tel.PrometheusHandler)
|
mux.Handle("/metrics", tel.PrometheusHandler)
|
||||||
}
|
}
|
||||||
|
if pprofEnabled {
|
||||||
|
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||||
|
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||||
|
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||||
|
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||||
|
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||||
|
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
|
||||||
|
}
|
||||||
admin := &http.Server{
|
admin := &http.Server{
|
||||||
Addr: tcfg.AdminAddr,
|
Addr: tcfg.AdminAddr,
|
||||||
Handler: otelhttp.NewHandler(mux, "newt-admin"),
|
Handler: otelhttp.NewHandler(mux, "newt-admin"),
|
||||||
@@ -567,10 +625,20 @@ func runNewtMain(ctx context.Context) {
|
|||||||
endpoint,
|
endpoint,
|
||||||
30*time.Second,
|
30*time.Second,
|
||||||
opt,
|
opt,
|
||||||
|
websocket.WithConfigFile(configFile),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create client: %v", err)
|
logger.Fatal("Failed to create client: %v", err)
|
||||||
}
|
}
|
||||||
|
// If a provisioning key was supplied via CLI / env and the config file did
|
||||||
|
// not already carry one, inject it now so provisionIfNeeded() can use it.
|
||||||
|
if provisioningKey != "" && client.GetConfig().ProvisioningKey == "" {
|
||||||
|
client.GetConfig().ProvisioningKey = provisioningKey
|
||||||
|
}
|
||||||
|
if newtName != "" && client.GetConfig().Name == "" {
|
||||||
|
client.GetConfig().Name = newtName
|
||||||
|
}
|
||||||
|
|
||||||
endpoint = client.GetConfig().Endpoint // Update endpoint from config
|
endpoint = client.GetConfig().Endpoint // Update endpoint from config
|
||||||
id = client.GetConfig().ID // Update ID from config
|
id = client.GetConfig().ID // Update ID from config
|
||||||
// Update site labels for metrics with the resolved ID
|
// Update site labels for metrics with the resolved ID
|
||||||
@@ -687,6 +755,24 @@ func runNewtMain(ctx context.Context) {
|
|||||||
defer func() {
|
defer func() {
|
||||||
telemetry.IncSiteRegistration(ctx, regResult)
|
telemetry.IncSiteRegistration(ctx, regResult)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Deduplicate using chainId: if the server echoes back a chainId we have
|
||||||
|
// already consumed (or one that doesn't match our current pending request),
|
||||||
|
// throw the message away to avoid setting up the tunnel twice.
|
||||||
|
var chainData struct {
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
|
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
|
||||||
|
_ = json.Unmarshal(jsonBytes, &chainData)
|
||||||
|
}
|
||||||
|
if chainData.ChainId != "" {
|
||||||
|
if chainData.ChainId != pendingRegisterChainId {
|
||||||
|
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pendingRegisterChainId = "" // consume – further duplicates with this id are rejected
|
||||||
|
}
|
||||||
|
|
||||||
if stopFunc != nil {
|
if stopFunc != nil {
|
||||||
stopFunc() // stop the ws from sending more requests
|
stopFunc() // stop the ws from sending more requests
|
||||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||||
@@ -871,8 +957,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Request exit nodes from the server
|
// Request exit nodes from the server
|
||||||
|
pingChainId := generateChainId()
|
||||||
|
pendingPingChainId = pingChainId
|
||||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
||||||
"noCloud": noCloud,
|
"noCloud": noCloud,
|
||||||
|
"chainId": pingChainId,
|
||||||
}, 3*time.Second)
|
}, 3*time.Second)
|
||||||
|
|
||||||
logger.Info("Tunnel destroyed, ready for reconnection")
|
logger.Info("Tunnel destroyed, ready for reconnection")
|
||||||
@@ -901,6 +990,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
|
|
||||||
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received ping message")
|
logger.Debug("Received ping message")
|
||||||
|
|
||||||
if stopFunc != nil {
|
if stopFunc != nil {
|
||||||
stopFunc() // stop the ws from sending more requests
|
stopFunc() // stop the ws from sending more requests
|
||||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||||
@@ -920,6 +1010,14 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
}
|
}
|
||||||
exitNodes := exitNodeData.ExitNodes
|
exitNodes := exitNodeData.ExitNodes
|
||||||
|
|
||||||
|
if exitNodeData.ChainId != "" {
|
||||||
|
if exitNodeData.ChainId != pendingPingChainId {
|
||||||
|
logger.Debug("Discarding duplicate/stale newt/ping/exitNodes (chainId=%s, expected=%s)", exitNodeData.ChainId, pendingPingChainId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pendingPingChainId = "" // consume – further duplicates with this id are rejected
|
||||||
|
}
|
||||||
|
|
||||||
if len(exitNodes) == 0 {
|
if len(exitNodes) == 0 {
|
||||||
logger.Info("No exit nodes provided")
|
logger.Info("No exit nodes provided")
|
||||||
return
|
return
|
||||||
@@ -952,10 +1050,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chainId := generateChainId()
|
||||||
|
pendingRegisterChainId = chainId
|
||||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"pingResults": pingResults,
|
"pingResults": pingResults,
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
|
"chainId": chainId,
|
||||||
}, 2*time.Second)
|
}, 2*time.Second)
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -1055,10 +1156,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send the ping results to the cloud for selection
|
// Send the ping results to the cloud for selection
|
||||||
|
chainId := generateChainId()
|
||||||
|
pendingRegisterChainId = chainId
|
||||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"pingResults": pingResults,
|
"pingResults": pingResults,
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
|
"chainId": chainId,
|
||||||
}, 2*time.Second)
|
}, 2*time.Second)
|
||||||
|
|
||||||
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
||||||
@@ -1708,8 +1812,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
stopFunc()
|
stopFunc()
|
||||||
}
|
}
|
||||||
// request from the server the list of nodes to ping
|
// request from the server the list of nodes to ping
|
||||||
|
pingChainId := generateChainId()
|
||||||
|
pendingPingChainId = pingChainId
|
||||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
||||||
"noCloud": noCloud,
|
"noCloud": noCloud,
|
||||||
|
"chainId": pingChainId,
|
||||||
}, 3*time.Second)
|
}, 3*time.Second)
|
||||||
logger.Debug("Requesting exit nodes from server")
|
logger.Debug("Requesting exit nodes from server")
|
||||||
|
|
||||||
@@ -1718,17 +1825,46 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
} else {
|
} else {
|
||||||
logger.Warn("CLIENTS WILL NOT WORK ON THIS VERSION OF NEWT WITH THIS VERSION OF PANGOLIN, PLEASE UPDATE THE SERVER TO 1.13 OR HIGHER OR DOWNGRADE NEWT")
|
logger.Warn("CLIENTS WILL NOT WORK ON THIS VERSION OF NEWT WITH THIS VERSION OF PANGOLIN, PLEASE UPDATE THE SERVER TO 1.13 OR HIGHER OR DOWNGRADE NEWT")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sendBlueprint(client, blueprintFile)
|
||||||
|
if client.WasJustProvisioned() {
|
||||||
|
logger.Info("Provisioning detected – sending provisioning blueprint")
|
||||||
|
sendBlueprint(client, provisioningBlueprintFile)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Resend current health check status for all targets in case the server
|
||||||
|
// missed updates while newt was disconnected.
|
||||||
|
targets := healthMonitor.GetTargets()
|
||||||
|
if len(targets) > 0 {
|
||||||
|
healthStatuses := make(map[int]interface{})
|
||||||
|
for id, target := range targets {
|
||||||
|
healthStatuses[id] = map[string]interface{}{
|
||||||
|
"status": target.Status.String(),
|
||||||
|
"lastCheck": target.LastCheck.Format(time.RFC3339),
|
||||||
|
"checkCount": target.CheckCount,
|
||||||
|
"lastError": target.LastError,
|
||||||
|
"config": target.Config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Debug("Reconnected: resending health check status for %d targets", len(healthStatuses))
|
||||||
|
if err := client.SendMessage("newt/healthcheck/status", map[string]interface{}{
|
||||||
|
"targets": healthStatuses,
|
||||||
|
}); err != nil {
|
||||||
|
logger.Error("Failed to resend health check status on reconnect: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send registration message to the server for backward compatibility
|
// Send registration message to the server for backward compatibility
|
||||||
|
bcChainId := generateChainId()
|
||||||
|
pendingRegisterChainId = bcChainId
|
||||||
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
"backwardsCompatible": true,
|
"backwardsCompatible": true,
|
||||||
|
"chainId": bcChainId,
|
||||||
})
|
})
|
||||||
|
|
||||||
sendBlueprint(client)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
514
netstack2/access_log.go
Normal file
514
netstack2/access_log.go
Normal file
@@ -0,0 +1,514 @@
|
|||||||
|
package netstack2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/zlib"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// flushInterval is how often the access logger flushes completed sessions to the server
|
||||||
|
flushInterval = 60 * time.Second
|
||||||
|
|
||||||
|
// maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush
|
||||||
|
maxBufferedSessions = 100
|
||||||
|
|
||||||
|
// sessionGapThreshold is the maximum gap between the end of one connection
|
||||||
|
// and the start of the next for them to be considered part of the same session.
|
||||||
|
// If the gap exceeds this, a new consolidated session is created.
|
||||||
|
sessionGapThreshold = 5 * time.Second
|
||||||
|
|
||||||
|
// minConnectionsToConsolidate is the minimum number of connections in a group
|
||||||
|
// before we bother consolidating. Groups smaller than this are sent as-is.
|
||||||
|
minConnectionsToConsolidate = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// SendFunc is a callback that sends compressed access log data to the server.
|
||||||
|
// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects.
|
||||||
|
type SendFunc func(data string) error
|
||||||
|
|
||||||
|
// AccessSession represents a tracked access session through the proxy
|
||||||
|
type AccessSession struct {
|
||||||
|
SessionID string `json:"sessionId"`
|
||||||
|
ResourceID int `json:"resourceId"`
|
||||||
|
SourceAddr string `json:"sourceAddr"`
|
||||||
|
DestAddr string `json:"destAddr"`
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
StartedAt time.Time `json:"startedAt"`
|
||||||
|
EndedAt time.Time `json:"endedAt,omitempty"`
|
||||||
|
BytesTx int64 `json:"bytesTx"`
|
||||||
|
BytesRx int64 `json:"bytesRx"`
|
||||||
|
ConnectionCount int `json:"connectionCount,omitempty"` // number of raw connections merged into this session (0 or 1 = single)
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpSessionKey identifies a unique UDP "session" by src -> dst
|
||||||
|
type udpSessionKey struct {
|
||||||
|
srcAddr string
|
||||||
|
dstAddr string
|
||||||
|
protocol string
|
||||||
|
}
|
||||||
|
|
||||||
|
// consolidationKey groups connections that may be part of the same logical session.
|
||||||
|
// Source port is intentionally excluded so that many ephemeral-port connections
|
||||||
|
// from the same source IP to the same destination are grouped together.
|
||||||
|
type consolidationKey struct {
|
||||||
|
sourceIP string // IP only, no port
|
||||||
|
destAddr string // full host:port of the destination
|
||||||
|
protocol string
|
||||||
|
resourceID int
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessLogger tracks access sessions for resources and periodically
|
||||||
|
// flushes completed sessions to the server via a configurable SendFunc.
|
||||||
|
type AccessLogger struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
sessions map[string]*AccessSession // active sessions: sessionID -> session
|
||||||
|
udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup
|
||||||
|
completedSessions []*AccessSession // completed sessions waiting to be flushed
|
||||||
|
udpTimeout time.Duration
|
||||||
|
sendFn SendFunc
|
||||||
|
stopCh chan struct{}
|
||||||
|
flushDone chan struct{} // closed after the flush goroutine exits
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAccessLogger creates a new access logger.
|
||||||
|
// udpTimeout controls how long a UDP session is kept alive without traffic before being ended.
|
||||||
|
func NewAccessLogger(udpTimeout time.Duration) *AccessLogger {
|
||||||
|
al := &AccessLogger{
|
||||||
|
sessions: make(map[string]*AccessSession),
|
||||||
|
udpSessions: make(map[udpSessionKey]*AccessSession),
|
||||||
|
completedSessions: make([]*AccessSession, 0),
|
||||||
|
udpTimeout: udpTimeout,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
flushDone: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go al.backgroundLoop()
|
||||||
|
return al
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSendFunc sets the callback used to send compressed access log batches
|
||||||
|
// to the server. This can be called after construction once the websocket
|
||||||
|
// client is available.
|
||||||
|
func (al *AccessLogger) SetSendFunc(fn SendFunc) {
|
||||||
|
al.mu.Lock()
|
||||||
|
defer al.mu.Unlock()
|
||||||
|
al.sendFn = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSessionID creates a random session identifier
|
||||||
|
func generateSessionID() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartTCPSession logs the start of a TCP session and returns a session ID.
|
||||||
|
func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string {
|
||||||
|
sessionID := generateSessionID()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
session := &AccessSession{
|
||||||
|
SessionID: sessionID,
|
||||||
|
ResourceID: resourceID,
|
||||||
|
SourceAddr: srcAddr,
|
||||||
|
DestAddr: dstAddr,
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
al.mu.Lock()
|
||||||
|
al.sessions[sessionID] = session
|
||||||
|
al.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s",
|
||||||
|
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
|
||||||
|
|
||||||
|
return sessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndTCPSession logs the end of a TCP session and queues it for sending.
|
||||||
|
func (al *AccessLogger) EndTCPSession(sessionID string) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
al.mu.Lock()
|
||||||
|
session, ok := al.sessions[sessionID]
|
||||||
|
if ok {
|
||||||
|
session.EndedAt = now
|
||||||
|
delete(al.sessions, sessionID)
|
||||||
|
al.completedSessions = append(al.completedSessions, session)
|
||||||
|
}
|
||||||
|
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
|
||||||
|
al.mu.Unlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
duration := now.Sub(session.StartedAt)
|
||||||
|
logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s",
|
||||||
|
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
|
||||||
|
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldFlush {
|
||||||
|
al.flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackUDPSession starts or returns an existing UDP session. Returns the session ID.
|
||||||
|
func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string {
|
||||||
|
key := udpSessionKey{
|
||||||
|
srcAddr: srcAddr,
|
||||||
|
dstAddr: dstAddr,
|
||||||
|
protocol: "udp",
|
||||||
|
}
|
||||||
|
|
||||||
|
al.mu.Lock()
|
||||||
|
defer al.mu.Unlock()
|
||||||
|
|
||||||
|
if existing, ok := al.udpSessions[key]; ok {
|
||||||
|
return existing.SessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := generateSessionID()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
session := &AccessSession{
|
||||||
|
SessionID: sessionID,
|
||||||
|
ResourceID: resourceID,
|
||||||
|
SourceAddr: srcAddr,
|
||||||
|
DestAddr: dstAddr,
|
||||||
|
Protocol: "udp",
|
||||||
|
StartedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
al.sessions[sessionID] = session
|
||||||
|
al.udpSessions[key] = session
|
||||||
|
|
||||||
|
logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s",
|
||||||
|
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
|
||||||
|
|
||||||
|
return sessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndUDPSession ends a UDP session and queues it for sending.
|
||||||
|
func (al *AccessLogger) EndUDPSession(sessionID string) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
al.mu.Lock()
|
||||||
|
session, ok := al.sessions[sessionID]
|
||||||
|
if ok {
|
||||||
|
session.EndedAt = now
|
||||||
|
delete(al.sessions, sessionID)
|
||||||
|
key := udpSessionKey{
|
||||||
|
srcAddr: session.SourceAddr,
|
||||||
|
dstAddr: session.DestAddr,
|
||||||
|
protocol: "udp",
|
||||||
|
}
|
||||||
|
delete(al.udpSessions, key)
|
||||||
|
al.completedSessions = append(al.completedSessions, session)
|
||||||
|
}
|
||||||
|
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
|
||||||
|
al.mu.Unlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
duration := now.Sub(session.StartedAt)
|
||||||
|
logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
|
||||||
|
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
|
||||||
|
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldFlush {
|
||||||
|
al.flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// backgroundLoop handles periodic flushing and stale session reaping.
|
||||||
|
func (al *AccessLogger) backgroundLoop() {
|
||||||
|
defer close(al.flushDone)
|
||||||
|
|
||||||
|
flushTicker := time.NewTicker(flushInterval)
|
||||||
|
defer flushTicker.Stop()
|
||||||
|
|
||||||
|
reapTicker := time.NewTicker(30 * time.Second)
|
||||||
|
defer reapTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-al.stopCh:
|
||||||
|
return
|
||||||
|
case <-flushTicker.C:
|
||||||
|
al.flush()
|
||||||
|
case <-reapTicker.C:
|
||||||
|
al.reapStaleSessions()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reapStaleSessions cleans up UDP sessions that were not properly ended.
|
||||||
|
func (al *AccessLogger) reapStaleSessions() {
|
||||||
|
al.mu.Lock()
|
||||||
|
defer al.mu.Unlock()
|
||||||
|
|
||||||
|
staleThreshold := time.Now().Add(-5 * time.Minute)
|
||||||
|
|
||||||
|
for key, session := range al.udpSessions {
|
||||||
|
if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() {
|
||||||
|
now := time.Now()
|
||||||
|
session.EndedAt = now
|
||||||
|
duration := now.Sub(session.StartedAt)
|
||||||
|
logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
|
||||||
|
session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
|
||||||
|
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||||
|
al.completedSessions = append(al.completedSessions, session)
|
||||||
|
delete(al.sessions, session.SessionID)
|
||||||
|
delete(al.udpSessions, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractIP strips the port from an address string and returns just the IP.
|
||||||
|
// If the address has no port component it is returned as-is.
|
||||||
|
func extractIP(addr string) string {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
// Might already be a bare IP
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// consolidateSessions takes a slice of completed sessions and merges bursts of
|
||||||
|
// short-lived connections from the same source IP to the same destination into
|
||||||
|
// single higher-level session entries.
|
||||||
|
//
|
||||||
|
// The algorithm:
|
||||||
|
// 1. Group sessions by (sourceIP, destAddr, protocol, resourceID).
|
||||||
|
// 2. Within each group, sort by StartedAt.
|
||||||
|
// 3. Walk through the sorted list and merge consecutive sessions whose gap
|
||||||
|
// (previous EndedAt → next StartedAt) is ≤ sessionGapThreshold.
|
||||||
|
// 4. For merged sessions the earliest StartedAt and latest EndedAt are kept,
|
||||||
|
// bytes are summed, and ConnectionCount records how many raw connections
|
||||||
|
// were folded in. If the merged connections used more than one source port,
|
||||||
|
// SourceAddr is set to just the IP (port omitted).
|
||||||
|
// 5. Groups with fewer than minConnectionsToConsolidate members are passed
|
||||||
|
// through unmodified.
|
||||||
|
func consolidateSessions(sessions []*AccessSession) []*AccessSession {
|
||||||
|
if len(sessions) <= 1 {
|
||||||
|
return sessions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group sessions by consolidation key
|
||||||
|
groups := make(map[consolidationKey][]*AccessSession)
|
||||||
|
for _, s := range sessions {
|
||||||
|
key := consolidationKey{
|
||||||
|
sourceIP: extractIP(s.SourceAddr),
|
||||||
|
destAddr: s.DestAddr,
|
||||||
|
protocol: s.Protocol,
|
||||||
|
resourceID: s.ResourceID,
|
||||||
|
}
|
||||||
|
groups[key] = append(groups[key], s)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]*AccessSession, 0, len(sessions))
|
||||||
|
|
||||||
|
for key, group := range groups {
|
||||||
|
// Small groups don't need consolidation
|
||||||
|
if len(group) < minConnectionsToConsolidate {
|
||||||
|
result = append(result, group...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the group by start time so we can detect gaps
|
||||||
|
sort.Slice(group, func(i, j int) bool {
|
||||||
|
return group[i].StartedAt.Before(group[j].StartedAt)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Walk through and merge runs that are within the gap threshold
|
||||||
|
var merged []*AccessSession
|
||||||
|
cur := cloneSession(group[0])
|
||||||
|
cur.ConnectionCount = 1
|
||||||
|
sourcePorts := make(map[string]struct{})
|
||||||
|
sourcePorts[cur.SourceAddr] = struct{}{}
|
||||||
|
|
||||||
|
for i := 1; i < len(group); i++ {
|
||||||
|
s := group[i]
|
||||||
|
|
||||||
|
// Determine the gap: from the latest end time we've seen so far to the
|
||||||
|
// start of the next connection.
|
||||||
|
gapRef := cur.EndedAt
|
||||||
|
if gapRef.IsZero() {
|
||||||
|
gapRef = cur.StartedAt
|
||||||
|
}
|
||||||
|
gap := s.StartedAt.Sub(gapRef)
|
||||||
|
|
||||||
|
if gap <= sessionGapThreshold {
|
||||||
|
// Merge into the current consolidated session
|
||||||
|
cur.ConnectionCount++
|
||||||
|
cur.BytesTx += s.BytesTx
|
||||||
|
cur.BytesRx += s.BytesRx
|
||||||
|
sourcePorts[s.SourceAddr] = struct{}{}
|
||||||
|
|
||||||
|
// Extend EndedAt to the latest time
|
||||||
|
endTime := s.EndedAt
|
||||||
|
if endTime.IsZero() {
|
||||||
|
endTime = s.StartedAt
|
||||||
|
}
|
||||||
|
if endTime.After(cur.EndedAt) {
|
||||||
|
cur.EndedAt = endTime
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Gap exceeded — finalize the current session and start a new one
|
||||||
|
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
|
||||||
|
merged = append(merged, cur)
|
||||||
|
|
||||||
|
cur = cloneSession(s)
|
||||||
|
cur.ConnectionCount = 1
|
||||||
|
sourcePorts = make(map[string]struct{})
|
||||||
|
sourcePorts[s.SourceAddr] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize the last accumulated session
|
||||||
|
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
|
||||||
|
merged = append(merged, cur)
|
||||||
|
|
||||||
|
result = append(result, merged...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneSession creates a shallow copy of an AccessSession.
|
||||||
|
func cloneSession(s *AccessSession) *AccessSession {
|
||||||
|
cp := *s
|
||||||
|
return &cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// finalizeMergedSourceAddr sets the SourceAddr on a consolidated session.
|
||||||
|
// If multiple distinct source addresses (ports) were seen, the port is
|
||||||
|
// stripped and only the IP is kept so the log isn't misleading.
|
||||||
|
func finalizeMergedSourceAddr(s *AccessSession, sourceIP string, ports map[string]struct{}) {
|
||||||
|
if len(ports) > 1 {
|
||||||
|
// Multiple source ports — just report the IP
|
||||||
|
s.SourceAddr = sourceIP
|
||||||
|
}
|
||||||
|
// Otherwise keep the original SourceAddr which already has ip:port
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush drains the completed sessions buffer, consolidates bursts of
|
||||||
|
// short-lived connections, compresses with zlib, and sends via the SendFunc.
|
||||||
|
func (al *AccessLogger) flush() {
|
||||||
|
al.mu.Lock()
|
||||||
|
if len(al.completedSessions) == 0 {
|
||||||
|
al.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch := al.completedSessions
|
||||||
|
al.completedSessions = make([]*AccessSession, 0)
|
||||||
|
sendFn := al.sendFn
|
||||||
|
al.mu.Unlock()
|
||||||
|
|
||||||
|
if sendFn == nil {
|
||||||
|
logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consolidate bursts of short-lived connections into higher-level sessions
|
||||||
|
originalCount := len(batch)
|
||||||
|
batch = consolidateSessions(batch)
|
||||||
|
if len(batch) != originalCount {
|
||||||
|
logger.Info("Access logger: consolidated %d raw connections into %d sessions", originalCount, len(batch))
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, err := compressSessions(batch)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sendFn(compressed); err != nil {
|
||||||
|
logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err)
|
||||||
|
// Re-queue the batch so we don't lose data
|
||||||
|
al.mu.Lock()
|
||||||
|
al.completedSessions = append(batch, al.completedSessions...)
|
||||||
|
// Cap re-queued data to prevent unbounded growth if server is unreachable
|
||||||
|
if len(al.completedSessions) > maxBufferedSessions*5 {
|
||||||
|
dropped := len(al.completedSessions) - maxBufferedSessions*5
|
||||||
|
al.completedSessions = al.completedSessions[:maxBufferedSessions*5]
|
||||||
|
logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped)
|
||||||
|
}
|
||||||
|
al.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Access logger: sent %d sessions to server", len(batch))
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressSessions JSON-encodes the sessions, compresses with zlib, and returns
|
||||||
|
// a base64-encoded string suitable for embedding in a JSON message.
|
||||||
|
func compressSessions(sessions []*AccessSession) (string, error) {
|
||||||
|
jsonData, err := json.Marshal(sessions)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if _, err := w.Write(jsonData); err != nil {
|
||||||
|
w.Close()
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the background loop, ends all active sessions,
|
||||||
|
// and performs one final flush to send everything to the server.
|
||||||
|
func (al *AccessLogger) Close() {
|
||||||
|
// Signal the background loop to stop
|
||||||
|
select {
|
||||||
|
case <-al.stopCh:
|
||||||
|
// Already closed
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
close(al.stopCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the background loop to exit so we don't race on flush
|
||||||
|
<-al.flushDone
|
||||||
|
|
||||||
|
al.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// End all active sessions and move them to the completed buffer
|
||||||
|
for _, session := range al.sessions {
|
||||||
|
if session.EndedAt.IsZero() {
|
||||||
|
session.EndedAt = now
|
||||||
|
duration := now.Sub(session.StartedAt)
|
||||||
|
logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s",
|
||||||
|
session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr,
|
||||||
|
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||||
|
al.completedSessions = append(al.completedSessions, session)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
al.sessions = make(map[string]*AccessSession)
|
||||||
|
al.udpSessions = make(map[udpSessionKey]*AccessSession)
|
||||||
|
al.mu.Unlock()
|
||||||
|
|
||||||
|
// Final flush to send all remaining sessions to the server
|
||||||
|
al.flush()
|
||||||
|
}
|
||||||
811
netstack2/access_log_test.go
Normal file
811
netstack2/access_log_test.go
Normal file
@@ -0,0 +1,811 @@
|
|||||||
|
package netstack2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"ipv4 with port", "192.168.1.1:12345", "192.168.1.1"},
|
||||||
|
{"ipv4 without port", "192.168.1.1", "192.168.1.1"},
|
||||||
|
{"ipv6 with port", "[::1]:12345", "::1"},
|
||||||
|
{"ipv6 without port", "::1", "::1"},
|
||||||
|
{"empty string", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractIP(tt.addr)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_Empty(t *testing.T) {
|
||||||
|
result := consolidateSessions(nil)
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected nil, got %v", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = consolidateSessions([]*AccessSession{})
|
||||||
|
if len(result) != 0 {
|
||||||
|
t.Errorf("expected empty slice, got %d items", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_SingleSession(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "abc123",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(1 * time.Second),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Fatalf("expected 1 session, got %d", len(result))
|
||||||
|
}
|
||||||
|
if result[0].SourceAddr != "10.0.0.1:5000" {
|
||||||
|
t.Errorf("expected source addr preserved, got %q", result[0].SourceAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_MergesBurstFromSameSourceIP(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
BytesTx: 100,
|
||||||
|
BytesRx: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
BytesTx: 150,
|
||||||
|
BytesRx: 250,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s3",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5002",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(400 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(500 * time.Millisecond),
|
||||||
|
BytesTx: 50,
|
||||||
|
BytesRx: 75,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Fatalf("expected 1 consolidated session, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
s := result[0]
|
||||||
|
if s.ConnectionCount != 3 {
|
||||||
|
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
|
||||||
|
}
|
||||||
|
if s.SourceAddr != "10.0.0.1" {
|
||||||
|
t.Errorf("expected source addr to be IP only (multiple ports), got %q", s.SourceAddr)
|
||||||
|
}
|
||||||
|
if s.DestAddr != "192.168.1.100:443" {
|
||||||
|
t.Errorf("expected dest addr preserved, got %q", s.DestAddr)
|
||||||
|
}
|
||||||
|
if s.StartedAt != now {
|
||||||
|
t.Errorf("expected StartedAt to be earliest time")
|
||||||
|
}
|
||||||
|
if s.EndedAt != now.Add(500*time.Millisecond) {
|
||||||
|
t.Errorf("expected EndedAt to be latest time")
|
||||||
|
}
|
||||||
|
expectedTx := int64(300)
|
||||||
|
expectedRx := int64(525)
|
||||||
|
if s.BytesTx != expectedTx {
|
||||||
|
t.Errorf("expected BytesTx=%d, got %d", expectedTx, s.BytesTx)
|
||||||
|
}
|
||||||
|
if s.BytesRx != expectedRx {
|
||||||
|
t.Errorf("expected BytesRx=%d, got %d", expectedRx, s.BytesRx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_SameSourcePortPreserved(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Fatalf("expected 1 session, got %d", len(result))
|
||||||
|
}
|
||||||
|
if result[0].SourceAddr != "10.0.0.1:5000" {
|
||||||
|
t.Errorf("expected source addr with port preserved when all ports are the same, got %q", result[0].SourceAddr)
|
||||||
|
}
|
||||||
|
if result[0].ConnectionCount != 2 {
|
||||||
|
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_GapSplitsSessions(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// First burst
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
},
|
||||||
|
// Big gap here (10 seconds)
|
||||||
|
{
|
||||||
|
SessionID: "s3",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5002",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(10 * time.Second),
|
||||||
|
EndedAt: now.Add(10*time.Second + 100*time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s4",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5003",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(10*time.Second + 200*time.Millisecond),
|
||||||
|
EndedAt: now.Add(10*time.Second + 300*time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 consolidated sessions (gap split), got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the sessions by their start time
|
||||||
|
var first, second *AccessSession
|
||||||
|
for _, s := range result {
|
||||||
|
if s.StartedAt.Equal(now) {
|
||||||
|
first = s
|
||||||
|
} else {
|
||||||
|
second = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if first == nil || second == nil {
|
||||||
|
t.Fatal("could not find both consolidated sessions")
|
||||||
|
}
|
||||||
|
|
||||||
|
if first.ConnectionCount != 2 {
|
||||||
|
t.Errorf("first burst: expected ConnectionCount=2, got %d", first.ConnectionCount)
|
||||||
|
}
|
||||||
|
if second.ConnectionCount != 2 {
|
||||||
|
t.Errorf("second burst: expected ConnectionCount=2, got %d", second.ConnectionCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_DifferentDestinationsNotMerged(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:8080",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
// Each goes to a different dest port so they should not be merged
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 sessions (different destinations), got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_DifferentProtocolsNotMerged(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "udp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 sessions (different protocols), got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_DifferentResourceIDsNotMerged(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 2,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 sessions (different resource IDs), got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_DifferentSourceIPsNotMerged(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.2:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 sessions (different source IPs), got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_OutOfOrderInput(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
// Provide sessions out of chronological order to verify sorting
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s3",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5002",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(400 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(500 * time.Millisecond),
|
||||||
|
BytesTx: 30,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
BytesTx: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
BytesTx: 20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Fatalf("expected 1 consolidated session, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
s := result[0]
|
||||||
|
if s.ConnectionCount != 3 {
|
||||||
|
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
|
||||||
|
}
|
||||||
|
if s.StartedAt != now {
|
||||||
|
t.Errorf("expected StartedAt to be earliest time")
|
||||||
|
}
|
||||||
|
if s.EndedAt != now.Add(500*time.Millisecond) {
|
||||||
|
t.Errorf("expected EndedAt to be latest time")
|
||||||
|
}
|
||||||
|
if s.BytesTx != 60 {
|
||||||
|
t.Errorf("expected BytesTx=60, got %d", s.BytesTx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_ExactlyAtGapThreshold(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Starts exactly sessionGapThreshold after s1 ends — should still merge
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold),
|
||||||
|
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Fatalf("expected 1 session (gap exactly at threshold merges), got %d", len(result))
|
||||||
|
}
|
||||||
|
if result[0].ConnectionCount != 2 {
|
||||||
|
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_JustOverGapThreshold(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Starts 1ms over the gap threshold after s1 ends — should split
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 1*time.Millisecond),
|
||||||
|
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 sessions (gap just over threshold splits), got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_UDPSessions(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "u1",
|
||||||
|
ResourceID: 5,
|
||||||
|
SourceAddr: "10.0.0.1:6000",
|
||||||
|
DestAddr: "192.168.1.100:53",
|
||||||
|
Protocol: "udp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(50 * time.Millisecond),
|
||||||
|
BytesTx: 64,
|
||||||
|
BytesRx: 512,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "u2",
|
||||||
|
ResourceID: 5,
|
||||||
|
SourceAddr: "10.0.0.1:6001",
|
||||||
|
DestAddr: "192.168.1.100:53",
|
||||||
|
Protocol: "udp",
|
||||||
|
StartedAt: now.Add(100 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(150 * time.Millisecond),
|
||||||
|
BytesTx: 64,
|
||||||
|
BytesRx: 256,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "u3",
|
||||||
|
ResourceID: 5,
|
||||||
|
SourceAddr: "10.0.0.1:6002",
|
||||||
|
DestAddr: "192.168.1.100:53",
|
||||||
|
Protocol: "udp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(250 * time.Millisecond),
|
||||||
|
BytesTx: 64,
|
||||||
|
BytesRx: 128,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Fatalf("expected 1 consolidated UDP session, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
s := result[0]
|
||||||
|
if s.Protocol != "udp" {
|
||||||
|
t.Errorf("expected protocol=udp, got %q", s.Protocol)
|
||||||
|
}
|
||||||
|
if s.ConnectionCount != 3 {
|
||||||
|
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
|
||||||
|
}
|
||||||
|
if s.SourceAddr != "10.0.0.1" {
|
||||||
|
t.Errorf("expected source addr to be IP only, got %q", s.SourceAddr)
|
||||||
|
}
|
||||||
|
if s.BytesTx != 192 {
|
||||||
|
t.Errorf("expected BytesTx=192, got %d", s.BytesTx)
|
||||||
|
}
|
||||||
|
if s.BytesRx != 896 {
|
||||||
|
t.Errorf("expected BytesRx=896, got %d", s.BytesRx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_MixedGroupsSomeConsolidatedSomeNot(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
// Group 1: 3 connections to :443 from same IP — should consolidate
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s3",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5002",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(400 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(500 * time.Millisecond),
|
||||||
|
},
|
||||||
|
// Group 2: 1 connection to :8080 from different IP — should pass through
|
||||||
|
{
|
||||||
|
SessionID: "s4",
|
||||||
|
ResourceID: 2,
|
||||||
|
SourceAddr: "10.0.0.2:6000",
|
||||||
|
DestAddr: "192.168.1.200:8080",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(1 * time.Second),
|
||||||
|
EndedAt: now.Add(2 * time.Second),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 sessions total, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
var consolidated, passthrough *AccessSession
|
||||||
|
for _, s := range result {
|
||||||
|
if s.ConnectionCount > 1 {
|
||||||
|
consolidated = s
|
||||||
|
} else {
|
||||||
|
passthrough = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if consolidated == nil {
|
||||||
|
t.Fatal("expected a consolidated session")
|
||||||
|
}
|
||||||
|
if consolidated.ConnectionCount != 3 {
|
||||||
|
t.Errorf("consolidated: expected ConnectionCount=3, got %d", consolidated.ConnectionCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if passthrough == nil {
|
||||||
|
t.Fatal("expected a passthrough session")
|
||||||
|
}
|
||||||
|
if passthrough.SessionID != "s4" {
|
||||||
|
t.Errorf("passthrough: expected session s4, got %s", passthrough.SessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_OverlappingConnections(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
// Connections that overlap in time (not sequential)
|
||||||
|
sessions := []*AccessSession{
|
||||||
|
{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(5 * time.Second),
|
||||||
|
BytesTx: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(1 * time.Second),
|
||||||
|
EndedAt: now.Add(3 * time.Second),
|
||||||
|
BytesTx: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: "s3",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5002",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(2 * time.Second),
|
||||||
|
EndedAt: now.Add(6 * time.Second),
|
||||||
|
BytesTx: 300,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Fatalf("expected 1 consolidated session, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
s := result[0]
|
||||||
|
if s.ConnectionCount != 3 {
|
||||||
|
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
|
||||||
|
}
|
||||||
|
if s.StartedAt != now {
|
||||||
|
t.Error("expected StartedAt to be earliest")
|
||||||
|
}
|
||||||
|
if s.EndedAt != now.Add(6*time.Second) {
|
||||||
|
t.Error("expected EndedAt to be the latest end time")
|
||||||
|
}
|
||||||
|
if s.BytesTx != 600 {
|
||||||
|
t.Errorf("expected BytesTx=600, got %d", s.BytesTx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_DoesNotMutateOriginals(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
s1 := &AccessSession{
|
||||||
|
SessionID: "s1",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5000",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now,
|
||||||
|
EndedAt: now.Add(100 * time.Millisecond),
|
||||||
|
BytesTx: 100,
|
||||||
|
}
|
||||||
|
s2 := &AccessSession{
|
||||||
|
SessionID: "s2",
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:5001",
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(200 * time.Millisecond),
|
||||||
|
EndedAt: now.Add(300 * time.Millisecond),
|
||||||
|
BytesTx: 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save original values
|
||||||
|
origS1Addr := s1.SourceAddr
|
||||||
|
origS1Bytes := s1.BytesTx
|
||||||
|
origS2Addr := s2.SourceAddr
|
||||||
|
origS2Bytes := s2.BytesTx
|
||||||
|
|
||||||
|
_ = consolidateSessions([]*AccessSession{s1, s2})
|
||||||
|
|
||||||
|
if s1.SourceAddr != origS1Addr {
|
||||||
|
t.Errorf("s1.SourceAddr was mutated: %q -> %q", origS1Addr, s1.SourceAddr)
|
||||||
|
}
|
||||||
|
if s1.BytesTx != origS1Bytes {
|
||||||
|
t.Errorf("s1.BytesTx was mutated: %d -> %d", origS1Bytes, s1.BytesTx)
|
||||||
|
}
|
||||||
|
if s2.SourceAddr != origS2Addr {
|
||||||
|
t.Errorf("s2.SourceAddr was mutated: %q -> %q", origS2Addr, s2.SourceAddr)
|
||||||
|
}
|
||||||
|
if s2.BytesTx != origS2Bytes {
|
||||||
|
t.Errorf("s2.BytesTx was mutated: %d -> %d", origS2Bytes, s2.BytesTx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConsolidateSessions_ThreeBurstsWithGaps(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
sessions := make([]*AccessSession, 0, 9)
|
||||||
|
|
||||||
|
// Burst 1: 3 connections at t=0
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
sessions = append(sessions, &AccessSession{
|
||||||
|
SessionID: generateSessionID(),
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:" + string(rune('A'+i)),
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(time.Duration(i*100) * time.Millisecond),
|
||||||
|
EndedAt: now.Add(time.Duration(i*100+50) * time.Millisecond),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Burst 2: 3 connections at t=20s (well past the 5s gap)
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
sessions = append(sessions, &AccessSession{
|
||||||
|
SessionID: generateSessionID(),
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:" + string(rune('D'+i)),
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(20*time.Second + time.Duration(i*100)*time.Millisecond),
|
||||||
|
EndedAt: now.Add(20*time.Second + time.Duration(i*100+50)*time.Millisecond),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Burst 3: 3 connections at t=40s
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
sessions = append(sessions, &AccessSession{
|
||||||
|
SessionID: generateSessionID(),
|
||||||
|
ResourceID: 1,
|
||||||
|
SourceAddr: "10.0.0.1:" + string(rune('G'+i)),
|
||||||
|
DestAddr: "192.168.1.100:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
StartedAt: now.Add(40*time.Second + time.Duration(i*100)*time.Millisecond),
|
||||||
|
EndedAt: now.Add(40*time.Second + time.Duration(i*100+50)*time.Millisecond),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
result := consolidateSessions(sessions)
|
||||||
|
if len(result) != 3 {
|
||||||
|
t.Fatalf("expected 3 consolidated sessions (3 bursts), got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range result {
|
||||||
|
if s.ConnectionCount != 3 {
|
||||||
|
t.Errorf("expected each burst to have ConnectionCount=3, got %d (started=%v)", s.ConnectionCount, s.StartedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeMergedSourceAddr(t *testing.T) {
|
||||||
|
s := &AccessSession{SourceAddr: "10.0.0.1:5000"}
|
||||||
|
ports := map[string]struct{}{"10.0.0.1:5000": {}}
|
||||||
|
finalizeMergedSourceAddr(s, "10.0.0.1", ports)
|
||||||
|
if s.SourceAddr != "10.0.0.1:5000" {
|
||||||
|
t.Errorf("single port: expected addr preserved, got %q", s.SourceAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
s2 := &AccessSession{SourceAddr: "10.0.0.1:5000"}
|
||||||
|
ports2 := map[string]struct{}{"10.0.0.1:5000": {}, "10.0.0.1:5001": {}}
|
||||||
|
finalizeMergedSourceAddr(s2, "10.0.0.1", ports2)
|
||||||
|
if s2.SourceAddr != "10.0.0.1" {
|
||||||
|
t.Errorf("multiple ports: expected IP only, got %q", s2.SourceAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneSession(t *testing.T) {
|
||||||
|
original := &AccessSession{
|
||||||
|
SessionID: "test",
|
||||||
|
ResourceID: 42,
|
||||||
|
SourceAddr: "1.2.3.4:100",
|
||||||
|
DestAddr: "5.6.7.8:443",
|
||||||
|
Protocol: "tcp",
|
||||||
|
BytesTx: 999,
|
||||||
|
}
|
||||||
|
|
||||||
|
clone := cloneSession(original)
|
||||||
|
|
||||||
|
if clone == original {
|
||||||
|
t.Error("clone should be a different pointer")
|
||||||
|
}
|
||||||
|
if clone.SessionID != original.SessionID {
|
||||||
|
t.Error("clone should have same SessionID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutating clone should not affect original
|
||||||
|
clone.BytesTx = 0
|
||||||
|
clone.SourceAddr = "changed"
|
||||||
|
if original.BytesTx != 999 {
|
||||||
|
t.Error("mutating clone affected original BytesTx")
|
||||||
|
}
|
||||||
|
if original.SourceAddr != "1.2.3.4:100" {
|
||||||
|
t.Error("mutating clone affected original SourceAddr")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -158,6 +158,18 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
|
|||||||
|
|
||||||
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
||||||
|
|
||||||
|
// Look up resource ID and start access session if applicable
|
||||||
|
var accessSessionID string
|
||||||
|
if h.proxyHandler != nil {
|
||||||
|
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber))
|
||||||
|
if resourceId != 0 {
|
||||||
|
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||||
|
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
|
||||||
|
accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create context with timeout for connection establishment
|
// Create context with timeout for connection establishment
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -167,11 +179,26 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
|
|||||||
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
|
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
|
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
|
||||||
|
// End access session on connection failure
|
||||||
|
if accessSessionID != "" {
|
||||||
|
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||||
|
al.EndTCPSession(accessSessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
// Connection failed, netstack will handle RST
|
// Connection failed, netstack will handle RST
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer targetConn.Close()
|
defer targetConn.Close()
|
||||||
|
|
||||||
|
// End access session when connection closes
|
||||||
|
if accessSessionID != "" {
|
||||||
|
defer func() {
|
||||||
|
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||||
|
al.EndTCPSession(accessSessionID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
|
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
|
||||||
|
|
||||||
// Bidirectional copy between netstack and target
|
// Bidirectional copy between netstack and target
|
||||||
@@ -280,6 +307,27 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo
|
|||||||
|
|
||||||
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
||||||
|
|
||||||
|
// Look up resource ID and start access session if applicable
|
||||||
|
var accessSessionID string
|
||||||
|
if h.proxyHandler != nil {
|
||||||
|
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber))
|
||||||
|
if resourceId != 0 {
|
||||||
|
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||||
|
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
|
||||||
|
accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// End access session when UDP handler returns (timeout or error)
|
||||||
|
if accessSessionID != "" {
|
||||||
|
defer func() {
|
||||||
|
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||||
|
al.EndUDPSession(accessSessionID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve target address
|
// Resolve target address
|
||||||
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -22,6 +22,12 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// udpAccessSessionTimeout is how long a UDP access session stays alive without traffic
|
||||||
|
// before being considered ended by the access logger
|
||||||
|
udpAccessSessionTimeout = 120 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
|
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
|
||||||
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
|
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
|
||||||
type PortRange struct {
|
type PortRange struct {
|
||||||
@@ -46,6 +52,7 @@ type SubnetRule struct {
|
|||||||
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
|
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
|
||||||
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
|
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
|
||||||
PortRanges []PortRange // empty slice means all ports allowed
|
PortRanges []PortRange // empty slice means all ports allowed
|
||||||
|
ResourceId int // Optional resource ID from the server for access logging
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllRules returns a copy of all subnet rules
|
// GetAllRules returns a copy of all subnet rules
|
||||||
@@ -111,10 +118,12 @@ type ProxyHandler struct {
|
|||||||
natTable map[connKey]*natState
|
natTable map[connKey]*natState
|
||||||
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
|
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
|
||||||
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
|
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
|
||||||
|
resourceTable map[destKey]int // Maps connection key to resource ID for access logging
|
||||||
natMu sync.RWMutex
|
natMu sync.RWMutex
|
||||||
enabled bool
|
enabled bool
|
||||||
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
|
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
|
||||||
notifiable channel.Notification // Notification handler for triggering reads
|
notifiable channel.Notification // Notification handler for triggering reads
|
||||||
|
accessLogger *AccessLogger // Access logger for tracking sessions
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyHandlerOptions configures the proxy handler
|
// ProxyHandlerOptions configures the proxy handler
|
||||||
@@ -137,7 +146,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
|||||||
natTable: make(map[connKey]*natState),
|
natTable: make(map[connKey]*natState),
|
||||||
reverseNatTable: make(map[reverseConnKey]*natState),
|
reverseNatTable: make(map[reverseConnKey]*natState),
|
||||||
destRewriteTable: make(map[destKey]netip.Addr),
|
destRewriteTable: make(map[destKey]netip.Addr),
|
||||||
|
resourceTable: make(map[destKey]int),
|
||||||
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
|
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
|
||||||
|
accessLogger: NewAccessLogger(udpAccessSessionTimeout),
|
||||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||||
proxyStack: stack.New(stack.Options{
|
proxyStack: stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
@@ -202,11 +213,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
|||||||
// destPrefix: The IP prefix of the destination
|
// destPrefix: The IP prefix of the destination
|
||||||
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
|
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
|
||||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||||
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||||
if p == nil || !p.enabled {
|
if p == nil || !p.enabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
|
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSubnetRule removes a subnet from the proxy handler
|
// RemoveSubnetRule removes a subnet from the proxy handler
|
||||||
@@ -225,6 +236,43 @@ func (p *ProxyHandler) GetAllRules() []SubnetRule {
|
|||||||
return p.subnetLookup.GetAllRules()
|
return p.subnetLookup.GetAllRules()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LookupResourceId looks up the resource ID for a connection
|
||||||
|
// Returns 0 if no resource ID is associated with this connection
|
||||||
|
func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int {
|
||||||
|
if p == nil || !p.enabled {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
key := destKey{
|
||||||
|
srcIP: srcIP,
|
||||||
|
dstIP: dstIP,
|
||||||
|
dstPort: dstPort,
|
||||||
|
proto: proto,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.natMu.RLock()
|
||||||
|
defer p.natMu.RUnlock()
|
||||||
|
|
||||||
|
return p.resourceTable[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessLogger returns the access logger for session tracking
|
||||||
|
func (p *ProxyHandler) GetAccessLogger() *AccessLogger {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return p.accessLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAccessLogSender configures the function used to send compressed access log
|
||||||
|
// batches to the server. This should be called once the websocket client is available.
|
||||||
|
func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
|
||||||
|
if p == nil || !p.enabled || p.accessLogger == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.accessLogger.SetSendFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
||||||
// This is used by TCP/UDP handlers to find the actual target address
|
// This is used by TCP/UDP handlers to find the actual target address
|
||||||
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
||||||
@@ -387,8 +435,22 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
|||||||
// Check if the source IP, destination IP, port, and protocol match any subnet rule
|
// Check if the source IP, destination IP, port, and protocol match any subnet rule
|
||||||
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
|
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
|
||||||
if matchedRule != nil {
|
if matchedRule != nil {
|
||||||
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
|
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)",
|
||||||
srcAddr, dstAddr, protocol, dstPort)
|
srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId)
|
||||||
|
|
||||||
|
// Store resource ID for connections without DNAT as well
|
||||||
|
if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" {
|
||||||
|
dKey := destKey{
|
||||||
|
srcIP: srcAddr.String(),
|
||||||
|
dstIP: dstAddr.String(),
|
||||||
|
dstPort: dstPort,
|
||||||
|
proto: uint8(protocol),
|
||||||
|
}
|
||||||
|
p.natMu.Lock()
|
||||||
|
p.resourceTable[dKey] = matchedRule.ResourceId
|
||||||
|
p.natMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we need to perform DNAT
|
// Check if we need to perform DNAT
|
||||||
if matchedRule.RewriteTo != "" {
|
if matchedRule.RewriteTo != "" {
|
||||||
// Create connection tracking key using original destination
|
// Create connection tracking key using original destination
|
||||||
@@ -420,6 +482,13 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
|||||||
proto: uint8(protocol),
|
proto: uint8(protocol),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store resource ID for access logging if present
|
||||||
|
if matchedRule.ResourceId != 0 {
|
||||||
|
p.natMu.Lock()
|
||||||
|
p.resourceTable[dKey] = matchedRule.ResourceId
|
||||||
|
p.natMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we already have a NAT entry for this connection
|
// Check if we already have a NAT entry for this connection
|
||||||
p.natMu.RLock()
|
p.natMu.RLock()
|
||||||
existingEntry, exists := p.natTable[key]
|
existingEntry, exists := p.natTable[key]
|
||||||
@@ -720,6 +789,11 @@ func (p *ProxyHandler) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shut down access logger
|
||||||
|
if p.accessLogger != nil {
|
||||||
|
p.accessLogger.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// Close ICMP replies channel
|
// Close ICMP replies channel
|
||||||
if p.icmpReplies != nil {
|
if p.icmpReplies != nil {
|
||||||
close(p.icmpReplies)
|
close(p.icmpReplies)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func prefixEqual(a, b netip.Prefix) bool {
|
|||||||
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
||||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||||
sl.mu.Lock()
|
sl.mu.Lock()
|
||||||
defer sl.mu.Unlock()
|
defer sl.mu.Unlock()
|
||||||
|
|
||||||
@@ -57,6 +57,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
|
|||||||
DisableIcmp: disableIcmp,
|
DisableIcmp: disableIcmp,
|
||||||
RewriteTo: rewriteTo,
|
RewriteTo: rewriteTo,
|
||||||
PortRanges: portRanges,
|
PortRanges: portRanges,
|
||||||
|
ResourceId: resourceId,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Canonicalize source prefix to handle host bits correctly
|
// Canonicalize source prefix to handle host bits correctly
|
||||||
|
|||||||
@@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
|||||||
// AddProxySubnetRule adds a subnet rule to the proxy handler
|
// AddProxySubnetRule adds a subnet rule to the proxy handler
|
||||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||||
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||||
tun := (*netTun)(net)
|
tun := (*netTun)(net)
|
||||||
if tun.proxyHandler != nil {
|
if tun.proxyHandler != nil {
|
||||||
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
|
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,6 +385,15 @@ func (net *Net) GetProxyHandler() *ProxyHandler {
|
|||||||
return tun.proxyHandler
|
return tun.proxyHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccessLogSender configures the function used to send compressed access log
|
||||||
|
// batches to the server. This should be called once the websocket client is available.
|
||||||
|
func (net *Net) SetAccessLogSender(fn SendFunc) {
|
||||||
|
tun := (*netTun)(net)
|
||||||
|
if tun.proxyHandler != nil {
|
||||||
|
tun.proxyHandler.SetAccessLogSender(fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type PingConn struct {
|
type PingConn struct {
|
||||||
laddr PingAddr
|
laddr PingAddr
|
||||||
raddr PingAddr
|
raddr PingAddr
|
||||||
|
|||||||
@@ -21,7 +21,10 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errUnsupportedProtoFmt = "unsupported protocol: %s"
|
const (
|
||||||
|
errUnsupportedProtoFmt = "unsupported protocol: %s"
|
||||||
|
maxUDPPacketSize = 65507
|
||||||
|
)
|
||||||
|
|
||||||
// Target represents a proxy target with its address and port
|
// Target represents a proxy target with its address and port
|
||||||
type Target struct {
|
type Target struct {
|
||||||
@@ -105,14 +108,10 @@ func classifyProxyError(err error) string {
|
|||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
return "closed"
|
return "closed"
|
||||||
}
|
}
|
||||||
if ne, ok := err.(net.Error); ok {
|
var ne net.Error
|
||||||
if ne.Timeout() {
|
if errors.As(err, &ne) && ne.Timeout() {
|
||||||
return "timeout"
|
return "timeout"
|
||||||
}
|
}
|
||||||
if ne.Temporary() {
|
|
||||||
return "temporary"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg := strings.ToLower(err.Error())
|
msg := strings.ToLower(err.Error())
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(msg, "refused"):
|
case strings.Contains(msg, "refused"):
|
||||||
@@ -437,14 +436,6 @@ func (pm *ProxyManager) Stop() error {
|
|||||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Clear the target maps
|
|
||||||
// for k := range pm.tcpTargets {
|
|
||||||
// delete(pm.tcpTargets, k)
|
|
||||||
// }
|
|
||||||
// for k := range pm.udpTargets {
|
|
||||||
// delete(pm.udpTargets, k)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Give active connections a chance to close gracefully
|
// Give active connections a chance to close gracefully
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
@@ -498,7 +489,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
|
|||||||
if !pm.running {
|
if !pm.running {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -564,7 +555,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||||
buffer := make([]byte, 65507) // Max UDP packet size
|
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
|
||||||
clientConns := make(map[string]*net.UDPConn)
|
clientConns := make(map[string]*net.UDPConn)
|
||||||
var clientsMutex sync.RWMutex
|
var clientsMutex sync.RWMutex
|
||||||
|
|
||||||
@@ -583,7 +574,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for connection closed conditions
|
// Check for connection closed conditions
|
||||||
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||||
logger.Info("UDP connection closed, stopping proxy handler")
|
logger.Info("UDP connection closed, stopping proxy handler")
|
||||||
|
|
||||||
// Clean up existing client connections
|
// Clean up existing client connections
|
||||||
@@ -662,10 +653,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
|||||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
|
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
buffer := make([]byte, 65507)
|
buffer := make([]byte, maxUDPPacketSize)
|
||||||
for {
|
for {
|
||||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Connection closed is normal during cleanup
|
||||||
|
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
|
||||||
|
return // defer will handle cleanup, result stays "success"
|
||||||
|
}
|
||||||
logger.Error("Error reading from target: %v", err)
|
logger.Error("Error reading from target: %v", err)
|
||||||
result = "failure"
|
result = "failure"
|
||||||
return // defer will handle cleanup
|
return // defer will handle cleanup
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ type Client struct {
|
|||||||
onTokenUpdate func(token string)
|
onTokenUpdate func(token string)
|
||||||
writeMux sync.Mutex
|
writeMux sync.Mutex
|
||||||
clientType string // Type of client (e.g., "newt", "olm")
|
clientType string // Type of client (e.g., "newt", "olm")
|
||||||
|
configFilePath string // Optional override for the config file path
|
||||||
tlsConfig TLSConfig
|
tlsConfig TLSConfig
|
||||||
metricsCtxMu sync.RWMutex
|
metricsCtxMu sync.RWMutex
|
||||||
metricsCtx context.Context
|
metricsCtx context.Context
|
||||||
@@ -52,6 +53,7 @@ type Client struct {
|
|||||||
processingMessage bool // Flag to track if a message is currently being processed
|
processingMessage bool // Flag to track if a message is currently being processed
|
||||||
processingMux sync.RWMutex // Protects processingMessage
|
processingMux sync.RWMutex // Protects processingMessage
|
||||||
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
|
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
|
||||||
|
justProvisioned bool // Set to true when provisionIfNeeded exchanges a key for permanent credentials
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
@@ -77,6 +79,12 @@ func WithBaseURL(url string) ClientOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithTLSConfig sets the TLS configuration for the client
|
// WithTLSConfig sets the TLS configuration for the client
|
||||||
|
func WithConfigFile(path string) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.configFilePath = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithTLSConfig(config TLSConfig) ClientOption {
|
func WithTLSConfig(config TLSConfig) ClientOption {
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
c.tlsConfig = config
|
c.tlsConfig = config
|
||||||
@@ -95,6 +103,16 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
|
|||||||
c.onTokenUpdate = callback
|
c.onTokenUpdate = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WasJustProvisioned reports whether the client exchanged a provisioning key
|
||||||
|
// for permanent credentials during the most recent connection attempt. It
|
||||||
|
// consumes the flag – subsequent calls return false until provisioning occurs
|
||||||
|
// again (which, in practice, never happens once credentials are persisted).
|
||||||
|
func (c *Client) WasJustProvisioned() bool {
|
||||||
|
v := c.justProvisioned
|
||||||
|
c.justProvisioned = false
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) metricsContext() context.Context {
|
func (c *Client) metricsContext() context.Context {
|
||||||
c.metricsCtxMu.RLock()
|
c.metricsCtxMu.RLock()
|
||||||
defer c.metricsCtxMu.RUnlock()
|
defer c.metricsCtxMu.RUnlock()
|
||||||
@@ -481,6 +499,11 @@ func (c *Client) connectWithRetry() {
|
|||||||
func (c *Client) establishConnection() error {
|
func (c *Client) establishConnection() error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Exchange provisioning key for permanent credentials if needed.
|
||||||
|
if err := c.provisionIfNeeded(); err != nil {
|
||||||
|
return fmt.Errorf("failed to provision newt credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Get token for authentication
|
// Get token for authentication
|
||||||
token, err := c.getToken()
|
token, err := c.getToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,16 +1,29 @@
|
|||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getConfigPath(clientType string) string {
|
func getConfigPath(clientType string, overridePath string) string {
|
||||||
|
if overridePath != "" {
|
||||||
|
return overridePath
|
||||||
|
}
|
||||||
configFile := os.Getenv("CONFIG_FILE")
|
configFile := os.Getenv("CONFIG_FILE")
|
||||||
if configFile == "" {
|
if configFile == "" {
|
||||||
var configDir string
|
var configDir string
|
||||||
@@ -36,7 +49,7 @@ func getConfigPath(clientType string) string {
|
|||||||
|
|
||||||
func (c *Client) loadConfig() error {
|
func (c *Client) loadConfig() error {
|
||||||
originalConfig := *c.config // Store original config to detect changes
|
originalConfig := *c.config // Store original config to detect changes
|
||||||
configPath := getConfigPath(c.clientType)
|
configPath := getConfigPath(c.clientType, c.configFilePath)
|
||||||
|
|
||||||
if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
|
if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
|
||||||
logger.Debug("Config already provided, skipping loading from file")
|
logger.Debug("Config already provided, skipping loading from file")
|
||||||
@@ -83,6 +96,14 @@ func (c *Client) loadConfig() error {
|
|||||||
c.config.Endpoint = config.Endpoint
|
c.config.Endpoint = config.Endpoint
|
||||||
c.baseURL = config.Endpoint
|
c.baseURL = config.Endpoint
|
||||||
}
|
}
|
||||||
|
// Always load the provisioning key from the file if not already set
|
||||||
|
if c.config.ProvisioningKey == "" {
|
||||||
|
c.config.ProvisioningKey = config.ProvisioningKey
|
||||||
|
}
|
||||||
|
// Always load the name from the file if not already set
|
||||||
|
if c.config.Name == "" {
|
||||||
|
c.config.Name = config.Name
|
||||||
|
}
|
||||||
|
|
||||||
// Check if CLI args provided values that override file values
|
// Check if CLI args provided values that override file values
|
||||||
if (!fileHadID && originalConfig.ID != "") ||
|
if (!fileHadID && originalConfig.ID != "") ||
|
||||||
@@ -105,7 +126,7 @@ func (c *Client) saveConfig() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
configPath := getConfigPath(c.clientType)
|
configPath := getConfigPath(c.clientType, c.configFilePath)
|
||||||
data, err := json.MarshalIndent(c.config, "", " ")
|
data, err := json.MarshalIndent(c.config, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -118,3 +139,139 @@ func (c *Client) saveConfig() error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// interpolateString replaces {{env.VAR}} tokens in s with the corresponding
|
||||||
|
// environment variable values. Tokens that do not match a supported scheme are
|
||||||
|
// left unchanged, mirroring the blueprint interpolation logic.
|
||||||
|
func interpolateString(s string) string {
|
||||||
|
re := regexp.MustCompile(`\{\{([^}]+)\}\}`)
|
||||||
|
return re.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
|
inner := strings.TrimSpace(match[2 : len(match)-2])
|
||||||
|
if strings.HasPrefix(inner, "env.") {
|
||||||
|
varName := strings.TrimPrefix(inner, "env.")
|
||||||
|
return os.Getenv(varName)
|
||||||
|
}
|
||||||
|
return match
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// provisionIfNeeded checks whether a provisioning key is present and, if so,
|
||||||
|
// exchanges it for a newt ID and secret by calling the registration endpoint.
|
||||||
|
// On success the config is updated in-place and flagged for saving so that
|
||||||
|
// subsequent runs use the permanent credentials directly.
|
||||||
|
func (c *Client) provisionIfNeeded() error {
|
||||||
|
if c.config.ProvisioningKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we already have both credentials there is nothing to provision.
|
||||||
|
if c.config.ID != "" && c.config.Secret != "" {
|
||||||
|
logger.Debug("Credentials already present, skipping provisioning")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Provisioning key found – exchanging for newt credentials...")
|
||||||
|
|
||||||
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse base URL for provisioning: %w", err)
|
||||||
|
}
|
||||||
|
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||||
|
|
||||||
|
// Interpolate any {{env.VAR}} tokens in the name before sending.
|
||||||
|
name := interpolateString(c.config.Name)
|
||||||
|
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"provisioningKey": c.config.ProvisioningKey,
|
||||||
|
}
|
||||||
|
if name != "" {
|
||||||
|
reqBody["name"] = name
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal provisioning request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(
|
||||||
|
ctx,
|
||||||
|
"POST",
|
||||||
|
baseEndpoint+"/api/v1/auth/newt/register",
|
||||||
|
bytes.NewBuffer(jsonData),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create provisioning request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||||
|
|
||||||
|
// Mirror the TLS setup used by getToken so mTLS / self-signed CAs work.
|
||||||
|
var tlsCfg *tls.Config
|
||||||
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" ||
|
||||||
|
len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
|
tlsCfg, err = c.setupTLS()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to setup TLS for provisioning: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||||
|
if tlsCfg == nil {
|
||||||
|
tlsCfg = &tls.Config{}
|
||||||
|
}
|
||||||
|
tlsCfg.InsecureSkipVerify = true
|
||||||
|
logger.Debug("TLS certificate verification disabled for provisioning via SKIP_TLS_VERIFY")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := &http.Client{}
|
||||||
|
if tlsCfg != nil {
|
||||||
|
httpClient.Transport = &http.Transport{TLSClientConfig: tlsCfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("provisioning request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
logger.Debug("Provisioning response body: %s", string(body))
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return fmt.Errorf("provisioning endpoint returned status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var provResp ProvisioningResponse
|
||||||
|
if err := json.Unmarshal(body, &provResp); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode provisioning response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !provResp.Success {
|
||||||
|
return fmt.Errorf("provisioning failed: %s", provResp.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if provResp.Data.NewtID == "" || provResp.Data.Secret == "" {
|
||||||
|
return fmt.Errorf("provisioning response is missing newt ID or secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully provisioned – newt ID: %s", provResp.Data.NewtID)
|
||||||
|
|
||||||
|
// Persist the returned credentials and clear the one-time provisioning key
|
||||||
|
// so subsequent runs authenticate normally.
|
||||||
|
c.config.ID = provResp.Data.NewtID
|
||||||
|
c.config.Secret = provResp.Data.Secret
|
||||||
|
c.config.ProvisioningKey = ""
|
||||||
|
c.config.Name = ""
|
||||||
|
c.configNeedsSave = true
|
||||||
|
c.justProvisioned = true
|
||||||
|
|
||||||
|
// Save immediately so that if the subsequent connection attempt fails the
|
||||||
|
// provisioning key is already gone from disk and the next retry uses the
|
||||||
|
// permanent credentials instead of trying to provision again.
|
||||||
|
if err := c.saveConfig(); err != nil {
|
||||||
|
logger.Error("Failed to save config after provisioning: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -5,6 +5,8 @@ type Config struct {
|
|||||||
Secret string `json:"secret"`
|
Secret string `json:"secret"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
TlsClientCert string `json:"tlsClientCert"`
|
TlsClientCert string `json:"tlsClientCert"`
|
||||||
|
ProvisioningKey string `json:"provisioningKey,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
@@ -16,6 +18,15 @@ type TokenResponse struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ProvisioningResponse struct {
|
||||||
|
Data struct {
|
||||||
|
NewtID string `json:"newtId"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
} `json:"data"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
type WSMessage struct {
|
type WSMessage struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Data interface{} `json:"data"`
|
Data interface{} `json:"data"`
|
||||||
|
|||||||
Reference in New Issue
Block a user