From 0f57985b6ffbd32d5755c78152b2a86a06145f37 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 16:39:01 -0700 Subject: [PATCH 1/9] Saving and sending access logs pass 1 --- clients/clients.go | 25 ++- netstack2/access_log.go | 355 +++++++++++++++++++++++++++++++++++++ netstack2/handlers.go | 48 +++++ netstack2/proxy.go | 82 ++++++++- netstack2/subnet_lookup.go | 3 +- netstack2/tun.go | 13 +- 6 files changed, 515 insertions(+), 11 deletions(-) create mode 100644 netstack2/access_log.go diff --git a/clients/clients.go b/clients/clients.go index 4c64dbd..9223262 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -43,6 +43,7 @@ type Target struct { RewriteTo string `json:"rewriteTo,omitempty"` DisableIcmp bool `json:"disableIcmp,omitempty"` PortRange []PortRange `json:"portRange,omitempty"` + ResourceId int `json:"resourceId,omitempty"` } type PortRange struct { @@ -196,6 +197,15 @@ func (s *WireGuardService) Close() { s.stopGetConfig = nil } + // Flush access logs before tearing down the tunnel + if s.tnet != nil { + if ph := s.tnet.GetProxyHandler(); ph != nil { + if al := ph.GetAccessLogger(); al != nil { + al.Close() + } + } + } + // Stop the direct UDP relay first s.StopDirectUDPRelay() @@ -663,7 +673,7 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error { }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix) } } @@ -794,6 +804,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.TunnelIP = tunnelIP.String() + // Configure the access log sender to ship compressed session logs via websocket + s.tnet.SetAccessLogSender(func(data string) error { + return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{ + "compressed": data, + }) + }) + // Create WireGuard device using the shared bind s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( device.LogLevelSilent, // Use silent logging by default - could be made configurable @@ -914,7 +931,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { if err != nil { return fmt.Errorf("invalid CIDR %s: %v", sp, err) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } @@ -1307,7 +1324,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { logger.Info("Invalid CIDR %s: %v", sp, err) continue } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } @@ -1425,7 +1442,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { logger.Info("Invalid CIDR %s: %v", sp, err) continue } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } diff --git a/netstack2/access_log.go b/netstack2/access_log.go new file mode 100644 index 0000000..ab0db78 --- /dev/null +++ b/netstack2/access_log.go @@ -0,0 +1,355 @@ +package netstack2 + +import ( + "bytes" + "compress/zlib" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "sync" + "time" + + "github.com/fosrl/newt/logger" +) + +const ( + // flushInterval is how often the access logger flushes completed sessions to the server + flushInterval = 60 * time.Second + + // maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush + maxBufferedSessions = 100 +) + +// SendFunc is a callback that sends compressed access log data to the server. +// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects. +type SendFunc func(data string) error + +// AccessSession represents a tracked access session through the proxy +type AccessSession struct { + SessionID string `json:"sessionId"` + ResourceID int `json:"resourceId"` + SourceAddr string `json:"sourceAddr"` + DestAddr string `json:"destAddr"` + Protocol string `json:"protocol"` + StartedAt time.Time `json:"startedAt"` + EndedAt time.Time `json:"endedAt,omitempty"` + BytesTx int64 `json:"bytesTx"` + BytesRx int64 `json:"bytesRx"` +} + +// udpSessionKey identifies a unique UDP "session" by src -> dst +type udpSessionKey struct { + srcAddr string + dstAddr string + protocol string +} + +// AccessLogger tracks access sessions for resources and periodically +// flushes completed sessions to the server via a configurable SendFunc. +type AccessLogger struct { + mu sync.Mutex + sessions map[string]*AccessSession // active sessions: sessionID -> session + udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup + completedSessions []*AccessSession // completed sessions waiting to be flushed + udpTimeout time.Duration + sendFn SendFunc + stopCh chan struct{} + flushDone chan struct{} // closed after the flush goroutine exits +} + +// NewAccessLogger creates a new access logger. +// udpTimeout controls how long a UDP session is kept alive without traffic before being ended. +func NewAccessLogger(udpTimeout time.Duration) *AccessLogger { + al := &AccessLogger{ + sessions: make(map[string]*AccessSession), + udpSessions: make(map[udpSessionKey]*AccessSession), + completedSessions: make([]*AccessSession, 0), + udpTimeout: udpTimeout, + stopCh: make(chan struct{}), + flushDone: make(chan struct{}), + } + go al.backgroundLoop() + return al +} + +// SetSendFunc sets the callback used to send compressed access log batches +// to the server. This can be called after construction once the websocket +// client is available. +func (al *AccessLogger) SetSendFunc(fn SendFunc) { + al.mu.Lock() + defer al.mu.Unlock() + al.sendFn = fn +} + +// generateSessionID creates a random session identifier +func generateSessionID() string { + b := make([]byte, 8) + rand.Read(b) + return hex.EncodeToString(b) +} + +// StartTCPSession logs the start of a TCP session and returns a session ID. +func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string { + sessionID := generateSessionID() + now := time.Now() + + session := &AccessSession{ + SessionID: sessionID, + ResourceID: resourceID, + SourceAddr: srcAddr, + DestAddr: dstAddr, + Protocol: "tcp", + StartedAt: now, + } + + al.mu.Lock() + al.sessions[sessionID] = session + al.mu.Unlock() + + logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s", + sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339)) + + return sessionID +} + +// EndTCPSession logs the end of a TCP session and queues it for sending. +func (al *AccessLogger) EndTCPSession(sessionID string) { + now := time.Now() + + al.mu.Lock() + session, ok := al.sessions[sessionID] + if ok { + session.EndedAt = now + delete(al.sessions, sessionID) + al.completedSessions = append(al.completedSessions, session) + } + shouldFlush := len(al.completedSessions) >= maxBufferedSessions + al.mu.Unlock() + + if ok { + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s", + sessionID, session.ResourceID, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + } + + if shouldFlush { + al.flush() + } +} + +// TrackUDPSession starts or returns an existing UDP session. Returns the session ID. +func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string { + key := udpSessionKey{ + srcAddr: srcAddr, + dstAddr: dstAddr, + protocol: "udp", + } + + al.mu.Lock() + defer al.mu.Unlock() + + if existing, ok := al.udpSessions[key]; ok { + return existing.SessionID + } + + sessionID := generateSessionID() + now := time.Now() + + session := &AccessSession{ + SessionID: sessionID, + ResourceID: resourceID, + SourceAddr: srcAddr, + DestAddr: dstAddr, + Protocol: "udp", + StartedAt: now, + } + + al.sessions[sessionID] = session + al.udpSessions[key] = session + + logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s", + sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339)) + + return sessionID +} + +// EndUDPSession ends a UDP session and queues it for sending. +func (al *AccessLogger) EndUDPSession(sessionID string) { + now := time.Now() + + al.mu.Lock() + session, ok := al.sessions[sessionID] + if ok { + session.EndedAt = now + delete(al.sessions, sessionID) + key := udpSessionKey{ + srcAddr: session.SourceAddr, + dstAddr: session.DestAddr, + protocol: "udp", + } + delete(al.udpSessions, key) + al.completedSessions = append(al.completedSessions, session) + } + shouldFlush := len(al.completedSessions) >= maxBufferedSessions + al.mu.Unlock() + + if ok { + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s", + sessionID, session.ResourceID, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + } + + if shouldFlush { + al.flush() + } +} + +// backgroundLoop handles periodic flushing and stale session reaping. +func (al *AccessLogger) backgroundLoop() { + defer close(al.flushDone) + + flushTicker := time.NewTicker(flushInterval) + defer flushTicker.Stop() + + reapTicker := time.NewTicker(30 * time.Second) + defer reapTicker.Stop() + + for { + select { + case <-al.stopCh: + return + case <-flushTicker.C: + al.flush() + case <-reapTicker.C: + al.reapStaleSessions() + } + } +} + +// reapStaleSessions cleans up UDP sessions that were not properly ended. +func (al *AccessLogger) reapStaleSessions() { + al.mu.Lock() + defer al.mu.Unlock() + + staleThreshold := time.Now().Add(-5 * time.Minute) + + for key, session := range al.udpSessions { + if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() { + now := time.Now() + session.EndedAt = now + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s", + session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + al.completedSessions = append(al.completedSessions, session) + delete(al.sessions, session.SessionID) + delete(al.udpSessions, key) + } + } +} + +// flush drains the completed sessions buffer, compresses with zlib, and sends via the SendFunc. +func (al *AccessLogger) flush() { + al.mu.Lock() + if len(al.completedSessions) == 0 { + al.mu.Unlock() + return + } + batch := al.completedSessions + al.completedSessions = make([]*AccessSession, 0) + sendFn := al.sendFn + al.mu.Unlock() + + if sendFn == nil { + logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch)) + return + } + + compressed, err := compressSessions(batch) + if err != nil { + logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err) + return + } + + if err := sendFn(compressed); err != nil { + logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err) + // Re-queue the batch so we don't lose data + al.mu.Lock() + al.completedSessions = append(batch, al.completedSessions...) + // Cap re-queued data to prevent unbounded growth if server is unreachable + if len(al.completedSessions) > maxBufferedSessions*5 { + dropped := len(al.completedSessions) - maxBufferedSessions*5 + al.completedSessions = al.completedSessions[:maxBufferedSessions*5] + logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped) + } + al.mu.Unlock() + return + } + + logger.Info("Access logger: sent %d sessions to server", len(batch)) +} + +// compressSessions JSON-encodes the sessions, compresses with zlib, and returns +// a base64-encoded string suitable for embedding in a JSON message. +func compressSessions(sessions []*AccessSession) (string, error) { + jsonData, err := json.Marshal(sessions) + if err != nil { + return "", err + } + + var buf bytes.Buffer + w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression) + if err != nil { + return "", err + } + if _, err := w.Write(jsonData); err != nil { + w.Close() + return "", err + } + if err := w.Close(); err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +// Close shuts down the background loop, ends all active sessions, +// and performs one final flush to send everything to the server. +func (al *AccessLogger) Close() { + // Signal the background loop to stop + select { + case <-al.stopCh: + // Already closed + return + default: + close(al.stopCh) + } + + // Wait for the background loop to exit so we don't race on flush + <-al.flushDone + + al.mu.Lock() + now := time.Now() + + // End all active sessions and move them to the completed buffer + for _, session := range al.sessions { + if session.EndedAt.IsZero() { + session.EndedAt = now + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s", + session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + al.completedSessions = append(al.completedSessions, session) + } + } + + al.sessions = make(map[string]*AccessSession) + al.udpSessions = make(map[udpSessionKey]*AccessSession) + al.mu.Unlock() + + // Final flush to send all remaining sessions to the server + al.flush() +} \ No newline at end of file diff --git a/netstack2/handlers.go b/netstack2/handlers.go index 75c58b2..07c235f 100644 --- a/netstack2/handlers.go +++ b/netstack2/handlers.go @@ -158,6 +158,18 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) + // Look up resource ID and start access session if applicable + var accessSessionID string + if h.proxyHandler != nil { + resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber)) + if resourceId != 0 { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort) + accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr) + } + } + } + // Create context with timeout for connection establishment ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) defer cancel() @@ -167,11 +179,26 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo targetConn, err := d.DialContext(ctx, "tcp", targetAddr) if err != nil { logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err) + // End access session on connection failure + if accessSessionID != "" { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + al.EndTCPSession(accessSessionID) + } + } // Connection failed, netstack will handle RST return } defer targetConn.Close() + // End access session when connection closes + if accessSessionID != "" { + defer func() { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + al.EndTCPSession(accessSessionID) + } + }() + } + logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr) // Bidirectional copy between netstack and target @@ -280,6 +307,27 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) + // Look up resource ID and start access session if applicable + var accessSessionID string + if h.proxyHandler != nil { + resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber)) + if resourceId != 0 { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort) + accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr) + } + } + } + + // End access session when UDP handler returns (timeout or error) + if accessSessionID != "" { + defer func() { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + al.EndUDPSession(accessSessionID) + } + }() + } + // Resolve target address remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) if err != nil { diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 1b34818..e383fc0 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -22,6 +22,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) +const ( + // udpAccessSessionTimeout is how long a UDP access session stays alive without traffic + // before being considered ended by the access logger + udpAccessSessionTimeout = 120 * time.Second +) + // PortRange represents an allowed range of ports (inclusive) with optional protocol filtering // Protocol can be "tcp", "udp", or "" (empty string means both protocols) type PortRange struct { @@ -46,6 +52,7 @@ type SubnetRule struct { DisableIcmp bool // If true, ICMP traffic is blocked for this subnet RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name PortRanges []PortRange // empty slice means all ports allowed + ResourceId int // Optional resource ID from the server for access logging } // GetAllRules returns a copy of all subnet rules @@ -111,10 +118,12 @@ type ProxyHandler struct { natTable map[connKey]*natState reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups + resourceTable map[destKey]int // Maps connection key to resource ID for access logging natMu sync.RWMutex enabled bool icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel notifiable channel.Notification // Notification handler for triggering reads + accessLogger *AccessLogger // Access logger for tracking sessions } // ProxyHandlerOptions configures the proxy handler @@ -137,7 +146,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { natTable: make(map[connKey]*natState), reverseNatTable: make(map[reverseConnKey]*natState), destRewriteTable: make(map[destKey]netip.Addr), + resourceTable: make(map[destKey]int), icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets + accessLogger: NewAccessLogger(udpAccessSessionTimeout), proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyStack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -202,11 +213,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { // destPrefix: The IP prefix of the destination // rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name // If portRanges is nil or empty, all ports are allowed for this subnet -func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { +func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { if p == nil || !p.enabled { return } - p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) + p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) } // RemoveSubnetRule removes a subnet from the proxy handler @@ -225,6 +236,43 @@ func (p *ProxyHandler) GetAllRules() []SubnetRule { return p.subnetLookup.GetAllRules() } +// LookupResourceId looks up the resource ID for a connection +// Returns 0 if no resource ID is associated with this connection +func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int { + if p == nil || !p.enabled { + return 0 + } + + key := destKey{ + srcIP: srcIP, + dstIP: dstIP, + dstPort: dstPort, + proto: proto, + } + + p.natMu.RLock() + defer p.natMu.RUnlock() + + return p.resourceTable[key] +} + +// GetAccessLogger returns the access logger for session tracking +func (p *ProxyHandler) GetAccessLogger() *AccessLogger { + if p == nil { + return nil + } + return p.accessLogger +} + +// SetAccessLogSender configures the function used to send compressed access log +// batches to the server. This should be called once the websocket client is available. +func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) { + if p == nil || !p.enabled || p.accessLogger == nil { + return + } + p.accessLogger.SetSendFunc(fn) +} + // LookupDestinationRewrite looks up the rewritten destination for a connection // This is used by TCP/UDP handlers to find the actual target address func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { @@ -387,8 +435,22 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { // Check if the source IP, destination IP, port, and protocol match any subnet rule matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) if matchedRule != nil { - logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)", - srcAddr, dstAddr, protocol, dstPort) + logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)", + srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId) + + // Store resource ID for connections without DNAT as well + if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" { + dKey := destKey{ + srcIP: srcAddr.String(), + dstIP: dstAddr.String(), + dstPort: dstPort, + proto: uint8(protocol), + } + p.natMu.Lock() + p.resourceTable[dKey] = matchedRule.ResourceId + p.natMu.Unlock() + } + // Check if we need to perform DNAT if matchedRule.RewriteTo != "" { // Create connection tracking key using original destination @@ -420,6 +482,13 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { proto: uint8(protocol), } + // Store resource ID for access logging if present + if matchedRule.ResourceId != 0 { + p.natMu.Lock() + p.resourceTable[dKey] = matchedRule.ResourceId + p.natMu.Unlock() + } + // Check if we already have a NAT entry for this connection p.natMu.RLock() existingEntry, exists := p.natTable[key] @@ -720,6 +789,11 @@ func (p *ProxyHandler) Close() error { return nil } + // Shut down access logger + if p.accessLogger != nil { + p.accessLogger.Close() + } + // Close ICMP replies channel if p.icmpReplies != nil { close(p.icmpReplies) diff --git a/netstack2/subnet_lookup.go b/netstack2/subnet_lookup.go index c6ad0d5..317f85c 100644 --- a/netstack2/subnet_lookup.go +++ b/netstack2/subnet_lookup.go @@ -47,7 +47,7 @@ func prefixEqual(a, b netip.Prefix) bool { // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { +func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { sl.mu.Lock() defer sl.mu.Unlock() @@ -57,6 +57,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite DisableIcmp: disableIcmp, RewriteTo: rewriteTo, PortRanges: portRanges, + ResourceId: resourceId, } // Canonicalize source prefix to handle host bits correctly diff --git a/netstack2/tun.go b/netstack2/tun.go index b00faea..3183c36 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { // AddProxySubnetRule adds a subnet rule to the proxy handler // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { +func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { tun := (*netTun)(net) if tun.proxyHandler != nil { - tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) + tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) } } @@ -385,6 +385,15 @@ func (net *Net) GetProxyHandler() *ProxyHandler { return tun.proxyHandler } +// SetAccessLogSender configures the function used to send compressed access log +// batches to the server. This should be called once the websocket client is available. +func (net *Net) SetAccessLogSender(fn SendFunc) { + tun := (*netTun)(net) + if tun.proxyHandler != nil { + tun.proxyHandler.SetAccessLogSender(fn) + } +} + type PingConn struct { laddr PingAddr raddr PingAddr From 69019d565567b09012061183c6afa9851c01e1b5 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 17:26:44 -0700 Subject: [PATCH 2/9] Process log to form sessions --- netstack2/access_log.go | 179 +++++++- netstack2/access_log_test.go | 811 +++++++++++++++++++++++++++++++++++ 2 files changed, 980 insertions(+), 10 deletions(-) create mode 100644 netstack2/access_log_test.go diff --git a/netstack2/access_log.go b/netstack2/access_log.go index ab0db78..de71296 100644 --- a/netstack2/access_log.go +++ b/netstack2/access_log.go @@ -7,6 +7,8 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "net" + "sort" "sync" "time" @@ -19,6 +21,15 @@ const ( // maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush maxBufferedSessions = 100 + + // sessionGapThreshold is the maximum gap between the end of one connection + // and the start of the next for them to be considered part of the same session. + // If the gap exceeds this, a new consolidated session is created. + sessionGapThreshold = 5 * time.Second + + // minConnectionsToConsolidate is the minimum number of connections in a group + // before we bother consolidating. Groups smaller than this are sent as-is. + minConnectionsToConsolidate = 2 ) // SendFunc is a callback that sends compressed access log data to the server. @@ -27,15 +38,16 @@ type SendFunc func(data string) error // AccessSession represents a tracked access session through the proxy type AccessSession struct { - SessionID string `json:"sessionId"` - ResourceID int `json:"resourceId"` - SourceAddr string `json:"sourceAddr"` - DestAddr string `json:"destAddr"` - Protocol string `json:"protocol"` - StartedAt time.Time `json:"startedAt"` - EndedAt time.Time `json:"endedAt,omitempty"` - BytesTx int64 `json:"bytesTx"` - BytesRx int64 `json:"bytesRx"` + SessionID string `json:"sessionId"` + ResourceID int `json:"resourceId"` + SourceAddr string `json:"sourceAddr"` + DestAddr string `json:"destAddr"` + Protocol string `json:"protocol"` + StartedAt time.Time `json:"startedAt"` + EndedAt time.Time `json:"endedAt,omitempty"` + BytesTx int64 `json:"bytesTx"` + BytesRx int64 `json:"bytesRx"` + ConnectionCount int `json:"connectionCount,omitempty"` // number of raw connections merged into this session (0 or 1 = single) } // udpSessionKey identifies a unique UDP "session" by src -> dst @@ -45,6 +57,16 @@ type udpSessionKey struct { protocol string } +// consolidationKey groups connections that may be part of the same logical session. +// Source port is intentionally excluded so that many ephemeral-port connections +// from the same source IP to the same destination are grouped together. +type consolidationKey struct { + sourceIP string // IP only, no port + destAddr string // full host:port of the destination + protocol string + resourceID int +} + // AccessLogger tracks access sessions for resources and periodically // flushes completed sessions to the server via a configurable SendFunc. type AccessLogger struct { @@ -251,7 +273,137 @@ func (al *AccessLogger) reapStaleSessions() { } } -// flush drains the completed sessions buffer, compresses with zlib, and sends via the SendFunc. +// extractIP strips the port from an address string and returns just the IP. +// If the address has no port component it is returned as-is. +func extractIP(addr string) string { + host, _, err := net.SplitHostPort(addr) + if err != nil { + // Might already be a bare IP + return addr + } + return host +} + +// consolidateSessions takes a slice of completed sessions and merges bursts of +// short-lived connections from the same source IP to the same destination into +// single higher-level session entries. +// +// The algorithm: +// 1. Group sessions by (sourceIP, destAddr, protocol, resourceID). +// 2. Within each group, sort by StartedAt. +// 3. Walk through the sorted list and merge consecutive sessions whose gap +// (previous EndedAt → next StartedAt) is ≤ sessionGapThreshold. +// 4. For merged sessions the earliest StartedAt and latest EndedAt are kept, +// bytes are summed, and ConnectionCount records how many raw connections +// were folded in. If the merged connections used more than one source port, +// SourceAddr is set to just the IP (port omitted). +// 5. Groups with fewer than minConnectionsToConsolidate members are passed +// through unmodified. +func consolidateSessions(sessions []*AccessSession) []*AccessSession { + if len(sessions) <= 1 { + return sessions + } + + // Group sessions by consolidation key + groups := make(map[consolidationKey][]*AccessSession) + for _, s := range sessions { + key := consolidationKey{ + sourceIP: extractIP(s.SourceAddr), + destAddr: s.DestAddr, + protocol: s.Protocol, + resourceID: s.ResourceID, + } + groups[key] = append(groups[key], s) + } + + result := make([]*AccessSession, 0, len(sessions)) + + for key, group := range groups { + // Small groups don't need consolidation + if len(group) < minConnectionsToConsolidate { + result = append(result, group...) + continue + } + + // Sort the group by start time so we can detect gaps + sort.Slice(group, func(i, j int) bool { + return group[i].StartedAt.Before(group[j].StartedAt) + }) + + // Walk through and merge runs that are within the gap threshold + var merged []*AccessSession + cur := cloneSession(group[0]) + cur.ConnectionCount = 1 + sourcePorts := make(map[string]struct{}) + sourcePorts[cur.SourceAddr] = struct{}{} + + for i := 1; i < len(group); i++ { + s := group[i] + + // Determine the gap: from the latest end time we've seen so far to the + // start of the next connection. + gapRef := cur.EndedAt + if gapRef.IsZero() { + gapRef = cur.StartedAt + } + gap := s.StartedAt.Sub(gapRef) + + if gap <= sessionGapThreshold { + // Merge into the current consolidated session + cur.ConnectionCount++ + cur.BytesTx += s.BytesTx + cur.BytesRx += s.BytesRx + sourcePorts[s.SourceAddr] = struct{}{} + + // Extend EndedAt to the latest time + endTime := s.EndedAt + if endTime.IsZero() { + endTime = s.StartedAt + } + if endTime.After(cur.EndedAt) { + cur.EndedAt = endTime + } + } else { + // Gap exceeded — finalize the current session and start a new one + finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts) + merged = append(merged, cur) + + cur = cloneSession(s) + cur.ConnectionCount = 1 + sourcePorts = make(map[string]struct{}) + sourcePorts[s.SourceAddr] = struct{}{} + } + } + + // Finalize the last accumulated session + finalizeMergedSourceAddr(cur, key.sourceIP, sourcePorts) + merged = append(merged, cur) + + result = append(result, merged...) + } + + return result +} + +// cloneSession creates a shallow copy of an AccessSession. +func cloneSession(s *AccessSession) *AccessSession { + cp := *s + return &cp +} + +// finalizeMergedSourceAddr sets the SourceAddr on a consolidated session. +// If multiple distinct source addresses (ports) were seen, the port is +// stripped and only the IP is kept so the log isn't misleading. +func finalizeMergedSourceAddr(s *AccessSession, sourceIP string, ports map[string]struct{}) { + if len(ports) > 1 { + // Multiple source ports — just report the IP + s.SourceAddr = sourceIP + } + // Otherwise keep the original SourceAddr which already has ip:port +} + +// flush drains the completed sessions buffer, consolidates bursts of +// short-lived connections, compresses with zlib, and sends via the SendFunc. func (al *AccessLogger) flush() { al.mu.Lock() if len(al.completedSessions) == 0 { @@ -268,6 +420,13 @@ func (al *AccessLogger) flush() { return } + // Consolidate bursts of short-lived connections into higher-level sessions + originalCount := len(batch) + batch = consolidateSessions(batch) + if len(batch) != originalCount { + logger.Info("Access logger: consolidated %d raw connections into %d sessions", originalCount, len(batch)) + } + compressed, err := compressSessions(batch) if err != nil { logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err) diff --git a/netstack2/access_log_test.go b/netstack2/access_log_test.go new file mode 100644 index 0000000..fc98054 --- /dev/null +++ b/netstack2/access_log_test.go @@ -0,0 +1,811 @@ +package netstack2 + +import ( + "testing" + "time" +) + +func TestExtractIP(t *testing.T) { + tests := []struct { + name string + addr string + expected string + }{ + {"ipv4 with port", "192.168.1.1:12345", "192.168.1.1"}, + {"ipv4 without port", "192.168.1.1", "192.168.1.1"}, + {"ipv6 with port", "[::1]:12345", "::1"}, + {"ipv6 without port", "::1", "::1"}, + {"empty string", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractIP(tt.addr) + if result != tt.expected { + t.Errorf("extractIP(%q) = %q, want %q", tt.addr, result, tt.expected) + } + }) + } +} + +func TestConsolidateSessions_Empty(t *testing.T) { + result := consolidateSessions(nil) + if result != nil { + t.Errorf("expected nil, got %v", result) + } + + result = consolidateSessions([]*AccessSession{}) + if len(result) != 0 { + t.Errorf("expected empty slice, got %d items", len(result)) + } +} + +func TestConsolidateSessions_SingleSession(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "abc123", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(1 * time.Second), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 1 { + t.Fatalf("expected 1 session, got %d", len(result)) + } + if result[0].SourceAddr != "10.0.0.1:5000" { + t.Errorf("expected source addr preserved, got %q", result[0].SourceAddr) + } +} + +func TestConsolidateSessions_MergesBurstFromSameSourceIP(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + BytesTx: 100, + BytesRx: 200, + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + BytesTx: 150, + BytesRx: 250, + }, + { + SessionID: "s3", + ResourceID: 1, + SourceAddr: "10.0.0.1:5002", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(400 * time.Millisecond), + EndedAt: now.Add(500 * time.Millisecond), + BytesTx: 50, + BytesRx: 75, + }, + } + + result := consolidateSessions(sessions) + if len(result) != 1 { + t.Fatalf("expected 1 consolidated session, got %d", len(result)) + } + + s := result[0] + if s.ConnectionCount != 3 { + t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount) + } + if s.SourceAddr != "10.0.0.1" { + t.Errorf("expected source addr to be IP only (multiple ports), got %q", s.SourceAddr) + } + if s.DestAddr != "192.168.1.100:443" { + t.Errorf("expected dest addr preserved, got %q", s.DestAddr) + } + if s.StartedAt != now { + t.Errorf("expected StartedAt to be earliest time") + } + if s.EndedAt != now.Add(500*time.Millisecond) { + t.Errorf("expected EndedAt to be latest time") + } + expectedTx := int64(300) + expectedRx := int64(525) + if s.BytesTx != expectedTx { + t.Errorf("expected BytesTx=%d, got %d", expectedTx, s.BytesTx) + } + if s.BytesRx != expectedRx { + t.Errorf("expected BytesRx=%d, got %d", expectedRx, s.BytesRx) + } +} + +func TestConsolidateSessions_SameSourcePortPreserved(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 1 { + t.Fatalf("expected 1 session, got %d", len(result)) + } + if result[0].SourceAddr != "10.0.0.1:5000" { + t.Errorf("expected source addr with port preserved when all ports are the same, got %q", result[0].SourceAddr) + } + if result[0].ConnectionCount != 2 { + t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount) + } +} + +func TestConsolidateSessions_GapSplitsSessions(t *testing.T) { + now := time.Now() + + // First burst + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + }, + // Big gap here (10 seconds) + { + SessionID: "s3", + ResourceID: 1, + SourceAddr: "10.0.0.1:5002", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(10 * time.Second), + EndedAt: now.Add(10*time.Second + 100*time.Millisecond), + }, + { + SessionID: "s4", + ResourceID: 1, + SourceAddr: "10.0.0.1:5003", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(10*time.Second + 200*time.Millisecond), + EndedAt: now.Add(10*time.Second + 300*time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 2 { + t.Fatalf("expected 2 consolidated sessions (gap split), got %d", len(result)) + } + + // Find the sessions by their start time + var first, second *AccessSession + for _, s := range result { + if s.StartedAt.Equal(now) { + first = s + } else { + second = s + } + } + + if first == nil || second == nil { + t.Fatal("could not find both consolidated sessions") + } + + if first.ConnectionCount != 2 { + t.Errorf("first burst: expected ConnectionCount=2, got %d", first.ConnectionCount) + } + if second.ConnectionCount != 2 { + t.Errorf("second burst: expected ConnectionCount=2, got %d", second.ConnectionCount) + } +} + +func TestConsolidateSessions_DifferentDestinationsNotMerged(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:8080", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + // Each goes to a different dest port so they should not be merged + if len(result) != 2 { + t.Fatalf("expected 2 sessions (different destinations), got %d", len(result)) + } +} + +func TestConsolidateSessions_DifferentProtocolsNotMerged(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "udp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 2 { + t.Fatalf("expected 2 sessions (different protocols), got %d", len(result)) + } +} + +func TestConsolidateSessions_DifferentResourceIDsNotMerged(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + SessionID: "s2", + ResourceID: 2, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 2 { + t.Fatalf("expected 2 sessions (different resource IDs), got %d", len(result)) + } +} + +func TestConsolidateSessions_DifferentSourceIPsNotMerged(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.2:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 2 { + t.Fatalf("expected 2 sessions (different source IPs), got %d", len(result)) + } +} + +func TestConsolidateSessions_OutOfOrderInput(t *testing.T) { + now := time.Now() + // Provide sessions out of chronological order to verify sorting + sessions := []*AccessSession{ + { + SessionID: "s3", + ResourceID: 1, + SourceAddr: "10.0.0.1:5002", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(400 * time.Millisecond), + EndedAt: now.Add(500 * time.Millisecond), + BytesTx: 30, + }, + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + BytesTx: 10, + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + BytesTx: 20, + }, + } + + result := consolidateSessions(sessions) + if len(result) != 1 { + t.Fatalf("expected 1 consolidated session, got %d", len(result)) + } + + s := result[0] + if s.ConnectionCount != 3 { + t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount) + } + if s.StartedAt != now { + t.Errorf("expected StartedAt to be earliest time") + } + if s.EndedAt != now.Add(500*time.Millisecond) { + t.Errorf("expected EndedAt to be latest time") + } + if s.BytesTx != 60 { + t.Errorf("expected BytesTx=60, got %d", s.BytesTx) + } +} + +func TestConsolidateSessions_ExactlyAtGapThreshold(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + // Starts exactly sessionGapThreshold after s1 ends — should still merge + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold), + EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 1 { + t.Fatalf("expected 1 session (gap exactly at threshold merges), got %d", len(result)) + } + if result[0].ConnectionCount != 2 { + t.Errorf("expected ConnectionCount=2, got %d", result[0].ConnectionCount) + } +} + +func TestConsolidateSessions_JustOverGapThreshold(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + // Starts 1ms over the gap threshold after s1 ends — should split + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 1*time.Millisecond), + EndedAt: now.Add(100*time.Millisecond + sessionGapThreshold + 50*time.Millisecond), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 2 { + t.Fatalf("expected 2 sessions (gap just over threshold splits), got %d", len(result)) + } +} + +func TestConsolidateSessions_UDPSessions(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + { + SessionID: "u1", + ResourceID: 5, + SourceAddr: "10.0.0.1:6000", + DestAddr: "192.168.1.100:53", + Protocol: "udp", + StartedAt: now, + EndedAt: now.Add(50 * time.Millisecond), + BytesTx: 64, + BytesRx: 512, + }, + { + SessionID: "u2", + ResourceID: 5, + SourceAddr: "10.0.0.1:6001", + DestAddr: "192.168.1.100:53", + Protocol: "udp", + StartedAt: now.Add(100 * time.Millisecond), + EndedAt: now.Add(150 * time.Millisecond), + BytesTx: 64, + BytesRx: 256, + }, + { + SessionID: "u3", + ResourceID: 5, + SourceAddr: "10.0.0.1:6002", + DestAddr: "192.168.1.100:53", + Protocol: "udp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(250 * time.Millisecond), + BytesTx: 64, + BytesRx: 128, + }, + } + + result := consolidateSessions(sessions) + if len(result) != 1 { + t.Fatalf("expected 1 consolidated UDP session, got %d", len(result)) + } + + s := result[0] + if s.Protocol != "udp" { + t.Errorf("expected protocol=udp, got %q", s.Protocol) + } + if s.ConnectionCount != 3 { + t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount) + } + if s.SourceAddr != "10.0.0.1" { + t.Errorf("expected source addr to be IP only, got %q", s.SourceAddr) + } + if s.BytesTx != 192 { + t.Errorf("expected BytesTx=192, got %d", s.BytesTx) + } + if s.BytesRx != 896 { + t.Errorf("expected BytesRx=896, got %d", s.BytesRx) + } +} + +func TestConsolidateSessions_MixedGroupsSomeConsolidatedSomeNot(t *testing.T) { + now := time.Now() + sessions := []*AccessSession{ + // Group 1: 3 connections to :443 from same IP — should consolidate + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + }, + { + SessionID: "s3", + ResourceID: 1, + SourceAddr: "10.0.0.1:5002", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(400 * time.Millisecond), + EndedAt: now.Add(500 * time.Millisecond), + }, + // Group 2: 1 connection to :8080 from different IP — should pass through + { + SessionID: "s4", + ResourceID: 2, + SourceAddr: "10.0.0.2:6000", + DestAddr: "192.168.1.200:8080", + Protocol: "tcp", + StartedAt: now.Add(1 * time.Second), + EndedAt: now.Add(2 * time.Second), + }, + } + + result := consolidateSessions(sessions) + if len(result) != 2 { + t.Fatalf("expected 2 sessions total, got %d", len(result)) + } + + var consolidated, passthrough *AccessSession + for _, s := range result { + if s.ConnectionCount > 1 { + consolidated = s + } else { + passthrough = s + } + } + + if consolidated == nil { + t.Fatal("expected a consolidated session") + } + if consolidated.ConnectionCount != 3 { + t.Errorf("consolidated: expected ConnectionCount=3, got %d", consolidated.ConnectionCount) + } + + if passthrough == nil { + t.Fatal("expected a passthrough session") + } + if passthrough.SessionID != "s4" { + t.Errorf("passthrough: expected session s4, got %s", passthrough.SessionID) + } +} + +func TestConsolidateSessions_OverlappingConnections(t *testing.T) { + now := time.Now() + // Connections that overlap in time (not sequential) + sessions := []*AccessSession{ + { + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(5 * time.Second), + BytesTx: 100, + }, + { + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(1 * time.Second), + EndedAt: now.Add(3 * time.Second), + BytesTx: 200, + }, + { + SessionID: "s3", + ResourceID: 1, + SourceAddr: "10.0.0.1:5002", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(2 * time.Second), + EndedAt: now.Add(6 * time.Second), + BytesTx: 300, + }, + } + + result := consolidateSessions(sessions) + if len(result) != 1 { + t.Fatalf("expected 1 consolidated session, got %d", len(result)) + } + + s := result[0] + if s.ConnectionCount != 3 { + t.Errorf("expected ConnectionCount=3, got %d", s.ConnectionCount) + } + if s.StartedAt != now { + t.Error("expected StartedAt to be earliest") + } + if s.EndedAt != now.Add(6*time.Second) { + t.Error("expected EndedAt to be the latest end time") + } + if s.BytesTx != 600 { + t.Errorf("expected BytesTx=600, got %d", s.BytesTx) + } +} + +func TestConsolidateSessions_DoesNotMutateOriginals(t *testing.T) { + now := time.Now() + s1 := &AccessSession{ + SessionID: "s1", + ResourceID: 1, + SourceAddr: "10.0.0.1:5000", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now, + EndedAt: now.Add(100 * time.Millisecond), + BytesTx: 100, + } + s2 := &AccessSession{ + SessionID: "s2", + ResourceID: 1, + SourceAddr: "10.0.0.1:5001", + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(200 * time.Millisecond), + EndedAt: now.Add(300 * time.Millisecond), + BytesTx: 200, + } + + // Save original values + origS1Addr := s1.SourceAddr + origS1Bytes := s1.BytesTx + origS2Addr := s2.SourceAddr + origS2Bytes := s2.BytesTx + + _ = consolidateSessions([]*AccessSession{s1, s2}) + + if s1.SourceAddr != origS1Addr { + t.Errorf("s1.SourceAddr was mutated: %q -> %q", origS1Addr, s1.SourceAddr) + } + if s1.BytesTx != origS1Bytes { + t.Errorf("s1.BytesTx was mutated: %d -> %d", origS1Bytes, s1.BytesTx) + } + if s2.SourceAddr != origS2Addr { + t.Errorf("s2.SourceAddr was mutated: %q -> %q", origS2Addr, s2.SourceAddr) + } + if s2.BytesTx != origS2Bytes { + t.Errorf("s2.BytesTx was mutated: %d -> %d", origS2Bytes, s2.BytesTx) + } +} + +func TestConsolidateSessions_ThreeBurstsWithGaps(t *testing.T) { + now := time.Now() + + sessions := make([]*AccessSession, 0, 9) + + // Burst 1: 3 connections at t=0 + for i := 0; i < 3; i++ { + sessions = append(sessions, &AccessSession{ + SessionID: generateSessionID(), + ResourceID: 1, + SourceAddr: "10.0.0.1:" + string(rune('A'+i)), + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(time.Duration(i*100) * time.Millisecond), + EndedAt: now.Add(time.Duration(i*100+50) * time.Millisecond), + }) + } + + // Burst 2: 3 connections at t=20s (well past the 5s gap) + for i := 0; i < 3; i++ { + sessions = append(sessions, &AccessSession{ + SessionID: generateSessionID(), + ResourceID: 1, + SourceAddr: "10.0.0.1:" + string(rune('D'+i)), + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(20*time.Second + time.Duration(i*100)*time.Millisecond), + EndedAt: now.Add(20*time.Second + time.Duration(i*100+50)*time.Millisecond), + }) + } + + // Burst 3: 3 connections at t=40s + for i := 0; i < 3; i++ { + sessions = append(sessions, &AccessSession{ + SessionID: generateSessionID(), + ResourceID: 1, + SourceAddr: "10.0.0.1:" + string(rune('G'+i)), + DestAddr: "192.168.1.100:443", + Protocol: "tcp", + StartedAt: now.Add(40*time.Second + time.Duration(i*100)*time.Millisecond), + EndedAt: now.Add(40*time.Second + time.Duration(i*100+50)*time.Millisecond), + }) + } + + result := consolidateSessions(sessions) + if len(result) != 3 { + t.Fatalf("expected 3 consolidated sessions (3 bursts), got %d", len(result)) + } + + for _, s := range result { + if s.ConnectionCount != 3 { + t.Errorf("expected each burst to have ConnectionCount=3, got %d (started=%v)", s.ConnectionCount, s.StartedAt) + } + } +} + +func TestFinalizeMergedSourceAddr(t *testing.T) { + s := &AccessSession{SourceAddr: "10.0.0.1:5000"} + ports := map[string]struct{}{"10.0.0.1:5000": {}} + finalizeMergedSourceAddr(s, "10.0.0.1", ports) + if s.SourceAddr != "10.0.0.1:5000" { + t.Errorf("single port: expected addr preserved, got %q", s.SourceAddr) + } + + s2 := &AccessSession{SourceAddr: "10.0.0.1:5000"} + ports2 := map[string]struct{}{"10.0.0.1:5000": {}, "10.0.0.1:5001": {}} + finalizeMergedSourceAddr(s2, "10.0.0.1", ports2) + if s2.SourceAddr != "10.0.0.1" { + t.Errorf("multiple ports: expected IP only, got %q", s2.SourceAddr) + } +} + +func TestCloneSession(t *testing.T) { + original := &AccessSession{ + SessionID: "test", + ResourceID: 42, + SourceAddr: "1.2.3.4:100", + DestAddr: "5.6.7.8:443", + Protocol: "tcp", + BytesTx: 999, + } + + clone := cloneSession(original) + + if clone == original { + t.Error("clone should be a different pointer") + } + if clone.SessionID != original.SessionID { + t.Error("clone should have same SessionID") + } + + // Mutating clone should not affect original + clone.BytesTx = 0 + clone.SourceAddr = "changed" + if original.BytesTx != 999 { + t.Error("mutating clone affected original BytesTx") + } + if original.SourceAddr != "1.2.3.4:100" { + t.Error("mutating clone affected original SourceAddr") + } +} \ No newline at end of file From b43572dd8d4aa03a223ee7d987668f4354ab9163 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 26 Mar 2026 17:23:19 -0700 Subject: [PATCH 3/9] Provisioning key working --- main.go | 13 +++++ websocket/client.go | 5 ++ websocket/config.go | 126 ++++++++++++++++++++++++++++++++++++++++++++ websocket/types.go | 20 +++++-- 4 files changed, 159 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index 3646a27..94b3b48 100644 --- a/main.go +++ b/main.go @@ -159,6 +159,9 @@ var ( // Legacy PKCS12 support (deprecated) tlsPrivateKey string + + // Provisioning key – exchanged once for a permanent newt ID + secret + provisioningKey string ) func main() { @@ -264,6 +267,7 @@ func runNewtMain(ctx context.Context) { blueprintFile = os.Getenv("BLUEPRINT_FILE") noCloudEnv := os.Getenv("NO_CLOUD") noCloud = noCloudEnv == "true" + provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -312,6 +316,9 @@ func runNewtMain(ctx context.Context) { } // load the prefer endpoint just as a flag flag.StringVar(&preferEndpoint, "prefer-endpoint", "", "Prefer this endpoint for the connection (if set, will override the endpoint from the server)") + if provisioningKey == "" { + flag.StringVar(&provisioningKey, "provisioning-key", "", "One-time provisioning key used to obtain a newt ID and secret from the server") + } // Add new mTLS flags if tlsClientCert == "" { @@ -590,6 +597,12 @@ func runNewtMain(ctx context.Context) { if err != nil { logger.Fatal("Failed to create client: %v", err) } + // If a provisioning key was supplied via CLI / env and the config file did + // not already carry one, inject it now so provisionIfNeeded() can use it. + if provisioningKey != "" && client.GetConfig().ProvisioningKey == "" { + client.GetConfig().ProvisioningKey = provisioningKey + } + endpoint = client.GetConfig().Endpoint // Update endpoint from config id = client.GetConfig().ID // Update ID from config // Update site labels for metrics with the resolved ID diff --git a/websocket/client.go b/websocket/client.go index 533771b..e645a6f 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -481,6 +481,11 @@ func (c *Client) connectWithRetry() { func (c *Client) establishConnection() error { ctx := context.Background() + // Exchange provisioning key for permanent credentials if needed. + if err := c.provisionIfNeeded(); err != nil { + return fmt.Errorf("failed to provision newt credentials: %w", err) + } + // Get token for authentication token, err := c.getToken() if err != nil { diff --git a/websocket/config.go b/websocket/config.go index 72c9164..8ae7ff5 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -1,11 +1,20 @@ package websocket import ( + "bytes" + "context" + "crypto/tls" "encoding/json" + "fmt" + "io" "log" + "net/http" + "net/url" "os" "path/filepath" "runtime" + "strings" + "time" "github.com/fosrl/newt/logger" ) @@ -83,6 +92,10 @@ func (c *Client) loadConfig() error { c.config.Endpoint = config.Endpoint c.baseURL = config.Endpoint } + // Always load the provisioning key from the file if not already set + if c.config.ProvisioningKey == "" { + c.config.ProvisioningKey = config.ProvisioningKey + } // Check if CLI args provided values that override file values if (!fileHadID && originalConfig.ID != "") || @@ -118,3 +131,116 @@ func (c *Client) saveConfig() error { } return err } + +// provisionIfNeeded checks whether a provisioning key is present and, if so, +// exchanges it for a newt ID and secret by calling the registration endpoint. +// On success the config is updated in-place and flagged for saving so that +// subsequent runs use the permanent credentials directly. +func (c *Client) provisionIfNeeded() error { + if c.config.ProvisioningKey == "" { + return nil + } + + // If we already have both credentials there is nothing to provision. + if c.config.ID != "" && c.config.Secret != "" { + logger.Debug("Credentials already present, skipping provisioning") + return nil + } + + logger.Info("Provisioning key found – exchanging for newt credentials...") + + baseURL, err := url.Parse(c.baseURL) + if err != nil { + return fmt.Errorf("failed to parse base URL for provisioning: %w", err) + } + baseEndpoint := strings.TrimRight(baseURL.String(), "/") + + reqBody := map[string]interface{}{ + "provisioningKey": c.config.ProvisioningKey, + } + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal provisioning request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext( + ctx, + "POST", + baseEndpoint+"/api/v1/auth/newt/register", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return fmt.Errorf("failed to create provisioning request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-CSRF-Token", "x-csrf-protection") + + // Mirror the TLS setup used by getToken so mTLS / self-signed CAs work. + var tlsCfg *tls.Config + if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || + len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { + tlsCfg, err = c.setupTLS() + if err != nil { + return fmt.Errorf("failed to setup TLS for provisioning: %w", err) + } + } + if os.Getenv("SKIP_TLS_VERIFY") == "true" { + if tlsCfg == nil { + tlsCfg = &tls.Config{} + } + tlsCfg.InsecureSkipVerify = true + logger.Debug("TLS certificate verification disabled for provisioning via SKIP_TLS_VERIFY") + } + + httpClient := &http.Client{} + if tlsCfg != nil { + httpClient.Transport = &http.Transport{TLSClientConfig: tlsCfg} + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("provisioning request failed: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + logger.Debug("Provisioning response body: %s", string(body)) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("provisioning endpoint returned status %d: %s", resp.StatusCode, string(body)) + } + + var provResp ProvisioningResponse + if err := json.Unmarshal(body, &provResp); err != nil { + return fmt.Errorf("failed to decode provisioning response: %w", err) + } + + if !provResp.Success { + return fmt.Errorf("provisioning failed: %s", provResp.Message) + } + + if provResp.Data.NewtID == "" || provResp.Data.Secret == "" { + return fmt.Errorf("provisioning response is missing newt ID or secret") + } + + logger.Info("Successfully provisioned – newt ID: %s", provResp.Data.NewtID) + + // Persist the returned credentials and clear the one-time provisioning key + // so subsequent runs authenticate normally. + c.config.ID = provResp.Data.NewtID + c.config.Secret = provResp.Data.Secret + c.config.ProvisioningKey = "" + c.configNeedsSave = true + + // Save immediately so that if the subsequent connection attempt fails the + // provisioning key is already gone from disk and the next retry uses the + // permanent credentials instead of trying to provision again. + if err := c.saveConfig(); err != nil { + logger.Error("Failed to save config after provisioning: %v", err) + } + + return nil +} \ No newline at end of file diff --git a/websocket/types.go b/websocket/types.go index 381f7a1..2b32dae 100644 --- a/websocket/types.go +++ b/websocket/types.go @@ -1,10 +1,11 @@ package websocket type Config struct { - ID string `json:"id"` - Secret string `json:"secret"` - Endpoint string `json:"endpoint"` - TlsClientCert string `json:"tlsClientCert"` + ID string `json:"id"` + Secret string `json:"secret"` + Endpoint string `json:"endpoint"` + TlsClientCert string `json:"tlsClientCert"` + ProvisioningKey string `json:"provisioningKey,omitempty"` } type TokenResponse struct { @@ -16,8 +17,17 @@ type TokenResponse struct { Message string `json:"message"` } +type ProvisioningResponse struct { + Data struct { + NewtID string `json:"newtId"` + Secret string `json:"secret"` + } `json:"data"` + Success bool `json:"success"` + Message string `json:"message"` +} + type WSMessage struct { Type string `json:"type"` Data interface{} `json:"data"` ConfigVersion int64 `json:"configVersion,omitempty"` -} +} \ No newline at end of file From baca04ee58916e3df6a1f74ff77bf7a52dd45e89 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 26 Mar 2026 17:31:04 -0700 Subject: [PATCH 4/9] Add --config-file --- config.json | 4 ++++ config.json.bak | 4 ++++ main.go | 8 ++++++++ websocket/client.go | 7 +++++++ websocket/config.go | 9 ++++++--- 5 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 config.json create mode 100644 config.json.bak diff --git a/config.json b/config.json new file mode 100644 index 0000000..fac9795 --- /dev/null +++ b/config.json @@ -0,0 +1,4 @@ +{ + "endpoint": "http://you.fosrl.io", + "provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx" +} \ No newline at end of file diff --git a/config.json.bak b/config.json.bak new file mode 100644 index 0000000..fac9795 --- /dev/null +++ b/config.json.bak @@ -0,0 +1,4 @@ +{ + "endpoint": "http://you.fosrl.io", + "provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx" +} \ No newline at end of file diff --git a/main.go b/main.go index 94b3b48..0af8773 100644 --- a/main.go +++ b/main.go @@ -162,6 +162,9 @@ var ( // Provisioning key – exchanged once for a permanent newt ID + secret provisioningKey string + + // Path to config file (overrides CONFIG_FILE env var and default location) + configFile string ) func main() { @@ -268,6 +271,7 @@ func runNewtMain(ctx context.Context) { noCloudEnv := os.Getenv("NO_CLOUD") noCloud = noCloudEnv == "true" provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY") + configFile = os.Getenv("CONFIG_FILE") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -319,6 +323,9 @@ func runNewtMain(ctx context.Context) { if provisioningKey == "" { flag.StringVar(&provisioningKey, "provisioning-key", "", "One-time provisioning key used to obtain a newt ID and secret from the server") } + if configFile == "" { + flag.StringVar(&configFile, "config-file", "", "Path to config file (overrides CONFIG_FILE env var and default location)") + } // Add new mTLS flags if tlsClientCert == "" { @@ -593,6 +600,7 @@ func runNewtMain(ctx context.Context) { endpoint, 30*time.Second, opt, + websocket.WithConfigFile(configFile), ) if err != nil { logger.Fatal("Failed to create client: %v", err) diff --git a/websocket/client.go b/websocket/client.go index e645a6f..49cf414 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -42,6 +42,7 @@ type Client struct { onTokenUpdate func(token string) writeMux sync.Mutex clientType string // Type of client (e.g., "newt", "olm") + configFilePath string // Optional override for the config file path tlsConfig TLSConfig metricsCtxMu sync.RWMutex metricsCtx context.Context @@ -77,6 +78,12 @@ func WithBaseURL(url string) ClientOption { } // WithTLSConfig sets the TLS configuration for the client +func WithConfigFile(path string) ClientOption { + return func(c *Client) { + c.configFilePath = path + } +} + func WithTLSConfig(config TLSConfig) ClientOption { return func(c *Client) { c.tlsConfig = config diff --git a/websocket/config.go b/websocket/config.go index 8ae7ff5..4fb6513 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -19,7 +19,10 @@ import ( "github.com/fosrl/newt/logger" ) -func getConfigPath(clientType string) string { +func getConfigPath(clientType string, overridePath string) string { + if overridePath != "" { + return overridePath + } configFile := os.Getenv("CONFIG_FILE") if configFile == "" { var configDir string @@ -45,7 +48,7 @@ func getConfigPath(clientType string) string { func (c *Client) loadConfig() error { originalConfig := *c.config // Store original config to detect changes - configPath := getConfigPath(c.clientType) + configPath := getConfigPath(c.clientType, c.configFilePath) if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" { logger.Debug("Config already provided, skipping loading from file") @@ -118,7 +121,7 @@ func (c *Client) saveConfig() error { return nil } - configPath := getConfigPath(c.clientType) + configPath := getConfigPath(c.clientType, c.configFilePath) data, err := json.MarshalIndent(c.config, "", " ") if err != nil { return err From fc4b375bf1fcf3f457c4d8730a55d8488d6cb87f Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 26 Mar 2026 20:05:04 -0700 Subject: [PATCH 5/9] Allow blueprint interpolation for env vars --- common.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/common.go b/common.go index 4701411..34e3cd0 100644 --- a/common.go +++ b/common.go @@ -8,6 +8,7 @@ import ( "net" "os" "os/exec" + "regexp" "strings" "time" @@ -509,6 +510,29 @@ func executeUpdownScript(action, proto, target string) (string, error) { return target, nil } +// interpolateBlueprint finds all {{...}} tokens in the raw blueprint bytes and +// replaces recognised schemes with their resolved values. Currently supported: +// +// - env. – replaced with the value of the named environment variable +// +// Any token that does not match a supported scheme is left as-is so that +// future schemes (e.g. tag., api.) are preserved rather than silently dropped. +func interpolateBlueprint(data []byte) []byte { + re := regexp.MustCompile(`\{\{([^}]+)\}\}`) + return re.ReplaceAllFunc(data, func(match []byte) []byte { + // strip the surrounding {{ }} + inner := strings.TrimSpace(string(match[2 : len(match)-2])) + + if strings.HasPrefix(inner, "env.") { + varName := strings.TrimPrefix(inner, "env.") + return []byte(os.Getenv(varName)) + } + + // unrecognised scheme – leave the token untouched + return match + }) +} + func sendBlueprint(client *websocket.Client) error { if blueprintFile == "" { return nil @@ -518,6 +542,9 @@ func sendBlueprint(client *websocket.Client) error { if err != nil { logger.Error("Failed to read blueprint file: %v", err) } else { + // interpolate {{env.VAR}} (and any future schemes) before parsing + blueprintData = interpolateBlueprint(blueprintData) + // first we should convert the yaml to json and error if the yaml is bad var yamlObj interface{} var blueprintJsonData string From 5208117c56e3e2ece7898cb958349d96b1dedaff Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 30 Mar 2026 17:18:22 -0700 Subject: [PATCH 6/9] Add name to provisioning --- main.go | 10 ++++++++++ websocket/config.go | 27 +++++++++++++++++++++++++++ websocket/types.go | 1 + 3 files changed, 38 insertions(+) diff --git a/main.go b/main.go index c573ee2..6ad1c2f 100644 --- a/main.go +++ b/main.go @@ -169,6 +169,9 @@ var ( // Provisioning key – exchanged once for a permanent newt ID + secret provisioningKey string + // Optional name for the site created during provisioning + newtName string + // Path to config file (overrides CONFIG_FILE env var and default location) configFile string ) @@ -284,6 +287,7 @@ func runNewtMain(ctx context.Context) { noCloudEnv := os.Getenv("NO_CLOUD") noCloud = noCloudEnv == "true" provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY") + newtName = os.Getenv("NEWT_NAME") configFile = os.Getenv("CONFIG_FILE") if endpoint == "" { @@ -336,6 +340,9 @@ func runNewtMain(ctx context.Context) { if provisioningKey == "" { flag.StringVar(&provisioningKey, "provisioning-key", "", "One-time provisioning key used to obtain a newt ID and secret from the server") } + if newtName == "" { + flag.StringVar(&newtName, "name", "", "Name for the site created during provisioning (supports {{env.VAR}} interpolation)") + } if configFile == "" { flag.StringVar(&configFile, "config-file", "", "Path to config file (overrides CONFIG_FILE env var and default location)") } @@ -623,6 +630,9 @@ func runNewtMain(ctx context.Context) { if provisioningKey != "" && client.GetConfig().ProvisioningKey == "" { client.GetConfig().ProvisioningKey = provisioningKey } + if newtName != "" && client.GetConfig().Name == "" { + client.GetConfig().Name = newtName + } endpoint = client.GetConfig().Endpoint // Update endpoint from config id = client.GetConfig().ID // Update ID from config diff --git a/websocket/config.go b/websocket/config.go index 4fb6513..d727a41 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -12,6 +12,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "runtime" "strings" "time" @@ -99,6 +100,10 @@ func (c *Client) loadConfig() error { if c.config.ProvisioningKey == "" { c.config.ProvisioningKey = config.ProvisioningKey } + // Always load the name from the file if not already set + if c.config.Name == "" { + c.config.Name = config.Name + } // Check if CLI args provided values that override file values if (!fileHadID && originalConfig.ID != "") || @@ -135,6 +140,21 @@ func (c *Client) saveConfig() error { return err } +// interpolateString replaces {{env.VAR}} tokens in s with the corresponding +// environment variable values. Tokens that do not match a supported scheme are +// left unchanged, mirroring the blueprint interpolation logic. +func interpolateString(s string) string { + re := regexp.MustCompile(`\{\{([^}]+)\}\}`) + return re.ReplaceAllStringFunc(s, func(match string) string { + inner := strings.TrimSpace(match[2 : len(match)-2]) + if strings.HasPrefix(inner, "env.") { + varName := strings.TrimPrefix(inner, "env.") + return os.Getenv(varName) + } + return match + }) +} + // provisionIfNeeded checks whether a provisioning key is present and, if so, // exchanges it for a newt ID and secret by calling the registration endpoint. // On success the config is updated in-place and flagged for saving so that @@ -158,9 +178,15 @@ func (c *Client) provisionIfNeeded() error { } baseEndpoint := strings.TrimRight(baseURL.String(), "/") + // Interpolate any {{env.VAR}} tokens in the name before sending. + name := interpolateString(c.config.Name) + reqBody := map[string]interface{}{ "provisioningKey": c.config.ProvisioningKey, } + if name != "" { + reqBody["name"] = name + } jsonData, err := json.Marshal(reqBody) if err != nil { return fmt.Errorf("failed to marshal provisioning request: %w", err) @@ -236,6 +262,7 @@ func (c *Client) provisionIfNeeded() error { c.config.ID = provResp.Data.NewtID c.config.Secret = provResp.Data.Secret c.config.ProvisioningKey = "" + c.config.Name = "" c.configNeedsSave = true // Save immediately so that if the subsequent connection attempt fails the diff --git a/websocket/types.go b/websocket/types.go index 2b32dae..195e06f 100644 --- a/websocket/types.go +++ b/websocket/types.go @@ -6,6 +6,7 @@ type Config struct { Endpoint string `json:"endpoint"` TlsClientCert string `json:"tlsClientCert"` ProvisioningKey string `json:"provisioningKey,omitempty"` + Name string `json:"name,omitempty"` } type TokenResponse struct { From 8d82460a76ef47dd14d36d115f6e14871c5b67fe Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 31 Mar 2026 17:06:07 -0700 Subject: [PATCH 7/9] Send health checks to the server on reconnect --- main.go | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 6ad1c2f..a5c4581 100644 --- a/main.go +++ b/main.go @@ -1820,6 +1820,30 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } else { logger.Warn("CLIENTS WILL NOT WORK ON THIS VERSION OF NEWT WITH THIS VERSION OF PANGOLIN, PLEASE UPDATE THE SERVER TO 1.13 OR HIGHER OR DOWNGRADE NEWT") } + + sendBlueprint(client) + } else { + // Resend current health check status for all targets in case the server + // missed updates while newt was disconnected. + targets := healthMonitor.GetTargets() + if len(targets) > 0 { + healthStatuses := make(map[int]interface{}) + for id, target := range targets { + healthStatuses[id] = map[string]interface{}{ + "status": target.Status.String(), + "lastCheck": target.LastCheck.Format(time.RFC3339), + "checkCount": target.CheckCount, + "lastError": target.LastError, + "config": target.Config, + } + } + logger.Debug("Reconnected: resending health check status for %d targets", len(healthStatuses)) + if err := client.SendMessage("newt/healthcheck/status", map[string]interface{}{ + "targets": healthStatuses, + }); err != nil { + logger.Error("Failed to resend health check status on reconnect: %v", err) + } + } } // Send registration message to the server for backward compatibility @@ -1832,8 +1856,6 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( "chainId": bcChainId, }) - sendBlueprint(client) - if err != nil { logger.Error("Failed to send registration message: %v", err) return err From f4d071fe27f1c7a6c9b54e5123262b81d41ac7eb Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 2 Apr 2026 21:39:59 -0400 Subject: [PATCH 8/9] Add provisioning blueprint file --- common.go | 6 +++--- main.go | 15 ++++++++++++--- websocket/client.go | 11 +++++++++++ websocket/config.go | 1 + 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/common.go b/common.go index e215813..4e1ed00 100644 --- a/common.go +++ b/common.go @@ -540,12 +540,12 @@ func interpolateBlueprint(data []byte) []byte { }) } -func sendBlueprint(client *websocket.Client) error { - if blueprintFile == "" { +func sendBlueprint(client *websocket.Client, file string) error { + if file == "" { return nil } // try to read the blueprint file - blueprintData, err := os.ReadFile(blueprintFile) + blueprintData, err := os.ReadFile(file) if err != nil { logger.Error("Failed to read blueprint file: %v", err) } else { diff --git a/main.go b/main.go index a5c4581..d5f2a96 100644 --- a/main.go +++ b/main.go @@ -155,8 +155,9 @@ var ( region string metricsAsyncBytes bool pprofEnabled bool - blueprintFile string - noCloud bool + blueprintFile string + provisioningBlueprintFile string + noCloud bool // New mTLS configuration variables tlsClientCert string @@ -284,6 +285,7 @@ func runNewtMain(ctx context.Context) { tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") } blueprintFile = os.Getenv("BLUEPRINT_FILE") + provisioningBlueprintFile = os.Getenv("PROVISIONING_BLUEPRINT_FILE") noCloudEnv := os.Getenv("NO_CLOUD") noCloud = noCloudEnv == "true" provisioningKey = os.Getenv("NEWT_PROVISIONING_KEY") @@ -393,6 +395,9 @@ func runNewtMain(ctx context.Context) { if blueprintFile == "" { flag.StringVar(&blueprintFile, "blueprint-file", "", "Path to blueprint file (if unset, no blueprint will be applied)") } + if provisioningBlueprintFile == "" { + flag.StringVar(&provisioningBlueprintFile, "provisioning-blueprint-file", "", "Path to blueprint file applied once after a provisioning credential exchange (if unset, no provisioning blueprint will be applied)") + } if noCloudEnv == "" { flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover") } @@ -1821,7 +1826,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( logger.Warn("CLIENTS WILL NOT WORK ON THIS VERSION OF NEWT WITH THIS VERSION OF PANGOLIN, PLEASE UPDATE THE SERVER TO 1.13 OR HIGHER OR DOWNGRADE NEWT") } - sendBlueprint(client) + sendBlueprint(client, blueprintFile) + if client.WasJustProvisioned() { + logger.Info("Provisioning detected – sending provisioning blueprint") + sendBlueprint(client, provisioningBlueprintFile) + } } else { // Resend current health check status for all targets in case the server // missed updates while newt was disconnected. diff --git a/websocket/client.go b/websocket/client.go index 49cf414..6990bd2 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -53,6 +53,7 @@ type Client struct { processingMessage bool // Flag to track if a message is currently being processed processingMux sync.RWMutex // Protects processingMessage processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete + justProvisioned bool // Set to true when provisionIfNeeded exchanges a key for permanent credentials } type ClientOption func(*Client) @@ -102,6 +103,16 @@ func (c *Client) OnTokenUpdate(callback func(token string)) { c.onTokenUpdate = callback } +// WasJustProvisioned reports whether the client exchanged a provisioning key +// for permanent credentials during the most recent connection attempt. It +// consumes the flag – subsequent calls return false until provisioning occurs +// again (which, in practice, never happens once credentials are persisted). +func (c *Client) WasJustProvisioned() bool { + v := c.justProvisioned + c.justProvisioned = false + return v +} + func (c *Client) metricsContext() context.Context { c.metricsCtxMu.RLock() defer c.metricsCtxMu.RUnlock() diff --git a/websocket/config.go b/websocket/config.go index d727a41..39f1bd2 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -264,6 +264,7 @@ func (c *Client) provisionIfNeeded() error { c.config.ProvisioningKey = "" c.config.Name = "" c.configNeedsSave = true + c.justProvisioned = true // Save immediately so that if the subsequent connection attempt fails the // provisioning key is already gone from disk and the next retry uses the From 2e02c9b7a93cb9c96ebbdbf097e75e55f7b3a20b Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 3 Apr 2026 16:49:09 -0400 Subject: [PATCH 9/9] Remove files --- config.json | 4 ---- config.json.bak | 4 ---- 2 files changed, 8 deletions(-) delete mode 100644 config.json delete mode 100644 config.json.bak diff --git a/config.json b/config.json deleted file mode 100644 index fac9795..0000000 --- a/config.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "endpoint": "http://you.fosrl.io", - "provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx" -} \ No newline at end of file diff --git a/config.json.bak b/config.json.bak deleted file mode 100644 index fac9795..0000000 --- a/config.json.bak +++ /dev/null @@ -1,4 +0,0 @@ -{ - "endpoint": "http://you.fosrl.io", - "provisioningKey": "spk-xt1opb0fkoqb7qb.hi44jciamqcrdaja4lvz3kp52pl3lssamp6asuyx" -} \ No newline at end of file