mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-12 19:09:54 +00:00
Compare commits
7 Commits
refactor/n
...
test/timin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95365f41c3 | ||
|
|
b50c060bcc | ||
|
|
eae4f4be12 | ||
|
|
9eb2faef7f | ||
|
|
a2313a5ba4 | ||
|
|
8c108ccad3 | ||
|
|
86eff0d750 |
@@ -50,6 +50,12 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
@@ -113,6 +119,9 @@ type Manager struct {
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATRules []portDNATRule
|
||||
portDNATMutex sync.RWMutex
|
||||
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -203,6 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@@ -838,9 +848,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
if m.shouldForward(d, dstIP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
@@ -1274,3 +1282,86 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
m.netstackServices[key] = struct{}{}
|
||||
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
|
||||
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
|
||||
}
|
||||
|
||||
// UnregisterNetstackService removes a service from the netstack registry
|
||||
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
delete(m.netstackServices, key)
|
||||
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
|
||||
}
|
||||
|
||||
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
|
||||
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
|
||||
switch protocol {
|
||||
case nftypes.TCP:
|
||||
return layers.LayerTypeTCP
|
||||
case nftypes.UDP:
|
||||
return layers.LayerTypeUDP
|
||||
case nftypes.ICMP:
|
||||
return layers.LayerTypeICMPv4
|
||||
default:
|
||||
return gopacket.LayerType(0) // Invalid/unknown
|
||||
}
|
||||
}
|
||||
|
||||
// shouldForward determines if a packet should be forwarded to the forwarder.
|
||||
// The forwarder handles routing packets to the native OS network stack.
|
||||
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
|
||||
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
// not enabled, never forward
|
||||
if !m.localForwarding {
|
||||
return false
|
||||
}
|
||||
|
||||
// netstack always needs to forward because it's lacking a native interface
|
||||
// exception for registered netstack services, those should go to netstack listeners
|
||||
if m.netstack {
|
||||
return !m.hasMatchingNetstackService(d)
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
return true
|
||||
}
|
||||
|
||||
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
|
||||
return false
|
||||
}
|
||||
|
||||
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
|
||||
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
|
||||
if len(d.decoded) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
var dstPort uint16
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
key := serviceKey{protocol: d.decoded[1], port: dstPort}
|
||||
m.netstackServiceMutex.RLock()
|
||||
_, exists := m.netstackServices[key]
|
||||
m.netstackServiceMutex.RUnlock()
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
@@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -33,7 +34,7 @@ type firewaller interface {
|
||||
}
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
listenAddress netip.AddrPort
|
||||
ttl uint32
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -47,9 +48,11 @@ type DNSForwarder struct {
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
|
||||
wgIface wgIface
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
@@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
|
||||
var netstackNet *netstack.Net
|
||||
if f.wgIface != nil {
|
||||
netstackNet = f.wgIface.GetNet()
|
||||
}
|
||||
|
||||
addrDesc := f.listenAddress.String()
|
||||
if netstackNet != nil {
|
||||
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
|
||||
}
|
||||
log.Infof("starting DNS forwarder on address=%s", addrDesc)
|
||||
|
||||
udpLn, err := f.createUDPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create UDP listener: %w", err)
|
||||
}
|
||||
|
||||
tcpLn, err := f.createTCPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create TCP listener: %w", err)
|
||||
}
|
||||
|
||||
// UDP server
|
||||
mux := dns.NewServeMux()
|
||||
f.mux = mux
|
||||
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||
f.dnsServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
PacketConn: udpLn,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// TCP server
|
||||
tcpMux := dns.NewServeMux()
|
||||
f.tcpMux = tcpMux
|
||||
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||
f.tcpServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "tcp",
|
||||
Handler: tcpMux,
|
||||
Listener: tcpLn,
|
||||
Handler: tcpMux,
|
||||
}
|
||||
|
||||
f.UpdateDomains(entries)
|
||||
@@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
log.Infof("DNS UDP listener running on %s", f.listenAddress)
|
||||
errCh <- f.dnsServer.ListenAndServe()
|
||||
log.Infof("DNS UDP listener running on %s", addrDesc)
|
||||
errCh <- f.dnsServer.ActivateAndServe()
|
||||
}()
|
||||
go func() {
|
||||
log.Infof("DNS TCP listener running on %s", f.listenAddress)
|
||||
errCh <- f.tcpServer.ListenAndServe()
|
||||
log.Infof("DNS TCP listener running on %s", addrDesc)
|
||||
errCh <- f.tcpServer.ActivateAndServe()
|
||||
}()
|
||||
|
||||
// return the first error we get (e.g. bind failure or shutdown)
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenUDPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenTCPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
@@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||
}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString(tt.configuredDomain)
|
||||
@@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
// Set up forwarder
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Create entries and track sets
|
||||
@@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Configure a single domain
|
||||
@@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
d, err := domain.FromString(tt.configured)
|
||||
require.NoError(t, err)
|
||||
@@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, _ := domain.FromString("example.com")
|
||||
@@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
@@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
@@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Set up complex overlapping patterns
|
||||
@@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
@@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
|
||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
// Test handling of malformed query with no questions
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -24,6 +26,12 @@ const (
|
||||
envServerPort = "NB_DNS_FORWARDER_PORT"
|
||||
)
|
||||
|
||||
// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder.
|
||||
type wgIface interface {
|
||||
GetNet() *netstack.Net
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
type ForwarderEntry struct {
|
||||
Domain domain.Domain
|
||||
@@ -34,7 +42,7 @@ type ForwarderEntry struct {
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
localAddr netip.Addr
|
||||
wgIface wgIface
|
||||
serverPort uint16
|
||||
|
||||
fwRules []firewall.Rule
|
||||
@@ -42,7 +50,7 @@ type Manager struct {
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager {
|
||||
serverPort := nbdns.ForwarderServerPort
|
||||
if envPort := os.Getenv(envServerPort); envPort != "" {
|
||||
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
||||
@@ -56,7 +64,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr neti
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
localAddr: localAddr,
|
||||
wgIface: wgIface,
|
||||
serverPort: serverPort,
|
||||
}
|
||||
}
|
||||
@@ -71,21 +79,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
localAddr := m.wgIface.Address().IP
|
||||
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
listenAddress := netip.AddrPortFrom(localAddr, m.serverPort)
|
||||
m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface)
|
||||
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -111,12 +123,13 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
localAddr := m.wgIface.Address().IP
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
||||
}
|
||||
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1855,35 +1855,69 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.stopDNSForwarder()
|
||||
return
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
if e.dnsForwardMgr == nil {
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
e.startDNSForwarder(fwdEntries)
|
||||
} else {
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.stopDNSForwarder()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
|
||||
e.registerDNSServices()
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSForwarder() {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
|
||||
e.unregisterDNSServices()
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
func (e *Engine) registerDNSServices() {
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
|
||||
registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
|
||||
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) unregisterDNSServices() {
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
|
||||
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package networkmonitor
|
||||
|
||||
@@ -6,21 +6,19 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
fd, err := prepareFd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
@@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
return routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||
}
|
||||
|
||||
92
client/internal/networkmonitor/check_change_common.go
Normal file
92
client/internal/networkmonitor/check_change_common.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd || darwin
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func prepareFd() (int, error) {
|
||||
return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
}
|
||||
|
||||
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
149
client/internal/networkmonitor/check_change_darwin.go
Normal file
149
client/internal/networkmonitor/check_change_darwin.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// todo: refactor to not use static functions
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := prepareFd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil {
|
||||
if !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
routeChanged := make(chan struct{})
|
||||
go func() {
|
||||
_ = routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||
close(routeChanged)
|
||||
}()
|
||||
|
||||
wakeUp := make(chan struct{})
|
||||
go func() {
|
||||
wakeUpListen(ctx)
|
||||
close(wakeUp)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-routeChanged:
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
log.Infof("route change detected")
|
||||
return nil
|
||||
case <-wakeUp:
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
log.Infof("wakeup detected")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func wakeUpListen(ctx context.Context) {
|
||||
log.Infof("start to watch for system wakeups")
|
||||
var (
|
||||
initialHash uint32
|
||||
err error
|
||||
)
|
||||
|
||||
// Keep retrying until initial sysctl succeeds or context is canceled
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
default:
|
||||
initialHash, err = readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to detect initial sleep time: %v", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
case <-time.After(3 * time.Second):
|
||||
continue
|
||||
}
|
||||
}
|
||||
log.Debugf("initial wakeup hash: %d", initialHash)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("context canceled, stopping wakeUpListen")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
newHash, err := readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to read sleep time hash: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if newHash == initialHash {
|
||||
log.Tracef("no wakeup detected")
|
||||
continue
|
||||
}
|
||||
|
||||
upOut, err := exec.Command("uptime").Output()
|
||||
if err != nil {
|
||||
log.Errorf("failed to run uptime command: %v", err)
|
||||
upOut = []byte("unknown")
|
||||
}
|
||||
log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readSleepTimeHash() (uint32, error) {
|
||||
cmd := exec.Command("sysctl", "kern.sleeptime")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to run sysctl: %w", err)
|
||||
}
|
||||
|
||||
h, err := hash(out)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to compute hash: %w", err)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func hash(data []byte) (uint32, error) {
|
||||
hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher
|
||||
if _, err := hasher.Write(data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return hasher.Sum32(), nil
|
||||
}
|
||||
@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||
event := make(chan struct{}, 1)
|
||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||
|
||||
log.Infof("start watching for network changes")
|
||||
// debounce changes
|
||||
timer := time.NewTimer(0)
|
||||
timer.Stop()
|
||||
|
||||
4
go.mod
4
go.mod
@@ -76,7 +76,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.7
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/quic-go/quic-go v0.48.2
|
||||
github.com/quic-go/quic-go v0.49.1
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
@@ -241,7 +241,7 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -590,8 +590,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
|
||||
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
|
||||
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
@@ -749,8 +749,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
|
||||
@@ -757,6 +757,7 @@ func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID st
|
||||
// If the user doesn't have an account, it creates one using the provided domain.
|
||||
// Returns the account ID or an error if none is found or created.
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
|
||||
defer util.TimeTrack(ctx, "GetAccountIDByUserID")()
|
||||
if userID == "" {
|
||||
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
||||
}
|
||||
@@ -785,6 +786,8 @@ func isNil(i idp.Manager) bool {
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
defer util.TimeTrack(ctx, "addAccountIDToIDPAppMeta")()
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
@@ -1043,6 +1046,8 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
defer util.TimeTrack(ctx, "updateAccountDomainAttributesIfNotUpToDate")()
|
||||
|
||||
if userAuth.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", userAuth)
|
||||
return nil
|
||||
@@ -1091,6 +1096,8 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
domainAccountID string,
|
||||
userAuth nbcontext.UserAuth,
|
||||
) error {
|
||||
defer util.TimeTrack(ctx, "handleExistingUserAccount")()
|
||||
|
||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
|
||||
if err != nil {
|
||||
@@ -1109,6 +1116,8 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
defer util.TimeTrack(ctx, "addNewPrivateAccount")()
|
||||
|
||||
if userAuth.UserId == "" {
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
@@ -1140,6 +1149,8 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
defer util.TimeTrack(ctx, "addNewUserToDomainAccount")()
|
||||
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
|
||||
@@ -1304,6 +1315,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
defer util.TimeTrack(ctx, "GetAccountIDFromUserAuth")()
|
||||
if userAuth.UserId == "" {
|
||||
return "", "", errors.New(emptyUserID)
|
||||
}
|
||||
@@ -1348,6 +1360,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
defer util.TimeTrack(ctx, "SyncUserJWTGroups")()
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return nil
|
||||
}
|
||||
@@ -1506,6 +1519,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
//
|
||||
// UserAuth IsChild -> checks that account exists
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
defer util.TimeTrack(ctx, "getAccountIDWithAuthorizationClaims")()
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
|
||||
|
||||
@@ -1559,6 +1573,8 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
|
||||
}
|
||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||
defer util.TimeTrack(ctx, "getPrivateDomainWithGlobalLock")()
|
||||
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
|
||||
@@ -1585,6 +1601,8 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
defer util.TimeTrack(ctx, "handlePrivateAccountWithIDFromClaim")()
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
|
||||
@@ -88,9 +88,13 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
conns = 1
|
||||
}
|
||||
|
||||
sql.SetMaxOpenConns(conns)
|
||||
sql.SetMaxOpenConns(10)
|
||||
sql.SetMaxIdleConns(10)
|
||||
sql.SetConnMaxLifetime(time.Hour)
|
||||
sql.SetConnMaxIdleTime(3 * time.Minute)
|
||||
|
||||
log.WithContext(ctx).Infof("Set max open db connections to %d", conns)
|
||||
log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
|
||||
conns, conns, time.Hour, 3*time.Minute)
|
||||
|
||||
if skipMigration {
|
||||
log.WithContext(ctx).Infof("skipping migration")
|
||||
@@ -500,6 +504,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
defer util.TimeTrack(ctx, "GetAccountIDByPrivateDomain")()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
@@ -575,6 +581,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
defer util.TimeTrack(ctx, "GetUserByUserID")()
|
||||
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -940,6 +948,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
|
||||
defer util.TimeTrack(ctx, "GetAccountIDByUserID")()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
@@ -1127,6 +1137,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
defer util.TimeTrack(ctx, "SyncUserJWTGroups")()
|
||||
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -1784,6 +1796,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
defer util.TimeTrack(ctx, "AccountExists")()
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
@@ -1804,6 +1817,8 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
defer util.TimeTrack(ctx, "GetAccountDomainAndCategory")()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
|
||||
@@ -178,6 +178,8 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
|
||||
// GetUser looks up a user by provided nbContext.UserAuths.
|
||||
// Expects account to have been created already.
|
||||
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
|
||||
defer util.TimeTrack(ctx, "SyncUserJWTGroups")()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Difference returns the elements in `a` that aren't in `b`.
|
||||
func Difference(a, b []string) []string {
|
||||
mb := make(map[string]struct{}, len(b))
|
||||
@@ -50,3 +57,13 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TimeTrack(ctx context.Context, name string) func() {
|
||||
start := time.Now()
|
||||
return func() {
|
||||
elapsed := time.Since(start)
|
||||
if elapsed > time.Second {
|
||||
log.WithContext(ctx).Infof("Slow Call: [%s] took %s", name, elapsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user