Add transparent proxy inspection engine with envoy sidecar support

This commit is contained in:
Viktor Liu
2026-04-11 18:07:46 +02:00
parent 5259e5df51
commit afbddae472
65 changed files with 10428 additions and 763 deletions

2
.gitignore vendored
View File

@@ -33,3 +33,5 @@ infrastructure_files/setup-*.env
vendor/
/netbird
client/netbird-electron/
management/server/types/testdata/comparison/
management/server/types/testdata/*.json

View File

@@ -364,6 +364,28 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for iptables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for iptables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)

View File

@@ -89,6 +89,8 @@ type router struct {
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
tproxyRules []tproxyRuleEntry
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -1109,3 +1111,92 @@ func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}
// AddTProxyRule adds iptables nat PREROUTING REDIRECT rules for transparent proxy interception.
// Traffic from sources on dstPorts arriving on the WG interface is redirected
// to the transparent proxy listener on redirectPort.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
portStr := fmt.Sprintf("%d", redirectPort)
for _, proto := range []string{"tcp", "udp"} {
srcSpecs := r.buildSourceSpecs(sources)
for _, srcSpec := range srcSpecs {
if len(dstPorts) == 0 {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s: %w", ruleID, proto, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
} else {
for _, port := range dstPorts {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"--dport", fmt.Sprintf("%d", port),
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s/%d: %w", ruleID, proto, port, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
}
}
}
}
return nil
}
// RemoveTProxyRule removes all iptables REDIRECT rules for the given ruleID.
func (r *router) RemoveTProxyRule(ruleID string) error {
var remaining []tproxyRuleEntry
for _, entry := range r.tproxyRules {
if entry.ruleID != ruleID {
remaining = append(remaining, entry)
continue
}
if err := r.iptablesClient.DeleteIfExists(entry.table, entry.chain, entry.spec...); err != nil {
log.Debugf("remove tproxy rule %s: %v", ruleID, err)
}
}
r.tproxyRules = remaining
return nil
}
type tproxyRuleEntry struct {
ruleID string
table string
chain string
spec []string
}
func (r *router) buildSourceSpecs(sources []netip.Prefix) [][]string {
if len(sources) == 0 {
return [][]string{{}} // empty spec = match any source
}
specs := make([][]string, 0, len(sources))
for _, src := range sources {
specs = append(specs, []string{"-s", src.String()})
}
return specs
}

View File

@@ -180,6 +180,22 @@ type Manager interface {
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
// AddTProxyRule adds TPROXY redirect rules for specific source CIDRs and destination ports.
// Traffic from sources on dstPorts is redirected to the transparent proxy on redirectPort.
// Empty dstPorts means redirect all ports.
AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error
// RemoveTProxyRule removes TPROXY redirect rules by ID.
RemoveTProxyRule(ruleID string) error
// AddUDPInspectionHook registers a hook that inspects UDP packets before forwarding.
// The hook receives the raw packet and returns true to drop it.
// Used for QUIC SNI-based blocking. Returns a hook ID for removal.
AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string
// RemoveUDPInspectionHook removes a previously registered inspection hook.
RemoveUDPInspectionHook(hookID string)
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -482,6 +482,28 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for nftables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for nftables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,

View File

@@ -77,6 +77,7 @@ type router struct {
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -2137,3 +2138,227 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
},
}, nil
}
// AddTProxyRule adds nftables TPROXY redirect rules in the mangle prerouting chain.
// Traffic from sources on dstPorts arriving on the WG interface is redirected to
// the transparent proxy listener on redirectPort.
// Separate rules are created for TCP and UDP protocols.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
// Use the nat redirect chain for DNAT rules.
// TPROXY doesn't work on WG kernel interfaces (socket assignment silently fails),
// so we use DNAT to 127.0.0.1:proxy_port instead. The proxy reads the original
// destination via SO_ORIGINAL_DST (conntrack).
chain := r.chains[chainNameRoutingRdr]
if chain == nil {
return fmt.Errorf("nat redirect chain not initialized")
}
for _, proto := range []uint8{unix.IPPROTO_TCP, unix.IPPROTO_UDP} {
protoName := "tcp"
if proto == unix.IPPROTO_UDP {
protoName = "udp"
}
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, protoName)
if existing, ok := r.rules[ruleKey]; ok && existing.Handle != 0 {
if err := r.decrementSetCounter(existing); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(existing); err != nil {
log.Debugf("remove existing tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
}
exprs, err := r.buildRedirectExprs(proto, sources, dstPorts, redirectPort)
if err != nil {
return fmt.Errorf("build redirect exprs for %s: %w", protoName, err)
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
})
}
// Accept redirected packets in the ACL input chain. After REDIRECT, the
// destination port becomes the proxy port. Without this rule, the ACL filter
// drops the packet. We match on ct state dnat so only REDIRECT'd connections
// are accepted: direct connections to the proxy port are blocked.
inputAcceptKey := fmt.Sprintf("tproxy-%s-input", ruleID)
if _, ok := r.rules[inputAcceptKey]; !ok {
inputChain := &nftables.Chain{
Name: "netbird-acl-input-rules",
Table: r.workTable,
}
r.rules[inputAcceptKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: inputChain,
Exprs: []expr.Any{
// Only accept connections that were REDIRECT'd (ct status dnat)
&expr.Ct{Register: 1, Key: expr.CtKeySTATUS},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(0x20), // IPS_DST_NAT
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
// Accept both TCP and UDP redirected to the proxy port.
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(redirectPort),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(inputAcceptKey),
})
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rules for %s: %w", ruleID, err)
}
return nil
}
// RemoveTProxyRule removes TPROXY redirect rules by ID (both TCP and UDP variants).
func (r *router) RemoveTProxyRule(ruleID string) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var removed int
for _, suffix := range []string{"tcp", "udp", "input"} {
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, suffix)
rule, ok := r.rules[ruleKey]
if !ok {
continue
}
if rule.Handle == 0 {
delete(r.rules, ruleKey)
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(rule); err != nil {
log.Debugf("delete tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
removed++
}
if removed > 0 {
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rule removal for %s: %w", ruleID, err)
}
}
return nil
}
// buildRedirectExprs builds nftables expressions for a REDIRECT rule.
// Matches WG interface ingress, source CIDRs, destination ports, then REDIRECTs to the proxy port.
func (r *router) buildRedirectExprs(proto uint8, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) ([]expr.Any, error) {
var exprs []expr.Any
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname(r.wgIface.Name())},
)
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
)
// Source CIDRs use the named ipset shared with route rules.
if len(sources) > 0 {
srcSet := firewall.NewPrefixSet(sources)
srcExprs, err := r.getIpSet(srcSet, sources, true)
if err != nil {
return nil, fmt.Errorf("get source ipset: %w", err)
}
exprs = append(exprs, srcExprs...)
}
if len(dstPorts) == 1 {
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(dstPorts[0]),
},
)
} else if len(dstPorts) > 1 {
setElements := make([]nftables.SetElement, len(dstPorts))
for i, p := range dstPorts {
setElements[i] = nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)}
}
portSet := &nftables.Set{
Table: r.workTable,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetService,
}
if err := r.conn.AddSet(portSet, setElements); err != nil {
return nil, fmt.Errorf("create port set: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Lookup{
SourceRegister: 1,
SetName: portSet.Name,
SetID: portSet.ID,
},
)
}
// REDIRECT to local proxy port. Changes the destination to the interface's
// primary address + specified port. Conntrack tracks the original destination,
// readable via SO_ORIGINAL_DST.
exprs = append(exprs,
&expr.Immediate{Register: 1, Data: binaryutil.BigEndian.PutUint16(redirectPort)},
&expr.Redir{
RegisterProtoMin: 1,
},
)
return exprs, nil
}

View File

@@ -641,6 +641,45 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// AddTProxyRule delegates to the native firewall for TPROXY rules.
// In userspace mode (no native firewall), this is a no-op since the
// forwarder intercepts traffic directly.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// AddUDPInspectionHook registers a hook for QUIC/UDP inspection via the packet filter.
func (m *Manager) AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string {
m.SetUDPPacketHook(netip.Addr{}, dstPort, hook)
return "udp-inspection"
}
// RemoveUDPInspectionHook removes a previously registered inspection hook.
func (m *Manager) RemoveUDPInspectionHook(_ string) {
m.SetUDPPacketHook(netip.Addr{}, 0, nil)
}
// RemoveTProxyRule delegates to the native firewall for TPROXY rules.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveTProxyRule(ruleID)
}
// IsLocalIP reports whether the given IP belongs to the local machine.
func (m *Manager) IsLocalIP(ip netip.Addr) bool {
return m.localipmanager.IsLocalIP(ip)
}
// GetForwarder returns the userspace packet forwarder, or nil if not initialized.
func (m *Manager) GetForwarder() *forwarder.Forwarder {
return m.forwarder.Load()
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/inspect"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -46,6 +48,10 @@ type Forwarder struct {
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
// proxy is the optional inspection engine.
// When set, TCP connections are handed to the engine for protocol detection
// and rule evaluation. Swapped atomically for lock-free hot-path access.
proxy atomic.Pointer[inspect.Proxy]
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -79,7 +85,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %s", err)
return nil, fmt.Errorf("add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
@@ -155,6 +161,13 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
return nil
}
// SetProxy sets the inspection engine. When set, TCP connections are handed
// to it for protocol detection and rule evaluation instead of direct relay.
// Pass nil to disable inspection.
func (f *Forwarder) SetProxy(p *inspect.Proxy) {
f.proxy.Store(p)
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
@@ -167,6 +180,25 @@ func (f *Forwarder) Stop() {
f.stack.Wait()
}
// CheckUDPPacket inspects a UDP payload against proxy rules before injection.
// This is called by the filter for QUIC SNI-based blocking.
// Returns true if the packet should be allowed, false if it should be dropped.
func (f *Forwarder) CheckUDPPacket(payload []byte, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) bool {
p := f.proxy.Load()
if p == nil {
return true
}
dst := netip.AddrPortFrom(dstIP, dstPort)
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(ruleID),
}
action := p.HandleUDPPacket(payload, dst, src)
return action != inspect.ActionBlock
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1)

View File

@@ -16,6 +16,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
"github.com/netbirdio/netbird/client/inspect"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -23,6 +24,86 @@ import (
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
// If the inspection engine is configured, accept the connection first and hand it off.
if p := f.proxy.Load(); p != nil {
f.handleTCPWithInspection(r, id, p)
return
}
f.handleTCPDirect(r, id)
}
// handleTCPWithInspection accepts the connection and hands it to the inspection
// engine. For allow decisions, the forwarder does its own relay (passthrough).
// For block/inspect, the engine handles everything internally.
func (f *Forwarder) handleTCPWithInspection(r *tcp.ForwarderRequest, id stack.TransportEndpointID, p *inspect.Proxy) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error1("forwarder: create TCP endpoint for inspection: %v", epErr)
r.Complete(true)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
srcIP := netip.AddrFrom4(id.RemoteAddress.As4())
dstIP := netip.AddrFrom4(id.LocalAddress.As4())
dst := netip.AddrPortFrom(dstIP, id.LocalPort)
var policyID []byte
if ruleID, ok := f.getRuleID(srcIP, dstIP, id.RemotePort, id.LocalPort); ok {
policyID = ruleID
}
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(policyID),
}
f.logger.Trace1("forwarder: handing TCP %v to inspection engine", epID(id))
go func() {
result, err := p.InspectTCP(f.ctx, inConn, dst, src)
if err != nil && err != inspect.ErrBlocked {
f.logger.Debug2("forwarder: inspection error for %v: %v", epID(id), err)
}
// Passthrough: engine returned allow, forwarder does the relay.
if result.PassthroughConn != nil {
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, dialErr := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if dialErr != nil {
f.logger.Trace2("forwarder: passthrough dial error for %v: %v", epID(id), dialErr)
if closeErr := result.PassthroughConn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: close passthrough conn: %v", closeErr)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
f.proxyTCPPassthrough(id, result.PassthroughConn, outConn, ep, flowID)
return
}
// Engine handled it (block/inspect/HTTP). Capture stats and clean up.
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, rxPackets, txPackets)
}()
}
// handleTCPDirect handles TCP connections with direct relay (no proxy).
func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportEndpointID) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
@@ -42,7 +123,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
@@ -55,7 +135,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
@@ -73,7 +152,6 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
@@ -132,6 +210,66 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
// proxyTCPPassthrough relays traffic between a peeked inbound connection
// (from the inspection engine passthrough) and the outbound connection.
// It accepts net.Conn for inConn since the inspection engine wraps it in a peekConn.
func (f *Forwarder) proxyTCPPassthrough(id stack.TransportEndpointID, inConn net.Conn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough inConn close: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough outConn close: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesIn int64
bytesOut int64
errIn error
errOut error
)
go func() {
bytesIn, errIn = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesOut, errOut = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errIn != nil && !isClosedError(errIn) {
f.logger.Error2("proxyTCPPassthrough: copy error (in→out) for %s: %v", epID(id), errIn)
}
if errOut != nil && !isClosedError(errOut) {
f.logger.Error2("proxyTCPPassthrough: copy error (out→in) for %s: %v", epID(id), errOut)
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace5("forwarder: passthrough TCP %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesOut, txPackets, bytesIn)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesOut), uint64(bytesIn), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())

212
client/inspect/config.go Normal file
View File

@@ -0,0 +1,212 @@
package inspect
import (
"crypto"
"crypto/x509"
"net"
"net/netip"
"net/url"
"strings"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// InspectResult holds the outcome of connection inspection.
type InspectResult struct {
// Action is the rule evaluation result.
Action Action
// PassthroughConn is the client connection with buffered peeked bytes.
// Non-nil only when Action is ActionAllow and the caller should relay
// (TLS passthrough or non-HTTP/TLS protocol). The caller takes ownership
// and is responsible for closing this connection.
PassthroughConn net.Conn
}
const (
// DefaultTProxyPort is the default TPROXY listener port for kernel mode.
// Override with NB_TPROXY_PORT environment variable.
DefaultTProxyPort = 22080
)
// Action determines how the proxy handles a matched connection.
type Action string
const (
// ActionAllow passes the connection through without decryption.
ActionAllow Action = "allow"
// ActionBlock denies the connection.
ActionBlock Action = "block"
// ActionInspect decrypts (MITM) and inspects the connection.
ActionInspect Action = "inspect"
)
// ProxyMode determines the proxy operating mode.
type ProxyMode string
const (
// ModeBuiltin uses the built-in proxy with rules and optional ICAP.
ModeBuiltin ProxyMode = "builtin"
// ModeEnvoy runs a local envoy sidecar for L7 processing.
// Go manages envoy lifecycle, config generation, and rule evaluation.
// USP path forwards via PROXY protocol v2; kernel path uses nftables redirect.
ModeEnvoy ProxyMode = "envoy"
// ModeExternal forwards all traffic to an external proxy.
ModeExternal ProxyMode = "external"
)
// PolicyID is the management policy identifier associated with a connection.
type PolicyID []byte
// MatchDomain reports whether target matches the pattern.
// If pattern starts with "*.", it matches any subdomain (but not the base itself).
// Otherwise it requires an exact match.
func MatchDomain(pattern, target domain.Domain) bool {
p := pattern.PunycodeString()
t := target.PunycodeString()
if strings.HasPrefix(p, "*.") {
base := p[2:]
return strings.HasSuffix(t, "."+base)
}
return p == t
}
// SourceInfo carries source identity context for rule evaluation.
// The source may be a direct WireGuard peer or a host behind
// a site-to-site gateway.
type SourceInfo struct {
// IP is the original source address from the packet.
IP netip.Addr
// PolicyID is the management policy that allowed this traffic
// through route ACLs.
PolicyID PolicyID
}
// ProtoType identifies a protocol handled by the proxy.
type ProtoType string
const (
ProtoHTTP ProtoType = "http"
ProtoHTTPS ProtoType = "https"
ProtoH2 ProtoType = "h2"
ProtoH3 ProtoType = "h3"
ProtoWebSocket ProtoType = "websocket"
ProtoOther ProtoType = "other"
)
// Rule defines a proxy inspection/filtering rule.
type Rule struct {
// ID uniquely identifies this rule.
ID id.RuleID
// Sources are the source CIDRs this rule applies to.
// Includes both direct peer IPs and routed networks behind gateways.
Sources []netip.Prefix
// Domains are the destination domain patterns to match (via SNI or Host header).
// Supports exact match ("example.com") and wildcard ("*.example.com").
Domains []domain.Domain
// Networks are the destination CIDRs to match.
Networks []netip.Prefix
// Ports are the destination ports to match. Empty means all ports.
Ports []uint16
// Protocols restricts which protocols this rule applies to.
// Empty means all protocols.
Protocols []ProtoType
// Paths are URL path patterns to match (HTTP only, requires inspect for HTTPS).
// Supports prefix ("/api/"), exact ("/login"), and wildcard ("/admin/*").
// Empty means all paths.
Paths []string
// Action determines what to do with matched connections.
Action Action
// Priority controls evaluation order. Lower values are evaluated first.
Priority int
}
// ICAPConfig holds ICAP service configuration.
type ICAPConfig struct {
// ReqModURL is the ICAP REQMOD service URL (e.g., icap://server:1344/reqmod).
ReqModURL *url.URL
// RespModURL is the ICAP RESPMOD service URL (e.g., icap://server:1344/respmod).
RespModURL *url.URL
// MaxConnections is the connection pool size. Zero uses a default.
MaxConnections int
}
// TLSConfig holds the MITM CA configuration for TLS inspection.
type TLSConfig struct {
// CA is the certificate authority used to sign dynamic certificates.
CA *x509.Certificate
// CAKey is the CA's private key.
CAKey crypto.PrivateKey
}
// Config holds the transparent proxy configuration.
type Config struct {
// Enabled controls whether the proxy is active.
Enabled bool
// Mode selects built-in or external proxy operation.
Mode ProxyMode
// ExternalURL is the upstream proxy URL for ModeExternal.
// Supports http:// and socks5:// schemes.
ExternalURL *url.URL
// DefaultAction applies when no rule matches a connection.
DefaultAction Action
// RedirectSources are the source CIDRs whose traffic should be intercepted.
// Admin decides: "activate for these users/subnets."
// Used for both kernel TPROXY rules and userspace forwarder source filtering.
RedirectSources []netip.Prefix
// RedirectPorts are the destination ports to intercept. Empty means all ports.
RedirectPorts []uint16
// Rules are the proxy inspection/filtering rules, evaluated in Priority order.
Rules []Rule
// ICAP holds ICAP service configuration. Nil disables ICAP.
ICAP *ICAPConfig
// TLS holds the MITM CA. Nil means no MITM capability (ActionInspect rules ignored).
TLS *TLSConfig
// Envoy configuration (ModeEnvoy only)
Envoy *EnvoyConfig
// ListenAddr is the TPROXY listen address for kernel mode.
// Zero value disables the TPROXY listener.
ListenAddr netip.AddrPort
// WGNetwork is the WireGuard overlay network prefix.
// The proxy blocks dialing destinations inside this network.
WGNetwork netip.Prefix
// LocalIPChecker reports whether an IP belongs to the routing peer.
// Used to prevent SSRF to local services. May be nil.
LocalIPChecker LocalIPChecker
}
// EnvoyConfig holds configuration for the envoy sidecar mode.
type EnvoyConfig struct {
// BinaryPath is the path to the envoy binary.
// Empty means search $PATH for "envoy".
BinaryPath string
// AdminPort is the port for envoy's admin API (health checks, stats).
// Zero means auto-assign.
AdminPort uint16
// Snippets are user-provided config fragments merged into the generated bootstrap.
Snippets *EnvoySnippets
}
// EnvoySnippets holds user-provided YAML fragments for envoy config customization.
// Only safe snippet types are allowed: filters (HTTP and network) and clusters
// needed as dependencies for filter services. Listeners and bootstrap overrides
// are not exposed since we manage the listener and bootstrap.
type EnvoySnippets struct {
// HTTPFilters is YAML injected into the HCM filter chain before the router filter.
// Used for ext_authz, rate limiting, Lua, Wasm, RBAC, JWT auth, etc.
HTTPFilters string
// NetworkFilters is YAML injected into the TLS filter chain before tcp_proxy.
// Used for network-level RBAC, rate limiting, ext_authz on raw TCP.
NetworkFilters string
// Clusters is YAML for additional upstream clusters referenced by filters.
// Needed when filters call external services (ext_authz backend, rate limit service).
Clusters string
}

View File

@@ -0,0 +1,93 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/domain"
)
func TestMatchDomain(t *testing.T) {
tests := []struct {
name string
pattern string
target string
want bool
}{
{
name: "exact match",
pattern: "example.com",
target: "example.com",
want: true,
},
{
name: "exact no match",
pattern: "example.com",
target: "other.com",
want: false,
},
{
name: "wildcard matches subdomain",
pattern: "*.example.com",
target: "foo.example.com",
want: true,
},
{
name: "wildcard matches deep subdomain",
pattern: "*.example.com",
target: "a.b.c.example.com",
want: true,
},
{
name: "wildcard does not match base",
pattern: "*.example.com",
target: "example.com",
want: false,
},
{
name: "wildcard does not match unrelated",
pattern: "*.example.com",
target: "foo.other.com",
want: false,
},
{
name: "case insensitive exact match",
pattern: "Example.COM",
target: "example.com",
want: true,
},
{
name: "case insensitive wildcard match",
pattern: "*.Example.COM",
target: "FOO.example.com",
want: true,
},
{
name: "wildcard does not match partial suffix",
pattern: "*.example.com",
target: "notexample.com",
want: false,
},
{
name: "unicode domain punycode match",
pattern: "*.münchen.de",
target: "sub.xn--mnchen-3ya.de",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pattern, err := domain.FromString(tt.pattern)
require.NoError(t, err)
target, err := domain.FromString(tt.target)
require.NoError(t, err)
got := MatchDomain(pattern, target)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,25 @@
package inspect
import (
"net"
"syscall"
)
// newOutboundDialer creates a net.Dialer that clears the socket fwmark.
// In kernel TPROXY mode, accepted connections inherit the TPROXY fwmark.
// Without clearing it, outbound connections from the proxy would match
// the ip rule (fwmark -> local loopback) and loop back to the proxy
// instead of reaching the real destination.
func newOutboundDialer() net.Dialer {
return net.Dialer{
Control: func(_, _ string, c syscall.RawConn) error {
var sockErr error
if err := c.Control(func(fd uintptr) {
sockErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, 0)
}); err != nil {
return err
}
return sockErr
},
}
}

View File

@@ -0,0 +1,11 @@
//go:build !linux
package inspect
import "net"
// newOutboundDialer returns a plain dialer on non-Linux platforms.
// TPROXY is Linux-only, so no fwmark clearing is needed.
func newOutboundDialer() net.Dialer {
return net.Dialer{}
}

298
client/inspect/envoy.go Normal file
View File

@@ -0,0 +1,298 @@
package inspect
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
envoyStartTimeout = 15 * time.Second
envoyHealthInterval = 500 * time.Millisecond
envoyStopTimeout = 10 * time.Second
envoyDrainTime = 5
)
// envoyManager manages the lifecycle of an envoy sidecar process.
type envoyManager struct {
log *log.Entry
cmd *exec.Cmd
configPath string
listenPort uint16
adminPort uint16
cancel context.CancelFunc
blockPagePath string
mu sync.Mutex
running bool
}
// startEnvoy finds the envoy binary, generates config, and spawns the process.
// It blocks until envoy reports healthy or the timeout expires.
func startEnvoy(ctx context.Context, logger *log.Entry, config Config) (*envoyManager, error) {
envCfg := config.Envoy
if envCfg == nil {
return nil, fmt.Errorf("envoy config is nil")
}
binaryPath, err := findEnvoyBinary(envCfg.BinaryPath)
if err != nil {
return nil, fmt.Errorf("find envoy binary: %w", err)
}
// Pick admin port
adminPort := envCfg.AdminPort
if adminPort == 0 {
p, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free admin port: %w", err)
}
adminPort = p
}
// Pick listener port
listenPort, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free listener port: %w", err)
}
// Use a private temp directory (0700) to prevent local attackers from
// replacing the config file between write and envoy read.
configDir, err := os.MkdirTemp("", "nb-envoy-*")
if err != nil {
return nil, fmt.Errorf("create envoy config directory: %w", err)
}
// Write the block page HTML for envoy's direct_response to reference.
blockPagePath := filepath.Join(configDir, "block.html")
blockHTML := fmt.Sprintf(blockPageHTML, "blocked domain", "this domain")
if err := os.WriteFile(blockPagePath, []byte(blockHTML), 0600); err != nil {
return nil, fmt.Errorf("write envoy block page: %w", err)
}
// Generate config with the block page path embedded.
bootstrap, err := generateBootstrap(config, listenPort, adminPort, blockPagePath)
if err != nil {
return nil, fmt.Errorf("generate envoy bootstrap: %w", err)
}
configPath := filepath.Join(configDir, "bootstrap.yaml")
if err := os.WriteFile(configPath, bootstrap, 0600); err != nil {
return nil, fmt.Errorf("write envoy config: %w", err)
}
ctx, cancel := context.WithCancel(ctx)
cmd := exec.CommandContext(ctx, binaryPath,
"-c", configPath,
"--drain-time-s", fmt.Sprintf("%d", envoyDrainTime),
)
// Pipe envoy output to our logger.
cmd.Stdout = &logWriter{entry: logger, level: log.DebugLevel}
cmd.Stderr = &logWriter{entry: logger, level: log.WarnLevel}
if err := cmd.Start(); err != nil {
cancel()
os.Remove(configPath)
return nil, fmt.Errorf("start envoy: %w", err)
}
mgr := &envoyManager{
log: logger,
cmd: cmd,
configPath: configPath,
listenPort: listenPort,
adminPort: adminPort,
blockPagePath: blockPagePath,
cancel: cancel,
running: true,
}
// Wait for envoy to become healthy.
if err := mgr.waitHealthy(ctx); err != nil {
mgr.Stop()
return nil, fmt.Errorf("wait for envoy readiness: %w", err)
}
logger.Infof("inspect: envoy started (pid=%d, listen=%d, admin=%d)", cmd.Process.Pid, listenPort, adminPort)
// Monitor process exit in background.
go mgr.monitor()
return mgr, nil
}
// ListenAddr returns the address envoy listens on for forwarded connections.
func (m *envoyManager) ListenAddr() netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), m.listenPort)
}
// AdminAddr returns the envoy admin API address.
func (m *envoyManager) AdminAddr() string {
return fmt.Sprintf("127.0.0.1:%d", m.adminPort)
}
// Reload writes a new config and sends SIGHUP to envoy.
func (m *envoyManager) Reload(config Config) error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return fmt.Errorf("envoy is not running")
}
bootstrap, err := generateBootstrap(config, m.listenPort, m.adminPort, m.blockPagePath)
if err != nil {
return fmt.Errorf("generate envoy bootstrap: %w", err)
}
if err := os.WriteFile(m.configPath, bootstrap, 0600); err != nil {
return fmt.Errorf("write envoy config: %w", err)
}
if err := signalReload(m.cmd.Process); err != nil {
return fmt.Errorf("signal envoy reload: %w", err)
}
m.log.Debugf("inspect: envoy config reloaded")
return nil
}
// Healthy checks the envoy admin API /ready endpoint.
func (m *envoyManager) Healthy() bool {
resp, err := http.Get(fmt.Sprintf("http://%s/ready", m.AdminAddr()))
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
// Stop terminates the envoy process and cleans up.
func (m *envoyManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
m.running = false
m.cancel()
if m.cmd.Process != nil {
done := make(chan struct{})
go func() {
m.cmd.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(envoyStopTimeout):
m.log.Warnf("inspect: envoy did not exit in %s, killing", envoyStopTimeout)
m.cmd.Process.Kill()
<-done
}
}
os.RemoveAll(filepath.Dir(m.configPath))
m.log.Infof("inspect: envoy stopped")
}
// waitHealthy polls the admin API until envoy is ready or timeout.
func (m *envoyManager) waitHealthy(ctx context.Context) error {
deadline := time.After(envoyStartTimeout)
ticker := time.NewTicker(envoyHealthInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadline:
return fmt.Errorf("envoy not ready after %s", envoyStartTimeout)
case <-ticker.C:
if m.Healthy() {
return nil
}
}
}
}
// monitor watches for unexpected envoy exits.
func (m *envoyManager) monitor() {
err := m.cmd.Wait()
m.mu.Lock()
wasRunning := m.running
m.running = false
m.mu.Unlock()
if wasRunning {
m.log.Errorf("inspect: envoy exited unexpectedly: %v", err)
}
}
// findEnvoyBinary resolves the envoy binary path.
func findEnvoyBinary(configPath string) (string, error) {
if configPath != "" {
if _, err := os.Stat(configPath); err != nil {
return "", fmt.Errorf("envoy binary not found at %s: %w", configPath, err)
}
return configPath, nil
}
path, err := exec.LookPath("envoy")
if err != nil {
return "", fmt.Errorf("envoy not found in PATH: %w", err)
}
return path, nil
}
// findFreePort asks the OS for an available TCP port.
func findFreePort() (uint16, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
port := uint16(ln.Addr().(*net.TCPAddr).Port)
ln.Close()
return port, nil
}
// logWriter adapts log.Entry to io.Writer for piping process output.
type logWriter struct {
entry *log.Entry
level log.Level
}
func (w *logWriter) Write(p []byte) (int, error) {
msg := strings.TrimRight(string(p), "\n\r")
if msg == "" {
return len(p), nil
}
switch w.level {
case log.WarnLevel:
w.entry.Warn(msg)
default:
w.entry.Debug(msg)
}
return len(p), nil
}
// Ensure logWriter satisfies io.Writer.
var _ io.Writer = (*logWriter)(nil)

View File

@@ -0,0 +1,382 @@
package inspect
import (
"bytes"
"fmt"
"strings"
"text/template"
)
// envoyBootstrapTmpl generates the full envoy bootstrap with rule translation.
// TLS rules become per-SNI filter chains; HTTP rules become per-domain virtual hosts.
var envoyBootstrapTmpl = template.Must(template.New("bootstrap").Funcs(template.FuncMap{
"quote": func(s string) string { return fmt.Sprintf("%q", s) },
}).Parse(`node:
id: netbird-inspect
cluster: netbird
admin:
address:
socket_address:
address: 127.0.0.1
port_value: {{.AdminPort}}
static_resources:
listeners:
- name: inspect_listener
address:
socket_address:
address: 127.0.0.1
port_value: {{.ListenPort}}
listener_filters:
- name: envoy.filters.listener.proxy_protocol
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.proxy_protocol.v3.ProxyProtocol
- name: envoy.filters.listener.tls_inspector
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.tls_inspector.v3.TlsInspector
filter_chains:
{{- /* TLS filter chains: per-SNI block/allow + default */ -}}
{{- range .TLSChains}}
- filter_chain_match:
transport_protocol: tls
{{- if .ServerNames}}
server_names:
{{- range .ServerNames}}
- {{quote .}}
{{- end}}
{{- end}}
filters:
{{$.NetworkFiltersSnippet}} - name: envoy.filters.network.tcp_proxy
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy
stat_prefix: {{.StatPrefix}}
cluster: original_dst
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] tcp %DOWNSTREAM_REMOTE_ADDRESS% -> %UPSTREAM_HOST% %RESPONSE_FLAGS% %DURATION%ms\n"
{{- end}}
{{- /* Plain HTTP filter chain with per-domain virtual hosts */}}
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: inspect_http
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] http %DOWNSTREAM_REMOTE_ADDRESS% %REQ(:AUTHORITY)% %REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %RESPONSE_CODE% %RESPONSE_FLAGS% %DURATION%ms\n"
http_filters:
{{.HTTPFiltersSnippet}} - name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
route_config:
virtual_hosts:
{{- range .VirtualHosts}}
- name: {{.Name}}
domains: [{{.DomainsStr}}]
routes:
{{- range .Routes}}
- match:
prefix: "{{if .PathPrefix}}{{.PathPrefix}}{{else}}/{{end}}"
{{- if .Block}}
direct_response:
status: 403
body:
filename: "{{$.BlockPagePath}}"
{{- else}}
route:
cluster: original_dst
{{- end}}
{{- end}}
{{- end}}
clusters:
- name: original_dst
type: ORIGINAL_DST
lb_policy: CLUSTER_PROVIDED
connect_timeout: 10s
{{.ExtraClusters}}`))
// tlsChain represents a TLS filter chain entry for the template.
// All TLS chains are passthrough (block decisions happen in Go before envoy).
type tlsChain struct {
// ServerNames restricts this chain to specific SNIs. Empty is catch-all.
ServerNames []string
StatPrefix string
}
// envoyRoute represents a single route entry within a virtual host.
type envoyRoute struct {
// PathPrefix for envoy prefix match. Empty means catch-all "/".
PathPrefix string
Block bool
}
// virtualHost represents an HTTP virtual host entry for the template.
type virtualHost struct {
Name string
// DomainsStr is pre-formatted for the template: "a", "b".
DomainsStr string
Routes []envoyRoute
}
type bootstrapData struct {
AdminPort uint16
ListenPort uint16
BlockPagePath string
TLSChains []tlsChain
VirtualHosts []virtualHost
HTTPFiltersSnippet string
NetworkFiltersSnippet string
ExtraClusters string
}
// generateBootstrap produces the envoy bootstrap YAML from the inspect config.
// Translates inspection rules into envoy-native per-SNI and per-domain routing.
// blockPagePath is the path to the HTML block page file served by direct_response.
func generateBootstrap(config Config, listenPort, adminPort uint16, blockPagePath string) ([]byte, error) {
data := bootstrapData{
AdminPort: adminPort,
BlockPagePath: blockPagePath,
ListenPort: listenPort,
TLSChains: buildTLSChains(config),
VirtualHosts: buildVirtualHosts(config),
}
if config.Envoy != nil && config.Envoy.Snippets != nil {
s := config.Envoy.Snippets
data.HTTPFiltersSnippet = indentSnippet(s.HTTPFilters, 18)
data.NetworkFiltersSnippet = indentSnippet(s.NetworkFilters, 12)
data.ExtraClusters = indentSnippet(s.Clusters, 4)
}
var buf bytes.Buffer
if err := envoyBootstrapTmpl.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute bootstrap template: %w", err)
}
return buf.Bytes(), nil
}
// buildTLSChains translates inspection rules into envoy TLS filter chains.
// Block rules -> per-SNI chain routing to blackhole.
// Allow rules (when default=block) -> per-SNI chain routing to original_dst.
// Default chain follows DefaultAction.
func buildTLSChains(config Config) []tlsChain {
// TLS block decisions happen in Go before forwarding to envoy, so we only
// generate allow/passthrough chains here. Envoy can't cleanly close a TLS
// connection without completing a handshake, so blocked SNIs never reach envoy.
var allowed []string
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTPS, ProtoH2) {
continue
}
for _, d := range rule.Domains {
sni := d.PunycodeString()
if rule.Action == ActionAllow || rule.Action == ActionInspect {
allowed = append(allowed, sni)
}
}
}
var chains []tlsChain
if len(allowed) > 0 && config.DefaultAction == ActionBlock {
chains = append(chains, tlsChain{
ServerNames: allowed,
StatPrefix: "tls_allowed",
})
}
// Default catch-all: passthrough (blocked SNIs never arrive here)
chains = append(chains, tlsChain{
StatPrefix: "tls_default",
})
return chains
}
// buildVirtualHosts translates inspection rules into envoy HTTP virtual hosts.
// Groups rules by domain, generates per-path routes within each virtual host.
func buildVirtualHosts(config Config) []virtualHost {
// Group rules by domain for per-domain virtual hosts.
type domainRules struct {
domains []string
routes []envoyRoute
}
domainRouteMap := make(map[string][]envoyRoute)
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTP, ProtoWebSocket) {
continue
}
isBlock := rule.Action == ActionBlock
// Rules without domains or paths are handled by the default action.
if len(rule.Domains) == 0 && len(rule.Paths) == 0 {
continue
}
// Build routes for this rule's paths
var routes []envoyRoute
if len(rule.Paths) > 0 {
for _, p := range rule.Paths {
// Convert our path patterns to envoy prefix match.
// Strip trailing * for envoy prefix matching.
prefix := strings.TrimSuffix(p, "*")
routes = append(routes, envoyRoute{PathPrefix: prefix, Block: isBlock})
}
} else {
routes = append(routes, envoyRoute{Block: isBlock})
}
if len(rule.Domains) > 0 {
for _, d := range rule.Domains {
host := d.PunycodeString()
domainRouteMap[host] = append(domainRouteMap[host], routes...)
}
} else {
// No domain: applies to all, add to default host
domainRouteMap["*"] = append(domainRouteMap["*"], routes...)
}
}
var hosts []virtualHost
idx := 0
// Per-domain virtual hosts with path routes
for domain, routes := range domainRouteMap {
if domain == "*" {
continue
}
// Add a catch-all route after path-specific routes.
// The catch-all follows the default action.
routes = append(routes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: fmt.Sprintf("domain_%d", idx),
DomainsStr: fmt.Sprintf("%q", domain),
Routes: routes,
})
idx++
}
// Default virtual host (catch-all for unmatched domains)
defaultRoutes := domainRouteMap["*"]
defaultRoutes = append(defaultRoutes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: "default",
DomainsStr: `"*"`,
Routes: defaultRoutes,
})
return hosts
}
// ruleTouchesProtocol returns true if the rule's protocol list includes any of the given protocols,
// or if the protocol list is empty (matches all).
func ruleTouchesProtocol(rule Rule, protos ...ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, rp := range rule.Protocols {
for _, p := range protos {
if rp == p {
return true
}
}
}
return false
}
// indentSnippet prepends each line of the YAML snippet with the given number of spaces.
// Returns empty string if snippet is empty.
func indentSnippet(snippet string, spaces int) string {
if snippet == "" {
return ""
}
prefix := make([]byte, spaces)
for i := range prefix {
prefix[i] = ' '
}
var buf bytes.Buffer
for i, line := range bytes.Split([]byte(snippet), []byte("\n")) {
if i > 0 {
buf.WriteByte('\n')
}
if len(line) > 0 {
buf.Write(prefix)
buf.Write(line)
}
}
buf.WriteByte('\n')
return buf.String()
}
// ValidateSnippets checks that user-provided snippets are safe to inject
// into the envoy config. Returns an error describing the first violation found.
//
// Validation rules:
// - Each snippet must be valid YAML (prevents syntax-level injection)
// - Snippets must not contain YAML document separators (--- or ...) that could
// break out of the indentation context
// - Snippets must only contain list items (starting with "- ") at the top level,
// matching what envoy expects for filters and clusters
func ValidateSnippets(snippets *EnvoySnippets) error {
if snippets == nil {
return nil
}
fields := []struct {
name string
value string
}{
{"http_filters", snippets.HTTPFilters},
{"network_filters", snippets.NetworkFilters},
{"clusters", snippets.Clusters},
}
for _, f := range fields {
if f.value == "" {
continue
}
if err := validateSnippetYAML(f.name, f.value); err != nil {
return err
}
}
return nil
}
func validateSnippetYAML(name, snippet string) error {
// Check for YAML document markers that could break template structure.
for _, line := range strings.Split(snippet, "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "---" || trimmed == "..." {
return fmt.Errorf("snippet %q: YAML document separators (--- or ...) are not allowed", name)
}
}
// Verify it's valid YAML by checking it doesn't cause template execution issues.
// We can't import yaml.v3 here without adding a dependency, so we do structural checks.
// Check for null bytes or control characters that could confuse YAML parsers.
for i, b := range []byte(snippet) {
if b == 0 {
return fmt.Errorf("snippet %q: null byte at position %d", name, i)
}
if b < 0x09 || (b > 0x0D && b < 0x20 && b != 0x1B) {
return fmt.Errorf("snippet %q: control character 0x%02x at position %d", name, b, i)
}
}
return nil
}

View File

@@ -0,0 +1,88 @@
package inspect
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
)
// PROXY protocol v2 constants (RFC 7239 / HAProxy spec)
var proxyV2Signature = [12]byte{
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51,
0x55, 0x49, 0x54, 0x0A,
}
const (
proxyV2VersionCommand = 0x21 // version 2, PROXY command
proxyV2FamilyTCP4 = 0x11 // AF_INET, STREAM
proxyV2FamilyTCP6 = 0x21 // AF_INET6, STREAM
)
// forwardToEnvoy forwards a connection to the given envoy sidecar via PROXY protocol v2.
// The caller provides the envoy manager snapshot to avoid accessing p.envoy without lock.
func (p *Proxy) forwardToEnvoy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo, em *envoyManager) error {
envoyAddr := em.ListenAddr()
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", envoyAddr.String())
if err != nil {
return fmt.Errorf("dial envoy at %s: %w", envoyAddr, err)
}
defer func() {
if err := conn.Close(); err != nil {
p.log.Debugf("close envoy conn: %v", err)
}
}()
if err := writeProxyV2Header(conn, src.IP, dst); err != nil {
return fmt.Errorf("write PROXY v2 header: %w", err)
}
p.log.Tracef("envoy: forwarded %s -> %s via PROXY v2", src.IP, dst)
return relay(ctx, pconn, conn)
}
// writeProxyV2Header writes a PROXY protocol v2 header to w.
// The header encodes the original source IP and the destination address:port.
func writeProxyV2Header(w net.Conn, srcIP netip.Addr, dst netip.AddrPort) error {
srcIP = srcIP.Unmap()
dstIP := dst.Addr().Unmap()
var (
family byte
addrs []byte
)
if srcIP.Is4() && dstIP.Is4() {
family = proxyV2FamilyTCP4
s4 := srcIP.As4()
d4 := dstIP.As4()
addrs = make([]byte, 12) // 4+4+2+2
copy(addrs[0:4], s4[:])
copy(addrs[4:8], d4[:])
binary.BigEndian.PutUint16(addrs[8:10], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[10:12], dst.Port())
} else {
family = proxyV2FamilyTCP6
s16 := srcIP.As16()
d16 := dstIP.As16()
addrs = make([]byte, 36) // 16+16+2+2
copy(addrs[0:16], s16[:])
copy(addrs[16:32], d16[:])
binary.BigEndian.PutUint16(addrs[32:34], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[34:36], dst.Port())
}
// Header: signature(12) + ver_cmd(1) + family(1) + len(2) + addrs
header := make([]byte, 16+len(addrs))
copy(header[0:12], proxyV2Signature[:])
header[12] = proxyV2VersionCommand
header[13] = family
binary.BigEndian.PutUint16(header[14:16], uint16(len(addrs)))
copy(header[16:], addrs)
_, err := w.Write(header)
return err
}

View File

@@ -0,0 +1,13 @@
//go:build !windows
package inspect
import (
"os"
"syscall"
)
// signalReload sends SIGHUP to the envoy process to trigger config reload.
func signalReload(p *os.Process) error {
return p.Signal(syscall.SIGHUP)
}

View File

@@ -0,0 +1,13 @@
//go:build windows
package inspect
import (
"fmt"
"os"
)
// signalReload is not supported on Windows. Envoy must be restarted.
func signalReload(_ *os.Process) error {
return fmt.Errorf("envoy config reload via signal not supported on Windows")
}

229
client/inspect/external.go Normal file
View File

@@ -0,0 +1,229 @@
package inspect
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"time"
)
const (
externalDialTimeout = 10 * time.Second
)
// handleExternal forwards the connection to an external proxy.
// For TLS connections, it uses HTTP CONNECT to tunnel through the proxy.
// For HTTP connections, it rewrites the request to use the proxy.
func (p *Proxy) handleExternal(ctx context.Context, pconn *peekConn, dst netip.AddrPort) error {
p.mu.RLock()
proxyURL := p.config.ExternalURL
p.mu.RUnlock()
if proxyURL == nil {
return fmt.Errorf("external proxy URL not configured")
}
switch proxyURL.Scheme {
case "http", "https":
return p.externalHTTPProxy(ctx, pconn, dst, proxyURL)
case "socks5":
return p.externalSOCKS5(ctx, pconn, dst, proxyURL)
default:
return fmt.Errorf("unsupported external proxy scheme: %s", proxyURL.Scheme)
}
}
// externalHTTPProxy tunnels through an HTTP proxy using CONNECT.
func (p *Proxy) externalHTTPProxy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "8080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial external proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close external proxy conn: %v", err)
}
}()
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", dst.String(), dst.String())
if proxyURL.User != nil {
connectReq += "Proxy-Authorization: Basic " + basicAuth(proxyURL.User) + "\r\n"
}
connectReq += "\r\n"
if _, err := io.WriteString(proxyConn, connectReq); err != nil {
return fmt.Errorf("send CONNECT to proxy: %w", err)
}
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), nil)
if err != nil {
return fmt.Errorf("read CONNECT response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close CONNECT resp body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
}
return relay(ctx, pconn, proxyConn)
}
// externalSOCKS5 tunnels through a SOCKS5 proxy.
func (p *Proxy) externalSOCKS5(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "1080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial SOCKS5 proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close SOCKS5 proxy conn: %v", err)
}
}()
if err := socks5Handshake(proxyConn, dst, proxyURL.User); err != nil {
return fmt.Errorf("SOCKS5 handshake: %w", err)
}
return relay(ctx, pconn, proxyConn)
}
// socks5Handshake performs the SOCKS5 handshake to connect through the proxy.
func socks5Handshake(conn net.Conn, dst netip.AddrPort, userinfo *url.Userinfo) error {
needAuth := userinfo != nil
// Greeting
var methods []byte
if needAuth {
methods = []byte{0x00, 0x02} // no auth, username/password
} else {
methods = []byte{0x00} // no auth
}
greeting := append([]byte{0x05, byte(len(methods))}, methods...)
if _, err := conn.Write(greeting); err != nil {
return fmt.Errorf("send greeting: %w", err)
}
// Server method selection
var methodResp [2]byte
if _, err := io.ReadFull(conn, methodResp[:]); err != nil {
return fmt.Errorf("read method selection: %w", err)
}
if methodResp[0] != 0x05 {
return fmt.Errorf("unexpected SOCKS version: %d", methodResp[0])
}
// Handle authentication if selected
if methodResp[1] == 0x02 {
if err := socks5Auth(conn, userinfo); err != nil {
return err
}
} else if methodResp[1] != 0x00 {
return fmt.Errorf("unsupported SOCKS5 auth method: %d", methodResp[1])
}
// Connection request
addr := dst.Addr()
var addrBytes []byte
if addr.Is4() {
a4 := addr.As4()
addrBytes = append([]byte{0x01}, a4[:]...) // IPv4
} else {
a16 := addr.As16()
addrBytes = append([]byte{0x04}, a16[:]...) // IPv6
}
port := dst.Port()
connectReq := append([]byte{0x05, 0x01, 0x00}, addrBytes...)
connectReq = append(connectReq, byte(port>>8), byte(port))
if _, err := conn.Write(connectReq); err != nil {
return fmt.Errorf("send connect request: %w", err)
}
// Read response (minimum 10 bytes for IPv4)
var respHeader [4]byte
if _, err := io.ReadFull(conn, respHeader[:]); err != nil {
return fmt.Errorf("read connect response: %w", err)
}
if respHeader[1] != 0x00 {
return fmt.Errorf("SOCKS5 connect failed: status %d", respHeader[1])
}
// Skip bound address
switch respHeader[3] {
case 0x01: // IPv4
var skip [4 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv4 address: %w", err)
}
case 0x04: // IPv6
var skip [16 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv6 address: %w", err)
}
case 0x03: // Domain
var dLen [1]byte
if _, err := io.ReadFull(conn, dLen[:]); err != nil {
return fmt.Errorf("read domain length: %w", err)
}
skip := make([]byte, int(dLen[0])+2)
if _, err := io.ReadFull(conn, skip); err != nil {
return fmt.Errorf("read SOCKS5 bound domain address: %w", err)
}
}
return nil
}
func socks5Auth(conn net.Conn, userinfo *url.Userinfo) error {
if userinfo == nil {
return fmt.Errorf("SOCKS5 auth required but no credentials provided")
}
user := userinfo.Username()
pass, _ := userinfo.Password()
// Username/password auth (RFC 1929)
auth := []byte{0x01, byte(len(user))}
auth = append(auth, []byte(user)...)
auth = append(auth, byte(len(pass)))
auth = append(auth, []byte(pass)...)
if _, err := conn.Write(auth); err != nil {
return fmt.Errorf("send auth: %w", err)
}
var resp [2]byte
if _, err := io.ReadFull(conn, resp[:]); err != nil {
return fmt.Errorf("read auth response: %w", err)
}
if resp[1] != 0x00 {
return fmt.Errorf("SOCKS5 auth failed: status %d", resp[1])
}
return nil
}
func basicAuth(userinfo *url.Userinfo) string {
user := userinfo.Username()
pass, _ := userinfo.Password()
return base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
}

532
client/inspect/http.go Normal file
View File

@@ -0,0 +1,532 @@
package inspect
import (
"bufio"
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
headerUpgrade = "Upgrade"
valueWebSocket = "websocket"
)
// inspectHTTP runs the HTTP inspection pipeline on decrypted traffic.
// It handles HTTP/1.1 (request-response loop), HTTP/2 (via Go stdlib reverse proxy),
// and WebSocket upgrade detection.
func (p *Proxy) inspectHTTP(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, proto string) error {
if proto == "h2" {
return p.inspectH2(ctx, client, remote, dst, sni, src)
}
return p.inspectH1(ctx, client, remote, dst, sni, src)
}
// inspectH1 handles HTTP/1.1 request-response inspection in a loop.
func (p *Proxy) inspectH1(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
clientReader := bufio.NewReader(client)
remoteReader := bufio.NewReader(remote)
for {
if ctx.Err() != nil {
return ctx.Err()
}
// Set idle timeout between requests to prevent connection hogging.
if err := client.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
return fmt.Errorf("set idle deadline: %w", err)
}
req, err := http.ReadRequest(clientReader)
if err != nil {
if isClosedErr(err) {
return nil
}
return fmt.Errorf("read HTTP request: %w", err)
}
if err := client.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
// Re-evaluate rules based on Host header if SNI was empty
host := hostFromRequest(req, sni)
// Domain fronting: Host header doesn't match TLS SNI
if isDomainFronting(req, sni) {
p.log.Debugf("domain fronting detected: SNI=%s Host=%s", sni.PunycodeString(), host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
proto := ProtoHTTP
if isWebSocketUpgrade(req) {
proto = ProtoWebSocket
}
action := p.evaluateAction(src.IP, host, dst, proto, req.URL.Path)
if action == ActionBlock {
p.log.Debugf("block: HTTP %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
p.log.Tracef("allow: HTTP %s %s (host=%s, action=%s)", req.Method, req.URL.Path, host.PunycodeString(), action)
// ICAP REQMOD: send request for inspection.
// Snapshot ICAP client under lock to avoid use-after-close races.
p.mu.RLock()
icap := p.icap
p.mu.RUnlock()
if icap != nil {
modified, err := icap.ReqMod(req)
if err != nil {
p.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
// Fail-closed: block on ICAP error
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP REQMOD: %w", err)
}
req = modified
}
if isWebSocketUpgrade(req) {
return p.handleWebSocket(ctx, req, client, clientReader, remote, remoteReader)
}
removeHopByHopHeaders(req.Header)
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward request: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read HTTP response: %w", err)
}
// ICAP RESPMOD: send response for inspection
if icap != nil {
modified, err := icap.RespMod(req, resp)
if err != nil {
p.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP RESPMOD: %w", err)
}
resp = modified
}
removeHopByHopHeaders(resp.Header)
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close resp body: %v", closeErr)
}
return fmt.Errorf("forward response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
// Connection: close means we're done
if resp.Close || req.Close {
return nil
}
}
}
// inspectH2 proxies HTTP/2 traffic using Go's http stack.
// Client and remote are already-established TLS connections with h2 negotiated.
func (p *Proxy) inspectH2(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
// For h2 MITM inspection, we use a local http.Server reading from the client
// connection and an http.Transport writing to the remote connection.
//
// The transport is configured to use the existing TLS connection to the
// real server. The handler inspects each request/response pair.
transport := &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
ForceAttemptHTTP2: true,
}
handler := &h2InspectionHandler{
proxy: p,
transport: transport,
dst: dst,
sni: sni,
src: src,
}
server := &http.Server{
Handler: handler,
}
// Serve the single client connection.
// ServeConn blocks until the connection is done.
errCh := make(chan error, 1)
go func() {
// http.Server doesn't have a direct ServeConn for h2,
// so we use Serve with a single-connection listener.
ln := &singleConnListener{conn: client}
errCh <- server.Serve(ln)
}()
select {
case <-ctx.Done():
if err := server.Close(); err != nil {
p.log.Debugf("close h2 server: %v", err)
}
return ctx.Err()
case err := <-errCh:
if err == http.ErrServerClosed {
return nil
}
return err
}
}
// h2InspectionHandler inspects each HTTP/2 request/response pair.
type h2InspectionHandler struct {
proxy *Proxy
transport http.RoundTripper
dst netip.AddrPort
sni domain.Domain
src SourceInfo
}
func (h *h2InspectionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
host := hostFromRequest(req, h.sni)
if isDomainFronting(req, h.sni) {
h.proxy.log.Debugf("domain fronting detected: SNI=%s Host=%s", h.sni.PunycodeString(), host.PunycodeString())
writeBlockPage(w, host)
return
}
action := h.proxy.evaluateAction(h.src.IP, host, h.dst, ProtoH2, req.URL.Path)
if action == ActionBlock {
h.proxy.log.Debugf("block: H2 %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockPage(w, host)
return
}
// ICAP REQMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.ReqMod(req)
if err != nil {
h.proxy.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
req = modified
}
// Forward to upstream
req.URL.Scheme = "https"
req.URL.Host = h.sni.PunycodeString()
req.RequestURI = ""
resp, err := h.transport.RoundTrip(req)
if err != nil {
h.proxy.log.Debugf("h2 upstream error for %s: %v", host.PunycodeString(), err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
h.proxy.log.Debugf("close h2 resp body: %v", err)
}
}()
// ICAP RESPMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.RespMod(req, resp)
if err != nil {
h.proxy.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
resp = modified
}
// Copy response headers and body
for k, vals := range resp.Header {
for _, v := range vals {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
h.proxy.log.Debugf("h2 response copy error: %v", err)
}
}
// handleWebSocket completes the WebSocket upgrade and relays frames bidirectionally.
func (p *Proxy) handleWebSocket(ctx context.Context, req *http.Request, client io.ReadWriter, clientReader *bufio.Reader, remote io.ReadWriter, remoteReader *bufio.Reader) error {
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward WebSocket upgrade: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read WebSocket upgrade response: %w", err)
}
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close ws resp body: %v", closeErr)
}
return fmt.Errorf("forward WebSocket upgrade response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close ws resp body: %v", err)
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("WebSocket upgrade rejected: status %d", resp.StatusCode)
}
p.log.Tracef("allow: WebSocket upgrade for %s", req.Host)
// Relay WebSocket frames bidirectionally.
// clientReader/remoteReader may have buffered data.
clientConn := mergeReadWriter(clientReader, client)
remoteConn := mergeReadWriter(remoteReader, remote)
return relayRW(ctx, clientConn, remoteConn)
}
// hostFromRequest extracts a domain.Domain from the HTTP request Host header,
// falling back to the SNI if Host is empty or an IP.
func hostFromRequest(req *http.Request, fallback domain.Domain) domain.Domain {
host := req.Host
if host == "" {
return fallback
}
// Strip port if present
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// If it's an IP address, use the SNI fallback
if _, err := netip.ParseAddr(host); err == nil {
return fallback
}
d, err := domain.FromString(host)
if err != nil {
return fallback
}
return d
}
// isDomainFronting detects domain fronting: the Host header doesn't match the
// SNI used during the TLS handshake. Only meaningful when SNI is non-empty
// (i.e., we're in MITM mode and know the original SNI).
func isDomainFronting(req *http.Request, sni domain.Domain) bool {
if sni == "" {
return false
}
host := hostFromRequest(req, "")
if host == "" {
return false
}
// Host should match SNI or be a subdomain of SNI
if host == sni {
return false
}
// Allow www.example.com when SNI is example.com
sniStr := sni.PunycodeString()
hostStr := host.PunycodeString()
if strings.HasSuffix(hostStr, "."+sniStr) {
return false
}
return true
}
func isWebSocketUpgrade(req *http.Request) bool {
return strings.EqualFold(req.Header.Get(headerUpgrade), valueWebSocket)
}
// writeBlockPage writes the styled HTML block page to an http.ResponseWriter (H2 path).
func writeBlockPage(w http.ResponseWriter, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
w.WriteHeader(http.StatusForbidden)
io.WriteString(w, body)
}
func writeBlockResponse(w io.Writer, _ *http.Request, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
resp := &http.Response{
StatusCode: http.StatusForbidden,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
ContentLength: int64(len(body)),
Body: io.NopCloser(strings.NewReader(body)),
}
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Header.Set("Connection", "close")
resp.Header.Set("Cache-Control", "no-store")
_ = resp.Write(w)
}
// blockPageHTML is the self-contained HTML block page.
// Uses NetBird dark theme with orange accent. Two format args: page title domain, displayed domain.
const blockPageHTML = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>Blocked - %s</title>
<style>
*{margin:0;padding:0;box-sizing:border-box}
body{background:#181a1d;color:#d1d5db;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;min-height:100vh;display:flex;align-items:center;justify-content:center}
.c{text-align:center;max-width:460px;padding:2rem}
.shield{width:56px;height:56px;margin:0 auto 1.5rem;border-radius:16px;background:#2b2f33;display:flex;align-items:center;justify-content:center}
.shield svg{width:28px;height:28px;color:#f68330}
.code{font-size:.8rem;font-weight:500;color:#f68330;font-family:ui-monospace,monospace;letter-spacing:.05em;margin-bottom:.5rem}
h1{font-size:1.5rem;font-weight:600;color:#f4f4f5;margin-bottom:.5rem}
p{font-size:.95rem;line-height:1.5;color:#9ca3af;margin-bottom:1.75rem}
.domain{display:inline-block;background:#25282d;border:1px solid #32363d;border-radius:6px;padding:.15rem .5rem;font-family:ui-monospace,monospace;font-size:.85rem;color:#d1d5db}
.footer{font-size:.7rem;color:#6b7280;margin-top:2rem;letter-spacing:.03em}
.footer a{color:#6b7280;text-decoration:none}
.footer a:hover{color:#9ca3af}
</style>
</head>
<body>
<div class="c">
<div class="shield"><svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m0-10.036A11.959 11.959 0 0 1 3.598 6 11.99 11.99 0 0 0 3 9.75c0 5.592 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751A11.96 11.96 0 0 0 12 3.714Z"/></svg></div>
<div class="code">403 BLOCKED</div>
<h1>Access Denied</h1>
<p>This connection to <span class="domain">%s</span> has been blocked by your organization's network policy.</p>
<div class="footer">Protected by <a href="https://netbird.io" target="_blank" rel="noopener">NetBird</a></div>
</div>
</body>
</html>`
// singleConnListener is a net.Listener that yields a single connection.
type singleConnListener struct {
conn net.Conn
once sync.Once
ch chan struct{}
}
func (l *singleConnListener) Accept() (net.Conn, error) {
var accepted bool
l.once.Do(func() {
l.ch = make(chan struct{})
accepted = true
})
if accepted {
return l.conn, nil
}
// Block until Close
<-l.ch
return nil, net.ErrClosed
}
func (l *singleConnListener) Close() error {
l.once.Do(func() {
l.ch = make(chan struct{})
})
select {
case <-l.ch:
default:
close(l.ch)
}
return nil
}
func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}
type readWriter struct {
io.Reader
io.Writer
}
func mergeReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
return &readWriter{Reader: r, Writer: w}
}
// relayRW copies data bidirectionally between two ReadWriters.
func relayRW(ctx context.Context, a, b io.ReadWriter) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(b, a)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(a, b)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// hopByHopHeaders are HTTP/1.1 headers that apply to a single connection
// and must not be forwarded by a proxy (RFC 7230, Section 6.1).
var hopByHopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"TE",
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
// removeHopByHopHeaders strips hop-by-hop headers from h.
// Also removes headers listed in the Connection header value.
func removeHopByHopHeaders(h http.Header) {
// First, remove any headers named in the Connection header
for _, connHeader := range h["Connection"] {
for _, name := range strings.Split(connHeader, ",") {
h.Del(strings.TrimSpace(name))
}
}
for _, name := range hopByHopHeaders {
h.Del(name)
}
}

479
client/inspect/icap.go Normal file
View File

@@ -0,0 +1,479 @@
package inspect
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
icapVersion = "ICAP/1.0"
icapDefaultPort = "1344"
icapConnTimeout = 30 * time.Second
icapRWTimeout = 60 * time.Second
icapMaxPoolSize = 8
icapIdleTimeout = 60 * time.Second
icapMaxRespSize = 4 * 1024 * 1024 // 4 MB
)
// ICAPClient implements an ICAP (RFC 3507) client with persistent connection pooling.
type ICAPClient struct {
reqModURL *url.URL
respModURL *url.URL
pool chan *icapConn
mu sync.Mutex
log *log.Entry
maxPool int
}
type icapConn struct {
conn net.Conn
reader *bufio.Reader
lastUse time.Time
}
// NewICAPClient creates an ICAP client. Either or both URLs may be nil
// to disable that mode.
func NewICAPClient(logger *log.Entry, cfg *ICAPConfig) *ICAPClient {
maxPool := cfg.MaxConnections
if maxPool <= 0 {
maxPool = icapMaxPoolSize
}
return &ICAPClient{
reqModURL: cfg.ReqModURL,
respModURL: cfg.RespModURL,
pool: make(chan *icapConn, maxPool),
log: logger,
maxPool: maxPool,
}
}
// ReqMod sends an HTTP request to the ICAP REQMOD service for inspection.
// Returns the (possibly modified) request, or the original if ICAP returns 204.
// Returns nil, nil if REQMOD is not configured.
func (c *ICAPClient) ReqMod(req *http.Request) (*http.Request, error) {
if c.reqModURL == nil {
return req, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
respBody, err := c.send("REQMOD", c.reqModURL, reqBuf.Bytes(), nil)
if err != nil {
return nil, err
}
if respBody == nil {
return req, nil
}
modified, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(respBody)))
if err != nil {
return nil, fmt.Errorf("parse ICAP modified request: %w", err)
}
return modified, nil
}
// RespMod sends an HTTP response to the ICAP RESPMOD service for inspection.
// Returns the (possibly modified) response, or the original if ICAP returns 204.
// Returns nil, nil if RESPMOD is not configured.
func (c *ICAPClient) RespMod(req *http.Request, resp *http.Response) (*http.Response, error) {
if c.respModURL == nil {
return resp, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
var respBuf bytes.Buffer
if err := resp.Write(&respBuf); err != nil {
return nil, fmt.Errorf("serialize response: %w", err)
}
respBody, err := c.send("RESPMOD", c.respModURL, reqBuf.Bytes(), respBuf.Bytes())
if err != nil {
return nil, err
}
if respBody == nil {
// 204 No Content: ICAP server didn't modify the response.
// Reconstruct from the buffered copy since resp.Body was consumed by Write.
reconstructed, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBuf.Bytes())), req)
if err != nil {
return nil, fmt.Errorf("reconstruct response after ICAP 204: %w", err)
}
return reconstructed, nil
}
modified, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBody)), req)
if err != nil {
return nil, fmt.Errorf("parse ICAP modified response: %w", err)
}
return modified, nil
}
// Close drains and closes all pooled connections.
func (c *ICAPClient) Close() {
close(c.pool)
for ic := range c.pool {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close ICAP connection: %v", err)
}
}
}
// send executes an ICAP request and returns the encapsulated body from the response.
// Returns nil body for 204 No Content (no modification).
// Retries once on stale pooled connection (EOF on read).
func (c *ICAPClient) send(method string, serviceURL *url.URL, reqData, respData []byte) ([]byte, error) {
statusCode, headers, body, err := c.trySend(method, serviceURL, reqData, respData)
if err != nil && isStaleConnErr(err) {
// Retry once with a fresh connection (stale pool entry).
c.log.Debugf("ICAP %s: retrying after stale connection: %v", method, err)
statusCode, headers, body, err = c.trySend(method, serviceURL, reqData, respData)
}
if err != nil {
return nil, err
}
switch statusCode {
case 204:
return nil, nil
case 200:
return body, nil
default:
c.log.Debugf("ICAP %s returned status %d, headers: %v", method, statusCode, headers)
return nil, fmt.Errorf("ICAP %s: status %d", method, statusCode)
}
}
func (c *ICAPClient) trySend(method string, serviceURL *url.URL, reqData, respData []byte) (int, textproto.MIMEHeader, []byte, error) {
ic, err := c.getConn(serviceURL)
if err != nil {
return 0, nil, nil, fmt.Errorf("get ICAP connection: %w", err)
}
if err := c.writeRequest(ic, method, serviceURL, reqData, respData); err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after write error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("write ICAP %s: %w", method, err)
}
statusCode, headers, body, err := c.readResponse(ic)
if err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after read error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("read ICAP response: %w", err)
}
c.putConn(ic)
return statusCode, headers, body, nil
}
func isStaleConnErr(err error) bool {
if err == nil {
return false
}
s := err.Error()
return strings.Contains(s, "EOF") || strings.Contains(s, "broken pipe") || strings.Contains(s, "connection reset")
}
func (c *ICAPClient) writeRequest(ic *icapConn, method string, serviceURL *url.URL, reqData, respData []byte) error {
if err := ic.conn.SetWriteDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
// For RESPMOD, split the serialized HTTP response into headers and body.
// The body must be sent chunked per RFC 3507.
var respHdr, respBody []byte
if respData != nil {
if idx := bytes.Index(respData, []byte("\r\n\r\n")); idx >= 0 {
respHdr = respData[:idx+4] // include the \r\n\r\n separator
respBody = respData[idx+4:]
} else {
respHdr = respData
}
}
var buf bytes.Buffer
// Request line
fmt.Fprintf(&buf, "%s %s %s\r\n", method, serviceURL.String(), icapVersion)
// Headers
host := serviceURL.Host
fmt.Fprintf(&buf, "Host: %s\r\n", host)
fmt.Fprintf(&buf, "Connection: keep-alive\r\n")
fmt.Fprintf(&buf, "Allow: 204\r\n")
// Build Encapsulated header
offset := 0
var encapParts []string
if reqData != nil {
encapParts = append(encapParts, fmt.Sprintf("req-hdr=%d", offset))
offset += len(reqData)
}
if respHdr != nil {
encapParts = append(encapParts, fmt.Sprintf("res-hdr=%d", offset))
offset += len(respHdr)
}
if len(respBody) > 0 {
encapParts = append(encapParts, fmt.Sprintf("res-body=%d", offset))
} else {
encapParts = append(encapParts, fmt.Sprintf("null-body=%d", offset))
}
fmt.Fprintf(&buf, "Encapsulated: %s\r\n", strings.Join(encapParts, ", "))
fmt.Fprintf(&buf, "\r\n")
// Encapsulated sections
if reqData != nil {
buf.Write(reqData)
}
if respHdr != nil {
buf.Write(respHdr)
}
// Body in chunked encoding (only when there is an actual body section).
// Per RFC 3507 Section 4.4.1, null-body must not include any entity data.
if len(respBody) > 0 {
fmt.Fprintf(&buf, "%x\r\n", len(respBody))
buf.Write(respBody)
buf.WriteString("\r\n")
buf.WriteString("0\r\n\r\n")
}
_, err := ic.conn.Write(buf.Bytes())
return err
}
func (c *ICAPClient) readResponse(ic *icapConn) (int, textproto.MIMEHeader, []byte, error) {
if err := ic.conn.SetReadDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return 0, nil, nil, fmt.Errorf("set read deadline: %w", err)
}
tp := textproto.NewReader(ic.reader)
// Status line: "ICAP/1.0 200 OK"
statusLine, err := tp.ReadLine()
if err != nil {
return 0, nil, nil, fmt.Errorf("read status line: %w", err)
}
statusCode, err := parseICAPStatus(statusLine)
if err != nil {
return 0, nil, nil, err
}
// Headers
headers, err := tp.ReadMIMEHeader()
if err != nil {
return statusCode, nil, nil, fmt.Errorf("read ICAP headers: %w", err)
}
if statusCode == 204 {
return statusCode, headers, nil, nil
}
// Read encapsulated body based on Encapsulated header
body, err := c.readEncapsulatedBody(ic.reader, headers)
if err != nil {
return statusCode, headers, nil, fmt.Errorf("read encapsulated body: %w", err)
}
return statusCode, headers, body, nil
}
func (c *ICAPClient) readEncapsulatedBody(r *bufio.Reader, headers textproto.MIMEHeader) ([]byte, error) {
encap := headers.Get("Encapsulated")
if encap == "" {
return nil, nil
}
// Find the body offset from the Encapsulated header.
// The last section with a non-zero offset is the body.
// Read everything from the reader as the encapsulated content.
var totalSize int
parts := strings.Split(encap, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
eqIdx := strings.Index(part, "=")
if eqIdx < 0 {
continue
}
offset, err := strconv.Atoi(strings.TrimSpace(part[eqIdx+1:]))
if err != nil {
continue
}
if offset > totalSize {
totalSize = offset
}
}
// Read all available encapsulated data (headers + body)
// The body section uses chunked encoding per RFC 3507
var buf bytes.Buffer
if totalSize > 0 {
// Read the header sections (everything before the body offset)
headerBytes := make([]byte, totalSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("read encapsulated headers: %w", err)
}
buf.Write(headerBytes)
}
// Read chunked body
chunked := newChunkedReader(r)
body, err := io.ReadAll(io.LimitReader(chunked, icapMaxRespSize))
if err != nil {
return nil, fmt.Errorf("read chunked body: %w", err)
}
buf.Write(body)
return buf.Bytes(), nil
}
func (c *ICAPClient) getConn(serviceURL *url.URL) (*icapConn, error) {
// Try to get a pooled connection
for {
select {
case ic := <-c.pool:
if time.Since(ic.lastUse) > icapIdleTimeout {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close idle ICAP connection: %v", err)
}
continue
}
return ic, nil
default:
return c.dialConn(serviceURL)
}
}
}
func (c *ICAPClient) putConn(ic *icapConn) {
c.mu.Lock()
defer c.mu.Unlock()
ic.lastUse = time.Now()
select {
case c.pool <- ic:
default:
// Pool full, close connection.
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close excess ICAP connection: %v", err)
}
}
}
func (c *ICAPClient) dialConn(serviceURL *url.URL) (*icapConn, error) {
host := serviceURL.Host
if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, icapDefaultPort)
}
conn, err := net.DialTimeout("tcp", host, icapConnTimeout)
if err != nil {
return nil, fmt.Errorf("dial ICAP %s: %w", host, err)
}
return &icapConn{
conn: conn,
reader: bufio.NewReader(conn),
lastUse: time.Now(),
}, nil
}
func parseICAPStatus(line string) (int, error) {
// "ICAP/1.0 200 OK"
parts := strings.SplitN(line, " ", 3)
if len(parts) < 2 {
return 0, fmt.Errorf("malformed ICAP status line: %q", line)
}
code, err := strconv.Atoi(parts[1])
if err != nil {
return 0, fmt.Errorf("parse ICAP status code %q: %w", parts[1], err)
}
return code, nil
}
// chunkedReader reads ICAP chunked encoding (same as HTTP chunked, terminated by "0\r\n\r\n").
type chunkedReader struct {
r *bufio.Reader
remaining int
done bool
}
func newChunkedReader(r *bufio.Reader) *chunkedReader {
return &chunkedReader{r: r}
}
func (cr *chunkedReader) Read(p []byte) (int, error) {
if cr.done {
return 0, io.EOF
}
if cr.remaining == 0 {
// Read chunk size line
line, err := cr.r.ReadString('\n')
if err != nil {
return 0, err
}
line = strings.TrimSpace(line)
// Strip any chunk extensions
if idx := strings.Index(line, ";"); idx >= 0 {
line = line[:idx]
}
size, err := strconv.ParseInt(line, 16, 64)
if err != nil {
return 0, fmt.Errorf("parse chunk size %q: %w", line, err)
}
if size == 0 {
cr.done = true
// Consume trailing \r\n
_, _ = cr.r.ReadString('\n')
return 0, io.EOF
}
if size < 0 || size > icapMaxRespSize {
return 0, fmt.Errorf("chunk size %d out of range (max %d)", size, icapMaxRespSize)
}
cr.remaining = int(size)
}
toRead := len(p)
if toRead > cr.remaining {
toRead = cr.remaining
}
n, err := cr.r.Read(p[:toRead])
cr.remaining -= n
if cr.remaining == 0 {
// Consume chunk-terminating \r\n
_, _ = cr.r.ReadString('\n')
}
return n, err
}

View File

@@ -0,0 +1,21 @@
//go:build !linux
package inspect
import (
"fmt"
"net"
"net/netip"
log "github.com/sirupsen/logrus"
)
// newTPROXYListener is not supported on non-Linux platforms.
func newTPROXYListener(_ *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
return nil, fmt.Errorf("TPROXY listener not supported on this platform (requested %s)", addr)
}
// getOriginalDst is not supported on non-Linux platforms.
func getOriginalDst(_ net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, fmt.Errorf("SO_ORIGINAL_DST not supported on this platform")
}

View File

@@ -0,0 +1,89 @@
package inspect
import (
"fmt"
"net"
"net/netip"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// newTPROXYListener creates a TCP listener for the transparent proxy.
// After nftables REDIRECT, accepted connections have LocalAddr = WG_IP:proxy_port.
// The original destination is retrieved via getsockopt(SO_ORIGINAL_DST).
func newTPROXYListener(logger *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
ln, err := net.Listen("tcp", addr.String())
if err != nil {
return nil, fmt.Errorf("listen on %s: %w", addr, err)
}
logger.Infof("inspect: listener started on %s", ln.Addr())
return ln, nil
}
// getOriginalDst reads the original destination from conntrack via SO_ORIGINAL_DST.
// This is set by the kernel when the connection was REDIRECT'd/DNAT'd.
// Tries IPv4 first, then falls back to IPv6 (IP6T_SO_ORIGINAL_DST).
func getOriginalDst(conn net.Conn) (netip.AddrPort, error) {
tc, ok := conn.(*net.TCPConn)
if !ok {
return netip.AddrPort{}, fmt.Errorf("not a TCPConn")
}
raw, err := tc.SyscallConn()
if err != nil {
return netip.AddrPort{}, fmt.Errorf("get syscall conn: %w", err)
}
var origDst netip.AddrPort
var sockErr error
if err := raw.Control(func(fd uintptr) {
// Try IPv4 first (SO_ORIGINAL_DST = 80)
var sa4 unix.RawSockaddrInet4
sa4Len := uint32(unsafe.Sizeof(sa4))
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IP,
80, // SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa4)),
uintptr(unsafe.Pointer(&sa4Len)),
0,
)
if errno == 0 {
addr := netip.AddrFrom4(sa4.Addr)
port := uint16(sa4.Port>>8) | uint16(sa4.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
return
}
// Fall back to IPv6 (IP6T_SO_ORIGINAL_DST = 80 on SOL_IPV6)
var sa6 unix.RawSockaddrInet6
sa6Len := uint32(unsafe.Sizeof(sa6))
_, _, errno = unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IPV6,
80, // IP6T_SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa6)),
uintptr(unsafe.Pointer(&sa6Len)),
0,
)
if errno != 0 {
sockErr = fmt.Errorf("getsockopt SO_ORIGINAL_DST (v4 and v6): %w", errno)
return
}
addr := netip.AddrFrom16(sa6.Addr)
port := uint16(sa6.Port>>8) | uint16(sa6.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
}); err != nil {
return netip.AddrPort{}, fmt.Errorf("control raw conn: %w", err)
}
if sockErr != nil {
return netip.AddrPort{}, sockErr
}
return origDst, nil
}

200
client/inspect/mitm.go Normal file
View File

@@ -0,0 +1,200 @@
package inspect
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
mrand "math/rand/v2"
"sync"
"time"
)
const (
// certCacheSize is the maximum number of cached leaf certificates.
certCacheSize = 1024
// certTTL is how long generated certificates remain valid.
certTTL = 24 * time.Hour
)
// certCache is a bounded LRU cache for generated TLS certificates.
type certCache struct {
mu sync.Mutex
entries map[string]*certEntry
// order tracks LRU eviction, most recent at end.
order []string
maxSize int
}
type certEntry struct {
cert *tls.Certificate
expiresAt time.Time
}
func newCertCache(maxSize int) *certCache {
return &certCache{
entries: make(map[string]*certEntry, maxSize),
order: make([]string, 0, maxSize),
maxSize: maxSize,
}
}
func (c *certCache) get(hostname string) (*tls.Certificate, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.entries[hostname]
if !ok {
return nil, false
}
if time.Now().After(entry.expiresAt) {
c.removeLocked(hostname)
return nil, false
}
// Move to end (most recently used)
c.touchLocked(hostname)
return entry.cert, true
}
func (c *certCache) put(hostname string, cert *tls.Certificate) {
c.mu.Lock()
defer c.mu.Unlock()
// Jitter the TTL by +/- 20% to prevent thundering herd on expiry.
jitter := time.Duration(float64(certTTL) * (0.8 + 0.4*mrand.Float64()))
if _, exists := c.entries[hostname]; exists {
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.touchLocked(hostname)
return
}
// Evict oldest if at capacity
for len(c.entries) >= c.maxSize && len(c.order) > 0 {
c.removeLocked(c.order[0])
}
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.order = append(c.order, hostname)
}
func (c *certCache) touchLocked(hostname string) {
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
c.order = append(c.order, hostname)
return
}
}
}
func (c *certCache) removeLocked(hostname string) {
delete(c.entries, hostname)
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
// CertProvider generates TLS certificates on the fly, signed by a CA.
// Generated certificates are cached in an LRU cache.
type CertProvider struct {
ca *x509.Certificate
caKey crypto.PrivateKey
cache *certCache
}
// NewCertProvider creates a certificate provider using the given CA.
func NewCertProvider(ca *x509.Certificate, caKey crypto.PrivateKey) *CertProvider {
return &CertProvider{
ca: ca,
caKey: caKey,
cache: newCertCache(certCacheSize),
}
}
// GetCertificate returns a TLS certificate for the given hostname,
// generating and caching one if necessary.
func (p *CertProvider) GetCertificate(hostname string) (*tls.Certificate, error) {
if cert, ok := p.cache.get(hostname); ok {
return cert, nil
}
cert, err := p.generateCert(hostname)
if err != nil {
return nil, fmt.Errorf("generate cert for %s: %w", hostname, err)
}
p.cache.put(hostname, cert)
return cert, nil
}
// GetTLSConfig returns a tls.Config that dynamically provides certificates
// for any hostname using the MITM CA.
func (p *CertProvider) GetTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return p.GetCertificate(hello.ServerName)
},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}
}
func (p *CertProvider) generateCert(hostname string) (*tls.Certificate, error) {
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: hostname,
},
NotBefore: now.Add(-5 * time.Minute),
NotAfter: now.Add(certTTL),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
DNSNames: []string{hostname},
}
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate leaf key: %w", err)
}
certDER, err := x509.CreateCertificate(rand.Reader, template, p.ca, &leafKey.PublicKey, p.caKey)
if err != nil {
return nil, fmt.Errorf("sign leaf certificate: %w", err)
}
leafCert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse generated certificate: %w", err)
}
return &tls.Certificate{
Certificate: [][]byte{certDER, p.ca.Raw},
PrivateKey: leafKey,
Leaf: leafCert,
}, nil
}

133
client/inspect/mitm_test.go Normal file
View File

@@ -0,0 +1,133 @@
package inspect
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func generateTestCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert, key
}
func TestCertProvider_GetCertificate(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert, err := provider.GetCertificate("example.com")
require.NoError(t, err)
require.NotNil(t, cert)
// Verify the leaf certificate
assert.Equal(t, "example.com", cert.Leaf.Subject.CommonName)
assert.Contains(t, cert.Leaf.DNSNames, "example.com")
// Verify chain: leaf + CA
assert.Len(t, cert.Certificate, 2)
// Verify leaf is signed by our CA
pool := x509.NewCertPool()
pool.AddCert(ca)
_, err = cert.Leaf.Verify(x509.VerifyOptions{
Roots: pool,
})
require.NoError(t, err)
}
func TestCertProvider_CachesResults(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
// Same pointer = cached
assert.Equal(t, cert1, cert2)
}
func TestCertProvider_DifferentHostsDifferentCerts(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("a.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("b.example.com")
require.NoError(t, err)
assert.NotEqual(t, cert1.Leaf.SerialNumber, cert2.Leaf.SerialNumber)
}
func TestCertProvider_TLSConfigHandshake(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
tlsConfig := provider.GetTLSConfig()
require.NotNil(t, tlsConfig)
require.NotNil(t, tlsConfig.GetCertificate)
// Simulate a ClientHelloInfo
hello := &tls.ClientHelloInfo{
ServerName: "handshake.example.com",
}
cert, err := tlsConfig.GetCertificate(hello)
require.NoError(t, err)
assert.Equal(t, "handshake.example.com", cert.Leaf.Subject.CommonName)
}
func TestCertCache_Eviction(t *testing.T) {
cache := newCertCache(3)
for i := range 5 {
hostname := string(rune('a'+i)) + ".example.com"
cache.put(hostname, &tls.Certificate{})
}
// Only 3 should remain (c, d, e - the most recent)
assert.Len(t, cache.entries, 3)
_, ok := cache.get("a.example.com")
assert.False(t, ok, "oldest entry should be evicted")
_, ok = cache.get("b.example.com")
assert.False(t, ok, "second oldest should be evicted")
_, ok = cache.get("e.example.com")
assert.True(t, ok, "newest entry should exist")
}

109
client/inspect/peek.go Normal file
View File

@@ -0,0 +1,109 @@
package inspect
import (
"bytes"
"fmt"
"io"
"net"
)
// peekConn wraps a net.Conn with a buffer that allows reading ahead
// without consuming data. Subsequent Read calls return the buffered
// bytes first, then read from the underlying connection.
type peekConn struct {
net.Conn
buf bytes.Buffer
// peeked holds the raw bytes that were peeked, available for replay.
peeked []byte
}
// newPeekConn wraps conn for peek-ahead reading.
func newPeekConn(conn net.Conn) *peekConn {
return &peekConn{Conn: conn}
}
// Peek reads exactly n bytes from the connection without consuming them.
// The peeked bytes are replayed on subsequent Read calls.
// Peek may only be called once; calling it again returns an error.
func (c *peekConn) Peek(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
if _, err := io.ReadFull(c.Conn, buf); err != nil {
return nil, fmt.Errorf("peek %d bytes: %w", n, err)
}
c.peeked = buf
c.buf.Write(buf)
return buf, nil
}
// PeekAll reads up to n bytes, returning whatever is available.
// Unlike Peek, it does not require exactly n bytes.
func (c *peekConn) PeekAll(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
nr, err := c.Conn.Read(buf)
if nr > 0 {
c.peeked = buf[:nr]
c.buf.Write(c.peeked)
}
if err != nil && nr == 0 {
return nil, fmt.Errorf("peek: %w", err)
}
return c.peeked, nil
}
// PeekMore extends the peeked buffer to at least n total bytes.
// The buffer is reset and refilled with the extended data.
// The returned slice is the internal peeked buffer; callers must not
// retain references from prior Peek/PeekMore calls after calling this.
func (c *peekConn) PeekMore(n int) ([]byte, error) {
if len(c.peeked) >= n {
return c.peeked[:n], nil
}
remaining := n - len(c.peeked)
extra := make([]byte, remaining)
if _, err := io.ReadFull(c.Conn, extra); err != nil {
return nil, fmt.Errorf("peek more %d bytes: %w", remaining, err)
}
// Pre-allocate to avoid reallocation detaching previously returned slices.
combined := make([]byte, 0, n)
combined = append(combined, c.peeked...)
combined = append(combined, extra...)
c.peeked = combined
c.buf.Reset()
c.buf.Write(c.peeked)
return c.peeked, nil
}
// Peeked returns the bytes that were peeked so far, or nil if Peek hasn't been called.
func (c *peekConn) Peeked() []byte {
return c.peeked
}
// Read returns buffered peek data first, then reads from the underlying connection.
func (c *peekConn) Read(p []byte) (int, error) {
if c.buf.Len() > 0 {
return c.buf.Read(p)
}
return c.Conn.Read(p)
}
// reader returns an io.Reader that replays buffered bytes then reads from conn.
func (c *peekConn) reader() io.Reader {
if c.buf.Len() > 0 {
return io.MultiReader(&c.buf, c.Conn)
}
return c.Conn
}

482
client/inspect/proxy.go Normal file
View File

@@ -0,0 +1,482 @@
package inspect
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// ErrBlocked is returned when a connection is denied by proxy policy.
var ErrBlocked = errors.New("connection blocked by proxy policy")
const (
// headerReadTimeout is the deadline for reading the initial protocol header.
// Prevents slow loris attacks where a client opens a connection but sends data slowly.
headerReadTimeout = 10 * time.Second
// idleTimeout is the deadline for idle connections between HTTP requests.
idleTimeout = 120 * time.Second
)
// Proxy is the inspection engine for traffic passing through a NetBird
// routing peer. It handles protocol detection, rule evaluation, MITM TLS
// decryption, ICAP delegation, and external proxy forwarding.
type Proxy struct {
config Config
rules *RuleEngine
certs *CertProvider
icap *ICAPClient
// envoy is nil unless mode is ModeEnvoy.
envoy *envoyManager
// dialer is the outbound dialer (with SO_MARK cleared on Linux).
dialer net.Dialer
log *log.Entry
// wgNetwork is the WG overlay prefix; dial targets inside it are blocked.
wgNetwork netip.Prefix
// localIPs reports the routing peer's own IPs; dial targets are blocked.
localIPs LocalIPChecker
// listener is the TPROXY/REDIRECT listener for kernel mode.
listener net.Listener
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// LocalIPChecker reports whether an IP belongs to the local machine.
type LocalIPChecker interface {
IsLocalIP(netip.Addr) bool
}
// New creates a transparent proxy with the given configuration.
func New(ctx context.Context, logger *log.Entry, config Config) (*Proxy, error) {
ctx, cancel := context.WithCancel(ctx)
p := &Proxy{
config: config,
rules: NewRuleEngine(logger, config.DefaultAction),
dialer: newOutboundDialer(),
log: logger,
wgNetwork: config.WGNetwork,
localIPs: config.LocalIPChecker,
ctx: ctx,
cancel: cancel,
}
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Initialize MITM certificate provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
}
// Initialize ICAP client
if config.ICAP != nil {
p.icap = NewICAPClient(logger, config.ICAP)
}
// Start envoy sidecar if configured
if config.Mode == ModeEnvoy {
envoyLog := logger.WithField("sidecar", "envoy")
em, err := startEnvoy(ctx, envoyLog, config)
if err != nil {
cancel()
return nil, fmt.Errorf("start envoy sidecar: %w", err)
}
p.envoy = em
}
// Start TPROXY listener for kernel mode
if config.ListenAddr.IsValid() {
ln, err := newTPROXYListener(logger, config.ListenAddr, netip.Prefix{})
if err != nil {
cancel()
return nil, fmt.Errorf("start TPROXY listener on %s: %w", config.ListenAddr, err)
}
p.listener = ln
go p.acceptLoop(ln)
}
return p, nil
}
// HandleTCP is the entry point for TCP connections from the userspace forwarder.
// It determines the protocol (TLS or plaintext HTTP), evaluates rules,
// and either blocks, passes through, inspects, or forwards to an external proxy.
func (p *Proxy) HandleTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) error {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
p.mu.RLock()
mode := p.config.Mode
p.mu.RUnlock()
if mode == ModeExternal {
pconn := newPeekConn(clientConn)
return p.handleExternal(ctx, pconn, dst)
}
// Envoy and builtin modes both peek the protocol header for rule evaluation.
// Envoy mode forwards non-blocked traffic to envoy; builtin mode handles all locally.
// TLS blocks are handled by Go (instant close) since envoy can't cleanly RST a TLS connection.
// Built-in and envoy mode: peek 5 bytes (TLS record header size) to determine protocol.
// Set a read deadline to prevent slow loris attacks.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
return fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
if isTLSHandshake(header[0]) {
return p.handleTLS(ctx, pconn, dst, src)
}
if isHTTPMethod(header) {
return p.handlePlainHTTP(ctx, pconn, dst, src)
}
// Not TLS and not HTTP: evaluate rules with ProtoOther.
// If no rule explicitly allows "other", this falls through to the default action.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial for passthrough: %w", err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote conn: %v", err)
}
}()
return relay(ctx, pconn, remote)
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
return ErrBlocked
}
// InspectTCP evaluates rules for a TCP connection and returns the result.
// Unlike HandleTCP, it can return early for allow decisions, letting the caller
// handle the relay (USP forwarder passthrough optimization).
//
// When InspectResult.PassthroughConn is non-nil, ownership transfers to the caller:
// the caller must close the connection and relay traffic. The engine does not close it.
//
// When PassthroughConn is nil, the engine handled everything internally
// (block, inspect/MITM, or plain HTTP inspection) and closed the connection.
func (p *Proxy) InspectTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
p.mu.RLock()
mode := p.config.Mode
envoy := p.envoy
p.mu.RUnlock()
// External mode: handle internally, engine owns the connection.
if mode == ModeExternal {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
pconn := newPeekConn(clientConn)
err := p.handleExternal(ctx, pconn, dst)
return InspectResult{Action: ActionAllow}, err
}
// Peek protocol header.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("clear read deadline: %w", err)
}
// TLS: may return passthrough for allow.
if isTLSHandshake(header[0]) {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil && result.PassthroughConn == nil {
clientConn.Close()
return result, err
}
// Envoy mode: forward allowed TLS to envoy instead of returning passthrough.
if result.PassthroughConn != nil && envoy != nil {
defer clientConn.Close()
envoyErr := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, envoyErr
}
return result, err
}
// Plain HTTP: in envoy mode, forward to envoy for L7 processing.
// In builtin mode, inspect per-request locally.
if isHTTPMethod(header) {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
if envoy != nil {
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
err := p.handlePlainHTTP(ctx, pconn, dst, src)
return InspectResult{Action: ActionInspect}, err
}
// Other protocol: evaluate rules.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
// Envoy mode: forward to envoy.
if envoy != nil {
defer clientConn.Close()
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
clientConn.Close()
return InspectResult{Action: ActionBlock}, ErrBlocked
}
// HandleUDPPacket inspects a UDP packet for QUIC Initial packets.
// Returns the action to take: ActionAllow to continue normal forwarding,
// ActionBlock to drop the packet.
// Non-QUIC packets always return ActionAllow.
func (p *Proxy) HandleUDPPacket(data []byte, dst netip.AddrPort, src SourceInfo) Action {
if len(data) < 5 {
return ActionAllow
}
// Check for QUIC Long Header
if data[0]&0x80 == 0 {
return ActionAllow
}
sni, err := ExtractQUICSNI(data)
if err != nil {
// Can't parse QUIC, allow through (could be non-QUIC UDP)
p.log.Tracef("QUIC SNI extraction failed for %s: %v", dst, err)
return ActionAllow
}
if sni == "" {
return ActionAllow
}
action := p.rules.Evaluate(src.IP, sni, dst.Addr(), dst.Port(), ProtoH3, "")
if action == ActionBlock {
p.log.Debugf("block: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
return ActionBlock
}
// QUIC can't be MITMed, treat Inspect as Allow
if action == ActionInspect {
p.log.Debugf("allow: QUIC to %s (SNI=%s), MITM not supported for QUIC", dst, sni.PunycodeString())
} else {
p.log.Tracef("allow: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
}
return ActionAllow
}
// handlePlainHTTP handles plaintext HTTP connections.
func (p *Proxy) handlePlainHTTP(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
// For plaintext HTTP, always inspect (we can see the traffic)
return p.inspectHTTP(ctx, pconn, remote, dst, "", src, "http/1.1")
}
// UpdateConfig replaces the inspection engine configuration at runtime.
func (p *Proxy) UpdateConfig(config Config) {
p.log.Debugf("config update: mode=%s rules=%d default=%s has_tls=%v has_icap=%v",
config.Mode, len(config.Rules), config.DefaultAction, config.TLS != nil, config.ICAP != nil)
p.mu.Lock()
p.config = config
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Update MITM provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
} else {
p.certs = nil
}
// Swap ICAP client under lock, close the old one outside to avoid blocking.
var oldICAP *ICAPClient
if config.ICAP != nil {
oldICAP = p.icap
p.icap = NewICAPClient(p.log, config.ICAP)
} else {
oldICAP = p.icap
p.icap = nil
}
// If switching away from envoy mode, clear and stop the old envoy.
var oldEnvoy *envoyManager
if config.Mode != ModeEnvoy && p.envoy != nil {
oldEnvoy = p.envoy
p.envoy = nil
}
envoy := p.envoy
p.mu.Unlock()
if oldICAP != nil {
oldICAP.Close()
}
if oldEnvoy != nil {
oldEnvoy.Stop()
}
// Reload envoy config if still in envoy mode.
if envoy != nil && config.Mode == ModeEnvoy {
if err := envoy.Reload(config); err != nil {
p.log.Errorf("inspect: envoy config reload: %v", err)
}
}
}
// Mode returns the current proxy operating mode.
func (p *Proxy) Mode() ProxyMode {
p.mu.RLock()
defer p.mu.RUnlock()
return p.config.Mode
}
// ListenPort returns the port to use for kernel-mode nftables REDIRECT.
// For builtin mode: the TPROXY listener port.
// For envoy mode: the envoy listener port (nftables redirects directly to envoy).
// Returns 0 if no listener is active.
func (p *Proxy) ListenPort() uint16 {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return envoy.listenPort
}
if p.listener == nil {
return 0
}
tcpAddr, ok := p.listener.Addr().(*net.TCPAddr)
if !ok {
return 0
}
return uint16(tcpAddr.Port)
}
// Close shuts down the proxy and releases resources.
func (p *Proxy) Close() error {
p.cancel()
p.mu.Lock()
envoy := p.envoy
p.envoy = nil
icap := p.icap
p.icap = nil
p.mu.Unlock()
if envoy != nil {
envoy.Stop()
}
if p.listener != nil {
if err := p.listener.Close(); err != nil {
p.log.Debugf("close TPROXY listener: %v", err)
}
}
if icap != nil {
icap.Close()
}
return nil
}
// acceptLoop accepts connections from the redirected listener (kernel mode).
// Connections arrive via nftables REDIRECT; original destination is read from conntrack.
func (p *Proxy) acceptLoop(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
if p.ctx.Err() != nil {
return
}
p.log.Debugf("accept error: %v", err)
continue
}
go func() {
// Read original destination from conntrack (SO_ORIGINAL_DST).
// nftables REDIRECT changes dst to the local WG IP:proxy_port,
// but conntrack preserves the real destination.
dstAddr, err := getOriginalDst(conn)
if err != nil {
p.log.Debugf("get original dst: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
p.log.Tracef("accepted: %s -> %s (original dst %s)",
conn.RemoteAddr(), conn.LocalAddr(), dstAddr)
srcAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
if err != nil {
p.log.Debugf("parse source: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
src := SourceInfo{
IP: srcAddr.Addr().Unmap(),
}
if err := p.HandleTCP(p.ctx, conn, dstAddr, src); err != nil && !errors.Is(err, ErrBlocked) {
p.log.Debugf("connection to %s: %v", dstAddr, err)
}
}()
}
}

388
client/inspect/quic.go Normal file
View File

@@ -0,0 +1,388 @@
package inspect
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
"github.com/netbirdio/netbird/shared/management/domain"
)
// QUIC version constants
const (
quicV1Version uint32 = 0x00000001
quicV2Version uint32 = 0x6b3343cf
)
// quicV1Salt is the initial salt for QUIC v1 (RFC 9001 Section 5.2).
var quicV1Salt = []byte{
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3,
0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
0xcc, 0xbb, 0x7f, 0x0a,
}
// quicV2Salt is the initial salt for QUIC v2 (RFC 9369).
var quicV2Salt = []byte{
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb,
0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb,
0xf9, 0xbd, 0x2e, 0xd9,
}
// ExtractQUICSNI extracts the SNI from a QUIC Initial packet.
// The Initial packet's encryption uses well-known keys derived from the
// Destination Connection ID, so any observer can decrypt it (by design).
func ExtractQUICSNI(data []byte) (domain.Domain, error) {
if len(data) < 5 {
return "", fmt.Errorf("packet too short")
}
// Check for QUIC Long Header (form bit set)
if data[0]&0x80 == 0 {
return "", fmt.Errorf("not a QUIC long header packet")
}
// Version
version := binary.BigEndian.Uint32(data[1:5])
var salt []byte
var initialLabel, keyLabel, ivLabel, hpLabel string
switch version {
case quicV1Version:
salt = quicV1Salt
initialLabel = "client in"
keyLabel = "quic key"
ivLabel = "quic iv"
hpLabel = "quic hp"
case quicV2Version:
salt = quicV2Salt
initialLabel = "client in"
keyLabel = "quicv2 key"
ivLabel = "quicv2 iv"
hpLabel = "quicv2 hp"
default:
return "", fmt.Errorf("unsupported QUIC version: 0x%08x", version)
}
// Parse Long Header
if len(data) < 6 {
return "", fmt.Errorf("packet too short for DCID length")
}
dcidLen := int(data[5])
if len(data) < 6+dcidLen+1 {
return "", fmt.Errorf("packet too short for DCID")
}
dcid := data[6 : 6+dcidLen]
scidLenOff := 6 + dcidLen
scidLen := int(data[scidLenOff])
tokenLenOff := scidLenOff + 1 + scidLen
if tokenLenOff >= len(data) {
return "", fmt.Errorf("packet too short for token length")
}
// Token length is a variable-length integer
tokenLen, tokenLenSize, err := readVarInt(data[tokenLenOff:])
if err != nil {
return "", fmt.Errorf("read token length: %w", err)
}
payloadLenOff := tokenLenOff + tokenLenSize + int(tokenLen)
if payloadLenOff >= len(data) {
return "", fmt.Errorf("packet too short for payload length")
}
// Payload length is a variable-length integer
payloadLen, payloadLenSize, err := readVarInt(data[payloadLenOff:])
if err != nil {
return "", fmt.Errorf("read payload length: %w", err)
}
pnOffset := payloadLenOff + payloadLenSize
if pnOffset+4 > len(data) {
return "", fmt.Errorf("packet too short for packet number")
}
// Derive initial keys
clientKey, clientIV, clientHP, err := deriveInitialKeys(dcid, salt, initialLabel, keyLabel, ivLabel, hpLabel)
if err != nil {
return "", fmt.Errorf("derive initial keys: %w", err)
}
// Remove header protection
sampleOffset := pnOffset + 4 // sample starts 4 bytes after pn offset
if sampleOffset+16 > len(data) {
return "", fmt.Errorf("packet too short for HP sample")
}
sample := data[sampleOffset : sampleOffset+16]
hpBlock, err := aes.NewCipher(clientHP)
if err != nil {
return "", fmt.Errorf("create HP cipher: %w", err)
}
mask := make([]byte, 16)
hpBlock.Encrypt(mask, sample)
// Unmask header byte
header := make([]byte, len(data))
copy(header, data)
header[0] ^= mask[0] & 0x0f // Long header: low 4 bits
// Determine packet number length
pnLen := int(header[0]&0x03) + 1
// Unmask packet number
for i := 0; i < pnLen; i++ {
header[pnOffset+i] ^= mask[1+i]
}
// Reconstruct packet number
var pn uint32
for i := 0; i < pnLen; i++ {
pn = (pn << 8) | uint32(header[pnOffset+i])
}
// Build nonce
nonce := make([]byte, len(clientIV))
copy(nonce, clientIV)
for i := 0; i < 4; i++ {
nonce[len(nonce)-1-i] ^= byte(pn >> (8 * i))
}
// Decrypt payload
block, err := aes.NewCipher(clientKey)
if err != nil {
return "", fmt.Errorf("create AES cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create AEAD: %w", err)
}
encryptedPayload := header[pnOffset+pnLen : pnOffset+int(payloadLen)]
aad := header[:pnOffset+pnLen]
plaintext, err := aead.Open(nil, nonce, encryptedPayload, aad)
if err != nil {
return "", fmt.Errorf("decrypt QUIC payload: %w", err)
}
// Parse CRYPTO frames to extract ClientHello
clientHello, err := extractCryptoFrames(plaintext)
if err != nil {
return "", fmt.Errorf("extract CRYPTO frames: %w", err)
}
info, err := parseHelloBody(clientHello)
return info.SNI, err
}
// deriveInitialKeys derives the client's initial encryption keys from the DCID.
func deriveInitialKeys(dcid, salt []byte, initialLabel, keyLabel, ivLabel, hpLabel string) (key, iv, hp []byte, err error) {
// initial_secret = HKDF-Extract(salt, DCID)
initialSecret := hkdf.Extract(sha256.New, dcid, salt)
// client_initial_secret = HKDF-Expand-Label(initial_secret, initialLabel, "", 32)
clientSecret, err := hkdfExpandLabel(initialSecret, initialLabel, nil, 32)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive client secret: %w", err)
}
// client_key = HKDF-Expand-Label(client_secret, keyLabel, "", 16)
key, err = hkdfExpandLabel(clientSecret, keyLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive key: %w", err)
}
// client_iv = HKDF-Expand-Label(client_secret, ivLabel, "", 12)
iv, err = hkdfExpandLabel(clientSecret, ivLabel, nil, 12)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive IV: %w", err)
}
// client_hp = HKDF-Expand-Label(client_secret, hpLabel, "", 16)
hp, err = hkdfExpandLabel(clientSecret, hpLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive HP key: %w", err)
}
return key, iv, hp, nil
}
// hkdfExpandLabel implements TLS 1.3 HKDF-Expand-Label.
func hkdfExpandLabel(secret []byte, label string, context []byte, length int) ([]byte, error) {
// HkdfLabel = struct {
// uint16 length;
// opaque label<7..255> = "tls13 " + Label;
// opaque context<0..255> = Context;
// }
fullLabel := "tls13 " + label
hkdfLabel := make([]byte, 2+1+len(fullLabel)+1+len(context))
binary.BigEndian.PutUint16(hkdfLabel[0:2], uint16(length))
hkdfLabel[2] = byte(len(fullLabel))
copy(hkdfLabel[3:], fullLabel)
hkdfLabel[3+len(fullLabel)] = byte(len(context))
if len(context) > 0 {
copy(hkdfLabel[4+len(fullLabel):], context)
}
expander := hkdf.Expand(sha256.New, secret, hkdfLabel)
out := make([]byte, length)
if _, err := io.ReadFull(expander, out); err != nil {
return nil, err
}
return out, nil
}
// maxCryptoFrameSize limits total CRYPTO frame data to prevent memory exhaustion.
const maxCryptoFrameSize = 64 * 1024
// extractCryptoFrames reassembles CRYPTO frame data from QUIC frames.
func extractCryptoFrames(frames []byte) ([]byte, error) {
var result []byte
pos := 0
for pos < len(frames) {
frameType := frames[pos]
switch {
case frameType == 0x00:
// PADDING frame
pos++
case frameType == 0x06:
// CRYPTO frame
pos++
offset, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto offset: %w", err)
}
pos += n
_ = offset // We assume ordered, offset 0 for Initial
dataLen, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto data length: %w", err)
}
pos += n
end := pos + int(dataLen)
if end > len(frames) {
return nil, fmt.Errorf("CRYPTO frame data truncated")
}
result = append(result, frames[pos:end]...)
if len(result) > maxCryptoFrameSize {
return nil, fmt.Errorf("CRYPTO frame data exceeds %d bytes", maxCryptoFrameSize)
}
pos = end
case frameType == 0x01:
// PING frame
pos++
case frameType == 0x02 || frameType == 0x03:
// ACK frame - skip
pos++
// Largest Acknowledged
_, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK: %w", err)
}
pos += n
// ACK Delay
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK delay: %w", err)
}
pos += n
// ACK Range Count
rangeCount, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range count: %w", err)
}
pos += n
// First ACK Range
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read first ACK range: %w", err)
}
pos += n
// Additional ranges
for i := uint64(0); i < rangeCount; i++ {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK gap: %w", err)
}
pos += n
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range: %w", err)
}
pos += n
}
// ECN counts for type 0x03
if frameType == 0x03 {
for range 3 {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ECN count: %w", err)
}
pos += n
}
}
default:
// Unknown frame type, stop parsing
if len(result) > 0 {
return result, nil
}
return nil, fmt.Errorf("unknown QUIC frame type: 0x%02x at offset %d", frameType, pos)
}
}
if len(result) == 0 {
return nil, fmt.Errorf("no CRYPTO frames found")
}
return result, nil
}
// readVarInt reads a QUIC variable-length integer.
// Returns (value, bytes consumed, error).
func readVarInt(data []byte) (uint64, int, error) {
if len(data) == 0 {
return 0, 0, fmt.Errorf("empty data for varint")
}
prefix := data[0] >> 6
length := 1 << prefix
if len(data) < length {
return 0, 0, fmt.Errorf("varint truncated: need %d, have %d", length, len(data))
}
var val uint64
switch length {
case 1:
val = uint64(data[0] & 0x3f)
case 2:
val = uint64(binary.BigEndian.Uint16(data[:2])) & 0x3fff
case 4:
val = uint64(binary.BigEndian.Uint32(data[:4])) & 0x3fffffff
case 8:
val = binary.BigEndian.Uint64(data[:8]) & 0x3fffffffffffffff
}
return val, length, nil
}

View File

@@ -0,0 +1,99 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadVarInt(t *testing.T) {
tests := []struct {
name string
data []byte
want uint64
n int
}{
{
name: "1 byte value",
data: []byte{0x25},
want: 37,
n: 1,
},
{
name: "2 byte value",
data: []byte{0x7b, 0xbd},
want: 15293,
n: 2,
},
{
name: "4 byte value",
data: []byte{0x9d, 0x7f, 0x3e, 0x7d},
want: 494878333,
n: 4,
},
{
name: "zero",
data: []byte{0x00},
want: 0,
n: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, n, err := readVarInt(tt.data)
require.NoError(t, err)
assert.Equal(t, tt.want, val)
assert.Equal(t, tt.n, n)
})
}
}
func TestReadVarInt_Empty(t *testing.T) {
_, _, err := readVarInt(nil)
require.Error(t, err)
}
func TestReadVarInt_Truncated(t *testing.T) {
// 2-byte prefix but only 1 byte
_, _, err := readVarInt([]byte{0x40})
require.Error(t, err)
}
func TestExtractQUICSNI_NotLongHeader(t *testing.T) {
// Short header packet (form bit not set)
data := make([]byte, 100)
data[0] = 0x40 // short header
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "not a QUIC long header")
}
func TestExtractQUICSNI_UnsupportedVersion(t *testing.T) {
data := make([]byte, 100)
data[0] = 0xC0 // long header
// Version 0xdeadbeef
data[1] = 0xde
data[2] = 0xad
data[3] = 0xbe
data[4] = 0xef
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported QUIC version")
}
func TestExtractQUICSNI_TooShort(t *testing.T) {
_, err := ExtractQUICSNI([]byte{0xC0, 0x00})
require.Error(t, err)
}
func TestHkdfExpandLabel(t *testing.T) {
// Smoke test: ensure it returns the right length and doesn't error
secret := make([]byte, 32)
result, err := hkdfExpandLabel(secret, "quic key", nil, 16)
require.NoError(t, err)
assert.Len(t, result, 16)
}

253
client/inspect/rules.go Normal file
View File

@@ -0,0 +1,253 @@
package inspect
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// RuleEngine evaluates proxy rules against connection metadata.
// It is safe for concurrent use.
type RuleEngine struct {
mu sync.RWMutex
rules []Rule
// defaultAction applies when no rule matches.
defaultAction Action
log *log.Entry
}
// NewRuleEngine creates a rule engine with the given default action.
func NewRuleEngine(logger *log.Entry, defaultAction Action) *RuleEngine {
return &RuleEngine{
defaultAction: defaultAction,
log: logger,
}
}
// UpdateRules replaces the rule set and default action. Rules are sorted by priority.
func (e *RuleEngine) UpdateRules(rules []Rule, defaultAction Action) {
sorted := make([]Rule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
e.mu.Lock()
e.rules = sorted
e.defaultAction = defaultAction
e.mu.Unlock()
}
// EvalResult holds the outcome of a rule evaluation.
type EvalResult struct {
Action Action
RuleID id.RuleID
}
// Evaluate determines the action for a connection based on the rule set.
// Pass empty path for connection-level evaluation (TLS/SNI), non-empty for request-level (HTTP).
func (e *RuleEngine) Evaluate(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) Action {
r := e.EvaluateWithResult(src, dstDomain, dstAddr, dstPort, proto, path)
return r.Action
}
// EvaluateWithResult is like Evaluate but also returns the matched rule ID.
func (e *RuleEngine) EvaluateWithResult(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) EvalResult {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
rule := &e.rules[i]
if e.ruleMatches(rule, src, dstDomain, dstAddr, dstPort, proto, path) {
e.log.Tracef("rule %s matched: action=%s src=%s domain=%s dst=%s:%d proto=%s path=%s",
rule.ID, rule.Action, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: rule.Action, RuleID: rule.ID}
}
}
e.log.Tracef("no rule matched, default=%s: src=%s domain=%s dst=%s:%d proto=%s path=%s",
e.defaultAction, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: e.defaultAction}
}
// HasPathRulesForDomain returns true if any rule matching the domain has non-empty Paths.
// Used to force MITM inspection when path-level rules exist (paths are only visible after decryption).
func (e *RuleEngine) HasPathRulesForDomain(dstDomain domain.Domain) bool {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
if len(e.rules[i].Paths) > 0 && e.matchDomain(&e.rules[i], dstDomain) {
return true
}
}
return false
}
// ruleMatches checks whether all non-empty fields of a rule match.
// Empty fields are treated as "match any".
// All specified fields must match (AND logic).
func (e *RuleEngine) ruleMatches(rule *Rule, src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) bool {
if !e.matchSource(rule, src) {
return false
}
if !e.matchDomain(rule, dstDomain) {
return false
}
if !e.matchNetwork(rule, dstAddr) {
return false
}
if !e.matchPort(rule, dstPort) {
return false
}
if !e.matchProtocol(rule, proto) {
return false
}
if !e.matchPaths(rule, path) {
return false
}
return true
}
// matchSource returns true if src matches any of the rule's source CIDRs,
// or if no source CIDRs are specified (match any).
func (e *RuleEngine) matchSource(rule *Rule, src netip.Addr) bool {
if len(rule.Sources) == 0 {
return true
}
for _, prefix := range rule.Sources {
if prefix.Contains(src) {
return true
}
}
return false
}
// matchDomain returns true if dstDomain matches any of the rule's domain patterns,
// or if no domain patterns are specified (match any).
func (e *RuleEngine) matchDomain(rule *Rule, dstDomain domain.Domain) bool {
if len(rule.Domains) == 0 {
return true
}
// If we have domain rules but no domain to match against (e.g., raw IP connection),
// the domain condition does not match.
if dstDomain == "" {
return false
}
for _, pattern := range rule.Domains {
if MatchDomain(pattern, dstDomain) {
return true
}
}
return false
}
// matchNetwork returns true if dstAddr is within any of the rule's destination CIDRs,
// or if no destination CIDRs are specified (match any).
func (e *RuleEngine) matchNetwork(rule *Rule, dstAddr netip.Addr) bool {
if len(rule.Networks) == 0 {
return true
}
for _, prefix := range rule.Networks {
if prefix.Contains(dstAddr) {
return true
}
}
return false
}
// matchProtocol returns true if proto matches any of the rule's protocols,
// or if no protocols are specified (match any).
func (e *RuleEngine) matchProtocol(rule *Rule, proto ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, p := range rule.Protocols {
if p == proto {
return true
}
}
return false
}
// matchPort returns true if dstPort matches any of the rule's destination ports,
// or if no ports are specified (match any).
func (e *RuleEngine) matchPort(rule *Rule, dstPort uint16) bool {
if len(rule.Ports) == 0 {
return true
}
return slices.Contains(rule.Ports, dstPort)
}
// matchPaths returns true if path matches any of the rule's path patterns,
// or if no paths are specified (match any). Empty path (connection-level eval) matches all.
func (e *RuleEngine) matchPaths(rule *Rule, path string) bool {
if len(rule.Paths) == 0 {
return true
}
// Connection-level (path=""): rules with paths don't match at connection level.
// HasPathRulesForDomain forces the connection to inspect, so paths are
// checked per-request once the HTTP request is visible.
if path == "" {
return false
}
for _, pattern := range rule.Paths {
if matchPath(pattern, path) {
return true
}
}
return false
}
// matchPath checks if a URL path matches a pattern.
// Supports: exact ("/login"), prefix with wildcard ("/api/*"),
// and contains ("*/admin/*"). A bare "*" matches everything.
func matchPath(pattern, path string) bool {
if pattern == "*" {
return true
}
hasLeadingStar := strings.HasPrefix(pattern, "*")
hasTrailingStar := strings.HasSuffix(pattern, "*")
switch {
case hasLeadingStar && hasTrailingStar:
// */admin/* = contains
middle := strings.Trim(pattern, "*")
return strings.Contains(path, middle)
case hasTrailingStar:
// /api/* = prefix
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(path, prefix)
case hasLeadingStar:
// *.json = suffix
suffix := strings.TrimPrefix(pattern, "*")
return strings.HasSuffix(path, suffix)
default:
// exact
return path == pattern
}
}

View File

@@ -0,0 +1,338 @@
package inspect
import (
"net/netip"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
func testLogger() *log.Entry {
return log.WithField("test", true)
}
func mustDomain(t *testing.T, s string) domain.Domain {
t.Helper()
d, err := domain.FromString(s)
require.NoError(t, err)
return d
}
func TestRuleEngine_Evaluate(t *testing.T) {
tests := []struct {
name string
rules []Rule
defaultAction Action
src netip.Addr
dstDomain domain.Domain
dstAddr netip.Addr
dstPort uint16
want Action
}{
{
name: "no rules returns default allow",
defaultAction: ActionAllow,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "no rules returns default block",
defaultAction: ActionBlock,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain exact match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "malware.example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "malware.example.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "phishing.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard does not match base",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "case insensitive domain match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "Example.COM")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "EXAMPLE.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "source CIDR match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.50"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "source CIDR no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.5"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "destination network match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.1"),
dstAddr: netip.MustParseAddr("10.50.0.1"),
dstPort: 80,
want: ActionInspect,
},
{
name: "port match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "port no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 22,
want: ActionAllow,
},
{
name: "priority ordering first match wins",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("allow-internal"),
Domains: []domain.Domain{mustDomain(t, "*.internal.corp")},
Action: ActionAllow,
Priority: 1,
},
{
ID: id.RuleID("inspect-all"),
Action: ActionInspect,
Priority: 10,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "api.internal.corp"),
dstAddr: netip.MustParseAddr("10.1.0.5"),
dstPort: 443,
want: ActionAllow,
},
{
name: "all fields must match (AND logic)",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Ports: []uint16{443},
Action: ActionBlock,
},
},
// Source matches, domain matches, but port doesn't
src: netip.MustParseAddr("192.168.1.10"),
dstDomain: mustDomain(t, "phish.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 8080,
want: ActionAllow,
},
{
name: "empty domain with domain rule does not match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: "", // raw IP connection, no SNI
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := NewRuleEngine(testLogger(), tt.defaultAction)
engine.UpdateRules(tt.rules, tt.defaultAction)
got := engine.Evaluate(tt.src, tt.dstDomain, tt.dstAddr, tt.dstPort, "", "")
assert.Equal(t, tt.want, got)
})
}
}
func TestRuleEngine_ProtocolMatching(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-websocket",
Protocols: []ProtoType{ProtoWebSocket},
Action: ActionBlock,
Priority: 1,
},
{
ID: "inspect-h2",
Protocols: []ProtoType{ProtoH2},
Action: ActionInspect,
Priority: 2,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
// WebSocket: blocked by rule
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
// HTTP/2: inspected by rule
assert.Equal(t, ActionInspect, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
// Plain HTTP: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 80, ProtoHTTP, ""))
// HTTPS: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
// QUIC/H3: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoH3, ""))
// Empty protocol (unknown): no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_EmptyProtocolsMatchAll(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-all-protos",
Action: ActionBlock,
// No Protocols field = match all protocols
Priority: 1,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTP, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_UpdateRulesSortsByPriority(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{ID: "c", Priority: 30, Action: ActionBlock},
{ID: "a", Priority: 10, Action: ActionInspect},
{ID: "b", Priority: 20, Action: ActionAllow},
}, ActionAllow)
engine.mu.RLock()
defer engine.mu.RUnlock()
require.Len(t, engine.rules, 3)
assert.Equal(t, id.RuleID("a"), engine.rules[0].ID)
assert.Equal(t, id.RuleID("b"), engine.rules[1].ID)
assert.Equal(t, id.RuleID("c"), engine.rules[2].ID)
}

287
client/inspect/sni.go Normal file
View File

@@ -0,0 +1,287 @@
package inspect
import (
"encoding/binary"
"fmt"
"io"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
recordTypeHandshake = 0x16
handshakeTypeClientHello = 0x01
extensionTypeSNI = 0x0000
extensionTypeALPN = 0x0010
sniTypeHostName = 0x00
// maxClientHelloSize is the maximum ClientHello size we'll read.
// Real-world ClientHellos are typically under 1KB but can reach ~16KB with
// many extensions (post-quantum key shares, etc.).
maxClientHelloSize = 16384
)
// ClientHelloInfo holds data extracted from a TLS ClientHello.
type ClientHelloInfo struct {
SNI domain.Domain
ALPN []string
}
// isTLSHandshake reports whether the first byte indicates a TLS handshake record.
func isTLSHandshake(b byte) bool {
return b == recordTypeHandshake
}
// httpMethods lists the first bytes of valid HTTP method tokens.
var httpMethods = [][]byte{
[]byte("GET "),
[]byte("POST"),
[]byte("PUT "),
[]byte("DELE"),
[]byte("HEAD"),
[]byte("OPTI"),
[]byte("PATC"),
[]byte("CONN"),
[]byte("TRAC"),
}
// isHTTPMethod reports whether the peeked bytes look like the start of an HTTP request.
func isHTTPMethod(b []byte) bool {
if len(b) < 4 {
return false
}
for _, m := range httpMethods {
if b[0] == m[0] && b[1] == m[1] && b[2] == m[2] && b[3] == m[3] {
return true
}
}
return false
}
// parseClientHello reads a TLS ClientHello from r and returns SNI and ALPN.
func parseClientHello(r io.Reader) (ClientHelloInfo, error) {
// TLS record header: type(1) + version(2) + length(2)
var recordHeader [5]byte
if _, err := io.ReadFull(r, recordHeader[:]); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read TLS record header: %w", err)
}
if recordHeader[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", recordHeader[0])
}
recordLen := int(binary.BigEndian.Uint16(recordHeader[3:5]))
if recordLen < 4 || recordLen > maxClientHelloSize {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
// Read the full handshake message
msg := make([]byte, recordLen)
if _, err := io.ReadFull(r, msg); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read handshake message: %w", err)
}
return parseClientHelloMsg(msg)
}
// extractSNI reads a TLS ClientHello from r and returns the SNI hostname.
// Returns empty domain if no SNI extension is present.
func extractSNI(r io.Reader) (domain.Domain, error) {
info, err := parseClientHello(r)
return info.SNI, err
}
// extractSNIFromBytes parses SNI from raw bytes that start with the TLS record header.
func extractSNIFromBytes(data []byte) (domain.Domain, error) {
info, err := parseClientHelloFromBytes(data)
return info.SNI, err
}
// parseClientHelloFromBytes parses a ClientHello from raw bytes starting with the TLS record header.
func parseClientHelloFromBytes(data []byte) (ClientHelloInfo, error) {
if len(data) < 5 {
return ClientHelloInfo{}, fmt.Errorf("data too short for TLS record header")
}
if data[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", data[0])
}
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
if recordLen < 4 {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
end := 5 + recordLen
if end > len(data) {
return ClientHelloInfo{}, fmt.Errorf("TLS record truncated: need %d, have %d", end, len(data))
}
return parseClientHelloMsg(data[5:end])
}
// parseClientHelloMsg extracts SNI and ALPN from a raw ClientHello handshake message.
// msg starts at the handshake type byte.
func parseClientHelloMsg(msg []byte) (ClientHelloInfo, error) {
if len(msg) < 4 {
return ClientHelloInfo{}, fmt.Errorf("handshake message too short")
}
if msg[0] != handshakeTypeClientHello {
return ClientHelloInfo{}, fmt.Errorf("not a ClientHello (type=%d)", msg[0])
}
// Handshake header: type(1) + length(3)
helloLen := int(msg[1])<<16 | int(msg[2])<<8 | int(msg[3])
if helloLen+4 > len(msg) {
return ClientHelloInfo{}, fmt.Errorf("ClientHello truncated")
}
hello := msg[4 : 4+helloLen]
return parseHelloBody(hello)
}
// parseHelloBody parses the ClientHello body (after handshake header)
// and extracts SNI and ALPN.
func parseHelloBody(hello []byte) (ClientHelloInfo, error) {
// ClientHello structure:
// version(2) + random(32) + session_id_len(1) + session_id(var)
// + cipher_suites_len(2) + cipher_suites(var)
// + compression_len(1) + compression(var)
// + extensions_len(2) + extensions(var)
var info ClientHelloInfo
if len(hello) < 35 {
return info, fmt.Errorf("ClientHello body too short")
}
pos := 2 + 32 // skip version + random
// Skip session ID
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at session ID")
}
sessionIDLen := int(hello[pos])
pos += 1 + sessionIDLen
// Skip cipher suites
if pos+2 > len(hello) {
return info, fmt.Errorf("ClientHello truncated at cipher suites")
}
cipherLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2 + cipherLen
// Skip compression methods
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at compression")
}
compLen := int(hello[pos])
pos += 1 + compLen
// Extensions
if pos+2 > len(hello) {
return info, nil
}
extLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2
extEnd := pos + extLen
if extEnd > len(hello) {
return info, fmt.Errorf("extensions block truncated")
}
// Walk extensions looking for SNI and ALPN
for pos+4 <= extEnd {
extType := binary.BigEndian.Uint16(hello[pos : pos+2])
extDataLen := int(binary.BigEndian.Uint16(hello[pos+2 : pos+4]))
pos += 4
if pos+extDataLen > extEnd {
return info, fmt.Errorf("extension data truncated")
}
switch extType {
case extensionTypeSNI:
sni, err := parseSNIExtension(hello[pos : pos+extDataLen])
if err != nil {
return info, err
}
info.SNI = sni
case extensionTypeALPN:
info.ALPN = parseALPNExtension(hello[pos : pos+extDataLen])
}
pos += extDataLen
}
return info, nil
}
// parseALPNExtension parses the ALPN extension data and returns protocol names.
// ALPN extension: list_length(2) + entries (each: len(1) + protocol_name(var))
func parseALPNExtension(data []byte) []string {
if len(data) < 2 {
return nil
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return nil
}
var protocols []string
pos := 2
end := 2 + listLen
for pos < end {
if pos >= len(data) {
break
}
nameLen := int(data[pos])
pos++
if pos+nameLen > end {
break
}
protocols = append(protocols, string(data[pos:pos+nameLen]))
pos += nameLen
}
return protocols
}
// parseSNIExtension parses the SNI extension data and returns the hostname.
func parseSNIExtension(data []byte) (domain.Domain, error) {
// SNI extension: list_length(2) + entries
if len(data) < 2 {
return "", fmt.Errorf("SNI extension too short")
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return "", fmt.Errorf("SNI list truncated")
}
pos := 2
end := 2 + listLen
for pos+3 <= end {
nameType := data[pos]
nameLen := int(binary.BigEndian.Uint16(data[pos+1 : pos+3]))
pos += 3
if pos+nameLen > end {
return "", fmt.Errorf("SNI name truncated")
}
if nameType == sniTypeHostName {
hostname := string(data[pos : pos+nameLen])
return domain.FromString(hostname)
}
pos += nameLen
}
return "", nil
}

109
client/inspect/sni_test.go Normal file
View File

@@ -0,0 +1,109 @@
package inspect
import (
"bytes"
"crypto/tls"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractSNI(t *testing.T) {
tests := []struct {
name string
sni string
wantSNI string
wantErr bool
}{
{
name: "standard domain",
sni: "example.com",
wantSNI: "example.com",
},
{
name: "subdomain",
sni: "api.staging.example.com",
wantSNI: "api.staging.example.com",
},
{
name: "mixed case normalized to lowercase",
sni: "Example.COM",
wantSNI: "example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientHello := buildClientHello(t, tt.sni)
sni, err := extractSNI(bytes.NewReader(clientHello))
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantSNI, sni.PunycodeString())
})
}
}
func TestExtractSNI_NotTLS(t *testing.T) {
// HTTP request instead of TLS
data := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
assert.Contains(t, err.Error(), "not a TLS handshake")
}
func TestExtractSNI_Truncated(t *testing.T) {
// Just the record header, no body
data := []byte{0x16, 0x03, 0x01, 0x00, 0x05}
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
}
func TestExtractSNIFromBytes(t *testing.T) {
clientHello := buildClientHello(t, "test.example.com")
sni, err := extractSNIFromBytes(clientHello)
require.NoError(t, err)
assert.Equal(t, "test.example.com", sni.PunycodeString())
}
// buildClientHello generates a real TLS ClientHello with the given SNI.
func buildClientHello(t *testing.T, serverName string) []byte {
t.Helper()
// Use a pipe to capture the ClientHello bytes
clientConn, serverConn := net.Pipe()
done := make(chan []byte, 1)
go func() {
buf := make([]byte, 4096)
n, _ := serverConn.Read(buf)
done <- buf[:n]
serverConn.Close()
}()
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true,
})
// Trigger the handshake (will fail since server isn't TLS, but we capture the ClientHello)
go func() {
_ = tlsConn.Handshake()
tlsConn.Close()
}()
clientHello := <-done
clientConn.Close()
require.True(t, len(clientHello) > 5, "ClientHello too short")
require.Equal(t, byte(0x16), clientHello[0], "not a TLS handshake record")
return clientHello
}

287
client/inspect/tls.go Normal file
View File

@@ -0,0 +1,287 @@
package inspect
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/netip"
"github.com/netbirdio/netbird/shared/management/domain"
)
// handleTLS processes a TLS connection for the kernel-mode path: extracts SNI,
// evaluates rules, and handles the connection internally.
// In envoy mode, allowed connections are forwarded to envoy instead of direct relay.
func (p *Proxy) handleTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil {
return err
}
if result.PassthroughConn != nil {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
}
return p.tlsPassthrough(ctx, pconn, dst, "")
}
return nil
}
// inspectTLS extracts SNI, evaluates rules, and returns the result.
// For ActionAllow: returns the peekConn as PassthroughConn (caller relays).
// For ActionBlock/ActionInspect: handles internally and returns nil PassthroughConn.
func (p *Proxy) inspectTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
// The first 5 bytes (TLS record header) are already peeked.
// Extend to read the full TLS record so bytes remain in the buffer for passthrough.
peeked := pconn.Peeked()
recordLen := int(peeked[3])<<8 | int(peeked[4])
if _, err := pconn.PeekMore(5 + recordLen); err != nil {
return InspectResult{}, fmt.Errorf("read TLS record: %w", err)
}
hello, err := parseClientHelloFromBytes(pconn.Peeked())
if err != nil {
return InspectResult{}, fmt.Errorf("parse ClientHello: %w", err)
}
sni := hello.SNI
proto := protoFromALPN(hello.ALPN)
// Connection-level evaluation: pass empty path.
action := p.evaluateAction(src.IP, sni, dst, proto, "")
// If any rule for this domain has path patterns, force inspect so paths can
// be checked per-request after MITM decryption.
if action == ActionAllow && p.rules.HasPathRulesForDomain(sni) {
p.log.Debugf("upgrading to inspect for %s (path rules exist)", sni.PunycodeString())
action = ActionInspect
}
// Snapshot cert provider under lock for use in this connection.
p.mu.RLock()
certs := p.certs
p.mu.RUnlock()
switch action {
case ActionBlock:
p.log.Debugf("block: TLS to %s (SNI=%s)", dst, sni.PunycodeString())
if certs != nil {
return InspectResult{Action: ActionBlock}, p.tlsBlockPage(ctx, pconn, sni, certs)
}
return InspectResult{Action: ActionBlock}, ErrBlocked
case ActionAllow:
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
case ActionInspect:
if certs == nil {
p.log.Warnf("allow: %s (inspect requested but no MITM CA configured)", sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
err := p.tlsMITM(ctx, pconn, dst, sni, src, certs)
return InspectResult{Action: ActionInspect}, err
default:
p.log.Warnf("block: unknown action %q for %s", action, sni.PunycodeString())
return InspectResult{Action: ActionBlock}, ErrBlocked
}
}
// tlsBlockPage completes a MITM TLS handshake with the client using a dynamic
// certificate, then serves an HTTP 403 block page so the user sees a clear
// message instead of a cryptic SSL error.
func (p *Proxy) tlsBlockPage(ctx context.Context, pconn *peekConn, sni domain.Domain, certs *CertProvider) error {
hostname := sni.PunycodeString()
// Force HTTP/1.1 only: block pages are simple responses, no need for h2
tlsCfg := certs.GetTLSConfig()
tlsCfg.NextProtos = []string{"http/1.1"}
clientTLS := tls.Server(pconn, tlsCfg)
if err := clientTLS.HandshakeContext(ctx); err != nil {
// Client may not trust our CA, handshake fails. That's expected.
return fmt.Errorf("block page TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close block page TLS for %s: %v", hostname, err)
}
}()
writeBlockResponse(clientTLS, nil, sni)
return ErrBlocked
}
// tlsPassthrough connects to the destination and relays encrypted traffic
// without decryption. The peeked ClientHello bytes are replayed.
func (p *Proxy) tlsPassthrough(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return relay(ctx, pconn, remote)
}
// tlsMITM terminates the client TLS connection with a dynamic certificate,
// establishes a new TLS connection to the real destination, and runs the
// HTTP inspection pipeline on the decrypted traffic.
func (p *Proxy) tlsMITM(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, certs *CertProvider) error {
hostname := sni.PunycodeString()
// TLS handshake with client using dynamic cert
clientTLS := tls.Server(pconn, certs.GetTLSConfig())
if err := clientTLS.HandshakeContext(ctx); err != nil {
return fmt.Errorf("client TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close client TLS for %s: %v", hostname, err)
}
}()
// TLS connection to real destination
remoteTLS, err := p.dialTLS(ctx, dst, hostname)
if err != nil {
return fmt.Errorf("dial TLS %s (%s): %w", dst, hostname, err)
}
defer func() {
if err := remoteTLS.Close(); err != nil {
p.log.Debugf("close remote TLS for %s: %v", hostname, err)
}
}()
negotiatedProto := clientTLS.ConnectionState().NegotiatedProtocol
p.log.Tracef("inspect: MITM established for %s (proto=%s)", hostname, negotiatedProto)
return p.inspectHTTP(ctx, clientTLS, remoteTLS, dst, sni, src, negotiatedProto)
}
// dialTLS connects to the destination with TLS, verifying the real server certificate.
func (p *Proxy) dialTLS(ctx context.Context, dst netip.AddrPort, serverName string) (net.Conn, error) {
rawConn, err := p.dialTCP(ctx, dst)
if err != nil {
return nil, err
}
tlsConn := tls.Client(rawConn, &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
})
if err := tlsConn.HandshakeContext(ctx); err != nil {
if closeErr := rawConn.Close(); closeErr != nil {
p.log.Debugf("close raw conn after TLS handshake failure: %v", closeErr)
}
return nil, fmt.Errorf("TLS handshake with %s: %w", serverName, err)
}
return tlsConn, nil
}
// protoFromALPN maps TLS ALPN protocol names to proxy ProtoType.
// Falls back to ProtoHTTPS when no recognized ALPN is present.
func protoFromALPN(alpn []string) ProtoType {
for _, p := range alpn {
switch p {
case "h2":
return ProtoH2
case "h3": // unlikely in TLS, but handle anyway
return ProtoH3
}
}
// No ALPN or only "http/1.1": treat as HTTPS
return ProtoHTTPS
}
// relay copies data bidirectionally between client and remote until one
// side closes or the context is cancelled.
func relay(ctx context.Context, client, remote net.Conn) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(remote, client)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(client, remote)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// evaluateAction runs rule evaluation and resolves the effective action.
// Pass empty path for connection-level (TLS), non-empty for request-level (HTTP).
func (p *Proxy) evaluateAction(src netip.Addr, sni domain.Domain, dst netip.AddrPort, proto ProtoType, path string) Action {
return p.rules.Evaluate(src, sni, dst.Addr(), dst.Port(), proto, path)
}
// dialTCP dials the destination, blocking connections to loopback, link-local,
// multicast, and WG overlay network addresses.
func (p *Proxy) dialTCP(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
ip := dst.Addr().Unmap()
if err := p.validateDialTarget(ip); err != nil {
return nil, fmt.Errorf("dial %s: %w", dst, err)
}
return p.dialer.DialContext(ctx, "tcp", dst.String())
}
// validateDialTarget blocks destinations that should never be dialed by the proxy.
// Mirrors the route validation in systemops.validateRoute.
func (p *Proxy) validateDialTarget(addr netip.Addr) error {
switch {
case !addr.IsValid():
return fmt.Errorf("invalid address")
case addr.IsLoopback():
return fmt.Errorf("loopback address not allowed")
case addr.IsLinkLocalUnicast(), addr.IsLinkLocalMulticast(), addr.IsInterfaceLocalMulticast():
return fmt.Errorf("link-local address not allowed")
case addr.IsMulticast():
return fmt.Errorf("multicast address not allowed")
case p.wgNetwork.IsValid() && p.wgNetwork.Contains(addr):
return fmt.Errorf("overlay network address not allowed")
case p.localIPs != nil && p.localIPs.IsLocalIP(addr):
return fmt.Errorf("local address not allowed")
}
return nil
}
func isClosedErr(err error) bool {
if err == nil {
return false
}
return err == io.EOF ||
err == io.ErrClosedPipe ||
err == net.ErrClosed ||
err == context.Canceled
}

View File

@@ -562,6 +562,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
InspectionCACertPath: config.InspectionCACertPath,
InspectionCAKeyPath: config.InspectionCAKeyPath,
ProfileConfig: config,
}

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/inspect"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -136,6 +137,12 @@ type EngineConfig struct {
MTU uint16
// InspectionCACertPath is a local CA cert for transparent proxy MITM.
// Takes priority over management-pushed CA.
InspectionCACertPath string
// InspectionCAKeyPath is the corresponding private key.
InspectionCAKeyPath string
// for debug bundle generation
ProfileConfig *profilemanager.Config
@@ -222,6 +229,10 @@ type Engine struct {
latestSyncResponse *mgmProto.SyncResponse
flowManager nftypes.FlowManager
// transparentProxy is the transparent forward proxy for traffic inspection.
transparentProxy *inspect.Proxy
udpInspectionHookID string
// auto-update
updateManager *updater.Manager
@@ -1272,6 +1283,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Transparent proxy
e.updateTransparentProxy(networkMap.GetTransparentProxyConfig())
// Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
@@ -1695,6 +1709,8 @@ func (e *Engine) parseNATExternalIPMappings() []string {
func (e *Engine) close() {
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
e.stopTransparentProxy()
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)

View File

@@ -0,0 +1,571 @@
package internal
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/netip"
"net/url"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/inspect"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// updateTransparentProxy processes transparent proxy configuration from the network map.
func (e *Engine) updateTransparentProxy(cfg *mgmProto.TransparentProxyConfig) {
if cfg == nil || !cfg.Enabled {
if cfg == nil {
log.Tracef("inspect: config is nil")
} else {
log.Tracef("inspect: config disabled")
}
// Only stop if explicitly disabled. Don't stop on nil config to avoid
// a gap during policy edits where management briefly pushes empty config.
if cfg != nil && !cfg.Enabled {
e.stopTransparentProxy()
}
return
}
log.Debugf("inspect: config received: enabled=%v mode=%v default_action=%v rules=%d has_ca=%v",
cfg.Enabled, cfg.Mode, cfg.DefaultAction, len(cfg.Rules), len(cfg.CaCertPem) > 0)
// BlockInbound prevents adding TPROXY rules since kernel TPROXY bypasses ACLs.
// The userspace forwarder path still works as it operates within the forwarder hook.
if e.config.BlockInbound {
log.Warnf("inspect: BlockInbound is set, skipping redirect rules (userspace path still active)")
}
proxyConfig, err := toProxyConfig(cfg)
if err != nil {
log.Errorf("inspect: parse config: %v", err)
e.stopTransparentProxy()
return
}
// CA priority: local config > management-pushed > auto-generated self-signed.
// Local wins over mgmt to prevent compromised management from injecting a CA.
e.resolveInspectionCA(&proxyConfig)
if e.transparentProxy != nil {
// Mode change requires full recreate (envoy lifecycle, listener changes).
if proxyConfig.Mode != e.transparentProxy.Mode() {
log.Infof("inspect: mode changed to %s, recreating engine", proxyConfig.Mode)
e.stopTransparentProxy()
} else {
e.transparentProxy.UpdateConfig(proxyConfig)
e.syncTProxyRules(proxyConfig)
e.syncUDPInspectionHook()
return
}
}
if e.wgInterface != nil {
proxyConfig.WGNetwork = e.wgInterface.Address().Network
proxyConfig.ListenAddr = netip.AddrPortFrom(
e.wgInterface.Address().IP.Unmap(),
proxyConfig.ListenAddr.Port(),
)
}
// Pass local IP checker for SSRF prevention
if checker, ok := e.firewall.(inspect.LocalIPChecker); ok {
proxyConfig.LocalIPChecker = checker
}
p, err := inspect.New(e.ctx, log.WithField("component", "inspect"), proxyConfig)
if err != nil {
log.Errorf("inspect: start engine: %v", err)
return
}
e.transparentProxy = p
e.attachProxyToForwarder(p)
e.syncTProxyRules(proxyConfig)
e.syncUDPInspectionHook()
log.Infof("inspect: engine started (mode=%s, rules=%d)", proxyConfig.Mode, len(proxyConfig.Rules))
}
// stopTransparentProxy shuts down the transparent proxy and removes interception.
func (e *Engine) stopTransparentProxy() {
if e.transparentProxy == nil {
return
}
e.attachProxyToForwarder(nil)
e.removeTProxyRule()
e.removeUDPInspectionHook()
if err := e.transparentProxy.Close(); err != nil {
log.Debugf("inspect: close engine: %v", err)
}
e.transparentProxy = nil
log.Info("inspect: engine stopped")
}
const tproxyRuleID = "tproxy-redirect"
// syncTProxyRules adds a TPROXY rule via the firewall manager to intercept
// matching traffic on the WG interface and redirect it to the proxy socket.
func (e *Engine) syncTProxyRules(config inspect.Config) {
if e.config.BlockInbound {
e.removeTProxyRule()
return
}
var listenPort uint16
if e.transparentProxy != nil {
listenPort = e.transparentProxy.ListenPort()
}
if listenPort == 0 {
e.removeTProxyRule()
return
}
if e.firewall == nil {
return
}
dstPorts := make([]uint16, len(config.RedirectPorts))
copy(dstPorts, config.RedirectPorts)
log.Debugf("inspect: syncing redirect rules: listen port %d, redirect ports %v, sources %v",
listenPort, dstPorts, config.RedirectSources)
if err := e.firewall.AddTProxyRule(tproxyRuleID, config.RedirectSources, dstPorts, listenPort); err != nil {
log.Errorf("inspect: add redirect rule: %v", err)
return
}
}
// removeTProxyRule removes the TPROXY redirect rule.
func (e *Engine) removeTProxyRule() {
if e.firewall == nil {
return
}
if err := e.firewall.RemoveTProxyRule(tproxyRuleID); err != nil {
log.Debugf("inspect: remove redirect rule: %v", err)
}
}
// syncUDPInspectionHook registers a UDP packet hook on port 443 for QUIC SNI blocking.
// The hook is called by the USP filter for each UDP packet matching the port,
// allowing the inspection engine to extract QUIC SNI and block by domain.
func (e *Engine) syncUDPInspectionHook() {
e.removeUDPInspectionHook()
if e.firewall == nil || e.transparentProxy == nil {
return
}
p := e.transparentProxy
hookID := e.firewall.AddUDPInspectionHook(443, func(packet []byte) bool {
srcIP, dstIP, dstPort, udpPayload, ok := parseUDPPacket(packet)
if !ok {
return false
}
src := inspect.SourceInfo{IP: srcIP}
dst := netip.AddrPortFrom(dstIP, dstPort)
action := p.HandleUDPPacket(udpPayload, dst, src)
return action == inspect.ActionBlock
})
e.udpInspectionHookID = hookID
log.Debugf("inspect: registered UDP inspection hook on port 443 (id=%s)", hookID)
}
// removeUDPInspectionHook removes the QUIC inspection hook.
func (e *Engine) removeUDPInspectionHook() {
if e.udpInspectionHookID == "" || e.firewall == nil {
return
}
e.firewall.RemoveUDPInspectionHook(e.udpInspectionHookID)
e.udpInspectionHookID = ""
}
// parseUDPPacket extracts source/destination IP, destination port, and UDP
// payload from a raw IP packet. Supports both IPv4 and IPv6.
func parseUDPPacket(packet []byte) (srcIP, dstIP netip.Addr, dstPort uint16, payload []byte, ok bool) {
if len(packet) < 1 {
return srcIP, dstIP, 0, nil, false
}
version := packet[0] >> 4
var udpOffset int
switch version {
case 4:
if len(packet) < 20 {
return srcIP, dstIP, 0, nil, false
}
ihl := int(packet[0]&0x0f) * 4
if len(packet) < ihl+8 {
return srcIP, dstIP, 0, nil, false
}
var srcOK, dstOK bool
srcIP, srcOK = netip.AddrFromSlice(packet[12:16])
dstIP, dstOK = netip.AddrFromSlice(packet[16:20])
if !srcOK || !dstOK {
return srcIP, dstIP, 0, nil, false
}
udpOffset = ihl
case 6:
// IPv6 fixed header is 40 bytes. Next header must be UDP (17).
if len(packet) < 48 { // 40 header + 8 UDP
return srcIP, dstIP, 0, nil, false
}
nextHeader := packet[6]
if nextHeader != 17 { // not UDP (may have extension headers)
return srcIP, dstIP, 0, nil, false
}
var srcOK, dstOK bool
srcIP, srcOK = netip.AddrFromSlice(packet[8:24])
dstIP, dstOK = netip.AddrFromSlice(packet[24:40])
if !srcOK || !dstOK {
return srcIP, dstIP, 0, nil, false
}
udpOffset = 40
default:
return srcIP, dstIP, 0, nil, false
}
srcIP = srcIP.Unmap()
dstIP = dstIP.Unmap()
dstPort = uint16(packet[udpOffset+2])<<8 | uint16(packet[udpOffset+3])
payload = packet[udpOffset+8:]
return srcIP, dstIP, dstPort, payload, true
}
// attachProxyToForwarder sets or clears the proxy on the userspace forwarder.
func (e *Engine) attachProxyToForwarder(p *inspect.Proxy) {
type forwarderGetter interface {
GetForwarder() *forwarder.Forwarder
}
if fg, ok := e.firewall.(forwarderGetter); ok {
if fwd := fg.GetForwarder(); fwd != nil {
fwd.SetProxy(p)
}
}
}
// toProxyConfig converts a proto TransparentProxyConfig to the inspect.Config type.
func toProxyConfig(cfg *mgmProto.TransparentProxyConfig) (inspect.Config, error) {
config := inspect.Config{
Enabled: cfg.Enabled,
DefaultAction: toProxyAction(cfg.DefaultAction),
}
switch cfg.Mode {
case mgmProto.TransparentProxyMode_TP_MODE_ENVOY:
config.Mode = inspect.ModeEnvoy
case mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL:
config.Mode = inspect.ModeExternal
default:
config.Mode = inspect.ModeBuiltin
}
if cfg.ExternalProxyUrl != "" {
u, err := url.Parse(cfg.ExternalProxyUrl)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse external proxy URL: %w", err)
}
config.ExternalURL = u
}
for _, s := range cfg.RedirectSources {
prefix, err := netip.ParsePrefix(s)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse redirect source %q: %w", s, err)
}
config.RedirectSources = append(config.RedirectSources, prefix)
}
for _, p := range cfg.RedirectPorts {
config.RedirectPorts = append(config.RedirectPorts, uint16(p))
}
// TPROXY listen port: fixed default, overridable via env var.
if config.Mode == inspect.ModeBuiltin {
port := uint16(inspect.DefaultTProxyPort)
if v := os.Getenv("NB_TPROXY_PORT"); v != "" {
if p, err := strconv.ParseUint(v, 10, 16); err == nil {
port = uint16(p)
} else {
log.Warnf("invalid NB_TPROXY_PORT %q, using default %d", v, inspect.DefaultTProxyPort)
}
}
config.ListenAddr = netip.AddrPortFrom(netip.IPv4Unspecified(), port)
}
for _, r := range cfg.Rules {
rule, err := toProxyRule(r)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse rule %q: %w", r.Id, err)
}
config.Rules = append(config.Rules, rule)
}
if cfg.Icap != nil {
icapCfg, err := toICAPConfig(cfg.Icap)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse ICAP config: %w", err)
}
config.ICAP = icapCfg
}
if len(cfg.CaCertPem) > 0 && len(cfg.CaKeyPem) > 0 {
tlsCfg, err := parseTLSConfig(cfg.CaCertPem, cfg.CaKeyPem)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse TLS config: %w", err)
}
config.TLS = tlsCfg
}
if config.Mode == inspect.ModeEnvoy {
envCfg := &inspect.EnvoyConfig{
BinaryPath: cfg.EnvoyBinaryPath,
AdminPort: uint16(cfg.EnvoyAdminPort),
}
if cfg.EnvoySnippets != nil {
envCfg.Snippets = &inspect.EnvoySnippets{
HTTPFilters: cfg.EnvoySnippets.HttpFilters,
NetworkFilters: cfg.EnvoySnippets.NetworkFilters,
Clusters: cfg.EnvoySnippets.Clusters,
}
}
config.Envoy = envCfg
}
return config, nil
}
func toProxyRule(r *mgmProto.TransparentProxyRule) (inspect.Rule, error) {
rule := inspect.Rule{
ID: id.RuleID(r.Id),
Action: toProxyAction(r.Action),
Priority: int(r.Priority),
}
for _, d := range r.Domains {
dom, err := domain.FromString(d)
if err != nil {
return inspect.Rule{}, fmt.Errorf("parse domain %q: %w", d, err)
}
rule.Domains = append(rule.Domains, dom)
}
for _, n := range r.Networks {
prefix, err := netip.ParsePrefix(n)
if err != nil {
return inspect.Rule{}, fmt.Errorf("parse network %q: %w", n, err)
}
rule.Networks = append(rule.Networks, prefix)
}
for _, p := range r.Ports {
rule.Ports = append(rule.Ports, uint16(p))
}
for _, proto := range r.Protocols {
rule.Protocols = append(rule.Protocols, toProxyProtoType(proto))
}
rule.Paths = r.Paths
return rule, nil
}
func toProxyProtoType(p mgmProto.TransparentProxyProtocol) inspect.ProtoType {
switch p {
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTP:
return inspect.ProtoHTTP
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTPS:
return inspect.ProtoHTTPS
case mgmProto.TransparentProxyProtocol_TP_PROTO_H2:
return inspect.ProtoH2
case mgmProto.TransparentProxyProtocol_TP_PROTO_H3:
return inspect.ProtoH3
case mgmProto.TransparentProxyProtocol_TP_PROTO_WEBSOCKET:
return inspect.ProtoWebSocket
case mgmProto.TransparentProxyProtocol_TP_PROTO_OTHER:
return inspect.ProtoOther
default:
return ""
}
}
func toProxyAction(a mgmProto.TransparentProxyAction) inspect.Action {
switch a {
case mgmProto.TransparentProxyAction_TP_ACTION_BLOCK:
return inspect.ActionBlock
case mgmProto.TransparentProxyAction_TP_ACTION_INSPECT:
return inspect.ActionInspect
default:
return inspect.ActionAllow
}
}
func toICAPConfig(cfg *mgmProto.TransparentProxyICAPConfig) (*inspect.ICAPConfig, error) {
icap := &inspect.ICAPConfig{
MaxConnections: int(cfg.MaxConnections),
}
if cfg.ReqmodUrl != "" {
u, err := url.Parse(cfg.ReqmodUrl)
if err != nil {
return nil, fmt.Errorf("parse ICAP reqmod URL: %w", err)
}
icap.ReqModURL = u
}
if cfg.RespmodUrl != "" {
u, err := url.Parse(cfg.RespmodUrl)
if err != nil {
return nil, fmt.Errorf("parse ICAP respmod URL: %w", err)
}
icap.RespModURL = u
}
return icap, nil
}
func parseTLSConfig(certPEM, keyPEM []byte) (*inspect.TLSConfig, error) {
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, fmt.Errorf("decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode(keyPEM)
if keyBlock == nil {
return nil, fmt.Errorf("decode CA key PEM")
}
key, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
// Try PKCS8 as fallback
pkcs8Key, pkcs8Err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if pkcs8Err != nil {
return nil, fmt.Errorf("parse CA private key (tried EC and PKCS8): %w", err)
}
return &inspect.TLSConfig{CA: cert, CAKey: pkcs8Key}, nil
}
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
}
// resolveInspectionCA sets the TLS config on the proxy config using priority:
// 1. Local config file CA (InspectionCACertPath/InspectionCAKeyPath)
// 2. Management-pushed CA (already parsed in toProxyConfig)
// 3. Auto-generated self-signed CA (ephemeral, for testing)
// Local always wins to prevent a compromised management server from injecting a CA.
func (e *Engine) resolveInspectionCA(config *inspect.Config) {
// 1. Local CA from config file or env vars
certPath := e.config.InspectionCACertPath
keyPath := e.config.InspectionCAKeyPath
if certPath == "" {
certPath = os.Getenv("NB_INSPECTION_CA_CERT")
}
if keyPath == "" {
keyPath = os.Getenv("NB_INSPECTION_CA_KEY")
}
if certPath != "" && keyPath != "" {
certPEM, err := os.ReadFile(certPath)
if err != nil {
log.Errorf("read local inspection CA cert %s: %v", certPath, err)
return
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
log.Errorf("read local inspection CA key %s: %v", keyPath, err)
return
}
tlsCfg, err := parseTLSConfig(certPEM, keyPEM)
if err != nil {
log.Errorf("parse local inspection CA: %v", err)
return
}
log.Infof("inspect: using local CA from %s", certPath)
config.TLS = tlsCfg
return
}
// 2. Management-pushed CA (already set by toProxyConfig)
if config.TLS != nil {
log.Infof("inspect: using management-pushed CA")
return
}
// 3. Auto-generate self-signed CA for testing / accept-cert UX
tlsCfg, err := generateSelfSignedCA()
if err != nil {
log.Errorf("generate self-signed inspection CA: %v", err)
return
}
log.Infof("inspect: using auto-generated self-signed CA (clients will see certificate warnings)")
config.TLS = tlsCfg
}
// generateSelfSignedCA creates an ephemeral ECDSA P-256 CA certificate.
// Clients will see certificate warnings but can choose to accept.
func generateSelfSignedCA() (*inspect.TLSConfig, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate CA key: %w", err)
}
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial: %w", err)
}
template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
Organization: []string{"NetBird Transparent Proxy"},
CommonName: "NetBird Inspection CA (auto-generated)",
},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 0,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
return nil, fmt.Errorf("create CA certificate: %w", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse generated CA certificate: %w", err)
}
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
}

View File

@@ -0,0 +1,279 @@
package internal
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/inspect"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestToProxyConfig_Basic(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Mode: mgmProto.TransparentProxyMode_TP_MODE_BUILTIN,
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_ALLOW,
RedirectSources: []string{
"10.0.0.0/24",
"192.168.1.0/24",
},
RedirectPorts: []uint32{80, 443},
Rules: []*mgmProto.TransparentProxyRule{
{
Id: "block-evil",
Domains: []string{"*.evil.com", "malware.example.com"},
Action: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
Priority: 1,
},
{
Id: "inspect-internal",
Domains: []string{"*.internal.corp"},
Networks: []string{"10.1.0.0/16"},
Ports: []uint32{443, 8443},
Action: mgmProto.TransparentProxyAction_TP_ACTION_INSPECT,
Priority: 10,
},
},
ListenPort: 8443,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.True(t, config.Enabled)
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
require.Len(t, config.RedirectSources, 2)
assert.Equal(t, "10.0.0.0/24", config.RedirectSources[0].String())
assert.Equal(t, "192.168.1.0/24", config.RedirectSources[1].String())
require.Len(t, config.RedirectPorts, 2)
assert.Equal(t, uint16(80), config.RedirectPorts[0])
assert.Equal(t, uint16(443), config.RedirectPorts[1])
require.Len(t, config.Rules, 2)
// Rule 1: block evil domains
assert.Equal(t, "block-evil", string(config.Rules[0].ID))
assert.Equal(t, inspect.ActionBlock, config.Rules[0].Action)
assert.Equal(t, 1, config.Rules[0].Priority)
require.Len(t, config.Rules[0].Domains, 2)
assert.Equal(t, "*.evil.com", config.Rules[0].Domains[0].PunycodeString())
assert.Equal(t, "malware.example.com", config.Rules[0].Domains[1].PunycodeString())
// Rule 2: inspect internal
assert.Equal(t, "inspect-internal", string(config.Rules[1].ID))
assert.Equal(t, inspect.ActionInspect, config.Rules[1].Action)
assert.Equal(t, 10, config.Rules[1].Priority)
require.Len(t, config.Rules[1].Networks, 1)
assert.Equal(t, "10.1.0.0/16", config.Rules[1].Networks[0].String())
require.Len(t, config.Rules[1].Ports, 2)
// Listen address
assert.True(t, config.ListenAddr.IsValid())
assert.Equal(t, uint16(8443), config.ListenAddr.Port())
}
func TestToProxyConfig_ExternalMode(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Mode: mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL,
ExternalProxyUrl: "http://proxy.corp:8080",
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.Equal(t, inspect.ModeExternal, config.Mode)
assert.Equal(t, inspect.ActionBlock, config.DefaultAction)
require.NotNil(t, config.ExternalURL)
assert.Equal(t, "http", config.ExternalURL.Scheme)
assert.Equal(t, "proxy.corp:8080", config.ExternalURL.Host)
}
func TestToProxyConfig_ICAP(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Icap: &mgmProto.TransparentProxyICAPConfig{
ReqmodUrl: "icap://icap-server:1344/reqmod",
RespmodUrl: "icap://icap-server:1344/respmod",
MaxConnections: 16,
},
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
require.NotNil(t, config.ICAP)
assert.Equal(t, "icap", config.ICAP.ReqModURL.Scheme)
assert.Equal(t, "icap-server:1344", config.ICAP.ReqModURL.Host)
assert.Equal(t, "/reqmod", config.ICAP.ReqModURL.Path)
assert.Equal(t, "/respmod", config.ICAP.RespModURL.Path)
assert.Equal(t, 16, config.ICAP.MaxConnections)
}
func TestToProxyConfig_Empty(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.True(t, config.Enabled)
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
assert.Empty(t, config.RedirectSources)
assert.Empty(t, config.RedirectPorts)
assert.Empty(t, config.Rules)
assert.Nil(t, config.ICAP)
assert.Nil(t, config.TLS)
assert.False(t, config.ListenAddr.IsValid())
}
func TestToProxyConfig_InvalidSource(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
RedirectSources: []string{"not-a-cidr"},
}
_, err := toProxyConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "parse redirect source")
}
func TestToProxyConfig_InvalidNetwork(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Rules: []*mgmProto.TransparentProxyRule{
{
Id: "bad",
Networks: []string{"not-a-cidr"},
},
},
}
_, err := toProxyConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "parse network")
}
func TestToProxyAction(t *testing.T) {
assert.Equal(t, inspect.ActionAllow, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_ALLOW))
assert.Equal(t, inspect.ActionBlock, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_BLOCK))
assert.Equal(t, inspect.ActionInspect, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_INSPECT))
// Unknown defaults to allow
assert.Equal(t, inspect.ActionAllow, toProxyAction(99))
}
func TestParseUDPPacket_IPv4(t *testing.T) {
// Build a minimal IPv4/UDP packet: 20-byte IPv4 header + 8-byte UDP header + payload
packet := make([]byte, 20+8+4)
// IPv4 header: version=4, IHL=5 (20 bytes)
packet[0] = 0x45
// Protocol = UDP (17)
packet[9] = 17
// Source IP: 10.0.0.1
packet[12], packet[13], packet[14], packet[15] = 10, 0, 0, 1
// Dest IP: 192.168.1.1
packet[16], packet[17], packet[18], packet[19] = 192, 168, 1, 1
// UDP source port: 54321 (0xD431)
packet[20] = 0xD4
packet[21] = 0x31
// UDP dest port: 443 (0x01BB)
packet[22] = 0x01
packet[23] = 0xBB
// Payload
packet[28] = 0xDE
packet[29] = 0xAD
packet[30] = 0xBE
packet[31] = 0xEF
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.Equal(t, "10.0.0.1", srcIP.String())
assert.Equal(t, "192.168.1.1", dstIP.String())
assert.Equal(t, uint16(443), dstPort)
assert.Equal(t, []byte{0xDE, 0xAD, 0xBE, 0xEF}, payload)
}
func TestParseUDPPacket_IPv6(t *testing.T) {
// Build a minimal IPv6/UDP packet: 40-byte IPv6 header + 8-byte UDP header + payload
packet := make([]byte, 40+8+4)
// Version = 6 (0x60 in high nibble)
packet[0] = 0x60
// Payload length: 8 (UDP header) + 4 (payload)
packet[4] = 0
packet[5] = 12
// Next header: UDP (17)
packet[6] = 17
// Source: 2001:db8::1
packet[8] = 0x20
packet[9] = 0x01
packet[10] = 0x0d
packet[11] = 0xb8
packet[23] = 0x01
// Dest: 2001:db8::2
packet[24] = 0x20
packet[25] = 0x01
packet[26] = 0x0d
packet[27] = 0xb8
packet[39] = 0x02
// UDP source port: 54321 (0xD431)
packet[40] = 0xD4
packet[41] = 0x31
// UDP dest port: 443 (0x01BB)
packet[42] = 0x01
packet[43] = 0xBB
// Payload
packet[48] = 0xCA
packet[49] = 0xFE
packet[50] = 0xBA
packet[51] = 0xBE
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.Equal(t, "2001:db8::1", srcIP.String())
assert.Equal(t, "2001:db8::2", dstIP.String())
assert.Equal(t, uint16(443), dstPort)
assert.Equal(t, []byte{0xCA, 0xFE, 0xBA, 0xBE}, payload)
}
func TestParseUDPPacket_TooShort(t *testing.T) {
_, _, _, _, ok := parseUDPPacket(nil)
assert.False(t, ok)
_, _, _, _, ok = parseUDPPacket([]byte{0x45, 0x00})
assert.False(t, ok)
}
func TestParseUDPPacket_IPv6ExtensionHeader(t *testing.T) {
// IPv6 with next header != UDP should be rejected
packet := make([]byte, 48)
packet[0] = 0x60
packet[6] = 6 // TCP, not UDP
_, _, _, _, ok := parseUDPPacket(packet)
assert.False(t, ok, "should reject IPv6 packets with non-UDP next header")
}
func TestParseUDPPacket_IPv4MappedIPv6(t *testing.T) {
// IPv4 packet with normal addresses should Unmap correctly
packet := make([]byte, 28)
packet[0] = 0x45
packet[9] = 17
packet[12], packet[13], packet[14], packet[15] = 127, 0, 0, 1
packet[16], packet[17], packet[18], packet[19] = 10, 0, 0, 1
packet[22] = 0x01
packet[23] = 0xBB
srcIP, dstIP, _, _, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.True(t, srcIP.Is4(), "should be plain IPv4, not mapped")
assert.True(t, dstIP.Is4(), "should be plain IPv4, not mapped")
}

View File

@@ -97,6 +97,9 @@ type ConfigInput struct {
LazyConnectionEnabled *bool
MTU *uint16
InspectionCACertPath string
InspectionCAKeyPath string
}
// Config Configuration type
@@ -171,6 +174,13 @@ type Config struct {
LazyConnectionEnabled bool
MTU uint16
// InspectionCACertPath is the path to a PEM CA certificate for transparent proxy MITM.
// Local CA takes priority over management-pushed CA.
InspectionCACertPath string
// InspectionCAKeyPath is the path to the PEM CA private key for transparent proxy MITM.
InspectionCAKeyPath string
}
var ConfigDirOverride string
@@ -603,6 +613,17 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.InspectionCACertPath != "" && input.InspectionCACertPath != config.InspectionCACertPath {
log.Infof("updating inspection CA cert path to %s", input.InspectionCACertPath)
config.InspectionCACertPath = input.InspectionCACertPath
updated = true
}
if input.InspectionCAKeyPath != "" && input.InspectionCAKeyPath != config.InspectionCAKeyPath {
log.Infof("updating inspection CA key path to %s", input.InspectionCAKeyPath)
config.InspectionCAKeyPath = input.InspectionCAKeyPath
updated = true
}
return updated, nil
}

View File

@@ -156,6 +156,10 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.ForwardingRules = forwardingRules
}
if networkMap.TransparentProxyConfig != nil {
response.NetworkMap.TransparentProxyConfig = networkMap.TransparentProxyConfig.ToProto()
}
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim

View File

@@ -113,6 +113,10 @@ type Manager interface {
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetInspectionPolicy(ctx context.Context, accountID, policyID, userID string) (*types.InspectionPolicy, error)
SaveInspectionPolicy(ctx context.Context, accountID, userID string, policy *types.InspectionPolicy, create bool) (*types.InspectionPolicy, error)
DeleteInspectionPolicy(ctx context.Context, accountID, policyID, userID string) error
ListInspectionPolicies(ctx context.Context, accountID, userID string) ([]*types.InspectionPolicy, error)
GetIdpManager() idp.Manager
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)

View File

@@ -51,6 +51,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/instance"
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
inspectionHandler "github.com/netbirdio/netbird/management/server/http/handlers/inspection"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
@@ -162,6 +163,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
inspectionHandler.AddEndpoints(accountManager, router)
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
groups.AddEndpoints(accountManager, router)
routes.AddEndpoints(accountManager, router)

View File

@@ -0,0 +1,288 @@
package inspection
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// Handler manages inspection policy CRUD operations.
type Handler struct {
accountManager account.Manager
}
// AddEndpoints registers the inspection policy API endpoints.
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
h := &Handler{accountManager: accountManager}
router.HandleFunc("/inspection-policies", h.list).Methods("GET", "OPTIONS")
router.HandleFunc("/inspection-policies", h.create).Methods("POST", "OPTIONS")
router.HandleFunc("/inspection-policies/{policyId}", h.get).Methods("GET", "OPTIONS")
router.HandleFunc("/inspection-policies/{policyId}", h.update).Methods("PUT", "OPTIONS")
router.HandleFunc("/inspection-policies/{policyId}", h.remove).Methods("DELETE", "OPTIONS")
}
func (h *Handler) list(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policies, err := h.accountManager.ListInspectionPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
result := make([]*api.InspectionPolicy, 0, len(policies))
for _, p := range policies {
result = append(result, toAPIResponse(p))
}
util.WriteJSONObject(r.Context(), w, result)
}
func (h *Handler) create(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.InspectionPolicyMinimum
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "decode request: %v", err), w)
return
}
policy := fromAPIRequest(&req)
saved, err := h.accountManager.SaveInspectionPolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(saved))
}
func (h *Handler) get(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyID := mux.Vars(r)["policyId"]
policy, err := h.accountManager.GetInspectionPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(policy))
}
func (h *Handler) update(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyID := mux.Vars(r)["policyId"]
var req api.InspectionPolicyMinimum
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "decode request: %v", err), w)
return
}
policy := fromAPIRequest(&req)
policy.ID = policyID
saved, err := h.accountManager.SaveInspectionPolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, false)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(saved))
}
func (h *Handler) remove(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyID := mux.Vars(r)["policyId"]
if err := h.accountManager.DeleteInspectionPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, struct{}{})
}
func toAPIResponse(p *types.InspectionPolicy) *api.InspectionPolicy {
id := p.ID
resp := &api.InspectionPolicy{
Id: &id,
Name: p.Name,
Enabled: p.Enabled,
}
if p.Description != "" {
resp.Description = &p.Description
}
if p.Mode != "" {
mode := api.InspectionPolicyMode(p.Mode)
resp.Mode = &mode
}
if p.ExternalURL != "" {
resp.ExternalUrl = &p.ExternalURL
}
if p.DefaultAction != "" {
da := api.InspectionPolicyDefaultAction(p.DefaultAction)
resp.DefaultAction = &da
}
if len(p.RedirectPorts) > 0 {
resp.RedirectPorts = &p.RedirectPorts
}
if p.CACertPEM != "" {
resp.CaCertPem = &p.CACertPEM
}
if p.CAKeyPEM != "" {
resp.CaKeyPem = &p.CAKeyPEM
}
if p.EnvoyBinaryPath != "" {
resp.EnvoyBinaryPath = &p.EnvoyBinaryPath
}
if p.EnvoyAdminPort != 0 {
port := int(p.EnvoyAdminPort)
resp.EnvoyAdminPort = &port
}
if p.ICAP != nil {
resp.Icap = &api.InspectionICAPConfig{}
if p.ICAP.ReqModURL != "" {
resp.Icap.ReqmodUrl = &p.ICAP.ReqModURL
}
if p.ICAP.RespModURL != "" {
resp.Icap.RespmodUrl = &p.ICAP.RespModURL
}
if p.ICAP.MaxConnections != 0 {
resp.Icap.MaxConnections = &p.ICAP.MaxConnections
}
}
rules := make([]api.InspectionPolicyRule, 0, len(p.Rules))
for _, r := range p.Rules {
rule := api.InspectionPolicyRule{
Action: api.InspectionPolicyRuleAction(r.Action),
Priority: r.Priority,
}
if len(r.Domains) > 0 {
rule.Domains = &r.Domains
}
if len(r.Networks) > 0 {
rule.Networks = &r.Networks
}
if len(r.Protocols) > 0 {
protos := make([]api.InspectionPolicyRuleProtocols, len(r.Protocols))
for i, proto := range r.Protocols {
protos[i] = api.InspectionPolicyRuleProtocols(proto)
}
rule.Protocols = &protos
}
if len(r.Paths) > 0 {
rule.Paths = &r.Paths
}
rules = append(rules, rule)
}
resp.Rules = rules
return resp
}
func fromAPIRequest(req *api.InspectionPolicyMinimum) *types.InspectionPolicy {
p := &types.InspectionPolicy{
Name: req.Name,
Enabled: req.Enabled,
}
if req.Description != nil {
p.Description = *req.Description
}
if req.Mode != nil {
p.Mode = string(*req.Mode)
}
if req.ExternalUrl != nil {
p.ExternalURL = *req.ExternalUrl
}
if req.DefaultAction != nil {
p.DefaultAction = string(*req.DefaultAction)
}
if req.RedirectPorts != nil {
p.RedirectPorts = *req.RedirectPorts
}
if req.CaCertPem != nil {
p.CACertPEM = *req.CaCertPem
}
if req.CaKeyPem != nil {
p.CAKeyPEM = *req.CaKeyPem
}
if req.EnvoyBinaryPath != nil {
p.EnvoyBinaryPath = *req.EnvoyBinaryPath
}
if req.EnvoyAdminPort != nil {
p.EnvoyAdminPort = *req.EnvoyAdminPort
}
if req.Icap != nil {
p.ICAP = &types.InspectionICAPConfig{}
if req.Icap.ReqmodUrl != nil {
p.ICAP.ReqModURL = *req.Icap.ReqmodUrl
}
if req.Icap.RespmodUrl != nil {
p.ICAP.RespModURL = *req.Icap.RespmodUrl
}
if req.Icap.MaxConnections != nil {
p.ICAP.MaxConnections = *req.Icap.MaxConnections
}
}
for _, r := range req.Rules {
rule := types.InspectionPolicyRule{
Action: string(r.Action),
Priority: r.Priority,
}
if r.Domains != nil {
rule.Domains = *r.Domains
}
if r.Networks != nil {
rule.Networks = *r.Networks
}
if r.Protocols != nil {
for _, proto := range *r.Protocols {
rule.Protocols = append(rule.Protocols, string(proto))
}
}
if r.Paths != nil {
rule.Paths = *r.Paths
}
p.Rules = append(p.Rules, rule)
}
return p
}

View File

@@ -281,6 +281,10 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
policy.SourcePostureChecks = *req.SourcePostureChecks
}
if req.InspectionPolicies != nil {
policy.InspectionPolicies = *req.InspectionPolicies
}
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy, create)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -377,6 +381,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
Description: &policy.Description,
Enabled: policy.Enabled,
SourcePostureChecks: policy.SourcePostureChecks,
InspectionPolicies: &policy.InspectionPolicies,
}
for _, r := range policy.Rules {
rID := r.ID

View File

@@ -0,0 +1,52 @@
package server
import (
"context"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// GetInspectionPolicy returns an inspection policy by ID.
func (am *DefaultAccountManager) GetInspectionPolicy(ctx context.Context, accountID, policyID, userID string) (*types.InspectionPolicy, error) {
return am.Store.GetInspectionPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
}
// SaveInspectionPolicy creates or updates an inspection policy.
func (am *DefaultAccountManager) SaveInspectionPolicy(ctx context.Context, accountID, userID string, policy *types.InspectionPolicy, create bool) (*types.InspectionPolicy, error) {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if create {
policy.ID = xid.New().String()
}
policy.AccountID = accountID
return transaction.SaveInspectionPolicy(ctx, store.LockingStrengthUpdate, policy)
})
if err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
return policy, nil
}
// DeleteInspectionPolicy removes an inspection policy.
func (am *DefaultAccountManager) DeleteInspectionPolicy(ctx context.Context, accountID, policyID, userID string) error {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return transaction.DeleteInspectionPolicy(ctx, store.LockingStrengthUpdate, accountID, policyID)
})
if err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
return nil
}
// ListInspectionPolicies returns all inspection policies for the account.
func (am *DefaultAccountManager) ListInspectionPolicies(ctx context.Context, accountID, userID string) ([]*types.InspectionPolicy, error) {
return am.Store.GetAccountInspectionPolicies(ctx, store.LockingStrengthShare, accountID)
}

View File

@@ -909,6 +909,26 @@ func (am *MockAccountManager) ListPostureChecks(ctx context.Context, accountID,
return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented")
}
// GetInspectionPolicy mocks GetInspectionPolicy of the AccountManager interface
func (am *MockAccountManager) GetInspectionPolicy(ctx context.Context, accountID, policyID, userID string) (*types.InspectionPolicy, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInspectionPolicy is not implemented")
}
// SaveInspectionPolicy mocks SaveInspectionPolicy of the AccountManager interface
func (am *MockAccountManager) SaveInspectionPolicy(ctx context.Context, accountID, userID string, policy *types.InspectionPolicy, create bool) (*types.InspectionPolicy, error) {
return nil, status.Errorf(codes.Unimplemented, "method SaveInspectionPolicy is not implemented")
}
// DeleteInspectionPolicy mocks DeleteInspectionPolicy of the AccountManager interface
func (am *MockAccountManager) DeleteInspectionPolicy(ctx context.Context, accountID, policyID, userID string) error {
return status.Errorf(codes.Unimplemented, "method DeleteInspectionPolicy is not implemented")
}
// ListInspectionPolicies mocks ListInspectionPolicies of the AccountManager interface
func (am *MockAccountManager) ListInspectionPolicies(ctx context.Context, accountID, userID string) ([]*types.InspectionPolicy, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListInspectionPolicies is not implemented")
}
// GetIdpManager mocks GetIdpManager of the AccountManager interface
func (am *MockAccountManager) GetIdpManager() idp.Manager {
if am.GetIdpManagerFunc != nil {

View File

@@ -18,6 +18,28 @@ type NetworkRouter struct {
Masquerade bool
Metric int
Enabled bool
Inspection *InspectionConfig `gorm:"serializer:json"`
}
// InspectionConfig holds traffic inspection settings for a routing peer.
// L7 inspection rules are stored separately as ProxyRule entities.
type InspectionConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode"` // "builtin" or "external"
ExternalURL string `json:"external_url"`
DefaultAction string `json:"default_action"` // "allow", "block", "inspect"
RedirectPorts []int `json:"redirect_ports"`
ICAP *InspectionICAP `json:"icap,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
CAKeyPEM string `json:"ca_key_pem,omitempty"`
ListenPort int `json:"listen_port"`
}
// InspectionICAP holds ICAP service configuration.
type InspectionICAP struct {
ReqModURL string `json:"reqmod_url"`
RespModURL string `json:"respmod_url"`
MaxConnections int `json:"max_connections"`
}
func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) {
@@ -38,7 +60,7 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup
}
func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter {
return &api.NetworkRouter{
resp := &api.NetworkRouter{
Id: n.ID,
Peer: &n.Peer,
PeerGroups: &n.PeerGroups,
@@ -46,6 +68,12 @@ func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter {
Metric: n.Metric,
Enabled: n.Enabled,
}
if n.Inspection != nil {
resp.Inspection = inspectionToAPI(n.Inspection)
}
return resp
}
func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
@@ -60,10 +88,11 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
n.Masquerade = req.Masquerade
n.Metric = req.Metric
n.Enabled = req.Enabled
n.Inspection = inspectionFromAPI(req.Inspection)
}
func (n *NetworkRouter) Copy() *NetworkRouter {
return &NetworkRouter{
c := &NetworkRouter{
ID: n.ID,
NetworkID: n.NetworkID,
AccountID: n.AccountID,
@@ -73,6 +102,108 @@ func (n *NetworkRouter) Copy() *NetworkRouter {
Metric: n.Metric,
Enabled: n.Enabled,
}
if n.Inspection != nil {
insp := *n.Inspection
c.Inspection = &insp
}
return c
}
func inspectionToAPI(c *InspectionConfig) *api.RouterInspectionConfig {
if c == nil {
return nil
}
mode := api.RouterInspectionConfigMode(c.Mode)
defaultAction := api.RouterInspectionConfigDefaultAction(c.DefaultAction)
resp := &api.RouterInspectionConfig{
Enabled: c.Enabled,
Mode: &mode,
DefaultAction: &defaultAction,
}
if c.ExternalURL != "" {
resp.ExternalUrl = &c.ExternalURL
}
if len(c.RedirectPorts) > 0 {
resp.RedirectPorts = &c.RedirectPorts
}
if c.CACertPEM != "" {
resp.CaCertPem = &c.CACertPEM
}
if c.CAKeyPEM != "" {
resp.CaKeyPem = &c.CAKeyPEM
}
if c.ICAP != nil {
icap := api.InspectionICAPConfig{}
if c.ICAP.ReqModURL != "" {
icap.ReqmodUrl = &c.ICAP.ReqModURL
}
if c.ICAP.RespModURL != "" {
icap.RespmodUrl = &c.ICAP.RespModURL
}
if c.ICAP.MaxConnections > 0 {
icap.MaxConnections = &c.ICAP.MaxConnections
}
resp.Icap = &icap
}
return resp
}
func inspectionFromAPI(c *api.RouterInspectionConfig) *InspectionConfig {
if c == nil {
return nil
}
insp := &InspectionConfig{
Enabled: c.Enabled,
}
if c.Mode != nil {
insp.Mode = string(*c.Mode)
}
if c.DefaultAction != nil {
insp.DefaultAction = string(*c.DefaultAction)
}
if c.ExternalUrl != nil {
insp.ExternalURL = *c.ExternalUrl
}
if c.RedirectPorts != nil {
insp.RedirectPorts = *c.RedirectPorts
}
if c.CaCertPem != nil {
insp.CACertPEM = *c.CaCertPem
}
if c.CaKeyPem != nil {
insp.CAKeyPEM = *c.CaKeyPem
}
if c.Icap != nil {
insp.ICAP = &InspectionICAP{}
if c.Icap.ReqmodUrl != nil {
insp.ICAP.ReqModURL = *c.Icap.ReqmodUrl
}
if c.Icap.RespmodUrl != nil {
insp.ICAP.RespModURL = *c.Icap.RespmodUrl
}
if c.Icap.MaxConnections != nil {
insp.ICAP.MaxConnections = *c.Icap.MaxConnections
}
}
return insp
}
func derefInt(p *int) int {
if p == nil {
return 0
}
return *p
}
func (n *NetworkRouter) EventMeta(network *types.Network) map[string]any {

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
_ "embed"
"slices"
"github.com/rs/xid"
@@ -150,6 +151,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
return false, nil
}
// Inspection policy changes always affect peers since they control
// the transparent proxy config pushed in the network map.
if !slices.Equal(existingPolicy.InspectionPolicies, policy.InspectionPolicies) {
return true, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil

View File

@@ -134,7 +134,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
&accesslogs.AccessLogEntry{}, &proxy.Proxy{}, &types.InspectionPolicy{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -1115,6 +1115,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
Preload("InspectionPolicies").
Preload("Networks").
Preload("NetworkRouters").
Preload("NetworkResources").
@@ -3877,6 +3878,63 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, accountID, postureCh
return nil
}
// GetAccountInspectionPolicies returns all inspection policies for the account.
// CA cert and key are decrypted after loading.
func (s *SqlStore) GetAccountInspectionPolicies(ctx context.Context, _ LockingStrength, accountID string) ([]*types.InspectionPolicy, error) {
var policies []*types.InspectionPolicy
result := s.db.Where("account_id = ?", accountID).Find(&policies)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get inspection policies: %s", result.Error)
}
for _, p := range policies {
if err := p.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt inspection policy %s: %w", p.ID, err)
}
}
return policies, nil
}
// GetInspectionPolicyByID returns an inspection policy by ID.
// CA cert and key are decrypted after loading.
func (s *SqlStore) GetInspectionPolicyByID(ctx context.Context, _ LockingStrength, accountID, policyID string) (*types.InspectionPolicy, error) {
var policy types.InspectionPolicy
result := s.db.Where(accountAndIDQueryCondition, accountID, policyID).First(&policy)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get inspection policy: %s", result.Error)
}
if err := policy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt inspection policy %s: %w", policyID, err)
}
return &policy, nil
}
// SaveInspectionPolicy saves an inspection policy to the database.
// CA cert and key are encrypted before storage.
func (s *SqlStore) SaveInspectionPolicy(ctx context.Context, _ LockingStrength, policy *types.InspectionPolicy) error {
if err := policy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt inspection policy: %w", err)
}
result := s.db.Save(policy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save inspection policy: %s", result.Error)
return status.Errorf(status.Internal, "failed to save inspection policy")
}
return nil
}
// DeleteInspectionPolicy deletes an inspection policy from the database.
func (s *SqlStore) DeleteInspectionPolicy(ctx context.Context, _ LockingStrength, accountID, policyID string) error {
result := s.db.Delete(&types.InspectionPolicy{}, accountAndIDQueryCondition, accountID, policyID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete inspection policy: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete inspection policy")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "inspection policy %s not found", policyID)
}
return nil
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
tx := s.db

View File

@@ -0,0 +1,151 @@
package store
import (
"context"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/types"
)
func TestSqlStore_InspectionPolicyCRUD(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
accountID := "test-account-inspection"
// Create account first
account := newAccountWithId(ctx, accountID, "test-user", "example.com")
err := store.SaveAccount(ctx, account)
require.NoError(t, err)
// Create
policy := &types.InspectionPolicy{
ID: "ip-1",
AccountID: accountID,
Name: "Block gambling",
Description: "Block all gambling sites",
Enabled: true,
Rules: []types.InspectionPolicyRule{
{
Domains: []string{"*.gambling.com", "*.betting.com"},
Action: "block",
Priority: 1,
Protocols: []string{"https"},
},
{
Domains: []string{"*.malware.org"},
Networks: []string{"10.0.0.0/8"},
Action: "block",
Priority: 2,
},
},
}
err = store.SaveInspectionPolicy(ctx, LockingStrengthUpdate, policy)
require.NoError(t, err)
// Read
got, err := store.GetInspectionPolicyByID(ctx, LockingStrengthShare, accountID, "ip-1")
require.NoError(t, err)
assert.Equal(t, "Block gambling", got.Name)
assert.Equal(t, "Block all gambling sites", got.Description)
assert.True(t, got.Enabled)
require.Len(t, got.Rules, 2)
assert.Equal(t, []string{"*.gambling.com", "*.betting.com"}, got.Rules[0].Domains)
assert.Equal(t, "block", got.Rules[0].Action)
assert.Equal(t, []string{"https"}, got.Rules[0].Protocols)
assert.Equal(t, []string{"10.0.0.0/8"}, got.Rules[1].Networks)
// List
policies, err := store.GetAccountInspectionPolicies(ctx, LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, "ip-1", policies[0].ID)
// Update
policy.Name = "Block gambling updated"
policy.Rules = append(policy.Rules, types.InspectionPolicyRule{
Domains: []string{"*.phishing.net"},
Action: "inspect",
Priority: 3,
})
err = store.SaveInspectionPolicy(ctx, LockingStrengthUpdate, policy)
require.NoError(t, err)
got, err = store.GetInspectionPolicyByID(ctx, LockingStrengthShare, accountID, "ip-1")
require.NoError(t, err)
assert.Equal(t, "Block gambling updated", got.Name)
require.Len(t, got.Rules, 3)
assert.Equal(t, "inspect", got.Rules[2].Action)
// Delete
err = store.DeleteInspectionPolicy(ctx, LockingStrengthUpdate, accountID, "ip-1")
require.NoError(t, err)
_, err = store.GetInspectionPolicyByID(ctx, LockingStrengthShare, accountID, "ip-1")
assert.Error(t, err)
policies, err = store.GetAccountInspectionPolicies(ctx, LockingStrengthShare, accountID)
require.NoError(t, err)
assert.Empty(t, policies)
})
}
func TestSqlStore_InspectionPolicyNotFound(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
accountID := "test-account-no-ip"
account := newAccountWithId(ctx, accountID, "test-user", "example.com")
err := store.SaveAccount(ctx, account)
require.NoError(t, err)
_, err = store.GetInspectionPolicyByID(ctx, LockingStrengthShare, accountID, "nonexistent")
assert.Error(t, err)
})
}
func TestSqlStore_InspectionPolicyIsolation(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
// Create two accounts
acc1 := newAccountWithId(ctx, "acc-1", "user-1", "one.com")
acc2 := newAccountWithId(ctx, "acc-2", "user-2", "two.com")
require.NoError(t, store.SaveAccount(ctx, acc1))
require.NoError(t, store.SaveAccount(ctx, acc2))
// Save policy for acc-1
policy := &types.InspectionPolicy{
ID: "ip-acc1",
AccountID: "acc-1",
Name: "Account 1 policy",
Enabled: true,
Rules: []types.InspectionPolicyRule{{Action: "block", Priority: 1}},
}
require.NoError(t, store.SaveInspectionPolicy(ctx, LockingStrengthUpdate, policy))
// acc-2 should not see acc-1's policy
policies, err := store.GetAccountInspectionPolicies(ctx, LockingStrengthShare, "acc-2")
require.NoError(t, err)
assert.Empty(t, policies)
// acc-2 should not be able to get acc-1's policy by ID
_, err = store.GetInspectionPolicyByID(ctx, LockingStrengthShare, "acc-2", "ip-acc1")
assert.Error(t, err)
})
}

View File

@@ -143,6 +143,11 @@ type Store interface {
SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error
GetAccountInspectionPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.InspectionPolicy, error)
GetInspectionPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.InspectionPolicy, error)
SaveInspectionPolicy(ctx context.Context, lockStrength LockingStrength, policy *types.InspectionPolicy) error
DeleteInspectionPolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountID, peerId string, groupID string) error

File diff suppressed because it is too large Load Diff

View File

@@ -101,6 +101,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
InspectionPolicies []*InspectionPolicy `gorm:"foreignKey:AccountID;references:id"`
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
Domains []*proxydomain.Domain `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings

View File

@@ -81,6 +81,12 @@ func (a *Account) GetPeerNetworkMapComponents(
return nil
}
// Build inspection policies map for the network map builder
inspectionPoliciesMap := make(map[string]*InspectionPolicy, len(a.InspectionPolicies))
for _, ip := range a.InspectionPolicies {
inspectionPoliciesMap[ip.ID] = ip
}
components := &NetworkMapComponents{
PeerID: peerID,
Network: a.Network.Copy(),
@@ -91,6 +97,7 @@ func (a *Account) GetPeerNetworkMapComponents(
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
RouterPeers: make(map[string]*nbpeer.Peer),
InspectionPolicies: inspectionPoliciesMap,
}
components.AccountSettings = &AccountSettingsInfo{

View File

@@ -0,0 +1,158 @@
package types
import (
"fmt"
"github.com/rs/xid"
"github.com/netbirdio/netbird/util/crypt"
)
// InspectionPolicy is a reusable set of L7 inspection rules with proxy configuration.
// Referenced by policies via InspectionPolicies field, similar to posture checks.
// Contains both what to inspect (rules) and how to inspect (CA, ICAP, mode).
type InspectionPolicy struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Description string
Enabled bool
Rules []InspectionPolicyRule `gorm:"serializer:json"`
// Mode is the proxy operation mode: "builtin", "envoy", or "external".
Mode string `json:"mode"`
// ExternalURL is the upstream proxy URL (HTTP CONNECT or SOCKS5) for external mode.
ExternalURL string `json:"external_url"`
// DefaultAction applies when no rule matches: "allow", "block", or "inspect".
DefaultAction string `json:"default_action"`
// Redirect ports: which destination ports to intercept at L4.
// Empty means all ports.
RedirectPorts []int `gorm:"serializer:json" json:"redirect_ports"`
// MITM CA certificate and key (PEM-encoded)
CACertPEM string `json:"ca_cert_pem"`
CAKeyPEM string `json:"ca_key_pem"`
// ICAP configuration for external content scanning
ICAP *InspectionICAPConfig `gorm:"serializer:json" json:"icap"`
// Envoy sidecar configuration (mode "envoy" only)
EnvoyBinaryPath string `json:"envoy_binary_path"`
EnvoyAdminPort int `json:"envoy_admin_port"`
EnvoySnippets *InspectionEnvoySnippets `gorm:"serializer:json" json:"envoy_snippets"`
}
// InspectionEnvoySnippets holds user-provided YAML fragments for envoy config customization.
// Only safe snippet types are exposed: filters and clusters. Listeners and bootstrap
// overrides are not allowed since the envoy instance is fully managed.
type InspectionEnvoySnippets struct {
HTTPFilters string `json:"http_filters"`
NetworkFilters string `json:"network_filters"`
Clusters string `json:"clusters"`
}
// InspectionICAPConfig holds ICAP protocol settings.
type InspectionICAPConfig struct {
ReqModURL string `json:"reqmod_url"`
RespModURL string `json:"respmod_url"`
MaxConnections int `json:"max_connections"`
}
// InspectionPolicyRule is an L7 rule within an inspection policy.
// No source or network references: sources come from the referencing policy,
// the destination network/routing peer is derived from the policy's destination.
type InspectionPolicyRule struct {
Domains []string `json:"domains"`
// Networks restricts this rule to specific destination CIDRs.
Networks []string `json:"networks"`
// Protocols this rule applies to: "http", "https", "h2", "h3", "websocket", "other".
Protocols []string `json:"protocols"`
// Paths are URL path patterns: "/api/", "/login", "/admin/*".
Paths []string `json:"paths"`
Action string `json:"action"`
Priority int `json:"priority"`
}
// NewInspectionPolicy creates a new InspectionPolicy with a generated ID.
func NewInspectionPolicy(accountID, name, description string, enabled bool) *InspectionPolicy {
return &InspectionPolicy{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Description: description,
Enabled: enabled,
}
}
// Copy returns a deep copy.
func (p *InspectionPolicy) Copy() *InspectionPolicy {
c := *p
c.Rules = make([]InspectionPolicyRule, len(p.Rules))
for i, r := range p.Rules {
c.Rules[i] = r
c.Rules[i].Domains = append([]string{}, r.Domains...)
c.Rules[i].Networks = append([]string{}, r.Networks...)
c.Rules[i].Protocols = append([]string{}, r.Protocols...)
}
c.RedirectPorts = append([]int{}, p.RedirectPorts...)
if p.ICAP != nil {
icap := *p.ICAP
c.ICAP = &icap
}
return &c
}
// EncryptSensitiveData encrypts CA cert and key in place.
func (p *InspectionPolicy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if p.CACertPEM != "" {
p.CACertPEM, err = enc.Encrypt(p.CACertPEM)
if err != nil {
return fmt.Errorf("encrypt ca_cert_pem: %w", err)
}
}
if p.CAKeyPEM != "" {
p.CAKeyPEM, err = enc.Encrypt(p.CAKeyPEM)
if err != nil {
return fmt.Errorf("encrypt ca_key_pem: %w", err)
}
}
return nil
}
// DecryptSensitiveData decrypts CA cert and key in place.
func (p *InspectionPolicy) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if p.CACertPEM != "" {
p.CACertPEM, err = enc.Decrypt(p.CACertPEM)
if err != nil {
return fmt.Errorf("decrypt ca_cert_pem: %w", err)
}
}
if p.CAKeyPEM != "" {
p.CAKeyPEM, err = enc.Decrypt(p.CAKeyPEM)
if err != nil {
return fmt.Errorf("decrypt ca_key_pem: %w", err)
}
}
return nil
}
// HasDomainOnly returns true if this rule matches by domain and has no CIDR destinations.
func (r *InspectionPolicyRule) HasDomainOnly() bool {
return len(r.Domains) > 0 && len(r.Networks) == 0
}
// HasCIDRDestination returns true if this rule specifies destination CIDRs.
func (r *InspectionPolicyRule) HasCIDRDestination() bool {
return len(r.Networks) > 0
}

View File

@@ -37,9 +37,10 @@ type NetworkMap struct {
OfflinePeers []*nbpeer.Peer
FirewallRules []*FirewallRule
RoutesFirewallRules []*RouteFirewallRule
ForwardingRules []*ForwardingRule
AuthorizedUsers map[string]map[string]struct{}
EnableSSH bool
ForwardingRules []*ForwardingRule
TransparentProxyConfig *TransparentProxyConfig
AuthorizedUsers map[string]map[string]struct{}
EnableSSH bool
}
func (nm *NetworkMap) Merge(other *NetworkMap) {

View File

@@ -45,6 +45,13 @@ type NetworkMapComponents struct {
PostureFailedPeers map[string]map[string]struct{}
RouterPeers map[string]*nbpeer.Peer
// TransparentProxyConfig is the account-level transparent proxy configuration.
// Nil if no proxy is configured at account level.
TransparentProxyConfig *TransparentProxyConfig
// InspectionPolicies are reusable inspection rule sets referenced by policies.
InspectionPolicies map[string]*InspectionPolicy
}
type AccountSettingsInfo struct {
@@ -155,16 +162,21 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
dnsUpdate.NameServerGroups = c.getPeerNSGroupsFromGroups(targetPeerID, peerGroups)
}
// Build transparent proxy config if this peer is a routing peer with inspection enabled.
// Falls back to the account-level config if set.
tpConfig := c.getTransparentProxyConfig(targetPeerID, isRouter)
return &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: c.Network.Copy(),
Routes: append(networkResourcesRoutes, routesUpdate...),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
AuthorizedUsers: authorizedUsers,
EnableSSH: sshEnabled,
Peers: peersToConnectIncludingRouters,
Network: c.Network.Copy(),
Routes: append(networkResourcesRoutes, routesUpdate...),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
TransparentProxyConfig: tpConfig,
AuthorizedUsers: authorizedUsers,
EnableSSH: sshEnabled,
}
}
@@ -526,7 +538,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -899,3 +910,198 @@ func (c *NetworkMapComponents) addNetworksRoutingPeers(
return peersToConnect
}
// getTransparentProxyConfig builds a TransparentProxyConfig for a routing peer
// by checking if any ACL policy targeting its networks has inspection policies attached.
func (c *NetworkMapComponents) getTransparentProxyConfig(peerID string, isRouter bool) *TransparentProxyConfig {
if c.TransparentProxyConfig != nil {
return c.TransparentProxyConfig
}
if !isRouter {
return nil
}
var networkIDs []string
for networkID, routers := range c.RoutersMap {
if _, ok := routers[peerID]; ok {
networkIDs = append(networkIDs, networkID)
}
}
if len(networkIDs) == 0 {
return nil
}
return c.buildTransparentProxyFromPolicies(networkIDs, peerID)
}
// buildTransparentProxyFromPolicies builds a TransparentProxyConfig from inspection
// policies attached to ACL policies targeting the given networks.
// Proxy infra config (CA, ICAP, mode) comes from the first inspection policy found.
// Rules are aggregated from all inspection policies.
func (c *NetworkMapComponents) buildTransparentProxyFromPolicies(networkIDs []string, peerID string) *TransparentProxyConfig {
var config *TransparentProxyConfig
// Accumulate redirect sources across all networks.
allSources := make(map[string]struct{})
for _, networkID := range networkIDs {
networkPolicies := c.getPoliciesForNetwork(networkID)
for _, policy := range networkPolicies {
for _, ipID := range policy.InspectionPolicies {
ip, ok := c.InspectionPolicies[ipID]
if !ok || !ip.Enabled {
continue
}
// First inspection policy sets the infra config
if config == nil {
config = &TransparentProxyConfig{
Enabled: true,
ExternalURL: ip.ExternalURL,
DefaultAction: toTransparentProxyAction(ip.DefaultAction),
CACertPEM: []byte(ip.CACertPEM),
CAKeyPEM: []byte(ip.CAKeyPEM),
}
switch ip.Mode {
case "envoy":
config.Mode = TransparentProxyModeEnvoy
case "external":
config.Mode = TransparentProxyModeExternal
default:
config.Mode = TransparentProxyModeBuiltin
}
for _, p := range ip.RedirectPorts {
config.RedirectPorts = append(config.RedirectPorts, uint16(p))
}
if ip.ICAP != nil {
config.ICAP = &TransparentProxyICAPConfig{
ReqModURL: ip.ICAP.ReqModURL,
RespModURL: ip.ICAP.RespModURL,
MaxConnections: ip.ICAP.MaxConnections,
}
}
if ip.Mode == "envoy" {
config.EnvoyBinaryPath = ip.EnvoyBinaryPath
config.EnvoyAdminPort = uint16(ip.EnvoyAdminPort)
if ip.EnvoySnippets != nil {
config.EnvoySnippets = &TransparentProxyEnvoySnippets{
HTTPFilters: ip.EnvoySnippets.HTTPFilters,
NetworkFilters: ip.EnvoySnippets.NetworkFilters,
Clusters: ip.EnvoySnippets.Clusters,
}
}
}
}
// Aggregate rules from all inspection policies
for _, pr := range ip.Rules {
rule := TransparentProxyRule{
ID: ip.ID,
Domains: pr.Domains,
Networks: pr.Networks,
Protocols: pr.Protocols,
Paths: pr.Paths,
Action: toTransparentProxyAction(pr.Action),
Priority: pr.Priority,
}
config.Rules = append(config.Rules, rule)
}
}
}
// Collect sources for this network
for _, src := range c.deriveRedirectSourcesFromPolicies(networkID, peerID) {
allSources[src] = struct{}{}
}
}
if config != nil {
config.RedirectSources = make([]string, 0, len(allSources))
for src := range allSources {
config.RedirectSources = append(config.RedirectSources, src)
}
}
return config
}
// deriveRedirectSourcesFromPolicies collects source peer IPs from policies
// that target the given network and have inspection policies attached.
func (c *NetworkMapComponents) deriveRedirectSourcesFromPolicies(networkID, routingPeerID string) []string {
sourceSet := make(map[string]struct{})
for _, policy := range c.getPoliciesForNetwork(networkID) {
if len(policy.InspectionPolicies) == 0 {
continue
}
peerIDs := c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups())
for _, peerID := range peerIDs {
if peerID == routingPeerID {
continue
}
peer := c.GetPeerInfo(peerID)
if peer != nil && peer.IP != nil {
sourceSet[peer.IP.String()+"/32"] = struct{}{}
}
}
}
sources := make([]string, 0, len(sourceSet))
for s := range sourceSet {
sources = append(sources, s)
}
return sources
}
// getPoliciesForNetwork returns all unique policies that have inspection policies attached
// and target resources belonging to the given network.
func (c *NetworkMapComponents) getPoliciesForNetwork(networkID string) []*Policy {
seen := make(map[string]bool)
var result []*Policy
add := func(policy *Policy) {
if len(policy.InspectionPolicies) == 0 || seen[policy.ID] {
return
}
seen[policy.ID] = true
result = append(result, policy)
}
// Only include policies that target resources in the given network.
networkResourceIDs := make(map[string]struct{})
for _, resource := range c.NetworkResources {
if resource.NetworkID == networkID {
networkResourceIDs[resource.ID] = struct{}{}
}
}
for resourceID, policies := range c.ResourcePoliciesMap {
if _, ok := networkResourceIDs[resourceID]; !ok {
continue
}
for _, policy := range policies {
add(policy)
}
}
// Also check classic policies whose destination groups contain peers in this network.
for _, policy := range c.Policies {
add(policy)
}
return result
}
func toTransparentProxyAction(s string) TransparentProxyAction {
switch s {
case "allow":
return TransparentProxyActionAllow
case "inspect":
return TransparentProxyActionInspect
default:
return TransparentProxyActionBlock
}
}

View File

@@ -73,6 +73,10 @@ type Policy struct {
// SourcePostureChecks are ID references to Posture checks for policy source groups
SourcePostureChecks []string `gorm:"serializer:json"`
// InspectionPolicies are ID references to inspection policies applied to traffic matching this policy.
// When set, traffic is routed through a transparent proxy on the destination network's routing peers.
InspectionPolicies []string `gorm:"serializer:json"`
}
// Copy returns a copy of the policy.
@@ -85,11 +89,13 @@ func (p *Policy) Copy() *Policy {
Enabled: p.Enabled,
Rules: make([]*PolicyRule, len(p.Rules)),
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
InspectionPolicies: make([]string, len(p.InspectionPolicies)),
}
for i, r := range p.Rules {
c.Rules[i] = r.Copy()
}
copy(c.SourcePostureChecks, p.SourcePostureChecks)
copy(c.InspectionPolicies, p.InspectionPolicies)
return c
}

View File

@@ -0,0 +1,91 @@
package types
import (
"net/netip"
"slices"
)
// ProxyRouteSet collects and deduplicates the routes that need to be pushed to
// source peers for transparent proxy rules. CIDR rules create specific routes;
// domain-only rules require a catch-all (0.0.0.0/0).
type ProxyRouteSet struct {
// routes is the deduplicated set of destination prefixes to route through the proxy.
routes map[netip.Prefix]struct{}
// needsCatchAll is true if any rule has domains without CIDRs.
needsCatchAll bool
}
// NewProxyRouteSet creates a new route set.
func NewProxyRouteSet() *ProxyRouteSet {
return &ProxyRouteSet{
routes: make(map[netip.Prefix]struct{}),
}
}
// AddFromRule adds route entries derived from a proxy rule's destinations.
// - CIDR destinations create specific routes
// - Domain-only rules (no CIDRs) trigger a catch-all route
// - Rules with neither domains nor CIDRs also trigger catch-all (match all traffic)
func (s *ProxyRouteSet) AddFromRule(rule *InspectionPolicyRule) {
if rule.HasCIDRDestination() {
for _, cidr := range rule.Networks {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
continue
}
s.routes[prefix] = struct{}{}
}
return
}
// Domain-only or no destination: need catch-all
s.needsCatchAll = true
}
// Routes returns the deduplicated list of prefixes to route through the proxy.
// If any rule requires catch-all, returns only ["0.0.0.0/0"] since it subsumes
// all specific CIDRs.
func (s *ProxyRouteSet) Routes() []netip.Prefix {
if s.needsCatchAll {
return []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}
}
result := make([]netip.Prefix, 0, len(s.routes))
for prefix := range s.routes {
result = append(result, prefix)
}
// Sort for deterministic output
slices.SortFunc(result, func(a, b netip.Prefix) int {
if c := a.Addr().Compare(b.Addr()); c != 0 {
return c
}
return a.Bits() - b.Bits()
})
// Remove CIDRs that are subsets of larger CIDRs
return deduplicatePrefixes(result)
}
// deduplicatePrefixes removes prefixes that are contained within other prefixes.
// Input must be sorted.
func deduplicatePrefixes(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) <= 1 {
return prefixes
}
var result []netip.Prefix
for _, p := range prefixes {
subsumed := false
for _, existing := range result {
if existing.Contains(p.Addr()) && existing.Bits() <= p.Bits() {
subsumed = true
break
}
}
if !subsumed {
result = append(result, p)
}
}
return result
}

View File

@@ -0,0 +1,81 @@
package types
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProxyRouteSet_CIDROnly(t *testing.T) {
s := NewProxyRouteSet()
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"10.0.0.0/8", "172.16.0.0/12"},
})
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"192.168.0.0/16"},
})
routes := s.Routes()
require.Len(t, routes, 3)
assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), routes[0])
assert.Equal(t, netip.MustParsePrefix("172.16.0.0/12"), routes[1])
assert.Equal(t, netip.MustParsePrefix("192.168.0.0/16"), routes[2])
}
func TestProxyRouteSet_DomainOnlyForceCatchAll(t *testing.T) {
s := NewProxyRouteSet()
s.AddFromRule(&InspectionPolicyRule{
Domains: []string{"*.gambling.com"},
})
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"10.0.0.0/8"},
})
routes := s.Routes()
require.Len(t, routes, 1)
assert.Equal(t, netip.MustParsePrefix("0.0.0.0/0"), routes[0])
}
func TestProxyRouteSet_EmptyDestinationForceCatchAll(t *testing.T) {
s := NewProxyRouteSet()
s.AddFromRule(&InspectionPolicyRule{
Action: "block",
// No domains, no networks: match all traffic
})
routes := s.Routes()
require.Len(t, routes, 1)
assert.Equal(t, netip.MustParsePrefix("0.0.0.0/0"), routes[0])
}
func TestProxyRouteSet_DeduplicateSubsets(t *testing.T) {
s := NewProxyRouteSet()
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"10.0.0.0/8"},
})
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"10.1.0.0/16"}, // subset of 10.0.0.0/8
})
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"10.1.2.0/24"}, // subset of both
})
routes := s.Routes()
require.Len(t, routes, 1)
assert.Equal(t, netip.MustParsePrefix("10.0.0.0/8"), routes[0])
}
func TestProxyRouteSet_DuplicateCIDRs(t *testing.T) {
s := NewProxyRouteSet()
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"10.0.0.0/8"},
})
s.AddFromRule(&InspectionPolicyRule{
Networks: []string{"10.0.0.0/8"}, // duplicate
})
routes := s.Routes()
require.Len(t, routes, 1)
}

View File

@@ -0,0 +1,158 @@
package types
import (
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// TransparentProxyAction determines the proxy behavior for matched connections.
type TransparentProxyAction int
const (
TransparentProxyActionAllow TransparentProxyAction = 0
TransparentProxyActionBlock TransparentProxyAction = 1
TransparentProxyActionInspect TransparentProxyAction = 2
)
// TransparentProxyMode selects built-in or external proxy operation.
type TransparentProxyMode int
const (
TransparentProxyModeBuiltin TransparentProxyMode = 0
TransparentProxyModeExternal TransparentProxyMode = 1
TransparentProxyModeEnvoy TransparentProxyMode = 2
)
// TransparentProxyConfig holds the transparent proxy configuration for a routing peer.
type TransparentProxyConfig struct {
Enabled bool
Mode TransparentProxyMode
ExternalURL string
DefaultAction TransparentProxyAction
// RedirectSources is the set of source CIDRs to intercept.
RedirectSources []string
RedirectPorts []uint16
Rules []TransparentProxyRule
ICAP *TransparentProxyICAPConfig
CACertPEM []byte
CAKeyPEM []byte
ListenPort uint16
// Envoy sidecar fields (ModeEnvoy only)
EnvoyBinaryPath string
EnvoyAdminPort uint16
EnvoySnippets *TransparentProxyEnvoySnippets
}
// TransparentProxyEnvoySnippets holds user-provided YAML fragments for envoy config.
type TransparentProxyEnvoySnippets struct {
HTTPFilters string
NetworkFilters string
Clusters string
}
// TransparentProxyRule is an L7 inspection rule evaluated by the proxy engine.
type TransparentProxyRule struct {
ID string
// Domains are domain patterns, e.g. "*.example.com".
Domains []string
// Networks restricts this rule to specific destination CIDRs.
Networks []string
// Protocols this rule applies to: "http", "https", "h2", "h3", "websocket", "other".
Protocols []string
// Paths are URL path patterns: "/api/", "/login", "/admin/*".
Paths []string
Action TransparentProxyAction
Priority int
}
// TransparentProxyICAPConfig holds ICAP service configuration.
type TransparentProxyICAPConfig struct {
ReqModURL string
RespModURL string
MaxConnections int
}
// ToProto converts the internal config to the proto representation.
func (c *TransparentProxyConfig) ToProto() *proto.TransparentProxyConfig {
if c == nil {
return nil
}
pc := &proto.TransparentProxyConfig{
Enabled: c.Enabled,
Mode: proto.TransparentProxyMode(c.Mode),
DefaultAction: proto.TransparentProxyAction(c.DefaultAction),
CaCertPem: c.CACertPEM,
CaKeyPem: c.CAKeyPEM,
ListenPort: uint32(c.ListenPort),
}
if c.ExternalURL != "" {
pc.ExternalProxyUrl = c.ExternalURL
}
if c.Mode == TransparentProxyModeEnvoy {
pc.EnvoyBinaryPath = c.EnvoyBinaryPath
pc.EnvoyAdminPort = uint32(c.EnvoyAdminPort)
if c.EnvoySnippets != nil {
pc.EnvoySnippets = &proto.TransparentProxyEnvoySnippets{
HttpFilters: c.EnvoySnippets.HTTPFilters,
NetworkFilters: c.EnvoySnippets.NetworkFilters,
Clusters: c.EnvoySnippets.Clusters,
}
}
}
pc.RedirectSources = c.RedirectSources
redirectPorts := make([]uint32, len(c.RedirectPorts))
for i, p := range c.RedirectPorts {
redirectPorts[i] = uint32(p)
}
pc.RedirectPorts = redirectPorts
for _, r := range c.Rules {
pr := &proto.TransparentProxyRule{
Id: r.ID,
Domains: r.Domains,
Networks: r.Networks,
Paths: r.Paths,
Action: proto.TransparentProxyAction(r.Action),
Priority: int32(r.Priority),
}
for _, p := range r.Protocols {
pr.Protocols = append(pr.Protocols, stringToProtoProtocol(p))
}
pc.Rules = append(pc.Rules, pr)
}
if c.ICAP != nil {
pc.Icap = &proto.TransparentProxyICAPConfig{
ReqmodUrl: c.ICAP.ReqModURL,
RespmodUrl: c.ICAP.RespModURL,
MaxConnections: int32(c.ICAP.MaxConnections),
}
}
return pc
}
// stringToProtoProtocol maps a protocol string to its proto enum value.
func stringToProtoProtocol(s string) proto.TransparentProxyProtocol {
switch s {
case "http":
return proto.TransparentProxyProtocol_TP_PROTO_HTTP
case "https":
return proto.TransparentProxyProtocol_TP_PROTO_HTTPS
case "h2":
return proto.TransparentProxyProtocol_TP_PROTO_H2
case "h3":
return proto.TransparentProxyProtocol_TP_PROTO_H3
case "websocket":
return proto.TransparentProxyProtocol_TP_PROTO_WEBSOCKET
case "other":
return proto.TransparentProxyProtocol_TP_PROTO_OTHER
default:
return proto.TransparentProxyProtocol_TP_PROTO_ALL
}
}

View File

@@ -1533,6 +1533,12 @@ components:
items:
type: string
example: "chacdk86lnnboviihd70"
inspection_policies:
description: Inspection policy IDs applied to traffic matching this policy. When set, traffic is routed through a transparent proxy on the destination network's routing peers.
type: array
items:
type: string
example: "chacdk86lnnboviihd71"
rules:
description: Policy rule object for policy UI editor
type: array
@@ -1551,6 +1557,12 @@ components:
items:
type: string
example: "chacdk86lnnboviihd70"
inspection_policies:
description: Inspection policy IDs applied to traffic matching this policy
type: array
items:
type: string
example: "chacdk86lnnboviihd71"
rules:
description: Policy rule object for policy UI editor
type: array
@@ -1573,6 +1585,12 @@ components:
items:
type: string
example: "chacdk86lnnboviihd70"
inspection_policies:
description: Inspection policy IDs applied to traffic matching this policy
type: array
items:
type: string
example: "chacdk86lnnboviihd71"
rules:
description: Policy rule object for policy UI editor
type: array
@@ -2050,6 +2068,9 @@ components:
description: Network router status
type: boolean
example: true
inspection:
description: Optional traffic inspection configuration. When enabled, traffic through this routing peer is transparently proxied and inspected.
$ref: '#/components/schemas/RouterInspectionConfig'
required:
# Only one property has to be set
#- peer
@@ -2057,6 +2078,174 @@ components:
- metric
- masquerade
- enabled
RouterInspectionConfig:
type: object
properties:
enabled:
description: Whether traffic inspection is active on this routing peer
type: boolean
example: false
mode:
description: Inspection mode
type: string
enum: [ "builtin", "envoy", "external" ]
example: builtin
external_url:
description: External proxy URL (http:// or socks5://) when mode is external
type: string
example: "http://proxy.corp:8080"
default_action:
description: Action when no inspection rule matches
type: string
enum: [ "allow", "block", "inspect" ]
example: allow
redirect_ports:
description: Destination ports to intercept. Empty means all ports.
type: array
items:
type: integer
example: [80, 443]
icap:
description: ICAP service configuration for external content scanning
$ref: '#/components/schemas/InspectionICAPConfig'
ca_cert_pem:
description: PEM-encoded CA certificate for MITM TLS inspection
type: string
ca_key_pem:
description: PEM-encoded CA private key for MITM TLS inspection
type: string
envoy_binary_path:
description: Path to envoy binary when mode is envoy. Empty searches $PATH.
type: string
envoy_admin_port:
description: Envoy admin API port for health checks. 0 picks a free port.
type: integer
required:
- enabled
InspectionPolicyMinimum:
type: object
properties:
name:
description: Human-readable name for this inspection policy
type: string
example: "Corporate web filtering"
description:
description: Description
type: string
enabled:
description: Whether this inspection policy is active
type: boolean
example: true
rules:
description: L7 inspection rules
type: array
items:
$ref: '#/components/schemas/InspectionPolicyRule'
mode:
description: Proxy operation mode
type: string
enum: [ "builtin", "envoy", "external" ]
example: "builtin"
external_url:
description: External proxy URL (HTTP CONNECT or SOCKS5) when mode is external
type: string
example: "socks5://proxy.corp.com:1080"
default_action:
description: Action for recognized traffic when no rule matches
type: string
enum: [ "allow", "block", "inspect" ]
example: "allow"
redirect_ports:
description: Destination ports to intercept at L4. Empty means all ports.
type: array
items:
type: integer
example: [80, 443]
ca_cert_pem:
description: PEM-encoded CA certificate for MITM TLS inspection
type: string
ca_key_pem:
description: PEM-encoded CA private key for MITM TLS inspection
type: string
envoy_binary_path:
description: Path to envoy binary when mode is envoy. Empty searches $PATH.
type: string
envoy_admin_port:
description: Envoy admin API port for health checks. 0 picks a free port.
type: integer
icap:
description: ICAP configuration for external content scanning
$ref: '#/components/schemas/InspectionICAPConfig'
required:
- name
- enabled
- rules
InspectionPolicy:
allOf:
- type: object
properties:
id:
description: Inspection Policy ID
type: string
readOnly: true
required:
- id
- $ref: '#/components/schemas/InspectionPolicyMinimum'
InspectionPolicyRule:
type: object
properties:
domains:
description: Domain patterns to match via SNI or Host header. Supports wildcards (*.example.com).
type: array
items:
type: string
example: ["*.gambling.com", "*.betting.com"]
networks:
description: Destination CIDRs for optional L7 destination filtering
type: array
items:
type: string
example: ["10.0.0.0/8"]
protocols:
description: Protocols this rule applies to. Empty means all.
type: array
items:
type: string
enum: [ "http", "https", "h2", "h3", "websocket", "other" ]
example: ["https", "h2"]
paths:
description: URL path patterns. Exact ("/login"), prefix ("/api/*"), contains ("*/admin/*"). HTTPS requires inspect (MITM). Empty means all paths.
type: array
items:
type: string
example: ["/admin/*", "/api/internal/*"]
action:
description: What to do with matched connections
type: string
enum: [ "allow", "block", "inspect" ]
example: block
priority:
description: Evaluation order. Lower values are evaluated first.
type: integer
example: 1
required:
- action
- priority
InspectionICAPConfig:
type: object
properties:
reqmod_url:
description: ICAP REQMOD service URL
type: string
example: "icap://icap-server:1344/reqmod"
respmod_url:
description: ICAP RESPMOD service URL
type: string
example: "icap://icap-server:1344/respmod"
max_connections:
description: Maximum ICAP connection pool size
type: integer
example: 8
NetworkRouter:
allOf:
- type: object
@@ -7410,6 +7599,144 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/inspection-policies:
get:
summary: List all Inspection Policies
description: Returns a list of all reusable inspection policy rule sets
tags: [ Inspection Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Inspection Policies
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/InspectionPolicy'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create an Inspection Policy
description: Creates a reusable inspection policy rule set
tags: [ Inspection Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InspectionPolicyMinimum'
responses:
'200':
description: An Inspection Policy object
content:
application/json:
schema:
$ref: '#/components/schemas/InspectionPolicy'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/inspection-policies/{policyId}:
get:
summary: Get an Inspection Policy
description: Returns an inspection policy rule set
tags: [ Inspection Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: policyId
required: true
schema:
type: string
responses:
'200':
description: An Inspection Policy object
content:
application/json:
schema:
$ref: '#/components/schemas/InspectionPolicy'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update an Inspection Policy
description: Updates an inspection policy rule set
tags: [ Inspection Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: policyId
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InspectionPolicyMinimum'
responses:
'200':
description: An Inspection Policy object
content:
application/json:
schema:
$ref: '#/components/schemas/InspectionPolicy'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete an Inspection Policy
description: Deletes an inspection policy rule set
tags: [ Inspection Policies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: policyId
required: true
schema:
type: string
responses:
'200':
description: Successfully deleted
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers:
get:
summary: List all Nameserver Groups

View File

@@ -584,6 +584,141 @@ func (e IngressPortAllocationRequestPortRangeProtocol) Valid() bool {
}
}
// Defines values for InspectionPolicyDefaultAction.
const (
InspectionPolicyDefaultActionAllow InspectionPolicyDefaultAction = "allow"
InspectionPolicyDefaultActionBlock InspectionPolicyDefaultAction = "block"
InspectionPolicyDefaultActionInspect InspectionPolicyDefaultAction = "inspect"
)
// Valid indicates whether the value is a known member of the InspectionPolicyDefaultAction enum.
func (e InspectionPolicyDefaultAction) Valid() bool {
switch e {
case InspectionPolicyDefaultActionAllow:
return true
case InspectionPolicyDefaultActionBlock:
return true
case InspectionPolicyDefaultActionInspect:
return true
default:
return false
}
}
// Defines values for InspectionPolicyMode.
const (
InspectionPolicyModeBuiltin InspectionPolicyMode = "builtin"
InspectionPolicyModeEnvoy InspectionPolicyMode = "envoy"
InspectionPolicyModeExternal InspectionPolicyMode = "external"
)
// Valid indicates whether the value is a known member of the InspectionPolicyMode enum.
func (e InspectionPolicyMode) Valid() bool {
switch e {
case InspectionPolicyModeBuiltin:
return true
case InspectionPolicyModeEnvoy:
return true
case InspectionPolicyModeExternal:
return true
default:
return false
}
}
// Defines values for InspectionPolicyMinimumDefaultAction.
const (
InspectionPolicyMinimumDefaultActionAllow InspectionPolicyMinimumDefaultAction = "allow"
InspectionPolicyMinimumDefaultActionBlock InspectionPolicyMinimumDefaultAction = "block"
InspectionPolicyMinimumDefaultActionInspect InspectionPolicyMinimumDefaultAction = "inspect"
)
// Valid indicates whether the value is a known member of the InspectionPolicyMinimumDefaultAction enum.
func (e InspectionPolicyMinimumDefaultAction) Valid() bool {
switch e {
case InspectionPolicyMinimumDefaultActionAllow:
return true
case InspectionPolicyMinimumDefaultActionBlock:
return true
case InspectionPolicyMinimumDefaultActionInspect:
return true
default:
return false
}
}
// Defines values for InspectionPolicyMinimumMode.
const (
InspectionPolicyMinimumModeBuiltin InspectionPolicyMinimumMode = "builtin"
InspectionPolicyMinimumModeEnvoy InspectionPolicyMinimumMode = "envoy"
InspectionPolicyMinimumModeExternal InspectionPolicyMinimumMode = "external"
)
// Valid indicates whether the value is a known member of the InspectionPolicyMinimumMode enum.
func (e InspectionPolicyMinimumMode) Valid() bool {
switch e {
case InspectionPolicyMinimumModeBuiltin:
return true
case InspectionPolicyMinimumModeEnvoy:
return true
case InspectionPolicyMinimumModeExternal:
return true
default:
return false
}
}
// Defines values for InspectionPolicyRuleAction.
const (
InspectionPolicyRuleActionAllow InspectionPolicyRuleAction = "allow"
InspectionPolicyRuleActionBlock InspectionPolicyRuleAction = "block"
InspectionPolicyRuleActionInspect InspectionPolicyRuleAction = "inspect"
)
// Valid indicates whether the value is a known member of the InspectionPolicyRuleAction enum.
func (e InspectionPolicyRuleAction) Valid() bool {
switch e {
case InspectionPolicyRuleActionAllow:
return true
case InspectionPolicyRuleActionBlock:
return true
case InspectionPolicyRuleActionInspect:
return true
default:
return false
}
}
// Defines values for InspectionPolicyRuleProtocols.
const (
InspectionPolicyRuleProtocolsH2 InspectionPolicyRuleProtocols = "h2"
InspectionPolicyRuleProtocolsH3 InspectionPolicyRuleProtocols = "h3"
InspectionPolicyRuleProtocolsHttp InspectionPolicyRuleProtocols = "http"
InspectionPolicyRuleProtocolsHttps InspectionPolicyRuleProtocols = "https"
InspectionPolicyRuleProtocolsOther InspectionPolicyRuleProtocols = "other"
InspectionPolicyRuleProtocolsWebsocket InspectionPolicyRuleProtocols = "websocket"
)
// Valid indicates whether the value is a known member of the InspectionPolicyRuleProtocols enum.
func (e InspectionPolicyRuleProtocols) Valid() bool {
switch e {
case InspectionPolicyRuleProtocolsH2:
return true
case InspectionPolicyRuleProtocolsH3:
return true
case InspectionPolicyRuleProtocolsHttp:
return true
case InspectionPolicyRuleProtocolsHttps:
return true
case InspectionPolicyRuleProtocolsOther:
return true
case InspectionPolicyRuleProtocolsWebsocket:
return true
default:
return false
}
}
// Defines values for IntegrationResponsePlatform.
const (
IntegrationResponsePlatformDatadog IntegrationResponsePlatform = "datadog"
@@ -896,6 +1031,48 @@ func (e ReverseProxyDomainType) Valid() bool {
}
}
// Defines values for RouterInspectionConfigDefaultAction.
const (
RouterInspectionConfigDefaultActionAllow RouterInspectionConfigDefaultAction = "allow"
RouterInspectionConfigDefaultActionBlock RouterInspectionConfigDefaultAction = "block"
RouterInspectionConfigDefaultActionInspect RouterInspectionConfigDefaultAction = "inspect"
)
// Valid indicates whether the value is a known member of the RouterInspectionConfigDefaultAction enum.
func (e RouterInspectionConfigDefaultAction) Valid() bool {
switch e {
case RouterInspectionConfigDefaultActionAllow:
return true
case RouterInspectionConfigDefaultActionBlock:
return true
case RouterInspectionConfigDefaultActionInspect:
return true
default:
return false
}
}
// Defines values for RouterInspectionConfigMode.
const (
RouterInspectionConfigModeBuiltin RouterInspectionConfigMode = "builtin"
RouterInspectionConfigModeEnvoy RouterInspectionConfigMode = "envoy"
RouterInspectionConfigModeExternal RouterInspectionConfigMode = "external"
)
// Valid indicates whether the value is a known member of the RouterInspectionConfigMode enum.
func (e RouterInspectionConfigMode) Valid() bool {
switch e {
case RouterInspectionConfigModeBuiltin:
return true
case RouterInspectionConfigModeEnvoy:
return true
case RouterInspectionConfigModeExternal:
return true
default:
return false
}
}
// Defines values for SentinelOneMatchAttributesNetworkStatus.
const (
SentinelOneMatchAttributesNetworkStatusConnected SentinelOneMatchAttributesNetworkStatus = "connected"
@@ -2464,6 +2641,140 @@ type IngressPortAllocationRequestPortRange struct {
// IngressPortAllocationRequestPortRangeProtocol The protocol accepted by the port range
type IngressPortAllocationRequestPortRangeProtocol string
// InspectionICAPConfig defines model for InspectionICAPConfig.
type InspectionICAPConfig struct {
// MaxConnections Maximum ICAP connection pool size
MaxConnections *int `json:"max_connections,omitempty"`
// ReqmodUrl ICAP REQMOD service URL
ReqmodUrl *string `json:"reqmod_url,omitempty"`
// RespmodUrl ICAP RESPMOD service URL
RespmodUrl *string `json:"respmod_url,omitempty"`
}
// InspectionPolicy defines model for InspectionPolicy.
type InspectionPolicy struct {
// CaCertPem PEM-encoded CA certificate for MITM TLS inspection
CaCertPem *string `json:"ca_cert_pem,omitempty"`
// CaKeyPem PEM-encoded CA private key for MITM TLS inspection
CaKeyPem *string `json:"ca_key_pem,omitempty"`
// DefaultAction Action for recognized traffic when no rule matches
DefaultAction *InspectionPolicyDefaultAction `json:"default_action,omitempty"`
// Description Description
Description *string `json:"description,omitempty"`
// Enabled Whether this inspection policy is active
Enabled bool `json:"enabled"`
// EnvoyAdminPort Envoy admin API port for health checks. 0 picks a free port.
EnvoyAdminPort *int `json:"envoy_admin_port,omitempty"`
// EnvoyBinaryPath Path to envoy binary when mode is envoy. Empty searches $PATH.
EnvoyBinaryPath *string `json:"envoy_binary_path,omitempty"`
// ExternalUrl External proxy URL (HTTP CONNECT or SOCKS5) when mode is external
ExternalUrl *string `json:"external_url,omitempty"`
Icap *InspectionICAPConfig `json:"icap,omitempty"`
// Id Inspection Policy ID
Id *string `json:"id,omitempty"`
// Mode Proxy operation mode
Mode *InspectionPolicyMode `json:"mode,omitempty"`
// Name Human-readable name for this inspection policy
Name string `json:"name"`
// RedirectPorts Destination ports to intercept at L4. Empty means all ports.
RedirectPorts *[]int `json:"redirect_ports,omitempty"`
// Rules L7 inspection rules
Rules []InspectionPolicyRule `json:"rules"`
}
// InspectionPolicyDefaultAction Action for recognized traffic when no rule matches
type InspectionPolicyDefaultAction string
// InspectionPolicyMode Proxy operation mode
type InspectionPolicyMode string
// InspectionPolicyMinimum defines model for InspectionPolicyMinimum.
type InspectionPolicyMinimum struct {
// CaCertPem PEM-encoded CA certificate for MITM TLS inspection
CaCertPem *string `json:"ca_cert_pem,omitempty"`
// CaKeyPem PEM-encoded CA private key for MITM TLS inspection
CaKeyPem *string `json:"ca_key_pem,omitempty"`
// DefaultAction Action for recognized traffic when no rule matches
DefaultAction *InspectionPolicyMinimumDefaultAction `json:"default_action,omitempty"`
// Description Description
Description *string `json:"description,omitempty"`
// Enabled Whether this inspection policy is active
Enabled bool `json:"enabled"`
// EnvoyAdminPort Envoy admin API port for health checks. 0 picks a free port.
EnvoyAdminPort *int `json:"envoy_admin_port,omitempty"`
// EnvoyBinaryPath Path to envoy binary when mode is envoy. Empty searches $PATH.
EnvoyBinaryPath *string `json:"envoy_binary_path,omitempty"`
// ExternalUrl External proxy URL (HTTP CONNECT or SOCKS5) when mode is external
ExternalUrl *string `json:"external_url,omitempty"`
Icap *InspectionICAPConfig `json:"icap,omitempty"`
// Mode Proxy operation mode
Mode *InspectionPolicyMinimumMode `json:"mode,omitempty"`
// Name Human-readable name for this inspection policy
Name string `json:"name"`
// RedirectPorts Destination ports to intercept at L4. Empty means all ports.
RedirectPorts *[]int `json:"redirect_ports,omitempty"`
// Rules L7 inspection rules
Rules []InspectionPolicyRule `json:"rules"`
}
// InspectionPolicyMinimumDefaultAction Action for recognized traffic when no rule matches
type InspectionPolicyMinimumDefaultAction string
// InspectionPolicyMinimumMode Proxy operation mode
type InspectionPolicyMinimumMode string
// InspectionPolicyRule defines model for InspectionPolicyRule.
type InspectionPolicyRule struct {
// Action What to do with matched connections
Action InspectionPolicyRuleAction `json:"action"`
// Domains Domain patterns to match via SNI or Host header. Supports wildcards (*.example.com).
Domains *[]string `json:"domains,omitempty"`
// Networks Destination CIDRs for optional L7 destination filtering
Networks *[]string `json:"networks,omitempty"`
// Paths URL path patterns. Exact ("/login"), prefix ("/api/*"), contains ("*/admin/*"). HTTPS requires inspect (MITM). Empty means all paths.
Paths *[]string `json:"paths,omitempty"`
// Priority Evaluation order. Lower values are evaluated first.
Priority int `json:"priority"`
// Protocols Protocols this rule applies to. Empty means all.
Protocols *[]InspectionPolicyRuleProtocols `json:"protocols,omitempty"`
}
// InspectionPolicyRuleAction What to do with matched connections
type InspectionPolicyRuleAction string
// InspectionPolicyRuleProtocols defines model for InspectionPolicyRule.Protocols.
type InspectionPolicyRuleProtocols string
// InstanceStatus Instance status information
type InstanceStatus struct {
// SetupRequired Indicates whether the instance requires initial setup
@@ -2774,7 +3085,8 @@ type NetworkRouter struct {
Enabled bool `json:"enabled"`
// Id Network Router Id
Id string `json:"id"`
Id string `json:"id"`
Inspection *RouterInspectionConfig `json:"inspection,omitempty"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
@@ -2792,7 +3104,8 @@ type NetworkRouter struct {
// NetworkRouterRequest defines model for NetworkRouterRequest.
type NetworkRouterRequest struct {
// Enabled Network router status
Enabled bool `json:"enabled"`
Enabled bool `json:"enabled"`
Inspection *RouterInspectionConfig `json:"inspection,omitempty"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
@@ -3380,6 +3693,9 @@ type Policy struct {
// Id Policy ID
Id *string `json:"id,omitempty"`
// InspectionPolicies Inspection policy IDs applied to traffic matching this policy
InspectionPolicies *[]string `json:"inspection_policies,omitempty"`
// Name Policy name identifier
Name string `json:"name"`
@@ -3398,6 +3714,9 @@ type PolicyCreate struct {
// Enabled Policy status
Enabled bool `json:"enabled"`
// InspectionPolicies Inspection policy IDs applied to traffic matching this policy
InspectionPolicies *[]string `json:"inspection_policies,omitempty"`
// Name Policy name identifier
Name string `json:"name"`
@@ -3558,6 +3877,9 @@ type PolicyUpdate struct {
// Enabled Policy status
Enabled bool `json:"enabled"`
// InspectionPolicies Inspection policy IDs applied to traffic matching this policy. When set, traffic is routed through a transparent proxy on the destination network's routing peers.
InspectionPolicies *[]string `json:"inspection_policies,omitempty"`
// Name Policy name identifier
Name string `json:"name"`
@@ -3874,6 +4196,43 @@ type RouteRequest struct {
SkipAutoApply *bool `json:"skip_auto_apply,omitempty"`
}
// RouterInspectionConfig defines model for RouterInspectionConfig.
type RouterInspectionConfig struct {
// CaCertPem PEM-encoded CA certificate for MITM TLS inspection
CaCertPem *string `json:"ca_cert_pem,omitempty"`
// CaKeyPem PEM-encoded CA private key for MITM TLS inspection
CaKeyPem *string `json:"ca_key_pem,omitempty"`
// DefaultAction Action when no inspection rule matches
DefaultAction *RouterInspectionConfigDefaultAction `json:"default_action,omitempty"`
// Enabled Whether traffic inspection is active on this routing peer
Enabled bool `json:"enabled"`
// EnvoyAdminPort Envoy admin API port for health checks. 0 picks a free port.
EnvoyAdminPort *int `json:"envoy_admin_port,omitempty"`
// EnvoyBinaryPath Path to envoy binary when mode is envoy. Empty searches $PATH.
EnvoyBinaryPath *string `json:"envoy_binary_path,omitempty"`
// ExternalUrl External proxy URL (http:// or socks5://) when mode is external
ExternalUrl *string `json:"external_url,omitempty"`
Icap *InspectionICAPConfig `json:"icap,omitempty"`
// Mode Inspection mode
Mode *RouterInspectionConfigMode `json:"mode,omitempty"`
// RedirectPorts Destination ports to intercept. Empty means all ports.
RedirectPorts *[]int `json:"redirect_ports,omitempty"`
}
// RouterInspectionConfigDefaultAction Action when no inspection rule matches
type RouterInspectionConfigDefaultAction string
// RouterInspectionConfigMode Inspection mode
type RouterInspectionConfigMode string
// RulePortRange Policy rule affected ports range
type RulePortRange struct {
// End The ending port of the range
@@ -4959,6 +5318,12 @@ type PostApiIngressPeersJSONRequestBody = IngressPeerCreateRequest
// PutApiIngressPeersIngressPeerIdJSONRequestBody defines body for PutApiIngressPeersIngressPeerId for application/json ContentType.
type PutApiIngressPeersIngressPeerIdJSONRequestBody = IngressPeerUpdateRequest
// PostApiInspectionPoliciesJSONRequestBody defines body for PostApiInspectionPolicies for application/json ContentType.
type PostApiInspectionPoliciesJSONRequestBody = InspectionPolicyMinimum
// PutApiInspectionPoliciesPolicyIdJSONRequestBody defines body for PutApiInspectionPoliciesPolicyId for application/json ContentType.
type PutApiInspectionPoliciesPolicyIdJSONRequestBody = InspectionPolicyMinimum
// CreateAzureIntegrationJSONRequestBody defines body for CreateAzureIntegration for application/json ContentType.
type CreateAzureIntegrationJSONRequestBody = CreateAzureIntegrationRequest

File diff suppressed because it is too large Load Diff

View File

@@ -387,6 +387,9 @@ message NetworkMap {
// SSHAuth represents SSH authorization configuration
SSHAuth sshAuth = 13;
// TransparentProxyConfig represents transparent proxy configuration for this peer
TransparentProxyConfig transparentProxyConfig = 14;
}
message SSHAuth {
@@ -684,3 +687,90 @@ message StopExposeRequest {
}
message StopExposeResponse {}
// TransparentProxyConfig configures the transparent forward proxy on a routing peer.
message TransparentProxyConfig {
bool enabled = 1;
TransparentProxyMode mode = 2;
// External proxy URL for MODE_EXTERNAL (http:// or socks5://)
string externalProxyUrl = 3;
TransparentProxyAction defaultAction = 4;
// L3/L4 interception: which traffic gets redirected to the proxy.
// Admin decides: activate for these users/subnets on these ports.
// Used for both kernel TPROXY rules and userspace forwarder source filtering.
repeated string redirectSources = 5;
// Destination ports to intercept. Empty means all ports.
repeated uint32 redirectPorts = 6;
// L7 inspection rules: what the proxy does with intercepted traffic.
repeated TransparentProxyRule rules = 7;
TransparentProxyICAPConfig icap = 8;
// MITM CA certificate in PEM format
bytes caCertPem = 9;
// MITM CA private key in PEM format
bytes caKeyPem = 10;
// TPROXY listen port for kernel mode. 0 means auto-assign.
uint32 listenPort = 11;
// Envoy sidecar configuration (MODE_ENVOY only)
string envoyBinaryPath = 12;
uint32 envoyAdminPort = 13;
TransparentProxyEnvoySnippets envoySnippets = 14;
}
enum TransparentProxyMode {
TP_MODE_BUILTIN = 0;
TP_MODE_EXTERNAL = 1;
TP_MODE_ENVOY = 2;
}
enum TransparentProxyAction {
TP_ACTION_ALLOW = 0;
TP_ACTION_BLOCK = 1;
TP_ACTION_INSPECT = 2;
}
enum TransparentProxyProtocol {
TP_PROTO_ALL = 0;
TP_PROTO_HTTP = 1;
TP_PROTO_HTTPS = 2;
TP_PROTO_H2 = 3;
TP_PROTO_H3 = 4;
TP_PROTO_WEBSOCKET = 5;
TP_PROTO_OTHER = 6;
}
// TransparentProxyRule is an L7 inspection rule evaluated by the proxy engine.
message TransparentProxyRule {
string id = 1;
// Domain patterns to match via SNI or Host header (e.g., *.example.com)
repeated string domains = 2;
// Destination CIDRs for optional L7 destination filtering
repeated string networks = 3;
// Destination ports for optional per-rule port filtering
repeated uint32 ports = 4;
TransparentProxyAction action = 5;
int32 priority = 6;
// Protocols to match. Empty means all protocols.
repeated TransparentProxyProtocol protocols = 7;
// URL path patterns to match (HTTP only, requires inspect for HTTPS).
// Supports prefix ("/api/"), exact ("/login"), and wildcard ("/admin/*").
repeated string paths = 8;
}
message TransparentProxyICAPConfig {
string reqmodUrl = 1;
string respmodUrl = 2;
int32 maxConnections = 3;
}
message TransparentProxyEnvoySnippets {
// YAML injected into the HCM filter chain before the router filter.
string httpFilters = 1;
// YAML for additional upstream clusters referenced by filters.
string clusters = 2;
// YAML injected into the TLS filter chain before tcp_proxy (L4 filters).
string networkFilters = 3;
}