diff --git a/clients/clients.go b/clients/clients.go index 4c64dbd..9223262 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -43,6 +43,7 @@ type Target struct { RewriteTo string `json:"rewriteTo,omitempty"` DisableIcmp bool `json:"disableIcmp,omitempty"` PortRange []PortRange `json:"portRange,omitempty"` + ResourceId int `json:"resourceId,omitempty"` } type PortRange struct { @@ -196,6 +197,15 @@ func (s *WireGuardService) Close() { s.stopGetConfig = nil } + // Flush access logs before tearing down the tunnel + if s.tnet != nil { + if ph := s.tnet.GetProxyHandler(); ph != nil { + if al := ph.GetAccessLogger(); al != nil { + al.Close() + } + } + } + // Stop the direct UDP relay first s.StopDirectUDPRelay() @@ -663,7 +673,7 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error { }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix) } } @@ -794,6 +804,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.TunnelIP = tunnelIP.String() + // Configure the access log sender to ship compressed session logs via websocket + s.tnet.SetAccessLogSender(func(data string) error { + return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{ + "compressed": data, + }) + }) + // Create WireGuard device using the shared bind s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( device.LogLevelSilent, // Use silent logging by default - could be made configurable @@ -914,7 +931,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { if err != nil { return fmt.Errorf("invalid CIDR %s: %v", sp, err) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } @@ -1307,7 +1324,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { logger.Info("Invalid CIDR %s: %v", sp, err) continue } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } @@ -1425,7 +1442,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { logger.Info("Invalid CIDR %s: %v", sp, err) continue } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange) } } diff --git a/netstack2/access_log.go b/netstack2/access_log.go new file mode 100644 index 0000000..ab0db78 --- /dev/null +++ b/netstack2/access_log.go @@ -0,0 +1,355 @@ +package netstack2 + +import ( + "bytes" + "compress/zlib" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "sync" + "time" + + "github.com/fosrl/newt/logger" +) + +const ( + // flushInterval is how often the access logger flushes completed sessions to the server + flushInterval = 60 * time.Second + + // maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush + maxBufferedSessions = 100 +) + +// SendFunc is a callback that sends compressed access log data to the server. +// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects. +type SendFunc func(data string) error + +// AccessSession represents a tracked access session through the proxy +type AccessSession struct { + SessionID string `json:"sessionId"` + ResourceID int `json:"resourceId"` + SourceAddr string `json:"sourceAddr"` + DestAddr string `json:"destAddr"` + Protocol string `json:"protocol"` + StartedAt time.Time `json:"startedAt"` + EndedAt time.Time `json:"endedAt,omitempty"` + BytesTx int64 `json:"bytesTx"` + BytesRx int64 `json:"bytesRx"` +} + +// udpSessionKey identifies a unique UDP "session" by src -> dst +type udpSessionKey struct { + srcAddr string + dstAddr string + protocol string +} + +// AccessLogger tracks access sessions for resources and periodically +// flushes completed sessions to the server via a configurable SendFunc. +type AccessLogger struct { + mu sync.Mutex + sessions map[string]*AccessSession // active sessions: sessionID -> session + udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup + completedSessions []*AccessSession // completed sessions waiting to be flushed + udpTimeout time.Duration + sendFn SendFunc + stopCh chan struct{} + flushDone chan struct{} // closed after the flush goroutine exits +} + +// NewAccessLogger creates a new access logger. +// udpTimeout controls how long a UDP session is kept alive without traffic before being ended. +func NewAccessLogger(udpTimeout time.Duration) *AccessLogger { + al := &AccessLogger{ + sessions: make(map[string]*AccessSession), + udpSessions: make(map[udpSessionKey]*AccessSession), + completedSessions: make([]*AccessSession, 0), + udpTimeout: udpTimeout, + stopCh: make(chan struct{}), + flushDone: make(chan struct{}), + } + go al.backgroundLoop() + return al +} + +// SetSendFunc sets the callback used to send compressed access log batches +// to the server. This can be called after construction once the websocket +// client is available. +func (al *AccessLogger) SetSendFunc(fn SendFunc) { + al.mu.Lock() + defer al.mu.Unlock() + al.sendFn = fn +} + +// generateSessionID creates a random session identifier +func generateSessionID() string { + b := make([]byte, 8) + rand.Read(b) + return hex.EncodeToString(b) +} + +// StartTCPSession logs the start of a TCP session and returns a session ID. +func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string { + sessionID := generateSessionID() + now := time.Now() + + session := &AccessSession{ + SessionID: sessionID, + ResourceID: resourceID, + SourceAddr: srcAddr, + DestAddr: dstAddr, + Protocol: "tcp", + StartedAt: now, + } + + al.mu.Lock() + al.sessions[sessionID] = session + al.mu.Unlock() + + logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s", + sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339)) + + return sessionID +} + +// EndTCPSession logs the end of a TCP session and queues it for sending. +func (al *AccessLogger) EndTCPSession(sessionID string) { + now := time.Now() + + al.mu.Lock() + session, ok := al.sessions[sessionID] + if ok { + session.EndedAt = now + delete(al.sessions, sessionID) + al.completedSessions = append(al.completedSessions, session) + } + shouldFlush := len(al.completedSessions) >= maxBufferedSessions + al.mu.Unlock() + + if ok { + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s", + sessionID, session.ResourceID, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + } + + if shouldFlush { + al.flush() + } +} + +// TrackUDPSession starts or returns an existing UDP session. Returns the session ID. +func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string { + key := udpSessionKey{ + srcAddr: srcAddr, + dstAddr: dstAddr, + protocol: "udp", + } + + al.mu.Lock() + defer al.mu.Unlock() + + if existing, ok := al.udpSessions[key]; ok { + return existing.SessionID + } + + sessionID := generateSessionID() + now := time.Now() + + session := &AccessSession{ + SessionID: sessionID, + ResourceID: resourceID, + SourceAddr: srcAddr, + DestAddr: dstAddr, + Protocol: "udp", + StartedAt: now, + } + + al.sessions[sessionID] = session + al.udpSessions[key] = session + + logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s", + sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339)) + + return sessionID +} + +// EndUDPSession ends a UDP session and queues it for sending. +func (al *AccessLogger) EndUDPSession(sessionID string) { + now := time.Now() + + al.mu.Lock() + session, ok := al.sessions[sessionID] + if ok { + session.EndedAt = now + delete(al.sessions, sessionID) + key := udpSessionKey{ + srcAddr: session.SourceAddr, + dstAddr: session.DestAddr, + protocol: "udp", + } + delete(al.udpSessions, key) + al.completedSessions = append(al.completedSessions, session) + } + shouldFlush := len(al.completedSessions) >= maxBufferedSessions + al.mu.Unlock() + + if ok { + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s", + sessionID, session.ResourceID, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + } + + if shouldFlush { + al.flush() + } +} + +// backgroundLoop handles periodic flushing and stale session reaping. +func (al *AccessLogger) backgroundLoop() { + defer close(al.flushDone) + + flushTicker := time.NewTicker(flushInterval) + defer flushTicker.Stop() + + reapTicker := time.NewTicker(30 * time.Second) + defer reapTicker.Stop() + + for { + select { + case <-al.stopCh: + return + case <-flushTicker.C: + al.flush() + case <-reapTicker.C: + al.reapStaleSessions() + } + } +} + +// reapStaleSessions cleans up UDP sessions that were not properly ended. +func (al *AccessLogger) reapStaleSessions() { + al.mu.Lock() + defer al.mu.Unlock() + + staleThreshold := time.Now().Add(-5 * time.Minute) + + for key, session := range al.udpSessions { + if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() { + now := time.Now() + session.EndedAt = now + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s", + session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + al.completedSessions = append(al.completedSessions, session) + delete(al.sessions, session.SessionID) + delete(al.udpSessions, key) + } + } +} + +// flush drains the completed sessions buffer, compresses with zlib, and sends via the SendFunc. +func (al *AccessLogger) flush() { + al.mu.Lock() + if len(al.completedSessions) == 0 { + al.mu.Unlock() + return + } + batch := al.completedSessions + al.completedSessions = make([]*AccessSession, 0) + sendFn := al.sendFn + al.mu.Unlock() + + if sendFn == nil { + logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch)) + return + } + + compressed, err := compressSessions(batch) + if err != nil { + logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err) + return + } + + if err := sendFn(compressed); err != nil { + logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err) + // Re-queue the batch so we don't lose data + al.mu.Lock() + al.completedSessions = append(batch, al.completedSessions...) + // Cap re-queued data to prevent unbounded growth if server is unreachable + if len(al.completedSessions) > maxBufferedSessions*5 { + dropped := len(al.completedSessions) - maxBufferedSessions*5 + al.completedSessions = al.completedSessions[:maxBufferedSessions*5] + logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped) + } + al.mu.Unlock() + return + } + + logger.Info("Access logger: sent %d sessions to server", len(batch)) +} + +// compressSessions JSON-encodes the sessions, compresses with zlib, and returns +// a base64-encoded string suitable for embedding in a JSON message. +func compressSessions(sessions []*AccessSession) (string, error) { + jsonData, err := json.Marshal(sessions) + if err != nil { + return "", err + } + + var buf bytes.Buffer + w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression) + if err != nil { + return "", err + } + if _, err := w.Write(jsonData); err != nil { + w.Close() + return "", err + } + if err := w.Close(); err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +// Close shuts down the background loop, ends all active sessions, +// and performs one final flush to send everything to the server. +func (al *AccessLogger) Close() { + // Signal the background loop to stop + select { + case <-al.stopCh: + // Already closed + return + default: + close(al.stopCh) + } + + // Wait for the background loop to exit so we don't race on flush + <-al.flushDone + + al.mu.Lock() + now := time.Now() + + // End all active sessions and move them to the completed buffer + for _, session := range al.sessions { + if session.EndedAt.IsZero() { + session.EndedAt = now + duration := now.Sub(session.StartedAt) + logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s", + session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr, + session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration) + al.completedSessions = append(al.completedSessions, session) + } + } + + al.sessions = make(map[string]*AccessSession) + al.udpSessions = make(map[udpSessionKey]*AccessSession) + al.mu.Unlock() + + // Final flush to send all remaining sessions to the server + al.flush() +} \ No newline at end of file diff --git a/netstack2/handlers.go b/netstack2/handlers.go index 75c58b2..07c235f 100644 --- a/netstack2/handlers.go +++ b/netstack2/handlers.go @@ -158,6 +158,18 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) + // Look up resource ID and start access session if applicable + var accessSessionID string + if h.proxyHandler != nil { + resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber)) + if resourceId != 0 { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort) + accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr) + } + } + } + // Create context with timeout for connection establishment ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) defer cancel() @@ -167,11 +179,26 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo targetConn, err := d.DialContext(ctx, "tcp", targetAddr) if err != nil { logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err) + // End access session on connection failure + if accessSessionID != "" { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + al.EndTCPSession(accessSessionID) + } + } // Connection failed, netstack will handle RST return } defer targetConn.Close() + // End access session when connection closes + if accessSessionID != "" { + defer func() { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + al.EndTCPSession(accessSessionID) + } + }() + } + logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr) // Bidirectional copy between netstack and target @@ -280,6 +307,27 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort) + // Look up resource ID and start access session if applicable + var accessSessionID string + if h.proxyHandler != nil { + resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber)) + if resourceId != 0 { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort) + accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr) + } + } + } + + // End access session when UDP handler returns (timeout or error) + if accessSessionID != "" { + defer func() { + if al := h.proxyHandler.GetAccessLogger(); al != nil { + al.EndUDPSession(accessSessionID) + } + }() + } + // Resolve target address remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) if err != nil { diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 1b34818..e383fc0 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -22,6 +22,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) +const ( + // udpAccessSessionTimeout is how long a UDP access session stays alive without traffic + // before being considered ended by the access logger + udpAccessSessionTimeout = 120 * time.Second +) + // PortRange represents an allowed range of ports (inclusive) with optional protocol filtering // Protocol can be "tcp", "udp", or "" (empty string means both protocols) type PortRange struct { @@ -46,6 +52,7 @@ type SubnetRule struct { DisableIcmp bool // If true, ICMP traffic is blocked for this subnet RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name PortRanges []PortRange // empty slice means all ports allowed + ResourceId int // Optional resource ID from the server for access logging } // GetAllRules returns a copy of all subnet rules @@ -111,10 +118,12 @@ type ProxyHandler struct { natTable map[connKey]*natState reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups + resourceTable map[destKey]int // Maps connection key to resource ID for access logging natMu sync.RWMutex enabled bool icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel notifiable channel.Notification // Notification handler for triggering reads + accessLogger *AccessLogger // Access logger for tracking sessions } // ProxyHandlerOptions configures the proxy handler @@ -137,7 +146,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { natTable: make(map[connKey]*natState), reverseNatTable: make(map[reverseConnKey]*natState), destRewriteTable: make(map[destKey]netip.Addr), + resourceTable: make(map[destKey]int), icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets + accessLogger: NewAccessLogger(udpAccessSessionTimeout), proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyStack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -202,11 +213,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { // destPrefix: The IP prefix of the destination // rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name // If portRanges is nil or empty, all ports are allowed for this subnet -func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { +func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { if p == nil || !p.enabled { return } - p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) + p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) } // RemoveSubnetRule removes a subnet from the proxy handler @@ -225,6 +236,43 @@ func (p *ProxyHandler) GetAllRules() []SubnetRule { return p.subnetLookup.GetAllRules() } +// LookupResourceId looks up the resource ID for a connection +// Returns 0 if no resource ID is associated with this connection +func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int { + if p == nil || !p.enabled { + return 0 + } + + key := destKey{ + srcIP: srcIP, + dstIP: dstIP, + dstPort: dstPort, + proto: proto, + } + + p.natMu.RLock() + defer p.natMu.RUnlock() + + return p.resourceTable[key] +} + +// GetAccessLogger returns the access logger for session tracking +func (p *ProxyHandler) GetAccessLogger() *AccessLogger { + if p == nil { + return nil + } + return p.accessLogger +} + +// SetAccessLogSender configures the function used to send compressed access log +// batches to the server. This should be called once the websocket client is available. +func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) { + if p == nil || !p.enabled || p.accessLogger == nil { + return + } + p.accessLogger.SetSendFunc(fn) +} + // LookupDestinationRewrite looks up the rewritten destination for a connection // This is used by TCP/UDP handlers to find the actual target address func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) { @@ -387,8 +435,22 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { // Check if the source IP, destination IP, port, and protocol match any subnet rule matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) if matchedRule != nil { - logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)", - srcAddr, dstAddr, protocol, dstPort) + logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)", + srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId) + + // Store resource ID for connections without DNAT as well + if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" { + dKey := destKey{ + srcIP: srcAddr.String(), + dstIP: dstAddr.String(), + dstPort: dstPort, + proto: uint8(protocol), + } + p.natMu.Lock() + p.resourceTable[dKey] = matchedRule.ResourceId + p.natMu.Unlock() + } + // Check if we need to perform DNAT if matchedRule.RewriteTo != "" { // Create connection tracking key using original destination @@ -420,6 +482,13 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { proto: uint8(protocol), } + // Store resource ID for access logging if present + if matchedRule.ResourceId != 0 { + p.natMu.Lock() + p.resourceTable[dKey] = matchedRule.ResourceId + p.natMu.Unlock() + } + // Check if we already have a NAT entry for this connection p.natMu.RLock() existingEntry, exists := p.natTable[key] @@ -720,6 +789,11 @@ func (p *ProxyHandler) Close() error { return nil } + // Shut down access logger + if p.accessLogger != nil { + p.accessLogger.Close() + } + // Close ICMP replies channel if p.icmpReplies != nil { close(p.icmpReplies) diff --git a/netstack2/subnet_lookup.go b/netstack2/subnet_lookup.go index c6ad0d5..317f85c 100644 --- a/netstack2/subnet_lookup.go +++ b/netstack2/subnet_lookup.go @@ -47,7 +47,7 @@ func prefixEqual(a, b netip.Prefix) bool { // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { +func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { sl.mu.Lock() defer sl.mu.Unlock() @@ -57,6 +57,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite DisableIcmp: disableIcmp, RewriteTo: rewriteTo, PortRanges: portRanges, + ResourceId: resourceId, } // Canonicalize source prefix to handle host bits correctly diff --git a/netstack2/tun.go b/netstack2/tun.go index b00faea..3183c36 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { // AddProxySubnetRule adds a subnet rule to the proxy handler // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { +func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) { tun := (*netTun)(net) if tun.proxyHandler != nil { - tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) + tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId) } } @@ -385,6 +385,15 @@ func (net *Net) GetProxyHandler() *ProxyHandler { return tun.proxyHandler } +// SetAccessLogSender configures the function used to send compressed access log +// batches to the server. This should be called once the websocket client is available. +func (net *Net) SetAccessLogSender(fn SendFunc) { + tun := (*netTun)(net) + if tun.proxyHandler != nil { + tun.proxyHandler.SetAccessLogSender(fn) + } +} + type PingConn struct { laddr PingAddr raddr PingAddr