Compare commits

...

11 Commits

Author SHA1 Message Date
Viktor Liu
c4a6dafd27 [client] Use GPO DNS Policy Config to configure DNS if present (#3319) 2025-02-13 18:17:18 +01:00
Zoltan Papp
a930c2aecf Fix priority handling (#3313) 2025-02-13 15:48:10 +01:00
Pedro Maia Costa
d48edb9837 fix integration tests (#3311) 2025-02-12 11:16:51 +00:00
Viktor Liu
b41de7fcd1 [client] Enable userspace forwarder conditionally (#3309)
* Enable userspace forwarder conditionally

* Move disable/enable logic
2025-02-12 11:10:49 +01:00
Viktor Liu
18f84f0df5 [client] Check for fwmark support and use fallback routing if not supported (#3220) 2025-02-11 13:09:17 +01:00
Viktor Liu
44407a158a [client] Fix dns handler chain test (#3307) 2025-02-11 12:42:04 +01:00
Viktor Liu
488b697479 [client] Support dns upstream failover for nameserver groups with same match domain (#3178) 2025-02-10 18:13:34 +01:00
Zoltan Papp
5953b43ead [client, relay] Fix/wg watch (#3261)
Fix WireGuard watcher related issues

- Fix race handling between TURN and Relayed reconnection
- Move the WgWatcher logic to separate struct
- Handle timeouts in a more defensive way
- Fix initial Relay client reconnection to the home server
2025-02-10 10:32:50 +01:00
ransomware
58b2eb4b92 [signal] Fix context propagation in signal server (#3251) 2025-02-07 15:05:41 +01:00
Viktor Liu
05415f72ec [client] Add experimental support for userspace routing (#3134) 2025-02-07 14:11:53 +01:00
Pascal Fischer
b7af53ea40 [management] add logs for grpc API (#3298) 2025-02-07 13:51:17 +01:00
107 changed files with 6551 additions and 1193 deletions

View File

@@ -9,6 +9,7 @@ USER netbird:netbird
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true ENV NB_USE_NETSTACK_MODE=true
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
ENV NB_CONFIG=config.json ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true ENV NB_DISABLE_DNS=true

137
client/cmd/trace.go Normal file
View File

@@ -0,0 +1,137 @@
package cmd
import (
"fmt"
"math/rand"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var traceCmd = &cobra.Command{
Use: "trace <direction> <source-ip> <dest-ip>",
Short: "Trace a packet through the firewall",
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,
}
func init() {
debugCmd.AddCommand(traceCmd)
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
traceCmd.Flags().Uint16("sport", 0, "Source port")
traceCmd.Flags().Uint16("dport", 0, "Destination port")
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
}
func tracePacket(cmd *cobra.Command, args []string) error {
direction := strings.ToLower(args[0])
if direction != "in" && direction != "out" {
return fmt.Errorf("invalid direction: use 'in' or 'out'")
}
protocol := cmd.Flag("protocol").Value.String()
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
}
sport, err := cmd.Flags().GetUint16("sport")
if err != nil {
return fmt.Errorf("invalid source port: %v", err)
}
dport, err := cmd.Flags().GetUint16("dport")
if err != nil {
return fmt.Errorf("invalid destination port: %v", err)
}
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
if protocol != "icmp" {
if sport == 0 {
sport = uint16(rand.Intn(16383) + 49152)
}
if dport == 0 {
dport = uint16(rand.Intn(16383) + 49152)
}
}
var tcpFlags *proto.TCPFlags
if protocol == "tcp" {
syn, _ := cmd.Flags().GetBool("syn")
ack, _ := cmd.Flags().GetBool("ack")
fin, _ := cmd.Flags().GetBool("fin")
rst, _ := cmd.Flags().GetBool("rst")
psh, _ := cmd.Flags().GetBool("psh")
urg, _ := cmd.Flags().GetBool("urg")
tcpFlags = &proto.TCPFlags{
Syn: syn,
Ack: ack,
Fin: fin,
Rst: rst,
Psh: psh,
Urg: urg,
}
}
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
SourceIp: args[1],
DestinationIp: args[2],
Protocol: protocol,
SourcePort: uint32(sport),
DestinationPort: uint32(dport),
Direction: direction,
TcpFlags: tcpFlags,
IcmpType: &icmpType,
IcmpCode: &icmpCode,
})
if err != nil {
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
}
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
return nil
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
} else {
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
}
}
disposition := map[bool]string{
true: "\033[32mALLOWED\033[0m", // Green
false: "\033[31mDENIED\033[0m", // Red
}[resp.FinalDisposition]
cmd.Printf("\nFinal disposition: %s\n", disposition)
}

View File

@@ -14,13 +14,13 @@ import (
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface) fm, err := uspfilter.Create(iface, disableServerRoutes)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager) fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return fm, err return fm, err
@@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm) return createUserspaceFirewall(iface, fm, disableServerRoutes)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface) fm, err := createFW(iface)
if err != nil { if err != nil {
return nil, fmt.Errorf("create firewall: %s", err) return nil, fmt.Errorf("create firewall: %s", err)
@@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
} else { } else {
fm, errUsp = uspfilter.Create(iface) fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
} }
if errUsp != nil { if errUsp != nil {

View File

@@ -1,6 +1,8 @@
package firewall package firewall
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -10,4 +12,6 @@ type IFaceMapper interface {
Address() device.WGAddress Address() device.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
} }

View File

@@ -213,6 +213,19 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering(
} }
rule := genRouteFilteringRuleSpec(params) rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { // Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
} else {
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
}
if err != nil {
return nil, fmt.Errorf("add route rule: %v", err) return nil, fmt.Errorf("add route rule: %v", err)
} }

View File

@@ -99,6 +99,12 @@ type Manager interface {
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
SetLogLevel(log.Level)
EnableRouting() error
DisableRouting() error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -318,6 +318,19 @@ func (m *Manager) cleanupNetbirdTables() error {
return nil return nil
} }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer
// //
// Method also get all rules after flush and refreshes handle values in the rulesets // Method also get all rules after flush and refreshes handle values in the rulesets

View File

@@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },
} }
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
@@ -307,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr = runIptablesSave(t) stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
} }
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")
for i := range got {
if _, isCounter := got[i].(*expr.Counter); isCounter {
_, wantIsCounter := want[i].(*expr.Counter)
require.True(t, wantIsCounter, "expected Counter at index %d", i)
continue
}
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
}
}

View File

@@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
} }
rule = r.conn.AddRule(rule) // Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// TODO: Insert after the established rule
rule = r.conn.InsertRule(rule)
} else {
rule = r.conn.AddRule(rule)
}
log.Tracef("Adding route rule %s", spew.Sdump(rule)) log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {

View File

@@ -3,6 +3,11 @@
package uspfilter package uspfilter
import ( import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
} }
if m.nativeFirewall != nil { if m.nativeFirewall != nil {

View File

@@ -1,9 +1,11 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
} }
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {

View File

@@ -0,0 +1,16 @@
package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -10,12 +10,11 @@ import (
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
SourceIP net.IP SourceIP net.IP
DestIP net.IP DestIP net.IP
SourcePort uint16 SourcePort uint16
DestPort uint16 DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}
// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time { func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load()) return time.Unix(0, b.lastSeen.Load())

View File

@@ -3,8 +3,14 @@ package conntrack
import ( import (
"net" "net"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger())
func BenchmarkIPOperations(b *testing.B) { func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) { b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
@@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) {
}) })
} }
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})
b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})
b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})
b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
@@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs

View File

@@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@@ -33,6 +35,7 @@ type ICMPConnTrack struct {
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
type ICMPTracker struct { type ICMPTracker struct {
logger *nblog.Logger
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@@ -42,12 +45,13 @@ type ICMPTracker struct {
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
@@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
// TrackOutbound records an outbound ICMP Echo Request // TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq) key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
ID: id, ID: id,
Sequence: seq, Sequence: seq,
} }
conn.lastSeen.Store(now) conn.UpdateLastSeen()
conn.established.Store(true)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
} }
t.mutex.Unlock() t.mutex.Unlock()
conn.lastSeen.Store(now) conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
switch icmpType { if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
return true
case uint8(layers.ICMPv4TypeEchoReply):
// continue processing
default:
return false return false
} }
@@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
return false return false
} }
return conn.IsEstablished() && return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id && conn.ID == id &&
conn.Sequence == seq conn.Sequence == seq
@@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
} }
} }
} }

View File

@@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout) tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout) tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@@ -5,7 +5,10 @@ package conntrack
import ( import (
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@@ -61,12 +64,24 @@ type TCPConnKey struct {
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
State TCPState State TCPState
established atomic.Bool
sync.RWMutex sync.RWMutex
} }
// IsEstablished safely checks if connection is established
func (t *TCPConnTrack) IsEstablished() bool {
return t.established.Load()
}
// SetEstablished safely sets the established state
func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@@ -76,8 +91,9 @@ type TCPTracker struct {
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
@@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock // Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
}, },
State: TCPStateNew, State: TCPStateNew,
} }
conn.lastSeen.Store(now) conn.UpdateLastSeen()
conn.established.Store(false) conn.established.Store(false)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
} }
t.mutex.Unlock() t.mutex.Unlock()
@@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.Lock() conn.Lock()
t.updateState(conn, flags, true) t.updateState(conn, flags, true)
conn.Unlock() conn.Unlock()
conn.lastSeen.Store(now) conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
@@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetEstablished(false) conn.SetEstablished(false)
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
return return
} }
@@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
case TCPStateTimeWait: case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed // Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine // This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
@@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
} }

View File

@@ -9,7 +9,7 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout) tracker = NewTCPTracker(DefaultTCPTimeout, logger)
tt.test(t) tt.test(t)
}) })
} }
@@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections

View File

@@ -4,6 +4,8 @@ import (
"net" "net"
"sync" "sync"
"time" "time"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
const ( const (
@@ -20,6 +22,7 @@ type UDPConnTrack struct {
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
type UDPTracker struct { type UDPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
@@ -29,12 +32,13 @@ type UDPTracker struct {
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker { func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
@@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
DestPort: dstPort, DestPort: dstPort,
}, },
} }
conn.lastSeen.Store(now) conn.UpdateLastSeen()
conn.established.Store(true)
t.connections[key] = conn t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn)
} }
t.mutex.Unlock() t.mutex.Unlock()
conn.lastSeen.Store(now) conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
@@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
return conn.IsEstablished() && return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort && conn.DestPort == srcPort &&
conn.SourcePort == dstPort conn.SourcePort == dstPort
@@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() {
t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP) t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
} }
} }
} }

View File

@@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout) tracker := NewUDPTracker(tt.timeout, logger)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
@@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
assert.True(t, conn.DestIP.Equal(dstIP)) assert.True(t, conn.DestIP.Equal(dstIP))
assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort) assert.Equal(t, dstPort, conn.DestPort)
assert.True(t, conn.IsEstablished())
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second) tracker := NewUDPTracker(1*time.Second, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
logger: logger,
} }
// Start cleanup routine // Start cleanup routine
@@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@@ -0,0 +1,81 @@
package forwarder
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (e *endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error("CreateOutboundPacket: %v", err)
continue
}
written++
}
return written, nil
}
func (e *endpoint) Wait() {
// not required
}
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
// not required
}
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}

View File

@@ -0,0 +1,166 @@
package forwarder
import (
"context"
"fmt"
"net"
"runtime"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
defaultReceiveWindow = 32768
defaultMaxInFlight = 1024
iosReceiveWindow = 16384
iosMaxInFlight = 256
)
type Forwarder struct {
logger *nblog.Logger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
},
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
return nil, fmt.Errorf("creating default subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %s", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: defaultSubnet,
NIC: nicID,
},
})
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
}
receiveWindow := defaultReceiveWindow
maxInFlight := defaultMaxInFlight
if runtime.GOOS == "ios" {
receiveWindow = iosReceiveWindow
maxInFlight = iosMaxInFlight
}
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
return nil
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
if f.udpForwarder != nil {
f.udpForwarder.Stop()
}
f.stack.Close()
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
}

View File

@@ -0,0 +1,109 @@
package forwarder
import (
"context"
"net"
"time"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true
}
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return true
}
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return true
}
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
return true
}

View File

@@ -0,0 +1,90 @@
package forwarder
import (
"context"
"fmt"
"io"
"net"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
f.logger.Trace("forwarder: established TCP connection %v", id)
go f.proxyTCP(id, inConn, outConn, ep)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() {
_, err := io.Copy(outConn, inConn)
errChan <- err
}()
go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
return
}
}

View File

@@ -0,0 +1,288 @@
package forwarder
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
const (
udpTimeout = 30 * time.Second
)
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastSeen atomic.Int64
cancel context.CancelFunc
ep tcpip.Endpoint
}
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
type idleConn struct {
id stack.TransportEndpointID
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, mtu)
return &b
},
},
}
go f.cleanup()
return f
}
// Stop stops the UDP forwarder and all active connections
func (f *udpForwarder) Stop() {
f.cancel()
f.Lock()
defer f.Unlock()
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
conn.ep.Close()
delete(f.conns, id)
}
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-f.ctx.Done():
return
case <-ticker.C:
var idleConns []idleConn
f.RLock()
for id, conn := range f.conns {
if conn.getIdleDuration() > udpTimeout {
idleConns = append(idleConns, idleConn{id, conn})
}
}
f.RUnlock()
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
}
idle.conn.ep.Close()
f.Lock()
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
}
}
}
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
}
id := r.ID()
f.udpForwarder.RLock()
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id)
return
}
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
conn: inConn,
outConn: outConn,
cancel: connCancel,
ep: ep,
}
pConn.updateLastSeen()
f.udpForwarder.Lock()
// Double-check no connection was created while we were setting up
if _, exists := f.udpForwarder.conns[id]; exists {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
f.logger.Trace("forwarder: established UDP connection to %v", id)
go f.proxyUDP(connCtx, pConn, id, ep)
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
}()
errChan := make(chan error, 2)
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
return
}
}
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}
func (c *udpPacketConn) getIdleDuration() time.Duration {
lastSeen := time.Unix(0, c.lastSeen.Load())
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
n, err := src.Read(buffer)
if err != nil {
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
}
c.updateLastSeen()
}
}
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}

View File

@@ -0,0 +1,134 @@
package uspfilter
import (
"fmt"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
type localIPManager struct {
mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
}
func newLocalIPManager() *localIPManager {
return &localIPManager{}
}
func (m *localIPManager) setBitmapBit(ip net.IP) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
return
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
}
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
}
}()
var newIPv4Bitmap [1 << 16]uint32
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
}
if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
return err
}
}
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}
}
m.mu.Lock()
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
return nil
}
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil {
return m.checkBitmapBit(ipv4)
}
return false
}

View File

@@ -0,0 +1,270 @@
package uspfilter
import (
"net"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr iface.WGAddress
testIP net.IP
expected bool
}{
{
name: "Localhost range",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: iface.WGAddress{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: net.ParseIP("fe80::1"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() iface.WGAddress {
return tt.setupAddr
},
}
err := manager.UpdateLocalIPs(mock)
require.NoError(t, err)
result := manager.IsLocalIP(tt.testIP)
require.Equal(t, tt.expected, result)
})
}
}
func TestLocalIPManager_AllInterfaces(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{}
// Get actual local interfaces
interfaces, err := net.Interfaces()
require.NoError(t, err)
var tests []struct {
ip string
expected bool
}
// Add all local interface IPs to test cases
for _, iface := range interfaces {
addrs, err := iface.Addrs()
require.NoError(t, err)
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip4 := ip.To4(); ip4 != nil {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip4.String(),
expected: true,
})
}
}
}
// Add some external IPs as negative test cases
externalIPs := []string{
"8.8.8.8",
"1.1.1.1",
"208.67.222.222",
}
for _, ip := range externalIPs {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip,
expected: false,
})
}
require.NotEmpty(t, tests, "No test cases generated")
err = manager.UpdateLocalIPs(mock)
require.NoError(t, err)
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}
}
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -0,0 +1,196 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking
package log
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxBatchSize = 1024 * 16 // 16KB max batch size
maxMessageSize = 1024 * 2 // 2KB per message
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second
)
// Level represents log severity
type Level uint32
const (
LevelPanic Level = iota
LevelFatal
LevelError
LevelWarn
LevelInfo
LevelDebug
LevelTrace
)
var levelStrings = map[Level]string{
LevelPanic: "PANC",
LevelFatal: "FATL",
LevelError: "ERRO",
LevelWarn: "WARN",
LevelInfo: "INFO",
LevelDebug: "DEBG",
LevelTrace: "TRAC",
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
buffer *ringBuffer
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
}
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)]
log.Debugf("New uspfilter logger created with loglevel %v", level)
l.wg.Add(1)
go l.worker()
return l
}
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
*buf = (*buf)[:0]
// Timestamp
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
}
*buf = append(*buf, '\n')
}
func (l *Logger) log(level Level, format string, args ...interface{}) {
bufp := l.bufPool.Get().(*[]byte)
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
func (l *Logger) Warn(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
func (l *Logger) Info(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
func (l *Logger) Debug(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
func (l *Logger) Trace(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
// worker periodically flushes the buffer
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
return
case <-ticker.C:
// Read accumulated messages
n, _ := l.buffer.Read(buf[:cap(buf)])
if n == 0 {
continue
}
// Write batch
_, _ = l.output.Write(buf[:n])
}
}
}
// Stop gracefully shuts down the logger
func (l *Logger) Stop(ctx context.Context) error {
done := make(chan struct{})
l.closeOnce.Do(func() {
close(l.shutdown)
})
go func() {
l.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@@ -0,0 +1,85 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@@ -2,14 +2,15 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
) )
// Rule to handle management of rules // PeerRule to handle management of rules
type Rule struct { type PeerRule struct {
id string id string
ip net.IP ip net.IP
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
@@ -24,6 +25,21 @@ type Rule struct {
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *PeerRule) GetRuleID() string {
return r.id
}
type RouteRule struct {
id string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// GetRuleID returns the rule id
func (r *RouteRule) GetRuleID() string {
return r.id return r.id
} }

View File

@@ -0,0 +1,390 @@
package uspfilter
import (
"fmt"
"net"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
)
type PacketStage int
const (
StageReceived PacketStage = iota
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
}[s]
}
type ForwarderAction struct {
Action string
RemoteAddr string
Error error
}
type TraceResult struct {
Timestamp time.Time
Stage PacketStage
Message string
Allowed bool
ForwarderAction *ForwarderAction
}
type PacketTrace struct {
SourceIP net.IP
DestinationIP net.IP
Protocol string
SourcePort uint16
DestinationPort uint16
Direction fw.RuleDirection
Results []TraceResult
}
type TCPState struct {
SYN bool
ACK bool
FIN bool
RST bool
PSH bool
URG bool
}
type PacketBuilder struct {
SrcIP net.IP
DstIP net.IP
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
ICMPType uint8
ICMPCode uint8
Direction fw.RuleDirection
PayloadSize int
TCPState *TCPState
}
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
})
}
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
ForwarderAction: action,
})
}
func (p *PacketBuilder) Build() ([]byte, error) {
ip := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip}
transportLayer, err := p.buildTransportLayer(ip)
if err != nil {
return nil, err
}
pktLayers = append(pktLayers, transportLayer...)
if p.PayloadSize > 0 {
payload := make([]byte, p.PayloadSize)
pktLayers = append(pktLayers, gopacket.Payload(payload))
}
return serializePacket(pktLayers)
}
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
return &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP,
DstIP: p.DstIP,
}
}
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
switch p.Protocol {
case "tcp":
return p.buildTCPLayer(ip)
case "udp":
return p.buildUDPLayer(ip)
case "icmp":
return p.buildICMPLayer()
default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
}
}
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort),
Window: 65535,
SYN: p.TCPState != nil && p.TCPState.SYN,
ACK: p.TCPState != nil && p.TCPState.ACK,
FIN: p.TCPState != nil && p.TCPState.FIN,
RST: p.TCPState != nil && p.TCPState.RST,
PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
return []gopacket.SerializableLayer{tcp}, nil
}
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort),
}
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
return []gopacket.SerializableLayer{udp}, nil
}
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
}
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
icmp.Id = uint16(1)
icmp.Seq = uint16(1)
}
return []gopacket.SerializableLayer{icmp}, nil
}
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
return nil, fmt.Errorf("serialize packet: %w", err)
}
return buf.Bytes(), nil
}
func getIPProtocolNumber(protocol fw.Protocol) int {
switch protocol {
case fw.ProtocolTCP:
return int(layers.IPProtocolTCP)
case fw.ProtocolUDP:
return int(layers.IPProtocolUDP)
case fw.ProtocolICMP:
return int(layers.IPProtocolICMPv4)
default:
return 0
}
}
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
packetData, err := builder.Build()
if err != nil {
return nil, fmt.Errorf("build packet: %w", err)
}
return m.TracePacket(packetData, builder.Direction), nil
}
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
trace := &PacketTrace{Direction: direction}
// Initial packet decoding
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace
}
// Extract base packet info
srcIP, dstIP := m.extractIPs(d)
trace.SourceIP = srcIP
trace.DestinationIP = dstIP
// Determine protocol and ports
switch d.decoded[1] {
case layers.LayerTypeTCP:
trace.Protocol = "TCP"
trace.SourcePort = uint16(d.tcp.SrcPort)
trace.DestinationPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
trace.Protocol = "UDP"
trace.SourcePort = uint16(d.udp.SrcPort)
trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP"
}
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
if direction == fw.RuleDirectionOUT {
return m.traceOutbound(packetData, trace)
}
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if !m.handleRouting(trace) {
return trace
}
if m.nativeRouter {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
trace.AddResult(StageConntrack, msg, true)
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
return true
}
trace.AddResult(StageConntrack, msg, false)
return false
}
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
msg := "Matched existing connection state"
switch d.decoded[1] {
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
flags&conntrack.TCPSyn != 0,
flags&conntrack.TCPAck != 0,
flags&conntrack.TCPRst != 0,
flags&conntrack.TCPFin != 0)
case layers.LayerTypeICMPv4:
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
}
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true
}
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
msg := "Allowed by peer ACL rules"
if blocked {
msg = "Blocked by peer ACL rules"
}
trace.AddResult(StagePeerACL, msg, !blocked)
if m.netstack {
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
return true
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
}
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
return true
}
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
trace.AddResult(StageForwarding, "Forwarding via native router", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
msg := "Allowed by route ACLs"
if !allowed {
msg = "Blocked by route ACLs"
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
return trace
}
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
fwdAction := &ForwarderAction{
Action: action,
RemoteAddr: remoteAddr,
}
trace.AddResultWithForwarder(StageForwarding,
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
}
return trace
}

View File

@@ -1,11 +1,14 @@
package uspfilter package uspfilter
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"slices"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -14,28 +17,48 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 0
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
var ( // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule type RuleSet map[string]PeerRule
type RouteRules []RouteRule
func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b RouteRule) int {
// Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
return -1
}
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
return 1
}
return strings.Compare(a.id, b.id)
})
}
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
@@ -43,17 +66,34 @@ type Manager struct {
outgoingRules map[string]RuleSet outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
mutex sync.RWMutex mutex sync.RWMutex
stateful bool // indicates whether server routes are disabled
disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves
routingEnabled bool
// indicates whether we leave forwarding and filtering to the native firewall
nativeRouter bool
// indicates whether we track outbound connections
stateful bool
// indicates whether wireguards runs in netstack mode
netstack bool
// indicates whether we forward local traffic to the native stack
localForwarding bool
localipmanager *localIPManager
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder
logger *nblog.Logger
} }
// decoder for packages // decoder for packages
@@ -70,22 +110,44 @@ type decoder struct {
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
return create(iface) return create(iface, nil, disableServerRoutes)
} }
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
mgr, err := create(iface) if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mgr.nativeFirewall = nativeFirewall
return mgr, nil return mgr, nil
} }
func create(iface IFaceMapper) (*Manager, error) { func parseCreateEnv() (bool, bool) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) var disableConntrack, enableLocalForwarding bool
var err error
if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
}
}
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
}
}
return disableConntrack, enableLocalForwarding
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv()
m := &Manager{ m := &Manager{
decoders: sync.Pool{ decoders: sync.Pool{
@@ -101,52 +163,183 @@ func create(iface IFaceMapper) (*Manager, error) {
return d return d
}, },
}, },
outgoingRules: make(map[string]RuleSet), nativeFirewall: nativeFirewall,
incomingRules: make(map[string]RuleSet), outgoingRules: make(map[string]RuleSet),
wgIface: iface, incomingRules: make(map[string]RuleSet),
stateful: !disableConntrack, wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
// default true for non-netstack, for netstack only if explicitly enabled
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
} }
// Only initialize trackers if stateful mode is enabled
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, err return nil, fmt.Errorf("set filter: %w", err)
} }
return m, nil return m, nil
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
if _, err := m.AddRouteFiltering(
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
wgPrefix,
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
); err != nil {
return fmt.Errorf("block wg nte : %w", err)
}
// TODO: Block networks that we're a client of
return nil
}
func (m *Manager) determineRouting() error {
var disableUspRouting, forceUserspaceRouter bool
var err error
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
disableUspRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
}
}
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
forceUserspaceRouter, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
}
}
switch {
case disableUspRouting:
m.routingEnabled = false
m.nativeRouter = false
log.Info("userspace routing is disabled")
case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true
m.nativeRouter = true
log.Info("server routes are disabled")
case forceUserspaceRouter:
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
m.routingEnabled = true
m.nativeRouter = true
log.Info("native routing is enabled")
default:
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing enabled by default")
}
if m.routingEnabled && !m.nativeRouter {
return m.initForwarder()
}
return nil
}
// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error {
if m.forwarder != nil {
return nil
}
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice()
if intf == nil {
m.routingEnabled = false
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
if err != nil {
m.routingEnabled = false
return fmt.Errorf("create forwarder: %w", err)
}
m.forwarder = forwarder
log.Debug("forwarder initialized")
return nil
}
func (m *Manager) Init(*statemanager.Manager) error { func (m *Manager) Init(*statemanager.Manager) error {
return nil return nil
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil { return true
return false
} else {
return true
}
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeRouter && m.nativeFirewall != nil {
return errRouteNotSupported return m.nativeFirewall.AddNatRule(pair)
} }
return m.nativeFirewall.AddNatRule(pair)
// userspace routed packets are always SNATed to the inbound direction
// TODO: implement outbound SNAT
return nil
} }
// RemoveNatRule removes a routing firewall rule // RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeRouter && m.nativeFirewall != nil {
return errRouteNotSupported return m.nativeFirewall.RemoveNatRule(pair)
} }
return m.nativeFirewall.RemoveNatRule(pair) return nil
} }
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
@@ -162,7 +355,7 @@ func (m *Manager) AddPeerFiltering(
_ string, _ string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
r := Rule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
@@ -205,18 +398,56 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(
if m.nativeFirewall == nil { sources []netip.Prefix,
return nil, errRouteNotSupported destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
} }
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String()
rule := RouteRule{
id: ruleID,
sources: sources,
destination: destination,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
}
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort()
return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil { if m.nativeRouter && m.nativeFirewall != nil {
return errRouteNotSupported return m.nativeFirewall.DeleteRouteRule(rule)
} }
return m.nativeFirewall.DeleteRouteRule(rule)
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.GetRuleID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
return r.id == ruleID
})
if idx < 0 {
return fmt.Errorf("route rule not found: %s", ruleID)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
return nil
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -224,7 +455,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
r, ok := rule.(*Rule) r, ok := rule.(*PeerRule)
if !ok { if !ok {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
@@ -255,10 +486,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool {
// DropIncoming filter incoming packets // DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool { func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules) return m.dropFilter(packetData)
}
// UpdateLocalIPs updates the list of local IPs
func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
} }
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@@ -279,18 +514,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
return false return false
} }
// Always process UDP hooks // Track all protocols if stateful mode is enabled
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.trackUDPOutbound(d, srcIP, dstIP)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
// Track other protocols only if stateful mode is enabled
if m.stateful { if m.stateful {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeUDP:
m.trackUDPOutbound(d, srcIP, dstIP)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP) m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
@@ -298,6 +526,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
} }
} }
// Process UDP hooks even if stateful mode is disabled
if d.decoded[1] == layers.LayerTypeUDP {
return m.checkUDPHooks(d, dstIP, packetData)
}
return false return false
} }
@@ -379,10 +612,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
} }
} }
// dropFilter implements filtering logic for incoming packets // dropFilter implements filtering logic for incoming packets.
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { // If it returns true, the packet should be dropped.
// TODO: Disable router if --disable-server-router is set func (m *Manager) dropFilter(packetData []byte) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@@ -395,39 +627,127 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if srcIP == nil {
log.Errorf("unknown layer: %v", d.decoded[0]) m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true return true
} }
if !m.isWireguardTraffic(srcIP, dstIP) { // For all inbound traffic, first check if it matches a tracked connection.
return false // This must happen before any other filtering because the packets are statefully tracked.
}
// Check connection state only if enabled
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false return false
} }
return m.applyRules(srcIP, packetData, rules, d) if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
}
// handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
if !m.localForwarding {
m.logger.Trace("Dropping local packet (local forwarding disabled): src=%s dst=%s", srcIP, dstIP)
return true
}
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
srcIP, dstIP)
return true
}
// if running in netstack mode we need to pass this to the forwarder
if m.netstack {
m.handleNetstackLocalTraffic(packetData)
// don't process this packet further
return true
}
return false
}
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
if m.forwarder == nil {
return
}
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
}
}
// handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
// Drop if routing is disabled
if !m.routingEnabled {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
// Pass to native stack if native router is enabled or forced
if m.nativeRouter {
return false
}
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
srcIP, srcPort, dstIP, dstPort, proto)
return true
}
// Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err)
}
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
return true
}
func getProtocolFromPacket(d *decoder) firewall.Protocol {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return firewall.ProtocolTCP
case layers.LayerTypeUDP:
return firewall.ProtocolUDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP
default:
return firewall.ProtocolALL
}
}
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
default:
return 0, 0
}
} }
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err) m.logger.Trace("couldn't decode packet, err: %s", err)
return false return false
} }
if len(d.decoded) < 2 { if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet") m.logger.Trace("packet doesn't have network and transport layers")
return false return false
} }
return true return true
} }
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
@@ -462,7 +782,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
return false return false
} }
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { // isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed
func (m *Manager) isSpecialICMP(d *decoder) bool {
if d.decoded[1] != layers.LayerTypeICMPv4 {
return false
}
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
if m.isSpecialICMP(d) {
return false
}
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
return filter return filter
} }
@@ -496,7 +831,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false return false
} }
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) { func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) { if rule.matchByIP && !ip.Equal(rule.ip) {
@@ -533,6 +868,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
return false, false return false, false
} }
// routeACLsPass returns treu if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
return rule.action == firewall.ActionAccept
}
}
return false
}
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) {
return false
}
sourceMatched := false
for _, src := range rule.sources {
if src.Contains(srcAddr) {
sourceMatched = true
break
}
}
if !sourceMatched {
return false
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
}
// SetNetwork of the wireguard interface to which filtering applied // SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) { func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network m.wgNetwork = network
@@ -544,7 +924,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
func (m *Manager) AddUDPPacketHook( func (m *Manager) AddUDPPacketHook(
in bool, ip net.IP, dPort uint16, hook func([]byte) bool, in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string { ) string {
r := Rule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
@@ -561,12 +941,12 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock() m.mutex.Lock()
if in { if in {
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule) m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok { if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]Rule) m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
} }
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip.String()][r.id] = r
} }
@@ -599,3 +979,41 @@ func (m *Manager) RemovePacketHook(hookID string) error {
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(level log.Level) {
if m.logger != nil {
m.logger.SetLevel(nblog.Level(level))
}
}
func (m *Manager) EnableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.determineRouting()
}
func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.forwarder == nil {
return nil
}
m.routingEnabled = false
m.nativeRouter = false
// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nil
}
m.forwarder.Stop()
m.forwarder = nil
log.Debug("forwarder stopped")
return nil
}

View File

@@ -1,9 +1,12 @@
//go:build uspbench
package uspfilter package uspfilter
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"os" "os"
"strings" "strings"
"testing" "testing"
@@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules) manager.dropFilter(inbound)
} }
}) })
} }
@@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, manager.incomingRules) manager.dropFilter(testIn)
} }
}) })
} }
@@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules) manager.dropFilter(inbound)
} }
}) })
} }
@@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules) manager.dropFilter(synack)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack)
@@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules) manager.dropFilter(inbound)
} }
}) })
} }
@@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules) manager.dropFilter(synack)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
@@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], manager.incomingRules) manager.dropFilter(inPackets[connIdx])
} }
}) })
} }
@@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules) manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules) manager.dropFilter(p.response)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules) manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer, manager.incomingRules) manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient)
} }
}) })
@@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules) manager.dropFilter(synack)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
@@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], manager.incomingRules) manager.dropFilter(inPackets[connIdx])
} }
}) })
}) })
@@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules) manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules) manager.dropFilter(p.response)
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules) manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer, manager.incomingRules) manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient)
} }
}) })
@@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes() return buf.Bytes()
} }
func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16")
// Add several route rules to simulate real-world scenario
rules := []struct {
sources []netip.Prefix
dest netip.Prefix
proto fw.Protocol
port *fw.Port
}{
{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
port: &fw.Port{Values: []uint16{80, 443}},
},
{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("10.0.0.0/8"),
},
dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolICMP,
},
{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.0.0/16"),
proto: fw.ProtocolUDP,
port: &fw.Port{Values: []uint16{53}},
},
}
for _, r := range rules {
_, err := manager.AddRouteFiltering(
r.sources,
r.dest,
r.proto,
nil,
r.port,
fw.ActionAccept,
)
if err != nil {
b.Fatal(err)
}
}
// Test cases that exercise different matching scenarios
cases := []struct {
srcIP string
dstIP string
proto fw.Protocol
dstPort uint16
}{
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,17 +9,38 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger())
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() iface.WGAddress
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil {
return nil
}
return i.GetWGDeviceFunc()
}
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
if i.GetDeviceFunc == nil {
return nil
}
return i.GetDeviceFunc()
} }
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
@@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -95,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -166,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
@@ -215,7 +236,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -247,9 +268,18 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock, false)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -298,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes(), m.incomingRules) { if m.dropFilter(buf.Bytes()) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
@@ -317,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface) manager, err := Create(iface, false)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
@@ -363,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -371,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Reset(nil))
}() }()
@@ -449,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, false)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -476,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}) }, false)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -485,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@@ -606,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) drop = manager.dropFilter(inboundBuf.Bytes())
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@@ -677,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) drop = manager.dropFilter(testBuf.Bytes())
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
} }
func getFwmark() int { func getFwmark() int {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() { if nbnet.AdvancedRouting() {
return nbnet.NetbirdFwmark return nbnet.NetbirdFwmark
} }
return 0 return 0

View File

@@ -3,6 +3,8 @@
package iface package iface
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -15,4 +17,5 @@ type WGTunDevice interface {
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
} }

View File

@@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())

View File

@@ -9,6 +9,7 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@@ -151,6 +152,11 @@ func (t *TunKernelDevice) DeviceName() string {
return t.name return t.name
} }
// Device returns the wireguard device, not applicable for kernel devices
func (t *TunKernelDevice) Device() *device.Device {
return nil
}
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil return nil
} }

View File

@@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string {
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunNetstackDevice) Device() *device.Device {
return t.device
}

View File

@@ -124,6 +124,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *USPDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error { func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name) link := newWGLink(t.name)

View File

@@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
func (t *TunDevice) GetInterfaceGUIDString() (string, error) { func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil { if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet") return "", fmt.Errorf("interface has not been initialized yet")

View File

@@ -1,6 +1,8 @@
package iface package iface
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -13,4 +15,5 @@ type WGTunDevice interface {
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
} }

View File

@@ -11,6 +11,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
return w.tun.FilteredDevice() return w.tun.FilteredDevice()
} }
// GetWGDevice returns the WireGuard device
func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats(peerKey)

View File

@@ -4,6 +4,7 @@ import (
"net" "net"
"time" "time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@@ -29,6 +30,7 @@ type MockWGIface struct {
SetFilterFunc func(filter device.PacketFilter) error SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
GetWGDeviceFunc func() *wgdevice.Device
GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy GetProxyFunc func() wgproxy.Proxy
@@ -102,11 +104,14 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc() return m.GetDeviceFunc()
} }
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
return m.GetWGDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey) return m.GetStatsFunc(peerKey)
} }
func (m *MockWGIface) GetProxy() wgproxy.Proxy { func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me return m.GetProxyFunc()
panic("implement me")
} }

View File

@@ -6,6 +6,7 @@ import (
"net" "net"
"time" "time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@@ -32,5 +33,6 @@ type IWGIface interface {
SetFilter(filter device.PacketFilter) error SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@@ -4,6 +4,7 @@ import (
"net" "net"
"time" "time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@@ -30,6 +31,7 @@ type IWGIface interface {
SetFilter(filter device.PacketFilter) error SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) {
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil) fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
@@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil) fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return

View File

@@ -8,6 +8,8 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface" iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
} }
// GetDevice mocks base method.
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDevice")
ret0, _ := ret[0].(*device.FilteredDevice)
return ret0
}
// GetDevice indicates an expected call of GetDevice.
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
}
// GetWGDevice mocks base method.
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWGDevice")
ret0, _ := ret[0].(*wgdevice.Device)
return ret0
}
// GetWGDevice indicates an expected call of GetWGDevice.
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
}

View File

@@ -31,6 +31,7 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -109,6 +110,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
nbnet.Init()
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,

View File

@@ -12,7 +12,7 @@ import (
const ( const (
PriorityDNSRoute = 100 PriorityDNSRoute = 100
PriorityMatchDomain = 50 PriorityMatchDomain = 50
PriorityDefault = 0 PriorityDefault = 1
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {
@@ -26,7 +26,6 @@ type HandlerEntry struct {
Pattern string Pattern string
OrigPattern string OrigPattern string
IsWildcard bool IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool MatchSubdomains bool
} }
@@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
} }
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
// First remove any existing handler with same pattern (case-insensitive) and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break break
} }
@@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
Pattern: pattern, Pattern: pattern,
OrigPattern: origPattern, OrigPattern: origPattern,
IsWildcard: isWildcard, IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains, MatchSubdomains: matchSubdomains,
} }
@@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return return
} }
@@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers)) log.Tracef("current handlers (%d):", len(handlers))
for _, h := range handlers { for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
} }
} }
@@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
if !matched { if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue continue
} }
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{ chainWriter := &ResponseWriterChain{
ResponseWriter: w, ResponseWriter: w,

View File

@@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
dnsRouteHandler := &nbdns.MockHandler{} dnsRouteHandler := &nbdns.MockHandler{}
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
@@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
pattern = "*." + tt.handlerDomain[2:] pattern = "*." + tt.handlerDomain[2:]
} }
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
@@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
} }
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
} }
// Create and execute request // Create and execute request
@@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3 := &nbdns.MockHandler{} handler3 := &nbdns.MockHandler{}
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
@@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
if op.action == "add" { if op.action == "add" {
handler := &nbdns.MockHandler{} handler := &nbdns.MockHandler{}
handlers[op.priority] = handler handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil) chain.AddHandler(op.pattern, handler, op.priority)
} else { } else {
chain.RemoveHandler(op.pattern, op.priority) chain.RemoveHandler(op.pattern, op.priority)
} }
@@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r.SetQuestion(testQuery, dns.TypeA) r.SetQuestion(testQuery, dns.TypeA)
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state with all three handlers // Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
@@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
handler = mockHandler handler = mockHandler
} }
chain.AddHandler(pattern, handler, h.priority, nil) chain.AddHandler(pattern, handler, h.priority)
} }
// Execute request // Execute request
@@ -795,7 +795,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
if op.action == "add" { if op.action == "add" {
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain} handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
handlers[op.pattern] = handler handlers[op.pattern] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil) chain.AddHandler(op.pattern, handler, op.priority)
} else { } else {
chain.RemoveHandler(op.pattern, op.priority) chain.RemoveHandler(op.pattern, op.priority)
} }

View File

@@ -1,35 +1,51 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"strings" "strings"
"syscall"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
var (
userenv = syscall.NewLazyDLL("userenv.dll")
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
)
const ( const (
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match` dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient`
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match`
dnsPolicyConfigVersionKey = "Version" dnsPolicyConfigVersionKey = "Version"
dnsPolicyConfigVersionValue = 2 dnsPolicyConfigVersionValue = 2
dnsPolicyConfigNameKey = "Name" dnsPolicyConfigNameKey = "Name"
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers" dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
dnsPolicyConfigConfigOptionsKey = "ConfigOptions" dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8 dnsPolicyConfigConfigOptionsValue = 0x8
)
const (
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
// RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1
) )
type registryConfigurator struct { type registryConfigurator struct {
guid string guid string
routingAll bool routingAll bool
gpo bool
} }
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -37,12 +53,20 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newHostManagerWithGuid(guid)
}
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) { var useGPO bool
k, err := registry.OpenKey(registry.LOCAL_MACHINE, gpoDnsPolicyRoot, registry.QUERY_VALUE)
if err != nil {
log.Debugf("failed to open GPO DNS policy root: %v", err)
} else {
closer(k)
useGPO = true
log.Infof("detected GPO DNS policy configuration, using policy store")
}
return &registryConfigurator{ return &registryConfigurator{
guid: guid, guid: guid,
gpo: useGPO,
}, nil }, nil
} }
@@ -51,30 +75,23 @@ func (r *registryConfigurator) supportCustomPort() bool {
} }
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if config.RouteAll { if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP) if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
if err != nil {
return fmt.Errorf("add dns setup: %w", err) return fmt.Errorf("add dns setup: %w", err)
} }
} else if r.routingAll { } else if r.routingAll {
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey) if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
if err != nil {
return fmt.Errorf("delete interface registry key property: %w", err) return fmt.Errorf("delete interface registry key property: %w", err)
} }
r.routingAll = false r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil { if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
log.Errorf("failed to update shutdown state: %s", err) log.Errorf("failed to update shutdown state: %s", err)
} }
var ( var searchDomains, matchDomains []string
searchDomains []string
matchDomains []string
)
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
if dConf.Disabled { if dConf.Disabled {
continue continue
@@ -86,16 +103,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
return fmt.Errorf("add dns match policy: %w", err)
}
} else { } else {
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) if err := r.removeDNSMatchPolicies(); err != nil {
} return fmt.Errorf("remove dns match policies: %w", err)
if err != nil { }
return fmt.Errorf("add dns match policy: %w", err)
} }
err = r.updateSearchDomains(searchDomains) if err := r.updateSearchDomains(searchDomains); err != nil {
if err != nil {
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
} }
@@ -103,9 +120,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
func (r *registryConfigurator) addDNSSetupForAll(ip string) error { func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip) if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil {
if err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err)
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
} }
r.routingAll = true r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
@@ -113,64 +129,54 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
} }
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE) // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
if err == nil { // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath) policyPath := dnsPolicyConfigMatchPath
if err != nil { if r.gpo {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err) policyPath = gpoDnsPolicyConfigMatchPath
}
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err)
}
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
}
defer closer(regKey)
if err := regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigVersionKey, err)
}
if err := regKey.SetStringsValue(dnsPolicyConfigNameKey, domains); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err)
}
if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err)
}
if err := regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigConfigOptionsKey, err)
}
if r.gpo {
if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err)
} }
} }
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE) log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
if err != nil {
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
}
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
}
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
}
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
}
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
}
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
return nil
}
func (r *registryConfigurator) restoreHostDNS() error {
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
log.Errorf("remove registry key from dns policy config: %s", err)
}
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
return fmt.Errorf("remove interface registry key: %w", err)
}
return nil return nil
} }
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")) if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
if err != nil { return fmt.Errorf("update search domains: %w", err)
return fmt.Errorf("adding search domain failed with error: %w", err)
} }
log.Infof("updated search domains: %s", domains)
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
return nil return nil
} }
@@ -181,11 +187,9 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
} }
defer closer(regKey) defer closer(regKey)
err = regKey.SetStringValue(key, value) if err := regKey.SetStringValue(key, value); err != nil {
if err != nil { return fmt.Errorf("set key %s=%s: %w", key, value, err)
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
} }
return nil return nil
} }
@@ -196,43 +200,91 @@ func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey st
} }
defer closer(regKey) defer closer(regKey)
err = regKey.DeleteValue(propertyKey) if err := regKey.DeleteValue(propertyKey); err != nil {
if err != nil { return fmt.Errorf("delete registry key %s: %w", propertyKey, err)
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
} }
return nil return nil
} }
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) { func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
var regKey registry.Key
regKeyPath := interfaceConfigPath + "\\" + r.guid regKeyPath := interfaceConfigPath + "\\" + r.guid
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE) regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
if err != nil { if err != nil {
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err) return regKey, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
} }
return regKey, nil return regKey, nil
} }
func (r *registryConfigurator) restoreUncleanShutdownDNS() error { func (r *registryConfigurator) restoreHostDNS() error {
if err := r.restoreHostDNS(); err != nil { if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("restoring dns via registry: %w", err) log.Errorf("remove dns match policies: %s", err)
} }
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
return fmt.Errorf("remove interface registry key: %w", err)
}
return nil return nil
} }
func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
}
if err := refreshGroupPolicy(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("refresh group policy: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
return r.restoreHostDNS()
}
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error { func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE) k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
if err == nil { if err != nil {
defer closer(k) log.Debugf("failed to open HKEY_LOCAL_MACHINE\\%s: %v", regKeyPath, err)
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath) return nil
if err != nil {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
}
} }
closer(k)
if err := registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath); err != nil {
return fmt.Errorf("delete HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
}
return nil
}
func refreshGroupPolicy() error {
// refreshPolicyExFn.Call() panics if the func is not found
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %v", r)
}
}()
ret, _, err := refreshPolicyExFn.Call(
// bMachine = TRUE (computer policy)
uintptr(1),
// dwOptions = RP_FORCE
uintptr(rpForce),
)
if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("RefreshPolicyEx failed: %w", err)
}
return fmt.Errorf("RefreshPolicyEx failed")
}
return nil return nil
} }

View File

@@ -29,10 +29,15 @@ func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
} }
// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 { if len(r.Question) > 0 {
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
} }
replyMessage := &dns.Msg{} replyMessage := &dns.Msg{}

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"runtime" "runtime"
"strings"
"sync" "sync"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -42,7 +41,12 @@ type Server interface {
ProbeAvailability() ProbeAvailability()
} }
type registeredHandlerMap map[string]handlerWithStop type handlerID string
type nsGroupsByDomain struct {
domain string
groups []*nbdns.NameServerGroup
}
// DefaultServer dns server object // DefaultServer dns server object
type DefaultServer struct { type DefaultServer struct {
@@ -52,7 +56,6 @@ type DefaultServer struct {
mux sync.Mutex mux sync.Mutex
service service service service
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
handlerPriorities map[string]int
localResolver *localResolver localResolver *localResolver
wgInterface WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
@@ -77,14 +80,17 @@ type handlerWithStop interface {
dns.Handler dns.Handler
stop() stop()
probeAvailability() probeAvailability()
id() handlerID
} }
type muxUpdate struct { type handlerWrapper struct {
domain string domain string
handler handlerWithStop handler handlerWithStop
priority int priority int
} }
type registeredHandlerMap map[handlerID]handlerWrapper
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(
ctx context.Context, ctx context.Context,
@@ -158,13 +164,12 @@ func newDefaultServer(
) *DefaultServer { ) *DefaultServer {
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
disableSys: disableSys, disableSys: disableSys,
service: dnsService, service: dnsService,
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
handlerPriorities: make(map[string]int),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
log.Warn("skipping empty domain") log.Warn("skipping empty domain")
continue continue
} }
s.handlerChain.AddHandler(domain, handler, priority, nil) s.handlerChain.AddHandler(domain, handler, priority)
s.handlerPriorities[domain] = priority
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
} }
} }
@@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority) log.Debugf("deregistering handler %v with priority %d", domains, priority)
for _, domain := range domains { for _, domain := range domains {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.handlerChain.RemoveHandler(domain, priority) s.handlerChain.RemoveHandler(domain, priority)
// Only deregister from service if no handlers remain // Only deregister from service if no handlers remain
if !s.handlerChain.HasHandlers(domain) { if !s.handlerChain.HasHandlers(domain) {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.service.DeregisterMux(nbdns.NormalizeZone(domain)) s.service.DeregisterMux(nbdns.NormalizeZone(domain))
} }
} }
@@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
_, ok := s.dnsMuxMap[nbdns.RootZone] // Check if there's any root handler
if ok { var hasRootHandler bool
for _, handler := range s.dnsMuxMap {
if handler.domain == nbdns.RootZone {
hasRootHandler = true
break
}
}
if hasRootHandler {
log.Debugf("on new host DNS config but skip to apply it") log.Debugf("on new host DNS config but skip to apply it")
return return
} }
log.Debugf("update host DNS settings: %+v", hostsDnsList) log.Debugf("update host DNS settings: %+v", hostsDnsList)
s.addHostRootZone() s.addHostRootZone()
} }
@@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() {
go func(mux handlerWithStop) { go func(mux handlerWithStop) {
defer wg.Done() defer wg.Done()
mux.probeAvailability() mux.probeAvailability()
}(mux) }(mux.handler)
} }
wg.Wait() wg.Wait()
} }
@@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
return nil return nil
} }
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate var muxUpdates []handlerWrapper
localRecords := make(map[string]nbdns.SimpleRecord, 0) localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones { for _, customZone := range customZones {
@@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
return nil, nil, fmt.Errorf("received an empty list of records") return nil, nil, fmt.Errorf("received an empty list of records")
} }
muxUpdates = append(muxUpdates, muxUpdate{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@@ -446,15 +461,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
return muxUpdates, localRecords, nil return muxUpdates, localRecords, nil
} }
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
var muxUpdates []handlerWrapper
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups { for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 { if len(nsGroup.NameServers) == 0 {
log.Warn("received a nameserver group with empty nameserver list") log.Warn("received a nameserver group with empty nameserver list")
continue continue
} }
if !nsGroup.Primary && len(nsGroup.Domains) == 0 {
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
}
}
groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain
if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault
}
updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority)
if err != nil {
return nil, err
}
muxUpdates = append(muxUpdates, updates...)
}
return muxUpdates, nil
}
func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) {
var muxUpdates []handlerWrapper
for i, nsGroup := range domainGroup.groups {
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i
// Check if we're about to overlap with the next priority tier
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
break
}
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface.Name(),
@@ -462,10 +521,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
domainGroup.domain,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) return nil, fmt.Errorf("create upstream resolver: %v", err)
} }
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType { if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
@@ -489,78 +550,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
// after some period defined by upstream it tries to reactivate self by calling this hook // after some period defined by upstream it tries to reactivate self by calling this hook
// everything we need here is just to re-apply current configuration because it already // everything we need here is just to re-apply current configuration because it already
// contains this upstream settings (temporal deactivation not removed it) // contains this upstream settings (temporal deactivation not removed it)
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority)
if nsGroup.Primary { muxUpdates = append(muxUpdates, handlerWrapper{
muxUpdates = append(muxUpdates, muxUpdate{ domain: domainGroup.domain,
domain: nbdns.RootZone, handler: handler,
handler: handler, priority: priority,
priority: PriorityDefault, })
})
continue
}
if len(nsGroup.Domains) == 0 {
handler.stop()
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
handler.stop()
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
priority: PriorityMatchDomain,
})
}
} }
return muxUpdates, nil return muxUpdates, nil
} }
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
muxUpdateMap := make(registeredHandlerMap) // this will introduce a short period of time when the server is not able to handle DNS requests
handlersByPriority := make(map[string]int) for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority)
var isContainRootUpdate bool existing.handler.stop()
// First register new handlers
for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.domain] = update.handler
handlersByPriority[update.domain] = update.priority
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop()
}
if update.domain == nbdns.RootZone {
isContainRootUpdate = true
}
} }
// Then deregister old handlers not in the update muxUpdateMap := make(registeredHandlerMap)
for key, existingHandler := range s.dnsMuxMap { var containsRootUpdate bool
_, found := muxUpdateMap[key]
if !found { for _, update := range muxUpdates {
if !isContainRootUpdate && key == nbdns.RootZone { if update.domain == nbdns.RootZone {
containsRootUpdate = true
}
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.id()] = update
}
// If there's no root update and we had a root handler, restore it
if !containsRootUpdate {
for _, existing := range s.dnsMuxMap {
if existing.domain == nbdns.RootZone {
s.addHostRootZone() s.addHostRootZone()
existingHandler.stop() break
} else {
existingHandler.stop()
// Deregister with the priority that was used to register
if oldPriority, ok := s.handlerPriorities[key]; ok {
s.deregisterHandler([]string{key}, oldPriority)
}
} }
} }
} }
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
s.handlerPriorities = handlersByPriority
} }
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
@@ -593,6 +623,7 @@ func getNSHostPort(ns nbdns.NameServer) string {
func (s *DefaultServer) upstreamCallbacks( func (s *DefaultServer) upstreamCallbacks(
nsGroup *nbdns.NameServerGroup, nsGroup *nbdns.NameServerGroup,
handler dns.Handler, handler dns.Handler,
priority int,
) (deactivate func(error), reactivate func()) { ) (deactivate func(error), reactivate func()) {
var removeIndex map[string]int var removeIndex map[string]int
deactivate = func(err error) { deactivate = func(err error) {
@@ -609,13 +640,13 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) s.deregisterHandler([]string{nbdns.RootZone}, priority)
} }
for i, item := range s.currentConfig.Domains { for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found { if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true s.currentConfig.Domains[i].Disabled = true
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) s.deregisterHandler([]string{item.Domain}, priority)
removeIndex[item.Domain] = i removeIndex[item.Domain] = i
} }
} }
@@ -635,8 +666,8 @@ func (s *DefaultServer) upstreamCallbacks(
} }
s.updateNSState(nsGroup, err, false) s.updateNSState(nsGroup, err, false)
} }
reactivate = func() { reactivate = func() {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -646,7 +677,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue continue
} }
s.currentConfig.Domains[i].Disabled = false s.currentConfig.Domains[i].Disabled = false
s.registerHandler([]string{domain}, handler, PriorityMatchDomain) s.registerHandler([]string{domain}, handler, priority)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@@ -654,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) s.registerHandler([]string{nbdns.RootZone}, handler, priority)
} }
if s.hostManager != nil { if s.hostManager != nil {
@@ -676,6 +707,7 @@ func (s *DefaultServer) addHostRootZone() {
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
nbdns.RootZone,
) )
if err != nil { if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err) log.Errorf("unable to create a new upstream resolver, error: %v", err)
@@ -732,5 +764,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
} }
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ",")) return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
}
// groupNSGroupsByDomain groups nameserver groups by their match domains
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
domainMap := make(map[string][]*nbdns.NameServerGroup)
for _, group := range nsGroups {
if group.Primary {
domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group)
continue
}
for _, domain := range group.Domains {
if domain == "" {
continue
}
domainMap[domain] = append(domainMap[domain], group)
}
}
var result []nsGroupsByDomain
for domain, groups := range domainMap {
result = append(result, nsGroupsByDomain{
domain: domain,
groups: groups,
})
}
return result
} }

View File

@@ -13,6 +13,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -88,6 +89,18 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger()) formatter.SetTextFormatter(log.StandardLogger())
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string
for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv))
}
return &upstreamResolverBase{
domain: domain,
upstreamServers: srvs,
cancel: func() {},
}
}
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
@@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler}, expectedUpstreamMap: registeredHandlerMap{
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
dummyHandler.id(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
generateDummyHandler(".", nameServers).id(): handlerWrapper{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
}, },
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler}, initUpstreamMap: registeredHandlerMap{
initSerial: 0, generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
inputSerial: 1, domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler}, expectedUpstreamMap: registeredHandlerMap{
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
"local-resolver": handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
}, },
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
@@ -242,9 +288,15 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true}, inputUpdate: nbdns.Config{ServiceEnable: true},
@@ -252,9 +304,15 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalMap: make(registrationMap), expectedLocalMap: make(registrationMap),
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false}, inputUpdate: nbdns.Config{ServiceEnable: false},
@@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
}() }()
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}} dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &localResolver{},
priority: PriorityMatchDomain,
},
}
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
dnsServer.updateSerial = 0 dnsServer.updateSerial = 0
@@ -562,9 +626,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
handlerPriorities: make(map[string]int), hostManager: hostManager,
hostManager: hostManager,
currentConfig: HostDNSConfig{ currentConfig: HostDNSConfig{
Domains: []DomainConfig{ Domains: []DomainConfig{
{false, "domain0", false}, {false, "domain0", false},
@@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
}, },
}, nil) }, nil, 0)
deactivate(nil) deactivate(nil)
expected := "domain0,domain2" expected := "domain0,domain2"
@@ -849,7 +912,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err return nil, err
} }
pf, err := uspfilter.Create(wgIface) pf, err := uspfilter.Create(wgIface, false)
if err != nil { if err != nil {
t.Fatalf("failed to create uspfilter: %v", err) t.Fatalf("failed to create uspfilter: %v", err)
return nil, err return nil, err
@@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
Subdomains: true, Subdomains: true,
} }
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
testCases := []struct { testCases := []struct {
name string name string
@@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
}) })
} }
} }
type mockHandler struct {
Id string
}
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) stop() {}
func (m *mockHandler) probeAvailability() {}
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
type mockService struct{}
func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() {}
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := registeredHandlerMap{
"upstream-group1": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
}
baseRootHandlers := registeredHandlerMap{
"upstream-root1": {
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
"upstream-root2": {
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
}
baseMixedHandlers := registeredHandlerMap{
"upstream-group1": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
"upstream-other": {
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
},
}
tests := []struct {
name string
initialHandlers registeredHandlerMap
updates []handlerWrapper
expectedHandlers map[string]string // map[handlerID]domain
description string
}{
{
name: "Remove group1 from update",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Only group2 remains
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
"upstream-group2": "example.com",
},
description: "When group1 is not included in the update, it should be removed while group2 remains",
},
{
name: "Remove group2 from update",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Only group1 remains
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
},
description: "When group2 is not included in the update, it should be removed while group1 remains",
},
{
name: "Add group3 in first position",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Add group3 with highest priority
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain + 1,
},
// Keep existing groups with their original priorities
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-group3": "example.com",
},
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
},
{
name: "Add group3 in last position",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Keep existing groups with their original priorities
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
// Add group3 with lowest priority
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain - 2,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-group3": "example.com",
},
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
},
// Root zone tests
{
name: "Remove root1 from update",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
},
expectedHandlers: map[string]string{
"upstream-root2": ".",
},
description: "When root1 is not included in the update, it should be removed while root2 remains",
},
{
name: "Remove root2 from update",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
},
expectedHandlers: map[string]string{
"upstream-root1": ".",
},
description: "When root2 is not included in the update, it should be removed while root1 remains",
},
{
name: "Add root3 in first position",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root3",
},
priority: PriorityDefault + 1,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
},
expectedHandlers: map[string]string{
"upstream-root1": ".",
"upstream-root2": ".",
"upstream-root3": ".",
},
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
},
{
name: "Add root3 in last position",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root3",
},
priority: PriorityDefault - 2,
},
},
expectedHandlers: map[string]string{
"upstream-root1": ".",
"upstream-root2": ".",
"upstream-root3": ".",
},
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
},
// Mixed domain tests
{
name: "Update with mixed domains - remove one of duplicate domain",
initialHandlers: baseMixedHandlers,
updates: []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-other": "other.com",
},
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
},
{
name: "Update with mixed domains - add new domain",
initialHandlers: baseMixedHandlers,
updates: []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
},
{
domain: "new.com",
handler: &mockHandler{
Id: "upstream-new",
},
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-other": "other.com",
"upstream-new": "new.com",
},
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{
dnsMuxMap: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Perform the update
server.updateMux(tt.updates)
// Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
"Number of handlers after update doesn't match expected")
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[handlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
"Domain mismatch for handler %s", id)
}
}
// Verify no unexpected handlers exist
for handlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(handlerID)]
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
}
// Verify the handlerChain state and order
previousPriority := 0
for _, chainEntry := range server.handlerChain.handlers {
// Verify priority order
if previousPriority > 0 {
assert.True(t, chainEntry.Priority <= previousPriority,
"Handlers in chain not properly ordered by priority")
}
previousPriority = chainEntry.Priority
// Verify handler exists in mux
foundInMux := false
for _, muxEntry := range server.dnsMuxMap {
if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
foundInMux = true
break
}
}
assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxMap")
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
type ShutdownState struct { type ShutdownState struct {
Guid string Guid string
GPO bool
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -13,9 +14,9 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
manager, err := newHostManagerWithGuid(s.Guid) manager := &registryConfigurator{
if err != nil { guid: s.Guid,
return fmt.Errorf("create host manager: %w", err) gpo: s.GPO,
} }
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {

View File

@@ -2,9 +2,13 @@ package dns
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"slices"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -40,6 +44,7 @@ type upstreamResolverBase struct {
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []string
domain string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
@@ -53,12 +58,13 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase { func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{ return &upstreamResolverBase{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
domain: domain,
upstreamTimeout: upstreamTimeout, upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
@@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers) return fmt.Sprintf("upstream %v", u.upstreamServers)
} }
// ID returns the unique handler ID
func (u *upstreamResolverBase) id() handlerID {
servers := slices.Clone(u.upstreamServers)
slices.Sort(servers)
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ",")))
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) MatchSubdomains() bool {
return true return true
} }
@@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
u.checkUpstreamFails(err) u.checkUpstreamFails(err)
}() }()
log.WithField("question", r.Question[0]).Trace("received an upstream question") log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records // set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil { if r.Extra == nil {
r.SetEdns0(4096, false) r.SetEdns0(4096, false)
@@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
log.Tracef("%s has been stopped", u)
return return
default: default:
} }
@@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.WithError(err).WithField("upstream", upstream). log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
Warn("got an error while connecting to upstream")
continue continue
} }
u.failsCount.Add(1) log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
log.WithError(err).WithField("upstream", upstream). continue
Error("got other error while querying the upstream")
return
} }
if rm == nil { if rm == nil || !rm.Response {
log.WithError(err).WithField("upstream", upstream). log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
Warn("no response from upstream") continue
return
}
// those checks need to be independent of each other due to memory address issues
if !rm.Response {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
return
} }
u.successCount.Add(1) u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream) log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
err = w.WriteMsg(rm) if err = w.WriteMsg(rm); err != nil {
if err != nil { log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
log.WithError(err).Error("got an error while writing the upstream resolver response")
} }
// count the fails only if they happen sequentially // count the fails only if they happen sequentially
u.failsCount.Store(0) u.failsCount.Store(0)
return return
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.Error("all queries to the upstream nameservers failed with timeout") log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
}
} }
// checkUpstreamFails counts fails and disables or enables upstream resolving // checkUpstreamFails counts fails and disables or enables upstream resolving

View File

@@ -27,8 +27,9 @@ func newUpstreamResolver(
_ *net.IPNet, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
c := &upstreamResolver{ c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder, hostsDNSHolder: hostsDNSHolder,

View File

@@ -23,8 +23,9 @@ func newUpstreamResolver(
_ *net.IPNet, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{ nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
} }

View File

@@ -30,8 +30,9 @@ func newUpstreamResolver(
net *net.IPNet, net *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string,
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,

View File

@@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
timeout time.Duration timeout time.Duration
cancelCTX bool cancelCTX bool
expectedAnswer string expectedAnswer string
acceptNXDomain bool
}{ }{
{ {
name: "Should Resolve A Record", name: "Should Resolve A Record",
@@ -36,11 +37,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
expectedAnswer: "1.1.1.1", expectedAnswer: "1.1.1.1",
}, },
{ {
name: "Should Not Resolve If Can't Connect To Both Servers", name: "Should Not Resolve If Can't Connect To Both Servers",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"}, InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
timeout: 200 * time.Millisecond, timeout: 200 * time.Millisecond,
responseShouldBeNil: true, acceptNXDomain: true,
}, },
{ {
name: "Should Not Resolve If Parent Context Is Canceled", name: "Should Not Resolve If Parent Context Is Canceled",
@@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
responseShouldBeNil: true, responseShouldBeNil: true,
}, },
} }
// should resolve if first upstream times out
// should not write when both fails
// should not resolve if parent context is canceled
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil) resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {
@@ -84,16 +82,22 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
t.Fatalf("should write a response message") t.Fatalf("should write a response message")
} }
foundAnswer := false if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError {
for _, answer := range responseMSG.Answer { return
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
} }
if !foundAnswer { if testCase.expectedAnswer != "" {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer) foundAnswer := false
for _, answer := range responseMSG.Answer {
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
}
if !foundAnswer {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
}
} }
}) })
} }

View File

@@ -42,13 +42,13 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac" auth "github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
@@ -193,6 +193,10 @@ type Peer struct {
WgAllowedIps string WgAllowedIps string
} }
type localIpUpdater interface {
UpdateLocalIPs() error
}
// NewEngine creates a new Connection Engine with probes attached // NewEngine creates a new Connection Engine with probes attached
func NewEngine( func NewEngine(
clientCtx context.Context, clientCtx context.Context,
@@ -433,7 +437,7 @@ func (e *Engine) createFirewall() error {
} }
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
if err != nil || e.firewall == nil { if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
return nil return nil
@@ -883,6 +887,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.acl.ApplyFiltering(networkMap) e.acl.ApplyFiltering(networkMap)
} }
if e.firewall != nil {
if localipfw, ok := e.firewall.(localIpUpdater); ok {
if err := localipfw.UpdateLocalIPs(); err != nil {
log.Errorf("failed to update local IPs: %v", err)
}
}
}
// DNS forwarder // DNS forwarder
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
@@ -1446,6 +1458,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager return e.routeManager
} }
// GetFirewallManager returns the firewall manager
func (e *Engine) GetFirewallManager() manager.Manager {
return e.firewall
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {
@@ -1657,6 +1674,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nm, nil return nm, nil
} }
// GetWgAddr returns the wireguard address
func (e *Engine) GetWgAddr() net.IP {
if e.wgInterface == nil {
return nil
}
return e.wgInterface.Address().IP
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
if !enabled { if !enabled {

View File

@@ -2,6 +2,7 @@ package peer
import ( import (
"context" "context"
"fmt"
"math/rand" "math/rand"
"net" "net"
"os" "os"
@@ -28,12 +29,28 @@ import (
type ConnPriority int type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case connPriorityNone:
return "None"
case connPriorityRelay:
return "PriorityRelay"
case connPriorityICETurn:
return "PriorityICETurn"
case connPriorityICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}
const ( const (
defaultWgKeepAlive = 25 * time.Second defaultWgKeepAlive = 25 * time.Second
connPriorityNone ConnPriority = 0
connPriorityRelay ConnPriority = 1 connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1 connPriorityICETurn ConnPriority = 2
connPriorityICEP2P ConnPriority = 2 connPriorityICEP2P ConnPriority = 3
) )
type WgConfig struct { type WgConfig struct {
@@ -66,14 +83,6 @@ type ConnConfig struct {
ICEConfig icemaker.Config ICEConfig icemaker.Config
} }
type WorkerCallbacks struct {
OnRelayReadyCallback func(info RelayConnInfo)
OnRelayStatusChanged func(ConnStatus)
OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
OnICEStatusChanged func(ConnStatus)
}
type Conn struct { type Conn struct {
log *log.Entry log *log.Entry
mu sync.Mutex mu sync.Mutex
@@ -135,21 +144,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
semaphore: semaphore, semaphore: semaphore,
} }
rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected,
}
wFns := WorkerICECallbacks{
OnConnReady: conn.iCEConnectionIsReady,
OnStatusChanged: conn.onWorkerICEStateDisconnected,
}
ctrl := isController(config) ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -304,7 +303,7 @@ func (conn *Conn) GetKey() string {
} }
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -317,9 +316,10 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
return return
} }
conn.log.Debugf("ICE connection is ready") // this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
// todo consider to remove this check
if conn.currentConnPriority > priority { if conn.currentConnPriority > priority {
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
conn.statusICE.Set(StatusConnected) conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
return return
@@ -375,8 +375,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
// todo review to make sense to handle connecting and disconnected status also? func (conn *Conn) onICEStateDisconnected() {
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -384,7 +383,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
return return
} }
conn.log.Tracef("ICE connection state changed to %s", newState) conn.log.Tracef("ICE connection state changed to disconnected")
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil { if err := conn.wgProxyICE.CloseConn(); err != nil {
@@ -394,7 +393,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
// switch back to relay connection // switch back to relay connection
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.log.Debugf("ICE disconnected, set Relay to active connection") conn.log.Infof("ICE disconnected, set Relay to active connection")
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
@@ -402,12 +401,16 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
} }
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
} else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
conn.currentConnPriority = connPriorityNone
} }
changed := conn.statusICE.Get() != newState && newState != StatusConnecting changed := conn.statusICE.Get() != StatusDisconnected
conn.statusICE.Set(newState) if changed {
conn.guard.SetICEConnDisconnected()
conn.guard.SetICEConnDisconnected(changed) }
conn.statusICE.Set(StatusDisconnected)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -422,7 +425,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
} }
} }
func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -444,7 +447,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.iceP2PIsActive() { if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected) conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
@@ -474,7 +477,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
} }
func (conn *Conn) onWorkerRelayStateDisconnected() { func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -497,8 +500,10 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
} }
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != StatusDisconnected
if changed {
conn.guard.SetRelayedConnDisconnected()
}
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.Set(StatusDisconnected)
conn.guard.SetRelayedConnDisconnected(changed)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,

View File

@@ -29,8 +29,8 @@ type Guard struct {
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay isConnectedFunc
timeout time.Duration timeout time.Duration
srWatcher *SRWatcher srWatcher *SRWatcher
relayedConnDisconnected chan bool relayedConnDisconnected chan struct{}
iCEConnDisconnected chan bool iCEConnDisconnected chan struct{}
} }
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
@@ -41,8 +41,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
timeout: timeout, timeout: timeout,
srWatcher: srWatcher, srWatcher: srWatcher,
relayedConnDisconnected: make(chan bool, 1), relayedConnDisconnected: make(chan struct{}, 1),
iCEConnDisconnected: make(chan bool, 1), iCEConnDisconnected: make(chan struct{}, 1),
} }
} }
@@ -54,16 +54,16 @@ func (g *Guard) Start(ctx context.Context) {
} }
} }
func (g *Guard) SetRelayedConnDisconnected(changed bool) { func (g *Guard) SetRelayedConnDisconnected() {
select { select {
case g.relayedConnDisconnected <- changed: case g.relayedConnDisconnected <- struct{}{}:
default: default:
} }
} }
func (g *Guard) SetICEConnDisconnected(changed bool) { func (g *Guard) SetICEConnDisconnected() {
select { select {
case g.iCEConnDisconnected <- changed: case g.iCEConnDisconnected <- struct{}{}:
default: default:
} }
} }
@@ -96,19 +96,13 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
g.triggerOfferSending() g.triggerOfferSending()
} }
case changed := <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, reset reconnection ticker") g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
case changed := <-g.iCEConnDisconnected: case <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE connection changed, reset reconnection ticker") g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.prepareExponentTicker(ctx)
@@ -138,16 +132,10 @@ func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
g.log.Infof("start listen for reconnect events...") g.log.Infof("start listen for reconnect events...")
for { for {
select { select {
case changed := <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, triggering reconnect") g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending() g.triggerOfferSending()
case changed := <-g.iCEConnDisconnected: case <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE state changed, try to send new offer") g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending() g.triggerOfferSending()
case <-srReconnectedChan: case <-srReconnectedChan:

View File

@@ -721,7 +721,9 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
func (d *Status) GetDNSStates() []NSGroupState { func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
return d.nsGroupStates
// shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates)
} }
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {

View File

@@ -0,0 +1,154 @@
package peer
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/configurer"
)
const (
wgHandshakePeriod = 3 * time.Minute
)
var (
wgHandshakeOvertime = 30 * time.Second // allowed delay in network
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
)
type WGInterfaceStater interface {
GetStats(key string) (configurer.WGStats, error)
}
type WGWatcher struct {
log *log.Entry
wgIfaceStater WGInterfaceStater
peerKey string
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
waitGroup sync.WaitGroup
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
return &WGWatcher{
log: log,
wgIfaceStater: wgIfaceStater,
peerKey: peerKey,
}
}
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
return
}
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.waitGroup.Add(1)
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
}
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
func (w *WGWatcher) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancel == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancel()
w.ctxCancel = nil
w.waitGroup.Wait()
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
defer w.waitGroup.Done()
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
lastHandshake := initialHandshake
for {
select {
case <-timer.C:
handshake, ok := w.handshakeCheck(lastHandshake)
if !ok {
onDisconnectedFn()
return
}
lastHandshake = *handshake
resetTime := time.Until(handshake.Add(checkPeriod))
timer.Reset(resetTime)
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-ctx.Done():
w.log.Infof("WireGuard watcher stopped")
return
}
}
}
// handshakeCheck checks the WireGuard handshake and return the new handshake time if it is different from the previous one
func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
return nil, false
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
// the current know handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}
// in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}
// error handling for handshake time in the future
if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
return nil, false
}
return &handshake, true
}
func (w *WGWatcher) wgState() (time.Time, error) {
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
if err != nil {
return time.Time{}, err
}
return wgState.LastHandshake, nil
}

View File

@@ -0,0 +1,98 @@
package peer
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type MocWgIface struct {
initial bool
lastHandshake time.Time
stop bool
}
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) {
if !m.initial {
m.initial = true
return configurer.WGStats{}, nil
}
if !m.stop {
m.lastHandshake = time.Now()
}
stats := configurer.WGStats{
LastHandshake: m.lastHandshake,
}
return stats, nil
}
func (m *MocWgIface) disconnect() {
m.stop = true
}
func TestWGWatcher_EnableWgWatcher(t *testing.T) {
checkPeriod = 5 * time.Second
wgHandshakeOvertime = 1 * time.Second
mlog := log.WithField("peer", "tet")
mocWgIface := &MocWgIface{}
watcher := NewWGWatcher(mlog, mocWgIface, "")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
onDisconnected := make(chan struct{}, 1)
watcher.EnableWgWatcher(ctx, func() {
mlog.Infof("onDisconnectedFn")
onDisconnected <- struct{}{}
})
// wait for initial reading
time.Sleep(2 * time.Second)
mocWgIface.disconnect()
select {
case <-onDisconnected:
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}
func TestWGWatcher_ReEnable(t *testing.T) {
checkPeriod = 5 * time.Second
wgHandshakeOvertime = 1 * time.Second
mlog := log.WithField("peer", "tet")
mocWgIface := &MocWgIface{}
watcher := NewWGWatcher(mlog, mocWgIface, "")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
onDisconnected := make(chan struct{}, 1)
watcher.EnableWgWatcher(ctx, func() {})
watcher.DisableWgWatcher()
watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{}
})
time.Sleep(2 * time.Second)
mocWgIface.disconnect()
select {
case <-onDisconnected:
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}

View File

@@ -31,20 +31,15 @@ type ICEConnInfo struct {
RelayedOnLocal bool RelayedOnLocal bool
} }
type WorkerICECallbacks struct {
OnConnReady func(ConnPriority, ICEConnInfo)
OnStatusChanged func(ConnStatus)
}
type WorkerICE struct { type WorkerICE struct {
ctx context.Context ctx context.Context
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
conn *Conn
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
statusRecorder *Status statusRecorder *Status
hasRelayOnLocally bool hasRelayOnLocally bool
conn WorkerICECallbacks
agent *ice.Agent agent *ice.Agent
muxAgent sync.Mutex muxAgent sync.Mutex
@@ -60,16 +55,16 @@ type WorkerICE struct {
lastKnownState ice.ConnectionState lastKnownState ice.ConnectionState
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
w := &WorkerICE{ w := &WorkerICE{
ctx: ctx, ctx: ctx,
log: log, log: log,
config: config, config: config,
conn: conn,
signaler: signaler, signaler: signaler,
iFaceDiscover: ifaceDiscover, iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally, hasRelayOnLocally: hasRelayOnLocally,
conn: callBacks,
} }
localUfrag, localPwd, err := icemaker.GenerateICECredentials() localUfrag, localPwd, err := icemaker.GenerateICECredentials()
@@ -154,8 +149,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
Relayed: isRelayed(pair), Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
} }
w.log.Debugf("on ICE conn read to use ready") w.log.Debugf("on ICE conn is ready to use")
go w.conn.OnConnReady(selectedPriority(pair), ci) go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -220,7 +215,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected { if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected) w.conn.onICEStateDisconnected()
} }
w.closeAgent(agentCancel) w.closeAgent(agentCancel)
default: default:

View File

@@ -6,52 +6,41 @@ import (
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
) )
var (
wgHandshakePeriod = 3 * time.Minute
wgHandshakeOvertime = 30 * time.Second
)
type RelayConnInfo struct { type RelayConnInfo struct {
relayedConn net.Conn relayedConn net.Conn
rosenpassPubKey []byte rosenpassPubKey []byte
rosenpassAddr string rosenpassAddr string
} }
type WorkerRelayCallbacks struct {
OnConnReady func(RelayConnInfo)
OnDisconnected func()
}
type WorkerRelay struct { type WorkerRelay struct {
log *log.Entry log *log.Entry
isController bool isController bool
config ConnConfig config ConnConfig
conn *Conn
relayManager relayClient.ManagerService relayManager relayClient.ManagerService
callBacks WorkerRelayCallbacks
relayedConn net.Conn relayedConn net.Conn
relayLock sync.Mutex relayLock sync.Mutex
ctxWgWatch context.Context
ctxCancelWgWatch context.CancelFunc
ctxLock sync.Mutex
relaySupportedOnRemotePeer atomic.Bool relaySupportedOnRemotePeer atomic.Bool
wgWatcher *WGWatcher
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
conn: conn,
relayManager: relayManager, relayManager: relayManager,
callBacks: callbacks, wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key),
} }
return r return r
} }
@@ -87,7 +76,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.relayedConn = relayedConn w.relayedConn = relayedConn
w.relayLock.Unlock() w.relayLock.Unlock()
err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected) err = w.relayManager.AddCloseListener(srv, w.onRelayClientDisconnected)
if err != nil { if err != nil {
log.Errorf("failed to add close listener: %s", err) log.Errorf("failed to add close listener: %s", err)
_ = relayedConn.Close() _ = relayedConn.Close()
@@ -95,7 +84,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
} }
w.log.Debugf("peer conn opened via Relay: %s", srv) w.log.Debugf("peer conn opened via Relay: %s", srv)
go w.callBacks.OnConnReady(RelayConnInfo{ go w.conn.onRelayConnectionIsReady(RelayConnInfo{
relayedConn: relayedConn, relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr, rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
@@ -103,32 +92,11 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
} }
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
w.log.Debugf("enable WireGuard watcher") w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil {
return
}
ctx, ctxCancel := context.WithCancel(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
w.wgStateCheck(ctx, ctxCancel)
} }
func (w *WorkerRelay) DisableWgWatcher() { func (w *WorkerRelay) DisableWgWatcher() {
w.ctxLock.Lock() w.wgWatcher.DisableWgWatcher()
defer w.ctxLock.Unlock()
if w.ctxCancelWgWatch == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancelWgWatch()
} }
func (w *WorkerRelay) RelayInstanceAddress() (string, error) { func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
@@ -150,57 +118,17 @@ func (w *WorkerRelay) CloseConn() {
return return
} }
err := w.relayedConn.Close() if err := w.relayedConn.Close(); err != nil {
if err != nil {
w.log.Warnf("failed to close relay connection: %v", err) w.log.Warnf("failed to close relay connection: %v", err)
} }
} }
// wgStateCheck help to check the state of the WireGuard handshake and relay connection func (w *WorkerRelay) onWGDisconnected() {
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) { w.relayLock.Lock()
w.log.Debugf("WireGuard watcher started") _ = w.relayedConn.Close()
lastHandshake, err := w.wgState() w.relayLock.Unlock()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
go func(lastHandshake time.Time) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
for {
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgHandshakeOvertime)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
}
}(lastHandshake)
w.conn.onRelayDisconnected()
} }
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
@@ -217,20 +145,7 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
return remoteRelayAddress return remoteRelayAddress
} }
func (w *WorkerRelay) wgState() (time.Time, error) { func (w *WorkerRelay) onRelayClientDisconnected() {
wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key) w.wgWatcher.DisableWgWatcher()
if err != nil { go w.conn.onRelayDisconnected()
return time.Time{}, err
}
return wgState.LastHandshake, nil
}
func (w *WorkerRelay) onRelayMGDisconnected() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancelWgWatch != nil {
w.ctxCancelWgWatch()
}
go w.callBacks.OnDisconnected()
} }

View File

@@ -113,13 +113,14 @@ func NewManager(config ManagerConfig) *DefaultManager {
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
} }
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
dm.setupRefCounters(useNoop)
// don't proceed with client routes if it is disabled // don't proceed with client routes if it is disabled
if config.DisableClientRoutes { if config.DisableClientRoutes {
return dm return dm
} }
dm.setupRefCounters()
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.initialClientRoutes(config.InitialRoutes) cr := dm.initialClientRoutes(config.InitialRoutes)
dm.notifier.SetInitialClientRoutes(cr) dm.notifier.SetInitialClientRoutes(cr)
@@ -127,7 +128,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
return dm return dm
} }
func (m *DefaultManager) setupRefCounters() { func (m *DefaultManager) setupRefCounters(useNoop bool) {
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) { func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
@@ -137,7 +138,7 @@ func (m *DefaultManager) setupRefCounters() {
}, },
) )
if netstack.IsEnabled() { if useNoop {
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
func(netip.Prefix, struct{}) (struct{}, error) { func(netip.Prefix, struct{}) (struct{}, error) {
return struct{}{}, refcounter.ErrIgnore return struct{}{}, refcounter.ErrIgnore
@@ -285,15 +286,15 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.updateClientNetworks(updateSerial, filteredClientRoutes) m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes)
} }
m.clientRoutes = newClientRoutesIDMap
if m.serverRouter != nil { if m.serverRouter == nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) return nil
if err != nil {
return err
}
} }
m.clientRoutes = newClientRoutesIDMap if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
return fmt.Errorf("update routes: %w", err)
}
return nil return nil
} }
@@ -422,11 +423,6 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
haID := newRoute.GetHAUniqueID() haID := newRoute.GetHAUniqueID()
if newRoute.Peer == m.pubKey { if newRoute.Peer == m.pubKey {
ownNetworkIDs[haID] = true ownNetworkIDs[haID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute newServerRoutesMap[newRoute.ID] = newRoute
} }
} }
@@ -454,7 +450,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
} }
func isRouteSupported(route *route.Route) bool { func isRouteSupported(route *route.Route) bool {
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() { if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
return true return true
} }

View File

@@ -71,9 +71,15 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
} }
if len(m.routes) > 0 { if len(m.routes) > 0 {
err := systemops.EnableIPForwarding() if err := systemops.EnableIPForwarding(); err != nil {
if err != nil { return fmt.Errorf("enable ip forwarding: %w", err)
return err }
if err := m.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err)
}
} else {
if err := m.firewall.DisableRouting(); err != nil {
return fmt.Errorf("disable routing: %w", err)
} }
} }

View File

@@ -53,20 +53,6 @@ type ruleParams struct {
description string description string
} }
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
}
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams { func getSetupRules() []ruleParams {
return []ruleParams{ return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
@@ -87,7 +73,7 @@ func getSetupRules() []ruleParams {
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
if isLegacy() { if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
@@ -103,11 +89,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true)
return r.setupRefCounter(initAddresses, stateManager)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
} }
} }
@@ -130,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
// It systematically removes the three rules and any associated routing table entries to ensure a clean state. // It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process. // The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
if isLegacy() { if !nbnet.AdvancedRouting() {
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
@@ -168,7 +149,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
} }
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if !nbnet.AdvancedRouting() {
return r.genericAddVPNRoute(prefix, intf) return r.genericAddVPNRoute(prefix, intf)
} }
@@ -191,7 +172,7 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
} }
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if !nbnet.AdvancedRouting() {
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }
@@ -504,7 +485,7 @@ func getAddressFamily(prefix netip.Prefix) int {
} }
func hasSeparateRouting() ([]netip.Prefix, error) { func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() { if !nbnet.AdvancedRouting() {
return GetRoutesFromTable() return GetRoutesFromTable()
} }
return nil, ErrRoutingIsSeparate return nil, ErrRoutingIsSeparate

View File

@@ -85,6 +85,7 @@ var testCases = []testCase{
} }
func TestRouting(t *testing.T) { func TestRouting(t *testing.T) {
nbnet.Init()
for _, tc := range testCases { for _, tc := range testCases {
// todo resolve test execution on freebsd // todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" { if runtime.GOOS == "freebsd" {

View File

@@ -2571,6 +2571,330 @@ func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{39} return file_daemon_proto_rawDescGZIP(), []int{39}
} }
type TCPFlags struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Syn bool `protobuf:"varint,1,opt,name=syn,proto3" json:"syn,omitempty"`
Ack bool `protobuf:"varint,2,opt,name=ack,proto3" json:"ack,omitempty"`
Fin bool `protobuf:"varint,3,opt,name=fin,proto3" json:"fin,omitempty"`
Rst bool `protobuf:"varint,4,opt,name=rst,proto3" json:"rst,omitempty"`
Psh bool `protobuf:"varint,5,opt,name=psh,proto3" json:"psh,omitempty"`
Urg bool `protobuf:"varint,6,opt,name=urg,proto3" json:"urg,omitempty"`
}
func (x *TCPFlags) Reset() {
*x = TCPFlags{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[40]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TCPFlags) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TCPFlags) ProtoMessage() {}
func (x *TCPFlags) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[40]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead.
func (*TCPFlags) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{40}
}
func (x *TCPFlags) GetSyn() bool {
if x != nil {
return x.Syn
}
return false
}
func (x *TCPFlags) GetAck() bool {
if x != nil {
return x.Ack
}
return false
}
func (x *TCPFlags) GetFin() bool {
if x != nil {
return x.Fin
}
return false
}
func (x *TCPFlags) GetRst() bool {
if x != nil {
return x.Rst
}
return false
}
func (x *TCPFlags) GetPsh() bool {
if x != nil {
return x.Psh
}
return false
}
func (x *TCPFlags) GetUrg() bool {
if x != nil {
return x.Urg
}
return false
}
type TracePacketRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
SourceIp string `protobuf:"bytes,1,opt,name=source_ip,json=sourceIp,proto3" json:"source_ip,omitempty"`
DestinationIp string `protobuf:"bytes,2,opt,name=destination_ip,json=destinationIp,proto3" json:"destination_ip,omitempty"`
Protocol string `protobuf:"bytes,3,opt,name=protocol,proto3" json:"protocol,omitempty"`
SourcePort uint32 `protobuf:"varint,4,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
DestinationPort uint32 `protobuf:"varint,5,opt,name=destination_port,json=destinationPort,proto3" json:"destination_port,omitempty"`
Direction string `protobuf:"bytes,6,opt,name=direction,proto3" json:"direction,omitempty"`
TcpFlags *TCPFlags `protobuf:"bytes,7,opt,name=tcp_flags,json=tcpFlags,proto3,oneof" json:"tcp_flags,omitempty"`
IcmpType *uint32 `protobuf:"varint,8,opt,name=icmp_type,json=icmpType,proto3,oneof" json:"icmp_type,omitempty"`
IcmpCode *uint32 `protobuf:"varint,9,opt,name=icmp_code,json=icmpCode,proto3,oneof" json:"icmp_code,omitempty"`
}
func (x *TracePacketRequest) Reset() {
*x = TracePacketRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[41]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TracePacketRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TracePacketRequest) ProtoMessage() {}
func (x *TracePacketRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[41]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead.
func (*TracePacketRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{41}
}
func (x *TracePacketRequest) GetSourceIp() string {
if x != nil {
return x.SourceIp
}
return ""
}
func (x *TracePacketRequest) GetDestinationIp() string {
if x != nil {
return x.DestinationIp
}
return ""
}
func (x *TracePacketRequest) GetProtocol() string {
if x != nil {
return x.Protocol
}
return ""
}
func (x *TracePacketRequest) GetSourcePort() uint32 {
if x != nil {
return x.SourcePort
}
return 0
}
func (x *TracePacketRequest) GetDestinationPort() uint32 {
if x != nil {
return x.DestinationPort
}
return 0
}
func (x *TracePacketRequest) GetDirection() string {
if x != nil {
return x.Direction
}
return ""
}
func (x *TracePacketRequest) GetTcpFlags() *TCPFlags {
if x != nil {
return x.TcpFlags
}
return nil
}
func (x *TracePacketRequest) GetIcmpType() uint32 {
if x != nil && x.IcmpType != nil {
return *x.IcmpType
}
return 0
}
func (x *TracePacketRequest) GetIcmpCode() uint32 {
if x != nil && x.IcmpCode != nil {
return *x.IcmpCode
}
return 0
}
type TraceStage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
Allowed bool `protobuf:"varint,3,opt,name=allowed,proto3" json:"allowed,omitempty"`
ForwardingDetails *string `protobuf:"bytes,4,opt,name=forwarding_details,json=forwardingDetails,proto3,oneof" json:"forwarding_details,omitempty"`
}
func (x *TraceStage) Reset() {
*x = TraceStage{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[42]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TraceStage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TraceStage) ProtoMessage() {}
func (x *TraceStage) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[42]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TraceStage.ProtoReflect.Descriptor instead.
func (*TraceStage) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{42}
}
func (x *TraceStage) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *TraceStage) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
func (x *TraceStage) GetAllowed() bool {
if x != nil {
return x.Allowed
}
return false
}
func (x *TraceStage) GetForwardingDetails() string {
if x != nil && x.ForwardingDetails != nil {
return *x.ForwardingDetails
}
return ""
}
type TracePacketResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Stages []*TraceStage `protobuf:"bytes,1,rep,name=stages,proto3" json:"stages,omitempty"`
FinalDisposition bool `protobuf:"varint,2,opt,name=final_disposition,json=finalDisposition,proto3" json:"final_disposition,omitempty"`
}
func (x *TracePacketResponse) Reset() {
*x = TracePacketResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[43]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TracePacketResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TracePacketResponse) ProtoMessage() {}
func (x *TracePacketResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[43]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead.
func (*TracePacketResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{43}
}
func (x *TracePacketResponse) GetStages() []*TraceStage {
if x != nil {
return x.Stages
}
return nil
}
func (x *TracePacketResponse) GetFinalDisposition() bool {
if x != nil {
return x.FinalDisposition
}
return false
}
var File_daemon_proto protoreflect.FileDescriptor var File_daemon_proto protoreflect.FileDescriptor
var file_daemon_proto_rawDesc = []byte{ var file_daemon_proto_rawDesc = []byte{
@@ -2920,87 +3244,141 @@ var file_daemon_proto_rawDesc = []byte{
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22,
0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x6e, 0x73, 0x65, 0x22, 0x76, 0x0a, 0x08, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12,
0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x10, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79,
0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03,
0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x61, 0x63, 0x6b, 0x12, 0x10, 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08,
0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01,
0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x28, 0x08, 0x52, 0x03, 0x72, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05,
0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x73, 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67,
0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12,
0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18,
0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12,
0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69,
0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61,
0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x70, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63,
0x6f, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63,
0x6f, 0x6c, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72,
0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50,
0x6f, 0x72, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64,
0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c,
0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28,
0x09, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09,
0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32,
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67,
0x73, 0x48, 0x00, 0x52, 0x08, 0x74, 0x63, 0x70, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01,
0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20,
0x01, 0x28, 0x0d, 0x48, 0x01, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88,
0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18,
0x09, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x02, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64,
0x65, 0x88, 0x01, 0x01, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61,
0x67, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65,
0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f,
0x01, 0x0a, 0x0a, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a,
0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d,
0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01,
0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61,
0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c,
0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64,
0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28,
0x09, 0x48, 0x00, 0x52, 0x11, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44,
0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x88, 0x01, 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f,
0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73,
0x22, 0x6e, 0x0a, 0x13, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65,
0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61,
0x67, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73,
0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10,
0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e,
0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07,
0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e,
0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12,
0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41,
0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09,
0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41,
0x43, 0x45, 0x10, 0x07, 0x32, 0xdd, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55,
0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71,
0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70,
0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61,
0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e,
0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65,
0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f,
0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77,
0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53,
0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e,
0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74,
0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64,
0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77,
0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53,
0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65,
0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75,
0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b,
0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a,
0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c,
0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61,
0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65,
0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65,
0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e,
0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74,
0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65,
0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52,
0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74,
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73,
0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28,
0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f,
0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72,
0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54,
0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@@ -3016,7 +3394,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
} }
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 45)
var file_daemon_proto_goTypes = []interface{}{ var file_daemon_proto_goTypes = []interface{}{
(LogLevel)(0), // 0: daemon.LogLevel (LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest (*LoginRequest)(nil), // 1: daemon.LoginRequest
@@ -3059,16 +3437,20 @@ var file_daemon_proto_goTypes = []interface{}{
(*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse
(*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest
(*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse
nil, // 41: daemon.Network.ResolvedIPsEntry (*TCPFlags)(nil), // 41: daemon.TCPFlags
(*durationpb.Duration)(nil), // 42: google.protobuf.Duration (*TracePacketRequest)(nil), // 42: daemon.TracePacketRequest
(*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp (*TraceStage)(nil), // 43: daemon.TraceStage
(*TracePacketResponse)(nil), // 44: daemon.TracePacketResponse
nil, // 45: daemon.Network.ResolvedIPsEntry
(*durationpb.Duration)(nil), // 46: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 47: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 46, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 47, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 47, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 46, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -3076,48 +3458,52 @@ var file_daemon_proto_depIdxs = []int32{
17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network 25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 45, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State
24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 41, // 16: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 43, // 17: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 24, // 18: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest 1, // 19: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 3, // 20: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest 5, // 21: daemon.DaemonService.Up:input_type -> daemon.UpRequest
11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest 7, // 22: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest 9, // 23: daemon.DaemonService.Down:input_type -> daemon.DownRequest
22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest 11, // 24: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest 20, // 25: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 22, // 26: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 22, // 27: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest 26, // 28: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest 28, // 29: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest 30, // 30: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest 33, // 31: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest 35, // 32: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 37, // 33: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 39, // 34: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse 42, // 35: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 2, // 36: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse 4, // 37: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 6, // 38: daemon.DaemonService.Up:output_type -> daemon.UpResponse
21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse 8, // 39: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse 10, // 40: daemon.DaemonService.Down:output_type -> daemon.DownResponse
23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 12, // 41: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 21, // 42: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 23, // 43: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse 23, // 44: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse 27, // 45: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse 29, // 46: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse 31, // 47: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse 34, // 48: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
33, // [33:49] is the sub-list for method output_type 36, // 49: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
17, // [17:33] is the sub-list for method input_type 38, // 50: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
17, // [17:17] is the sub-list for extension type_name 40, // 51: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
17, // [17:17] is the sub-list for extension extendee 44, // 52: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
0, // [0:17] is the sub-list for field type_name 36, // [36:53] is the sub-list for method output_type
19, // [19:36] is the sub-list for method input_type
19, // [19:19] is the sub-list for extension type_name
19, // [19:19] is the sub-list for extension extendee
0, // [0:19] is the sub-list for field type_name
} }
func init() { file_daemon_proto_init() } func init() { file_daemon_proto_init() }
@@ -3606,15 +3992,65 @@ func file_daemon_proto_init() {
return nil return nil
} }
} }
file_daemon_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TCPFlags); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TracePacketRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TraceStage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[43].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TracePacketResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
file_daemon_proto_msgTypes[41].OneofWrappers = []interface{}{}
file_daemon_proto_msgTypes[42].OneofWrappers = []interface{}{}
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{ File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc, RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 1, NumEnums: 1,
NumMessages: 41, NumMessages: 45,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -57,6 +57,8 @@ service DaemonService {
// SetNetworkMapPersistence enables or disables network map persistence // SetNetworkMapPersistence enables or disables network map persistence
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {} rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
} }
@@ -356,3 +358,36 @@ message SetNetworkMapPersistenceRequest {
} }
message SetNetworkMapPersistenceResponse {} message SetNetworkMapPersistenceResponse {}
message TCPFlags {
bool syn = 1;
bool ack = 2;
bool fin = 3;
bool rst = 4;
bool psh = 5;
bool urg = 6;
}
message TracePacketRequest {
string source_ip = 1;
string destination_ip = 2;
string protocol = 3;
uint32 source_port = 4;
uint32 destination_port = 5;
string direction = 6;
optional TCPFlags tcp_flags = 7;
optional uint32 icmp_type = 8;
optional uint32 icmp_code = 9;
}
message TraceStage {
string name = 1;
string message = 2;
bool allowed = 3;
optional string forwarding_details = 4;
}
message TracePacketResponse {
repeated TraceStage stages = 1;
bool final_disposition = 2;
}

View File

@@ -51,6 +51,7 @@ type DaemonServiceClient interface {
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence // SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -205,6 +206,15 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *
return out, nil return out, nil
} }
func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) {
out := new(TracePacketResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -242,6 +252,7 @@ type DaemonServiceServer interface {
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence // SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -297,6 +308,9 @@ func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStat
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) { func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented") return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
} }
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -598,6 +612,24 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(TracePacketRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).TracePacket(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/TracePacket",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -669,6 +701,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SetNetworkMapPersistence", MethodName: "SetNetworkMapPersistence",
Handler: _DaemonService_SetNetworkMapPersistence_Handler, Handler: _DaemonService_SetNetworkMapPersistence_Handler,
}, },
{
MethodName: "TracePacket",
Handler: _DaemonService_TracePacket_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto", Metadata: "daemon.proto",

View File

@@ -538,7 +538,24 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
} }
log.SetLevel(level) log.SetLevel(level)
if s.connectClient == nil {
return nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
fwManager.SetLogLevel(level)
log.Infof("Log level set to %s", level.String()) log.Infof("Log level set to %s", level.String())
return &proto.SetLogLevelResponse{}, nil return &proto.SetLogLevelResponse{}, nil
} }

123
client/server/trace.go Normal file
View File

@@ -0,0 +1,123 @@
package server
import (
"context"
"fmt"
"net"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/proto"
)
type packetTracer interface {
TracePacketFromBuilder(builder *uspfilter.PacketBuilder) (*uspfilter.PacketTrace, error)
}
func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (*proto.TracePacketResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.connectClient == nil {
return nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
tracer, ok := fwManager.(packetTracer)
if !ok {
return nil, fmt.Errorf("firewall manager does not support packet tracing")
}
srcIP := net.ParseIP(req.GetSourceIp())
if req.GetSourceIp() == "self" {
srcIP = engine.GetWgAddr()
}
dstIP := net.ParseIP(req.GetDestinationIp())
if req.GetDestinationIp() == "self" {
dstIP = engine.GetWgAddr()
}
if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("invalid IP address")
}
var tcpState *uspfilter.TCPState
if flags := req.GetTcpFlags(); flags != nil {
tcpState = &uspfilter.TCPState{
SYN: flags.GetSyn(),
ACK: flags.GetAck(),
FIN: flags.GetFin(),
RST: flags.GetRst(),
PSH: flags.GetPsh(),
URG: flags.GetUrg(),
}
}
var dir fw.RuleDirection
switch req.GetDirection() {
case "in":
dir = fw.RuleDirectionIN
case "out":
dir = fw.RuleDirectionOUT
default:
return nil, fmt.Errorf("invalid direction")
}
var protocol fw.Protocol
switch req.GetProtocol() {
case "tcp":
protocol = fw.ProtocolTCP
case "udp":
protocol = fw.ProtocolUDP
case "icmp":
protocol = fw.ProtocolICMP
default:
return nil, fmt.Errorf("invalid protocolcol")
}
builder := &uspfilter.PacketBuilder{
SrcIP: srcIP,
DstIP: dstIP,
Protocol: protocol,
SrcPort: uint16(req.GetSourcePort()),
DstPort: uint16(req.GetDestinationPort()),
Direction: dir,
TCPState: tcpState,
ICMPType: uint8(req.GetIcmpType()),
ICMPCode: uint8(req.GetIcmpCode()),
}
trace, err := tracer.TracePacketFromBuilder(builder)
if err != nil {
return nil, fmt.Errorf("trace packet: %w", err)
}
resp := &proto.TracePacketResponse{}
for _, result := range trace.Results {
stage := &proto.TraceStage{
Name: result.Stage.String(),
Message: result.Message,
Allowed: result.Allowed,
}
if result.ForwarderAction != nil {
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
stage.ForwardingDetails = &details
}
resp.Stages = append(resp.Stages, stage)
}
if len(trace.Results) > 0 {
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
}
return resp, nil
}

4
go.mod
View File

@@ -102,6 +102,7 @@ require (
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.7 gorm.io/driver/sqlite v1.5.7
gorm.io/gorm v1.25.12 gorm.io/gorm v1.25.12
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
) )
require ( require (
@@ -237,7 +238,6 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 // indirect
k8s.io/apimachinery v0.26.2 // indirect k8s.io/apimachinery v0.26.2 // indirect
) )
@@ -245,7 +245,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

8
go.sum
View File

@@ -535,8 +535,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -1250,8 +1250,8 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 h1:sCEaoA7ZmkuFwa2IR61pl4+RYZPwCJOiaSYT0k+BRf8= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -33,7 +38,7 @@ var (
) )
func TestAccounts_List_200(t *testing.T) { func TestAccounts_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Account{testAccount}) retBytes, _ := json.Marshal([]api.Account{testAccount})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -47,7 +52,7 @@ func TestAccounts_List_200(t *testing.T) {
} }
func TestAccounts_List_Err(t *testing.T) { func TestAccounts_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -62,7 +67,7 @@ func TestAccounts_List_Err(t *testing.T) {
} }
func TestAccounts_Update_200(t *testing.T) { func TestAccounts_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -87,7 +92,7 @@ func TestAccounts_Update_200(t *testing.T) {
} }
func TestAccounts_Update_Err(t *testing.T) { func TestAccounts_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -106,7 +111,7 @@ func TestAccounts_Update_Err(t *testing.T) {
} }
func TestAccounts_Delete_200(t *testing.T) { func TestAccounts_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -117,7 +122,7 @@ func TestAccounts_Delete_200(t *testing.T) {
} }
func TestAccounts_Delete_Err(t *testing.T) { func TestAccounts_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -131,7 +136,7 @@ func TestAccounts_Delete_Err(t *testing.T) {
} }
func TestAccounts_Integration_List(t *testing.T) { func TestAccounts_Integration_List(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
accounts, err := c.Accounts.List(context.Background()) accounts, err := c.Accounts.List(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, accounts, 1) assert.Len(t, accounts, 1)
@@ -141,7 +146,7 @@ func TestAccounts_Integration_List(t *testing.T) {
} }
func TestAccounts_Integration_Update(t *testing.T) { func TestAccounts_Integration_Update(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
accounts, err := c.Accounts.List(context.Background()) accounts, err := c.Accounts.List(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, accounts, 1) assert.Len(t, accounts, 1)
@@ -157,7 +162,7 @@ func TestAccounts_Integration_Update(t *testing.T) {
// Account deletion on MySQL and PostgreSQL databases causes unknown errors // Account deletion on MySQL and PostgreSQL databases causes unknown errors
// func TestAccounts_Integration_Delete(t *testing.T) { // func TestAccounts_Integration_Delete(t *testing.T) {
// withBlackBoxServer(t, func(c *Client) { // withBlackBoxServer(t, func(c *rest.Client) {
// accounts, err := c.Accounts.List(context.Background()) // accounts, err := c.Accounts.List(context.Background())
// require.NoError(t, err) // require.NoError(t, err)
// assert.Len(t, accounts, 1) // assert.Len(t, accounts, 1)

View File

@@ -1,18 +1,22 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
) )
func withMockClient(callback func(*Client, *http.ServeMux)) { func withMockClient(callback func(*rest.Client, *http.ServeMux)) {
mux := &http.ServeMux{} mux := &http.ServeMux{}
server := httptest.NewServer(mux) server := httptest.NewServer(mux)
defer server.Close() defer server.Close()
c := New(server.URL, "ABC") c := rest.New(server.URL, "ABC")
callback(c, mux) callback(c, mux)
} }
@@ -20,11 +24,11 @@ func ptr[T any, PT *T](x T) PT {
return &x return &x
} }
func withBlackBoxServer(t *testing.T, callback func(*Client)) { func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) {
t.Helper() t.Helper()
handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../server/testdata/store.sql", nil, false) handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../server/testdata/store.sql", nil, false)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
c := New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp") c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
callback(c) callback(c)
} }

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -25,7 +30,7 @@ var (
) )
func TestDNSNameserverGroup_List_200(t *testing.T) { func TestDNSNameserverGroup_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NameserverGroup{testNameserverGroup}) retBytes, _ := json.Marshal([]api.NameserverGroup{testNameserverGroup})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -39,7 +44,7 @@ func TestDNSNameserverGroup_List_200(t *testing.T) {
} }
func TestDNSNameserverGroup_List_Err(t *testing.T) { func TestDNSNameserverGroup_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -54,7 +59,7 @@ func TestDNSNameserverGroup_List_Err(t *testing.T) {
} }
func TestDNSNameserverGroup_Get_200(t *testing.T) { func TestDNSNameserverGroup_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNameserverGroup) retBytes, _ := json.Marshal(testNameserverGroup)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -67,7 +72,7 @@ func TestDNSNameserverGroup_Get_200(t *testing.T) {
} }
func TestDNSNameserverGroup_Get_Err(t *testing.T) { func TestDNSNameserverGroup_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -82,7 +87,7 @@ func TestDNSNameserverGroup_Get_Err(t *testing.T) {
} }
func TestDNSNameserverGroup_Create_200(t *testing.T) { func TestDNSNameserverGroup_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -104,7 +109,7 @@ func TestDNSNameserverGroup_Create_200(t *testing.T) {
} }
func TestDNSNameserverGroup_Create_Err(t *testing.T) { func TestDNSNameserverGroup_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -121,7 +126,7 @@ func TestDNSNameserverGroup_Create_Err(t *testing.T) {
} }
func TestDNSNameserverGroup_Update_200(t *testing.T) { func TestDNSNameserverGroup_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -143,7 +148,7 @@ func TestDNSNameserverGroup_Update_200(t *testing.T) {
} }
func TestDNSNameserverGroup_Update_Err(t *testing.T) { func TestDNSNameserverGroup_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -160,7 +165,7 @@ func TestDNSNameserverGroup_Update_Err(t *testing.T) {
} }
func TestDNSNameserverGroup_Delete_200(t *testing.T) { func TestDNSNameserverGroup_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -171,7 +176,7 @@ func TestDNSNameserverGroup_Delete_200(t *testing.T) {
} }
func TestDNSNameserverGroup_Delete_Err(t *testing.T) { func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -185,7 +190,7 @@ func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
} }
func TestDNSSettings_Get_200(t *testing.T) { func TestDNSSettings_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testSettings) retBytes, _ := json.Marshal(testSettings)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -198,7 +203,7 @@ func TestDNSSettings_Get_200(t *testing.T) {
} }
func TestDNSSettings_Get_Err(t *testing.T) { func TestDNSSettings_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -213,7 +218,7 @@ func TestDNSSettings_Get_Err(t *testing.T) {
} }
func TestDNSSettings_Update_200(t *testing.T) { func TestDNSSettings_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -235,7 +240,7 @@ func TestDNSSettings_Update_200(t *testing.T) {
} }
func TestDNSSettings_Update_Err(t *testing.T) { func TestDNSSettings_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -267,7 +272,7 @@ func TestDNS_Integration(t *testing.T) {
Primary: true, Primary: true,
SearchDomainsEnabled: false, SearchDomainsEnabled: false,
} }
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
// Create // Create
nsGroup, err := c.DNS.CreateNameserverGroup(context.Background(), nsGroupReq) nsGroup, err := c.DNS.CreateNameserverGroup(context.Background(), nsGroupReq)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -6,10 +9,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -20,7 +25,7 @@ var (
) )
func TestEvents_List_200(t *testing.T) { func TestEvents_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Event{testEvent}) retBytes, _ := json.Marshal([]api.Event{testEvent})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -34,7 +39,7 @@ func TestEvents_List_200(t *testing.T) {
} }
func TestEvents_List_Err(t *testing.T) { func TestEvents_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -49,7 +54,7 @@ func TestEvents_List_Err(t *testing.T) {
} }
func TestEvents_Integration(t *testing.T) { func TestEvents_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
// Do something that would trigger any event // Do something that would trigger any event
_, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{ _, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
Ephemeral: ptr(true), Ephemeral: ptr(true),

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -6,10 +9,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -25,7 +30,7 @@ var (
) )
func TestGeo_ListCountries_200(t *testing.T) { func TestGeo_ListCountries_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Country{testCountry}) retBytes, _ := json.Marshal([]api.Country{testCountry})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -39,7 +44,7 @@ func TestGeo_ListCountries_200(t *testing.T) {
} }
func TestGeo_ListCountries_Err(t *testing.T) { func TestGeo_ListCountries_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -54,7 +59,7 @@ func TestGeo_ListCountries_Err(t *testing.T) {
} }
func TestGeo_ListCountryCities_200(t *testing.T) { func TestGeo_ListCountryCities_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.City{testCity}) retBytes, _ := json.Marshal([]api.City{testCity})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -68,7 +73,7 @@ func TestGeo_ListCountryCities_200(t *testing.T) {
} }
func TestGeo_ListCountryCities_Err(t *testing.T) { func TestGeo_ListCountryCities_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -84,7 +89,7 @@ func TestGeo_ListCountryCities_Err(t *testing.T) {
func TestGeo_Integration(t *testing.T) { func TestGeo_Integration(t *testing.T) {
// Blackbox is initialized with empty GeoLocations // Blackbox is initialized with empty GeoLocations
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
countries, err := c.GeoLocation.ListCountries(context.Background()) countries, err := c.GeoLocation.ListCountries(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, countries) assert.Empty(t, countries)

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -22,7 +27,7 @@ var (
) )
func TestGroups_List_200(t *testing.T) { func TestGroups_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Group{testGroup}) retBytes, _ := json.Marshal([]api.Group{testGroup})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -36,7 +41,7 @@ func TestGroups_List_200(t *testing.T) {
} }
func TestGroups_List_Err(t *testing.T) { func TestGroups_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -51,7 +56,7 @@ func TestGroups_List_Err(t *testing.T) {
} }
func TestGroups_Get_200(t *testing.T) { func TestGroups_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testGroup) retBytes, _ := json.Marshal(testGroup)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -64,7 +69,7 @@ func TestGroups_Get_200(t *testing.T) {
} }
func TestGroups_Get_Err(t *testing.T) { func TestGroups_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -79,7 +84,7 @@ func TestGroups_Get_Err(t *testing.T) {
} }
func TestGroups_Create_200(t *testing.T) { func TestGroups_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -101,7 +106,7 @@ func TestGroups_Create_200(t *testing.T) {
} }
func TestGroups_Create_Err(t *testing.T) { func TestGroups_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -118,7 +123,7 @@ func TestGroups_Create_Err(t *testing.T) {
} }
func TestGroups_Update_200(t *testing.T) { func TestGroups_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -140,7 +145,7 @@ func TestGroups_Update_200(t *testing.T) {
} }
func TestGroups_Update_Err(t *testing.T) { func TestGroups_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -157,7 +162,7 @@ func TestGroups_Update_Err(t *testing.T) {
} }
func TestGroups_Delete_200(t *testing.T) { func TestGroups_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -168,7 +173,7 @@ func TestGroups_Delete_200(t *testing.T) {
} }
func TestGroups_Delete_Err(t *testing.T) { func TestGroups_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -182,7 +187,7 @@ func TestGroups_Delete_Err(t *testing.T) {
} }
func TestGroups_Integration(t *testing.T) { func TestGroups_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
groups, err := c.Groups.List(context.Background()) groups, err := c.Groups.List(context.Background())
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, groups, 1) assert.Len(t, groups, 1)

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -30,7 +35,7 @@ var (
) )
func TestNetworks_List_200(t *testing.T) { func TestNetworks_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Network{testNetwork}) retBytes, _ := json.Marshal([]api.Network{testNetwork})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -44,7 +49,7 @@ func TestNetworks_List_200(t *testing.T) {
} }
func TestNetworks_List_Err(t *testing.T) { func TestNetworks_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -59,7 +64,7 @@ func TestNetworks_List_Err(t *testing.T) {
} }
func TestNetworks_Get_200(t *testing.T) { func TestNetworks_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetwork) retBytes, _ := json.Marshal(testNetwork)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -72,7 +77,7 @@ func TestNetworks_Get_200(t *testing.T) {
} }
func TestNetworks_Get_Err(t *testing.T) { func TestNetworks_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -87,7 +92,7 @@ func TestNetworks_Get_Err(t *testing.T) {
} }
func TestNetworks_Create_200(t *testing.T) { func TestNetworks_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -109,7 +114,7 @@ func TestNetworks_Create_200(t *testing.T) {
} }
func TestNetworks_Create_Err(t *testing.T) { func TestNetworks_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -126,7 +131,7 @@ func TestNetworks_Create_Err(t *testing.T) {
} }
func TestNetworks_Update_200(t *testing.T) { func TestNetworks_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -148,7 +153,7 @@ func TestNetworks_Update_200(t *testing.T) {
} }
func TestNetworks_Update_Err(t *testing.T) { func TestNetworks_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -165,7 +170,7 @@ func TestNetworks_Update_Err(t *testing.T) {
} }
func TestNetworks_Delete_200(t *testing.T) { func TestNetworks_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -176,7 +181,7 @@ func TestNetworks_Delete_200(t *testing.T) {
} }
func TestNetworks_Delete_Err(t *testing.T) { func TestNetworks_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -190,7 +195,7 @@ func TestNetworks_Delete_Err(t *testing.T) {
} }
func TestNetworks_Integration(t *testing.T) { func TestNetworks_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
network, err := c.Networks.Create(context.Background(), api.NetworkRequest{ network, err := c.Networks.Create(context.Background(), api.NetworkRequest{
Description: ptr("TestNetwork"), Description: ptr("TestNetwork"),
Name: "Test", Name: "Test",
@@ -216,7 +221,7 @@ func TestNetworks_Integration(t *testing.T) {
} }
func TestNetworkResources_List_200(t *testing.T) { func TestNetworkResources_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkResource{testNetworkResource}) retBytes, _ := json.Marshal([]api.NetworkResource{testNetworkResource})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -230,7 +235,7 @@ func TestNetworkResources_List_200(t *testing.T) {
} }
func TestNetworkResources_List_Err(t *testing.T) { func TestNetworkResources_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -245,7 +250,7 @@ func TestNetworkResources_List_Err(t *testing.T) {
} }
func TestNetworkResources_Get_200(t *testing.T) { func TestNetworkResources_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetworkResource) retBytes, _ := json.Marshal(testNetworkResource)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -258,7 +263,7 @@ func TestNetworkResources_Get_200(t *testing.T) {
} }
func TestNetworkResources_Get_Err(t *testing.T) { func TestNetworkResources_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -273,7 +278,7 @@ func TestNetworkResources_Get_Err(t *testing.T) {
} }
func TestNetworkResources_Create_200(t *testing.T) { func TestNetworkResources_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -295,7 +300,7 @@ func TestNetworkResources_Create_200(t *testing.T) {
} }
func TestNetworkResources_Create_Err(t *testing.T) { func TestNetworkResources_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -312,7 +317,7 @@ func TestNetworkResources_Create_Err(t *testing.T) {
} }
func TestNetworkResources_Update_200(t *testing.T) { func TestNetworkResources_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -334,7 +339,7 @@ func TestNetworkResources_Update_200(t *testing.T) {
} }
func TestNetworkResources_Update_Err(t *testing.T) { func TestNetworkResources_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -351,7 +356,7 @@ func TestNetworkResources_Update_Err(t *testing.T) {
} }
func TestNetworkResources_Delete_200(t *testing.T) { func TestNetworkResources_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -362,7 +367,7 @@ func TestNetworkResources_Delete_200(t *testing.T) {
} }
func TestNetworkResources_Delete_Err(t *testing.T) { func TestNetworkResources_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -376,7 +381,7 @@ func TestNetworkResources_Delete_Err(t *testing.T) {
} }
func TestNetworkResources_Integration(t *testing.T) { func TestNetworkResources_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
_, err := c.Networks.Resources("TestNetwork").Create(context.Background(), api.NetworkResourceRequest{ _, err := c.Networks.Resources("TestNetwork").Create(context.Background(), api.NetworkResourceRequest{
Address: "test.com", Address: "test.com",
Description: ptr("Description"), Description: ptr("Description"),
@@ -403,7 +408,7 @@ func TestNetworkResources_Integration(t *testing.T) {
} }
func TestNetworkRouters_List_200(t *testing.T) { func TestNetworkRouters_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter}) retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -417,7 +422,7 @@ func TestNetworkRouters_List_200(t *testing.T) {
} }
func TestNetworkRouters_List_Err(t *testing.T) { func TestNetworkRouters_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -432,7 +437,7 @@ func TestNetworkRouters_List_Err(t *testing.T) {
} }
func TestNetworkRouters_Get_200(t *testing.T) { func TestNetworkRouters_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetworkRouter) retBytes, _ := json.Marshal(testNetworkRouter)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -445,7 +450,7 @@ func TestNetworkRouters_Get_200(t *testing.T) {
} }
func TestNetworkRouters_Get_Err(t *testing.T) { func TestNetworkRouters_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -460,7 +465,7 @@ func TestNetworkRouters_Get_Err(t *testing.T) {
} }
func TestNetworkRouters_Create_200(t *testing.T) { func TestNetworkRouters_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -482,7 +487,7 @@ func TestNetworkRouters_Create_200(t *testing.T) {
} }
func TestNetworkRouters_Create_Err(t *testing.T) { func TestNetworkRouters_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -499,7 +504,7 @@ func TestNetworkRouters_Create_Err(t *testing.T) {
} }
func TestNetworkRouters_Update_200(t *testing.T) { func TestNetworkRouters_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -521,7 +526,7 @@ func TestNetworkRouters_Update_200(t *testing.T) {
} }
func TestNetworkRouters_Update_Err(t *testing.T) { func TestNetworkRouters_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -538,7 +543,7 @@ func TestNetworkRouters_Update_Err(t *testing.T) {
} }
func TestNetworkRouters_Delete_200(t *testing.T) { func TestNetworkRouters_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -549,7 +554,7 @@ func TestNetworkRouters_Delete_200(t *testing.T) {
} }
func TestNetworkRouters_Delete_Err(t *testing.T) { func TestNetworkRouters_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -563,7 +568,7 @@ func TestNetworkRouters_Delete_Err(t *testing.T) {
} }
func TestNetworkRouters_Integration(t *testing.T) { func TestNetworkRouters_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
_, err := c.Networks.Routers("TestNetwork").Create(context.Background(), api.NetworkRouterRequest{ _, err := c.Networks.Routers("TestNetwork").Create(context.Background(), api.NetworkRouterRequest{
Enabled: false, Enabled: false,
Masquerade: false, Masquerade: false,

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -24,7 +29,7 @@ var (
) )
func TestPeers_List_200(t *testing.T) { func TestPeers_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Peer{testPeer}) retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -38,7 +43,7 @@ func TestPeers_List_200(t *testing.T) {
} }
func TestPeers_List_Err(t *testing.T) { func TestPeers_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -53,7 +58,7 @@ func TestPeers_List_Err(t *testing.T) {
} }
func TestPeers_Get_200(t *testing.T) { func TestPeers_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPeer) retBytes, _ := json.Marshal(testPeer)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -66,7 +71,7 @@ func TestPeers_Get_200(t *testing.T) {
} }
func TestPeers_Get_Err(t *testing.T) { func TestPeers_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -81,7 +86,7 @@ func TestPeers_Get_Err(t *testing.T) {
} }
func TestPeers_Update_200(t *testing.T) { func TestPeers_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -103,7 +108,7 @@ func TestPeers_Update_200(t *testing.T) {
} }
func TestPeers_Update_Err(t *testing.T) { func TestPeers_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -120,7 +125,7 @@ func TestPeers_Update_Err(t *testing.T) {
} }
func TestPeers_Delete_200(t *testing.T) { func TestPeers_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -131,7 +136,7 @@ func TestPeers_Delete_200(t *testing.T) {
} }
func TestPeers_Delete_Err(t *testing.T) { func TestPeers_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -145,7 +150,7 @@ func TestPeers_Delete_Err(t *testing.T) {
} }
func TestPeers_ListAccessiblePeers_200(t *testing.T) { func TestPeers_ListAccessiblePeers_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Peer{testPeer}) retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -159,7 +164,7 @@ func TestPeers_ListAccessiblePeers_200(t *testing.T) {
} }
func TestPeers_ListAccessiblePeers_Err(t *testing.T) { func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -174,7 +179,7 @@ func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
} }
func TestPeers_Integration(t *testing.T) { func TestPeers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
peers, err := c.Peers.List(context.Background()) peers, err := c.Peers.List(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, peers) require.NotEmpty(t, peers)

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -22,7 +27,7 @@ var (
) )
func TestPolicies_List_200(t *testing.T) { func TestPolicies_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Policy{testPolicy}) retBytes, _ := json.Marshal([]api.Policy{testPolicy})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -36,7 +41,7 @@ func TestPolicies_List_200(t *testing.T) {
} }
func TestPolicies_List_Err(t *testing.T) { func TestPolicies_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -51,7 +56,7 @@ func TestPolicies_List_Err(t *testing.T) {
} }
func TestPolicies_Get_200(t *testing.T) { func TestPolicies_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPolicy) retBytes, _ := json.Marshal(testPolicy)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -64,7 +69,7 @@ func TestPolicies_Get_200(t *testing.T) {
} }
func TestPolicies_Get_Err(t *testing.T) { func TestPolicies_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -79,7 +84,7 @@ func TestPolicies_Get_Err(t *testing.T) {
} }
func TestPolicies_Create_200(t *testing.T) { func TestPolicies_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -101,7 +106,7 @@ func TestPolicies_Create_200(t *testing.T) {
} }
func TestPolicies_Create_Err(t *testing.T) { func TestPolicies_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -118,7 +123,7 @@ func TestPolicies_Create_Err(t *testing.T) {
} }
func TestPolicies_Update_200(t *testing.T) { func TestPolicies_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -140,7 +145,7 @@ func TestPolicies_Update_200(t *testing.T) {
} }
func TestPolicies_Update_Err(t *testing.T) { func TestPolicies_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -157,7 +162,7 @@ func TestPolicies_Update_Err(t *testing.T) {
} }
func TestPolicies_Delete_200(t *testing.T) { func TestPolicies_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -168,7 +173,7 @@ func TestPolicies_Delete_200(t *testing.T) {
} }
func TestPolicies_Delete_Err(t *testing.T) { func TestPolicies_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -182,7 +187,7 @@ func TestPolicies_Delete_Err(t *testing.T) {
} }
func TestPolicies_Integration(t *testing.T) { func TestPolicies_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
policies, err := c.Policies.List(context.Background()) policies, err := c.Policies.List(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, policies) require.NotEmpty(t, policies)

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -21,7 +26,7 @@ var (
) )
func TestPostureChecks_List_200(t *testing.T) { func TestPostureChecks_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.PostureCheck{testPostureCheck}) retBytes, _ := json.Marshal([]api.PostureCheck{testPostureCheck})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -35,7 +40,7 @@ func TestPostureChecks_List_200(t *testing.T) {
} }
func TestPostureChecks_List_Err(t *testing.T) { func TestPostureChecks_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -50,7 +55,7 @@ func TestPostureChecks_List_Err(t *testing.T) {
} }
func TestPostureChecks_Get_200(t *testing.T) { func TestPostureChecks_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPostureCheck) retBytes, _ := json.Marshal(testPostureCheck)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -63,7 +68,7 @@ func TestPostureChecks_Get_200(t *testing.T) {
} }
func TestPostureChecks_Get_Err(t *testing.T) { func TestPostureChecks_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -78,7 +83,7 @@ func TestPostureChecks_Get_Err(t *testing.T) {
} }
func TestPostureChecks_Create_200(t *testing.T) { func TestPostureChecks_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -100,7 +105,7 @@ func TestPostureChecks_Create_200(t *testing.T) {
} }
func TestPostureChecks_Create_Err(t *testing.T) { func TestPostureChecks_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -117,7 +122,7 @@ func TestPostureChecks_Create_Err(t *testing.T) {
} }
func TestPostureChecks_Update_200(t *testing.T) { func TestPostureChecks_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -139,7 +144,7 @@ func TestPostureChecks_Update_200(t *testing.T) {
} }
func TestPostureChecks_Update_Err(t *testing.T) { func TestPostureChecks_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -156,7 +161,7 @@ func TestPostureChecks_Update_Err(t *testing.T) {
} }
func TestPostureChecks_Delete_200(t *testing.T) { func TestPostureChecks_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -167,7 +172,7 @@ func TestPostureChecks_Delete_200(t *testing.T) {
} }
func TestPostureChecks_Delete_Err(t *testing.T) { func TestPostureChecks_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -181,7 +186,7 @@ func TestPostureChecks_Delete_Err(t *testing.T) {
} }
func TestPostureChecks_Integration(t *testing.T) { func TestPostureChecks_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
check, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{ check, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
Name: "Test", Name: "Test",
Description: "Testing", Description: "Testing",

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -21,7 +26,7 @@ var (
) )
func TestRoutes_List_200(t *testing.T) { func TestRoutes_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Route{testRoute}) retBytes, _ := json.Marshal([]api.Route{testRoute})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -35,7 +40,7 @@ func TestRoutes_List_200(t *testing.T) {
} }
func TestRoutes_List_Err(t *testing.T) { func TestRoutes_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -50,7 +55,7 @@ func TestRoutes_List_Err(t *testing.T) {
} }
func TestRoutes_Get_200(t *testing.T) { func TestRoutes_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testRoute) retBytes, _ := json.Marshal(testRoute)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -63,7 +68,7 @@ func TestRoutes_Get_200(t *testing.T) {
} }
func TestRoutes_Get_Err(t *testing.T) { func TestRoutes_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -78,7 +83,7 @@ func TestRoutes_Get_Err(t *testing.T) {
} }
func TestRoutes_Create_200(t *testing.T) { func TestRoutes_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -100,7 +105,7 @@ func TestRoutes_Create_200(t *testing.T) {
} }
func TestRoutes_Create_Err(t *testing.T) { func TestRoutes_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -117,7 +122,7 @@ func TestRoutes_Create_Err(t *testing.T) {
} }
func TestRoutes_Update_200(t *testing.T) { func TestRoutes_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -139,7 +144,7 @@ func TestRoutes_Update_200(t *testing.T) {
} }
func TestRoutes_Update_Err(t *testing.T) { func TestRoutes_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -156,7 +161,7 @@ func TestRoutes_Update_Err(t *testing.T) {
} }
func TestRoutes_Delete_200(t *testing.T) { func TestRoutes_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -167,7 +172,7 @@ func TestRoutes_Delete_200(t *testing.T) {
} }
func TestRoutes_Delete_Err(t *testing.T) { func TestRoutes_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -181,7 +186,7 @@ func TestRoutes_Delete_Err(t *testing.T) {
} }
func TestRoutes_Integration(t *testing.T) { func TestRoutes_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
route, err := c.Routes.Create(context.Background(), api.RouteRequest{ route, err := c.Routes.Create(context.Background(), api.RouteRequest{
Description: "Meow", Description: "Meow",
Enabled: false, Enabled: false,

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -7,10 +10,12 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -31,7 +36,7 @@ var (
) )
func TestSetupKeys_List_200(t *testing.T) { func TestSetupKeys_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.SetupKey{testSetupKey}) retBytes, _ := json.Marshal([]api.SetupKey{testSetupKey})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -45,7 +50,7 @@ func TestSetupKeys_List_200(t *testing.T) {
} }
func TestSetupKeys_List_Err(t *testing.T) { func TestSetupKeys_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -60,7 +65,7 @@ func TestSetupKeys_List_Err(t *testing.T) {
} }
func TestSetupKeys_Get_200(t *testing.T) { func TestSetupKeys_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testSetupKey) retBytes, _ := json.Marshal(testSetupKey)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -73,7 +78,7 @@ func TestSetupKeys_Get_200(t *testing.T) {
} }
func TestSetupKeys_Get_Err(t *testing.T) { func TestSetupKeys_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -88,7 +93,7 @@ func TestSetupKeys_Get_Err(t *testing.T) {
} }
func TestSetupKeys_Create_200(t *testing.T) { func TestSetupKeys_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -110,7 +115,7 @@ func TestSetupKeys_Create_200(t *testing.T) {
} }
func TestSetupKeys_Create_Err(t *testing.T) { func TestSetupKeys_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -127,7 +132,7 @@ func TestSetupKeys_Create_Err(t *testing.T) {
} }
func TestSetupKeys_Update_200(t *testing.T) { func TestSetupKeys_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -149,7 +154,7 @@ func TestSetupKeys_Update_200(t *testing.T) {
} }
func TestSetupKeys_Update_Err(t *testing.T) { func TestSetupKeys_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -166,7 +171,7 @@ func TestSetupKeys_Update_Err(t *testing.T) {
} }
func TestSetupKeys_Delete_200(t *testing.T) { func TestSetupKeys_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -177,7 +182,7 @@ func TestSetupKeys_Delete_200(t *testing.T) {
} }
func TestSetupKeys_Delete_Err(t *testing.T) { func TestSetupKeys_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -191,7 +196,7 @@ func TestSetupKeys_Delete_Err(t *testing.T) {
} }
func TestSetupKeys_Integration(t *testing.T) { func TestSetupKeys_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
group, err := c.Groups.Create(context.Background(), api.GroupRequest{ group, err := c.Groups.Create(context.Background(), api.GroupRequest{
Name: "Test", Name: "Test",
}) })

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -8,10 +11,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -31,7 +36,7 @@ var (
) )
func TestTokens_List_200(t *testing.T) { func TestTokens_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.PersonalAccessToken{testToken}) retBytes, _ := json.Marshal([]api.PersonalAccessToken{testToken})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -45,7 +50,7 @@ func TestTokens_List_200(t *testing.T) {
} }
func TestTokens_List_Err(t *testing.T) { func TestTokens_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -60,7 +65,7 @@ func TestTokens_List_Err(t *testing.T) {
} }
func TestTokens_Get_200(t *testing.T) { func TestTokens_Get_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testToken) retBytes, _ := json.Marshal(testToken)
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -73,7 +78,7 @@ func TestTokens_Get_200(t *testing.T) {
} }
func TestTokens_Get_Err(t *testing.T) { func TestTokens_Get_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -88,7 +93,7 @@ func TestTokens_Get_Err(t *testing.T) {
} }
func TestTokens_Create_200(t *testing.T) { func TestTokens_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -110,7 +115,7 @@ func TestTokens_Create_200(t *testing.T) {
} }
func TestTokens_Create_Err(t *testing.T) { func TestTokens_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -127,7 +132,7 @@ func TestTokens_Create_Err(t *testing.T) {
} }
func TestTokens_Delete_200(t *testing.T) { func TestTokens_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -138,7 +143,7 @@ func TestTokens_Delete_200(t *testing.T) {
} }
func TestTokens_Delete_Err(t *testing.T) { func TestTokens_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -152,7 +157,7 @@ func TestTokens_Delete_Err(t *testing.T) {
} }
func TestTokens_Integration(t *testing.T) { func TestTokens_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
tokenClear, err := c.Tokens.Create(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", api.PersonalAccessTokenRequest{ tokenClear, err := c.Tokens.Create(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", api.PersonalAccessTokenRequest{
Name: "Test", Name: "Test",
ExpiresIn: 365, ExpiresIn: 365,

View File

@@ -1,4 +1,7 @@
package rest //go:build integration
// +build integration
package rest_test
import ( import (
"context" "context"
@@ -8,10 +11,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
) )
var ( var (
@@ -34,7 +39,7 @@ var (
) )
func TestUsers_List_200(t *testing.T) { func TestUsers_List_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.User{testUser}) retBytes, _ := json.Marshal([]api.User{testUser})
_, err := w.Write(retBytes) _, err := w.Write(retBytes)
@@ -48,7 +53,7 @@ func TestUsers_List_200(t *testing.T) {
} }
func TestUsers_List_Err(t *testing.T) { func TestUsers_List_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -63,7 +68,7 @@ func TestUsers_List_Err(t *testing.T) {
} }
func TestUsers_Create_200(t *testing.T) { func TestUsers_Create_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -85,7 +90,7 @@ func TestUsers_Create_200(t *testing.T) {
} }
func TestUsers_Create_Err(t *testing.T) { func TestUsers_Create_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -102,7 +107,7 @@ func TestUsers_Create_Err(t *testing.T) {
} }
func TestUsers_Update_200(t *testing.T) { func TestUsers_Update_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method) assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body) reqBytes, err := io.ReadAll(r.Body)
@@ -125,7 +130,7 @@ func TestUsers_Update_200(t *testing.T) {
} }
func TestUsers_Update_Err(t *testing.T) { func TestUsers_Update_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400) w.WriteHeader(400)
@@ -142,7 +147,7 @@ func TestUsers_Update_Err(t *testing.T) {
} }
func TestUsers_Delete_200(t *testing.T) { func TestUsers_Delete_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method) assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -153,7 +158,7 @@ func TestUsers_Delete_200(t *testing.T) {
} }
func TestUsers_Delete_Err(t *testing.T) { func TestUsers_Delete_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -167,7 +172,7 @@ func TestUsers_Delete_Err(t *testing.T) {
} }
func TestUsers_ResendInvitation_200(t *testing.T) { func TestUsers_ResendInvitation_200(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method) assert.Equal(t, "POST", r.Method)
w.WriteHeader(200) w.WriteHeader(200)
@@ -178,7 +183,7 @@ func TestUsers_ResendInvitation_200(t *testing.T) {
} }
func TestUsers_ResendInvitation_Err(t *testing.T) { func TestUsers_ResendInvitation_Err(t *testing.T) {
withMockClient(func(c *Client, mux *http.ServeMux) { withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404) w.WriteHeader(404)
@@ -192,7 +197,7 @@ func TestUsers_ResendInvitation_Err(t *testing.T) {
} }
func TestUsers_Integration(t *testing.T) { func TestUsers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *Client) { withBlackBoxServer(t, func(c *rest.Client) {
user, err := c.Users.Create(context.Background(), api.UserCreateRequest{ user, err := c.Users.Create(context.Background(), api.UserCreateRequest{
AutoGroups: []string{}, AutoGroups: []string{},
Email: ptr("test@example.com"), Email: ptr("test@example.com"),

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
@@ -114,6 +115,18 @@ func NewServer(
} }
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
ip := ""
p, ok := peer.FromContext(ctx)
if ok {
ip = p.Addr.String()
}
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
}()
// todo introduce something more meaningful with the key expiration/rotation // todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest() s.appMetrics.GRPCMetrics().CountGetKeyRequest()
@@ -717,6 +730,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
// This is used for initiating an Oauth 2 device authorization grant flow // This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
@@ -769,6 +788,12 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// This is used for initiating an Oauth 2 pkce authorization grant flow // This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey) errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)

View File

@@ -141,7 +141,6 @@ type Client struct {
muInstanceURL sync.Mutex muInstanceURL sync.Mutex
onDisconnectListener func(string) onDisconnectListener func(string)
onConnectedListener func()
listenerMutex sync.Mutex listenerMutex sync.Mutex
} }
@@ -190,7 +189,6 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(c.relayConn)
go c.notifyConnected()
return nil return nil
} }
@@ -238,12 +236,6 @@ func (c *Client) SetOnDisconnectListener(fn func(string)) {
c.onDisconnectListener = fn c.onDisconnectListener = fn
} }
func (c *Client) SetOnConnectedListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onConnectedListener = fn
}
// HasConns returns true if there are connections. // HasConns returns true if there are connections.
func (c *Client) HasConns() bool { func (c *Client) HasConns() bool {
c.mu.Lock() c.mu.Lock()
@@ -559,16 +551,6 @@ func (c *Client) notifyDisconnected() {
go c.onDisconnectListener(c.connectionURL) go c.onDisconnectListener(c.connectionURL)
} }
func (c *Client) notifyConnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onConnectedListener == nil {
return
}
go c.onConnectedListener()
}
func (c *Client) writeCloseMsg() { func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg() msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg) _, err := c.relayConn.Write(msg)

View File

@@ -14,8 +14,9 @@ var (
// Guard manage the reconnection tries to the Relay server in case of disconnection event. // Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct { type Guard struct {
// OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance. // OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance.
OnNewRelayClient chan *Client OnNewRelayClient chan *Client
OnReconnected chan struct{}
serverPicker *ServerPicker serverPicker *ServerPicker
} }
@@ -23,6 +24,7 @@ type Guard struct {
func NewGuard(sp *ServerPicker) *Guard { func NewGuard(sp *ServerPicker) *Guard {
g := &Guard{ g := &Guard{
OnNewRelayClient: make(chan *Client, 1), OnNewRelayClient: make(chan *Client, 1),
OnReconnected: make(chan struct{}, 1),
serverPicker: sp, serverPicker: sp,
} }
return g return g
@@ -39,14 +41,13 @@ func NewGuard(sp *ServerPicker) *Guard {
// - relayClient: The relay client instance that was disconnected. // - relayClient: The relay client instance that was disconnected.
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
if relayClient == nil { // try to reconnect to the same server
goto RETRY if ok := g.tryToQuickReconnect(ctx, relayClient); ok {
} g.notifyReconnected()
if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) {
return return
} }
RETRY: // start a ticker to pick a new server
ticker := exponentTicker(ctx) ticker := exponentTicker(ctx)
defer ticker.Stop() defer ticker.Stop()
@@ -64,6 +65,28 @@ RETRY:
} }
} }
func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool {
if rc == nil {
return false
}
if !g.isServerURLStillValid(rc) {
return false
}
if cancelled := waiteBeforeRetry(parentCtx); !cancelled {
return false
}
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := rc.Connect(); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
return false
}
return true
}
func (g *Guard) retry(ctx context.Context) error { func (g *Guard) retry(ctx context.Context) error {
log.Infof("try to pick up a new Relay server") log.Infof("try to pick up a new Relay server")
relayClient, err := g.serverPicker.PickServer(ctx) relayClient, err := g.serverPicker.PickServer(ctx)
@@ -78,23 +101,6 @@ func (g *Guard) retry(ctx context.Context) error {
return nil return nil
} }
func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool {
ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond)
defer cancel()
<-ctx.Done()
if parentCtx.Err() != nil {
return false
}
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := rc.Connect(); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
return false
}
return true
}
func (g *Guard) drainRelayClientChan() { func (g *Guard) drainRelayClientChan() {
select { select {
case <-g.OnNewRelayClient: case <-g.OnNewRelayClient:
@@ -111,6 +117,13 @@ func (g *Guard) isServerURLStillValid(rc *Client) bool {
return false return false
} }
func (g *Guard) notifyReconnected() {
select {
case g.OnReconnected <- struct{}{}:
default:
}
}
func exponentTicker(ctx context.Context) *backoff.Ticker { func exponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 2 * time.Second, InitialInterval: 2 * time.Second,
@@ -121,3 +134,15 @@ func exponentTicker(ctx context.Context) *backoff.Ticker {
return backoff.NewTicker(bo) return backoff.NewTicker(bo)
} }
func waiteBeforeRetry(ctx context.Context) bool {
timer := time.NewTimer(1500 * time.Millisecond)
defer timer.Stop()
select {
case <-timer.C:
return true
case <-ctx.Done():
return false
}
}

Some files were not shown because too many files have changed in this diff Show More