Files
netbird/client/internal/netflow/conntrack/conntrack.go

439 lines
10 KiB
Go

//go:build linux && !android
package conntrack
import (
"encoding/binary"
"fmt"
"net/netip"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
nfct "github.com/ti-mo/conntrack"
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
defaultChannelSize = 100
reconnectInitInterval = 5 * time.Second
reconnectMaxInterval = 5 * time.Minute
reconnectRandomization = 0.5
)
// listener abstracts a netlink conntrack connection for testability.
type listener interface {
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
Close() error
}
// ConnTrack manages kernel-based conntrack events
type ConnTrack struct {
flowLogger nftypes.FlowLogger
iface nftypes.IFaceMapper
conn listener
mux sync.Mutex
dial func() (listener, error)
instanceID uuid.UUID
started bool
done chan struct{}
sysctlModified bool
}
// DialFunc is a constructor for netlink conntrack connections.
type DialFunc func() (listener, error)
// Option configures a ConnTrack instance.
type Option func(*ConnTrack)
// WithDialer overrides the default netlink dialer, primarily for testing.
func WithDialer(dial DialFunc) Option {
return func(c *ConnTrack) {
c.dial = dial
}
}
func defaultDial() (listener, error) {
return nfct.Dial(nil)
}
// New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
ct := &ConnTrack{
flowLogger: flowLogger,
iface: iface,
instanceID: uuid.New(),
dial: defaultDial,
done: make(chan struct{}, 1),
}
for _, opt := range opts {
opt(ct)
}
return ct
}
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
func (c *ConnTrack) Start(enableCounters bool) error {
c.mux.Lock()
defer c.mux.Unlock()
if c.started {
return nil
}
log.Info("Starting conntrack event listening")
if enableCounters {
c.EnableAccounting()
}
conn, err := c.dial()
if err != nil {
c.RestoreAccounting()
return fmt.Errorf("dial conntrack: %w", err)
}
c.conn = conn
events := make(chan nfct.Event, defaultChannelSize)
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
netfilter.GroupCTNew,
netfilter.GroupCTDestroy,
})
if err != nil {
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
c.RestoreAccounting()
return fmt.Errorf("start conntrack listener: %w", err)
}
// Drain any stale stop signal from a previous cycle.
select {
case <-c.done:
default:
}
c.started = true
go c.receiverRoutine(events, errChan)
return nil
}
func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) {
for {
select {
case event := <-events:
c.handleEvent(event)
case err := <-errChan:
if events, errChan = c.handleListenerError(err); events == nil {
return
}
case <-c.done:
return
}
}
}
// handleListenerError closes the failed connection and attempts to reconnect.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) {
log.Warnf("conntrack event listener failed: %v", err)
c.closeConn()
return c.reconnect()
}
func (c *ConnTrack) closeConn() {
c.mux.Lock()
defer c.mux.Unlock()
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Debugf("close conntrack connection: %v", err)
}
c.conn = nil
}
}
// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
bo := &backoff.ExponentialBackOff{
InitialInterval: reconnectInitInterval,
RandomizationFactor: reconnectRandomization,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: reconnectMaxInterval,
MaxElapsedTime: 0, // retry indefinitely
Clock: backoff.SystemClock,
}
bo.Reset()
for {
delay := bo.NextBackOff()
log.Infof("reconnecting conntrack listener in %s", delay)
select {
case <-c.done:
c.mux.Lock()
c.started = false
c.mux.Unlock()
return nil, nil
case <-time.After(delay):
}
conn, err := c.dial()
if err != nil {
log.Warnf("reconnect conntrack dial: %v", err)
continue
}
events := make(chan nfct.Event, defaultChannelSize)
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
netfilter.GroupCTNew,
netfilter.GroupCTDestroy,
})
if err != nil {
log.Warnf("reconnect conntrack listen: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
continue
}
c.mux.Lock()
if !c.started {
// Stop() ran while we were reconnecting.
c.mux.Unlock()
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
return nil, nil
}
c.conn = conn
c.mux.Unlock()
log.Infof("conntrack listener reconnected successfully")
return events, errChan
}
}
// Stop stops the connection tracking. This method is idempotent.
func (c *ConnTrack) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if !c.started {
return
}
log.Info("Stopping conntrack event listening")
select {
case c.done <- struct{}{}:
default:
}
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
}
c.started = false
c.RestoreAccounting()
}
// Close stops listening for events and cleans up resources
func (c *ConnTrack) Close() error {
c.mux.Lock()
defer c.mux.Unlock()
if !c.started {
return nil
}
select {
case c.done <- struct{}{}:
default:
}
c.started = false
var closeErr error
if c.conn != nil {
closeErr = c.conn.Close()
c.conn = nil
}
c.RestoreAccounting()
if closeErr != nil {
return fmt.Errorf("close conntrack: %w", closeErr)
}
return nil
}
// handleEvent processes incoming conntrack events
func (c *ConnTrack) handleEvent(event nfct.Event) {
if event.Flow == nil {
return
}
if event.Type != nfct.EventNew && event.Type != nfct.EventDestroy {
return
}
flow := *event.Flow
proto := nftypes.Protocol(flow.TupleOrig.Proto.Protocol)
if proto == nftypes.ProtocolUnknown {
return
}
srcIP := flow.TupleOrig.IP.SourceAddress
dstIP := flow.TupleOrig.IP.DestinationAddress
if !c.relevantFlow(flow.Mark, srcIP, dstIP) {
return
}
var srcPort, dstPort uint16
var icmpType, icmpCode uint8
switch proto {
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
srcPort = flow.TupleOrig.Proto.SourcePort
dstPort = flow.TupleOrig.Proto.DestinationPort
case nftypes.ICMP:
icmpType = flow.TupleOrig.Proto.ICMPType
icmpCode = flow.TupleOrig.Proto.ICMPCode
}
flowID := c.getFlowID(flow.ID)
direction := c.inferDirection(flow.Mark, srcIP, dstIP)
eventType := nftypes.TypeStart
eventStr := "New"
if event.Type == nfct.EventDestroy {
eventType = nftypes.TypeEnd
eventStr = "Ended"
}
log.Tracef("%s %s %s connection: %s:%d → %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
c.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: eventType,
Direction: direction,
Protocol: proto,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
ICMPType: icmpType,
ICMPCode: icmpCode,
RxPackets: c.mapRxPackets(flow, direction),
TxPackets: c.mapTxPackets(flow, direction),
RxBytes: c.mapRxBytes(flow, direction),
TxBytes: c.mapTxBytes(flow, direction),
})
}
// relevantFlow checks if the flow is related to the specified interface
func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
if nbnet.IsDataPlaneMark(mark) {
return true
}
// fallback if mark rules are not in place
wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
}
// mapRxPackets maps packet counts to RX based on flow direction
func (c *ConnTrack) mapRxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersOrig is from external to us (RX)
// For Egress: CountersReply is from external to us (RX)
if direction == nftypes.Ingress {
return flow.CountersOrig.Packets
}
return flow.CountersReply.Packets
}
// mapTxPackets maps packet counts to TX based on flow direction
func (c *ConnTrack) mapTxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersReply is from us to external (TX)
// For Egress: CountersOrig is from us to external (TX)
if direction == nftypes.Ingress {
return flow.CountersReply.Packets
}
return flow.CountersOrig.Packets
}
// mapRxBytes maps byte counts to RX based on flow direction
func (c *ConnTrack) mapRxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersOrig is from external to us (RX)
// For Egress: CountersReply is from external to us (RX)
if direction == nftypes.Ingress {
return flow.CountersOrig.Bytes
}
return flow.CountersReply.Bytes
}
// mapTxBytes maps byte counts to TX based on flow direction
func (c *ConnTrack) mapTxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersReply is from us to external (TX)
// For Egress: CountersOrig is from us to external (TX)
if direction == nftypes.Ingress {
return flow.CountersReply.Bytes
}
return flow.CountersOrig.Bytes
}
// getFlowID creates a unique UUID based on the conntrack ID and instance ID
func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], conntrackID)
return uuid.NewSHA1(c.instanceID, buf[:])
}
func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes.Direction {
switch mark {
case nbnet.DataPlaneMarkIn:
return nftypes.Ingress
case nbnet.DataPlaneMarkOut:
return nftypes.Egress
}
// fallback if marks are not set
wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network
switch {
case wgaddr == srcIP:
return nftypes.Egress
case wgaddr == dstIP:
return nftypes.Ingress
case wgnetwork.Contains(srcIP):
// netbird network -> resource network
return nftypes.Ingress
case wgnetwork.Contains(dstIP):
// resource network -> netbird network
return nftypes.Egress
}
return nftypes.DirectionUnknown
}