Compare commits

...

17 Commits
1.10.4 ... http

Author SHA1 Message Date
Owen
092535441e Pass the new data down from the websocket 2026-04-09 16:13:19 -04:00
Owen
5848c8d4b4 Adjust to use data saved inside of the subnet rule 2026-04-09 16:04:11 -04:00
Owen
47c646bc33 Basic http is working 2026-04-09 11:43:26 -04:00
Owen Schwartz
7e1e3408d5 Merge pull request #302 from LaurenceJJones/fix/config-file-provision-save
fix: allow empty config file bootstrap before provisioning
2026-04-08 21:58:07 -04:00
Laurence
d7c3c38d24 fix: allow empty config file bootstrap before provisioning
Treat an empty CONFIG_FILE as initial state instead of failing JSON parse, so provisioning can proceed and credentials can be saved. Ref: fosrl/pangolin#2812
2026-04-08 14:13:13 +01:00
Owen
27e471942e Add CODEOWNERS 2026-04-07 11:34:18 -04:00
Owen
184bfb12d6 Delete bad bp 2026-04-03 17:36:48 -04:00
Owen
2e02c9b7a9 Remove files 2026-04-03 16:49:09 -04:00
Owen
f4d071fe27 Add provisioning blueprint file 2026-04-02 21:39:59 -04:00
Owen
8d82460a76 Send health checks to the server on reconnect 2026-03-31 17:06:07 -07:00
Owen
5208117c56 Add name to provisioning 2026-03-30 17:18:22 -07:00
Owen
381f5a619c Merge branch 'main' into logging-provision 2026-03-29 21:19:53 -07:00
Owen
fc4b375bf1 Allow blueprint interpolation for env vars 2026-03-26 20:05:04 -07:00
Owen
baca04ee58 Add --config-file 2026-03-26 17:31:04 -07:00
Owen
b43572dd8d Provisioning key working 2026-03-26 17:23:19 -07:00
Owen
69019d5655 Process log to form sessions 2026-03-24 17:26:44 -07:00
Owen
0f57985b6f Saving and sending access logs pass 1 2026-03-23 16:39:01 -07:00
16 changed files with 2263 additions and 107 deletions

1
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1 @@
* @oschwartz10612 @miloschwartz

View File

@@ -1,37 +0,0 @@
resources:
resource-nice-id:
name: this is my resource
protocol: http
full-domain: level1.test3.example.com
host-header: example.com
tls-server-name: example.com
auth:
pincode: 123456
password: sadfasdfadsf
sso-enabled: true
sso-roles:
- Member
sso-users:
- owen@pangolin.net
whitelist-users:
- owen@pangolin.net
targets:
# - site: glossy-plains-viscacha-rat
- hostname: localhost
method: http
port: 8000
healthcheck:
port: 8000
hostname: localhost
# - site: glossy-plains-viscacha-rat
- hostname: localhost
method: http
port: 8001
resource-nice-id2:
name: this is other resource
protocol: tcp
proxy-port: 3000
targets:
# - site: glossy-plains-viscacha-rat
- hostname: localhost
port: 3000

View File

@@ -40,12 +40,17 @@ type WgConfig struct {
}
type Target struct {
SourcePrefix string `json:"sourcePrefix"`
SourcePrefixes []string `json:"sourcePrefixes"`
DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
SourcePrefix string `json:"sourcePrefix"`
SourcePrefixes []string `json:"sourcePrefixes"`
DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"`
ResourceId int `json:"resourceId,omitempty"`
Protocol string `json:"protocol,omitempty"` // for now practicably either http or https
HTTPTargets []netstack2.HTTPTarget `json:"httpTargets,omitempty"` // for http protocol, list of downstream services to load balance across
TLSCert string `json:"tlsCert,omitempty"` // PEM-encoded certificate for incoming HTTPS termination
TLSKey string `json:"tlsKey,omitempty"` // PEM-encoded private key for incoming HTTPS termination
}
type PortRange struct {
@@ -73,18 +78,18 @@ type PeerReading struct {
}
type WireGuardService struct {
interfaceName string
mtu int
client *websocket.Client
config WgConfig
key wgtypes.Key
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
host string
serverPubKey string
token string
interfaceName string
mtu int
client *websocket.Client
config WgConfig
key wgtypes.Key
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
host string
serverPubKey string
token string
stopGetConfig func()
pendingConfigChainId string
// Netstack fields
@@ -206,6 +211,15 @@ func (s *WireGuardService) Close() {
s.stopGetConfig = nil
}
// Flush access logs before tearing down the tunnel
if s.tnet != nil {
if ph := s.tnet.GetProxyHandler(); ph != nil {
if al := ph.GetAccessLogger(); al != nil {
al.Close()
}
}
}
// Stop the direct UDP relay first
s.StopDirectUDPRelay()
@@ -687,7 +701,18 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
})
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
RewriteTo: target.RewriteTo,
PortRanges: portRanges,
DisableIcmp: target.DisableIcmp,
ResourceId: target.ResourceId,
Protocol: target.Protocol,
HTTPTargets: target.HTTPTargets,
TLSCert: target.TLSCert,
TLSKey: target.TLSKey,
})
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
}
}
@@ -818,6 +843,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.TunnelIP = tunnelIP.String()
// Configure the access log sender to ship compressed session logs via websocket
s.tnet.SetAccessLogSender(func(data string) error {
return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{
"compressed": data,
})
})
// Create WireGuard device using the shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent, // Use silent logging by default - could be made configurable
@@ -938,7 +970,18 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
if err != nil {
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
RewriteTo: target.RewriteTo,
PortRanges: portRanges,
DisableIcmp: target.DisableIcmp,
ResourceId: target.ResourceId,
Protocol: target.Protocol,
HTTPTargets: target.HTTPTargets,
TLSCert: target.TLSCert,
TLSKey: target.TLSKey,
})
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
@@ -1331,7 +1374,18 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
RewriteTo: target.RewriteTo,
PortRanges: portRanges,
DisableIcmp: target.DisableIcmp,
ResourceId: target.ResourceId,
Protocol: target.Protocol,
HTTPTargets: target.HTTPTargets,
TLSCert: target.TLSCert,
TLSKey: target.TLSKey,
})
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}
@@ -1449,7 +1503,18 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
logger.Info("Invalid CIDR %s: %v", sp, err)
continue
}
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
s.tnet.AddProxySubnetRule(netstack2.SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
RewriteTo: target.RewriteTo,
PortRanges: portRanges,
DisableIcmp: target.DisableIcmp,
ResourceId: target.ResourceId,
Protocol: target.Protocol,
HTTPTargets: target.HTTPTargets,
TLSCert: target.TLSCert,
TLSKey: target.TLSKey,
})
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
}
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"os"
"os/exec"
"regexp"
"strings"
"time"
@@ -516,15 +517,41 @@ func executeUpdownScript(action, proto, target string) (string, error) {
return target, nil
}
func sendBlueprint(client *websocket.Client) error {
if blueprintFile == "" {
// interpolateBlueprint finds all {{...}} tokens in the raw blueprint bytes and
// replaces recognised schemes with their resolved values. Currently supported:
//
// - env.<VAR> replaced with the value of the named environment variable
//
// Any token that does not match a supported scheme is left as-is so that
// future schemes (e.g. tag., api.) are preserved rather than silently dropped.
func interpolateBlueprint(data []byte) []byte {
re := regexp.MustCompile(`\{\{([^}]+)\}\}`)
return re.ReplaceAllFunc(data, func(match []byte) []byte {
// strip the surrounding {{ }}
inner := strings.TrimSpace(string(match[2 : len(match)-2]))
if strings.HasPrefix(inner, "env.") {
varName := strings.TrimPrefix(inner, "env.")
return []byte(os.Getenv(varName))
}
// unrecognised scheme leave the token untouched
return match
})
}
func sendBlueprint(client *websocket.Client, file string) error {
if file == "" {
return nil
}
// try to read the blueprint file
blueprintData, err := os.ReadFile(blueprintFile)
blueprintData, err := os.ReadFile(file)
if err != nil {
logger.Error("Failed to read blueprint file: %v", err)
} else {
// interpolate {{env.VAR}} (and any future schemes) before parsing
blueprintData = interpolateBlueprint(blueprintData)
// first we should convert the yaml to json and error if the yaml is bad
var yamlObj interface{}
var blueprintJsonData string

70
main.go
View File

@@ -155,8 +155,9 @@ var (
region string
metricsAsyncBytes bool
pprofEnabled bool
blueprintFile string
noCloud bool
blueprintFile string
provisioningBlueprintFile string
noCloud bool
// New mTLS configuration variables
tlsClientCert string
@@ -165,6 +166,15 @@ var (
// Legacy PKCS12 support (deprecated)
tlsPrivateKey string
// Provisioning key exchanged once for a permanent newt ID + secret
provisioningKey string
// Optional name for the site created during provisioning
newtName string
// Path to config file (overrides CONFIG_FILE env var and default location)
configFile string
)
// generateChainId generates a random chain ID for deduplicating round-trip messages.
@@ -275,8 +285,12 @@ func runNewtMain(ctx context.Context) {
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
}
blueprintFile = os.Getenv("BLUEPRINT_FILE")
provisioningBlueprintFile = os.Getenv("PROVISIONING_BLUEPRINT_FILE")
noCloudEnv := os.Getenv("NO_CLOUD")
noCloud = noCloudEnv == "true"
provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY")
newtName = os.Getenv("NEWT_NAME")
configFile = os.Getenv("CONFIG_FILE")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -325,6 +339,15 @@ func runNewtMain(ctx context.Context) {
}
// load the prefer endpoint just as a flag
flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)")
if provisioningKey == "" {
flag.StringVar(&provisioningKey, "provisioning-key", "", "One-time provisioning key used to obtain a newt ID and secret from the server")
}
if newtName == "" {
flag.StringVar(&newtName, "name", "", "Name for the site created during provisioning (supports {{env.VAR}} interpolation)")
}
if configFile == "" {
flag.StringVar(&configFile, "config-file", "", "Path to config file (overrides CONFIG_FILE env var and default location)")
}
// Add new mTLS flags
if tlsClientCert == "" {
@@ -372,6 +395,9 @@ func runNewtMain(ctx context.Context) {
if blueprintFile == "" {
flag.StringVar(&blueprintFile, "blueprint-file", "", "Path to blueprint file (if unset, no blueprint will be applied)")
}
if provisioningBlueprintFile == "" {
flag.StringVar(&provisioningBlueprintFile, "provisioning-blueprint-file", "", "Path to blueprint file applied once after a provisioning credential exchange (if unset, no provisioning blueprint will be applied)")
}
if noCloudEnv == "" {
flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover")
}
@@ -599,10 +625,20 @@ func runNewtMain(ctx context.Context) {
endpoint,
30*time.Second,
opt,
websocket.WithConfigFile(configFile),
)
if err != nil {
logger.Fatal("Failed to create client: %v", err)
}
// If a provisioning key was supplied via CLI / env and the config file did
// not already carry one, inject it now so provisionIfNeeded() can use it.
if provisioningKey != "" && client.GetConfig().ProvisioningKey == "" {
client.GetConfig().ProvisioningKey = provisioningKey
}
if newtName != "" && client.GetConfig().Name == "" {
client.GetConfig().Name = newtName
}
endpoint = client.GetConfig().Endpoint // Update endpoint from config
id = client.GetConfig().ID // Update ID from config
// Update site labels for metrics with the resolved ID
@@ -1789,6 +1825,34 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} else {
logger.Warn("CLIENTS WILL NOT WORK ON THIS VERSION OF NEWT WITH THIS VERSION OF PANGOLIN, PLEASE UPDATE THE SERVER TO 1.13 OR HIGHER OR DOWNGRADE NEWT")
}
sendBlueprint(client, blueprintFile)
if client.WasJustProvisioned() {
logger.Info("Provisioning detected sending provisioning blueprint")
sendBlueprint(client, provisioningBlueprintFile)
}
} else {
// Resend current health check status for all targets in case the server
// missed updates while newt was disconnected.
targets := healthMonitor.GetTargets()
if len(targets) > 0 {
healthStatuses := make(map[int]interface{})
for id, target := range targets {
healthStatuses[id] = map[string]interface{}{
"status": target.Status.String(),
"lastCheck": target.LastCheck.Format(time.RFC3339),
"checkCount": target.CheckCount,
"lastError": target.LastError,
"config": target.Config,
}
}
logger.Debug("Reconnected: resending health check status for %d targets", len(healthStatuses))
if err := client.SendMessage("newt/healthcheck/status", map[string]interface{}{
"targets": healthStatuses,
}); err != nil {
logger.Error("Failed to resend health check status on reconnect: %v", err)
}
}
}
// Send registration message to the server for backward compatibility
@@ -1801,8 +1865,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
"chainId": bcChainId,
})
sendBlueprint(client)
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err

514
netstack2/access_log.go Normal file
View File

@@ -0,0 +1,514 @@
package netstack2
import (
"bytes"
"compress/zlib"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"net"
"sort"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
const (
// flushInterval is how often the access logger flushes completed sessions to the server
flushInterval = 60 * time.Second
// maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush
maxBufferedSessions = 100
// sessionGapThreshold is the maximum gap between the end of one connection
// and the start of the next for them to be considered part of the same session.
// If the gap exceeds this, a new consolidated session is created.
sessionGapThreshold = 5 * time.Second
// minConnectionsToConsolidate is the minimum number of connections in a group
// before we bother consolidating. Groups smaller than this are sent as-is.
minConnectionsToConsolidate = 2
)
// SendFunc is a callback that sends compressed access log data to the server.
// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects.
type SendFunc func(data string) error
// AccessSession represents a tracked access session through the proxy
type AccessSession struct {
SessionID string `json:"sessionId"`
ResourceID int `json:"resourceId"`
SourceAddr string `json:"sourceAddr"`
DestAddr string `json:"destAddr"`
Protocol string `json:"protocol"`
StartedAt time.Time `json:"startedAt"`
EndedAt time.Time `json:"endedAt,omitempty"`
BytesTx int64 `json:"bytesTx"`
BytesRx int64 `json:"bytesRx"`
ConnectionCount int `json:"connectionCount,omitempty"` // number of raw connections merged into this session (0 or 1 = single)
}
// udpSessionKey identifies a unique UDP "session" by src -> dst
type udpSessionKey struct {
srcAddr string
dstAddr string
protocol string
}
// consolidationKey groups connections that may be part of the same logical session.
// Source port is intentionally excluded so that many ephemeral-port connections
// from the same source IP to the same destination are grouped together.
type consolidationKey struct {
sourceIP string // IP only, no port
destAddr string // full host:port of the destination
protocol string
resourceID int
}
// AccessLogger tracks access sessions for resources and periodically
// flushes completed sessions to the server via a configurable SendFunc.
type AccessLogger struct {
mu sync.Mutex
sessions map[string]*AccessSession // active sessions: sessionID -> session
udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup
completedSessions []*AccessSession // completed sessions waiting to be flushed
udpTimeout time.Duration
sendFn SendFunc
stopCh chan struct{}
flushDone chan struct{} // closed after the flush goroutine exits
}
// NewAccessLogger creates a new access logger.
// udpTimeout controls how long a UDP session is kept alive without traffic before being ended.
func NewAccessLogger(udpTimeout time.Duration) *AccessLogger {
al := &AccessLogger{
sessions: make(map[string]*AccessSession),
udpSessions: make(map[udpSessionKey]*AccessSession),
completedSessions: make([]*AccessSession, 0),
udpTimeout: udpTimeout,
stopCh: make(chan struct{}),
flushDone: make(chan struct{}),
}
go al.backgroundLoop()
return al
}
// SetSendFunc sets the callback used to send compressed access log batches
// to the server. This can be called after construction once the websocket
// client is available.
func (al *AccessLogger) SetSendFunc(fn SendFunc) {
al.mu.Lock()
defer al.mu.Unlock()
al.sendFn = fn
}
// generateSessionID creates a random session identifier
func generateSessionID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}
// StartTCPSession logs the start of a TCP session and returns a session ID.
func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string {
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "tcp",
StartedAt: now,
}
al.mu.Lock()
al.sessions[sessionID] = session
al.mu.Unlock()
logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndTCPSession logs the end of a TCP session and queues it for sending.
func (al *AccessLogger) EndTCPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// TrackUDPSession starts or returns an existing UDP session. Returns the session ID.
func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string {
key := udpSessionKey{
srcAddr: srcAddr,
dstAddr: dstAddr,
protocol: "udp",
}
al.mu.Lock()
defer al.mu.Unlock()
if existing, ok := al.udpSessions[key]; ok {
return existing.SessionID
}
sessionID := generateSessionID()
now := time.Now()
session := &AccessSession{
SessionID: sessionID,
ResourceID: resourceID,
SourceAddr: srcAddr,
DestAddr: dstAddr,
Protocol: "udp",
StartedAt: now,
}
al.sessions[sessionID] = session
al.udpSessions[key] = session
logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s",
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
return sessionID
}
// EndUDPSession ends a UDP session and queues it for sending.
func (al *AccessLogger) EndUDPSession(sessionID string) {
now := time.Now()
al.mu.Lock()
session, ok := al.sessions[sessionID]
if ok {
session.EndedAt = now
delete(al.sessions, sessionID)
key := udpSessionKey{
srcAddr: session.SourceAddr,
dstAddr: session.DestAddr,
protocol: "udp",
}
delete(al.udpSessions, key)
al.completedSessions = append(al.completedSessions, session)
}
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
al.mu.Unlock()
if ok {
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
}
if shouldFlush {
al.flush()
}
}
// backgroundLoop handles periodic flushing and stale session reaping.
func (al *AccessLogger) backgroundLoop() {
defer close(al.flushDone)
flushTicker := time.NewTicker(flushInterval)
defer flushTicker.Stop()
reapTicker := time.NewTicker(30 * time.Second)
defer reapTicker.Stop()
for {
select {
case <-al.stopCh:
return
case <-flushTicker.C:
al.flush()
case <-reapTicker.C:
al.reapStaleSessions()
}
}
}
// reapStaleSessions cleans up UDP sessions that were not properly ended.
func (al *AccessLogger) reapStaleSessions() {
al.mu.Lock()
defer al.mu.Unlock()
staleThreshold := time.Now().Add(-5 * time.Minute)
for key, session := range al.udpSessions {
if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() {
now := time.Now()
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
delete(al.sessions, session.SessionID)
delete(al.udpSessions, key)
}
}
}
// extractIP strips the port from an address string and returns just the IP.
// If the address has no port component it is returned as-is.
func extractIP(addr string) string {
host, _, err := net.SplitHostPort(addr)
if err != nil {
// Might already be a bare IP
return addr
}
return host
}
// consolidateSessions takes a slice of completed sessions and merges bursts of
// short-lived connections from the same source IP to the same destination into
// single higher-level session entries.
//
// The algorithm:
// 1. Group sessions by (sourceIP, destAddr, protocol, resourceID).
// 2. Within each group, sort by StartedAt.
// 3. Walk through the sorted list and merge consecutive sessions whose gap
// (previous EndedAt → next StartedAt) is ≤ sessionGapThreshold.
// 4. For merged sessions the earliest StartedAt and latest EndedAt are kept,
// bytes are summed, and ConnectionCount records how many raw connections
// were folded in. If the merged connections used more than one source port,
// SourceAddr is set to just the IP (port omitted).
// 5. Groups with fewer than minConnectionsToConsolidate members are passed
// through unmodified.
func consolidateSessions(sessions []*AccessSession) []*AccessSession {
if len(sessions) <= 1 {
return sessions
}
// Group sessions by consolidation key
groups := make(map[consolidationKey][]*AccessSession)
for _, s := range sessions {
key := consolidationKey{
sourceIP: extractIP(s.SourceAddr),
destAddr: s.DestAddr,
protocol: s.Protocol,
resourceID: s.ResourceID,
}
groups[key] = append(groups[key], s)
}
result := make([]*AccessSession, 0, len(sessions))
for key, group := range groups {
// Small groups don't need consolidation
if len(group) < minConnectionsToConsolidate {
result = append(result, group...)
continue
}
// Sort the group by start time so we can detect gaps
sort.Slice(group, func(i, j int) bool {
return group[i].StartedAt.Before(group[j].StartedAt)
})
// Walk through and merge runs that are within the gap threshold
var merged []*AccessSession
cur := cloneSession(group[0])
cur.ConnectionCount = 1
sourcePorts := make(map[string]struct{})
sourcePorts[cur.SourceAddr] = struct{}{}
for i := 1; i < len(group); i++ {
s := group[i]
// Determine the gap: from the latest end time we've seen so far to the
// start of the next connection.
gapRef := cur.EndedAt
if gapRef.IsZero() {
gapRef = cur.StartedAt
}
gap := s.StartedAt.Sub(gapRef)
if gap <= sessionGapThreshold {
// Merge into the current consolidated session
cur.ConnectionCount++
cur.BytesTx += s.BytesTx
cur.BytesRx += s.BytesRx
sourcePorts[s.SourceAddr] = struct{}{}
// Extend EndedAt to the latest time
endTime := s.EndedAt
if endTime.IsZero() {
endTime = s.StartedAt
}
if endTime.After(cur.EndedAt) {
cur.EndedAt = endTime
}
} else {
// Gap exceeded — finalize the current session and start a new one
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
cur = cloneSession(s)
cur.ConnectionCount = 1
sourcePorts = make(map[string]struct{})
sourcePorts[s.SourceAddr] = struct{}{}
}
}
// Finalize the last accumulated session
finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts)
merged = append(merged, cur)
result = append(result, merged...)
}
return result
}
// cloneSession creates a shallow copy of an AccessSession.
func cloneSession(s *AccessSession) *AccessSession {
cp := *s
return &cp
}
// finalizeMergedSourceAddr sets the SourceAddr on a consolidated session.
// If multiple distinct source addresses (ports) were seen, the port is
// stripped and only the IP is kept so the log isn't misleading.
func finalizeMergedSourceAddr(s *AccessSession, sourceIP string, ports map[string]struct{}) {
if len(ports) > 1 {
// Multiple source ports — just report the IP
s.SourceAddr = sourceIP
}
// Otherwise keep the original SourceAddr which already has ip:port
}
// flush drains the completed sessions buffer, consolidates bursts of
// short-lived connections, compresses with zlib, and sends via the SendFunc.
func (al *AccessLogger) flush() {
al.mu.Lock()
if len(al.completedSessions) == 0 {
al.mu.Unlock()
return
}
batch := al.completedSessions
al.completedSessions = make([]*AccessSession, 0)
sendFn := al.sendFn
al.mu.Unlock()
if sendFn == nil {
logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch))
return
}
// Consolidate bursts of short-lived connections into higher-level sessions
originalCount := len(batch)
batch = consolidateSessions(batch)
if len(batch) != originalCount {
logger.Info("Access logger: consolidated %d raw connections into %d sessions", originalCount, len(batch))
}
compressed, err := compressSessions(batch)
if err != nil {
logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err)
return
}
if err := sendFn(compressed); err != nil {
logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err)
// Re-queue the batch so we don't lose data
al.mu.Lock()
al.completedSessions = append(batch, al.completedSessions...)
// Cap re-queued data to prevent unbounded growth if server is unreachable
if len(al.completedSessions) > maxBufferedSessions*5 {
dropped := len(al.completedSessions) - maxBufferedSessions*5
al.completedSessions = al.completedSessions[:maxBufferedSessions*5]
logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped)
}
al.mu.Unlock()
return
}
logger.Info("Access logger: sent %d sessions to server", len(batch))
}
// compressSessions JSON-encodes the sessions, compresses with zlib, and returns
// a base64-encoded string suitable for embedding in a JSON message.
func compressSessions(sessions []*AccessSession) (string, error) {
jsonData, err := json.Marshal(sessions)
if err != nil {
return "", err
}
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
if err != nil {
return "", err
}
if _, err := w.Write(jsonData); err != nil {
w.Close()
return "", err
}
if err := w.Close(); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// Close shuts down the background loop, ends all active sessions,
// and performs one final flush to send everything to the server.
func (al *AccessLogger) Close() {
// Signal the background loop to stop
select {
case <-al.stopCh:
// Already closed
return
default:
close(al.stopCh)
}
// Wait for the background loop to exit so we don't race on flush
<-al.flushDone
al.mu.Lock()
now := time.Now()
// End all active sessions and move them to the completed buffer
for _, session := range al.sessions {
if session.EndedAt.IsZero() {
session.EndedAt = now
duration := now.Sub(session.StartedAt)
logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s",
session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr,
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
al.completedSessions = append(al.completedSessions, session)
}
}
al.sessions = make(map[string]*AccessSession)
al.udpSessions = make(map[udpSessionKey]*AccessSession)
al.mu.Unlock()
// Final flush to send all remaining sessions to the server
al.flush()
}

View File

@@ -0,0 +1,811 @@
package netstack2
import (
"testing"
"time"
)
func TestExtractIP(t *testing.T) {
tests := []struct {
name string
addr string
expected string
}{
{"ipv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"ipv4 without port", "192.168.1.1", "192.168.1.1"},
{"ipv6 with port", "[::1]:12345", "::1"},
{"ipv6 without port", "::1", "::1"},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractIP(tt.addr)
if result != tt.expected {
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, result, tt.expected)
}
})
}
}
func TestConsolidateSessions_Empty(t *testing.T) {
result := consolidateSessions(nil)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
result = consolidateSessions([]*AccessSession{})
if len(result) != 0 {
t.Errorf("expected empty slice, got %d items", len(result))
}
}
func TestConsolidateSessions_SingleSession(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "abc123",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(1 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr preserved, got %q", result[0].SourceAddr)
}
}
func TestConsolidateSessions_MergesBurstFromSameSourceIP(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
BytesRx: 200,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 150,
BytesRx: 250,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 50,
BytesRx: 75,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only (multiple ports), got %q", s.SourceAddr)
}
if s.DestAddr != "192.168.1.100:443" {
t.Errorf("expected dest addr preserved, got %q", s.DestAddr)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
expectedTx := int64(300)
expectedRx := int64(525)
if s.BytesTx != expectedTx {
t.Errorf("expected BytesTx=%d, got %d", expectedTx, s.BytesTx)
}
if s.BytesRx != expectedRx {
t.Errorf("expected BytesRx=%d, got %d", expectedRx, s.BytesRx)
}
}
func TestConsolidateSessions_SameSourcePortPreserved(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session, got %d", len(result))
}
if result[0].SourceAddr != "10.0.0.1:5000" {
t.Errorf("expected source addr with port preserved when all ports are the same, got %q", result[0].SourceAddr)
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_GapSplitsSessions(t *testing.T) {
now := time.Now()
// First burst
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
// Big gap here (10 seconds)
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10 * time.Second),
EndedAt: now.Add(10*time.Second + 100*time.Millisecond),
},
{
SessionID: "s4",
ResourceID: 1,
SourceAddr: "10.0.0.1:5003",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(10*time.Second + 200*time.Millisecond),
EndedAt: now.Add(10*time.Second + 300*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 consolidated sessions (gap split), got %d", len(result))
}
// Find the sessions by their start time
var first, second *AccessSession
for _, s := range result {
if s.StartedAt.Equal(now) {
first = s
} else {
second = s
}
}
if first == nil || second == nil {
t.Fatal("could not find both consolidated sessions")
}
if first.ConnectionCount != 2 {
t.Errorf("first burst: expected ConnectionCount=2, got %d", first.ConnectionCount)
}
if second.ConnectionCount != 2 {
t.Errorf("second burst: expected ConnectionCount=2, got %d", second.ConnectionCount)
}
}
func TestConsolidateSessions_DifferentDestinationsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:8080",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
// Each goes to a different dest port so they should not be merged
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different destinations), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentProtocolsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different protocols), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentResourceIDsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 2,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different resource IDs), got %d", len(result))
}
}
func TestConsolidateSessions_DifferentSourceIPsNotMerged(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.2:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (different source IPs), got %d", len(result))
}
}
func TestConsolidateSessions_OutOfOrderInput(t *testing.T) {
now := time.Now()
// Provide sessions out of chronological order to verify sorting
sessions := []*AccessSession{
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
BytesTx: 30,
},
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 10,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 20,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Errorf("expected StartedAt to be earliest time")
}
if s.EndedAt != now.Add(500*time.Millisecond) {
t.Errorf("expected EndedAt to be latest time")
}
if s.BytesTx != 60 {
t.Errorf("expected BytesTx=60, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_ExactlyAtGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts exactly sessionGapThreshold after s1 ends — should still merge
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 session (gap exactly at threshold merges), got %d", len(result))
}
if result[0].ConnectionCount != 2 {
t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount)
}
}
func TestConsolidateSessions_JustOverGapThreshold(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
// Starts 1ms over the gap threshold after s1 ends — should split
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 1*time.Millisecond),
EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions (gap just over threshold splits), got %d", len(result))
}
}
func TestConsolidateSessions_UDPSessions(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
{
SessionID: "u1",
ResourceID: 5,
SourceAddr: "10.0.0.1:6000",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now,
EndedAt: now.Add(50 * time.Millisecond),
BytesTx: 64,
BytesRx: 512,
},
{
SessionID: "u2",
ResourceID: 5,
SourceAddr: "10.0.0.1:6001",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(100 * time.Millisecond),
EndedAt: now.Add(150 * time.Millisecond),
BytesTx: 64,
BytesRx: 256,
},
{
SessionID: "u3",
ResourceID: 5,
SourceAddr: "10.0.0.1:6002",
DestAddr: "192.168.1.100:53",
Protocol: "udp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(250 * time.Millisecond),
BytesTx: 64,
BytesRx: 128,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated UDP session, got %d", len(result))
}
s := result[0]
if s.Protocol != "udp" {
t.Errorf("expected protocol=udp, got %q", s.Protocol)
}
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.SourceAddr != "10.0.0.1" {
t.Errorf("expected source addr to be IP only, got %q", s.SourceAddr)
}
if s.BytesTx != 192 {
t.Errorf("expected BytesTx=192, got %d", s.BytesTx)
}
if s.BytesRx != 896 {
t.Errorf("expected BytesRx=896, got %d", s.BytesRx)
}
}
func TestConsolidateSessions_MixedGroupsSomeConsolidatedSomeNot(t *testing.T) {
now := time.Now()
sessions := []*AccessSession{
// Group 1: 3 connections to :443 from same IP — should consolidate
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(400 * time.Millisecond),
EndedAt: now.Add(500 * time.Millisecond),
},
// Group 2: 1 connection to :8080 from different IP — should pass through
{
SessionID: "s4",
ResourceID: 2,
SourceAddr: "10.0.0.2:6000",
DestAddr: "192.168.1.200:8080",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(2 * time.Second),
},
}
result := consolidateSessions(sessions)
if len(result) != 2 {
t.Fatalf("expected 2 sessions total, got %d", len(result))
}
var consolidated, passthrough *AccessSession
for _, s := range result {
if s.ConnectionCount > 1 {
consolidated = s
} else {
passthrough = s
}
}
if consolidated == nil {
t.Fatal("expected a consolidated session")
}
if consolidated.ConnectionCount != 3 {
t.Errorf("consolidated: expected ConnectionCount=3, got %d", consolidated.ConnectionCount)
}
if passthrough == nil {
t.Fatal("expected a passthrough session")
}
if passthrough.SessionID != "s4" {
t.Errorf("passthrough: expected session s4, got %s", passthrough.SessionID)
}
}
func TestConsolidateSessions_OverlappingConnections(t *testing.T) {
now := time.Now()
// Connections that overlap in time (not sequential)
sessions := []*AccessSession{
{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(5 * time.Second),
BytesTx: 100,
},
{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(1 * time.Second),
EndedAt: now.Add(3 * time.Second),
BytesTx: 200,
},
{
SessionID: "s3",
ResourceID: 1,
SourceAddr: "10.0.0.1:5002",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(2 * time.Second),
EndedAt: now.Add(6 * time.Second),
BytesTx: 300,
},
}
result := consolidateSessions(sessions)
if len(result) != 1 {
t.Fatalf("expected 1 consolidated session, got %d", len(result))
}
s := result[0]
if s.ConnectionCount != 3 {
t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount)
}
if s.StartedAt != now {
t.Error("expected StartedAt to be earliest")
}
if s.EndedAt != now.Add(6*time.Second) {
t.Error("expected EndedAt to be the latest end time")
}
if s.BytesTx != 600 {
t.Errorf("expected BytesTx=600, got %d", s.BytesTx)
}
}
func TestConsolidateSessions_DoesNotMutateOriginals(t *testing.T) {
now := time.Now()
s1 := &AccessSession{
SessionID: "s1",
ResourceID: 1,
SourceAddr: "10.0.0.1:5000",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now,
EndedAt: now.Add(100 * time.Millisecond),
BytesTx: 100,
}
s2 := &AccessSession{
SessionID: "s2",
ResourceID: 1,
SourceAddr: "10.0.0.1:5001",
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(200 * time.Millisecond),
EndedAt: now.Add(300 * time.Millisecond),
BytesTx: 200,
}
// Save original values
origS1Addr := s1.SourceAddr
origS1Bytes := s1.BytesTx
origS2Addr := s2.SourceAddr
origS2Bytes := s2.BytesTx
_ = consolidateSessions([]*AccessSession{s1, s2})
if s1.SourceAddr != origS1Addr {
t.Errorf("s1.SourceAddr was mutated: %q -> %q", origS1Addr, s1.SourceAddr)
}
if s1.BytesTx != origS1Bytes {
t.Errorf("s1.BytesTx was mutated: %d -> %d", origS1Bytes, s1.BytesTx)
}
if s2.SourceAddr != origS2Addr {
t.Errorf("s2.SourceAddr was mutated: %q -> %q", origS2Addr, s2.SourceAddr)
}
if s2.BytesTx != origS2Bytes {
t.Errorf("s2.BytesTx was mutated: %d -> %d", origS2Bytes, s2.BytesTx)
}
}
func TestConsolidateSessions_ThreeBurstsWithGaps(t *testing.T) {
now := time.Now()
sessions := make([]*AccessSession, 0, 9)
// Burst 1: 3 connections at t=0
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('A'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(time.Duration(i*100) * time.Millisecond),
EndedAt: now.Add(time.Duration(i*100+50) * time.Millisecond),
})
}
// Burst 2: 3 connections at t=20s (well past the 5s gap)
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('D'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(20*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(20*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
// Burst 3: 3 connections at t=40s
for i := 0; i < 3; i++ {
sessions = append(sessions, &AccessSession{
SessionID: generateSessionID(),
ResourceID: 1,
SourceAddr: "10.0.0.1:" + string(rune('G'+i)),
DestAddr: "192.168.1.100:443",
Protocol: "tcp",
StartedAt: now.Add(40*time.Second + time.Duration(i*100)*time.Millisecond),
EndedAt: now.Add(40*time.Second + time.Duration(i*100+50)*time.Millisecond),
})
}
result := consolidateSessions(sessions)
if len(result) != 3 {
t.Fatalf("expected 3 consolidated sessions (3 bursts), got %d", len(result))
}
for _, s := range result {
if s.ConnectionCount != 3 {
t.Errorf("expected each burst to have ConnectionCount=3, got %d (started=%v)", s.ConnectionCount, s.StartedAt)
}
}
}
func TestFinalizeMergedSourceAddr(t *testing.T) {
s := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports := map[string]struct{}{"10.0.0.1:5000": {}}
finalizeMergedSourceAddr(s, "10.0.0.1", ports)
if s.SourceAddr != "10.0.0.1:5000" {
t.Errorf("single port: expected addr preserved, got %q", s.SourceAddr)
}
s2 := &AccessSession{SourceAddr: "10.0.0.1:5000"}
ports2 := map[string]struct{}{"10.0.0.1:5000": {}, "10.0.0.1:5001": {}}
finalizeMergedSourceAddr(s2, "10.0.0.1", ports2)
if s2.SourceAddr != "10.0.0.1" {
t.Errorf("multiple ports: expected IP only, got %q", s2.SourceAddr)
}
}
func TestCloneSession(t *testing.T) {
original := &AccessSession{
SessionID: "test",
ResourceID: 42,
SourceAddr: "1.2.3.4:100",
DestAddr: "5.6.7.8:443",
Protocol: "tcp",
BytesTx: 999,
}
clone := cloneSession(original)
if clone == original {
t.Error("clone should be a different pointer")
}
if clone.SessionID != original.SessionID {
t.Error("clone should have same SessionID")
}
// Mutating clone should not affect original
clone.BytesTx = 0
clone.SourceAddr = "changed"
if original.BytesTx != 999 {
t.Error("mutating clone affected original BytesTx")
}
if original.SourceAddr != "1.2.3.4:100" {
t.Error("mutating clone affected original SourceAddr")
}
}

View File

@@ -137,14 +137,31 @@ func (h *TCPHandler) InstallTCPHandler() error {
// handleTCPConn handles a TCP connection by proxying it to the actual target
func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.TransportEndpointID) {
defer netstackConn.Close()
// Extract source and target address from the connection ID
// Extract source and target address from the connection ID first so they
// are available for HTTP routing before any defer is set up.
srcIP := id.RemoteAddress.String()
srcPort := id.RemotePort
dstIP := id.LocalAddress.String()
dstPort := id.LocalPort
// For HTTP/HTTPS ports, look up the matching subnet rule. If the rule has
// Protocol configured, hand the connection off to the HTTP handler which
// takes full ownership of the lifecycle (the defer close must not be
// installed before this point).
if (dstPort == 80 || dstPort == 443) && h.proxyHandler != nil && h.proxyHandler.httpHandler != nil {
srcAddr, _ := netip.ParseAddr(srcIP)
dstAddr, _ := netip.ParseAddr(dstIP)
rule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, dstPort, tcp.ProtocolNumber)
if rule != nil && rule.Protocol != "" {
logger.Info("TCP Forwarder: Routing %s:%d -> %s:%d to HTTP handler (%s)",
srcIP, srcPort, dstIP, dstPort, rule.Protocol)
h.proxyHandler.httpHandler.HandleConn(netstackConn, rule)
return
}
}
defer netstackConn.Close()
logger.Info("TCP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
// Check if there's a destination rewrite for this connection (e.g., localhost targets)
@@ -158,6 +175,18 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr)
}
}
}
// Create context with timeout for connection establishment
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
@@ -167,11 +196,26 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
if err != nil {
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
// End access session on connection failure
if accessSessionID != "" {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}
// Connection failed, netstack will handle RST
return
}
defer targetConn.Close()
// End access session when connection closes
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndTCPSession(accessSessionID)
}
}()
}
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
// Bidirectional copy between netstack and target
@@ -280,6 +324,27 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
// Look up resource ID and start access session if applicable
var accessSessionID string
if h.proxyHandler != nil {
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber))
if resourceId != 0 {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr)
}
}
}
// End access session when UDP handler returns (timeout or error)
if accessSessionID != "" {
defer func() {
if al := h.proxyHandler.GetAccessLogger(); al != nil {
al.EndUDPSession(accessSessionID)
}
}()
}
// Resolve target address
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {

318
netstack2/http_handler.go Normal file
View File

@@ -0,0 +1,318 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package netstack2
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"github.com/fosrl/newt/logger"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// ---------------------------------------------------------------------------
// HTTPTarget
// ---------------------------------------------------------------------------
// HTTPTarget describes a single downstream HTTP or HTTPS service that the
// proxy should forward requests to.
type HTTPTarget struct {
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
UseHTTPS bool `json:"useHttps"` // When true the outbound leg uses HTTPS
}
// ---------------------------------------------------------------------------
// HTTPHandler
// ---------------------------------------------------------------------------
// HTTPHandler intercepts TCP connections from the netstack forwarder on ports
// 80 and 443 and services them as HTTP or HTTPS, reverse-proxying each request
// to downstream targets specified by the matching SubnetRule.
//
// HTTP and raw TCP are fully separate: a connection is only routed here when
// its SubnetRule has Protocol set ("http" or "https"). All other connections
// on those ports fall through to the normal raw-TCP path.
//
// Incoming TLS termination (Protocol == "https") is performed per-connection
// using the certificate and key stored in the rule, so different subnet rules
// can present different certificates without sharing any state.
//
// Outbound connections to downstream targets honour HTTPTarget.UseHTTPS
// independently of the incoming protocol.
type HTTPHandler struct {
stack *stack.Stack
proxyHandler *ProxyHandler
listener *chanListener
server *http.Server
// proxyCache holds pre-built *httputil.ReverseProxy values keyed by the
// canonical target URL string ("scheme://host:port"). Building a proxy is
// cheap, but reusing one preserves the underlying http.Transport connection
// pool, which matters for throughput.
proxyCache sync.Map // map[string]*httputil.ReverseProxy
// tlsCache holds pre-parsed *tls.Config values keyed by the concatenation
// of the PEM certificate and key. Parsing a keypair is relatively expensive
// and the same cert is likely reused across many connections.
tlsCache sync.Map // map[string]*tls.Config
}
// ---------------------------------------------------------------------------
// chanListener net.Listener backed by a channel
// ---------------------------------------------------------------------------
// chanListener implements net.Listener by receiving net.Conn values over a
// buffered channel. This lets the netstack TCP forwarder hand off connections
// directly to a running http.Server without any real OS socket.
type chanListener struct {
connCh chan net.Conn
closed chan struct{}
once sync.Once
}
func newChanListener() *chanListener {
return &chanListener{
connCh: make(chan net.Conn, 128),
closed: make(chan struct{}),
}
}
// Accept blocks until a connection is available or the listener is closed.
func (l *chanListener) Accept() (net.Conn, error) {
select {
case conn, ok := <-l.connCh:
if !ok {
return nil, net.ErrClosed
}
return conn, nil
case <-l.closed:
return nil, net.ErrClosed
}
}
// Close shuts down the listener; subsequent Accept calls return net.ErrClosed.
func (l *chanListener) Close() error {
l.once.Do(func() { close(l.closed) })
return nil
}
// Addr returns a placeholder address (the listener has no real OS socket).
func (l *chanListener) Addr() net.Addr {
return &net.TCPAddr{}
}
// send delivers conn to the listener. Returns false if the listener is already
// closed, in which case the caller is responsible for closing conn.
func (l *chanListener) send(conn net.Conn) bool {
select {
case l.connCh <- conn:
return true
case <-l.closed:
return false
}
}
// ---------------------------------------------------------------------------
// httpConnCtx conn wrapper that carries a SubnetRule through the listener
// ---------------------------------------------------------------------------
// httpConnCtx wraps a net.Conn so the matching SubnetRule can be passed
// through the chanListener into the http.Server's ConnContext callback,
// making it available to request handlers via the request context.
type httpConnCtx struct {
net.Conn
rule *SubnetRule
}
// connCtxKey is the unexported context key used to store a *SubnetRule on the
// per-connection context created by http.Server.ConnContext.
type connCtxKey struct{}
// ---------------------------------------------------------------------------
// Constructor and lifecycle
// ---------------------------------------------------------------------------
// NewHTTPHandler creates an HTTPHandler attached to the given stack and
// ProxyHandler. Call Start to begin serving connections.
func NewHTTPHandler(s *stack.Stack, ph *ProxyHandler) *HTTPHandler {
return &HTTPHandler{
stack: s,
proxyHandler: ph,
}
}
// Start launches the internal http.Server that services connections delivered
// via HandleConn. The server runs for the lifetime of the HTTPHandler; call
// Close to stop it.
func (h *HTTPHandler) Start() error {
h.listener = newChanListener()
h.server = &http.Server{
Handler: http.HandlerFunc(h.handleRequest),
// ConnContext runs once per accepted connection and attaches the
// SubnetRule carried by httpConnCtx to the connection's context so
// that handleRequest can retrieve it without any global state.
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
if cc, ok := c.(*httpConnCtx); ok {
return context.WithValue(ctx, connCtxKey{}, cc.rule)
}
return ctx
},
}
go func() {
if err := h.server.Serve(h.listener); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP handler: server exited unexpectedly: %v", err)
}
}()
logger.Info("HTTP handler: ready — routing determined per SubnetRule on ports 80/443")
return nil
}
// HandleConn accepts a TCP connection from the netstack forwarder together
// with the SubnetRule that matched it. The HTTP handler takes full ownership
// of the connection's lifecycle; the caller must NOT close conn after this call.
//
// When rule.Protocol is "https", TLS termination is performed on conn using
// the certificate and key stored in rule.TLSCert and rule.TLSKey before the
// connection is passed to the HTTP server. The HTTP server itself is always
// plain-HTTP; TLS is fully unwrapped at this layer.
func (h *HTTPHandler) HandleConn(conn net.Conn, rule *SubnetRule) {
var effectiveConn net.Conn = conn
if rule.Protocol == "https" {
tlsCfg, err := h.getTLSConfig(rule)
if err != nil {
logger.Error("HTTP handler: cannot build TLS config for connection from %s: %v",
conn.RemoteAddr(), err)
conn.Close()
return
}
// tls.Server wraps the raw conn; the TLS handshake is deferred until
// the first Read, which the http.Server will trigger naturally.
effectiveConn = tls.Server(conn, tlsCfg)
}
wrapped := &httpConnCtx{Conn: effectiveConn, rule: rule}
if !h.listener.send(wrapped) {
// Listener is already closed — clean up the orphaned connection.
effectiveConn.Close()
}
}
// Close gracefully shuts down the HTTP server and the underlying channel
// listener, causing the goroutine started in Start to exit.
func (h *HTTPHandler) Close() error {
if h.server != nil {
if err := h.server.Close(); err != nil {
return err
}
}
if h.listener != nil {
h.listener.Close()
}
return nil
}
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
// getTLSConfig returns a *tls.Config for the cert/key pair in rule, using a
// cache to avoid re-parsing the same keypair on every connection.
// The cache key is the concatenation of the PEM cert and key strings, so
// different rules that happen to share the same material hit the same entry.
func (h *HTTPHandler) getTLSConfig(rule *SubnetRule) (*tls.Config, error) {
cacheKey := rule.TLSCert + "|" + rule.TLSKey
if v, ok := h.tlsCache.Load(cacheKey); ok {
return v.(*tls.Config), nil
}
cert, err := tls.X509KeyPair([]byte(rule.TLSCert), []byte(rule.TLSKey))
if err != nil {
return nil, fmt.Errorf("failed to parse TLS keypair: %w", err)
}
cfg := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// LoadOrStore is safe under concurrent calls: if two goroutines race here
// both will produce a valid config; the loser's work is discarded.
actual, _ := h.tlsCache.LoadOrStore(cacheKey, cfg)
return actual.(*tls.Config), nil
}
// getProxy returns a cached *httputil.ReverseProxy for the given target,
// creating one on first use. Reusing the proxy preserves its http.Transport
// connection pool, avoiding repeated TCP/TLS handshakes to the downstream.
func (h *HTTPHandler) getProxy(target HTTPTarget) *httputil.ReverseProxy {
scheme := "http"
if target.UseHTTPS {
scheme = "https"
}
cacheKey := fmt.Sprintf("%s://%s:%d", scheme, target.DestAddr, target.DestPort)
if v, ok := h.proxyCache.Load(cacheKey); ok {
return v.(*httputil.ReverseProxy)
}
targetURL := &url.URL{
Scheme: scheme,
Host: fmt.Sprintf("%s:%d", target.DestAddr, target.DestPort),
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
if target.UseHTTPS {
// Allow self-signed certificates on downstream HTTPS targets.
proxy.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, //nolint:gosec // downstream self-signed certs are a supported configuration
},
}
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("HTTP handler: upstream error (%s %s -> %s): %v",
r.Method, r.URL.RequestURI(), cacheKey, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
}
actual, _ := h.proxyCache.LoadOrStore(cacheKey, proxy)
return actual.(*httputil.ReverseProxy)
}
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
// attached to the connection by ConnContext, selects the first configured
// downstream target, and forwards the request via the cached ReverseProxy.
//
// TODO: add host/path-based routing across multiple HTTPTargets once the
// configuration model evolves beyond a single target per rule.
func (h *HTTPHandler) handleRequest(w http.ResponseWriter, r *http.Request) {
rule, _ := r.Context().Value(connCtxKey{}).(*SubnetRule)
if rule == nil || len(rule.HTTPTargets) == 0 {
logger.Error("HTTP handler: no downstream targets for request %s %s", r.Method, r.URL.RequestURI())
http.Error(w, "no targets configured", http.StatusBadGateway)
return
}
target := rule.HTTPTargets[0]
scheme := "http"
if target.UseHTTPS {
scheme = "https"
}
logger.Info("HTTP handler: %s %s -> %s://%s:%d",
r.Method, r.URL.RequestURI(), scheme, target.DestAddr, target.DestPort)
h.getProxy(target).ServeHTTP(w, r)
}

View File

@@ -22,6 +22,12 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
// udpAccessSessionTimeout is how long a UDP access session stays alive without traffic
// before being considered ended by the access logger
udpAccessSessionTimeout = 120 * time.Second
)
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
type PortRange struct {
@@ -46,6 +52,15 @@ type SubnetRule struct {
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed
ResourceId int // Optional resource ID from the server for access logging
// HTTP proxy configuration (optional).
// When Protocol is non-empty the TCP connection is handled by HTTPHandler
// instead of the raw TCP forwarder.
Protocol string // "", "http", or "https" — controls the incoming (client-facing) protocol
HTTPTargets []HTTPTarget // downstream services to proxy requests to
TLSCert string // PEM-encoded certificate for incoming HTTPS termination
TLSKey string // PEM-encoded private key for incoming HTTPS termination
}
// GetAllRules returns a copy of all subnet rules
@@ -107,14 +122,17 @@ type ProxyHandler struct {
tcpHandler *TCPHandler
udpHandler *UDPHandler
icmpHandler *ICMPHandler
httpHandler *HTTPHandler
subnetLookup *SubnetLookup
natTable map[connKey]*natState
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
resourceTable map[destKey]int // Maps connection key to resource ID for access logging
natMu sync.RWMutex
enabled bool
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
notifiable channel.Notification // Notification handler for triggering reads
accessLogger *AccessLogger // Access logger for tracking sessions
}
// ProxyHandlerOptions configures the proxy handler
@@ -137,7 +155,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
natTable: make(map[connKey]*natState),
reverseNatTable: make(map[reverseConnKey]*natState),
destRewriteTable: make(map[destKey]netip.Addr),
resourceTable: make(map[destKey]int),
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
accessLogger: NewAccessLogger(udpAccessSessionTimeout),
proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -153,12 +173,21 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
}),
}
// Initialize TCP handler if enabled
// Initialize TCP handler if enabled. The HTTP handler piggybacks on the
// TCP forwarder — TCPHandler.handleTCPConn checks the subnet rule for
// ports 80/443 and routes matching connections to the HTTP handler, so
// the HTTP handler is always initialised alongside TCP.
if options.EnableTCP {
handler.tcpHandler = NewTCPHandler(handler.proxyStack, handler)
if err := handler.tcpHandler.InstallTCPHandler(); err != nil {
return nil, fmt.Errorf("failed to install TCP handler: %v", err)
}
handler.httpHandler = NewHTTPHandler(handler.proxyStack, handler)
if err := handler.httpHandler.Start(); err != nil {
return nil, fmt.Errorf("failed to start HTTP handler: %v", err)
}
logger.Debug("ProxyHandler: HTTP handler enabled")
}
// Initialize UDP handler if enabled
@@ -197,16 +226,14 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
return handler, nil
}
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
// sourcePrefix: The IP prefix of the peer sending the data
// destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
// AddSubnetRule adds a subnet rule to the proxy handler.
// HTTP proxy behaviour is configured via rule.Protocol, rule.HTTPTargets,
// rule.TLSCert, and rule.TLSKey; leave Protocol empty for raw TCP/UDP.
func (p *ProxyHandler) AddSubnetRule(rule SubnetRule) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
p.subnetLookup.AddSubnet(rule)
}
// RemoveSubnetRule removes a subnet from the proxy handler
@@ -225,6 +252,43 @@ func (p *ProxyHandler) GetAllRules() []SubnetRule {
return p.subnetLookup.GetAllRules()
}
// LookupResourceId looks up the resource ID for a connection
// Returns 0 if no resource ID is associated with this connection
func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int {
if p == nil || !p.enabled {
return 0
}
key := destKey{
srcIP: srcIP,
dstIP: dstIP,
dstPort: dstPort,
proto: proto,
}
p.natMu.RLock()
defer p.natMu.RUnlock()
return p.resourceTable[key]
}
// GetAccessLogger returns the access logger for session tracking
func (p *ProxyHandler) GetAccessLogger() *AccessLogger {
if p == nil {
return nil
}
return p.accessLogger
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
if p == nil || !p.enabled || p.accessLogger == nil {
return
}
p.accessLogger.SetSendFunc(fn)
}
// LookupDestinationRewrite looks up the rewritten destination for a connection
// This is used by TCP/UDP handlers to find the actual target address
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
@@ -387,8 +451,22 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
// Check if the source IP, destination IP, port, and protocol match any subnet rule
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
if matchedRule != nil {
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
srcAddr, dstAddr, protocol, dstPort)
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)",
srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId)
// Store resource ID for connections without DNAT as well
if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" {
dKey := destKey{
srcIP: srcAddr.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we need to perform DNAT
if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination
@@ -420,6 +498,13 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
proto: uint8(protocol),
}
// Store resource ID for access logging if present
if matchedRule.ResourceId != 0 {
p.natMu.Lock()
p.resourceTable[dKey] = matchedRule.ResourceId
p.natMu.Unlock()
}
// Check if we already have a NAT entry for this connection
p.natMu.RLock()
existingEntry, exists := p.natTable[key]
@@ -720,6 +805,16 @@ func (p *ProxyHandler) Close() error {
return nil
}
// Shut down access logger
if p.accessLogger != nil {
p.accessLogger.Close()
}
// Shut down HTTP handler
if p.httpHandler != nil {
p.httpHandler.Close()
}
// Close ICMP replies channel
if p.icmpReplies != nil {
close(p.icmpReplies)

View File

@@ -44,23 +44,18 @@ func prefixEqual(a, b netip.Prefix) bool {
return a.Masked() == b.Masked()
}
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
// AddSubnet adds a subnet rule to the lookup table.
// If rule.PortRanges is nil or empty, all ports are allowed.
// rule.RewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com").
// HTTP proxy behaviour is driven by rule.Protocol, rule.HTTPTargets, rule.TLSCert, and rule.TLSKey.
func (sl *SubnetLookup) AddSubnet(rule SubnetRule) {
sl.mu.Lock()
defer sl.mu.Unlock()
rule := &SubnetRule{
SourcePrefix: sourcePrefix,
DestPrefix: destPrefix,
DisableIcmp: disableIcmp,
RewriteTo: rewriteTo,
PortRanges: portRanges,
}
rulePtr := &rule
// Canonicalize source prefix to handle host bits correctly
canonicalSourcePrefix := sourcePrefix.Masked()
canonicalSourcePrefix := rule.SourcePrefix.Masked()
// Get or create destination trie for this source prefix
destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix)
@@ -75,12 +70,12 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
// Canonicalize destination prefix to handle host bits correctly
// BART masks prefixes internally, so we need to match that behavior in our bookkeeping
canonicalDestPrefix := destPrefix.Masked()
canonicalDestPrefix := rule.DestPrefix.Masked()
// Add rule to destination trie
// Original behavior: overwrite if same (sourcePrefix, destPrefix) exists
// Store as single-element slice to match original overwrite behavior
destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rule})
destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rulePtr})
// Update destTriePtr.rules - remove old rule with same canonical prefix if exists, then add new one
// Use canonical comparison to handle cases like 10.0.0.5/24 vs 10.0.0.0/24
@@ -90,7 +85,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
newRules = append(newRules, r)
}
}
newRules = append(newRules, rule)
newRules = append(newRules, rulePtr)
destTriePtr.rules = newRules
}

View File

@@ -351,13 +351,13 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
return net.DialUDP(laddr, nil)
}
// AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
// AddProxySubnetRule adds a subnet rule to the proxy handler.
// HTTP proxy behaviour is configured via rule.Protocol, rule.HTTPTargets,
// rule.TLSCert, and rule.TLSKey; leave Protocol empty for raw TCP/UDP.
func (net *Net) AddProxySubnetRule(rule SubnetRule) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
tun.proxyHandler.AddSubnetRule(rule)
}
}
@@ -385,6 +385,15 @@ func (net *Net) GetProxyHandler() *ProxyHandler {
return tun.proxyHandler
}
// SetAccessLogSender configures the function used to send compressed access log
// batches to the server. This should be called once the websocket client is available.
func (net *Net) SetAccessLogSender(fn SendFunc) {
tun := (*netTun)(net)
if tun.proxyHandler != nil {
tun.proxyHandler.SetAccessLogSender(fn)
}
}
type PingConn struct {
laddr PingAddr
raddr PingAddr

View File

@@ -42,6 +42,7 @@ type Client struct {
onTokenUpdate func(token string)
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
configFilePath string // Optional override for the config file path
tlsConfig TLSConfig
metricsCtxMu sync.RWMutex
metricsCtx context.Context
@@ -52,6 +53,7 @@ type Client struct {
processingMessage bool // Flag to track if a message is currently being processed
processingMux sync.RWMutex // Protects processingMessage
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
justProvisioned bool // Set to true when provisionIfNeeded exchanges a key for permanent credentials
}
type ClientOption func(*Client)
@@ -77,6 +79,12 @@ func WithBaseURL(url string) ClientOption {
}
// WithTLSConfig sets the TLS configuration for the client
func WithConfigFile(path string) ClientOption {
return func(c *Client) {
c.configFilePath = path
}
}
func WithTLSConfig(config TLSConfig) ClientOption {
return func(c *Client) {
c.tlsConfig = config
@@ -95,6 +103,16 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
c.onTokenUpdate = callback
}
// WasJustProvisioned reports whether the client exchanged a provisioning key
// for permanent credentials during the most recent connection attempt. It
// consumes the flag subsequent calls return false until provisioning occurs
// again (which, in practice, never happens once credentials are persisted).
func (c *Client) WasJustProvisioned() bool {
v := c.justProvisioned
c.justProvisioned = false
return v
}
func (c *Client) metricsContext() context.Context {
c.metricsCtxMu.RLock()
defer c.metricsCtxMu.RUnlock()
@@ -481,6 +499,11 @@ func (c *Client) connectWithRetry() {
func (c *Client) establishConnection() error {
ctx := context.Background()
// Exchange provisioning key for permanent credentials if needed.
if err := c.provisionIfNeeded(); err != nil {
return fmt.Errorf("failed to provision newt credentials: %w", err)
}
// Get token for authentication
token, err := c.getToken()
if err != nil {

View File

@@ -1,16 +1,29 @@
package websocket
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"runtime"
"strings"
"time"
"github.com/fosrl/newt/logger"
)
func getConfigPath(clientType string) string {
func getConfigPath(clientType string, overridePath string) string {
if overridePath != "" {
return overridePath
}
configFile := os.Getenv("CONFIG_FILE")
if configFile == "" {
var configDir string
@@ -36,7 +49,7 @@ func getConfigPath(clientType string) string {
func (c *Client) loadConfig() error {
originalConfig := *c.config // Store original config to detect changes
configPath := getConfigPath(c.clientType)
configPath := getConfigPath(c.clientType, c.configFilePath)
if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
logger.Debug("Config already provided, skipping loading from file")
@@ -58,6 +71,11 @@ func (c *Client) loadConfig() error {
}
return err
}
if len(bytes.TrimSpace(data)) == 0 {
logger.Info("Config file at %s is empty, will initialize it with provided values", configPath)
c.configNeedsSave = true
return nil
}
var config Config
if err := json.Unmarshal(data, &config); err != nil {
@@ -83,6 +101,14 @@ func (c *Client) loadConfig() error {
c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint
}
// Always load the provisioning key from the file if not already set
if c.config.ProvisioningKey == "" {
c.config.ProvisioningKey = config.ProvisioningKey
}
// Always load the name from the file if not already set
if c.config.Name == "" {
c.config.Name = config.Name
}
// Check if CLI args provided values that override file values
if (!fileHadID && originalConfig.ID != "") ||
@@ -105,7 +131,7 @@ func (c *Client) saveConfig() error {
return nil
}
configPath := getConfigPath(c.clientType)
configPath := getConfigPath(c.clientType, c.configFilePath)
data, err := json.MarshalIndent(c.config, "", " ")
if err != nil {
return err
@@ -118,3 +144,139 @@ func (c *Client) saveConfig() error {
}
return err
}
// interpolateString replaces {{env.VAR}} tokens in s with the corresponding
// environment variable values. Tokens that do not match a supported scheme are
// left unchanged, mirroring the blueprint interpolation logic.
func interpolateString(s string) string {
re := regexp.MustCompile(`\{\{([^}]+)\}\}`)
return re.ReplaceAllStringFunc(s, func(match string) string {
inner := strings.TrimSpace(match[2 : len(match)-2])
if strings.HasPrefix(inner, "env.") {
varName := strings.TrimPrefix(inner, "env.")
return os.Getenv(varName)
}
return match
})
}
// provisionIfNeeded checks whether a provisioning key is present and, if so,
// exchanges it for a newt ID and secret by calling the registration endpoint.
// On success the config is updated in-place and flagged for saving so that
// subsequent runs use the permanent credentials directly.
func (c *Client) provisionIfNeeded() error {
if c.config.ProvisioningKey == "" {
return nil
}
// If we already have both credentials there is nothing to provision.
if c.config.ID != "" && c.config.Secret != "" {
logger.Debug("Credentials already present, skipping provisioning")
return nil
}
logger.Info("Provisioning key found exchanging for newt credentials...")
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return fmt.Errorf("failed to parse base URL for provisioning: %w", err)
}
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
// Interpolate any {{env.VAR}} tokens in the name before sending.
name := interpolateString(c.config.Name)
reqBody := map[string]interface{}{
"provisioningKey": c.config.ProvisioningKey,
}
if name != "" {
reqBody["name"] = name
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal provisioning request: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(
ctx,
"POST",
baseEndpoint+"/api/v1/auth/newt/register",
bytes.NewBuffer(jsonData),
)
if err != nil {
return fmt.Errorf("failed to create provisioning request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// Mirror the TLS setup used by getToken so mTLS / self-signed CAs work.
var tlsCfg *tls.Config
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" ||
len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsCfg, err = c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS for provisioning: %w", err)
}
}
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
if tlsCfg == nil {
tlsCfg = &tls.Config{}
}
tlsCfg.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled for provisioning via SKIP_TLS_VERIFY")
}
httpClient := &http.Client{}
if tlsCfg != nil {
httpClient.Transport = &http.Transport{TLSClientConfig: tlsCfg}
}
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("provisioning request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
logger.Debug("Provisioning response body: %s", string(body))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("provisioning endpoint returned status %d: %s", resp.StatusCode, string(body))
}
var provResp ProvisioningResponse
if err := json.Unmarshal(body, &provResp); err != nil {
return fmt.Errorf("failed to decode provisioning response: %w", err)
}
if !provResp.Success {
return fmt.Errorf("provisioning failed: %s", provResp.Message)
}
if provResp.Data.NewtID == "" || provResp.Data.Secret == "" {
return fmt.Errorf("provisioning response is missing newt ID or secret")
}
logger.Info("Successfully provisioned newt ID: %s", provResp.Data.NewtID)
// Persist the returned credentials and clear the one-time provisioning key
// so subsequent runs authenticate normally.
c.config.ID = provResp.Data.NewtID
c.config.Secret = provResp.Data.Secret
c.config.ProvisioningKey = ""
c.config.Name = ""
c.configNeedsSave = true
c.justProvisioned = true
// Save immediately so that if the subsequent connection attempt fails the
// provisioning key is already gone from disk and the next retry uses the
// permanent credentials instead of trying to provision again.
if err := c.saveConfig(); err != nil {
logger.Error("Failed to save config after provisioning: %v", err)
}
return nil
}

35
websocket/config_test.go Normal file
View File

@@ -0,0 +1,35 @@
package websocket
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfig_EmptyFileMarksConfigForSave(t *testing.T) {
t.Setenv("CONFIG_FILE", "")
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
if err := os.WriteFile(configPath, []byte(""), 0o644); err != nil {
t.Fatalf("failed to create empty config file: %v", err)
}
client := &Client{
config: &Config{
Endpoint: "https://example.com",
ProvisioningKey: "spk-test",
},
clientType: "newt",
configFilePath: configPath,
}
if err := client.loadConfig(); err != nil {
t.Fatalf("loadConfig returned error for empty file: %v", err)
}
if !client.configNeedsSave {
t.Fatal("expected empty config file to mark configNeedsSave")
}
}

View File

@@ -1,10 +1,12 @@
package websocket
type Config struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
TlsClientCert string `json:"tlsClientCert"`
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
TlsClientCert string `json:"tlsClientCert"`
ProvisioningKey string `json:"provisioningKey,omitempty"`
Name string `json:"name,omitempty"`
}
type TokenResponse struct {
@@ -16,8 +18,17 @@ type TokenResponse struct {
Message string `json:"message"`
}
type ProvisioningResponse struct {
Data struct {
NewtID string `json:"newtId"`
Secret string `json:"secret"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
ConfigVersion int64 `json:"configVersion,omitempty"`
}
}