mirror of
https://github.com/fosrl/newt.git
synced 2026-03-26 20:46:41 +00:00
Saving and sending access logs pass 1
This commit is contained in:
@@ -43,6 +43,7 @@ type Target struct {
|
||||
RewriteTo string `json:"rewriteTo,omitempty"`
|
||||
DisableIcmp bool `json:"disableIcmp,omitempty"`
|
||||
PortRange []PortRange `json:"portRange,omitempty"`
|
||||
ResourceId int `json:"resourceId,omitempty"`
|
||||
}
|
||||
|
||||
type PortRange struct {
|
||||
@@ -196,6 +197,15 @@ func (s *WireGuardService) Close() {
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
|
||||
// Flush access logs before tearing down the tunnel
|
||||
if s.tnet != nil {
|
||||
if ph := s.tnet.GetProxyHandler(); ph != nil {
|
||||
if al := ph.GetAccessLogger(); al != nil {
|
||||
al.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop the direct UDP relay first
|
||||
s.StopDirectUDPRelay()
|
||||
|
||||
@@ -663,7 +673,7 @@ func (s *WireGuardService) syncTargets(desiredTargets []Target) error {
|
||||
})
|
||||
}
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
logger.Info("Added target %s -> %s during sync", target.SourcePrefix, target.DestPrefix)
|
||||
}
|
||||
}
|
||||
@@ -794,6 +804,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
|
||||
s.TunnelIP = tunnelIP.String()
|
||||
|
||||
// Configure the access log sender to ship compressed session logs via websocket
|
||||
s.tnet.SetAccessLogSender(func(data string) error {
|
||||
return s.client.SendMessageNoLog("newt/access-log", map[string]interface{}{
|
||||
"compressed": data,
|
||||
})
|
||||
})
|
||||
|
||||
// Create WireGuard device using the shared bind
|
||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||
device.LogLevelSilent, // Use silent logging by default - could be made configurable
|
||||
@@ -914,7 +931,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %s: %v", sp, err)
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
}
|
||||
@@ -1307,7 +1324,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
}
|
||||
@@ -1425,7 +1442,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
logger.Info("Invalid CIDR %s: %v", sp, err)
|
||||
continue
|
||||
}
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp, target.ResourceId)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", sp, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
}
|
||||
}
|
||||
|
||||
355
netstack2/access_log.go
Normal file
355
netstack2/access_log.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package netstack2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// flushInterval is how often the access logger flushes completed sessions to the server
|
||||
flushInterval = 60 * time.Second
|
||||
|
||||
// maxBufferedSessions is the max number of completed sessions to buffer before forcing a flush
|
||||
maxBufferedSessions = 100
|
||||
)
|
||||
|
||||
// SendFunc is a callback that sends compressed access log data to the server.
|
||||
// The data is a base64-encoded zlib-compressed JSON array of AccessSession objects.
|
||||
type SendFunc func(data string) error
|
||||
|
||||
// AccessSession represents a tracked access session through the proxy
|
||||
type AccessSession struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
ResourceID int `json:"resourceId"`
|
||||
SourceAddr string `json:"sourceAddr"`
|
||||
DestAddr string `json:"destAddr"`
|
||||
Protocol string `json:"protocol"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
EndedAt time.Time `json:"endedAt,omitempty"`
|
||||
BytesTx int64 `json:"bytesTx"`
|
||||
BytesRx int64 `json:"bytesRx"`
|
||||
}
|
||||
|
||||
// udpSessionKey identifies a unique UDP "session" by src -> dst
|
||||
type udpSessionKey struct {
|
||||
srcAddr string
|
||||
dstAddr string
|
||||
protocol string
|
||||
}
|
||||
|
||||
// AccessLogger tracks access sessions for resources and periodically
|
||||
// flushes completed sessions to the server via a configurable SendFunc.
|
||||
type AccessLogger struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*AccessSession // active sessions: sessionID -> session
|
||||
udpSessions map[udpSessionKey]*AccessSession // active UDP sessions for dedup
|
||||
completedSessions []*AccessSession // completed sessions waiting to be flushed
|
||||
udpTimeout time.Duration
|
||||
sendFn SendFunc
|
||||
stopCh chan struct{}
|
||||
flushDone chan struct{} // closed after the flush goroutine exits
|
||||
}
|
||||
|
||||
// NewAccessLogger creates a new access logger.
|
||||
// udpTimeout controls how long a UDP session is kept alive without traffic before being ended.
|
||||
func NewAccessLogger(udpTimeout time.Duration) *AccessLogger {
|
||||
al := &AccessLogger{
|
||||
sessions: make(map[string]*AccessSession),
|
||||
udpSessions: make(map[udpSessionKey]*AccessSession),
|
||||
completedSessions: make([]*AccessSession, 0),
|
||||
udpTimeout: udpTimeout,
|
||||
stopCh: make(chan struct{}),
|
||||
flushDone: make(chan struct{}),
|
||||
}
|
||||
go al.backgroundLoop()
|
||||
return al
|
||||
}
|
||||
|
||||
// SetSendFunc sets the callback used to send compressed access log batches
|
||||
// to the server. This can be called after construction once the websocket
|
||||
// client is available.
|
||||
func (al *AccessLogger) SetSendFunc(fn SendFunc) {
|
||||
al.mu.Lock()
|
||||
defer al.mu.Unlock()
|
||||
al.sendFn = fn
|
||||
}
|
||||
|
||||
// generateSessionID creates a random session identifier
|
||||
func generateSessionID() string {
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// StartTCPSession logs the start of a TCP session and returns a session ID.
|
||||
func (al *AccessLogger) StartTCPSession(resourceID int, srcAddr, dstAddr string) string {
|
||||
sessionID := generateSessionID()
|
||||
now := time.Now()
|
||||
|
||||
session := &AccessSession{
|
||||
SessionID: sessionID,
|
||||
ResourceID: resourceID,
|
||||
SourceAddr: srcAddr,
|
||||
DestAddr: dstAddr,
|
||||
Protocol: "tcp",
|
||||
StartedAt: now,
|
||||
}
|
||||
|
||||
al.mu.Lock()
|
||||
al.sessions[sessionID] = session
|
||||
al.mu.Unlock()
|
||||
|
||||
logger.Info("ACCESS START session=%s resource=%d proto=tcp src=%s dst=%s time=%s",
|
||||
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
|
||||
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// EndTCPSession logs the end of a TCP session and queues it for sending.
|
||||
func (al *AccessLogger) EndTCPSession(sessionID string) {
|
||||
now := time.Now()
|
||||
|
||||
al.mu.Lock()
|
||||
session, ok := al.sessions[sessionID]
|
||||
if ok {
|
||||
session.EndedAt = now
|
||||
delete(al.sessions, sessionID)
|
||||
al.completedSessions = append(al.completedSessions, session)
|
||||
}
|
||||
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
|
||||
al.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
duration := now.Sub(session.StartedAt)
|
||||
logger.Info("ACCESS END session=%s resource=%d proto=tcp src=%s dst=%s started=%s ended=%s duration=%s",
|
||||
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
|
||||
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||
}
|
||||
|
||||
if shouldFlush {
|
||||
al.flush()
|
||||
}
|
||||
}
|
||||
|
||||
// TrackUDPSession starts or returns an existing UDP session. Returns the session ID.
|
||||
func (al *AccessLogger) TrackUDPSession(resourceID int, srcAddr, dstAddr string) string {
|
||||
key := udpSessionKey{
|
||||
srcAddr: srcAddr,
|
||||
dstAddr: dstAddr,
|
||||
protocol: "udp",
|
||||
}
|
||||
|
||||
al.mu.Lock()
|
||||
defer al.mu.Unlock()
|
||||
|
||||
if existing, ok := al.udpSessions[key]; ok {
|
||||
return existing.SessionID
|
||||
}
|
||||
|
||||
sessionID := generateSessionID()
|
||||
now := time.Now()
|
||||
|
||||
session := &AccessSession{
|
||||
SessionID: sessionID,
|
||||
ResourceID: resourceID,
|
||||
SourceAddr: srcAddr,
|
||||
DestAddr: dstAddr,
|
||||
Protocol: "udp",
|
||||
StartedAt: now,
|
||||
}
|
||||
|
||||
al.sessions[sessionID] = session
|
||||
al.udpSessions[key] = session
|
||||
|
||||
logger.Info("ACCESS START session=%s resource=%d proto=udp src=%s dst=%s time=%s",
|
||||
sessionID, resourceID, srcAddr, dstAddr, now.Format(time.RFC3339))
|
||||
|
||||
return sessionID
|
||||
}
|
||||
|
||||
// EndUDPSession ends a UDP session and queues it for sending.
|
||||
func (al *AccessLogger) EndUDPSession(sessionID string) {
|
||||
now := time.Now()
|
||||
|
||||
al.mu.Lock()
|
||||
session, ok := al.sessions[sessionID]
|
||||
if ok {
|
||||
session.EndedAt = now
|
||||
delete(al.sessions, sessionID)
|
||||
key := udpSessionKey{
|
||||
srcAddr: session.SourceAddr,
|
||||
dstAddr: session.DestAddr,
|
||||
protocol: "udp",
|
||||
}
|
||||
delete(al.udpSessions, key)
|
||||
al.completedSessions = append(al.completedSessions, session)
|
||||
}
|
||||
shouldFlush := len(al.completedSessions) >= maxBufferedSessions
|
||||
al.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
duration := now.Sub(session.StartedAt)
|
||||
logger.Info("ACCESS END session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
|
||||
sessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
|
||||
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||
}
|
||||
|
||||
if shouldFlush {
|
||||
al.flush()
|
||||
}
|
||||
}
|
||||
|
||||
// backgroundLoop handles periodic flushing and stale session reaping.
|
||||
func (al *AccessLogger) backgroundLoop() {
|
||||
defer close(al.flushDone)
|
||||
|
||||
flushTicker := time.NewTicker(flushInterval)
|
||||
defer flushTicker.Stop()
|
||||
|
||||
reapTicker := time.NewTicker(30 * time.Second)
|
||||
defer reapTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-al.stopCh:
|
||||
return
|
||||
case <-flushTicker.C:
|
||||
al.flush()
|
||||
case <-reapTicker.C:
|
||||
al.reapStaleSessions()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reapStaleSessions cleans up UDP sessions that were not properly ended.
|
||||
func (al *AccessLogger) reapStaleSessions() {
|
||||
al.mu.Lock()
|
||||
defer al.mu.Unlock()
|
||||
|
||||
staleThreshold := time.Now().Add(-5 * time.Minute)
|
||||
|
||||
for key, session := range al.udpSessions {
|
||||
if session.StartedAt.Before(staleThreshold) && session.EndedAt.IsZero() {
|
||||
now := time.Now()
|
||||
session.EndedAt = now
|
||||
duration := now.Sub(session.StartedAt)
|
||||
logger.Info("ACCESS END (reaped) session=%s resource=%d proto=udp src=%s dst=%s started=%s ended=%s duration=%s",
|
||||
session.SessionID, session.ResourceID, session.SourceAddr, session.DestAddr,
|
||||
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||
al.completedSessions = append(al.completedSessions, session)
|
||||
delete(al.sessions, session.SessionID)
|
||||
delete(al.udpSessions, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flush drains the completed sessions buffer, compresses with zlib, and sends via the SendFunc.
|
||||
func (al *AccessLogger) flush() {
|
||||
al.mu.Lock()
|
||||
if len(al.completedSessions) == 0 {
|
||||
al.mu.Unlock()
|
||||
return
|
||||
}
|
||||
batch := al.completedSessions
|
||||
al.completedSessions = make([]*AccessSession, 0)
|
||||
sendFn := al.sendFn
|
||||
al.mu.Unlock()
|
||||
|
||||
if sendFn == nil {
|
||||
logger.Debug("Access logger: no send function configured, discarding %d sessions", len(batch))
|
||||
return
|
||||
}
|
||||
|
||||
compressed, err := compressSessions(batch)
|
||||
if err != nil {
|
||||
logger.Error("Access logger: failed to compress %d sessions: %v", len(batch), err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sendFn(compressed); err != nil {
|
||||
logger.Error("Access logger: failed to send %d sessions: %v", len(batch), err)
|
||||
// Re-queue the batch so we don't lose data
|
||||
al.mu.Lock()
|
||||
al.completedSessions = append(batch, al.completedSessions...)
|
||||
// Cap re-queued data to prevent unbounded growth if server is unreachable
|
||||
if len(al.completedSessions) > maxBufferedSessions*5 {
|
||||
dropped := len(al.completedSessions) - maxBufferedSessions*5
|
||||
al.completedSessions = al.completedSessions[:maxBufferedSessions*5]
|
||||
logger.Warn("Access logger: buffer overflow, dropped %d oldest sessions", dropped)
|
||||
}
|
||||
al.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Access logger: sent %d sessions to server", len(batch))
|
||||
}
|
||||
|
||||
// compressSessions JSON-encodes the sessions, compresses with zlib, and returns
|
||||
// a base64-encoded string suitable for embedding in a JSON message.
|
||||
func compressSessions(sessions []*AccessSession) (string, error) {
|
||||
jsonData, err := json.Marshal(sessions)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
w, err := zlib.NewWriterLevel(&buf, zlib.BestCompression)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := w.Write(jsonData); err != nil {
|
||||
w.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
// Close shuts down the background loop, ends all active sessions,
|
||||
// and performs one final flush to send everything to the server.
|
||||
func (al *AccessLogger) Close() {
|
||||
// Signal the background loop to stop
|
||||
select {
|
||||
case <-al.stopCh:
|
||||
// Already closed
|
||||
return
|
||||
default:
|
||||
close(al.stopCh)
|
||||
}
|
||||
|
||||
// Wait for the background loop to exit so we don't race on flush
|
||||
<-al.flushDone
|
||||
|
||||
al.mu.Lock()
|
||||
now := time.Now()
|
||||
|
||||
// End all active sessions and move them to the completed buffer
|
||||
for _, session := range al.sessions {
|
||||
if session.EndedAt.IsZero() {
|
||||
session.EndedAt = now
|
||||
duration := now.Sub(session.StartedAt)
|
||||
logger.Info("ACCESS END (shutdown) session=%s resource=%d proto=%s src=%s dst=%s started=%s ended=%s duration=%s",
|
||||
session.SessionID, session.ResourceID, session.Protocol, session.SourceAddr, session.DestAddr,
|
||||
session.StartedAt.Format(time.RFC3339), now.Format(time.RFC3339), duration)
|
||||
al.completedSessions = append(al.completedSessions, session)
|
||||
}
|
||||
}
|
||||
|
||||
al.sessions = make(map[string]*AccessSession)
|
||||
al.udpSessions = make(map[udpSessionKey]*AccessSession)
|
||||
al.mu.Unlock()
|
||||
|
||||
// Final flush to send all remaining sessions to the server
|
||||
al.flush()
|
||||
}
|
||||
@@ -158,6 +158,18 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
|
||||
|
||||
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
||||
|
||||
// Look up resource ID and start access session if applicable
|
||||
var accessSessionID string
|
||||
if h.proxyHandler != nil {
|
||||
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(tcp.ProtocolNumber))
|
||||
if resourceId != 0 {
|
||||
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
|
||||
accessSessionID = al.StartTCPSession(resourceId, srcAddr, targetAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create context with timeout for connection establishment
|
||||
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
|
||||
defer cancel()
|
||||
@@ -167,11 +179,26 @@ func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.Transpo
|
||||
targetConn, err := d.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err)
|
||||
// End access session on connection failure
|
||||
if accessSessionID != "" {
|
||||
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||
al.EndTCPSession(accessSessionID)
|
||||
}
|
||||
}
|
||||
// Connection failed, netstack will handle RST
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// End access session when connection closes
|
||||
if accessSessionID != "" {
|
||||
defer func() {
|
||||
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||
al.EndTCPSession(accessSessionID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr)
|
||||
|
||||
// Bidirectional copy between netstack and target
|
||||
@@ -280,6 +307,27 @@ func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.Transpo
|
||||
|
||||
targetAddr := fmt.Sprintf("%s:%d", actualDstIP, dstPort)
|
||||
|
||||
// Look up resource ID and start access session if applicable
|
||||
var accessSessionID string
|
||||
if h.proxyHandler != nil {
|
||||
resourceId := h.proxyHandler.LookupResourceId(srcIP, dstIP, dstPort, uint8(udp.ProtocolNumber))
|
||||
if resourceId != 0 {
|
||||
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||
srcAddr := fmt.Sprintf("%s:%d", srcIP, srcPort)
|
||||
accessSessionID = al.TrackUDPSession(resourceId, srcAddr, targetAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// End access session when UDP handler returns (timeout or error)
|
||||
if accessSessionID != "" {
|
||||
defer func() {
|
||||
if al := h.proxyHandler.GetAccessLogger(); al != nil {
|
||||
al.EndUDPSession(accessSessionID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Resolve target address
|
||||
remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,6 +22,12 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
// udpAccessSessionTimeout is how long a UDP access session stays alive without traffic
|
||||
// before being considered ended by the access logger
|
||||
udpAccessSessionTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering
|
||||
// Protocol can be "tcp", "udp", or "" (empty string means both protocols)
|
||||
type PortRange struct {
|
||||
@@ -46,6 +52,7 @@ type SubnetRule struct {
|
||||
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
|
||||
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
|
||||
PortRanges []PortRange // empty slice means all ports allowed
|
||||
ResourceId int // Optional resource ID from the server for access logging
|
||||
}
|
||||
|
||||
// GetAllRules returns a copy of all subnet rules
|
||||
@@ -111,10 +118,12 @@ type ProxyHandler struct {
|
||||
natTable map[connKey]*natState
|
||||
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
|
||||
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
|
||||
resourceTable map[destKey]int // Maps connection key to resource ID for access logging
|
||||
natMu sync.RWMutex
|
||||
enabled bool
|
||||
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
|
||||
notifiable channel.Notification // Notification handler for triggering reads
|
||||
accessLogger *AccessLogger // Access logger for tracking sessions
|
||||
}
|
||||
|
||||
// ProxyHandlerOptions configures the proxy handler
|
||||
@@ -137,7 +146,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
natTable: make(map[connKey]*natState),
|
||||
reverseNatTable: make(map[reverseConnKey]*natState),
|
||||
destRewriteTable: make(map[destKey]netip.Addr),
|
||||
resourceTable: make(map[destKey]int),
|
||||
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
|
||||
accessLogger: NewAccessLogger(udpAccessSessionTimeout),
|
||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||
proxyStack: stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
@@ -202,11 +213,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
// destPrefix: The IP prefix of the destination
|
||||
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
||||
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||
if p == nil || !p.enabled {
|
||||
return
|
||||
}
|
||||
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
|
||||
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
|
||||
}
|
||||
|
||||
// RemoveSubnetRule removes a subnet from the proxy handler
|
||||
@@ -225,6 +236,43 @@ func (p *ProxyHandler) GetAllRules() []SubnetRule {
|
||||
return p.subnetLookup.GetAllRules()
|
||||
}
|
||||
|
||||
// LookupResourceId looks up the resource ID for a connection
|
||||
// Returns 0 if no resource ID is associated with this connection
|
||||
func (p *ProxyHandler) LookupResourceId(srcIP, dstIP string, dstPort uint16, proto uint8) int {
|
||||
if p == nil || !p.enabled {
|
||||
return 0
|
||||
}
|
||||
|
||||
key := destKey{
|
||||
srcIP: srcIP,
|
||||
dstIP: dstIP,
|
||||
dstPort: dstPort,
|
||||
proto: proto,
|
||||
}
|
||||
|
||||
p.natMu.RLock()
|
||||
defer p.natMu.RUnlock()
|
||||
|
||||
return p.resourceTable[key]
|
||||
}
|
||||
|
||||
// GetAccessLogger returns the access logger for session tracking
|
||||
func (p *ProxyHandler) GetAccessLogger() *AccessLogger {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return p.accessLogger
|
||||
}
|
||||
|
||||
// SetAccessLogSender configures the function used to send compressed access log
|
||||
// batches to the server. This should be called once the websocket client is available.
|
||||
func (p *ProxyHandler) SetAccessLogSender(fn SendFunc) {
|
||||
if p == nil || !p.enabled || p.accessLogger == nil {
|
||||
return
|
||||
}
|
||||
p.accessLogger.SetSendFunc(fn)
|
||||
}
|
||||
|
||||
// LookupDestinationRewrite looks up the rewritten destination for a connection
|
||||
// This is used by TCP/UDP handlers to find the actual target address
|
||||
func (p *ProxyHandler) LookupDestinationRewrite(srcIP, dstIP string, dstPort uint16, proto uint8) (netip.Addr, bool) {
|
||||
@@ -387,8 +435,22 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
||||
// Check if the source IP, destination IP, port, and protocol match any subnet rule
|
||||
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
|
||||
if matchedRule != nil {
|
||||
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
|
||||
srcAddr, dstAddr, protocol, dstPort)
|
||||
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d, resourceId=%d)",
|
||||
srcAddr, dstAddr, protocol, dstPort, matchedRule.ResourceId)
|
||||
|
||||
// Store resource ID for connections without DNAT as well
|
||||
if matchedRule.ResourceId != 0 && matchedRule.RewriteTo == "" {
|
||||
dKey := destKey{
|
||||
srcIP: srcAddr.String(),
|
||||
dstIP: dstAddr.String(),
|
||||
dstPort: dstPort,
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
p.natMu.Lock()
|
||||
p.resourceTable[dKey] = matchedRule.ResourceId
|
||||
p.natMu.Unlock()
|
||||
}
|
||||
|
||||
// Check if we need to perform DNAT
|
||||
if matchedRule.RewriteTo != "" {
|
||||
// Create connection tracking key using original destination
|
||||
@@ -420,6 +482,13 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
|
||||
// Store resource ID for access logging if present
|
||||
if matchedRule.ResourceId != 0 {
|
||||
p.natMu.Lock()
|
||||
p.resourceTable[dKey] = matchedRule.ResourceId
|
||||
p.natMu.Unlock()
|
||||
}
|
||||
|
||||
// Check if we already have a NAT entry for this connection
|
||||
p.natMu.RLock()
|
||||
existingEntry, exists := p.natTable[key]
|
||||
@@ -720,6 +789,11 @@ func (p *ProxyHandler) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shut down access logger
|
||||
if p.accessLogger != nil {
|
||||
p.accessLogger.Close()
|
||||
}
|
||||
|
||||
// Close ICMP replies channel
|
||||
if p.icmpReplies != nil {
|
||||
close(p.icmpReplies)
|
||||
|
||||
@@ -47,7 +47,7 @@ func prefixEqual(a, b netip.Prefix) bool {
|
||||
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
@@ -57,6 +57,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
|
||||
DisableIcmp: disableIcmp,
|
||||
RewriteTo: rewriteTo,
|
||||
PortRanges: portRanges,
|
||||
ResourceId: resourceId,
|
||||
}
|
||||
|
||||
// Canonicalize source prefix to handle host bits correctly
|
||||
|
||||
@@ -354,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||
// AddProxySubnetRule adds a subnet rule to the proxy handler
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
||||
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool, resourceId int) {
|
||||
tun := (*netTun)(net)
|
||||
if tun.proxyHandler != nil {
|
||||
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
|
||||
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp, resourceId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,6 +385,15 @@ func (net *Net) GetProxyHandler() *ProxyHandler {
|
||||
return tun.proxyHandler
|
||||
}
|
||||
|
||||
// SetAccessLogSender configures the function used to send compressed access log
|
||||
// batches to the server. This should be called once the websocket client is available.
|
||||
func (net *Net) SetAccessLogSender(fn SendFunc) {
|
||||
tun := (*netTun)(net)
|
||||
if tun.proxyHandler != nil {
|
||||
tun.proxyHandler.SetAccessLogSender(fn)
|
||||
}
|
||||
}
|
||||
|
||||
type PingConn struct {
|
||||
laddr PingAddr
|
||||
raddr PingAddr
|
||||
|
||||
Reference in New Issue
Block a user